cudf-polars-cu13 25.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. cudf_polars/GIT_COMMIT +1 -0
  2. cudf_polars/VERSION +1 -0
  3. cudf_polars/__init__.py +28 -0
  4. cudf_polars/_version.py +21 -0
  5. cudf_polars/callback.py +318 -0
  6. cudf_polars/containers/__init__.py +13 -0
  7. cudf_polars/containers/column.py +495 -0
  8. cudf_polars/containers/dataframe.py +361 -0
  9. cudf_polars/containers/datatype.py +137 -0
  10. cudf_polars/dsl/__init__.py +8 -0
  11. cudf_polars/dsl/expr.py +66 -0
  12. cudf_polars/dsl/expressions/__init__.py +8 -0
  13. cudf_polars/dsl/expressions/aggregation.py +226 -0
  14. cudf_polars/dsl/expressions/base.py +272 -0
  15. cudf_polars/dsl/expressions/binaryop.py +120 -0
  16. cudf_polars/dsl/expressions/boolean.py +326 -0
  17. cudf_polars/dsl/expressions/datetime.py +271 -0
  18. cudf_polars/dsl/expressions/literal.py +97 -0
  19. cudf_polars/dsl/expressions/rolling.py +643 -0
  20. cudf_polars/dsl/expressions/selection.py +74 -0
  21. cudf_polars/dsl/expressions/slicing.py +46 -0
  22. cudf_polars/dsl/expressions/sorting.py +85 -0
  23. cudf_polars/dsl/expressions/string.py +1002 -0
  24. cudf_polars/dsl/expressions/struct.py +137 -0
  25. cudf_polars/dsl/expressions/ternary.py +49 -0
  26. cudf_polars/dsl/expressions/unary.py +517 -0
  27. cudf_polars/dsl/ir.py +2607 -0
  28. cudf_polars/dsl/nodebase.py +164 -0
  29. cudf_polars/dsl/to_ast.py +359 -0
  30. cudf_polars/dsl/tracing.py +16 -0
  31. cudf_polars/dsl/translate.py +939 -0
  32. cudf_polars/dsl/traversal.py +224 -0
  33. cudf_polars/dsl/utils/__init__.py +8 -0
  34. cudf_polars/dsl/utils/aggregations.py +481 -0
  35. cudf_polars/dsl/utils/groupby.py +98 -0
  36. cudf_polars/dsl/utils/naming.py +34 -0
  37. cudf_polars/dsl/utils/replace.py +61 -0
  38. cudf_polars/dsl/utils/reshape.py +74 -0
  39. cudf_polars/dsl/utils/rolling.py +121 -0
  40. cudf_polars/dsl/utils/windows.py +192 -0
  41. cudf_polars/experimental/__init__.py +8 -0
  42. cudf_polars/experimental/base.py +386 -0
  43. cudf_polars/experimental/benchmarks/__init__.py +4 -0
  44. cudf_polars/experimental/benchmarks/pdsds.py +220 -0
  45. cudf_polars/experimental/benchmarks/pdsds_queries/__init__.py +4 -0
  46. cudf_polars/experimental/benchmarks/pdsds_queries/q1.py +88 -0
  47. cudf_polars/experimental/benchmarks/pdsds_queries/q10.py +225 -0
  48. cudf_polars/experimental/benchmarks/pdsds_queries/q2.py +244 -0
  49. cudf_polars/experimental/benchmarks/pdsds_queries/q3.py +65 -0
  50. cudf_polars/experimental/benchmarks/pdsds_queries/q4.py +359 -0
  51. cudf_polars/experimental/benchmarks/pdsds_queries/q5.py +462 -0
  52. cudf_polars/experimental/benchmarks/pdsds_queries/q6.py +92 -0
  53. cudf_polars/experimental/benchmarks/pdsds_queries/q7.py +79 -0
  54. cudf_polars/experimental/benchmarks/pdsds_queries/q8.py +524 -0
  55. cudf_polars/experimental/benchmarks/pdsds_queries/q9.py +137 -0
  56. cudf_polars/experimental/benchmarks/pdsh.py +814 -0
  57. cudf_polars/experimental/benchmarks/utils.py +832 -0
  58. cudf_polars/experimental/dask_registers.py +200 -0
  59. cudf_polars/experimental/dispatch.py +156 -0
  60. cudf_polars/experimental/distinct.py +197 -0
  61. cudf_polars/experimental/explain.py +157 -0
  62. cudf_polars/experimental/expressions.py +590 -0
  63. cudf_polars/experimental/groupby.py +327 -0
  64. cudf_polars/experimental/io.py +943 -0
  65. cudf_polars/experimental/join.py +391 -0
  66. cudf_polars/experimental/parallel.py +423 -0
  67. cudf_polars/experimental/repartition.py +69 -0
  68. cudf_polars/experimental/scheduler.py +155 -0
  69. cudf_polars/experimental/select.py +188 -0
  70. cudf_polars/experimental/shuffle.py +354 -0
  71. cudf_polars/experimental/sort.py +609 -0
  72. cudf_polars/experimental/spilling.py +151 -0
  73. cudf_polars/experimental/statistics.py +795 -0
  74. cudf_polars/experimental/utils.py +169 -0
  75. cudf_polars/py.typed +0 -0
  76. cudf_polars/testing/__init__.py +8 -0
  77. cudf_polars/testing/asserts.py +448 -0
  78. cudf_polars/testing/io.py +122 -0
  79. cudf_polars/testing/plugin.py +236 -0
  80. cudf_polars/typing/__init__.py +219 -0
  81. cudf_polars/utils/__init__.py +8 -0
  82. cudf_polars/utils/config.py +741 -0
  83. cudf_polars/utils/conversion.py +40 -0
  84. cudf_polars/utils/dtypes.py +118 -0
  85. cudf_polars/utils/sorting.py +53 -0
  86. cudf_polars/utils/timer.py +39 -0
  87. cudf_polars/utils/versions.py +27 -0
  88. cudf_polars_cu13-25.10.0.dist-info/METADATA +136 -0
  89. cudf_polars_cu13-25.10.0.dist-info/RECORD +92 -0
  90. cudf_polars_cu13-25.10.0.dist-info/WHEEL +5 -0
  91. cudf_polars_cu13-25.10.0.dist-info/licenses/LICENSE +201 -0
  92. cudf_polars_cu13-25.10.0.dist-info/top_level.txt +1 -0
cudf_polars/dsl/ir.py ADDED
@@ -0,0 +1,2607 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """
4
+ DSL nodes for the LogicalPlan of polars.
5
+
6
+ An IR node is either a source, normal, or a sink. Respectively they
7
+ can be considered as functions:
8
+
9
+ - source: `IO () -> DataFrame`
10
+ - normal: `DataFrame -> DataFrame`
11
+ - sink: `DataFrame -> IO ()`
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import itertools
17
+ import json
18
+ import random
19
+ import time
20
+ from functools import cache
21
+ from pathlib import Path
22
+ from typing import TYPE_CHECKING, Any, ClassVar
23
+
24
+ from typing_extensions import assert_never
25
+
26
+ import polars as pl
27
+
28
+ import pylibcudf as plc
29
+
30
+ import cudf_polars.dsl.expr as expr
31
+ from cudf_polars.containers import Column, DataFrame, DataType
32
+ from cudf_polars.dsl.expressions import rolling, unary
33
+ from cudf_polars.dsl.expressions.base import ExecutionContext
34
+ from cudf_polars.dsl.nodebase import Node
35
+ from cudf_polars.dsl.to_ast import to_ast, to_parquet_filter
36
+ from cudf_polars.dsl.tracing import nvtx_annotate_cudf_polars
37
+ from cudf_polars.dsl.utils.reshape import broadcast
38
+ from cudf_polars.dsl.utils.windows import range_window_bounds
39
+ from cudf_polars.utils import dtypes
40
+ from cudf_polars.utils.versions import POLARS_VERSION_LT_131
41
+
42
+ if TYPE_CHECKING:
43
+ from collections.abc import Callable, Hashable, Iterable, Sequence
44
+ from typing import Literal
45
+
46
+ from typing_extensions import Self
47
+
48
+ from polars.polars import _expr_nodes as pl_expr
49
+
50
+ from cudf_polars.containers.dataframe import NamedColumn
51
+ from cudf_polars.typing import CSECache, ClosedInterval, Schema, Slice as Zlice
52
+ from cudf_polars.utils.config import ParquetOptions
53
+ from cudf_polars.utils.timer import Timer
54
+
55
+
56
+ __all__ = [
57
+ "IR",
58
+ "Cache",
59
+ "ConditionalJoin",
60
+ "DataFrameScan",
61
+ "Distinct",
62
+ "Empty",
63
+ "ErrorNode",
64
+ "Filter",
65
+ "GroupBy",
66
+ "HConcat",
67
+ "HStack",
68
+ "Join",
69
+ "MapFunction",
70
+ "MergeSorted",
71
+ "Projection",
72
+ "PythonScan",
73
+ "Reduce",
74
+ "Rolling",
75
+ "Scan",
76
+ "Select",
77
+ "Sink",
78
+ "Slice",
79
+ "Sort",
80
+ "Union",
81
+ ]
82
+
83
+
84
+ class IR(Node["IR"]):
85
+ """Abstract plan node, representing an unevaluated dataframe."""
86
+
87
+ __slots__ = ("_non_child_args", "schema")
88
+ # This annotation is needed because of https://github.com/python/mypy/issues/17981
89
+ _non_child: ClassVar[tuple[str, ...]] = ("schema",)
90
+ # Concrete classes should set this up with the arguments that will
91
+ # be passed to do_evaluate.
92
+ _non_child_args: tuple[Any, ...]
93
+ schema: Schema
94
+ """Mapping from column names to their data types."""
95
+
96
+ def get_hashable(self) -> Hashable:
97
+ """
98
+ Hashable representation of node, treating schema dictionary.
99
+
100
+ Since the schema is a dictionary, even though it is morally
101
+ immutable, it is not hashable. We therefore convert it to
102
+ tuples for hashing purposes.
103
+ """
104
+ # Schema is the first constructor argument
105
+ args = self._ctor_arguments(self.children)[1:]
106
+ schema_hash = tuple(self.schema.items())
107
+ return (type(self), schema_hash, args)
108
+
109
+ # Hacky to avoid type-checking issues, just advertise the
110
+ # signature. Both mypy and pyright complain if we have an abstract
111
+ # method that takes arbitrary *args, but the subclasses have
112
+ # tighter signatures. This complaint is correct because the
113
+ # subclass is not Liskov-substitutable for the superclass.
114
+ # However, we know do_evaluate will only be called with the
115
+ # correct arguments by "construction".
116
+ do_evaluate: Callable[..., DataFrame]
117
+ """
118
+ Evaluate the node (given its evaluated children), and return a dataframe.
119
+
120
+ Parameters
121
+ ----------
122
+ args
123
+ Non child arguments followed by any evaluated dataframe inputs.
124
+
125
+ Returns
126
+ -------
127
+ DataFrame (on device) representing the evaluation of this plan
128
+ node.
129
+
130
+ Raises
131
+ ------
132
+ NotImplementedError
133
+ If evaluation fails. Ideally this should not occur, since the
134
+ translation phase should fail earlier.
135
+ """
136
+
137
+ def evaluate(self, *, cache: CSECache, timer: Timer | None) -> DataFrame:
138
+ """
139
+ Evaluate the node (recursively) and return a dataframe.
140
+
141
+ Parameters
142
+ ----------
143
+ cache
144
+ Mapping from cached node ids to constructed DataFrames.
145
+ Used to implement evaluation of the `Cache` node.
146
+ timer
147
+ If not None, a Timer object to record timings for the
148
+ evaluation of the node.
149
+
150
+ Notes
151
+ -----
152
+ Prefer not to override this method. Instead implement
153
+ :meth:`do_evaluate` which doesn't encode a recursion scheme
154
+ and just assumes already evaluated inputs.
155
+
156
+ Returns
157
+ -------
158
+ DataFrame (on device) representing the evaluation of this plan
159
+ node (and its children).
160
+
161
+ Raises
162
+ ------
163
+ NotImplementedError
164
+ If evaluation fails. Ideally this should not occur, since the
165
+ translation phase should fail earlier.
166
+ """
167
+ children = [child.evaluate(cache=cache, timer=timer) for child in self.children]
168
+ if timer is not None:
169
+ start = time.monotonic_ns()
170
+ result = self.do_evaluate(*self._non_child_args, *children)
171
+ end = time.monotonic_ns()
172
+ # TODO: Set better names on each class object.
173
+ timer.store(start, end, type(self).__name__)
174
+ return result
175
+ else:
176
+ return self.do_evaluate(*self._non_child_args, *children)
177
+
178
+
179
+ class ErrorNode(IR):
180
+ """Represents an error translating the IR."""
181
+
182
+ __slots__ = ("error",)
183
+ _non_child = (
184
+ "schema",
185
+ "error",
186
+ )
187
+ error: str
188
+ """The error."""
189
+
190
+ def __init__(self, schema: Schema, error: str):
191
+ self.schema = schema
192
+ self.error = error
193
+ self.children = ()
194
+
195
+
196
+ class PythonScan(IR):
197
+ """Representation of input from a python function."""
198
+
199
+ __slots__ = ("options", "predicate")
200
+ _non_child = ("schema", "options", "predicate")
201
+ options: Any
202
+ """Arbitrary options."""
203
+ predicate: expr.NamedExpr | None
204
+ """Filter to apply to the constructed dataframe before returning it."""
205
+
206
+ def __init__(self, schema: Schema, options: Any, predicate: expr.NamedExpr | None):
207
+ self.schema = schema
208
+ self.options = options
209
+ self.predicate = predicate
210
+ self._non_child_args = (schema, options, predicate)
211
+ self.children = ()
212
+ raise NotImplementedError("PythonScan not implemented")
213
+
214
+
215
+ def _align_parquet_schema(df: DataFrame, schema: Schema) -> DataFrame:
216
+ # TODO: Alternatively set the schema of the parquet reader to decimal128
217
+ plc_decimals_ids = {
218
+ plc.TypeId.DECIMAL32,
219
+ plc.TypeId.DECIMAL64,
220
+ plc.TypeId.DECIMAL128,
221
+ }
222
+ cast_list = []
223
+
224
+ for name, col in df.column_map.items():
225
+ src = col.obj.type()
226
+ dst = schema[name].plc
227
+ if (
228
+ src.id() in plc_decimals_ids
229
+ and dst.id() in plc_decimals_ids
230
+ and ((src.id() != dst.id()) or (src.scale != dst.scale))
231
+ ):
232
+ cast_list.append(
233
+ Column(plc.unary.cast(col.obj, dst), name=name, dtype=schema[name])
234
+ )
235
+
236
+ if cast_list:
237
+ df = df.with_columns(cast_list)
238
+
239
+ return df
240
+
241
+
242
+ class Scan(IR):
243
+ """Input from files."""
244
+
245
+ __slots__ = (
246
+ "cloud_options",
247
+ "include_file_paths",
248
+ "n_rows",
249
+ "parquet_options",
250
+ "paths",
251
+ "predicate",
252
+ "reader_options",
253
+ "row_index",
254
+ "skip_rows",
255
+ "typ",
256
+ "with_columns",
257
+ )
258
+ _non_child = (
259
+ "schema",
260
+ "typ",
261
+ "reader_options",
262
+ "cloud_options",
263
+ "paths",
264
+ "with_columns",
265
+ "skip_rows",
266
+ "n_rows",
267
+ "row_index",
268
+ "include_file_paths",
269
+ "predicate",
270
+ "parquet_options",
271
+ )
272
+ typ: str
273
+ """What type of file are we reading? Parquet, CSV, etc..."""
274
+ reader_options: dict[str, Any]
275
+ """Reader-specific options, as dictionary."""
276
+ cloud_options: dict[str, Any] | None
277
+ """Cloud-related authentication options, currently ignored."""
278
+ paths: list[str]
279
+ """List of paths to read from."""
280
+ with_columns: list[str] | None
281
+ """Projected columns to return."""
282
+ skip_rows: int
283
+ """Rows to skip at the start when reading."""
284
+ n_rows: int
285
+ """Number of rows to read after skipping."""
286
+ row_index: tuple[str, int] | None
287
+ """If not None add an integer index column of the given name."""
288
+ include_file_paths: str | None
289
+ """Include the path of the source file(s) as a column with this name."""
290
+ predicate: expr.NamedExpr | None
291
+ """Mask to apply to the read dataframe."""
292
+ parquet_options: ParquetOptions
293
+ """Parquet-specific options."""
294
+
295
+ PARQUET_DEFAULT_CHUNK_SIZE: int = 0 # unlimited
296
+ PARQUET_DEFAULT_PASS_LIMIT: int = 16 * 1024**3 # 16GiB
297
+
298
+ def __init__(
299
+ self,
300
+ schema: Schema,
301
+ typ: str,
302
+ reader_options: dict[str, Any],
303
+ cloud_options: dict[str, Any] | None,
304
+ paths: list[str],
305
+ with_columns: list[str] | None,
306
+ skip_rows: int,
307
+ n_rows: int,
308
+ row_index: tuple[str, int] | None,
309
+ include_file_paths: str | None,
310
+ predicate: expr.NamedExpr | None,
311
+ parquet_options: ParquetOptions,
312
+ ):
313
+ self.schema = schema
314
+ self.typ = typ
315
+ self.reader_options = reader_options
316
+ self.cloud_options = cloud_options
317
+ self.paths = paths
318
+ self.with_columns = with_columns
319
+ self.skip_rows = skip_rows
320
+ self.n_rows = n_rows
321
+ self.row_index = row_index
322
+ self.include_file_paths = include_file_paths
323
+ self.predicate = predicate
324
+ self._non_child_args = (
325
+ schema,
326
+ typ,
327
+ reader_options,
328
+ paths,
329
+ with_columns,
330
+ skip_rows,
331
+ n_rows,
332
+ row_index,
333
+ include_file_paths,
334
+ predicate,
335
+ parquet_options,
336
+ )
337
+ self.children = ()
338
+ self.parquet_options = parquet_options
339
+ if self.typ not in ("csv", "parquet", "ndjson"): # pragma: no cover
340
+ # This line is unhittable ATM since IPC/Anonymous scan raise
341
+ # on the polars side
342
+ raise NotImplementedError(f"Unhandled scan type: {self.typ}")
343
+ if self.typ == "ndjson" and (self.n_rows != -1 or self.skip_rows != 0):
344
+ raise NotImplementedError("row limit in scan for json reader")
345
+ if self.skip_rows < 0:
346
+ # TODO: polars has this implemented for parquet,
347
+ # maybe we can do this too?
348
+ raise NotImplementedError("slice pushdown for negative slices")
349
+ if self.cloud_options is not None and any(
350
+ self.cloud_options.get(k) is not None for k in ("aws", "azure", "gcp")
351
+ ):
352
+ raise NotImplementedError(
353
+ "Read from cloud storage"
354
+ ) # pragma: no cover; no test yet
355
+ if (
356
+ any(str(p).startswith("https:/") for p in self.paths)
357
+ and POLARS_VERSION_LT_131
358
+ ): # pragma: no cover; polars passed us the wrong URI
359
+ # https://github.com/pola-rs/polars/issues/22766
360
+ raise NotImplementedError("Read from https")
361
+ if any(
362
+ str(p).startswith("file:/" if POLARS_VERSION_LT_131 else "file://")
363
+ for p in self.paths
364
+ ):
365
+ raise NotImplementedError("Read from file URI")
366
+ if self.typ == "csv":
367
+ if any(
368
+ plc.io.SourceInfo._is_remote_uri(p) for p in self.paths
369
+ ): # pragma: no cover; no test yet
370
+ # This works fine when the file has no leading blank lines,
371
+ # but currently we do some file introspection
372
+ # to skip blanks before parsing the header.
373
+ # For remote files we cannot determine if leading blank lines
374
+ # exist, so we're punting on CSV support.
375
+ # TODO: Once the CSV reader supports skipping leading
376
+ # blank lines natively, we can remove this guard.
377
+ raise NotImplementedError(
378
+ "Reading CSV from remote is not yet supported"
379
+ )
380
+
381
+ if self.reader_options["skip_rows_after_header"] != 0:
382
+ raise NotImplementedError("Skipping rows after header in CSV reader")
383
+ parse_options = self.reader_options["parse_options"]
384
+ if (
385
+ null_values := parse_options["null_values"]
386
+ ) is not None and "Named" in null_values:
387
+ raise NotImplementedError(
388
+ "Per column null value specification not supported for CSV reader"
389
+ )
390
+ if (
391
+ comment := parse_options["comment_prefix"]
392
+ ) is not None and "Multi" in comment:
393
+ raise NotImplementedError(
394
+ "Multi-character comment prefix not supported for CSV reader"
395
+ )
396
+ if not self.reader_options["has_header"]:
397
+ # TODO: To support reading headerless CSV files without requiring new
398
+ # column names, we would need to do file introspection to infer the number
399
+ # of columns so column projection works right.
400
+ reader_schema = self.reader_options.get("schema")
401
+ if not (
402
+ reader_schema
403
+ and isinstance(schema, dict)
404
+ and "fields" in reader_schema
405
+ ):
406
+ raise NotImplementedError(
407
+ "Reading CSV without header requires user-provided column names via new_columns"
408
+ )
409
+ elif self.typ == "ndjson":
410
+ # TODO: consider handling the low memory option here
411
+ # (maybe use chunked JSON reader)
412
+ if self.reader_options["ignore_errors"]:
413
+ raise NotImplementedError(
414
+ "ignore_errors is not supported in the JSON reader"
415
+ )
416
+ if include_file_paths is not None:
417
+ # TODO: Need to populate num_rows_per_source in read_json in libcudf
418
+ raise NotImplementedError("Including file paths in a json scan.")
419
+ elif (
420
+ self.typ == "parquet"
421
+ and self.row_index is not None
422
+ and self.with_columns is not None
423
+ and len(self.with_columns) == 0
424
+ ):
425
+ raise NotImplementedError(
426
+ "Reading only parquet metadata to produce row index."
427
+ )
428
+
429
+ def get_hashable(self) -> Hashable:
430
+ """
431
+ Hashable representation of the node.
432
+
433
+ The options dictionaries are serialised for hashing purposes
434
+ as json strings.
435
+ """
436
+ schema_hash = tuple(self.schema.items())
437
+ return (
438
+ type(self),
439
+ schema_hash,
440
+ self.typ,
441
+ json.dumps(self.reader_options),
442
+ json.dumps(self.cloud_options),
443
+ tuple(self.paths),
444
+ tuple(self.with_columns) if self.with_columns is not None else None,
445
+ self.skip_rows,
446
+ self.n_rows,
447
+ self.row_index,
448
+ self.include_file_paths,
449
+ self.predicate,
450
+ self.parquet_options,
451
+ )
452
+
453
+ @staticmethod
454
+ def add_file_paths(
455
+ name: str, paths: list[str], rows_per_path: list[int], df: DataFrame
456
+ ) -> DataFrame:
457
+ """
458
+ Add a Column of file paths to the DataFrame.
459
+
460
+ Each path is repeated according to the number of rows read from it.
461
+ """
462
+ (filepaths,) = plc.filling.repeat(
463
+ plc.Table([plc.Column.from_arrow(pl.Series(values=map(str, paths)))]),
464
+ plc.Column.from_arrow(
465
+ pl.Series(values=rows_per_path, dtype=pl.datatypes.Int32())
466
+ ),
467
+ ).columns()
468
+ dtype = DataType(pl.String())
469
+ return df.with_columns([Column(filepaths, name=name, dtype=dtype)])
470
+
471
+ def fast_count(self) -> int: # pragma: no cover
472
+ """Get the number of rows in a Parquet Scan."""
473
+ meta = plc.io.parquet_metadata.read_parquet_metadata(
474
+ plc.io.SourceInfo(self.paths)
475
+ )
476
+ total_rows = meta.num_rows() - self.skip_rows
477
+ if self.n_rows != -1:
478
+ total_rows = min(total_rows, self.n_rows)
479
+ return max(total_rows, 0)
480
+
481
+ @classmethod
482
+ @nvtx_annotate_cudf_polars(message="Scan")
483
+ def do_evaluate(
484
+ cls,
485
+ schema: Schema,
486
+ typ: str,
487
+ reader_options: dict[str, Any],
488
+ paths: list[str],
489
+ with_columns: list[str] | None,
490
+ skip_rows: int,
491
+ n_rows: int,
492
+ row_index: tuple[str, int] | None,
493
+ include_file_paths: str | None,
494
+ predicate: expr.NamedExpr | None,
495
+ parquet_options: ParquetOptions,
496
+ ) -> DataFrame:
497
+ """Evaluate and return a dataframe."""
498
+ if typ == "csv":
499
+
500
+ def read_csv_header(
501
+ path: Path | str, sep: str
502
+ ) -> list[str]: # pragma: no cover
503
+ with Path(path).open() as f:
504
+ for line in f:
505
+ stripped = line.strip()
506
+ if stripped:
507
+ return stripped.split(sep)
508
+ return []
509
+
510
+ parse_options = reader_options["parse_options"]
511
+ sep = chr(parse_options["separator"])
512
+ quote = chr(parse_options["quote_char"])
513
+ eol = chr(parse_options["eol_char"])
514
+ if reader_options["schema"] is not None:
515
+ # Reader schema provides names
516
+ column_names = list(reader_options["schema"]["fields"].keys())
517
+ else:
518
+ # file provides column names
519
+ column_names = None
520
+ usecols = with_columns
521
+ has_header = reader_options["has_header"]
522
+ header = 0 if has_header else -1
523
+
524
+ # polars defaults to no null recognition
525
+ null_values = [""]
526
+ if parse_options["null_values"] is not None:
527
+ ((typ, nulls),) = parse_options["null_values"].items()
528
+ if typ == "AllColumnsSingle":
529
+ # Single value
530
+ null_values.append(nulls)
531
+ else:
532
+ # List of values
533
+ null_values.extend(nulls)
534
+ if parse_options["comment_prefix"] is not None:
535
+ comment = chr(parse_options["comment_prefix"]["Single"])
536
+ else:
537
+ comment = None
538
+ decimal = "," if parse_options["decimal_comma"] else "."
539
+
540
+ # polars skips blank lines at the beginning of the file
541
+ pieces = []
542
+ seen_paths = []
543
+ read_partial = n_rows != -1
544
+ for p in paths:
545
+ skiprows = reader_options["skip_rows"]
546
+ path = Path(p)
547
+ with path.open() as f:
548
+ while f.readline() == "\n":
549
+ skiprows += 1
550
+ options = (
551
+ plc.io.csv.CsvReaderOptions.builder(plc.io.SourceInfo([path]))
552
+ .nrows(n_rows)
553
+ .skiprows(skiprows + skip_rows)
554
+ .lineterminator(str(eol))
555
+ .quotechar(str(quote))
556
+ .decimal(decimal)
557
+ .keep_default_na(keep_default_na=False)
558
+ .na_filter(na_filter=True)
559
+ .delimiter(str(sep))
560
+ .build()
561
+ )
562
+ if column_names is not None:
563
+ options.set_names([str(name) for name in column_names])
564
+ else:
565
+ if header > -1 and skip_rows > header: # pragma: no cover
566
+ # We need to read the header otherwise we would skip it
567
+ column_names = read_csv_header(path, str(sep))
568
+ options.set_names(column_names)
569
+ options.set_header(header)
570
+ options.set_dtypes({name: dtype.plc for name, dtype in schema.items()})
571
+ if usecols is not None:
572
+ options.set_use_cols_names([str(name) for name in usecols])
573
+ options.set_na_values(null_values)
574
+ if comment is not None:
575
+ options.set_comment(comment)
576
+ tbl_w_meta = plc.io.csv.read_csv(options)
577
+ pieces.append(tbl_w_meta)
578
+ if include_file_paths is not None:
579
+ seen_paths.append(p)
580
+ if read_partial:
581
+ n_rows -= tbl_w_meta.tbl.num_rows()
582
+ if n_rows <= 0:
583
+ break
584
+ tables, (colnames, *_) = zip(
585
+ *(
586
+ (piece.tbl, piece.column_names(include_children=False))
587
+ for piece in pieces
588
+ ),
589
+ strict=True,
590
+ )
591
+ df = DataFrame.from_table(
592
+ plc.concatenate.concatenate(list(tables)),
593
+ colnames,
594
+ [schema[colname] for colname in colnames],
595
+ )
596
+ if include_file_paths is not None:
597
+ df = Scan.add_file_paths(
598
+ include_file_paths,
599
+ seen_paths,
600
+ [t.num_rows() for t in tables],
601
+ df,
602
+ )
603
+ elif typ == "parquet":
604
+ filters = None
605
+ if predicate is not None and row_index is None:
606
+ # Can't apply filters during read if we have a row index.
607
+ filters = to_parquet_filter(predicate.value)
608
+ options = plc.io.parquet.ParquetReaderOptions.builder(
609
+ plc.io.SourceInfo(paths)
610
+ ).build()
611
+ if with_columns is not None:
612
+ options.set_columns(with_columns)
613
+ if filters is not None:
614
+ options.set_filter(filters)
615
+ if n_rows != -1:
616
+ options.set_num_rows(n_rows)
617
+ if skip_rows != 0:
618
+ options.set_skip_rows(skip_rows)
619
+ if parquet_options.chunked:
620
+ reader = plc.io.parquet.ChunkedParquetReader(
621
+ options,
622
+ chunk_read_limit=parquet_options.chunk_read_limit,
623
+ pass_read_limit=parquet_options.pass_read_limit,
624
+ )
625
+ chunk = reader.read_chunk()
626
+ tbl = chunk.tbl
627
+ # TODO: Nested column names
628
+ names = chunk.column_names(include_children=False)
629
+ concatenated_columns = tbl.columns()
630
+ while reader.has_next():
631
+ chunk = reader.read_chunk()
632
+ tbl = chunk.tbl
633
+ for i in range(tbl.num_columns()):
634
+ concatenated_columns[i] = plc.concatenate.concatenate(
635
+ [concatenated_columns[i], tbl._columns[i]]
636
+ )
637
+ # Drop residual columns to save memory
638
+ tbl._columns[i] = None
639
+ df = DataFrame.from_table(
640
+ plc.Table(concatenated_columns),
641
+ names=names,
642
+ dtypes=[schema[name] for name in names],
643
+ )
644
+ df = _align_parquet_schema(df, schema)
645
+ if include_file_paths is not None:
646
+ df = Scan.add_file_paths(
647
+ include_file_paths, paths, chunk.num_rows_per_source, df
648
+ )
649
+ else:
650
+ tbl_w_meta = plc.io.parquet.read_parquet(options)
651
+ # TODO: consider nested column names?
652
+ col_names = tbl_w_meta.column_names(include_children=False)
653
+ df = DataFrame.from_table(
654
+ tbl_w_meta.tbl,
655
+ col_names,
656
+ [schema[name] for name in col_names],
657
+ )
658
+ df = _align_parquet_schema(df, schema)
659
+ if include_file_paths is not None:
660
+ df = Scan.add_file_paths(
661
+ include_file_paths, paths, tbl_w_meta.num_rows_per_source, df
662
+ )
663
+ if filters is not None:
664
+ # Mask must have been applied.
665
+ return df
666
+ elif typ == "ndjson":
667
+ json_schema: list[plc.io.json.NameAndType] = [
668
+ (name, typ.plc, []) for name, typ in schema.items()
669
+ ]
670
+ plc_tbl_w_meta = plc.io.json.read_json(
671
+ plc.io.json._setup_json_reader_options(
672
+ plc.io.SourceInfo(paths),
673
+ lines=True,
674
+ dtypes=json_schema,
675
+ prune_columns=True,
676
+ )
677
+ )
678
+ # TODO: I don't think cudf-polars supports nested types in general right now
679
+ # (but when it does, we should pass child column names from nested columns in)
680
+ col_names = plc_tbl_w_meta.column_names(include_children=False)
681
+ df = DataFrame.from_table(
682
+ plc_tbl_w_meta.tbl,
683
+ col_names,
684
+ [schema[name] for name in col_names],
685
+ )
686
+ col_order = list(schema.keys())
687
+ if row_index is not None:
688
+ col_order.remove(row_index[0])
689
+ df = df.select(col_order)
690
+ else:
691
+ raise NotImplementedError(
692
+ f"Unhandled scan type: {typ}"
693
+ ) # pragma: no cover; post init trips first
694
+ if row_index is not None:
695
+ name, offset = row_index
696
+ offset += skip_rows
697
+ dtype = schema[name]
698
+ step = plc.Scalar.from_py(1, dtype.plc)
699
+ init = plc.Scalar.from_py(offset, dtype.plc)
700
+ index_col = Column(
701
+ plc.filling.sequence(df.num_rows, init, step),
702
+ is_sorted=plc.types.Sorted.YES,
703
+ order=plc.types.Order.ASCENDING,
704
+ null_order=plc.types.NullOrder.AFTER,
705
+ name=name,
706
+ dtype=dtype,
707
+ )
708
+ df = DataFrame([index_col, *df.columns])
709
+ if next(iter(schema)) != name:
710
+ df = df.select(schema)
711
+ assert all(
712
+ c.obj.type() == schema[name].plc for name, c in df.column_map.items()
713
+ )
714
+ if predicate is None:
715
+ return df
716
+ else:
717
+ (mask,) = broadcast(predicate.evaluate(df), target_length=df.num_rows)
718
+ return df.filter(mask)
719
+
720
+
721
+ class Sink(IR):
722
+ """Sink a dataframe to a file."""
723
+
724
+ __slots__ = ("cloud_options", "kind", "options", "parquet_options", "path")
725
+ _non_child = (
726
+ "schema",
727
+ "kind",
728
+ "path",
729
+ "parquet_options",
730
+ "options",
731
+ "cloud_options",
732
+ )
733
+
734
+ kind: str
735
+ """The type of file to write to. Eg. Parquet, CSV, etc."""
736
+ path: str
737
+ """The path to write to"""
738
+ parquet_options: ParquetOptions
739
+ """GPU-specific configuration options"""
740
+ cloud_options: dict[str, Any] | None
741
+ """Cloud-related authentication options, currently ignored."""
742
+ options: dict[str, Any]
743
+ """Sink options from Polars"""
744
+
745
+ def __init__(
746
+ self,
747
+ schema: Schema,
748
+ kind: str,
749
+ path: str,
750
+ parquet_options: ParquetOptions,
751
+ options: dict[str, Any],
752
+ cloud_options: dict[str, Any],
753
+ df: IR,
754
+ ):
755
+ self.schema = schema
756
+ self.kind = kind
757
+ self.path = path
758
+ self.parquet_options = parquet_options
759
+ self.options = options
760
+ self.cloud_options = cloud_options
761
+ self.children = (df,)
762
+ self._non_child_args = (schema, kind, path, parquet_options, options)
763
+ if self.cloud_options is not None and any(
764
+ self.cloud_options.get(k) is not None
765
+ for k in ("config", "credential_provider")
766
+ ):
767
+ raise NotImplementedError(
768
+ "Write to cloud storage"
769
+ ) # pragma: no cover; no test yet
770
+ sync_on_close = options.get("sync_on_close")
771
+ if sync_on_close not in {"None", None}:
772
+ raise NotImplementedError(
773
+ f"sync_on_close='{sync_on_close}' is not supported."
774
+ ) # pragma: no cover; no test yet
775
+ child_schema = df.schema.values()
776
+ if kind == "Csv":
777
+ if not all(
778
+ plc.io.csv.is_supported_write_csv(dtype.plc) for dtype in child_schema
779
+ ):
780
+ # Nested types are unsupported in polars and libcudf
781
+ raise NotImplementedError(
782
+ "Contains unsupported types for CSV writing"
783
+ ) # pragma: no cover
784
+ serialize = options["serialize_options"]
785
+ if options["include_bom"]:
786
+ raise NotImplementedError("include_bom is not supported.")
787
+ for key in (
788
+ "date_format",
789
+ "time_format",
790
+ "datetime_format",
791
+ "float_scientific",
792
+ "float_precision",
793
+ ):
794
+ if serialize[key] is not None:
795
+ raise NotImplementedError(f"{key} is not supported.")
796
+ if serialize["quote_style"] != "Necessary":
797
+ raise NotImplementedError("Only quote_style='Necessary' is supported.")
798
+ if chr(serialize["quote_char"]) != '"':
799
+ raise NotImplementedError("Only quote_char='\"' is supported.")
800
+ elif kind == "Parquet":
801
+ compression = options["compression"]
802
+ if isinstance(compression, dict):
803
+ if len(compression) != 1:
804
+ raise NotImplementedError(
805
+ "Compression dict with more than one entry."
806
+ ) # pragma: no cover
807
+ compression, compression_level = next(iter(compression.items()))
808
+ options["compression"] = compression
809
+ if compression_level is not None:
810
+ raise NotImplementedError(
811
+ "Setting compression_level is not supported."
812
+ )
813
+ if compression == "Lz4Raw":
814
+ compression = "Lz4"
815
+ options["compression"] = compression
816
+ if (
817
+ compression != "Uncompressed"
818
+ and not plc.io.parquet.is_supported_write_parquet(
819
+ getattr(plc.io.types.CompressionType, compression.upper())
820
+ )
821
+ ):
822
+ raise NotImplementedError(
823
+ f"Compression type '{compression}' is not supported."
824
+ )
825
+ elif (
826
+ kind == "Json"
827
+ ): # pragma: no cover; options are validated on the polars side
828
+ if not all(
829
+ plc.io.json.is_supported_write_json(dtype.plc) for dtype in child_schema
830
+ ):
831
+ # Nested types are unsupported in polars and libcudf
832
+ raise NotImplementedError(
833
+ "Contains unsupported types for JSON writing"
834
+ ) # pragma: no cover
835
+ shared_writer_options = {"sync_on_close", "maintain_order", "mkdir"}
836
+ if set(options) - shared_writer_options:
837
+ raise NotImplementedError("Unsupported options passed JSON writer.")
838
+ else:
839
+ raise NotImplementedError(
840
+ f"Unhandled sink kind: {kind}"
841
+ ) # pragma: no cover
842
+
843
+ def get_hashable(self) -> Hashable:
844
+ """
845
+ Hashable representation of the node.
846
+
847
+ The option dictionary is serialised for hashing purposes.
848
+ """
849
+ schema_hash = tuple(self.schema.items()) # pragma: no cover
850
+ return (
851
+ type(self),
852
+ schema_hash,
853
+ self.kind,
854
+ self.path,
855
+ self.parquet_options,
856
+ json.dumps(self.options),
857
+ json.dumps(self.cloud_options),
858
+ ) # pragma: no cover
859
+
860
+ @classmethod
861
+ def _write_csv(
862
+ cls, target: plc.io.SinkInfo, options: dict[str, Any], df: DataFrame
863
+ ) -> None:
864
+ """Write CSV data to a sink."""
865
+ serialize = options["serialize_options"]
866
+ options = (
867
+ plc.io.csv.CsvWriterOptions.builder(target, df.table)
868
+ .include_header(options["include_header"])
869
+ .names(df.column_names if options["include_header"] else [])
870
+ .na_rep(serialize["null"])
871
+ .line_terminator(serialize["line_terminator"])
872
+ .inter_column_delimiter(chr(serialize["separator"]))
873
+ .build()
874
+ )
875
+ plc.io.csv.write_csv(options)
876
+
877
+ @classmethod
878
+ def _write_json(cls, target: plc.io.SinkInfo, df: DataFrame) -> None:
879
+ """Write Json data to a sink."""
880
+ metadata = plc.io.TableWithMetadata(
881
+ df.table, [(col, []) for col in df.column_names]
882
+ )
883
+ options = (
884
+ plc.io.json.JsonWriterOptions.builder(target, df.table)
885
+ .lines(val=True)
886
+ .na_rep("null")
887
+ .include_nulls(val=True)
888
+ .metadata(metadata)
889
+ .utf8_escaped(val=False)
890
+ .build()
891
+ )
892
+ plc.io.json.write_json(options)
893
+
894
+ @staticmethod
895
+ def _make_parquet_metadata(df: DataFrame) -> plc.io.types.TableInputMetadata:
896
+ """Create TableInputMetadata and set column names."""
897
+ metadata = plc.io.types.TableInputMetadata(df.table)
898
+ for i, name in enumerate(df.column_names):
899
+ metadata.column_metadata[i].set_name(name)
900
+ return metadata
901
+
902
+ @staticmethod
903
+ def _apply_parquet_writer_options(
904
+ builder: plc.io.parquet.ChunkedParquetWriterOptionsBuilder
905
+ | plc.io.parquet.ParquetWriterOptionsBuilder,
906
+ options: dict[str, Any],
907
+ ) -> (
908
+ plc.io.parquet.ChunkedParquetWriterOptionsBuilder
909
+ | plc.io.parquet.ParquetWriterOptionsBuilder
910
+ ):
911
+ """Apply writer options to the builder."""
912
+ compression = options.get("compression")
913
+ if compression and compression != "Uncompressed":
914
+ compression_type = getattr(
915
+ plc.io.types.CompressionType, compression.upper()
916
+ )
917
+ builder = builder.compression(compression_type)
918
+
919
+ if (data_page_size := options.get("data_page_size")) is not None:
920
+ builder = builder.max_page_size_bytes(data_page_size)
921
+
922
+ if (row_group_size := options.get("row_group_size")) is not None:
923
+ builder = builder.row_group_size_rows(row_group_size)
924
+
925
+ return builder
926
+
927
+ @classmethod
928
+ def _write_parquet(
929
+ cls,
930
+ target: plc.io.SinkInfo,
931
+ parquet_options: ParquetOptions,
932
+ options: dict[str, Any],
933
+ df: DataFrame,
934
+ ) -> None:
935
+ metadata: plc.io.types.TableInputMetadata = cls._make_parquet_metadata(df)
936
+
937
+ builder: (
938
+ plc.io.parquet.ChunkedParquetWriterOptionsBuilder
939
+ | plc.io.parquet.ParquetWriterOptionsBuilder
940
+ )
941
+
942
+ if (
943
+ parquet_options.chunked
944
+ and parquet_options.n_output_chunks != 1
945
+ and df.table.num_rows() != 0
946
+ ):
947
+ builder = plc.io.parquet.ChunkedParquetWriterOptions.builder(
948
+ target
949
+ ).metadata(metadata)
950
+ builder = cls._apply_parquet_writer_options(builder, options)
951
+ writer_options = builder.build()
952
+ writer = plc.io.parquet.ChunkedParquetWriter.from_options(writer_options)
953
+
954
+ # TODO: Can be based on a heuristic that estimates chunk size
955
+ # from the input table size and available GPU memory.
956
+ num_chunks = parquet_options.n_output_chunks
957
+ table_chunks = plc.copying.split(
958
+ df.table,
959
+ [i * df.table.num_rows() // num_chunks for i in range(1, num_chunks)],
960
+ )
961
+ for chunk in table_chunks:
962
+ writer.write(chunk)
963
+ writer.close([])
964
+
965
+ else:
966
+ builder = plc.io.parquet.ParquetWriterOptions.builder(
967
+ target, df.table
968
+ ).metadata(metadata)
969
+ builder = cls._apply_parquet_writer_options(builder, options)
970
+ writer_options = builder.build()
971
+ plc.io.parquet.write_parquet(writer_options)
972
+
973
+ @classmethod
974
+ @nvtx_annotate_cudf_polars(message="Sink")
975
+ def do_evaluate(
976
+ cls,
977
+ schema: Schema,
978
+ kind: str,
979
+ path: str,
980
+ parquet_options: ParquetOptions,
981
+ options: dict[str, Any],
982
+ df: DataFrame,
983
+ ) -> DataFrame:
984
+ """Write the dataframe to a file."""
985
+ target = plc.io.SinkInfo([path])
986
+
987
+ if options.get("mkdir", False):
988
+ Path(path).parent.mkdir(parents=True, exist_ok=True)
989
+ if kind == "Csv":
990
+ cls._write_csv(target, options, df)
991
+ elif kind == "Parquet":
992
+ cls._write_parquet(target, parquet_options, options, df)
993
+ elif kind == "Json":
994
+ cls._write_json(target, df)
995
+
996
+ return DataFrame([])
997
+
998
+
999
+ class Cache(IR):
1000
+ """
1001
+ Return a cached plan node.
1002
+
1003
+ Used for CSE at the plan level.
1004
+ """
1005
+
1006
+ __slots__ = ("key", "refcount")
1007
+ _non_child = ("schema", "key", "refcount")
1008
+ key: int
1009
+ """The cache key."""
1010
+ refcount: int | None
1011
+ """The number of cache hits."""
1012
+
1013
+ def __init__(self, schema: Schema, key: int, refcount: int | None, value: IR):
1014
+ self.schema = schema
1015
+ self.key = key
1016
+ self.refcount = refcount
1017
+ self.children = (value,)
1018
+ self._non_child_args = (key, refcount)
1019
+
1020
+ def get_hashable(self) -> Hashable: # noqa: D102
1021
+ # Polars arranges that the keys are unique across all cache
1022
+ # nodes that reference the same child, so we don't need to
1023
+ # hash the child.
1024
+ return (type(self), self.key, self.refcount)
1025
+
1026
+ def is_equal(self, other: Self) -> bool: # noqa: D102
1027
+ if self.key == other.key and self.refcount == other.refcount:
1028
+ self.children = other.children
1029
+ return True
1030
+ return False
1031
+
1032
+ @classmethod
1033
+ @nvtx_annotate_cudf_polars(message="Cache")
1034
+ def do_evaluate(
1035
+ cls, key: int, refcount: int | None, df: DataFrame
1036
+ ) -> DataFrame: # pragma: no cover; basic evaluation never calls this
1037
+ """Evaluate and return a dataframe."""
1038
+ # Our value has already been computed for us, so let's just
1039
+ # return it.
1040
+ return df
1041
+
1042
+ def evaluate(self, *, cache: CSECache, timer: Timer | None) -> DataFrame:
1043
+ """Evaluate and return a dataframe."""
1044
+ # We must override the recursion scheme because we don't want
1045
+ # to recurse if we're in the cache.
1046
+ try:
1047
+ (result, hits) = cache[self.key]
1048
+ except KeyError:
1049
+ (value,) = self.children
1050
+ result = value.evaluate(cache=cache, timer=timer)
1051
+ cache[self.key] = (result, 0)
1052
+ return result
1053
+ else:
1054
+ if self.refcount is None:
1055
+ return result
1056
+
1057
+ hits += 1 # pragma: no cover
1058
+ if hits == self.refcount: # pragma: no cover
1059
+ del cache[self.key]
1060
+ else: # pragma: no cover
1061
+ cache[self.key] = (result, hits)
1062
+ return result # pragma: no cover
1063
+
1064
+
1065
+ class DataFrameScan(IR):
1066
+ """
1067
+ Input from an existing polars DataFrame.
1068
+
1069
+ This typically arises from ``q.collect().lazy()``
1070
+ """
1071
+
1072
+ __slots__ = ("_id_for_hash", "df", "projection")
1073
+ _non_child = ("schema", "df", "projection")
1074
+ df: Any
1075
+ """Polars internal PyDataFrame object."""
1076
+ projection: tuple[str, ...] | None
1077
+ """List of columns to project out."""
1078
+
1079
+ def __init__(
1080
+ self,
1081
+ schema: Schema,
1082
+ df: Any,
1083
+ projection: Sequence[str] | None,
1084
+ ):
1085
+ self.schema = schema
1086
+ self.df = df
1087
+ self.projection = tuple(projection) if projection is not None else None
1088
+ self._non_child_args = (
1089
+ schema,
1090
+ pl.DataFrame._from_pydf(df),
1091
+ self.projection,
1092
+ )
1093
+ self.children = ()
1094
+ self._id_for_hash = random.randint(0, 2**64 - 1)
1095
+
1096
+ def get_hashable(self) -> Hashable:
1097
+ """
1098
+ Hashable representation of the node.
1099
+
1100
+ The (heavy) dataframe object is not hashed. No two instances of
1101
+ ``DataFrameScan`` will have the same hash, even if they have the
1102
+ same schema, projection, and config options, and data.
1103
+ """
1104
+ schema_hash = tuple(self.schema.items())
1105
+ return (
1106
+ type(self),
1107
+ schema_hash,
1108
+ self._id_for_hash,
1109
+ self.projection,
1110
+ )
1111
+
1112
+ @classmethod
1113
+ @nvtx_annotate_cudf_polars(message="DataFrameScan")
1114
+ def do_evaluate(
1115
+ cls,
1116
+ schema: Schema,
1117
+ df: Any,
1118
+ projection: tuple[str, ...] | None,
1119
+ ) -> DataFrame:
1120
+ """Evaluate and return a dataframe."""
1121
+ if projection is not None:
1122
+ df = df.select(projection)
1123
+ df = DataFrame.from_polars(df)
1124
+ assert all(
1125
+ c.obj.type() == dtype.plc
1126
+ for c, dtype in zip(df.columns, schema.values(), strict=True)
1127
+ )
1128
+ return df
1129
+
1130
+
1131
+ class Select(IR):
1132
+ """Produce a new dataframe selecting given expressions from an input."""
1133
+
1134
+ __slots__ = ("exprs", "should_broadcast")
1135
+ _non_child = ("schema", "exprs", "should_broadcast")
1136
+ exprs: tuple[expr.NamedExpr, ...]
1137
+ """List of expressions to evaluate to form the new dataframe."""
1138
+ should_broadcast: bool
1139
+ """Should columns be broadcast?"""
1140
+
1141
+ def __init__(
1142
+ self,
1143
+ schema: Schema,
1144
+ exprs: Sequence[expr.NamedExpr],
1145
+ should_broadcast: bool, # noqa: FBT001
1146
+ df: IR,
1147
+ ):
1148
+ self.schema = schema
1149
+ self.exprs = tuple(exprs)
1150
+ self.should_broadcast = should_broadcast
1151
+ self.children = (df,)
1152
+ self._non_child_args = (self.exprs, should_broadcast)
1153
+ if (
1154
+ Select._is_len_expr(self.exprs)
1155
+ and isinstance(df, Scan)
1156
+ and df.typ != "parquet"
1157
+ ): # pragma: no cover
1158
+ raise NotImplementedError(f"Unsupported scan type: {df.typ}")
1159
+
1160
+ @staticmethod
1161
+ def _is_len_expr(exprs: tuple[expr.NamedExpr, ...]) -> bool: # pragma: no cover
1162
+ if len(exprs) == 1:
1163
+ expr0 = exprs[0].value
1164
+ return (
1165
+ isinstance(expr0, expr.Cast)
1166
+ and len(expr0.children) == 1
1167
+ and isinstance(expr0.children[0], expr.Len)
1168
+ )
1169
+ return False
1170
+
1171
+ @classmethod
1172
+ @nvtx_annotate_cudf_polars(message="Select")
1173
+ def do_evaluate(
1174
+ cls,
1175
+ exprs: tuple[expr.NamedExpr, ...],
1176
+ should_broadcast: bool, # noqa: FBT001
1177
+ df: DataFrame,
1178
+ ) -> DataFrame:
1179
+ """Evaluate and return a dataframe."""
1180
+ # Handle any broadcasting
1181
+ columns = [e.evaluate(df) for e in exprs]
1182
+ if should_broadcast:
1183
+ columns = broadcast(*columns)
1184
+ return DataFrame(columns)
1185
+
1186
+ def evaluate(self, *, cache: CSECache, timer: Timer | None) -> DataFrame:
1187
+ """
1188
+ Evaluate the Select node with special handling for fast count queries.
1189
+
1190
+ Parameters
1191
+ ----------
1192
+ cache
1193
+ Mapping from cached node ids to constructed DataFrames.
1194
+ Used to implement evaluation of the `Cache` node.
1195
+ timer
1196
+ If not None, a Timer object to record timings for the
1197
+ evaluation of the node.
1198
+
1199
+ Returns
1200
+ -------
1201
+ DataFrame
1202
+ Result of evaluating this Select node. If the expression is a
1203
+ count over a parquet scan, returns a constant row count directly
1204
+ without evaluating the scan.
1205
+
1206
+ Raises
1207
+ ------
1208
+ NotImplementedError
1209
+ If evaluation fails. Ideally this should not occur, since the
1210
+ translation phase should fail earlier.
1211
+ """
1212
+ if (
1213
+ isinstance(self.children[0], Scan)
1214
+ and Select._is_len_expr(self.exprs)
1215
+ and self.children[0].typ == "parquet"
1216
+ and self.children[0].predicate is None
1217
+ ):
1218
+ scan = self.children[0] # pragma: no cover
1219
+ effective_rows = scan.fast_count() # pragma: no cover
1220
+ dtype = DataType(pl.UInt32()) # pragma: no cover
1221
+ col = Column(
1222
+ plc.Column.from_scalar(
1223
+ plc.Scalar.from_py(effective_rows, dtype.plc),
1224
+ 1,
1225
+ ),
1226
+ name=self.exprs[0].name or "len",
1227
+ dtype=dtype,
1228
+ ) # pragma: no cover
1229
+ return DataFrame([col]) # pragma: no cover
1230
+
1231
+ return super().evaluate(cache=cache, timer=timer)
1232
+
1233
+
1234
+ class Reduce(IR):
1235
+ """
1236
+ Produce a new dataframe selecting given expressions from an input.
1237
+
1238
+ This is a special case of :class:`Select` where all outputs are a single row.
1239
+ """
1240
+
1241
+ __slots__ = ("exprs",)
1242
+ _non_child = ("schema", "exprs")
1243
+ exprs: tuple[expr.NamedExpr, ...]
1244
+ """List of expressions to evaluate to form the new dataframe."""
1245
+
1246
+ def __init__(
1247
+ self, schema: Schema, exprs: Sequence[expr.NamedExpr], df: IR
1248
+ ): # pragma: no cover; polars doesn't emit this node yet
1249
+ self.schema = schema
1250
+ self.exprs = tuple(exprs)
1251
+ self.children = (df,)
1252
+ self._non_child_args = (self.exprs,)
1253
+
1254
+ @classmethod
1255
+ @nvtx_annotate_cudf_polars(message="Reduce")
1256
+ def do_evaluate(
1257
+ cls,
1258
+ exprs: tuple[expr.NamedExpr, ...],
1259
+ df: DataFrame,
1260
+ ) -> DataFrame: # pragma: no cover; not exposed by polars yet
1261
+ """Evaluate and return a dataframe."""
1262
+ columns = broadcast(*(e.evaluate(df) for e in exprs))
1263
+ assert all(column.size == 1 for column in columns)
1264
+ return DataFrame(columns)
1265
+
1266
+
1267
+ class Rolling(IR):
1268
+ """Perform a (possibly grouped) rolling aggregation."""
1269
+
1270
+ __slots__ = (
1271
+ "agg_requests",
1272
+ "closed_window",
1273
+ "following",
1274
+ "index",
1275
+ "keys",
1276
+ "preceding",
1277
+ "zlice",
1278
+ )
1279
+ _non_child = (
1280
+ "schema",
1281
+ "index",
1282
+ "preceding",
1283
+ "following",
1284
+ "closed_window",
1285
+ "keys",
1286
+ "agg_requests",
1287
+ "zlice",
1288
+ )
1289
+ index: expr.NamedExpr
1290
+ """Column being rolled over."""
1291
+ preceding: plc.Scalar
1292
+ """Preceding window extent defining start of window."""
1293
+ following: plc.Scalar
1294
+ """Following window extent defining end of window."""
1295
+ closed_window: ClosedInterval
1296
+ """Treatment of window endpoints."""
1297
+ keys: tuple[expr.NamedExpr, ...]
1298
+ """Grouping keys."""
1299
+ agg_requests: tuple[expr.NamedExpr, ...]
1300
+ """Aggregation expressions."""
1301
+ zlice: Zlice | None
1302
+ """Optional slice"""
1303
+
1304
+ def __init__(
1305
+ self,
1306
+ schema: Schema,
1307
+ index: expr.NamedExpr,
1308
+ preceding: plc.Scalar,
1309
+ following: plc.Scalar,
1310
+ closed_window: ClosedInterval,
1311
+ keys: Sequence[expr.NamedExpr],
1312
+ agg_requests: Sequence[expr.NamedExpr],
1313
+ zlice: Zlice | None,
1314
+ df: IR,
1315
+ ):
1316
+ self.schema = schema
1317
+ self.index = index
1318
+ self.preceding = preceding
1319
+ self.following = following
1320
+ self.closed_window = closed_window
1321
+ self.keys = tuple(keys)
1322
+ self.agg_requests = tuple(agg_requests)
1323
+ if not all(
1324
+ plc.rolling.is_valid_rolling_aggregation(
1325
+ agg.value.dtype.plc, agg.value.agg_request
1326
+ )
1327
+ for agg in self.agg_requests
1328
+ ):
1329
+ raise NotImplementedError("Unsupported rolling aggregation")
1330
+ if any(
1331
+ agg.value.agg_request.kind() == plc.aggregation.Kind.COLLECT_LIST
1332
+ for agg in self.agg_requests
1333
+ ):
1334
+ raise NotImplementedError(
1335
+ "Incorrect handling of empty groups for list collection"
1336
+ )
1337
+
1338
+ self.zlice = zlice
1339
+ self.children = (df,)
1340
+ self._non_child_args = (
1341
+ index,
1342
+ preceding,
1343
+ following,
1344
+ closed_window,
1345
+ keys,
1346
+ agg_requests,
1347
+ zlice,
1348
+ )
1349
+
1350
+ @classmethod
1351
+ @nvtx_annotate_cudf_polars(message="Rolling")
1352
+ def do_evaluate(
1353
+ cls,
1354
+ index: expr.NamedExpr,
1355
+ preceding: plc.Scalar,
1356
+ following: plc.Scalar,
1357
+ closed_window: ClosedInterval,
1358
+ keys_in: Sequence[expr.NamedExpr],
1359
+ aggs: Sequence[expr.NamedExpr],
1360
+ zlice: Zlice | None,
1361
+ df: DataFrame,
1362
+ ) -> DataFrame:
1363
+ """Evaluate and return a dataframe."""
1364
+ keys = broadcast(*(k.evaluate(df) for k in keys_in), target_length=df.num_rows)
1365
+ orderby = index.evaluate(df)
1366
+ # Polars casts integral orderby to int64, but only for calculating window bounds
1367
+ if (
1368
+ plc.traits.is_integral(orderby.obj.type())
1369
+ and orderby.obj.type().id() != plc.TypeId.INT64
1370
+ ):
1371
+ orderby_obj = plc.unary.cast(orderby.obj, plc.DataType(plc.TypeId.INT64))
1372
+ else:
1373
+ orderby_obj = orderby.obj
1374
+ preceding_window, following_window = range_window_bounds(
1375
+ preceding, following, closed_window
1376
+ )
1377
+ if orderby.obj.null_count() != 0:
1378
+ raise RuntimeError(
1379
+ f"Index column '{index.name}' in rolling may not contain nulls"
1380
+ )
1381
+ if len(keys_in) > 0:
1382
+ # Must always check sortedness
1383
+ table = plc.Table([*(k.obj for k in keys), orderby_obj])
1384
+ n = table.num_columns()
1385
+ if not plc.sorting.is_sorted(
1386
+ table, [plc.types.Order.ASCENDING] * n, [plc.types.NullOrder.BEFORE] * n
1387
+ ):
1388
+ raise RuntimeError("Input for grouped rolling is not sorted")
1389
+ else:
1390
+ if not orderby.check_sorted(
1391
+ order=plc.types.Order.ASCENDING, null_order=plc.types.NullOrder.BEFORE
1392
+ ):
1393
+ raise RuntimeError(
1394
+ f"Index column '{index.name}' in rolling is not sorted, please sort first"
1395
+ )
1396
+ values = plc.rolling.grouped_range_rolling_window(
1397
+ plc.Table([k.obj for k in keys]),
1398
+ orderby_obj,
1399
+ plc.types.Order.ASCENDING, # Polars requires ascending orderby.
1400
+ plc.types.NullOrder.BEFORE, # Doesn't matter, polars doesn't allow nulls in orderby
1401
+ preceding_window,
1402
+ following_window,
1403
+ [rolling.to_request(request.value, orderby, df) for request in aggs],
1404
+ )
1405
+ return DataFrame(
1406
+ itertools.chain(
1407
+ keys,
1408
+ [orderby],
1409
+ (
1410
+ Column(col, name=request.name, dtype=request.value.dtype)
1411
+ for col, request in zip(values.columns(), aggs, strict=True)
1412
+ ),
1413
+ )
1414
+ ).slice(zlice)
1415
+
1416
+
1417
+ class GroupBy(IR):
1418
+ """Perform a groupby."""
1419
+
1420
+ __slots__ = (
1421
+ "agg_requests",
1422
+ "keys",
1423
+ "maintain_order",
1424
+ "zlice",
1425
+ )
1426
+ _non_child = (
1427
+ "schema",
1428
+ "keys",
1429
+ "agg_requests",
1430
+ "maintain_order",
1431
+ "zlice",
1432
+ )
1433
+ keys: tuple[expr.NamedExpr, ...]
1434
+ """Grouping keys."""
1435
+ agg_requests: tuple[expr.NamedExpr, ...]
1436
+ """Aggregation expressions."""
1437
+ maintain_order: bool
1438
+ """Preserve order in groupby."""
1439
+ zlice: Zlice | None
1440
+ """Optional slice to apply after grouping."""
1441
+
1442
+ def __init__(
1443
+ self,
1444
+ schema: Schema,
1445
+ keys: Sequence[expr.NamedExpr],
1446
+ agg_requests: Sequence[expr.NamedExpr],
1447
+ maintain_order: bool, # noqa: FBT001
1448
+ zlice: Zlice | None,
1449
+ df: IR,
1450
+ ):
1451
+ self.schema = schema
1452
+ self.keys = tuple(keys)
1453
+ for request in agg_requests:
1454
+ expr = request.value
1455
+ if isinstance(expr, unary.UnaryFunction) and expr.name == "value_counts":
1456
+ raise NotImplementedError("value_counts is not supported in groupby")
1457
+ if any(
1458
+ isinstance(child, unary.UnaryFunction) and child.name == "value_counts"
1459
+ for child in expr.children
1460
+ ):
1461
+ raise NotImplementedError("value_counts is not supported in groupby")
1462
+ self.agg_requests = tuple(agg_requests)
1463
+ self.maintain_order = maintain_order
1464
+ self.zlice = zlice
1465
+ self.children = (df,)
1466
+ self._non_child_args = (
1467
+ schema,
1468
+ self.keys,
1469
+ self.agg_requests,
1470
+ maintain_order,
1471
+ self.zlice,
1472
+ )
1473
+
1474
+ @classmethod
1475
+ @nvtx_annotate_cudf_polars(message="GroupBy")
1476
+ def do_evaluate(
1477
+ cls,
1478
+ schema: Schema,
1479
+ keys_in: Sequence[expr.NamedExpr],
1480
+ agg_requests: Sequence[expr.NamedExpr],
1481
+ maintain_order: bool, # noqa: FBT001
1482
+ zlice: Zlice | None,
1483
+ df: DataFrame,
1484
+ ) -> DataFrame:
1485
+ """Evaluate and return a dataframe."""
1486
+ keys = broadcast(*(k.evaluate(df) for k in keys_in), target_length=df.num_rows)
1487
+ sorted = (
1488
+ plc.types.Sorted.YES
1489
+ if all(k.is_sorted for k in keys)
1490
+ else plc.types.Sorted.NO
1491
+ )
1492
+ grouper = plc.groupby.GroupBy(
1493
+ plc.Table([k.obj for k in keys]),
1494
+ null_handling=plc.types.NullPolicy.INCLUDE,
1495
+ keys_are_sorted=sorted,
1496
+ column_order=[k.order for k in keys],
1497
+ null_precedence=[k.null_order for k in keys],
1498
+ )
1499
+ requests = []
1500
+ names = []
1501
+ for request in agg_requests:
1502
+ name = request.name
1503
+ value = request.value
1504
+ if isinstance(value, expr.Len):
1505
+ # A count aggregation, we need a column so use a key column
1506
+ col = keys[0].obj
1507
+ elif isinstance(value, expr.Agg):
1508
+ if value.name == "quantile":
1509
+ child = value.children[0]
1510
+ else:
1511
+ (child,) = value.children
1512
+ col = child.evaluate(df, context=ExecutionContext.GROUPBY).obj
1513
+ else:
1514
+ # Anything else, we pre-evaluate
1515
+ col = value.evaluate(df, context=ExecutionContext.GROUPBY).obj
1516
+ requests.append(plc.groupby.GroupByRequest(col, [value.agg_request]))
1517
+ names.append(name)
1518
+ group_keys, raw_tables = grouper.aggregate(requests)
1519
+ results = [
1520
+ Column(column, name=name, dtype=schema[name])
1521
+ for name, column, request in zip(
1522
+ names,
1523
+ itertools.chain.from_iterable(t.columns() for t in raw_tables),
1524
+ agg_requests,
1525
+ strict=True,
1526
+ )
1527
+ ]
1528
+ result_keys = [
1529
+ Column(grouped_key, name=key.name, dtype=key.dtype)
1530
+ for key, grouped_key in zip(keys, group_keys.columns(), strict=True)
1531
+ ]
1532
+ broadcasted = broadcast(*result_keys, *results)
1533
+ # Handle order preservation of groups
1534
+ if maintain_order and not sorted:
1535
+ # The order we want
1536
+ want = plc.stream_compaction.stable_distinct(
1537
+ plc.Table([k.obj for k in keys]),
1538
+ list(range(group_keys.num_columns())),
1539
+ plc.stream_compaction.DuplicateKeepOption.KEEP_FIRST,
1540
+ plc.types.NullEquality.EQUAL,
1541
+ plc.types.NanEquality.ALL_EQUAL,
1542
+ )
1543
+ # The order we have
1544
+ have = plc.Table([key.obj for key in broadcasted[: len(keys)]])
1545
+
1546
+ # We know an inner join is OK because by construction
1547
+ # want and have are permutations of each other.
1548
+ left_order, right_order = plc.join.inner_join(
1549
+ want, have, plc.types.NullEquality.EQUAL
1550
+ )
1551
+ # Now left_order is an arbitrary permutation of the ordering we
1552
+ # want, and right_order is a matching permutation of the ordering
1553
+ # we have. To get to the original ordering, we need
1554
+ # left_order == iota(nrows), with right_order permuted
1555
+ # appropriately. This can be obtained by sorting
1556
+ # right_order by left_order.
1557
+ (right_order,) = plc.sorting.sort_by_key(
1558
+ plc.Table([right_order]),
1559
+ plc.Table([left_order]),
1560
+ [plc.types.Order.ASCENDING],
1561
+ [plc.types.NullOrder.AFTER],
1562
+ ).columns()
1563
+ ordered_table = plc.copying.gather(
1564
+ plc.Table([col.obj for col in broadcasted]),
1565
+ right_order,
1566
+ plc.copying.OutOfBoundsPolicy.DONT_CHECK,
1567
+ )
1568
+ broadcasted = [
1569
+ Column(reordered, name=old.name, dtype=old.dtype)
1570
+ for reordered, old in zip(
1571
+ ordered_table.columns(), broadcasted, strict=True
1572
+ )
1573
+ ]
1574
+ return DataFrame(broadcasted).slice(zlice)
1575
+
1576
+
1577
+ class ConditionalJoin(IR):
1578
+ """A conditional inner join of two dataframes on a predicate."""
1579
+
1580
+ class Predicate:
1581
+ """Serializable wrapper for a predicate expression."""
1582
+
1583
+ predicate: expr.Expr
1584
+ ast: plc.expressions.Expression
1585
+
1586
+ def __init__(self, predicate: expr.Expr):
1587
+ self.predicate = predicate
1588
+ self.ast = to_ast(predicate)
1589
+
1590
+ def __reduce__(self) -> tuple[Any, ...]:
1591
+ """Pickle a Predicate object."""
1592
+ return (type(self), (self.predicate,))
1593
+
1594
+ __slots__ = ("ast_predicate", "options", "predicate")
1595
+ _non_child = ("schema", "predicate", "options")
1596
+ predicate: expr.Expr
1597
+ """Expression predicate to join on"""
1598
+ options: tuple[
1599
+ tuple[
1600
+ str,
1601
+ pl_expr.Operator | Iterable[pl_expr.Operator],
1602
+ ],
1603
+ bool,
1604
+ Zlice | None,
1605
+ str,
1606
+ bool,
1607
+ Literal["none", "left", "right", "left_right", "right_left"],
1608
+ ]
1609
+ """
1610
+ tuple of options:
1611
+ - predicates: tuple of ir join type (eg. ie_join) and (In)Equality conditions
1612
+ - nulls_equal: do nulls compare equal?
1613
+ - slice: optional slice to perform after joining.
1614
+ - suffix: string suffix for right columns if names match
1615
+ - coalesce: should key columns be coalesced (only makes sense for outer joins)
1616
+ - maintain_order: which DataFrame row order to preserve, if any
1617
+ """
1618
+
1619
+ def __init__(
1620
+ self, schema: Schema, predicate: expr.Expr, options: tuple, left: IR, right: IR
1621
+ ) -> None:
1622
+ self.schema = schema
1623
+ self.predicate = predicate
1624
+ self.options = options
1625
+ self.children = (left, right)
1626
+ predicate_wrapper = self.Predicate(predicate)
1627
+ _, nulls_equal, zlice, suffix, coalesce, maintain_order = self.options
1628
+ # Preconditions from polars
1629
+ assert not nulls_equal
1630
+ assert not coalesce
1631
+ assert maintain_order == "none"
1632
+ if predicate_wrapper.ast is None:
1633
+ raise NotImplementedError(
1634
+ f"Conditional join with predicate {predicate}"
1635
+ ) # pragma: no cover; polars never delivers expressions we can't handle
1636
+ self._non_child_args = (predicate_wrapper, zlice, suffix, maintain_order)
1637
+
1638
+ @classmethod
1639
+ @nvtx_annotate_cudf_polars(message="ConditionalJoin")
1640
+ def do_evaluate(
1641
+ cls,
1642
+ predicate_wrapper: Predicate,
1643
+ zlice: Zlice | None,
1644
+ suffix: str,
1645
+ maintain_order: Literal["none", "left", "right", "left_right", "right_left"],
1646
+ left: DataFrame,
1647
+ right: DataFrame,
1648
+ ) -> DataFrame:
1649
+ """Evaluate and return a dataframe."""
1650
+ lg, rg = plc.join.conditional_inner_join(
1651
+ left.table,
1652
+ right.table,
1653
+ predicate_wrapper.ast,
1654
+ )
1655
+ left = DataFrame.from_table(
1656
+ plc.copying.gather(
1657
+ left.table, lg, plc.copying.OutOfBoundsPolicy.DONT_CHECK
1658
+ ),
1659
+ left.column_names,
1660
+ left.dtypes,
1661
+ )
1662
+ right = DataFrame.from_table(
1663
+ plc.copying.gather(
1664
+ right.table, rg, plc.copying.OutOfBoundsPolicy.DONT_CHECK
1665
+ ),
1666
+ right.column_names,
1667
+ right.dtypes,
1668
+ )
1669
+ right = right.rename_columns(
1670
+ {
1671
+ name: f"{name}{suffix}"
1672
+ for name in right.column_names
1673
+ if name in left.column_names_set
1674
+ }
1675
+ )
1676
+ result = left.with_columns(right.columns)
1677
+ return result.slice(zlice)
1678
+
1679
+
1680
+ class Join(IR):
1681
+ """A join of two dataframes."""
1682
+
1683
+ __slots__ = ("left_on", "options", "right_on")
1684
+ _non_child = ("schema", "left_on", "right_on", "options")
1685
+ left_on: tuple[expr.NamedExpr, ...]
1686
+ """List of expressions used as keys in the left frame."""
1687
+ right_on: tuple[expr.NamedExpr, ...]
1688
+ """List of expressions used as keys in the right frame."""
1689
+ options: tuple[
1690
+ Literal["Inner", "Left", "Right", "Full", "Semi", "Anti", "Cross"],
1691
+ bool,
1692
+ Zlice | None,
1693
+ str,
1694
+ bool,
1695
+ Literal["none", "left", "right", "left_right", "right_left"],
1696
+ ]
1697
+ """
1698
+ tuple of options:
1699
+ - how: join type
1700
+ - nulls_equal: do nulls compare equal?
1701
+ - slice: optional slice to perform after joining.
1702
+ - suffix: string suffix for right columns if names match
1703
+ - coalesce: should key columns be coalesced (only makes sense for outer joins)
1704
+ - maintain_order: which DataFrame row order to preserve, if any
1705
+ """
1706
+
1707
+ def __init__(
1708
+ self,
1709
+ schema: Schema,
1710
+ left_on: Sequence[expr.NamedExpr],
1711
+ right_on: Sequence[expr.NamedExpr],
1712
+ options: Any,
1713
+ left: IR,
1714
+ right: IR,
1715
+ ):
1716
+ self.schema = schema
1717
+ self.left_on = tuple(left_on)
1718
+ self.right_on = tuple(right_on)
1719
+ self.options = options
1720
+ self.children = (left, right)
1721
+ self._non_child_args = (self.left_on, self.right_on, self.options)
1722
+ # TODO: Implement maintain_order
1723
+ if options[5] != "none":
1724
+ raise NotImplementedError("maintain_order not implemented yet")
1725
+
1726
+ @staticmethod
1727
+ @cache
1728
+ def _joiners(
1729
+ how: Literal["Inner", "Left", "Right", "Full", "Semi", "Anti"],
1730
+ ) -> tuple[
1731
+ Callable, plc.copying.OutOfBoundsPolicy, plc.copying.OutOfBoundsPolicy | None
1732
+ ]:
1733
+ if how == "Inner":
1734
+ return (
1735
+ plc.join.inner_join,
1736
+ plc.copying.OutOfBoundsPolicy.DONT_CHECK,
1737
+ plc.copying.OutOfBoundsPolicy.DONT_CHECK,
1738
+ )
1739
+ elif how == "Left" or how == "Right":
1740
+ return (
1741
+ plc.join.left_join,
1742
+ plc.copying.OutOfBoundsPolicy.DONT_CHECK,
1743
+ plc.copying.OutOfBoundsPolicy.NULLIFY,
1744
+ )
1745
+ elif how == "Full":
1746
+ return (
1747
+ plc.join.full_join,
1748
+ plc.copying.OutOfBoundsPolicy.NULLIFY,
1749
+ plc.copying.OutOfBoundsPolicy.NULLIFY,
1750
+ )
1751
+ elif how == "Semi":
1752
+ return (
1753
+ plc.join.left_semi_join,
1754
+ plc.copying.OutOfBoundsPolicy.DONT_CHECK,
1755
+ None,
1756
+ )
1757
+ elif how == "Anti":
1758
+ return (
1759
+ plc.join.left_anti_join,
1760
+ plc.copying.OutOfBoundsPolicy.DONT_CHECK,
1761
+ None,
1762
+ )
1763
+ assert_never(how) # pragma: no cover
1764
+
1765
+ @staticmethod
1766
+ def _reorder_maps(
1767
+ left_rows: int,
1768
+ lg: plc.Column,
1769
+ left_policy: plc.copying.OutOfBoundsPolicy,
1770
+ right_rows: int,
1771
+ rg: plc.Column,
1772
+ right_policy: plc.copying.OutOfBoundsPolicy,
1773
+ ) -> list[plc.Column]:
1774
+ """
1775
+ Reorder gather maps to satisfy polars join order restrictions.
1776
+
1777
+ Parameters
1778
+ ----------
1779
+ left_rows
1780
+ Number of rows in left table
1781
+ lg
1782
+ Left gather map
1783
+ left_policy
1784
+ Nullify policy for left map
1785
+ right_rows
1786
+ Number of rows in right table
1787
+ rg
1788
+ Right gather map
1789
+ right_policy
1790
+ Nullify policy for right map
1791
+
1792
+ Returns
1793
+ -------
1794
+ list of reordered left and right gather maps.
1795
+
1796
+ Notes
1797
+ -----
1798
+ For a left join, the polars result preserves the order of the
1799
+ left keys, and is stable wrt the right keys. For all other
1800
+ joins, there is no order obligation.
1801
+ """
1802
+ init = plc.Scalar.from_py(0, plc.types.SIZE_TYPE)
1803
+ step = plc.Scalar.from_py(1, plc.types.SIZE_TYPE)
1804
+ left_order = plc.copying.gather(
1805
+ plc.Table([plc.filling.sequence(left_rows, init, step)]), lg, left_policy
1806
+ )
1807
+ right_order = plc.copying.gather(
1808
+ plc.Table([plc.filling.sequence(right_rows, init, step)]), rg, right_policy
1809
+ )
1810
+ return plc.sorting.stable_sort_by_key(
1811
+ plc.Table([lg, rg]),
1812
+ plc.Table([*left_order.columns(), *right_order.columns()]),
1813
+ [plc.types.Order.ASCENDING, plc.types.Order.ASCENDING],
1814
+ [plc.types.NullOrder.AFTER, plc.types.NullOrder.AFTER],
1815
+ ).columns()
1816
+
1817
+ @staticmethod
1818
+ def _build_columns(
1819
+ columns: Iterable[plc.Column],
1820
+ template: Iterable[NamedColumn],
1821
+ *,
1822
+ left: bool = True,
1823
+ empty: bool = False,
1824
+ rename: Callable[[str], str] = lambda name: name,
1825
+ ) -> list[Column]:
1826
+ if empty:
1827
+ return [
1828
+ Column(
1829
+ plc.column_factories.make_empty_column(col.dtype.plc),
1830
+ col.dtype,
1831
+ name=rename(col.name),
1832
+ )
1833
+ for col in template
1834
+ ]
1835
+
1836
+ columns = [
1837
+ Column(new, col.dtype, name=rename(col.name))
1838
+ for new, col in zip(columns, template, strict=True)
1839
+ ]
1840
+
1841
+ if left:
1842
+ columns = [
1843
+ col.sorted_like(orig)
1844
+ for col, orig in zip(columns, template, strict=True)
1845
+ ]
1846
+
1847
+ return columns
1848
+
1849
+ @classmethod
1850
+ @nvtx_annotate_cudf_polars(message="Join")
1851
+ def do_evaluate(
1852
+ cls,
1853
+ left_on_exprs: Sequence[expr.NamedExpr],
1854
+ right_on_exprs: Sequence[expr.NamedExpr],
1855
+ options: tuple[
1856
+ Literal["Inner", "Left", "Right", "Full", "Semi", "Anti", "Cross"],
1857
+ bool,
1858
+ Zlice | None,
1859
+ str,
1860
+ bool,
1861
+ Literal["none", "left", "right", "left_right", "right_left"],
1862
+ ],
1863
+ left: DataFrame,
1864
+ right: DataFrame,
1865
+ ) -> DataFrame:
1866
+ """Evaluate and return a dataframe."""
1867
+ how, nulls_equal, zlice, suffix, coalesce, _ = options
1868
+ if how == "Cross":
1869
+ # Separate implementation, since cross_join returns the
1870
+ # result, not the gather maps
1871
+ if right.num_rows == 0:
1872
+ left_cols = Join._build_columns([], left.columns, empty=True)
1873
+ right_cols = Join._build_columns(
1874
+ [],
1875
+ right.columns,
1876
+ left=False,
1877
+ empty=True,
1878
+ rename=lambda name: name
1879
+ if name not in left.column_names_set
1880
+ else f"{name}{suffix}",
1881
+ )
1882
+ return DataFrame([*left_cols, *right_cols])
1883
+
1884
+ columns = plc.join.cross_join(left.table, right.table).columns()
1885
+ left_cols = Join._build_columns(
1886
+ columns[: left.num_columns],
1887
+ left.columns,
1888
+ )
1889
+ right_cols = Join._build_columns(
1890
+ columns[left.num_columns :],
1891
+ right.columns,
1892
+ rename=lambda name: name
1893
+ if name not in left.column_names_set
1894
+ else f"{name}{suffix}",
1895
+ left=False,
1896
+ )
1897
+ return DataFrame([*left_cols, *right_cols]).slice(zlice)
1898
+ # TODO: Waiting on clarity based on https://github.com/pola-rs/polars/issues/17184
1899
+ left_on = DataFrame(broadcast(*(e.evaluate(left) for e in left_on_exprs)))
1900
+ right_on = DataFrame(broadcast(*(e.evaluate(right) for e in right_on_exprs)))
1901
+ null_equality = (
1902
+ plc.types.NullEquality.EQUAL
1903
+ if nulls_equal
1904
+ else plc.types.NullEquality.UNEQUAL
1905
+ )
1906
+ join_fn, left_policy, right_policy = cls._joiners(how)
1907
+ if right_policy is None:
1908
+ # Semi join
1909
+ lg = join_fn(left_on.table, right_on.table, null_equality)
1910
+ table = plc.copying.gather(left.table, lg, left_policy)
1911
+ result = DataFrame.from_table(table, left.column_names, left.dtypes)
1912
+ else:
1913
+ if how == "Right":
1914
+ # Right join is a left join with the tables swapped
1915
+ left, right = right, left
1916
+ left_on, right_on = right_on, left_on
1917
+ lg, rg = join_fn(left_on.table, right_on.table, null_equality)
1918
+ if how == "Left" or how == "Right":
1919
+ # Order of left table is preserved
1920
+ lg, rg = cls._reorder_maps(
1921
+ left.num_rows, lg, left_policy, right.num_rows, rg, right_policy
1922
+ )
1923
+ if coalesce:
1924
+ if how == "Full":
1925
+ # In this case, keys must be column references,
1926
+ # possibly with dtype casting. We should use them in
1927
+ # preference to the columns from the original tables.
1928
+ left = left.with_columns(left_on.columns, replace_only=True)
1929
+ right = right.with_columns(right_on.columns, replace_only=True)
1930
+ else:
1931
+ right = right.discard_columns(right_on.column_names_set)
1932
+ left = DataFrame.from_table(
1933
+ plc.copying.gather(left.table, lg, left_policy),
1934
+ left.column_names,
1935
+ left.dtypes,
1936
+ )
1937
+ right = DataFrame.from_table(
1938
+ plc.copying.gather(right.table, rg, right_policy),
1939
+ right.column_names,
1940
+ right.dtypes,
1941
+ )
1942
+ if coalesce and how == "Full":
1943
+ left = left.with_columns(
1944
+ (
1945
+ Column(
1946
+ plc.replace.replace_nulls(left_col.obj, right_col.obj),
1947
+ name=left_col.name,
1948
+ dtype=left_col.dtype,
1949
+ )
1950
+ for left_col, right_col in zip(
1951
+ left.select_columns(left_on.column_names_set),
1952
+ right.select_columns(right_on.column_names_set),
1953
+ strict=True,
1954
+ )
1955
+ ),
1956
+ replace_only=True,
1957
+ )
1958
+ right = right.discard_columns(right_on.column_names_set)
1959
+ if how == "Right":
1960
+ # Undo the swap for right join before gluing together.
1961
+ left, right = right, left
1962
+ right = right.rename_columns(
1963
+ {
1964
+ name: f"{name}{suffix}"
1965
+ for name in right.column_names
1966
+ if name in left.column_names_set
1967
+ }
1968
+ )
1969
+ result = left.with_columns(right.columns)
1970
+ return result.slice(zlice)
1971
+
1972
+
1973
+ class HStack(IR):
1974
+ """Add new columns to a dataframe."""
1975
+
1976
+ __slots__ = ("columns", "should_broadcast")
1977
+ _non_child = ("schema", "columns", "should_broadcast")
1978
+ should_broadcast: bool
1979
+ """Should the resulting evaluated columns be broadcast to the same length."""
1980
+
1981
+ def __init__(
1982
+ self,
1983
+ schema: Schema,
1984
+ columns: Sequence[expr.NamedExpr],
1985
+ should_broadcast: bool, # noqa: FBT001
1986
+ df: IR,
1987
+ ):
1988
+ self.schema = schema
1989
+ self.columns = tuple(columns)
1990
+ self.should_broadcast = should_broadcast
1991
+ self._non_child_args = (self.columns, self.should_broadcast)
1992
+ self.children = (df,)
1993
+
1994
+ @classmethod
1995
+ @nvtx_annotate_cudf_polars(message="HStack")
1996
+ def do_evaluate(
1997
+ cls,
1998
+ exprs: Sequence[expr.NamedExpr],
1999
+ should_broadcast: bool, # noqa: FBT001
2000
+ df: DataFrame,
2001
+ ) -> DataFrame:
2002
+ """Evaluate and return a dataframe."""
2003
+ columns = [c.evaluate(df) for c in exprs]
2004
+ if should_broadcast:
2005
+ columns = broadcast(
2006
+ *columns, target_length=df.num_rows if df.num_columns != 0 else None
2007
+ )
2008
+ else:
2009
+ # Polars ensures this is true, but let's make sure nothing
2010
+ # went wrong. In this case, the parent node is a
2011
+ # guaranteed to be a Select which will take care of making
2012
+ # sure that everything is the same length. The result
2013
+ # table that might have mismatching column lengths will
2014
+ # never be turned into a pylibcudf Table with all columns
2015
+ # by the Select, which is why this is safe.
2016
+ assert all(e.name.startswith("__POLARS_CSER_0x") for e in exprs)
2017
+ return df.with_columns(columns)
2018
+
2019
+
2020
+ class Distinct(IR):
2021
+ """Produce a new dataframe with distinct rows."""
2022
+
2023
+ __slots__ = ("keep", "stable", "subset", "zlice")
2024
+ _non_child = ("schema", "keep", "subset", "zlice", "stable")
2025
+ keep: plc.stream_compaction.DuplicateKeepOption
2026
+ """Which distinct value to keep."""
2027
+ subset: frozenset[str] | None
2028
+ """Which columns should be used to define distinctness. If None,
2029
+ then all columns are used."""
2030
+ zlice: Zlice | None
2031
+ """Optional slice to apply to the result."""
2032
+ stable: bool
2033
+ """Should the result maintain ordering."""
2034
+
2035
+ def __init__(
2036
+ self,
2037
+ schema: Schema,
2038
+ keep: plc.stream_compaction.DuplicateKeepOption,
2039
+ subset: frozenset[str] | None,
2040
+ zlice: Zlice | None,
2041
+ stable: bool, # noqa: FBT001
2042
+ df: IR,
2043
+ ):
2044
+ self.schema = schema
2045
+ self.keep = keep
2046
+ self.subset = subset
2047
+ self.zlice = zlice
2048
+ self.stable = stable
2049
+ self._non_child_args = (keep, subset, zlice, stable)
2050
+ self.children = (df,)
2051
+
2052
+ _KEEP_MAP: ClassVar[dict[str, plc.stream_compaction.DuplicateKeepOption]] = {
2053
+ "first": plc.stream_compaction.DuplicateKeepOption.KEEP_FIRST,
2054
+ "last": plc.stream_compaction.DuplicateKeepOption.KEEP_LAST,
2055
+ "none": plc.stream_compaction.DuplicateKeepOption.KEEP_NONE,
2056
+ "any": plc.stream_compaction.DuplicateKeepOption.KEEP_ANY,
2057
+ }
2058
+
2059
+ @classmethod
2060
+ @nvtx_annotate_cudf_polars(message="Distinct")
2061
+ def do_evaluate(
2062
+ cls,
2063
+ keep: plc.stream_compaction.DuplicateKeepOption,
2064
+ subset: frozenset[str] | None,
2065
+ zlice: Zlice | None,
2066
+ stable: bool, # noqa: FBT001
2067
+ df: DataFrame,
2068
+ ) -> DataFrame:
2069
+ """Evaluate and return a dataframe."""
2070
+ if subset is None:
2071
+ indices = list(range(df.num_columns))
2072
+ keys_sorted = all(c.is_sorted for c in df.column_map.values())
2073
+ else:
2074
+ indices = [i for i, k in enumerate(df.column_names) if k in subset]
2075
+ keys_sorted = all(df.column_map[name].is_sorted for name in subset)
2076
+ if keys_sorted:
2077
+ table = plc.stream_compaction.unique(
2078
+ df.table,
2079
+ indices,
2080
+ keep,
2081
+ plc.types.NullEquality.EQUAL,
2082
+ )
2083
+ else:
2084
+ distinct = (
2085
+ plc.stream_compaction.stable_distinct
2086
+ if stable
2087
+ else plc.stream_compaction.distinct
2088
+ )
2089
+ table = distinct(
2090
+ df.table,
2091
+ indices,
2092
+ keep,
2093
+ plc.types.NullEquality.EQUAL,
2094
+ plc.types.NanEquality.ALL_EQUAL,
2095
+ )
2096
+ # TODO: Is this sortedness setting correct
2097
+ result = DataFrame(
2098
+ [
2099
+ Column(new, name=old.name, dtype=old.dtype).sorted_like(old)
2100
+ for new, old in zip(table.columns(), df.columns, strict=True)
2101
+ ]
2102
+ )
2103
+ if keys_sorted or stable:
2104
+ result = result.sorted_like(df)
2105
+ return result.slice(zlice)
2106
+
2107
+
2108
+ class Sort(IR):
2109
+ """Sort a dataframe."""
2110
+
2111
+ __slots__ = ("by", "null_order", "order", "stable", "zlice")
2112
+ _non_child = ("schema", "by", "order", "null_order", "stable", "zlice")
2113
+ by: tuple[expr.NamedExpr, ...]
2114
+ """Sort keys."""
2115
+ order: tuple[plc.types.Order, ...]
2116
+ """Sort order for each sort key."""
2117
+ null_order: tuple[plc.types.NullOrder, ...]
2118
+ """Null sorting location for each sort key."""
2119
+ stable: bool
2120
+ """Should the sort be stable?"""
2121
+ zlice: Zlice | None
2122
+ """Optional slice to apply to the result."""
2123
+
2124
+ def __init__(
2125
+ self,
2126
+ schema: Schema,
2127
+ by: Sequence[expr.NamedExpr],
2128
+ order: Sequence[plc.types.Order],
2129
+ null_order: Sequence[plc.types.NullOrder],
2130
+ stable: bool, # noqa: FBT001
2131
+ zlice: Zlice | None,
2132
+ df: IR,
2133
+ ):
2134
+ self.schema = schema
2135
+ self.by = tuple(by)
2136
+ self.order = tuple(order)
2137
+ self.null_order = tuple(null_order)
2138
+ self.stable = stable
2139
+ self.zlice = zlice
2140
+ self._non_child_args = (
2141
+ self.by,
2142
+ self.order,
2143
+ self.null_order,
2144
+ self.stable,
2145
+ self.zlice,
2146
+ )
2147
+ self.children = (df,)
2148
+
2149
+ @classmethod
2150
+ @nvtx_annotate_cudf_polars(message="Sort")
2151
+ def do_evaluate(
2152
+ cls,
2153
+ by: Sequence[expr.NamedExpr],
2154
+ order: Sequence[plc.types.Order],
2155
+ null_order: Sequence[plc.types.NullOrder],
2156
+ stable: bool, # noqa: FBT001
2157
+ zlice: Zlice | None,
2158
+ df: DataFrame,
2159
+ ) -> DataFrame:
2160
+ """Evaluate and return a dataframe."""
2161
+ sort_keys = broadcast(*(k.evaluate(df) for k in by), target_length=df.num_rows)
2162
+ do_sort = plc.sorting.stable_sort_by_key if stable else plc.sorting.sort_by_key
2163
+ table = do_sort(
2164
+ df.table,
2165
+ plc.Table([k.obj for k in sort_keys]),
2166
+ list(order),
2167
+ list(null_order),
2168
+ )
2169
+ result = DataFrame.from_table(table, df.column_names, df.dtypes)
2170
+ first_key = sort_keys[0]
2171
+ name = by[0].name
2172
+ first_key_in_result = (
2173
+ name in df.column_map and first_key.obj is df.column_map[name].obj
2174
+ )
2175
+ if first_key_in_result:
2176
+ result.column_map[name].set_sorted(
2177
+ is_sorted=plc.types.Sorted.YES, order=order[0], null_order=null_order[0]
2178
+ )
2179
+ return result.slice(zlice)
2180
+
2181
+
2182
+ class Slice(IR):
2183
+ """Slice a dataframe."""
2184
+
2185
+ __slots__ = ("length", "offset")
2186
+ _non_child = ("schema", "offset", "length")
2187
+ offset: int
2188
+ """Start of the slice."""
2189
+ length: int | None
2190
+ """Length of the slice."""
2191
+
2192
+ def __init__(self, schema: Schema, offset: int, length: int | None, df: IR):
2193
+ self.schema = schema
2194
+ self.offset = offset
2195
+ self.length = length
2196
+ self._non_child_args = (offset, length)
2197
+ self.children = (df,)
2198
+
2199
+ @classmethod
2200
+ @nvtx_annotate_cudf_polars(message="Slice")
2201
+ def do_evaluate(cls, offset: int, length: int, df: DataFrame) -> DataFrame:
2202
+ """Evaluate and return a dataframe."""
2203
+ return df.slice((offset, length))
2204
+
2205
+
2206
+ class Filter(IR):
2207
+ """Filter a dataframe with a boolean mask."""
2208
+
2209
+ __slots__ = ("mask",)
2210
+ _non_child = ("schema", "mask")
2211
+ mask: expr.NamedExpr
2212
+ """Expression to produce the filter mask."""
2213
+
2214
+ def __init__(self, schema: Schema, mask: expr.NamedExpr, df: IR):
2215
+ self.schema = schema
2216
+ self.mask = mask
2217
+ self._non_child_args = (mask,)
2218
+ self.children = (df,)
2219
+
2220
+ @classmethod
2221
+ @nvtx_annotate_cudf_polars(message="Filter")
2222
+ def do_evaluate(cls, mask_expr: expr.NamedExpr, df: DataFrame) -> DataFrame:
2223
+ """Evaluate and return a dataframe."""
2224
+ (mask,) = broadcast(mask_expr.evaluate(df), target_length=df.num_rows)
2225
+ return df.filter(mask)
2226
+
2227
+
2228
+ class Projection(IR):
2229
+ """Select a subset of columns from a dataframe."""
2230
+
2231
+ __slots__ = ()
2232
+ _non_child = ("schema",)
2233
+
2234
+ def __init__(self, schema: Schema, df: IR):
2235
+ self.schema = schema
2236
+ self._non_child_args = (schema,)
2237
+ self.children = (df,)
2238
+
2239
+ @classmethod
2240
+ @nvtx_annotate_cudf_polars(message="Projection")
2241
+ def do_evaluate(cls, schema: Schema, df: DataFrame) -> DataFrame:
2242
+ """Evaluate and return a dataframe."""
2243
+ # This can reorder things.
2244
+ columns = broadcast(
2245
+ *(df.column_map[name] for name in schema), target_length=df.num_rows
2246
+ )
2247
+ return DataFrame(columns)
2248
+
2249
+
2250
+ class MergeSorted(IR):
2251
+ """Merge sorted operation."""
2252
+
2253
+ __slots__ = ("key",)
2254
+ _non_child = ("schema", "key")
2255
+ key: str
2256
+ """Key that is sorted."""
2257
+
2258
+ def __init__(self, schema: Schema, key: str, left: IR, right: IR):
2259
+ # Children must be Sort or Repartition(Sort).
2260
+ # The Repartition(Sort) case happens during fallback.
2261
+ left_sort_child = left if isinstance(left, Sort) else left.children[0]
2262
+ right_sort_child = right if isinstance(right, Sort) else right.children[0]
2263
+ assert isinstance(left_sort_child, Sort)
2264
+ assert isinstance(right_sort_child, Sort)
2265
+ assert left_sort_child.order == right_sort_child.order
2266
+ assert len(left.schema.keys()) <= len(right.schema.keys())
2267
+ self.schema = schema
2268
+ self.key = key
2269
+ self.children = (left, right)
2270
+ self._non_child_args = (key,)
2271
+
2272
+ @classmethod
2273
+ @nvtx_annotate_cudf_polars(message="MergeSorted")
2274
+ def do_evaluate(cls, key: str, *dfs: DataFrame) -> DataFrame:
2275
+ """Evaluate and return a dataframe."""
2276
+ left, right = dfs
2277
+ right = right.discard_columns(right.column_names_set - left.column_names_set)
2278
+ on_col_left = left.select_columns({key})[0]
2279
+ on_col_right = right.select_columns({key})[0]
2280
+ return DataFrame.from_table(
2281
+ plc.merge.merge(
2282
+ [right.table, left.table],
2283
+ [left.column_names.index(key), right.column_names.index(key)],
2284
+ [on_col_left.order, on_col_right.order],
2285
+ [on_col_left.null_order, on_col_right.null_order],
2286
+ ),
2287
+ left.column_names,
2288
+ left.dtypes,
2289
+ )
2290
+
2291
+
2292
+ class MapFunction(IR):
2293
+ """Apply some function to a dataframe."""
2294
+
2295
+ __slots__ = ("name", "options")
2296
+ _non_child = ("schema", "name", "options")
2297
+ name: str
2298
+ """Name of the function to apply"""
2299
+ options: Any
2300
+ """Arbitrary name-specific options"""
2301
+
2302
+ _NAMES: ClassVar[frozenset[str]] = frozenset(
2303
+ [
2304
+ "rechunk",
2305
+ "rename",
2306
+ "explode",
2307
+ "unpivot",
2308
+ "row_index",
2309
+ "fast_count",
2310
+ ]
2311
+ )
2312
+
2313
+ def __init__(self, schema: Schema, name: str, options: Any, df: IR):
2314
+ self.schema = schema
2315
+ self.name = name
2316
+ self.options = options
2317
+ self.children = (df,)
2318
+ if (
2319
+ self.name not in MapFunction._NAMES
2320
+ ): # pragma: no cover; need more polars rust functions
2321
+ raise NotImplementedError(
2322
+ f"Unhandled map function {self.name}"
2323
+ ) # pragma: no cover
2324
+ if self.name == "explode":
2325
+ (to_explode,) = self.options
2326
+ if len(to_explode) > 1:
2327
+ # TODO: straightforward, but need to error check
2328
+ # polars requires that all to-explode columns have the
2329
+ # same sub-shapes
2330
+ raise NotImplementedError("Explode with more than one column")
2331
+ self.options = (tuple(to_explode),)
2332
+ elif POLARS_VERSION_LT_131 and self.name == "rename": # pragma: no cover
2333
+ # As of 1.31, polars validates renaming in the IR
2334
+ old, new, strict = self.options
2335
+ if len(new) != len(set(new)) or (
2336
+ set(new) & (set(df.schema.keys()) - set(old))
2337
+ ):
2338
+ raise NotImplementedError(
2339
+ "Duplicate new names in rename."
2340
+ ) # pragma: no cover
2341
+ self.options = (tuple(old), tuple(new), strict)
2342
+ elif self.name == "unpivot":
2343
+ indices, pivotees, variable_name, value_name = self.options
2344
+ value_name = "value" if value_name is None else value_name
2345
+ variable_name = "variable" if variable_name is None else variable_name
2346
+ if len(pivotees) == 0:
2347
+ index = frozenset(indices)
2348
+ pivotees = [name for name in df.schema if name not in index]
2349
+ if not all(
2350
+ dtypes.can_cast(df.schema[p].plc, self.schema[value_name].plc)
2351
+ for p in pivotees
2352
+ ):
2353
+ raise NotImplementedError(
2354
+ "Unpivot cannot cast all input columns to "
2355
+ f"{self.schema[value_name].id()}"
2356
+ ) # pragma: no cover
2357
+ self.options = (
2358
+ tuple(indices),
2359
+ tuple(pivotees),
2360
+ variable_name,
2361
+ value_name,
2362
+ )
2363
+ elif self.name == "row_index":
2364
+ col_name, offset = options
2365
+ self.options = (col_name, offset)
2366
+ elif self.name == "fast_count":
2367
+ # TODO: Remove this once all scan types support projections
2368
+ # using Select + Len. Currently, CSV is the only format that
2369
+ # uses the legacy MapFunction(FastCount) path because it is
2370
+ # faster than the new-streaming path for large files.
2371
+ # See https://github.com/pola-rs/polars/pull/22363#issue-3010224808
2372
+ raise NotImplementedError(
2373
+ "Fast count unsupported for CSV scans"
2374
+ ) # pragma: no cover
2375
+ self._non_child_args = (schema, name, self.options)
2376
+
2377
+ def get_hashable(self) -> Hashable:
2378
+ """
2379
+ Hashable representation of the node.
2380
+
2381
+ The options dictionaries are serialised for hashing purposes
2382
+ as json strings.
2383
+ """
2384
+ return (
2385
+ type(self),
2386
+ self.name,
2387
+ json.dumps(self.options),
2388
+ tuple(self.schema.items()),
2389
+ self._ctor_arguments(self.children)[1:],
2390
+ )
2391
+
2392
+ @classmethod
2393
+ @nvtx_annotate_cudf_polars(message="MapFunction")
2394
+ def do_evaluate(
2395
+ cls, schema: Schema, name: str, options: Any, df: DataFrame
2396
+ ) -> DataFrame:
2397
+ """Evaluate and return a dataframe."""
2398
+ if name == "rechunk":
2399
+ # No-op in our data model
2400
+ # Don't think this appears in a plan tree from python
2401
+ return df # pragma: no cover
2402
+ elif POLARS_VERSION_LT_131 and name == "rename": # pragma: no cover
2403
+ # final tag is "swapping" which is useful for the
2404
+ # optimiser (it blocks some pushdown operations)
2405
+ old, new, _ = options
2406
+ return df.rename_columns(dict(zip(old, new, strict=True)))
2407
+ elif name == "explode":
2408
+ ((to_explode,),) = options
2409
+ index = df.column_names.index(to_explode)
2410
+ subset = df.column_names_set - {to_explode}
2411
+ return DataFrame.from_table(
2412
+ plc.lists.explode_outer(df.table, index), df.column_names, df.dtypes
2413
+ ).sorted_like(df, subset=subset)
2414
+ elif name == "unpivot":
2415
+ (
2416
+ indices,
2417
+ pivotees,
2418
+ variable_name,
2419
+ value_name,
2420
+ ) = options
2421
+ npiv = len(pivotees)
2422
+ selected = df.select(indices)
2423
+ index_columns = [
2424
+ Column(tiled, name=name, dtype=old.dtype)
2425
+ for tiled, name, old in zip(
2426
+ plc.reshape.tile(selected.table, npiv).columns(),
2427
+ indices,
2428
+ selected.columns,
2429
+ strict=True,
2430
+ )
2431
+ ]
2432
+ (variable_column,) = plc.filling.repeat(
2433
+ plc.Table(
2434
+ [
2435
+ plc.Column.from_arrow(
2436
+ pl.Series(
2437
+ values=pivotees, dtype=schema[variable_name].polars
2438
+ )
2439
+ )
2440
+ ]
2441
+ ),
2442
+ df.num_rows,
2443
+ ).columns()
2444
+ value_column = plc.concatenate.concatenate(
2445
+ [
2446
+ df.column_map[pivotee].astype(schema[value_name]).obj
2447
+ for pivotee in pivotees
2448
+ ]
2449
+ )
2450
+ return DataFrame(
2451
+ [
2452
+ *index_columns,
2453
+ Column(
2454
+ variable_column, name=variable_name, dtype=schema[variable_name]
2455
+ ),
2456
+ Column(value_column, name=value_name, dtype=schema[value_name]),
2457
+ ]
2458
+ )
2459
+ elif name == "row_index":
2460
+ col_name, offset = options
2461
+ dtype = schema[col_name]
2462
+ step = plc.Scalar.from_py(1, dtype.plc)
2463
+ init = plc.Scalar.from_py(offset, dtype.plc)
2464
+ index_col = Column(
2465
+ plc.filling.sequence(df.num_rows, init, step),
2466
+ is_sorted=plc.types.Sorted.YES,
2467
+ order=plc.types.Order.ASCENDING,
2468
+ null_order=plc.types.NullOrder.AFTER,
2469
+ name=col_name,
2470
+ dtype=dtype,
2471
+ )
2472
+ return DataFrame([index_col, *df.columns])
2473
+ else:
2474
+ raise AssertionError("Should never be reached") # pragma: no cover
2475
+
2476
+
2477
+ class Union(IR):
2478
+ """Concatenate dataframes vertically."""
2479
+
2480
+ __slots__ = ("zlice",)
2481
+ _non_child = ("schema", "zlice")
2482
+ zlice: Zlice | None
2483
+ """Optional slice to apply to the result."""
2484
+
2485
+ def __init__(self, schema: Schema, zlice: Zlice | None, *children: IR):
2486
+ self.schema = schema
2487
+ self.zlice = zlice
2488
+ self._non_child_args = (zlice,)
2489
+ self.children = children
2490
+ schema = self.children[0].schema
2491
+
2492
+ @classmethod
2493
+ @nvtx_annotate_cudf_polars(message="Union")
2494
+ def do_evaluate(cls, zlice: Zlice | None, *dfs: DataFrame) -> DataFrame:
2495
+ """Evaluate and return a dataframe."""
2496
+ # TODO: only evaluate what we need if we have a slice?
2497
+ return DataFrame.from_table(
2498
+ plc.concatenate.concatenate([df.table for df in dfs]),
2499
+ dfs[0].column_names,
2500
+ dfs[0].dtypes,
2501
+ ).slice(zlice)
2502
+
2503
+
2504
+ class HConcat(IR):
2505
+ """Concatenate dataframes horizontally."""
2506
+
2507
+ __slots__ = ("should_broadcast",)
2508
+ _non_child = ("schema", "should_broadcast")
2509
+
2510
+ def __init__(
2511
+ self,
2512
+ schema: Schema,
2513
+ should_broadcast: bool, # noqa: FBT001
2514
+ *children: IR,
2515
+ ):
2516
+ self.schema = schema
2517
+ self.should_broadcast = should_broadcast
2518
+ self._non_child_args = (should_broadcast,)
2519
+ self.children = children
2520
+
2521
+ @staticmethod
2522
+ def _extend_with_nulls(table: plc.Table, *, nrows: int) -> plc.Table:
2523
+ """
2524
+ Extend a table with nulls.
2525
+
2526
+ Parameters
2527
+ ----------
2528
+ table
2529
+ Table to extend
2530
+ nrows
2531
+ Number of additional rows
2532
+
2533
+ Returns
2534
+ -------
2535
+ New pylibcudf table.
2536
+ """
2537
+ return plc.concatenate.concatenate(
2538
+ [
2539
+ table,
2540
+ plc.Table(
2541
+ [
2542
+ plc.Column.all_null_like(column, nrows)
2543
+ for column in table.columns()
2544
+ ]
2545
+ ),
2546
+ ]
2547
+ )
2548
+
2549
+ @classmethod
2550
+ @nvtx_annotate_cudf_polars(message="HConcat")
2551
+ def do_evaluate(
2552
+ cls,
2553
+ should_broadcast: bool, # noqa: FBT001
2554
+ *dfs: DataFrame,
2555
+ ) -> DataFrame:
2556
+ """Evaluate and return a dataframe."""
2557
+ # Special should_broadcast case.
2558
+ # Used to recombine decomposed expressions
2559
+ if should_broadcast:
2560
+ return DataFrame(
2561
+ broadcast(*itertools.chain.from_iterable(df.columns for df in dfs))
2562
+ )
2563
+
2564
+ max_rows = max(df.num_rows for df in dfs)
2565
+ # Horizontal concatenation extends shorter tables with nulls
2566
+ return DataFrame(
2567
+ itertools.chain.from_iterable(
2568
+ df.columns
2569
+ for df in (
2570
+ df
2571
+ if df.num_rows == max_rows
2572
+ else DataFrame.from_table(
2573
+ cls._extend_with_nulls(df.table, nrows=max_rows - df.num_rows),
2574
+ df.column_names,
2575
+ df.dtypes,
2576
+ )
2577
+ for df in dfs
2578
+ )
2579
+ )
2580
+ )
2581
+
2582
+
2583
+ class Empty(IR):
2584
+ """Represents an empty DataFrame with a known schema."""
2585
+
2586
+ __slots__ = ("schema",)
2587
+ _non_child = ("schema",)
2588
+
2589
+ def __init__(self, schema: Schema):
2590
+ self.schema = schema
2591
+ self._non_child_args = (schema,)
2592
+ self.children = ()
2593
+
2594
+ @classmethod
2595
+ @nvtx_annotate_cudf_polars(message="Empty")
2596
+ def do_evaluate(cls, schema: Schema) -> DataFrame: # pragma: no cover
2597
+ """Evaluate and return a dataframe."""
2598
+ return DataFrame(
2599
+ [
2600
+ Column(
2601
+ plc.column_factories.make_empty_column(dtype.plc),
2602
+ dtype=dtype,
2603
+ name=name,
2604
+ )
2605
+ for name, dtype in schema.items()
2606
+ ]
2607
+ )