cudf-polars-cu13 25.10.0__py3-none-any.whl → 26.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. cudf_polars/GIT_COMMIT +1 -1
  2. cudf_polars/VERSION +1 -1
  3. cudf_polars/callback.py +60 -15
  4. cudf_polars/containers/column.py +137 -77
  5. cudf_polars/containers/dataframe.py +123 -34
  6. cudf_polars/containers/datatype.py +134 -13
  7. cudf_polars/dsl/expr.py +0 -2
  8. cudf_polars/dsl/expressions/aggregation.py +80 -28
  9. cudf_polars/dsl/expressions/binaryop.py +34 -14
  10. cudf_polars/dsl/expressions/boolean.py +110 -37
  11. cudf_polars/dsl/expressions/datetime.py +59 -30
  12. cudf_polars/dsl/expressions/literal.py +11 -5
  13. cudf_polars/dsl/expressions/rolling.py +460 -119
  14. cudf_polars/dsl/expressions/selection.py +9 -8
  15. cudf_polars/dsl/expressions/slicing.py +1 -1
  16. cudf_polars/dsl/expressions/string.py +256 -114
  17. cudf_polars/dsl/expressions/struct.py +19 -7
  18. cudf_polars/dsl/expressions/ternary.py +33 -3
  19. cudf_polars/dsl/expressions/unary.py +126 -64
  20. cudf_polars/dsl/ir.py +1053 -350
  21. cudf_polars/dsl/to_ast.py +30 -13
  22. cudf_polars/dsl/tracing.py +194 -0
  23. cudf_polars/dsl/translate.py +307 -107
  24. cudf_polars/dsl/utils/aggregations.py +43 -30
  25. cudf_polars/dsl/utils/reshape.py +14 -2
  26. cudf_polars/dsl/utils/rolling.py +12 -8
  27. cudf_polars/dsl/utils/windows.py +35 -20
  28. cudf_polars/experimental/base.py +55 -2
  29. cudf_polars/experimental/benchmarks/pdsds.py +12 -126
  30. cudf_polars/experimental/benchmarks/pdsh.py +792 -2
  31. cudf_polars/experimental/benchmarks/utils.py +596 -39
  32. cudf_polars/experimental/dask_registers.py +47 -20
  33. cudf_polars/experimental/dispatch.py +9 -3
  34. cudf_polars/experimental/distinct.py +2 -0
  35. cudf_polars/experimental/explain.py +15 -2
  36. cudf_polars/experimental/expressions.py +30 -15
  37. cudf_polars/experimental/groupby.py +25 -4
  38. cudf_polars/experimental/io.py +156 -124
  39. cudf_polars/experimental/join.py +53 -23
  40. cudf_polars/experimental/parallel.py +68 -19
  41. cudf_polars/experimental/rapidsmpf/__init__.py +8 -0
  42. cudf_polars/experimental/rapidsmpf/collectives/__init__.py +9 -0
  43. cudf_polars/experimental/rapidsmpf/collectives/allgather.py +90 -0
  44. cudf_polars/experimental/rapidsmpf/collectives/common.py +96 -0
  45. cudf_polars/experimental/rapidsmpf/collectives/shuffle.py +253 -0
  46. cudf_polars/experimental/rapidsmpf/core.py +488 -0
  47. cudf_polars/experimental/rapidsmpf/dask.py +172 -0
  48. cudf_polars/experimental/rapidsmpf/dispatch.py +153 -0
  49. cudf_polars/experimental/rapidsmpf/io.py +696 -0
  50. cudf_polars/experimental/rapidsmpf/join.py +322 -0
  51. cudf_polars/experimental/rapidsmpf/lower.py +74 -0
  52. cudf_polars/experimental/rapidsmpf/nodes.py +735 -0
  53. cudf_polars/experimental/rapidsmpf/repartition.py +216 -0
  54. cudf_polars/experimental/rapidsmpf/union.py +115 -0
  55. cudf_polars/experimental/rapidsmpf/utils.py +374 -0
  56. cudf_polars/experimental/repartition.py +9 -2
  57. cudf_polars/experimental/select.py +177 -14
  58. cudf_polars/experimental/shuffle.py +46 -12
  59. cudf_polars/experimental/sort.py +100 -26
  60. cudf_polars/experimental/spilling.py +1 -1
  61. cudf_polars/experimental/statistics.py +24 -5
  62. cudf_polars/experimental/utils.py +25 -7
  63. cudf_polars/testing/asserts.py +13 -8
  64. cudf_polars/testing/io.py +2 -1
  65. cudf_polars/testing/plugin.py +93 -17
  66. cudf_polars/typing/__init__.py +86 -32
  67. cudf_polars/utils/config.py +473 -58
  68. cudf_polars/utils/cuda_stream.py +70 -0
  69. cudf_polars/utils/versions.py +5 -4
  70. cudf_polars_cu13-26.2.0.dist-info/METADATA +181 -0
  71. cudf_polars_cu13-26.2.0.dist-info/RECORD +108 -0
  72. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/WHEEL +1 -1
  73. cudf_polars_cu13-25.10.0.dist-info/METADATA +0 -136
  74. cudf_polars_cu13-25.10.0.dist-info/RECORD +0 -92
  75. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/licenses/LICENSE +0 -0
  76. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/top_level.txt +0 -0
@@ -5,13 +5,12 @@
5
5
  from __future__ import annotations
6
6
 
7
7
  import dataclasses
8
- import enum
9
8
  import functools
10
9
  import itertools
11
10
  import math
12
11
  import statistics
13
12
  from collections import defaultdict
14
- from enum import IntEnum
13
+ from functools import partial
15
14
  from pathlib import Path
16
15
  from typing import TYPE_CHECKING, Any
17
16
 
@@ -19,24 +18,35 @@ import polars as pl
19
18
 
20
19
  import pylibcudf as plc
21
20
 
22
- from cudf_polars.dsl.ir import IR, DataFrameScan, Empty, Scan, Sink, Union
21
+ from cudf_polars.dsl.ir import (
22
+ IR,
23
+ DataFrameScan,
24
+ Empty,
25
+ Scan,
26
+ Sink,
27
+ Union,
28
+ )
23
29
  from cudf_polars.experimental.base import (
24
30
  ColumnSourceInfo,
25
31
  ColumnStat,
26
32
  ColumnStats,
27
33
  DataSourceInfo,
28
34
  DataSourcePair,
35
+ IOPartitionFlavor,
36
+ IOPartitionPlan,
29
37
  PartitionInfo,
30
38
  UniqueStats,
31
39
  get_key_name,
32
40
  )
33
41
  from cudf_polars.experimental.dispatch import generate_ir_tasks, lower_ir_node
42
+ from cudf_polars.utils.cuda_stream import get_cuda_stream
34
43
 
35
44
  if TYPE_CHECKING:
36
45
  from collections.abc import Hashable, MutableMapping
37
46
 
38
47
  from cudf_polars.containers import DataFrame
39
48
  from cudf_polars.dsl.expr import NamedExpr
49
+ from cudf_polars.dsl.ir import IRExecutionContext
40
50
  from cudf_polars.experimental.base import StatsCollector
41
51
  from cudf_polars.experimental.dispatch import LowerIRTransformer
42
52
  from cudf_polars.typing import Schema
@@ -80,73 +90,40 @@ def _(
80
90
  return ir, {ir: PartitionInfo(count=1)}
81
91
 
82
92
 
83
- class ScanPartitionFlavor(IntEnum):
84
- """Flavor of Scan partitioning."""
85
-
86
- SINGLE_FILE = enum.auto() # 1:1 mapping between files and partitions
87
- SPLIT_FILES = enum.auto() # Split each file into >1 partition
88
- FUSED_FILES = enum.auto() # Fuse multiple files into each partition
89
-
90
-
91
- class ScanPartitionPlan:
92
- """
93
- Scan partitioning plan.
94
-
95
- Notes
96
- -----
97
- The meaning of `factor` depends on the value of `flavor`:
98
- - SINGLE_FILE: `factor` must be `1`.
99
- - SPLIT_FILES: `factor` is the number of partitions per file.
100
- - FUSED_FILES: `factor` is the number of files per partition.
101
- """
102
-
103
- __slots__ = ("factor", "flavor")
104
- factor: int
105
- flavor: ScanPartitionFlavor
106
-
107
- def __init__(self, factor: int, flavor: ScanPartitionFlavor) -> None:
108
- if (
109
- flavor == ScanPartitionFlavor.SINGLE_FILE and factor != 1
110
- ): # pragma: no cover
111
- raise ValueError(f"Expected factor == 1 for {flavor}, got: {factor}")
112
- self.factor = factor
113
- self.flavor = flavor
114
-
115
- @staticmethod
116
- def from_scan(
117
- ir: Scan, stats: StatsCollector, config_options: ConfigOptions
118
- ) -> ScanPartitionPlan:
119
- """Extract the partitioning plan of a Scan operation."""
120
- if ir.typ == "parquet":
121
- # TODO: Use system info to set default blocksize
122
- assert config_options.executor.name == "streaming", (
123
- "'in-memory' executor not supported in 'generate_ir_tasks'"
124
- )
93
+ def scan_partition_plan(
94
+ ir: Scan, stats: StatsCollector, config_options: ConfigOptions
95
+ ) -> IOPartitionPlan:
96
+ """Extract the partitioning plan of a Scan operation."""
97
+ if ir.typ == "parquet":
98
+ # TODO: Use system info to set default blocksize
99
+ assert config_options.executor.name == "streaming", (
100
+ "'in-memory' executor not supported in 'generate_ir_tasks'"
101
+ )
125
102
 
126
- blocksize: int = config_options.executor.target_partition_size
127
- column_stats = stats.column_stats.get(ir, {})
128
- column_sizes: list[int] = []
129
- for cs in column_stats.values():
130
- storage_size = cs.source_info.storage_size
131
- if storage_size.value is not None:
132
- column_sizes.append(storage_size.value)
133
-
134
- if (file_size := sum(column_sizes)) > 0:
135
- if file_size > blocksize:
136
- # Split large files
137
- return ScanPartitionPlan(
138
- math.ceil(file_size / blocksize),
139
- ScanPartitionFlavor.SPLIT_FILES,
140
- )
141
- else:
142
- # Fuse small files
143
- return ScanPartitionPlan(
144
- max(blocksize // int(file_size), 1),
145
- ScanPartitionFlavor.FUSED_FILES,
146
- )
103
+ blocksize: int = config_options.executor.target_partition_size
104
+ column_stats = stats.column_stats.get(ir, {})
105
+ column_sizes: list[int] = []
106
+ for cs in column_stats.values():
107
+ storage_size = cs.source_info.storage_size
108
+ if storage_size.value is not None:
109
+ column_sizes.append(storage_size.value)
110
+
111
+ if (file_size := sum(column_sizes)) > 0:
112
+ if file_size > blocksize:
113
+ # Split large files
114
+ return IOPartitionPlan(
115
+ math.ceil(file_size / blocksize),
116
+ IOPartitionFlavor.SPLIT_FILES,
117
+ )
118
+ else:
119
+ # Fuse small files
120
+ return IOPartitionPlan(
121
+ max(blocksize // int(file_size), 1),
122
+ IOPartitionFlavor.FUSED_FILES,
123
+ )
147
124
 
148
- # TODO: Use file sizes for csv and json
149
- return ScanPartitionPlan(1, ScanPartitionFlavor.SINGLE_FILE)
125
+ # TODO: Use file sizes for csv and json
126
+ return IOPartitionPlan(1, IOPartitionFlavor.SINGLE_FILE)
150
127
 
151
128
 
152
129
  class SplitScan(IR):
@@ -222,6 +199,8 @@ class SplitScan(IR):
222
199
  include_file_paths: str | None,
223
200
  predicate: NamedExpr | None,
224
201
  parquet_options: ParquetOptions,
202
+ *,
203
+ context: IRExecutionContext,
225
204
  ) -> DataFrame:
226
205
  """Evaluate and return a dataframe."""
227
206
  if typ not in ("parquet",): # pragma: no cover
@@ -282,6 +261,7 @@ class SplitScan(IR):
282
261
  include_file_paths,
283
262
  predicate,
284
263
  parquet_options,
264
+ context=context,
285
265
  )
286
266
 
287
267
 
@@ -304,9 +284,9 @@ def _(
304
284
  and ir.skip_rows == 0
305
285
  and ir.row_index is None
306
286
  ):
307
- plan = ScanPartitionPlan.from_scan(ir, rec.state["stats"], config_options)
287
+ plan = scan_partition_plan(ir, rec.state["stats"], config_options)
308
288
  paths = list(ir.paths)
309
- if plan.flavor == ScanPartitionFlavor.SPLIT_FILES:
289
+ if plan.flavor == IOPartitionFlavor.SPLIT_FILES:
310
290
  # Disable chunked reader when splitting files
311
291
  parquet_options = dataclasses.replace(
312
292
  config_options.parquet_options,
@@ -435,9 +415,12 @@ def _sink_to_directory(
435
415
  options: dict[str, Any],
436
416
  df: DataFrame,
437
417
  ready: None,
418
+ context: IRExecutionContext,
438
419
  ) -> DataFrame:
439
420
  """Sink a partition to a new file."""
440
- return Sink.do_evaluate(schema, kind, path, parquet_options, options, df)
421
+ return Sink.do_evaluate(
422
+ schema, kind, path, parquet_options, options, df, context=context
423
+ )
441
424
 
442
425
 
443
426
  def _sink_to_parquet_file(
@@ -456,7 +439,9 @@ def _sink_to_parquet_file(
456
439
  plc.io.parquet.ChunkedParquetWriterOptions.builder(sink), options
457
440
  )
458
441
  writer_options = builder.metadata(metadata).build()
459
- writer = plc.io.parquet.ChunkedParquetWriter.from_options(writer_options)
442
+ writer = plc.io.parquet.ChunkedParquetWriter.from_options(
443
+ writer_options, stream=df.stream
444
+ )
460
445
 
461
446
  # Append to the open Parquet file.
462
447
  assert isinstance(writer, plc.io.parquet.ChunkedParquetWriter), (
@@ -499,12 +484,14 @@ def _sink_to_file(
499
484
  mode = "ab"
500
485
  use_options["include_header"] = False
501
486
  with Path.open(Path(path), mode) as f:
502
- sink = plc.io.types.SinkInfo([f])
487
+ # Path.open returns IO[Any] but SinkInfo needs more specific IO types
488
+ sink = plc.io.types.SinkInfo([f]) # type: ignore[arg-type]
503
489
  Sink._write_csv(sink, use_options, df)
504
490
  elif kind == "Json":
505
491
  mode = "wb" if writer_state is None else "ab"
506
492
  with Path.open(Path(path), mode) as f:
507
- sink = plc.io.types.SinkInfo([f])
493
+ # Path.open returns IO[Any] but SinkInfo needs more specific IO types
494
+ sink = plc.io.types.SinkInfo([f]) # type: ignore[arg-type]
508
495
  Sink._write_json(sink, df)
509
496
  else: # pragma: no cover; Shouldn't get here.
510
497
  raise NotImplementedError(f"{kind} not yet supported in _sink_to_file")
@@ -516,7 +503,9 @@ def _sink_to_file(
516
503
 
517
504
 
518
505
  def _file_sink_graph(
519
- ir: StreamingSink, partition_info: MutableMapping[IR, PartitionInfo]
506
+ ir: StreamingSink,
507
+ partition_info: MutableMapping[IR, PartitionInfo],
508
+ context: IRExecutionContext,
520
509
  ) -> MutableMapping[Any, Any]:
521
510
  """Sink to a single file."""
522
511
  name = get_key_name(ir)
@@ -526,7 +515,7 @@ def _file_sink_graph(
526
515
  if count == 1:
527
516
  return {
528
517
  (name, 0): (
529
- sink.do_evaluate,
518
+ partial(sink.do_evaluate, context=context),
530
519
  *sink._non_child_args,
531
520
  (child_name, 0),
532
521
  )
@@ -552,7 +541,9 @@ def _file_sink_graph(
552
541
 
553
542
 
554
543
  def _directory_sink_graph(
555
- ir: StreamingSink, partition_info: MutableMapping[IR, PartitionInfo]
544
+ ir: StreamingSink,
545
+ partition_info: MutableMapping[IR, PartitionInfo],
546
+ context: IRExecutionContext,
556
547
  ) -> MutableMapping[Any, Any]:
557
548
  """Sink to a directory of files."""
558
549
  name = get_key_name(ir)
@@ -573,6 +564,7 @@ def _directory_sink_graph(
573
564
  sink.options,
574
565
  (child_name, i),
575
566
  setup_name,
567
+ context,
576
568
  )
577
569
  for i in range(count)
578
570
  }
@@ -582,12 +574,14 @@ def _directory_sink_graph(
582
574
 
583
575
  @generate_ir_tasks.register(StreamingSink)
584
576
  def _(
585
- ir: StreamingSink, partition_info: MutableMapping[IR, PartitionInfo]
577
+ ir: StreamingSink,
578
+ partition_info: MutableMapping[IR, PartitionInfo],
579
+ context: IRExecutionContext,
586
580
  ) -> MutableMapping[Any, Any]:
587
581
  if ir.executor_options.sink_to_directory:
588
- return _directory_sink_graph(ir, partition_info)
582
+ return _directory_sink_graph(ir, partition_info, context=context)
589
583
  else:
590
- return _file_sink_graph(ir, partition_info)
584
+ return _file_sink_graph(ir, partition_info, context=context)
591
585
 
592
586
 
593
587
  class ParquetMetadata:
@@ -715,6 +709,8 @@ class ParquetSourceInfo(DataSourceInfo):
715
709
  # Helper attributes
716
710
  self._key_columns: set[str] = set() # Used to fuse lazy row-group sampling
717
711
  self._unique_stats: dict[str, UniqueStats] = {}
712
+ self._read_columns: set[str] = set()
713
+ self._real_rg_size: dict[str, int] = {}
718
714
 
719
715
  @functools.cached_property
720
716
  def metadata(self) -> ParquetMetadata:
@@ -737,11 +733,13 @@ class ParquetSourceInfo(DataSourceInfo):
737
733
  return
738
734
 
739
735
  column_names = self.metadata.column_names
740
- if not (
741
- key_columns := [key for key in self._key_columns if key in column_names]
742
- ): # pragma: no cover; should never get here
743
- # No key columns found in the file
744
- raise ValueError(f"None of {self._key_columns} in {column_names}")
736
+ key_columns = [key for key in self._key_columns if key in column_names]
737
+ read_columns = list(
738
+ self._read_columns.intersection(column_names).union(key_columns)
739
+ )
740
+ if not read_columns: # pragma: no cover; should never get here
741
+ # No key columns or read columns found in the file
742
+ raise ValueError(f"None of {read_columns} in {column_names}")
745
743
 
746
744
  sampled_file_count = len(sample_paths)
747
745
  num_row_groups_per_file = self.metadata.num_row_groups_per_file
@@ -751,15 +749,15 @@ class ParquetSourceInfo(DataSourceInfo):
751
749
  ):
752
750
  raise ValueError("Parquet metadata sampling failed.") # pragma: no cover
753
751
 
754
- n = 0
752
+ n_sampled = 0
755
753
  samples: defaultdict[str, list[int]] = defaultdict(list)
756
754
  for path, num_rgs in zip(sample_paths, num_row_groups_per_file, strict=True):
757
755
  for rg_id in range(num_rgs):
758
- n += 1
756
+ n_sampled += 1
759
757
  samples[path].append(rg_id)
760
- if n == self.max_row_group_samples:
758
+ if n_sampled == self.max_row_group_samples:
761
759
  break
762
- if n == self.max_row_group_samples:
760
+ if n_sampled == self.max_row_group_samples:
763
761
  break
764
762
 
765
763
  exact = sampled_file_count == len(
@@ -769,36 +767,43 @@ class ParquetSourceInfo(DataSourceInfo):
769
767
  options = plc.io.parquet.ParquetReaderOptions.builder(
770
768
  plc.io.SourceInfo(list(samples))
771
769
  ).build()
772
- options.set_columns(key_columns)
770
+ options.set_columns(read_columns)
773
771
  options.set_row_groups(list(samples.values()))
774
- tbl_w_meta = plc.io.parquet.read_parquet(options)
772
+ stream = get_cuda_stream()
773
+ tbl_w_meta = plc.io.parquet.read_parquet(options, stream=stream)
775
774
  row_group_num_rows = tbl_w_meta.tbl.num_rows()
776
775
  for name, column in zip(
777
- tbl_w_meta.column_names(), tbl_w_meta.columns, strict=True
776
+ tbl_w_meta.column_names(include_children=False),
777
+ tbl_w_meta.columns,
778
+ strict=True,
778
779
  ):
779
- row_group_unique_count = plc.stream_compaction.distinct_count(
780
- column,
781
- plc.types.NullPolicy.INCLUDE,
782
- plc.types.NanPolicy.NAN_IS_NULL,
783
- )
784
- fraction = row_group_unique_count / row_group_num_rows
785
- # Assume that if every row is unique then this is a
786
- # primary key otherwise it's a foreign key and we
787
- # can't use the single row group count estimate.
788
- # Example, consider a "foreign" key that has 100
789
- # unique values. If we sample from a single row group,
790
- # we likely obtain a unique count of 100. But we can't
791
- # necessarily deduce that that means that the unique
792
- # count is 100 / num_rows_in_group * num_rows_in_file
793
- count: int | None = None
794
- if exact:
795
- count = row_group_unique_count
796
- elif row_group_unique_count == row_group_num_rows:
797
- count = self.row_count.value
798
- self._unique_stats[name] = UniqueStats(
799
- ColumnStat[int](value=count, exact=exact),
800
- ColumnStat[float](value=fraction, exact=exact),
801
- )
780
+ self._real_rg_size[name] = column.device_buffer_size() // n_sampled
781
+ if name in key_columns:
782
+ row_group_unique_count = plc.stream_compaction.distinct_count(
783
+ column,
784
+ plc.types.NullPolicy.INCLUDE,
785
+ plc.types.NanPolicy.NAN_IS_NULL,
786
+ stream=stream,
787
+ )
788
+ fraction = row_group_unique_count / row_group_num_rows
789
+ # Assume that if every row is unique then this is a
790
+ # primary key otherwise it's a foreign key and we
791
+ # can't use the single row group count estimate.
792
+ # Example, consider a "foreign" key that has 100
793
+ # unique values. If we sample from a single row group,
794
+ # we likely obtain a unique count of 100. But we can't
795
+ # necessarily deduce that that means that the unique
796
+ # count is 100 / num_rows_in_group * num_rows_in_file
797
+ count: int | None = None
798
+ if exact:
799
+ count = row_group_unique_count
800
+ elif row_group_unique_count == row_group_num_rows:
801
+ count = self.row_count.value
802
+ self._unique_stats[name] = UniqueStats(
803
+ ColumnStat[int](value=count, exact=exact),
804
+ ColumnStat[float](value=fraction, exact=exact),
805
+ )
806
+ stream.synchronize()
802
807
 
803
808
  def _update_unique_stats(self, column: str) -> None:
804
809
  if column not in self._unique_stats and column in self.metadata.column_names:
@@ -813,7 +818,27 @@ class ParquetSourceInfo(DataSourceInfo):
813
818
 
814
819
  def storage_size(self, column: str) -> ColumnStat[int]:
815
820
  """Return the average column size for a single file."""
816
- return self.metadata.mean_size_per_file.get(column, ColumnStat[int]())
821
+ file_count = len(self.paths)
822
+ row_count = self.row_count.value
823
+ partial_mean_size = self.metadata.mean_size_per_file.get(
824
+ column, ColumnStat[int]()
825
+ ).value
826
+ if file_count and row_count and partial_mean_size:
827
+ # NOTE: We set a lower bound on the estimated size using
828
+ # the row count, because dictionary encoding can make the
829
+ # in-memory size much larger.
830
+ min_value = max(1, row_count // file_count)
831
+ if partial_mean_size < min_value and column not in self._real_rg_size:
832
+ # If the metadata is suspiciously small,
833
+ # sample "real" data to get a better estimate.
834
+ self._sample_row_groups()
835
+ if column in self._real_rg_size:
836
+ partial_mean_size = int(
837
+ self._real_rg_size[column]
838
+ * statistics.mean(self.metadata.num_row_groups_per_file)
839
+ )
840
+ return ColumnStat[int](max(min_value, partial_mean_size))
841
+ return ColumnStat[int]()
817
842
 
818
843
  def add_unique_stats_column(self, column: str) -> None:
819
844
  """Add a column needing unique-value information."""
@@ -853,14 +878,19 @@ def _extract_scan_stats(
853
878
  config_options.parquet_options.max_row_group_samples,
854
879
  config_options.executor.stats_planning,
855
880
  )
856
- return {
881
+ cstats = {
857
882
  name: ColumnStats(
858
883
  name=name,
859
884
  source_info=ColumnSourceInfo(DataSourcePair(table_source_info, name)),
860
885
  )
861
886
  for name in ir.schema
862
887
  }
863
-
888
+ # Mark all columns that we are reading in case
889
+ # we need to sample real data later.
890
+ if config_options.executor.stats_planning.use_sampling:
891
+ for name, cs in cstats.items():
892
+ cs.source_info.add_read_column(name)
893
+ return cstats
864
894
  else:
865
895
  return {name: ColumnStats(name=name) for name in ir.schema}
866
896
 
@@ -879,10 +909,10 @@ class DataFrameSourceInfo(DataSourceInfo):
879
909
 
880
910
  def __init__(
881
911
  self,
882
- df: Any,
912
+ df: pl.DataFrame,
883
913
  stats_planning: StatsPlanningOptions,
884
914
  ):
885
- self._df = df
915
+ self._pdf = df
886
916
  self._stats_planning = stats_planning
887
917
  self._key_columns: set[str] = set()
888
918
  self._unique_stats_columns = set()
@@ -891,17 +921,19 @@ class DataFrameSourceInfo(DataSourceInfo):
891
921
  @functools.cached_property
892
922
  def row_count(self) -> ColumnStat[int]:
893
923
  """Data source row-count estimate."""
894
- return ColumnStat[int](value=self._df.height(), exact=True)
924
+ return ColumnStat[int](value=self._pdf.height, exact=True)
895
925
 
896
926
  def _update_unique_stats(self, column: str) -> None:
897
927
  if column not in self._unique_stats and self._stats_planning.use_sampling:
898
928
  row_count = self.row_count.value
899
929
  try:
900
930
  unique_count = (
901
- self._df.get_column(column).approx_n_unique() if row_count else 0
931
+ self._pdf._df.get_column(column).approx_n_unique()
932
+ if row_count
933
+ else 0
902
934
  )
903
935
  except pl.exceptions.InvalidOperationError: # pragma: no cover
904
- unique_count = self._df.get_column(column).n_unique()
936
+ unique_count = self._pdf._df.get_column(column).n_unique()
905
937
  unique_fraction = min((unique_count / row_count), 1.0) if row_count else 1.0
906
938
  self._unique_stats[column] = UniqueStats(
907
939
  ColumnStat[int](value=unique_count),
@@ -922,7 +954,7 @@ def _extract_dataframescan_stats(
922
954
  "Only streaming executor is supported in _extract_dataframescan_stats"
923
955
  )
924
956
  table_source_info = DataFrameSourceInfo(
925
- ir.df,
957
+ pl.DataFrame._from_pydf(ir.df),
926
958
  config_options.executor.stats_planning,
927
959
  )
928
960
  return {
@@ -5,7 +5,7 @@
5
5
  from __future__ import annotations
6
6
 
7
7
  import operator
8
- from functools import reduce
8
+ from functools import partial, reduce
9
9
  from typing import TYPE_CHECKING, Any
10
10
 
11
11
  from cudf_polars.dsl.ir import ConditionalJoin, Join, Slice
@@ -19,9 +19,9 @@ if TYPE_CHECKING:
19
19
  from collections.abc import MutableMapping
20
20
 
21
21
  from cudf_polars.dsl.expr import NamedExpr
22
- from cudf_polars.dsl.ir import IR
22
+ from cudf_polars.dsl.ir import IR, IRExecutionContext
23
23
  from cudf_polars.experimental.parallel import LowerIRTransformer
24
- from cudf_polars.utils.config import ShuffleMethod
24
+ from cudf_polars.utils.config import ShuffleMethod, ShufflerInsertionMethod
25
25
 
26
26
 
27
27
  def _maybe_shuffle_frame(
@@ -30,6 +30,8 @@ def _maybe_shuffle_frame(
30
30
  partition_info: MutableMapping[IR, PartitionInfo],
31
31
  shuffle_method: ShuffleMethod,
32
32
  output_count: int,
33
+ *,
34
+ shuffler_insertion_method: ShufflerInsertionMethod,
33
35
  ) -> IR:
34
36
  # Shuffle `frame` if it isn't already shuffled.
35
37
  if (
@@ -44,6 +46,7 @@ def _maybe_shuffle_frame(
44
46
  frame.schema,
45
47
  on,
46
48
  shuffle_method,
49
+ shuffler_insertion_method,
47
50
  frame,
48
51
  )
49
52
  partition_info[frame] = PartitionInfo(
@@ -60,6 +63,8 @@ def _make_hash_join(
60
63
  left: IR,
61
64
  right: IR,
62
65
  shuffle_method: ShuffleMethod,
66
+ *,
67
+ shuffler_insertion_method: ShufflerInsertionMethod,
63
68
  ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
64
69
  # Shuffle left and right dataframes (if necessary)
65
70
  new_left = _maybe_shuffle_frame(
@@ -68,6 +73,7 @@ def _make_hash_join(
68
73
  partition_info,
69
74
  shuffle_method,
70
75
  output_count,
76
+ shuffler_insertion_method=shuffler_insertion_method,
71
77
  )
72
78
  new_right = _maybe_shuffle_frame(
73
79
  right,
@@ -75,6 +81,7 @@ def _make_hash_join(
75
81
  partition_info,
76
82
  shuffle_method,
77
83
  output_count,
84
+ shuffler_insertion_method=shuffler_insertion_method,
78
85
  )
79
86
  if left != new_left or right != new_right:
80
87
  ir = ir.reconstruct([new_left, new_right])
@@ -144,6 +151,9 @@ def _make_bcast_join(
144
151
  left: IR,
145
152
  right: IR,
146
153
  shuffle_method: ShuffleMethod,
154
+ *,
155
+ streaming_runtime: str,
156
+ shuffler_insertion_method: ShufflerInsertionMethod,
147
157
  ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
148
158
  if ir.options[0] != "Inner":
149
159
  left_count = partition_info[left].count
@@ -162,22 +172,25 @@ def _make_bcast_join(
162
172
  # - In some cases, we can perform the partial joins
163
173
  # sequentially. However, we are starting with a
164
174
  # catch-all algorithm that works for all cases.
165
- if left_count >= right_count:
166
- right = _maybe_shuffle_frame(
167
- right,
168
- ir.right_on,
169
- partition_info,
170
- shuffle_method,
171
- right_count,
172
- )
173
- else:
174
- left = _maybe_shuffle_frame(
175
- left,
176
- ir.left_on,
177
- partition_info,
178
- shuffle_method,
179
- left_count,
180
- )
175
+ if streaming_runtime == "tasks":
176
+ if left_count >= right_count:
177
+ right = _maybe_shuffle_frame(
178
+ right,
179
+ ir.right_on,
180
+ partition_info,
181
+ shuffle_method,
182
+ right_count,
183
+ shuffler_insertion_method=shuffler_insertion_method,
184
+ )
185
+ else:
186
+ left = _maybe_shuffle_frame(
187
+ left,
188
+ ir.left_on,
189
+ partition_info,
190
+ shuffle_method,
191
+ left_count,
192
+ shuffler_insertion_method=shuffler_insertion_method,
193
+ )
181
194
 
182
195
  new_node = ir.reconstruct([left, right])
183
196
  partition_info[new_node] = PartitionInfo(count=output_count)
@@ -263,6 +276,15 @@ def _(
263
276
  assert config_options.executor.name == "streaming", (
264
277
  "'in-memory' executor not supported in 'lower_join'"
265
278
  )
279
+
280
+ maintain_order = ir.options[5]
281
+ if maintain_order != "none" and output_count > 1:
282
+ return _lower_ir_fallback(
283
+ ir,
284
+ rec,
285
+ msg=f"Join({maintain_order=}) not supported for multiple partitions.",
286
+ )
287
+
266
288
  if _should_bcast_join(
267
289
  ir,
268
290
  left,
@@ -279,6 +301,8 @@ def _(
279
301
  left,
280
302
  right,
281
303
  config_options.executor.shuffle_method,
304
+ streaming_runtime=config_options.executor.runtime,
305
+ shuffler_insertion_method=config_options.executor.shuffler_insertion_method,
282
306
  )
283
307
  else:
284
308
  # Create a hash join
@@ -289,12 +313,15 @@ def _(
289
313
  left,
290
314
  right,
291
315
  config_options.executor.shuffle_method,
316
+ shuffler_insertion_method=config_options.executor.shuffler_insertion_method,
292
317
  )
293
318
 
294
319
 
295
320
  @generate_ir_tasks.register(Join)
296
321
  def _(
297
- ir: Join, partition_info: MutableMapping[IR, PartitionInfo]
322
+ ir: Join,
323
+ partition_info: MutableMapping[IR, PartitionInfo],
324
+ context: IRExecutionContext,
298
325
  ) -> MutableMapping[Any, Any]:
299
326
  left, right = ir.children
300
327
  output_count = partition_info[ir].count
@@ -314,7 +341,7 @@ def _(
314
341
  right_name = get_key_name(right)
315
342
  return {
316
343
  key: (
317
- ir.do_evaluate,
344
+ partial(ir.do_evaluate, context=context),
318
345
  *ir._non_child_args,
319
346
  (left_name, i),
320
347
  (right_name, i),
@@ -376,7 +403,7 @@ def _(
376
403
 
377
404
  inter_key = (inter_name, part_out, j)
378
405
  graph[(inter_name, part_out, j)] = (
379
- ir.do_evaluate,
406
+ partial(ir.do_evaluate, context=context),
380
407
  ir.left_on,
381
408
  ir.right_on,
382
409
  ir.options,
@@ -386,6 +413,9 @@ def _(
386
413
  if len(_concat_list) == 1:
387
414
  graph[(out_name, part_out)] = graph.pop(_concat_list[0])
388
415
  else:
389
- graph[(out_name, part_out)] = (_concat, *_concat_list)
416
+ graph[(out_name, part_out)] = (
417
+ partial(_concat, context=context),
418
+ *_concat_list,
419
+ )
390
420
 
391
421
  return graph