cudf-polars-cu12 24.8.0a281__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cudf_polars/dsl/ir.py ADDED
@@ -0,0 +1,1053 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """
4
+ DSL nodes for the LogicalPlan of polars.
5
+
6
+ An IR node is either a source, normal, or a sink. Respectively they
7
+ can be considered as functions:
8
+
9
+ - source: `IO () -> DataFrame`
10
+ - normal: `DataFrame -> DataFrame`
11
+ - sink: `DataFrame -> IO ()`
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import dataclasses
17
+ import itertools
18
+ import types
19
+ from functools import cache
20
+ from pathlib import Path
21
+ from typing import TYPE_CHECKING, Any, Callable, ClassVar
22
+
23
+ import pyarrow as pa
24
+ from typing_extensions import assert_never
25
+
26
+ import polars as pl
27
+
28
+ import cudf
29
+ import cudf._lib.pylibcudf as plc
30
+
31
+ import cudf_polars.dsl.expr as expr
32
+ from cudf_polars.containers import DataFrame, NamedColumn
33
+ from cudf_polars.utils import dtypes, sorting
34
+
35
+ if TYPE_CHECKING:
36
+ from collections.abc import MutableMapping
37
+ from typing import Literal
38
+
39
+ from cudf_polars.typing import Schema
40
+
41
+
42
+ __all__ = [
43
+ "IR",
44
+ "PythonScan",
45
+ "Scan",
46
+ "Cache",
47
+ "DataFrameScan",
48
+ "Select",
49
+ "GroupBy",
50
+ "Join",
51
+ "HStack",
52
+ "Distinct",
53
+ "Sort",
54
+ "Slice",
55
+ "Filter",
56
+ "Projection",
57
+ "MapFunction",
58
+ "Union",
59
+ "HConcat",
60
+ ]
61
+
62
+
63
+ def broadcast(
64
+ *columns: NamedColumn, target_length: int | None = None
65
+ ) -> list[NamedColumn]:
66
+ """
67
+ Broadcast a sequence of columns to a common length.
68
+
69
+ Parameters
70
+ ----------
71
+ columns
72
+ Columns to broadcast.
73
+ target_length
74
+ Optional length to broadcast to. If not provided, uses the
75
+ non-unit length of existing columns.
76
+
77
+ Returns
78
+ -------
79
+ List of broadcasted columns all of the same length.
80
+
81
+ Raises
82
+ ------
83
+ RuntimeError
84
+ If broadcasting is not possible.
85
+
86
+ Notes
87
+ -----
88
+ In evaluation of a set of expressions, polars type-puns length-1
89
+ columns with scalars. When we insert these into a DataFrame
90
+ object, we need to ensure they are of equal length. This function
91
+ takes some columns, some of which may be length-1 and ensures that
92
+ all length-1 columns are broadcast to the length of the others.
93
+
94
+ Broadcasting is only possible if the set of lengths of the input
95
+ columns is a subset of ``{1, n}`` for some (fixed) ``n``. If
96
+ ``target_length`` is provided and not all columns are length-1
97
+ (i.e. ``n != 1``), then ``target_length`` must be equal to ``n``.
98
+ """
99
+ if len(columns) == 0:
100
+ return []
101
+ lengths: set[int] = {column.obj.size() for column in columns}
102
+ if lengths == {1}:
103
+ if target_length is None:
104
+ return list(columns)
105
+ nrows = target_length
106
+ else:
107
+ try:
108
+ (nrows,) = lengths.difference([1])
109
+ except ValueError as e:
110
+ raise RuntimeError("Mismatching column lengths") from e
111
+ if target_length is not None and nrows != target_length:
112
+ raise RuntimeError(
113
+ f"Cannot broadcast columns of length {nrows=} to {target_length=}"
114
+ )
115
+ return [
116
+ column
117
+ if column.obj.size() != 1
118
+ else NamedColumn(
119
+ plc.Column.from_scalar(column.obj_scalar, nrows),
120
+ column.name,
121
+ is_sorted=plc.types.Sorted.YES,
122
+ order=plc.types.Order.ASCENDING,
123
+ null_order=plc.types.NullOrder.BEFORE,
124
+ )
125
+ for column in columns
126
+ ]
127
+
128
+
129
+ @dataclasses.dataclass
130
+ class IR:
131
+ """Abstract plan node, representing an unevaluated dataframe."""
132
+
133
+ schema: Schema
134
+ """Mapping from column names to their data types."""
135
+
136
+ def __post_init__(self):
137
+ """Validate preconditions."""
138
+ if any(dtype.id() == plc.TypeId.EMPTY for dtype in self.schema.values()):
139
+ raise NotImplementedError("Cannot make empty columns.")
140
+
141
+ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
142
+ """
143
+ Evaluate the node and return a dataframe.
144
+
145
+ Parameters
146
+ ----------
147
+ cache
148
+ Mapping from cached node ids to constructed DataFrames.
149
+ Used to implement evaluation of the `Cache` node.
150
+
151
+ Returns
152
+ -------
153
+ DataFrame (on device) representing the evaluation of this plan
154
+ node.
155
+
156
+ Raises
157
+ ------
158
+ NotImplementedError
159
+ If we couldn't evaluate things. Ideally this should not occur,
160
+ since the translation phase should pick up things that we
161
+ cannot handle.
162
+ """
163
+ raise NotImplementedError(
164
+ f"Evaluation of plan {type(self).__name__}"
165
+ ) # pragma: no cover
166
+
167
+
168
+ @dataclasses.dataclass
169
+ class PythonScan(IR):
170
+ """Representation of input from a python function."""
171
+
172
+ options: Any
173
+ """Arbitrary options."""
174
+ predicate: expr.NamedExpr | None
175
+ """Filter to apply to the constructed dataframe before returning it."""
176
+
177
+ def __post_init__(self):
178
+ """Validate preconditions."""
179
+ raise NotImplementedError("PythonScan not implemented")
180
+
181
+
182
+ @dataclasses.dataclass
183
+ class Scan(IR):
184
+ """Input from files."""
185
+
186
+ typ: str
187
+ """What type of file are we reading? Parquet, CSV, etc..."""
188
+ reader_options: dict[str, Any]
189
+ """Reader-specific options, as dictionary."""
190
+ cloud_options: dict[str, Any] | None
191
+ """Cloud-related authentication options, currently ignored."""
192
+ paths: list[str]
193
+ """List of paths to read from."""
194
+ file_options: Any
195
+ """Options for reading the file.
196
+
197
+ Attributes are:
198
+ - ``with_columns: list[str]`` of projected columns to return.
199
+ - ``n_rows: int``: Number of rows to read.
200
+ - ``row_index: tuple[name, offset] | None``: Add an integer index
201
+ column with given name.
202
+ """
203
+ predicate: expr.NamedExpr | None
204
+ """Mask to apply to the read dataframe."""
205
+
206
+ def __post_init__(self) -> None:
207
+ """Validate preconditions."""
208
+ if self.file_options.n_rows is not None:
209
+ raise NotImplementedError("row limit in scan")
210
+ if self.typ not in ("csv", "parquet"):
211
+ raise NotImplementedError(f"Unhandled scan type: {self.typ}")
212
+ if self.cloud_options is not None and any(
213
+ self.cloud_options[k] is not None for k in ("aws", "azure", "gcp")
214
+ ):
215
+ raise NotImplementedError(
216
+ "Read from cloud storage"
217
+ ) # pragma: no cover; no test yet
218
+ if self.typ == "csv":
219
+ if self.reader_options["skip_rows_after_header"] != 0:
220
+ raise NotImplementedError("Skipping rows after header in CSV reader")
221
+ parse_options = self.reader_options["parse_options"]
222
+ if (
223
+ null_values := parse_options["null_values"]
224
+ ) is not None and "Named" in null_values:
225
+ raise NotImplementedError(
226
+ "Per column null value specification not supported for CSV reader"
227
+ )
228
+ if (
229
+ comment := parse_options["comment_prefix"]
230
+ ) is not None and "Multi" in comment:
231
+ raise NotImplementedError(
232
+ "Multi-character comment prefix not supported for CSV reader"
233
+ )
234
+ if not self.reader_options["has_header"]:
235
+ # Need to do some file introspection to get the number
236
+ # of columns so that column projection works right.
237
+ raise NotImplementedError("Reading CSV without header")
238
+
239
+ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
240
+ """Evaluate and return a dataframe."""
241
+ options = self.file_options
242
+ with_columns = options.with_columns
243
+ row_index = options.row_index
244
+ if self.typ == "csv":
245
+ dtype_map = {
246
+ name: cudf._lib.types.PYLIBCUDF_TO_SUPPORTED_NUMPY_TYPES[typ.id()]
247
+ for name, typ in self.schema.items()
248
+ }
249
+ parse_options = self.reader_options["parse_options"]
250
+ sep = chr(parse_options["separator"])
251
+ quote = chr(parse_options["quote_char"])
252
+ eol = chr(parse_options["eol_char"])
253
+ if self.reader_options["schema"] is not None:
254
+ # Reader schema provides names
255
+ column_names = list(self.reader_options["schema"]["inner"].keys())
256
+ else:
257
+ # file provides column names
258
+ column_names = None
259
+ usecols = with_columns
260
+ # TODO: support has_header=False
261
+ header = 0
262
+
263
+ # polars defaults to no null recognition
264
+ null_values = [""]
265
+ if parse_options["null_values"] is not None:
266
+ ((typ, nulls),) = parse_options["null_values"].items()
267
+ if typ == "AllColumnsSingle":
268
+ # Single value
269
+ null_values.append(nulls)
270
+ else:
271
+ # List of values
272
+ null_values.extend(nulls)
273
+ if parse_options["comment_prefix"] is not None:
274
+ comment = chr(parse_options["comment_prefix"]["Single"])
275
+ else:
276
+ comment = None
277
+ decimal = "," if parse_options["decimal_comma"] else "."
278
+
279
+ # polars skips blank lines at the beginning of the file
280
+ pieces = []
281
+ for p in self.paths:
282
+ skiprows = self.reader_options["skip_rows"]
283
+ # TODO: read_csv expands globs which we should not do,
284
+ # because polars will already have handled them.
285
+ path = Path(p)
286
+ with path.open() as f:
287
+ while f.readline() == "\n":
288
+ skiprows += 1
289
+ pieces.append(
290
+ cudf.read_csv(
291
+ path,
292
+ sep=sep,
293
+ quotechar=quote,
294
+ lineterminator=eol,
295
+ names=column_names,
296
+ header=header,
297
+ usecols=usecols,
298
+ na_filter=True,
299
+ na_values=null_values,
300
+ keep_default_na=False,
301
+ skiprows=skiprows,
302
+ comment=comment,
303
+ decimal=decimal,
304
+ dtype=dtype_map,
305
+ )
306
+ )
307
+ df = DataFrame.from_cudf(cudf.concat(pieces))
308
+ elif self.typ == "parquet":
309
+ cdf = cudf.read_parquet(self.paths, columns=with_columns)
310
+ assert isinstance(cdf, cudf.DataFrame)
311
+ df = DataFrame.from_cudf(cdf)
312
+ else:
313
+ raise NotImplementedError(
314
+ f"Unhandled scan type: {self.typ}"
315
+ ) # pragma: no cover; post init trips first
316
+ if row_index is not None:
317
+ name, offset = row_index
318
+ dtype = self.schema[name]
319
+ step = plc.interop.from_arrow(
320
+ pa.scalar(1, type=plc.interop.to_arrow(dtype))
321
+ )
322
+ init = plc.interop.from_arrow(
323
+ pa.scalar(offset, type=plc.interop.to_arrow(dtype))
324
+ )
325
+ index = NamedColumn(
326
+ plc.filling.sequence(df.num_rows, init, step),
327
+ name,
328
+ is_sorted=plc.types.Sorted.YES,
329
+ order=plc.types.Order.ASCENDING,
330
+ null_order=plc.types.NullOrder.AFTER,
331
+ )
332
+ df = DataFrame([index, *df.columns])
333
+ # TODO: should be true, but not the case until we get
334
+ # cudf-classic out of the loop for IO since it converts date32
335
+ # to datetime.
336
+ # assert all(
337
+ # c.obj.type() == dtype
338
+ # for c, dtype in zip(df.columns, self.schema.values())
339
+ # )
340
+ if self.predicate is None:
341
+ return df
342
+ else:
343
+ (mask,) = broadcast(self.predicate.evaluate(df), target_length=df.num_rows)
344
+ return df.filter(mask)
345
+
346
+
347
+ @dataclasses.dataclass
348
+ class Cache(IR):
349
+ """
350
+ Return a cached plan node.
351
+
352
+ Used for CSE at the plan level.
353
+ """
354
+
355
+ key: int
356
+ """The cache key."""
357
+ value: IR
358
+ """The unevaluated node to cache."""
359
+
360
+ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
361
+ """Evaluate and return a dataframe."""
362
+ try:
363
+ return cache[self.key]
364
+ except KeyError:
365
+ return cache.setdefault(self.key, self.value.evaluate(cache=cache))
366
+
367
+
368
+ @dataclasses.dataclass
369
+ class DataFrameScan(IR):
370
+ """
371
+ Input from an existing polars DataFrame.
372
+
373
+ This typically arises from ``q.collect().lazy()``
374
+ """
375
+
376
+ df: Any
377
+ """Polars LazyFrame object."""
378
+ projection: list[str]
379
+ """List of columns to project out."""
380
+ predicate: expr.NamedExpr | None
381
+ """Mask to apply."""
382
+
383
+ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
384
+ """Evaluate and return a dataframe."""
385
+ pdf = pl.DataFrame._from_pydf(self.df)
386
+ if self.projection is not None:
387
+ pdf = pdf.select(self.projection)
388
+ table = pdf.to_arrow()
389
+ schema = table.schema
390
+ for i, field in enumerate(schema):
391
+ schema = schema.set(
392
+ i, pa.field(field.name, dtypes.downcast_arrow_lists(field.type))
393
+ )
394
+ # No-op if the schema is unchanged.
395
+ table = table.cast(schema)
396
+ df = DataFrame.from_table(
397
+ plc.interop.from_arrow(table), list(self.schema.keys())
398
+ )
399
+ assert all(
400
+ c.obj.type() == dtype for c, dtype in zip(df.columns, self.schema.values())
401
+ )
402
+ if self.predicate is not None:
403
+ (mask,) = broadcast(self.predicate.evaluate(df), target_length=df.num_rows)
404
+ return df.filter(mask)
405
+ else:
406
+ return df
407
+
408
+
409
+ @dataclasses.dataclass
410
+ class Select(IR):
411
+ """Produce a new dataframe selecting given expressions from an input."""
412
+
413
+ df: IR
414
+ """Input dataframe."""
415
+ expr: list[expr.NamedExpr]
416
+ """List of expressions to evaluate to form the new dataframe."""
417
+ should_broadcast: bool
418
+ """Should columns be broadcast?"""
419
+
420
+ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
421
+ """Evaluate and return a dataframe."""
422
+ df = self.df.evaluate(cache=cache)
423
+ # Handle any broadcasting
424
+ columns = [e.evaluate(df) for e in self.expr]
425
+ if self.should_broadcast:
426
+ columns = broadcast(*columns)
427
+ return DataFrame(columns)
428
+
429
+
430
+ @dataclasses.dataclass
431
+ class Reduce(IR):
432
+ """
433
+ Produce a new dataframe selecting given expressions from an input.
434
+
435
+ This is a special case of :class:`Select` where all outputs are a single row.
436
+ """
437
+
438
+ df: IR
439
+ """Input dataframe."""
440
+ expr: list[expr.NamedExpr]
441
+ """List of expressions to evaluate to form the new dataframe."""
442
+
443
+ def evaluate(
444
+ self, *, cache: MutableMapping[int, DataFrame]
445
+ ) -> DataFrame: # pragma: no cover; polars doesn't emit this node yet
446
+ """Evaluate and return a dataframe."""
447
+ df = self.df.evaluate(cache=cache)
448
+ columns = broadcast(*(e.evaluate(df) for e in self.expr))
449
+ assert all(column.obj.size() == 1 for column in columns)
450
+ return DataFrame(columns)
451
+
452
+
453
+ def placeholder_column(n: int) -> plc.Column:
454
+ """
455
+ Produce a placeholder pylibcudf column with NO BACKING DATA.
456
+
457
+ Parameters
458
+ ----------
459
+ n
460
+ Number of rows the column will advertise
461
+
462
+ Returns
463
+ -------
464
+ pylibcudf Column that is almost unusable. DO NOT ACCESS THE DATA BUFFER.
465
+
466
+ Notes
467
+ -----
468
+ This is used to avoid allocating data for count aggregations.
469
+ """
470
+ return plc.Column(
471
+ plc.DataType(plc.TypeId.INT8),
472
+ n,
473
+ plc.gpumemoryview(
474
+ types.SimpleNamespace(__cuda_array_interface__={"data": (1, True)})
475
+ ),
476
+ None,
477
+ 0,
478
+ 0,
479
+ [],
480
+ )
481
+
482
+
483
+ @dataclasses.dataclass
484
+ class GroupBy(IR):
485
+ """Perform a groupby."""
486
+
487
+ df: IR
488
+ """Input dataframe."""
489
+ agg_requests: list[expr.NamedExpr]
490
+ """List of expressions to evaluate groupwise."""
491
+ keys: list[expr.NamedExpr]
492
+ """List of expressions forming the keys."""
493
+ maintain_order: bool
494
+ """Should the order of the input dataframe be maintained?"""
495
+ options: Any
496
+ """Options controlling style of groupby."""
497
+ agg_infos: list[expr.AggInfo] = dataclasses.field(init=False)
498
+
499
+ @staticmethod
500
+ def check_agg(agg: expr.Expr) -> int:
501
+ """
502
+ Determine if we can handle an aggregation expression.
503
+
504
+ Parameters
505
+ ----------
506
+ agg
507
+ Expression to check
508
+
509
+ Returns
510
+ -------
511
+ depth of nesting
512
+
513
+ Raises
514
+ ------
515
+ NotImplementedError
516
+ For unsupported expression nodes.
517
+ """
518
+ if isinstance(agg, (expr.BinOp, expr.Cast, expr.UnaryFunction)):
519
+ return max(GroupBy.check_agg(child) for child in agg.children)
520
+ elif isinstance(agg, expr.Agg):
521
+ return 1 + max(GroupBy.check_agg(child) for child in agg.children)
522
+ elif isinstance(agg, (expr.Len, expr.Col, expr.Literal)):
523
+ return 0
524
+ else:
525
+ raise NotImplementedError(f"No handler for {agg=}")
526
+
527
+ def __post_init__(self) -> None:
528
+ """Check whether all the aggregations are implemented."""
529
+ if self.options.rolling is None and self.maintain_order:
530
+ raise NotImplementedError("Maintaining order in groupby")
531
+ if self.options.rolling:
532
+ raise NotImplementedError(
533
+ "rolling window/groupby"
534
+ ) # pragma: no cover; rollingwindow constructor has already raised
535
+ if any(GroupBy.check_agg(a.value) > 1 for a in self.agg_requests):
536
+ raise NotImplementedError("Nested aggregations in groupby")
537
+ self.agg_infos = [req.collect_agg(depth=0) for req in self.agg_requests]
538
+
539
+ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
540
+ """Evaluate and return a dataframe."""
541
+ df = self.df.evaluate(cache=cache)
542
+ keys = broadcast(
543
+ *(k.evaluate(df) for k in self.keys), target_length=df.num_rows
544
+ )
545
+ # TODO: use sorted information, need to expose column_order
546
+ # and null_precedence in pylibcudf groupby constructor
547
+ # sorted = (
548
+ # plc.types.Sorted.YES
549
+ # if all(k.is_sorted for k in keys)
550
+ # else plc.types.Sorted.NO
551
+ # )
552
+ grouper = plc.groupby.GroupBy(
553
+ plc.Table([k.obj for k in keys]),
554
+ null_handling=plc.types.NullPolicy.INCLUDE,
555
+ )
556
+ # TODO: uniquify
557
+ requests = []
558
+ replacements: list[expr.Expr] = []
559
+ for info in self.agg_infos:
560
+ for pre_eval, req, rep in info.requests:
561
+ if pre_eval is None:
562
+ col = placeholder_column(df.num_rows)
563
+ else:
564
+ col = pre_eval.evaluate(df).obj
565
+ requests.append(plc.groupby.GroupByRequest(col, [req]))
566
+ replacements.append(rep)
567
+ group_keys, raw_tables = grouper.aggregate(requests)
568
+ # TODO: names
569
+ raw_columns: list[NamedColumn] = []
570
+ for i, table in enumerate(raw_tables):
571
+ (column,) = table.columns()
572
+ raw_columns.append(NamedColumn(column, f"tmp{i}"))
573
+ mapping = dict(zip(replacements, raw_columns))
574
+ result_keys = [
575
+ NamedColumn(gk, k.name) for gk, k in zip(group_keys.columns(), keys)
576
+ ]
577
+ result_subs = DataFrame(raw_columns)
578
+ results = [
579
+ req.evaluate(result_subs, mapping=mapping) for req in self.agg_requests
580
+ ]
581
+ return DataFrame([*result_keys, *results]).slice(self.options.slice)
582
+
583
+
584
+ @dataclasses.dataclass
585
+ class Join(IR):
586
+ """A join of two dataframes."""
587
+
588
+ left: IR
589
+ """Left frame."""
590
+ right: IR
591
+ """Right frame."""
592
+ left_on: list[expr.NamedExpr]
593
+ """List of expressions used as keys in the left frame."""
594
+ right_on: list[expr.NamedExpr]
595
+ """List of expressions used as keys in the right frame."""
596
+ options: tuple[
597
+ Literal["inner", "left", "full", "leftsemi", "leftanti", "cross"],
598
+ bool,
599
+ tuple[int, int] | None,
600
+ str | None,
601
+ bool,
602
+ ]
603
+ """
604
+ tuple of options:
605
+ - how: join type
606
+ - join_nulls: do nulls compare equal?
607
+ - slice: optional slice to perform after joining.
608
+ - suffix: string suffix for right columns if names match
609
+ - coalesce: should key columns be coalesced (only makes sense for outer joins)
610
+ """
611
+
612
+ def __post_init__(self) -> None:
613
+ """Validate preconditions."""
614
+ if any(
615
+ isinstance(e.value, expr.Literal)
616
+ for e in itertools.chain(self.left_on, self.right_on)
617
+ ):
618
+ raise NotImplementedError("Join with literal as join key.")
619
+
620
+ @staticmethod
621
+ @cache
622
+ def _joiners(
623
+ how: Literal["inner", "left", "full", "leftsemi", "leftanti"],
624
+ ) -> tuple[
625
+ Callable, plc.copying.OutOfBoundsPolicy, plc.copying.OutOfBoundsPolicy | None
626
+ ]:
627
+ if how == "inner":
628
+ return (
629
+ plc.join.inner_join,
630
+ plc.copying.OutOfBoundsPolicy.DONT_CHECK,
631
+ plc.copying.OutOfBoundsPolicy.DONT_CHECK,
632
+ )
633
+ elif how == "left":
634
+ return (
635
+ plc.join.left_join,
636
+ plc.copying.OutOfBoundsPolicy.DONT_CHECK,
637
+ plc.copying.OutOfBoundsPolicy.NULLIFY,
638
+ )
639
+ elif how == "full":
640
+ return (
641
+ plc.join.full_join,
642
+ plc.copying.OutOfBoundsPolicy.NULLIFY,
643
+ plc.copying.OutOfBoundsPolicy.NULLIFY,
644
+ )
645
+ elif how == "leftsemi":
646
+ return (
647
+ plc.join.left_semi_join,
648
+ plc.copying.OutOfBoundsPolicy.DONT_CHECK,
649
+ None,
650
+ )
651
+ elif how == "leftanti":
652
+ return (
653
+ plc.join.left_anti_join,
654
+ plc.copying.OutOfBoundsPolicy.DONT_CHECK,
655
+ None,
656
+ )
657
+ else:
658
+ assert_never(how)
659
+
660
+ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
661
+ """Evaluate and return a dataframe."""
662
+ left = self.left.evaluate(cache=cache)
663
+ right = self.right.evaluate(cache=cache)
664
+ how, join_nulls, zlice, suffix, coalesce = self.options
665
+ suffix = "_right" if suffix is None else suffix
666
+ if how == "cross":
667
+ # Separate implementation, since cross_join returns the
668
+ # result, not the gather maps
669
+ columns = plc.join.cross_join(left.table, right.table).columns()
670
+ left_cols = [
671
+ NamedColumn(new, old.name).sorted_like(old)
672
+ for new, old in zip(columns[: left.num_columns], left.columns)
673
+ ]
674
+ right_cols = [
675
+ NamedColumn(
676
+ new,
677
+ old.name
678
+ if old.name not in left.column_names_set
679
+ else f"{old.name}{suffix}",
680
+ )
681
+ for new, old in zip(columns[left.num_columns :], right.columns)
682
+ ]
683
+ return DataFrame([*left_cols, *right_cols])
684
+ # TODO: Waiting on clarity based on https://github.com/pola-rs/polars/issues/17184
685
+ left_on = DataFrame(broadcast(*(e.evaluate(left) for e in self.left_on)))
686
+ right_on = DataFrame(broadcast(*(e.evaluate(right) for e in self.right_on)))
687
+ null_equality = (
688
+ plc.types.NullEquality.EQUAL
689
+ if join_nulls
690
+ else plc.types.NullEquality.UNEQUAL
691
+ )
692
+ join_fn, left_policy, right_policy = Join._joiners(how)
693
+ if right_policy is None:
694
+ # Semi join
695
+ lg = join_fn(left_on.table, right_on.table, null_equality)
696
+ table = plc.copying.gather(left.table, lg, left_policy)
697
+ result = DataFrame.from_table(table, left.column_names)
698
+ else:
699
+ lg, rg = join_fn(left_on.table, right_on.table, null_equality)
700
+ if coalesce and how == "inner":
701
+ right = right.discard_columns(right_on.column_names_set)
702
+ left = DataFrame.from_table(
703
+ plc.copying.gather(left.table, lg, left_policy), left.column_names
704
+ )
705
+ right = DataFrame.from_table(
706
+ plc.copying.gather(right.table, rg, right_policy), right.column_names
707
+ )
708
+ if coalesce and how != "inner":
709
+ left = left.replace_columns(
710
+ *(
711
+ NamedColumn(
712
+ plc.replace.replace_nulls(left_col.obj, right_col.obj),
713
+ left_col.name,
714
+ )
715
+ for left_col, right_col in zip(
716
+ left.select_columns(left_on.column_names_set),
717
+ right.select_columns(right_on.column_names_set),
718
+ )
719
+ )
720
+ )
721
+ right = right.discard_columns(right_on.column_names_set)
722
+ right = right.rename_columns(
723
+ {
724
+ name: f"{name}{suffix}"
725
+ for name in right.column_names
726
+ if name in left.column_names_set
727
+ }
728
+ )
729
+ result = left.with_columns(right.columns)
730
+ return result.slice(zlice)
731
+
732
+
733
+ @dataclasses.dataclass
734
+ class HStack(IR):
735
+ """Add new columns to a dataframe."""
736
+
737
+ df: IR
738
+ """Input dataframe."""
739
+ columns: list[expr.NamedExpr]
740
+ """List of expressions to produce new columns."""
741
+ should_broadcast: bool
742
+ """Should columns be broadcast?"""
743
+
744
+ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
745
+ """Evaluate and return a dataframe."""
746
+ df = self.df.evaluate(cache=cache)
747
+ columns = [c.evaluate(df) for c in self.columns]
748
+ if self.should_broadcast:
749
+ columns = broadcast(*columns, target_length=df.num_rows)
750
+ else:
751
+ # Polars ensures this is true, but let's make sure nothing
752
+ # went wrong. In this case, the parent node is a
753
+ # guaranteed to be a Select which will take care of making
754
+ # sure that everything is the same length. The result
755
+ # table that might have mismatching column lengths will
756
+ # never be turned into a pylibcudf Table with all columns
757
+ # by the Select, which is why this is safe.
758
+ assert all(e.name.startswith("__POLARS_CSER_0x") for e in self.columns)
759
+ return df.with_columns(columns)
760
+
761
+
762
+ @dataclasses.dataclass
763
+ class Distinct(IR):
764
+ """Produce a new dataframe with distinct rows."""
765
+
766
+ df: IR
767
+ """Input dataframe."""
768
+ keep: plc.stream_compaction.DuplicateKeepOption
769
+ """Which rows to keep."""
770
+ subset: set[str] | None
771
+ """Which columns to inspect when computing distinct rows."""
772
+ zlice: tuple[int, int] | None
773
+ """Optional slice to perform after compaction."""
774
+ stable: bool
775
+ """Should order be preserved?"""
776
+
777
+ _KEEP_MAP: ClassVar[dict[str, plc.stream_compaction.DuplicateKeepOption]] = {
778
+ "first": plc.stream_compaction.DuplicateKeepOption.KEEP_FIRST,
779
+ "last": plc.stream_compaction.DuplicateKeepOption.KEEP_LAST,
780
+ "none": plc.stream_compaction.DuplicateKeepOption.KEEP_NONE,
781
+ "any": plc.stream_compaction.DuplicateKeepOption.KEEP_ANY,
782
+ }
783
+
784
+ def __init__(self, schema: Schema, df: IR, options: Any) -> None:
785
+ self.schema = schema
786
+ self.df = df
787
+ (keep, subset, maintain_order, zlice) = options
788
+ self.keep = Distinct._KEEP_MAP[keep]
789
+ self.subset = set(subset) if subset is not None else None
790
+ self.stable = maintain_order
791
+ self.zlice = zlice
792
+
793
+ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
794
+ """Evaluate and return a dataframe."""
795
+ df = self.df.evaluate(cache=cache)
796
+ if self.subset is None:
797
+ indices = list(range(df.num_columns))
798
+ else:
799
+ indices = [i for i, k in enumerate(df.column_names) if k in self.subset]
800
+ keys_sorted = all(df.columns[i].is_sorted for i in indices)
801
+ if keys_sorted:
802
+ table = plc.stream_compaction.unique(
803
+ df.table,
804
+ indices,
805
+ self.keep,
806
+ plc.types.NullEquality.EQUAL,
807
+ )
808
+ else:
809
+ distinct = (
810
+ plc.stream_compaction.stable_distinct
811
+ if self.stable
812
+ else plc.stream_compaction.distinct
813
+ )
814
+ table = distinct(
815
+ df.table,
816
+ indices,
817
+ self.keep,
818
+ plc.types.NullEquality.EQUAL,
819
+ plc.types.NanEquality.ALL_EQUAL,
820
+ )
821
+ result = DataFrame(
822
+ [
823
+ NamedColumn(c, old.name).sorted_like(old)
824
+ for c, old in zip(table.columns(), df.columns)
825
+ ]
826
+ )
827
+ if keys_sorted or self.stable:
828
+ result = result.sorted_like(df)
829
+ return result.slice(self.zlice)
830
+
831
+
832
+ @dataclasses.dataclass
833
+ class Sort(IR):
834
+ """Sort a dataframe."""
835
+
836
+ df: IR
837
+ """Input."""
838
+ by: list[expr.NamedExpr]
839
+ """List of expressions to produce sort keys."""
840
+ do_sort: Callable[..., plc.Table]
841
+ """pylibcudf sorting function."""
842
+ zlice: tuple[int, int] | None
843
+ """Optional slice to apply after sorting."""
844
+ order: list[plc.types.Order]
845
+ """Order keys should be sorted in."""
846
+ null_order: list[plc.types.NullOrder]
847
+ """Where nulls sort to."""
848
+
849
+ def __init__(
850
+ self,
851
+ schema: Schema,
852
+ df: IR,
853
+ by: list[expr.NamedExpr],
854
+ options: Any,
855
+ zlice: tuple[int, int] | None,
856
+ ) -> None:
857
+ self.schema = schema
858
+ self.df = df
859
+ self.by = by
860
+ self.zlice = zlice
861
+ stable, nulls_last, descending = options
862
+ self.order, self.null_order = sorting.sort_order(
863
+ descending, nulls_last=nulls_last, num_keys=len(by)
864
+ )
865
+ self.do_sort = (
866
+ plc.sorting.stable_sort_by_key if stable else plc.sorting.sort_by_key
867
+ )
868
+
869
+ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
870
+ """Evaluate and return a dataframe."""
871
+ df = self.df.evaluate(cache=cache)
872
+ sort_keys = broadcast(
873
+ *(k.evaluate(df) for k in self.by), target_length=df.num_rows
874
+ )
875
+ names = {c.name: i for i, c in enumerate(df.columns)}
876
+ # TODO: More robust identification here.
877
+ keys_in_result = [
878
+ i
879
+ for k in sort_keys
880
+ if (i := names.get(k.name)) is not None and k.obj is df.columns[i].obj
881
+ ]
882
+ table = self.do_sort(
883
+ df.table,
884
+ plc.Table([k.obj for k in sort_keys]),
885
+ self.order,
886
+ self.null_order,
887
+ )
888
+ columns = [
889
+ NamedColumn(c, old.name) for c, old in zip(table.columns(), df.columns)
890
+ ]
891
+ # If a sort key is in the result table, set the sortedness property
892
+ for k, i in enumerate(keys_in_result):
893
+ columns[i] = columns[i].set_sorted(
894
+ is_sorted=plc.types.Sorted.YES,
895
+ order=self.order[k],
896
+ null_order=self.null_order[k],
897
+ )
898
+ return DataFrame(columns).slice(self.zlice)
899
+
900
+
901
+ @dataclasses.dataclass
902
+ class Slice(IR):
903
+ """Slice a dataframe."""
904
+
905
+ df: IR
906
+ """Input."""
907
+ offset: int
908
+ """Start of the slice."""
909
+ length: int
910
+ """Length of the slice."""
911
+
912
+ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
913
+ """Evaluate and return a dataframe."""
914
+ df = self.df.evaluate(cache=cache)
915
+ return df.slice((self.offset, self.length))
916
+
917
+
918
+ @dataclasses.dataclass
919
+ class Filter(IR):
920
+ """Filter a dataframe with a boolean mask."""
921
+
922
+ df: IR
923
+ """Input."""
924
+ mask: expr.NamedExpr
925
+ """Expression evaluating to a mask."""
926
+
927
+ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
928
+ """Evaluate and return a dataframe."""
929
+ df = self.df.evaluate(cache=cache)
930
+ (mask,) = broadcast(self.mask.evaluate(df), target_length=df.num_rows)
931
+ return df.filter(mask)
932
+
933
+
934
+ @dataclasses.dataclass
935
+ class Projection(IR):
936
+ """Select a subset of columns from a dataframe."""
937
+
938
+ df: IR
939
+ """Input."""
940
+
941
+ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
942
+ """Evaluate and return a dataframe."""
943
+ df = self.df.evaluate(cache=cache)
944
+ # This can reorder things.
945
+ columns = broadcast(
946
+ *df.select(list(self.schema.keys())).columns, target_length=df.num_rows
947
+ )
948
+ return DataFrame(columns)
949
+
950
+
951
+ @dataclasses.dataclass
952
+ class MapFunction(IR):
953
+ """Apply some function to a dataframe."""
954
+
955
+ df: IR
956
+ """Input."""
957
+ name: str
958
+ """Function name."""
959
+ options: Any
960
+ """Arbitrary options, interpreted per function."""
961
+
962
+ _NAMES: ClassVar[frozenset[str]] = frozenset(
963
+ [
964
+ "rechunk",
965
+ # libcudf merge is not stable wrt order of inputs, since
966
+ # it uses a priority queue to manage the tables it produces.
967
+ # See: https://github.com/rapidsai/cudf/issues/16010
968
+ # "merge_sorted",
969
+ "rename",
970
+ "explode",
971
+ ]
972
+ )
973
+
974
+ def __post_init__(self) -> None:
975
+ """Validate preconditions."""
976
+ if self.name not in MapFunction._NAMES:
977
+ raise NotImplementedError(f"Unhandled map function {self.name}")
978
+ if self.name == "explode":
979
+ (to_explode,) = self.options
980
+ if len(to_explode) > 1:
981
+ # TODO: straightforward, but need to error check
982
+ # polars requires that all to-explode columns have the
983
+ # same sub-shapes
984
+ raise NotImplementedError("Explode with more than one column")
985
+ elif self.name == "rename":
986
+ old, new, _ = self.options
987
+ # TODO: perhaps polars should validate renaming in the IR?
988
+ if len(new) != len(set(new)) or (
989
+ set(new) & (set(self.df.schema.keys() - set(old)))
990
+ ):
991
+ raise NotImplementedError("Duplicate new names in rename.")
992
+
993
+ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
994
+ """Evaluate and return a dataframe."""
995
+ if self.name == "rechunk":
996
+ # No-op in our data model
997
+ # Don't think this appears in a plan tree from python
998
+ return self.df.evaluate(cache=cache) # pragma: no cover
999
+ elif self.name == "rename":
1000
+ df = self.df.evaluate(cache=cache)
1001
+ # final tag is "swapping" which is useful for the
1002
+ # optimiser (it blocks some pushdown operations)
1003
+ old, new, _ = self.options
1004
+ return df.rename_columns(dict(zip(old, new)))
1005
+ elif self.name == "explode":
1006
+ df = self.df.evaluate(cache=cache)
1007
+ ((to_explode,),) = self.options
1008
+ index = df.column_names.index(to_explode)
1009
+ subset = df.column_names_set - {to_explode}
1010
+ return DataFrame.from_table(
1011
+ plc.lists.explode_outer(df.table, index), df.column_names
1012
+ ).sorted_like(df, subset=subset)
1013
+ else:
1014
+ raise AssertionError("Should never be reached") # pragma: no cover
1015
+
1016
+
1017
+ @dataclasses.dataclass
1018
+ class Union(IR):
1019
+ """Concatenate dataframes vertically."""
1020
+
1021
+ dfs: list[IR]
1022
+ """List of inputs."""
1023
+ zlice: tuple[int, int] | None
1024
+ """Optional slice to apply after concatenation."""
1025
+
1026
+ def __post_init__(self) -> None:
1027
+ """Validate preconditions."""
1028
+ schema = self.dfs[0].schema
1029
+ if not all(s.schema == schema for s in self.dfs[1:]):
1030
+ raise NotImplementedError("Schema mismatch")
1031
+
1032
+ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
1033
+ """Evaluate and return a dataframe."""
1034
+ # TODO: only evaluate what we need if we have a slice
1035
+ dfs = [df.evaluate(cache=cache) for df in self.dfs]
1036
+ return DataFrame.from_table(
1037
+ plc.concatenate.concatenate([df.table for df in dfs]), dfs[0].column_names
1038
+ ).slice(self.zlice)
1039
+
1040
+
1041
+ @dataclasses.dataclass
1042
+ class HConcat(IR):
1043
+ """Concatenate dataframes horizontally."""
1044
+
1045
+ dfs: list[IR]
1046
+ """List of inputs."""
1047
+
1048
+ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
1049
+ """Evaluate and return a dataframe."""
1050
+ dfs = [df.evaluate(cache=cache) for df in self.dfs]
1051
+ return DataFrame(
1052
+ list(itertools.chain.from_iterable(df.columns for df in dfs)),
1053
+ )