cudf-polars-cu12 24.8.0a281__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cudf_polars/VERSION +1 -0
- cudf_polars/__init__.py +22 -0
- cudf_polars/_version.py +21 -0
- cudf_polars/callback.py +66 -0
- cudf_polars/containers/__init__.py +11 -0
- cudf_polars/containers/column.py +189 -0
- cudf_polars/containers/dataframe.py +226 -0
- cudf_polars/dsl/__init__.py +8 -0
- cudf_polars/dsl/expr.py +1422 -0
- cudf_polars/dsl/ir.py +1053 -0
- cudf_polars/dsl/translate.py +535 -0
- cudf_polars/py.typed +0 -0
- cudf_polars/testing/__init__.py +8 -0
- cudf_polars/testing/asserts.py +118 -0
- cudf_polars/typing/__init__.py +106 -0
- cudf_polars/utils/__init__.py +8 -0
- cudf_polars/utils/dtypes.py +159 -0
- cudf_polars/utils/sorting.py +53 -0
- cudf_polars_cu12-24.8.0a281.dist-info/LICENSE +201 -0
- cudf_polars_cu12-24.8.0a281.dist-info/METADATA +126 -0
- cudf_polars_cu12-24.8.0a281.dist-info/RECORD +23 -0
- cudf_polars_cu12-24.8.0a281.dist-info/WHEEL +5 -0
- cudf_polars_cu12-24.8.0a281.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,535 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
"""Translate polars IR representation to ours."""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import json
|
|
9
|
+
from contextlib import AbstractContextManager, nullcontext
|
|
10
|
+
from functools import singledispatch
|
|
11
|
+
from typing import Any
|
|
12
|
+
|
|
13
|
+
import pyarrow as pa
|
|
14
|
+
from typing_extensions import assert_never
|
|
15
|
+
|
|
16
|
+
import polars as pl
|
|
17
|
+
import polars.polars as plrs
|
|
18
|
+
from polars.polars import _expr_nodes as pl_expr, _ir_nodes as pl_ir
|
|
19
|
+
|
|
20
|
+
import cudf._lib.pylibcudf as plc
|
|
21
|
+
|
|
22
|
+
from cudf_polars.dsl import expr, ir
|
|
23
|
+
from cudf_polars.typing import NodeTraverser
|
|
24
|
+
from cudf_polars.utils import dtypes
|
|
25
|
+
|
|
26
|
+
__all__ = ["translate_ir", "translate_named_expr"]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class set_node(AbstractContextManager[None]):
|
|
30
|
+
"""
|
|
31
|
+
Run a block with current node set in the visitor.
|
|
32
|
+
|
|
33
|
+
Parameters
|
|
34
|
+
----------
|
|
35
|
+
visitor
|
|
36
|
+
The internal Rust visitor object
|
|
37
|
+
n
|
|
38
|
+
The node to set as the current root.
|
|
39
|
+
|
|
40
|
+
Notes
|
|
41
|
+
-----
|
|
42
|
+
This is useful for translating expressions with a given node
|
|
43
|
+
active, restoring the node when the block exits.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
__slots__ = ("n", "visitor")
|
|
47
|
+
visitor: NodeTraverser
|
|
48
|
+
n: int
|
|
49
|
+
|
|
50
|
+
def __init__(self, visitor: NodeTraverser, n: int) -> None:
|
|
51
|
+
self.visitor = visitor
|
|
52
|
+
self.n = n
|
|
53
|
+
|
|
54
|
+
def __enter__(self) -> None:
|
|
55
|
+
n = self.visitor.get_node()
|
|
56
|
+
self.visitor.set_node(self.n)
|
|
57
|
+
self.n = n
|
|
58
|
+
|
|
59
|
+
def __exit__(self, *args: Any) -> None:
|
|
60
|
+
self.visitor.set_node(self.n)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
noop_context: nullcontext[None] = nullcontext()
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@singledispatch
|
|
67
|
+
def _translate_ir(
|
|
68
|
+
node: Any, visitor: NodeTraverser, schema: dict[str, plc.DataType]
|
|
69
|
+
) -> ir.IR:
|
|
70
|
+
raise NotImplementedError(
|
|
71
|
+
f"Translation for {type(node).__name__}"
|
|
72
|
+
) # pragma: no cover
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@_translate_ir.register
|
|
76
|
+
def _(
|
|
77
|
+
node: pl_ir.PythonScan, visitor: NodeTraverser, schema: dict[str, plc.DataType]
|
|
78
|
+
) -> ir.IR:
|
|
79
|
+
return ir.PythonScan(
|
|
80
|
+
schema,
|
|
81
|
+
node.options,
|
|
82
|
+
translate_named_expr(visitor, n=node.predicate)
|
|
83
|
+
if node.predicate is not None
|
|
84
|
+
else None,
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@_translate_ir.register
|
|
89
|
+
def _(
|
|
90
|
+
node: pl_ir.Scan, visitor: NodeTraverser, schema: dict[str, plc.DataType]
|
|
91
|
+
) -> ir.IR:
|
|
92
|
+
typ, *options = node.scan_type
|
|
93
|
+
if typ == "ndjson":
|
|
94
|
+
(reader_options,) = map(json.loads, options)
|
|
95
|
+
cloud_options = None
|
|
96
|
+
else:
|
|
97
|
+
reader_options, cloud_options = map(json.loads, options)
|
|
98
|
+
return ir.Scan(
|
|
99
|
+
schema,
|
|
100
|
+
typ,
|
|
101
|
+
reader_options,
|
|
102
|
+
cloud_options,
|
|
103
|
+
node.paths,
|
|
104
|
+
node.file_options,
|
|
105
|
+
translate_named_expr(visitor, n=node.predicate)
|
|
106
|
+
if node.predicate is not None
|
|
107
|
+
else None,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
@_translate_ir.register
|
|
112
|
+
def _(
|
|
113
|
+
node: pl_ir.Cache, visitor: NodeTraverser, schema: dict[str, plc.DataType]
|
|
114
|
+
) -> ir.IR:
|
|
115
|
+
return ir.Cache(schema, node.id_, translate_ir(visitor, n=node.input))
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
@_translate_ir.register
|
|
119
|
+
def _(
|
|
120
|
+
node: pl_ir.DataFrameScan, visitor: NodeTraverser, schema: dict[str, plc.DataType]
|
|
121
|
+
) -> ir.IR:
|
|
122
|
+
return ir.DataFrameScan(
|
|
123
|
+
schema,
|
|
124
|
+
node.df,
|
|
125
|
+
node.projection,
|
|
126
|
+
translate_named_expr(visitor, n=node.selection)
|
|
127
|
+
if node.selection is not None
|
|
128
|
+
else None,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
@_translate_ir.register
|
|
133
|
+
def _(
|
|
134
|
+
node: pl_ir.Select, visitor: NodeTraverser, schema: dict[str, plc.DataType]
|
|
135
|
+
) -> ir.IR:
|
|
136
|
+
with set_node(visitor, node.input):
|
|
137
|
+
inp = translate_ir(visitor, n=None)
|
|
138
|
+
exprs = [translate_named_expr(visitor, n=e) for e in node.expr]
|
|
139
|
+
return ir.Select(schema, inp, exprs, node.should_broadcast)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
@_translate_ir.register
|
|
143
|
+
def _(
|
|
144
|
+
node: pl_ir.GroupBy, visitor: NodeTraverser, schema: dict[str, plc.DataType]
|
|
145
|
+
) -> ir.IR:
|
|
146
|
+
with set_node(visitor, node.input):
|
|
147
|
+
inp = translate_ir(visitor, n=None)
|
|
148
|
+
aggs = [translate_named_expr(visitor, n=e) for e in node.aggs]
|
|
149
|
+
keys = [translate_named_expr(visitor, n=e) for e in node.keys]
|
|
150
|
+
return ir.GroupBy(
|
|
151
|
+
schema,
|
|
152
|
+
inp,
|
|
153
|
+
aggs,
|
|
154
|
+
keys,
|
|
155
|
+
node.maintain_order,
|
|
156
|
+
node.options,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
@_translate_ir.register
|
|
161
|
+
def _(
|
|
162
|
+
node: pl_ir.Join, visitor: NodeTraverser, schema: dict[str, plc.DataType]
|
|
163
|
+
) -> ir.IR:
|
|
164
|
+
# Join key dtypes are dependent on the schema of the left and
|
|
165
|
+
# right inputs, so these must be translated with the relevant
|
|
166
|
+
# input active.
|
|
167
|
+
with set_node(visitor, node.input_left):
|
|
168
|
+
inp_left = translate_ir(visitor, n=None)
|
|
169
|
+
left_on = [translate_named_expr(visitor, n=e) for e in node.left_on]
|
|
170
|
+
with set_node(visitor, node.input_right):
|
|
171
|
+
inp_right = translate_ir(visitor, n=None)
|
|
172
|
+
right_on = [translate_named_expr(visitor, n=e) for e in node.right_on]
|
|
173
|
+
return ir.Join(schema, inp_left, inp_right, left_on, right_on, node.options)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
@_translate_ir.register
|
|
177
|
+
def _(
|
|
178
|
+
node: pl_ir.HStack, visitor: NodeTraverser, schema: dict[str, plc.DataType]
|
|
179
|
+
) -> ir.IR:
|
|
180
|
+
with set_node(visitor, node.input):
|
|
181
|
+
inp = translate_ir(visitor, n=None)
|
|
182
|
+
exprs = [translate_named_expr(visitor, n=e) for e in node.exprs]
|
|
183
|
+
return ir.HStack(schema, inp, exprs, node.should_broadcast)
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
@_translate_ir.register
|
|
187
|
+
def _(
|
|
188
|
+
node: pl_ir.Reduce, visitor: NodeTraverser, schema: dict[str, plc.DataType]
|
|
189
|
+
) -> ir.IR: # pragma: no cover; polars doesn't emit this node yet
|
|
190
|
+
with set_node(visitor, node.input):
|
|
191
|
+
inp = translate_ir(visitor, n=None)
|
|
192
|
+
exprs = [translate_named_expr(visitor, n=e) for e in node.expr]
|
|
193
|
+
return ir.Reduce(schema, inp, exprs)
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
@_translate_ir.register
|
|
197
|
+
def _(
|
|
198
|
+
node: pl_ir.Distinct, visitor: NodeTraverser, schema: dict[str, plc.DataType]
|
|
199
|
+
) -> ir.IR:
|
|
200
|
+
return ir.Distinct(
|
|
201
|
+
schema,
|
|
202
|
+
translate_ir(visitor, n=node.input),
|
|
203
|
+
node.options,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
@_translate_ir.register
|
|
208
|
+
def _(
|
|
209
|
+
node: pl_ir.Sort, visitor: NodeTraverser, schema: dict[str, plc.DataType]
|
|
210
|
+
) -> ir.IR:
|
|
211
|
+
with set_node(visitor, node.input):
|
|
212
|
+
inp = translate_ir(visitor, n=None)
|
|
213
|
+
by = [translate_named_expr(visitor, n=e) for e in node.by_column]
|
|
214
|
+
return ir.Sort(schema, inp, by, node.sort_options, node.slice)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
@_translate_ir.register
|
|
218
|
+
def _(
|
|
219
|
+
node: pl_ir.Slice, visitor: NodeTraverser, schema: dict[str, plc.DataType]
|
|
220
|
+
) -> ir.IR:
|
|
221
|
+
return ir.Slice(schema, translate_ir(visitor, n=node.input), node.offset, node.len)
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
@_translate_ir.register
|
|
225
|
+
def _(
|
|
226
|
+
node: pl_ir.Filter, visitor: NodeTraverser, schema: dict[str, plc.DataType]
|
|
227
|
+
) -> ir.IR:
|
|
228
|
+
with set_node(visitor, node.input):
|
|
229
|
+
inp = translate_ir(visitor, n=None)
|
|
230
|
+
mask = translate_named_expr(visitor, n=node.predicate)
|
|
231
|
+
return ir.Filter(schema, inp, mask)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
@_translate_ir.register
|
|
235
|
+
def _(
|
|
236
|
+
node: pl_ir.SimpleProjection,
|
|
237
|
+
visitor: NodeTraverser,
|
|
238
|
+
schema: dict[str, plc.DataType],
|
|
239
|
+
) -> ir.IR:
|
|
240
|
+
return ir.Projection(schema, translate_ir(visitor, n=node.input))
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
@_translate_ir.register
|
|
244
|
+
def _(
|
|
245
|
+
node: pl_ir.MapFunction, visitor: NodeTraverser, schema: dict[str, plc.DataType]
|
|
246
|
+
) -> ir.IR:
|
|
247
|
+
name, *options = node.function
|
|
248
|
+
return ir.MapFunction(
|
|
249
|
+
schema,
|
|
250
|
+
# TODO: merge_sorted breaks this pattern
|
|
251
|
+
translate_ir(visitor, n=node.input),
|
|
252
|
+
name,
|
|
253
|
+
options,
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
@_translate_ir.register
|
|
258
|
+
def _(
|
|
259
|
+
node: pl_ir.Union, visitor: NodeTraverser, schema: dict[str, plc.DataType]
|
|
260
|
+
) -> ir.IR:
|
|
261
|
+
return ir.Union(
|
|
262
|
+
schema, [translate_ir(visitor, n=n) for n in node.inputs], node.options
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
@_translate_ir.register
|
|
267
|
+
def _(
|
|
268
|
+
node: pl_ir.HConcat, visitor: NodeTraverser, schema: dict[str, plc.DataType]
|
|
269
|
+
) -> ir.IR:
|
|
270
|
+
return ir.HConcat(schema, [translate_ir(visitor, n=n) for n in node.inputs])
|
|
271
|
+
|
|
272
|
+
|
|
273
|
+
def translate_ir(visitor: NodeTraverser, *, n: int | None = None) -> ir.IR:
|
|
274
|
+
"""
|
|
275
|
+
Translate a polars-internal IR node to our representation.
|
|
276
|
+
|
|
277
|
+
Parameters
|
|
278
|
+
----------
|
|
279
|
+
visitor
|
|
280
|
+
Polars NodeTraverser object
|
|
281
|
+
n
|
|
282
|
+
Optional node to start traversing from, if not provided uses
|
|
283
|
+
current polars-internal node.
|
|
284
|
+
|
|
285
|
+
Returns
|
|
286
|
+
-------
|
|
287
|
+
Translated IR object
|
|
288
|
+
|
|
289
|
+
Raises
|
|
290
|
+
------
|
|
291
|
+
NotImplementedError
|
|
292
|
+
If we can't translate the nodes due to unsupported functionality.
|
|
293
|
+
"""
|
|
294
|
+
ctx: AbstractContextManager[None] = (
|
|
295
|
+
set_node(visitor, n) if n is not None else noop_context
|
|
296
|
+
)
|
|
297
|
+
with ctx:
|
|
298
|
+
node = visitor.view_current_node()
|
|
299
|
+
schema = {k: dtypes.from_polars(v) for k, v in visitor.get_schema().items()}
|
|
300
|
+
return _translate_ir(node, visitor, schema)
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def translate_named_expr(
|
|
304
|
+
visitor: NodeTraverser, *, n: pl_expr.PyExprIR
|
|
305
|
+
) -> expr.NamedExpr:
|
|
306
|
+
"""
|
|
307
|
+
Translate a polars-internal named expression IR object into our representation.
|
|
308
|
+
|
|
309
|
+
Parameters
|
|
310
|
+
----------
|
|
311
|
+
visitor
|
|
312
|
+
Polars NodeTraverser object
|
|
313
|
+
n
|
|
314
|
+
Node to translate, a named expression node.
|
|
315
|
+
|
|
316
|
+
Returns
|
|
317
|
+
-------
|
|
318
|
+
Translated IR object.
|
|
319
|
+
|
|
320
|
+
Notes
|
|
321
|
+
-----
|
|
322
|
+
The datatype of the internal expression will be obtained from the
|
|
323
|
+
visitor by calling ``get_dtype``, for this to work properly, the
|
|
324
|
+
caller should arrange that the expression is translated with the
|
|
325
|
+
node that it references "active" for the visitor (see :class:`set_node`).
|
|
326
|
+
|
|
327
|
+
Raises
|
|
328
|
+
------
|
|
329
|
+
NotImplementedError
|
|
330
|
+
If any translation fails due to unsupported functionality.
|
|
331
|
+
"""
|
|
332
|
+
return expr.NamedExpr(n.output_name, translate_expr(visitor, n=n.node))
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
@singledispatch
|
|
336
|
+
def _translate_expr(
|
|
337
|
+
node: Any, visitor: NodeTraverser, dtype: plc.DataType
|
|
338
|
+
) -> expr.Expr:
|
|
339
|
+
raise NotImplementedError(
|
|
340
|
+
f"Translation for {type(node).__name__}"
|
|
341
|
+
) # pragma: no cover
|
|
342
|
+
|
|
343
|
+
|
|
344
|
+
@_translate_expr.register
|
|
345
|
+
def _(node: pl_expr.Function, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr:
|
|
346
|
+
name, *options = node.function_data
|
|
347
|
+
options = tuple(options)
|
|
348
|
+
if isinstance(name, pl_expr.StringFunction):
|
|
349
|
+
return expr.StringFunction(
|
|
350
|
+
dtype,
|
|
351
|
+
name,
|
|
352
|
+
options,
|
|
353
|
+
*(translate_expr(visitor, n=n) for n in node.input),
|
|
354
|
+
)
|
|
355
|
+
elif isinstance(name, pl_expr.BooleanFunction):
|
|
356
|
+
if name == pl_expr.BooleanFunction.IsBetween:
|
|
357
|
+
column, lo, hi = (translate_expr(visitor, n=n) for n in node.input)
|
|
358
|
+
(closed,) = options
|
|
359
|
+
lop, rop = expr.BooleanFunction._BETWEEN_OPS[closed]
|
|
360
|
+
return expr.BinOp(
|
|
361
|
+
dtype,
|
|
362
|
+
plc.binaryop.BinaryOperator.LOGICAL_AND,
|
|
363
|
+
expr.BinOp(dtype, lop, column, lo),
|
|
364
|
+
expr.BinOp(dtype, rop, column, hi),
|
|
365
|
+
)
|
|
366
|
+
return expr.BooleanFunction(
|
|
367
|
+
dtype,
|
|
368
|
+
name,
|
|
369
|
+
options,
|
|
370
|
+
*(translate_expr(visitor, n=n) for n in node.input),
|
|
371
|
+
)
|
|
372
|
+
elif isinstance(name, pl_expr.TemporalFunction):
|
|
373
|
+
return expr.TemporalFunction(
|
|
374
|
+
dtype,
|
|
375
|
+
name,
|
|
376
|
+
options,
|
|
377
|
+
*(translate_expr(visitor, n=n) for n in node.input),
|
|
378
|
+
)
|
|
379
|
+
elif isinstance(name, str):
|
|
380
|
+
return expr.UnaryFunction(
|
|
381
|
+
dtype,
|
|
382
|
+
name,
|
|
383
|
+
options,
|
|
384
|
+
*(translate_expr(visitor, n=n) for n in node.input),
|
|
385
|
+
)
|
|
386
|
+
raise NotImplementedError(
|
|
387
|
+
f"No handler for Expr function node with {name=}"
|
|
388
|
+
) # pragma: no cover; polars raises on the rust side for now
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
@_translate_expr.register
|
|
392
|
+
def _(node: pl_expr.Window, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr:
|
|
393
|
+
# TODO: raise in groupby?
|
|
394
|
+
if isinstance(node.options, pl_expr.RollingGroupOptions):
|
|
395
|
+
# pl.col("a").rolling(...)
|
|
396
|
+
return expr.RollingWindow(
|
|
397
|
+
dtype, node.options, translate_expr(visitor, n=node.function)
|
|
398
|
+
)
|
|
399
|
+
elif isinstance(node.options, pl_expr.WindowMapping):
|
|
400
|
+
# pl.col("a").over(...)
|
|
401
|
+
return expr.GroupedRollingWindow(
|
|
402
|
+
dtype,
|
|
403
|
+
node.options,
|
|
404
|
+
translate_expr(visitor, n=node.function),
|
|
405
|
+
*(translate_expr(visitor, n=n) for n in node.partition_by),
|
|
406
|
+
)
|
|
407
|
+
assert_never(node.options)
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
@_translate_expr.register
|
|
411
|
+
def _(node: pl_expr.Literal, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr:
|
|
412
|
+
if isinstance(node.value, plrs.PySeries):
|
|
413
|
+
return expr.LiteralColumn(dtype, pl.Series._from_pyseries(node.value))
|
|
414
|
+
value = pa.scalar(node.value, type=plc.interop.to_arrow(dtype))
|
|
415
|
+
return expr.Literal(dtype, value)
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
@_translate_expr.register
|
|
419
|
+
def _(node: pl_expr.Sort, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr:
|
|
420
|
+
# TODO: raise in groupby
|
|
421
|
+
return expr.Sort(dtype, node.options, translate_expr(visitor, n=node.expr))
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
@_translate_expr.register
|
|
425
|
+
def _(node: pl_expr.SortBy, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr:
|
|
426
|
+
return expr.SortBy(
|
|
427
|
+
dtype,
|
|
428
|
+
node.sort_options,
|
|
429
|
+
translate_expr(visitor, n=node.expr),
|
|
430
|
+
*(translate_expr(visitor, n=n) for n in node.by),
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
|
|
434
|
+
@_translate_expr.register
|
|
435
|
+
def _(node: pl_expr.Gather, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr:
|
|
436
|
+
return expr.Gather(
|
|
437
|
+
dtype,
|
|
438
|
+
translate_expr(visitor, n=node.expr),
|
|
439
|
+
translate_expr(visitor, n=node.idx),
|
|
440
|
+
)
|
|
441
|
+
|
|
442
|
+
|
|
443
|
+
@_translate_expr.register
|
|
444
|
+
def _(node: pl_expr.Filter, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr:
|
|
445
|
+
return expr.Filter(
|
|
446
|
+
dtype,
|
|
447
|
+
translate_expr(visitor, n=node.input),
|
|
448
|
+
translate_expr(visitor, n=node.by),
|
|
449
|
+
)
|
|
450
|
+
|
|
451
|
+
|
|
452
|
+
@_translate_expr.register
|
|
453
|
+
def _(node: pl_expr.Cast, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr:
|
|
454
|
+
inner = translate_expr(visitor, n=node.expr)
|
|
455
|
+
# Push casts into literals so we can handle Cast(Literal(Null))
|
|
456
|
+
if isinstance(inner, expr.Literal):
|
|
457
|
+
return expr.Literal(dtype, inner.value.cast(plc.interop.to_arrow(dtype)))
|
|
458
|
+
elif isinstance(inner, expr.Cast):
|
|
459
|
+
# Translation of Len/Count-agg put in a cast, remove double
|
|
460
|
+
# casts if we have one.
|
|
461
|
+
(inner,) = inner.children
|
|
462
|
+
return expr.Cast(dtype, inner)
|
|
463
|
+
|
|
464
|
+
|
|
465
|
+
@_translate_expr.register
|
|
466
|
+
def _(node: pl_expr.Column, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr:
|
|
467
|
+
return expr.Col(dtype, node.name)
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
@_translate_expr.register
|
|
471
|
+
def _(node: pl_expr.Agg, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr:
|
|
472
|
+
value = expr.Agg(
|
|
473
|
+
dtype,
|
|
474
|
+
node.name,
|
|
475
|
+
node.options,
|
|
476
|
+
*(translate_expr(visitor, n=n) for n in node.arguments),
|
|
477
|
+
)
|
|
478
|
+
if value.name == "count" and value.dtype.id() != plc.TypeId.INT32:
|
|
479
|
+
return expr.Cast(value.dtype, value)
|
|
480
|
+
return value
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
@_translate_expr.register
|
|
484
|
+
def _(node: pl_expr.Ternary, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr:
|
|
485
|
+
return expr.Ternary(
|
|
486
|
+
dtype,
|
|
487
|
+
translate_expr(visitor, n=node.predicate),
|
|
488
|
+
translate_expr(visitor, n=node.truthy),
|
|
489
|
+
translate_expr(visitor, n=node.falsy),
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
|
|
493
|
+
@_translate_expr.register
|
|
494
|
+
def _(
|
|
495
|
+
node: pl_expr.BinaryExpr, visitor: NodeTraverser, dtype: plc.DataType
|
|
496
|
+
) -> expr.Expr:
|
|
497
|
+
return expr.BinOp(
|
|
498
|
+
dtype,
|
|
499
|
+
expr.BinOp._MAPPING[node.op],
|
|
500
|
+
translate_expr(visitor, n=node.left),
|
|
501
|
+
translate_expr(visitor, n=node.right),
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
@_translate_expr.register
|
|
506
|
+
def _(node: pl_expr.Len, visitor: NodeTraverser, dtype: plc.DataType) -> expr.Expr:
|
|
507
|
+
value = expr.Len(dtype)
|
|
508
|
+
if dtype.id() != plc.TypeId.INT32:
|
|
509
|
+
return expr.Cast(dtype, value)
|
|
510
|
+
return value # pragma: no cover; never reached since polars len has uint32 dtype
|
|
511
|
+
|
|
512
|
+
|
|
513
|
+
def translate_expr(visitor: NodeTraverser, *, n: int) -> expr.Expr:
|
|
514
|
+
"""
|
|
515
|
+
Translate a polars-internal expression IR into our representation.
|
|
516
|
+
|
|
517
|
+
Parameters
|
|
518
|
+
----------
|
|
519
|
+
visitor
|
|
520
|
+
Polars NodeTraverser object
|
|
521
|
+
n
|
|
522
|
+
Node to translate, an integer referencing a polars internal node.
|
|
523
|
+
|
|
524
|
+
Returns
|
|
525
|
+
-------
|
|
526
|
+
Translated IR object.
|
|
527
|
+
|
|
528
|
+
Raises
|
|
529
|
+
------
|
|
530
|
+
NotImplementedError
|
|
531
|
+
If any translation fails due to unsupported functionality.
|
|
532
|
+
"""
|
|
533
|
+
node = visitor.view_expression(n)
|
|
534
|
+
dtype = dtypes.from_polars(visitor.get_dtype(n))
|
|
535
|
+
return _translate_expr(node, visitor, dtype)
|
cudf_polars/py.typed
ADDED
|
File without changes
|
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
|
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
3
|
+
|
|
4
|
+
"""Device-aware assertions."""
|
|
5
|
+
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from functools import partial
|
|
9
|
+
from typing import TYPE_CHECKING
|
|
10
|
+
|
|
11
|
+
from polars.testing.asserts import assert_frame_equal
|
|
12
|
+
|
|
13
|
+
from cudf_polars.callback import execute_with_cudf
|
|
14
|
+
from cudf_polars.dsl.translate import translate_ir
|
|
15
|
+
|
|
16
|
+
if TYPE_CHECKING:
|
|
17
|
+
from collections.abc import Mapping
|
|
18
|
+
|
|
19
|
+
import polars as pl
|
|
20
|
+
|
|
21
|
+
from cudf_polars.typing import OptimizationArgs
|
|
22
|
+
|
|
23
|
+
__all__: list[str] = ["assert_gpu_result_equal", "assert_ir_translation_raises"]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def assert_gpu_result_equal(
|
|
27
|
+
lazydf: pl.LazyFrame,
|
|
28
|
+
*,
|
|
29
|
+
collect_kwargs: Mapping[OptimizationArgs, bool] | None = None,
|
|
30
|
+
check_row_order: bool = True,
|
|
31
|
+
check_column_order: bool = True,
|
|
32
|
+
check_dtypes: bool = True,
|
|
33
|
+
check_exact: bool = True,
|
|
34
|
+
rtol: float = 1e-05,
|
|
35
|
+
atol: float = 1e-08,
|
|
36
|
+
categorical_as_str: bool = False,
|
|
37
|
+
) -> None:
|
|
38
|
+
"""
|
|
39
|
+
Assert that collection of a lazyframe on GPU produces correct results.
|
|
40
|
+
|
|
41
|
+
Parameters
|
|
42
|
+
----------
|
|
43
|
+
lazydf
|
|
44
|
+
frame to collect.
|
|
45
|
+
collect_kwargs
|
|
46
|
+
Keyword arguments to pass to collect. Useful for controlling
|
|
47
|
+
optimization settings.
|
|
48
|
+
check_row_order
|
|
49
|
+
Expect rows to be in same order
|
|
50
|
+
check_column_order
|
|
51
|
+
Expect columns to be in same order
|
|
52
|
+
check_dtypes
|
|
53
|
+
Expect dtypes to match
|
|
54
|
+
check_exact
|
|
55
|
+
Require exact equality for floats, if `False` compare using
|
|
56
|
+
rtol and atol.
|
|
57
|
+
rtol
|
|
58
|
+
Relative tolerance for float comparisons
|
|
59
|
+
atol
|
|
60
|
+
Absolute tolerance for float comparisons
|
|
61
|
+
categorical_as_str
|
|
62
|
+
Decat categoricals to strings before comparing
|
|
63
|
+
|
|
64
|
+
Raises
|
|
65
|
+
------
|
|
66
|
+
AssertionError
|
|
67
|
+
If the GPU and CPU collection do not match.
|
|
68
|
+
NotImplementedError
|
|
69
|
+
If GPU collection failed in some way.
|
|
70
|
+
"""
|
|
71
|
+
collect_kwargs = {} if collect_kwargs is None else collect_kwargs
|
|
72
|
+
expect = lazydf.collect(**collect_kwargs)
|
|
73
|
+
got = lazydf.collect(
|
|
74
|
+
**collect_kwargs,
|
|
75
|
+
post_opt_callback=partial(execute_with_cudf, raise_on_fail=True),
|
|
76
|
+
)
|
|
77
|
+
assert_frame_equal(
|
|
78
|
+
expect,
|
|
79
|
+
got,
|
|
80
|
+
check_row_order=check_row_order,
|
|
81
|
+
check_column_order=check_column_order,
|
|
82
|
+
check_dtypes=check_dtypes,
|
|
83
|
+
check_exact=check_exact,
|
|
84
|
+
rtol=rtol,
|
|
85
|
+
atol=atol,
|
|
86
|
+
categorical_as_str=categorical_as_str,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def assert_ir_translation_raises(q: pl.LazyFrame, *exceptions: type[Exception]) -> None:
|
|
91
|
+
"""
|
|
92
|
+
Assert that translation of a query raises an exception.
|
|
93
|
+
|
|
94
|
+
Parameters
|
|
95
|
+
----------
|
|
96
|
+
q
|
|
97
|
+
Query to translate.
|
|
98
|
+
exceptions
|
|
99
|
+
Exceptions that one expects might be raised.
|
|
100
|
+
|
|
101
|
+
Returns
|
|
102
|
+
-------
|
|
103
|
+
None
|
|
104
|
+
If translation successfully raised the specified exceptions.
|
|
105
|
+
|
|
106
|
+
Raises
|
|
107
|
+
------
|
|
108
|
+
AssertionError
|
|
109
|
+
If the specified exceptions were not raised.
|
|
110
|
+
"""
|
|
111
|
+
try:
|
|
112
|
+
_ = translate_ir(q._ldf.visit())
|
|
113
|
+
except exceptions:
|
|
114
|
+
return
|
|
115
|
+
except Exception as e:
|
|
116
|
+
raise AssertionError(f"Translation DID NOT RAISE {exceptions}") from e
|
|
117
|
+
else:
|
|
118
|
+
raise AssertionError(f"Translation DID NOT RAISE {exceptions}")
|