cudf-polars-cu13 25.10.0__py3-none-any.whl → 26.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. cudf_polars/GIT_COMMIT +1 -1
  2. cudf_polars/VERSION +1 -1
  3. cudf_polars/callback.py +60 -15
  4. cudf_polars/containers/column.py +137 -77
  5. cudf_polars/containers/dataframe.py +123 -34
  6. cudf_polars/containers/datatype.py +134 -13
  7. cudf_polars/dsl/expr.py +0 -2
  8. cudf_polars/dsl/expressions/aggregation.py +80 -28
  9. cudf_polars/dsl/expressions/binaryop.py +34 -14
  10. cudf_polars/dsl/expressions/boolean.py +110 -37
  11. cudf_polars/dsl/expressions/datetime.py +59 -30
  12. cudf_polars/dsl/expressions/literal.py +11 -5
  13. cudf_polars/dsl/expressions/rolling.py +460 -119
  14. cudf_polars/dsl/expressions/selection.py +9 -8
  15. cudf_polars/dsl/expressions/slicing.py +1 -1
  16. cudf_polars/dsl/expressions/string.py +256 -114
  17. cudf_polars/dsl/expressions/struct.py +19 -7
  18. cudf_polars/dsl/expressions/ternary.py +33 -3
  19. cudf_polars/dsl/expressions/unary.py +126 -64
  20. cudf_polars/dsl/ir.py +1053 -350
  21. cudf_polars/dsl/to_ast.py +30 -13
  22. cudf_polars/dsl/tracing.py +194 -0
  23. cudf_polars/dsl/translate.py +307 -107
  24. cudf_polars/dsl/utils/aggregations.py +43 -30
  25. cudf_polars/dsl/utils/reshape.py +14 -2
  26. cudf_polars/dsl/utils/rolling.py +12 -8
  27. cudf_polars/dsl/utils/windows.py +35 -20
  28. cudf_polars/experimental/base.py +55 -2
  29. cudf_polars/experimental/benchmarks/pdsds.py +12 -126
  30. cudf_polars/experimental/benchmarks/pdsh.py +792 -2
  31. cudf_polars/experimental/benchmarks/utils.py +596 -39
  32. cudf_polars/experimental/dask_registers.py +47 -20
  33. cudf_polars/experimental/dispatch.py +9 -3
  34. cudf_polars/experimental/distinct.py +2 -0
  35. cudf_polars/experimental/explain.py +15 -2
  36. cudf_polars/experimental/expressions.py +30 -15
  37. cudf_polars/experimental/groupby.py +25 -4
  38. cudf_polars/experimental/io.py +156 -124
  39. cudf_polars/experimental/join.py +53 -23
  40. cudf_polars/experimental/parallel.py +68 -19
  41. cudf_polars/experimental/rapidsmpf/__init__.py +8 -0
  42. cudf_polars/experimental/rapidsmpf/collectives/__init__.py +9 -0
  43. cudf_polars/experimental/rapidsmpf/collectives/allgather.py +90 -0
  44. cudf_polars/experimental/rapidsmpf/collectives/common.py +96 -0
  45. cudf_polars/experimental/rapidsmpf/collectives/shuffle.py +253 -0
  46. cudf_polars/experimental/rapidsmpf/core.py +488 -0
  47. cudf_polars/experimental/rapidsmpf/dask.py +172 -0
  48. cudf_polars/experimental/rapidsmpf/dispatch.py +153 -0
  49. cudf_polars/experimental/rapidsmpf/io.py +696 -0
  50. cudf_polars/experimental/rapidsmpf/join.py +322 -0
  51. cudf_polars/experimental/rapidsmpf/lower.py +74 -0
  52. cudf_polars/experimental/rapidsmpf/nodes.py +735 -0
  53. cudf_polars/experimental/rapidsmpf/repartition.py +216 -0
  54. cudf_polars/experimental/rapidsmpf/union.py +115 -0
  55. cudf_polars/experimental/rapidsmpf/utils.py +374 -0
  56. cudf_polars/experimental/repartition.py +9 -2
  57. cudf_polars/experimental/select.py +177 -14
  58. cudf_polars/experimental/shuffle.py +46 -12
  59. cudf_polars/experimental/sort.py +100 -26
  60. cudf_polars/experimental/spilling.py +1 -1
  61. cudf_polars/experimental/statistics.py +24 -5
  62. cudf_polars/experimental/utils.py +25 -7
  63. cudf_polars/testing/asserts.py +13 -8
  64. cudf_polars/testing/io.py +2 -1
  65. cudf_polars/testing/plugin.py +93 -17
  66. cudf_polars/typing/__init__.py +86 -32
  67. cudf_polars/utils/config.py +473 -58
  68. cudf_polars/utils/cuda_stream.py +70 -0
  69. cudf_polars/utils/versions.py +5 -4
  70. cudf_polars_cu13-26.2.0.dist-info/METADATA +181 -0
  71. cudf_polars_cu13-26.2.0.dist-info/RECORD +108 -0
  72. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/WHEEL +1 -1
  73. cudf_polars_cu13-25.10.0.dist-info/METADATA +0 -136
  74. cudf_polars_cu13-25.10.0.dist-info/RECORD +0 -92
  75. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/licenses/LICENSE +0 -0
  76. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/top_level.txt +0 -0
@@ -8,7 +8,9 @@ from __future__ import annotations
8
8
 
9
9
  from enum import IntEnum, auto
10
10
  from io import StringIO
11
- from typing import TYPE_CHECKING, Any, ClassVar
11
+ from typing import TYPE_CHECKING, Any, ClassVar, cast
12
+
13
+ import polars as pl
12
14
 
13
15
  import pylibcudf as plc
14
16
 
@@ -18,7 +20,7 @@ from cudf_polars.dsl.expressions.base import ExecutionContext, Expr
18
20
  if TYPE_CHECKING:
19
21
  from typing_extensions import Self
20
22
 
21
- from polars.polars import _expr_nodes as pl_expr
23
+ from polars import polars # type: ignore[attr-defined]
22
24
 
23
25
  from cudf_polars.containers import DataFrame, DataType
24
26
 
@@ -42,7 +44,7 @@ class StructFunction(Expr):
42
44
  ) # https://github.com/pola-rs/polars/pull/23022#issuecomment-2933910958
43
45
 
44
46
  @classmethod
45
- def from_polars(cls, obj: pl_expr.StructFunction) -> Self:
47
+ def from_polars(cls, obj: polars._expr_nodes.StructFunction) -> Self:
46
48
  """Convert from polars' `StructFunction`."""
47
49
  try:
48
50
  function, name = str(obj).split(".", maxsplit=1)
@@ -87,11 +89,14 @@ class StructFunction(Expr):
87
89
  """Evaluate this expression given a dataframe for context."""
88
90
  columns = [child.evaluate(df, context=context) for child in self.children]
89
91
  (column,) = columns
92
+ # Type checker doesn't know polars only calls StructFunction with struct types
90
93
  if self.name == StructFunction.Name.FieldByName:
91
94
  field_index = next(
92
95
  (
93
96
  i
94
- for i, field in enumerate(self.children[0].dtype.polars.fields)
97
+ for i, field in enumerate(
98
+ cast(pl.Struct, self.children[0].dtype.polars_type).fields
99
+ )
95
100
  if field.name == self.options[0]
96
101
  ),
97
102
  None,
@@ -109,7 +114,12 @@ class StructFunction(Expr):
109
114
  table = plc.Table(column.obj.children())
110
115
  metadata = plc.io.TableWithMetadata(
111
116
  table,
112
- [(field.name, []) for field in self.children[0].dtype.polars.fields],
117
+ [
118
+ (field.name, [])
119
+ for field in cast(
120
+ pl.Struct, self.children[0].dtype.polars_type
121
+ ).fields
122
+ ],
113
123
  )
114
124
  options = (
115
125
  plc.io.json.JsonWriterOptions.builder(target, table)
@@ -120,9 +130,11 @@ class StructFunction(Expr):
120
130
  .utf8_escaped(val=False)
121
131
  .build()
122
132
  )
123
- plc.io.json.write_json(options)
133
+ plc.io.json.write_json(options, stream=df.stream)
124
134
  return Column(
125
- plc.Column.from_iterable_of_py(buff.getvalue().split()),
135
+ plc.Column.from_iterable_of_py(
136
+ buff.getvalue().split(), stream=df.stream
137
+ ),
126
138
  dtype=self.dtype,
127
139
  )
128
140
  elif self.name in {
@@ -15,6 +15,7 @@ from cudf_polars.dsl.expressions.base import (
15
15
  ExecutionContext,
16
16
  Expr,
17
17
  )
18
+ from cudf_polars.dsl.utils.reshape import broadcast
18
19
 
19
20
  if TYPE_CHECKING:
20
21
  from cudf_polars.containers import DataFrame, DataType
@@ -41,9 +42,38 @@ class Ternary(Expr):
41
42
  when, then, otherwise = (
42
43
  child.evaluate(df, context=context) for child in self.children
43
44
  )
44
- then_obj = then.obj_scalar if then.is_scalar else then.obj
45
- otherwise_obj = otherwise.obj_scalar if otherwise.is_scalar else otherwise.obj
45
+
46
+ if when.is_scalar:
47
+ # For scalar predicates: lowering to copy_if_else would require
48
+ # materializing an all true/false mask column. Instead, just pick
49
+ # the correct branch.
50
+ when_predicate = when.obj_scalar(stream=df.stream).to_py(stream=df.stream)
51
+ pick, other = (then, otherwise) if when_predicate else (otherwise, then)
52
+
53
+ pick_col = (
54
+ broadcast(
55
+ pick,
56
+ target_length=1 if other.is_scalar else other.size,
57
+ stream=df.stream,
58
+ )[0]
59
+ if pick.is_scalar
60
+ else pick
61
+ )
62
+ return Column(pick_col.obj, dtype=self.dtype)
63
+
64
+ then_obj = then.obj_scalar(stream=df.stream) if then.is_scalar else then.obj
65
+ otherwise_obj = (
66
+ otherwise.obj_scalar(stream=df.stream)
67
+ if otherwise.is_scalar
68
+ else otherwise.obj
69
+ )
70
+
46
71
  return Column(
47
- plc.copying.copy_if_else(then_obj, otherwise_obj, when.obj),
72
+ plc.copying.copy_if_else(
73
+ then_obj,
74
+ otherwise_obj,
75
+ when.obj,
76
+ stream=df.stream,
77
+ ),
48
78
  dtype=self.dtype,
49
79
  )
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2026, NVIDIA CORPORATION & AFFILIATES.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
  # TODO: remove need for this
4
4
  """DSL nodes for unary operations."""
@@ -15,7 +15,6 @@ from cudf_polars.containers import Column
15
15
  from cudf_polars.dsl.expressions.base import ExecutionContext, Expr
16
16
  from cudf_polars.dsl.expressions.literal import Literal
17
17
  from cudf_polars.utils import dtypes
18
- from cudf_polars.utils.versions import POLARS_VERSION_LT_129
19
18
 
20
19
  if TYPE_CHECKING:
21
20
  from cudf_polars.containers import DataFrame, DataType
@@ -26,14 +25,15 @@ __all__ = ["Cast", "Len", "UnaryFunction"]
26
25
  class Cast(Expr):
27
26
  """Class representing a cast of an expression."""
28
27
 
29
- __slots__ = ()
30
- _non_child = ("dtype",)
28
+ __slots__ = ("strict",)
29
+ _non_child = ("dtype", "strict")
31
30
 
32
- def __init__(self, dtype: DataType, value: Expr) -> None:
31
+ def __init__(self, dtype: DataType, strict: bool, value: Expr) -> None: # noqa: FBT001
33
32
  self.dtype = dtype
33
+ self.strict = strict
34
34
  self.children = (value,)
35
35
  self.is_pointwise = True
36
- if not dtypes.can_cast(value.dtype.plc, self.dtype.plc):
36
+ if not dtypes.can_cast(value.dtype.plc_type, self.dtype.plc_type):
37
37
  raise NotImplementedError(
38
38
  f"Can't cast {value.dtype.id().name} to {self.dtype.id().name}"
39
39
  )
@@ -44,7 +44,7 @@ class Cast(Expr):
44
44
  """Evaluate this expression given a dataframe for context."""
45
45
  (child,) = self.children
46
46
  column = child.evaluate(df, context=context)
47
- return column.astype(self.dtype)
47
+ return column.astype(self.dtype, stream=df.stream, strict=self.strict)
48
48
 
49
49
 
50
50
  class Len(Expr):
@@ -61,8 +61,9 @@ class Len(Expr):
61
61
  """Evaluate this expression given a dataframe for context."""
62
62
  return Column(
63
63
  plc.Column.from_scalar(
64
- plc.Scalar.from_py(df.num_rows, self.dtype.plc),
64
+ plc.Scalar.from_py(df.num_rows, self.dtype.plc_type, stream=df.stream),
65
65
  1,
66
+ stream=df.stream,
66
67
  ),
67
68
  dtype=self.dtype,
68
69
  )
@@ -150,7 +151,7 @@ class UnaryFunction(Expr):
150
151
  )
151
152
 
152
153
  if self.name not in UnaryFunction._supported_fns:
153
- raise NotImplementedError(f"Unary function {name=}")
154
+ raise NotImplementedError(f"Unary function {name=}") # pragma: no cover
154
155
  if self.name in UnaryFunction._supported_cum_aggs:
155
156
  (reverse,) = self.options
156
157
  if reverse:
@@ -174,26 +175,25 @@ class UnaryFunction(Expr):
174
175
  """Evaluate this expression given a dataframe for context."""
175
176
  if self.name == "mask_nans":
176
177
  (child,) = self.children
177
- return child.evaluate(df, context=context).mask_nans()
178
+ return child.evaluate(df, context=context).mask_nans(stream=df.stream)
178
179
  if self.name == "null_count":
179
180
  (column,) = (child.evaluate(df, context=context) for child in self.children)
180
181
  return Column(
181
182
  plc.Column.from_scalar(
182
- plc.Scalar.from_py(column.null_count, self.dtype.plc),
183
+ plc.Scalar.from_py(
184
+ column.null_count, self.dtype.plc_type, stream=df.stream
185
+ ),
183
186
  1,
187
+ stream=df.stream,
184
188
  ),
185
189
  dtype=self.dtype,
186
190
  )
191
+ arg: plc.Column | plc.Scalar
187
192
  if self.name == "round":
188
- round_mode = "half_away_from_zero"
189
- if POLARS_VERSION_LT_129:
190
- (decimal_places,) = self.options # pragma: no cover
191
- else:
192
- # pragma: no cover
193
- (
194
- decimal_places,
195
- round_mode,
196
- ) = self.options
193
+ (
194
+ decimal_places,
195
+ round_mode,
196
+ ) = self.options
197
197
  (values,) = (child.evaluate(df, context=context) for child in self.children)
198
198
  return Column(
199
199
  plc.round.round(
@@ -204,6 +204,7 @@ class UnaryFunction(Expr):
204
204
  if round_mode == "half_to_even"
205
205
  else plc.round.RoundingMethod.HALF_UP
206
206
  ),
207
+ stream=df.stream,
207
208
  ),
208
209
  dtype=self.dtype,
209
210
  ).sorted_like(values) # pragma: no cover
@@ -215,31 +216,34 @@ class UnaryFunction(Expr):
215
216
  keep = plc.stream_compaction.DuplicateKeepOption.KEEP_ANY
216
217
  if values.is_sorted:
217
218
  maintain_order = True
218
- result = plc.stream_compaction.unique(
219
+ (compacted,) = plc.stream_compaction.unique(
219
220
  plc.Table([values.obj]),
220
221
  [0],
221
222
  keep,
222
223
  plc.types.NullEquality.EQUAL,
223
- )
224
+ stream=df.stream,
225
+ ).columns()
224
226
  else:
225
227
  distinct = (
226
228
  plc.stream_compaction.stable_distinct
227
229
  if maintain_order
228
230
  else plc.stream_compaction.distinct
229
231
  )
230
- result = distinct(
232
+ (compacted,) = distinct(
231
233
  plc.Table([values.obj]),
232
234
  [0],
233
235
  keep,
234
236
  plc.types.NullEquality.EQUAL,
235
237
  plc.types.NanEquality.ALL_EQUAL,
236
- )
237
- (column,) = result.columns()
238
- result = Column(column, dtype=self.dtype)
238
+ stream=df.stream,
239
+ ).columns()
240
+ column = Column(compacted, dtype=self.dtype)
239
241
  if maintain_order:
240
- result = result.sorted_like(values)
241
- return result
242
- elif self.name == "set_sorted":
242
+ column = column.sorted_like(values)
243
+ return column
244
+ elif self.name == "set_sorted": # pragma: no cover
245
+ # TODO: LazyFrame.set_sorted is proper IR concept (ie. FunctionIR::Hint)
246
+ # and is is currently not implemented. We should reimplement it as a MapFunction.
243
247
  (column,) = (child.evaluate(df, context=context) for child in self.children)
244
248
  (asc,) = self.options
245
249
  order = (
@@ -250,10 +254,12 @@ class UnaryFunction(Expr):
250
254
  null_order = plc.types.NullOrder.BEFORE
251
255
  if column.null_count > 0 and (n := column.size) > 1:
252
256
  # PERF: This invokes four stream synchronisations!
253
- has_nulls_first = not plc.copying.get_element(column.obj, 0).is_valid()
257
+ has_nulls_first = not plc.copying.get_element(
258
+ column.obj, 0, stream=df.stream
259
+ ).is_valid(df.stream)
254
260
  has_nulls_last = not plc.copying.get_element(
255
- column.obj, n - 1
256
- ).is_valid()
261
+ column.obj, n - 1, stream=df.stream
262
+ ).is_valid(df.stream)
257
263
  if (order == plc.types.Order.DESCENDING and has_nulls_first) or (
258
264
  order == plc.types.Order.ASCENDING and has_nulls_last
259
265
  ):
@@ -280,30 +286,43 @@ class UnaryFunction(Expr):
280
286
  counts_table,
281
287
  [plc.types.Order.DESCENDING],
282
288
  [plc.types.NullOrder.BEFORE],
289
+ stream=df.stream,
283
290
  )
284
291
  counts_table = plc.copying.gather(
285
- counts_table, sort_indices, plc.copying.OutOfBoundsPolicy.DONT_CHECK
292
+ counts_table,
293
+ sort_indices,
294
+ plc.copying.OutOfBoundsPolicy.DONT_CHECK,
295
+ stream=df.stream,
286
296
  )
287
297
  keys_table = plc.copying.gather(
288
- keys_table, sort_indices, plc.copying.OutOfBoundsPolicy.DONT_CHECK
298
+ keys_table,
299
+ sort_indices,
300
+ plc.copying.OutOfBoundsPolicy.DONT_CHECK,
301
+ stream=df.stream,
289
302
  )
290
303
  keys_col = keys_table.columns()[0]
291
304
  counts_col = counts_table.columns()[0]
292
305
  if normalize:
293
306
  total_counts = plc.reduce.reduce(
294
- counts_col, plc.aggregation.sum(), plc.DataType(plc.TypeId.UINT64)
307
+ counts_col,
308
+ plc.aggregation.sum(),
309
+ plc.DataType(plc.TypeId.UINT64),
310
+ stream=df.stream,
295
311
  )
296
312
  counts_col = plc.binaryop.binary_operation(
297
313
  counts_col,
298
314
  total_counts,
299
315
  plc.binaryop.BinaryOperator.DIV,
300
316
  plc.DataType(plc.TypeId.FLOAT64),
317
+ stream=df.stream,
301
318
  )
302
319
  elif counts_col.type().id() == plc.TypeId.INT32:
303
- counts_col = plc.unary.cast(counts_col, plc.DataType(plc.TypeId.UINT32))
320
+ counts_col = plc.unary.cast(
321
+ counts_col, plc.DataType(plc.TypeId.UINT32), stream=df.stream
322
+ )
304
323
 
305
324
  plc_column = plc.Column(
306
- self.dtype.plc,
325
+ self.dtype.plc_type,
307
326
  counts_col.size(),
308
327
  None,
309
328
  None,
@@ -318,7 +337,7 @@ class UnaryFunction(Expr):
318
337
  return column
319
338
  return Column(
320
339
  plc.stream_compaction.drop_nulls(
321
- plc.Table([column.obj]), [0], 1
340
+ plc.Table([column.obj]), [0], 1, stream=df.stream
322
341
  ).columns()[0],
323
342
  dtype=self.dtype,
324
343
  )
@@ -328,19 +347,31 @@ class UnaryFunction(Expr):
328
347
  return column
329
348
  fill_value = self.children[1]
330
349
  if isinstance(fill_value, Literal):
331
- arg = plc.Scalar.from_py(fill_value.value, fill_value.dtype.plc)
350
+ arg = plc.Scalar.from_py(
351
+ fill_value.value, fill_value.dtype.plc_type, stream=df.stream
352
+ )
332
353
  else:
333
354
  evaluated = fill_value.evaluate(df, context=context)
334
- arg = evaluated.obj_scalar if evaluated.is_scalar else evaluated.obj
355
+ arg = (
356
+ evaluated.obj_scalar(stream=df.stream)
357
+ if evaluated.is_scalar
358
+ else evaluated.obj
359
+ )
335
360
  if isinstance(arg, plc.Scalar) and dtypes.can_cast(
336
- column.dtype.plc, arg.type()
361
+ column.dtype.plc_type, arg.type()
337
362
  ): # pragma: no cover
338
363
  arg = (
339
- Column(plc.Column.from_scalar(arg, 1), dtype=fill_value.dtype)
340
- .astype(column.dtype)
341
- .obj.to_scalar()
364
+ Column(
365
+ plc.Column.from_scalar(arg, 1, stream=df.stream),
366
+ dtype=fill_value.dtype,
367
+ )
368
+ .astype(column.dtype, stream=df.stream)
369
+ .obj.to_scalar(stream=df.stream)
342
370
  )
343
- return Column(plc.replace.replace_nulls(column.obj, arg), dtype=self.dtype)
371
+ return Column(
372
+ plc.replace.replace_nulls(column.obj, arg, stream=df.stream),
373
+ dtype=self.dtype,
374
+ )
344
375
  elif self.name == "fill_null_with_strategy":
345
376
  column = self.children[0].evaluate(df, context=context)
346
377
  strategy, limit = self.options
@@ -352,6 +383,8 @@ class UnaryFunction(Expr):
352
383
  )
353
384
  ):
354
385
  return column
386
+
387
+ replacement: plc.replace.ReplacePolicy | plc.Scalar
355
388
  if strategy == "forward":
356
389
  replacement = plc.replace.ReplacePolicy.PRECEDING
357
390
  elif strategy == "backward":
@@ -360,37 +393,49 @@ class UnaryFunction(Expr):
360
393
  replacement = plc.reduce.reduce(
361
394
  column.obj,
362
395
  plc.aggregation.min(),
363
- column.dtype.plc,
396
+ column.dtype.plc_type,
397
+ stream=df.stream,
364
398
  )
365
399
  elif strategy == "max":
366
400
  replacement = plc.reduce.reduce(
367
401
  column.obj,
368
402
  plc.aggregation.max(),
369
- column.dtype.plc,
403
+ column.dtype.plc_type,
404
+ stream=df.stream,
370
405
  )
371
406
  elif strategy == "mean":
372
407
  replacement = plc.reduce.reduce(
373
408
  column.obj,
374
409
  plc.aggregation.mean(),
375
410
  plc.DataType(plc.TypeId.FLOAT64),
411
+ stream=df.stream,
376
412
  )
377
413
  elif strategy == "zero":
378
- replacement = plc.scalar.Scalar.from_py(0, dtype=column.dtype.plc)
414
+ replacement = plc.scalar.Scalar.from_py(
415
+ 0, dtype=column.dtype.plc_type, stream=df.stream
416
+ )
379
417
  elif strategy == "one":
380
- replacement = plc.scalar.Scalar.from_py(1, dtype=column.dtype.plc)
418
+ replacement = plc.scalar.Scalar.from_py(
419
+ 1, dtype=column.dtype.plc_type, stream=df.stream
420
+ )
381
421
  else:
382
422
  assert_never(strategy) # pragma: no cover
383
423
 
384
424
  if strategy == "mean":
385
425
  return Column(
386
426
  plc.replace.replace_nulls(
387
- plc.unary.cast(column.obj, plc.DataType(plc.TypeId.FLOAT64)),
427
+ plc.unary.cast(
428
+ column.obj,
429
+ plc.DataType(plc.TypeId.FLOAT64),
430
+ stream=df.stream,
431
+ ),
388
432
  replacement,
433
+ stream=df.stream,
389
434
  ),
390
435
  dtype=self.dtype,
391
- ).astype(self.dtype)
436
+ ).astype(self.dtype, stream=df.stream)
392
437
  return Column(
393
- plc.replace.replace_nulls(column.obj, replacement),
438
+ plc.replace.replace_nulls(column.obj, replacement, stream=df.stream),
394
439
  dtype=self.dtype,
395
440
  )
396
441
  elif self.name == "as_struct":
@@ -399,7 +444,7 @@ class UnaryFunction(Expr):
399
444
  ]
400
445
  return Column(
401
446
  plc.Column(
402
- data_type=self.dtype.plc,
447
+ data_type=self.dtype.plc_type,
403
448
  size=children[0].size(),
404
449
  data=None,
405
450
  mask=None,
@@ -432,19 +477,24 @@ class UnaryFunction(Expr):
432
477
  plc.types.NullPolicy.EXCLUDE,
433
478
  plc.types.NullOrder.BEFORE if descending else plc.types.NullOrder.AFTER,
434
479
  percentage=False,
480
+ stream=df.stream,
435
481
  )
436
482
 
437
483
  # Min/Max/Dense/Ordinal -> IDX_DTYPE
438
484
  # See https://github.com/pola-rs/polars/blob/main/crates/polars-ops/src/series/ops/rank.rs
439
485
  if method_str in {"min", "max", "dense", "ordinal"}:
440
- dest = self.dtype.plc.id()
486
+ dest = self.dtype.plc_type.id()
441
487
  src = ranked.type().id()
442
488
  if dest == plc.TypeId.UINT32 and src != plc.TypeId.UINT32:
443
- ranked = plc.unary.cast(ranked, plc.DataType(plc.TypeId.UINT32))
489
+ ranked = plc.unary.cast(
490
+ ranked, plc.DataType(plc.TypeId.UINT32), stream=df.stream
491
+ )
444
492
  elif (
445
493
  dest == plc.TypeId.UINT64 and src != plc.TypeId.UINT64
446
494
  ): # pragma: no cover
447
- ranked = plc.unary.cast(ranked, plc.DataType(plc.TypeId.UINT64))
495
+ ranked = plc.unary.cast(
496
+ ranked, plc.DataType(plc.TypeId.UINT64), stream=df.stream
497
+ )
448
498
 
449
499
  return Column(ranked, dtype=self.dtype)
450
500
  elif self.name == "top_k":
@@ -459,23 +509,26 @@ class UnaryFunction(Expr):
459
509
  plc.types.Order.ASCENDING
460
510
  if reverse
461
511
  else plc.types.Order.DESCENDING,
512
+ stream=df.stream,
462
513
  ),
463
514
  dtype=self.dtype,
464
515
  )
465
516
  elif self.name in self._OP_MAPPING:
466
517
  column = self.children[0].evaluate(df, context=context)
467
- if column.dtype.plc.id() != self.dtype.id():
468
- arg = plc.unary.cast(column.obj, self.dtype.plc)
518
+ if column.dtype.plc_type.id() != self.dtype.id():
519
+ arg = plc.unary.cast(column.obj, self.dtype.plc_type, stream=df.stream)
469
520
  else:
470
521
  arg = column.obj
471
522
  return Column(
472
- plc.unary.unary_operation(arg, self._OP_MAPPING[self.name]),
523
+ plc.unary.unary_operation(
524
+ arg, self._OP_MAPPING[self.name], stream=df.stream
525
+ ),
473
526
  dtype=self.dtype,
474
527
  )
475
528
  elif self.name in UnaryFunction._supported_cum_aggs:
476
529
  column = self.children[0].evaluate(df, context=context)
477
530
  plc_col = column.obj
478
- col_type = column.dtype.plc
531
+ col_type = column.dtype.plc_type
479
532
  # cum_sum casts
480
533
  # Int8, UInt8, Int16, UInt16 -> Int64 for overflow prevention
481
534
  # Bool -> UInt32
@@ -496,9 +549,16 @@ class UnaryFunction(Expr):
496
549
  and plc.traits.is_integral(col_type)
497
550
  and plc.types.size_of(col_type) <= 4
498
551
  ):
499
- plc_col = plc.unary.cast(plc_col, plc.DataType(plc.TypeId.INT64))
500
- elif self.name == "cum_sum" and column.dtype.plc.id() == plc.TypeId.BOOL8:
501
- plc_col = plc.unary.cast(plc_col, plc.DataType(plc.TypeId.UINT32))
552
+ plc_col = plc.unary.cast(
553
+ plc_col, plc.DataType(plc.TypeId.INT64), stream=df.stream
554
+ )
555
+ elif (
556
+ self.name == "cum_sum"
557
+ and column.dtype.plc_type.id() == plc.TypeId.BOOL8
558
+ ):
559
+ plc_col = plc.unary.cast(
560
+ plc_col, plc.DataType(plc.TypeId.UINT32), stream=df.stream
561
+ )
502
562
  if self.name == "cum_sum":
503
563
  agg = plc.aggregation.sum()
504
564
  elif self.name == "cum_prod":
@@ -509,7 +569,9 @@ class UnaryFunction(Expr):
509
569
  agg = plc.aggregation.max()
510
570
 
511
571
  return Column(
512
- plc.reduce.scan(plc_col, agg, plc.reduce.ScanType.INCLUSIVE),
572
+ plc.reduce.scan(
573
+ plc_col, agg, plc.reduce.ScanType.INCLUSIVE, stream=df.stream
574
+ ),
513
575
  dtype=self.dtype,
514
576
  )
515
577
  raise NotImplementedError(