cudf-polars-cu13 25.10.0__py3-none-any.whl → 25.12.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cudf_polars/GIT_COMMIT +1 -1
- cudf_polars/VERSION +1 -1
- cudf_polars/callback.py +32 -8
- cudf_polars/containers/column.py +94 -59
- cudf_polars/containers/dataframe.py +123 -34
- cudf_polars/containers/datatype.py +134 -13
- cudf_polars/dsl/expr.py +0 -2
- cudf_polars/dsl/expressions/aggregation.py +80 -28
- cudf_polars/dsl/expressions/binaryop.py +34 -14
- cudf_polars/dsl/expressions/boolean.py +110 -37
- cudf_polars/dsl/expressions/datetime.py +59 -30
- cudf_polars/dsl/expressions/literal.py +11 -5
- cudf_polars/dsl/expressions/rolling.py +460 -119
- cudf_polars/dsl/expressions/selection.py +9 -8
- cudf_polars/dsl/expressions/slicing.py +1 -1
- cudf_polars/dsl/expressions/string.py +235 -102
- cudf_polars/dsl/expressions/struct.py +19 -7
- cudf_polars/dsl/expressions/ternary.py +9 -3
- cudf_polars/dsl/expressions/unary.py +117 -58
- cudf_polars/dsl/ir.py +923 -290
- cudf_polars/dsl/to_ast.py +30 -13
- cudf_polars/dsl/tracing.py +194 -0
- cudf_polars/dsl/translate.py +294 -97
- cudf_polars/dsl/utils/aggregations.py +34 -26
- cudf_polars/dsl/utils/reshape.py +14 -2
- cudf_polars/dsl/utils/rolling.py +12 -8
- cudf_polars/dsl/utils/windows.py +35 -20
- cudf_polars/experimental/base.py +45 -2
- cudf_polars/experimental/benchmarks/pdsds.py +12 -126
- cudf_polars/experimental/benchmarks/pdsh.py +791 -1
- cudf_polars/experimental/benchmarks/utils.py +515 -39
- cudf_polars/experimental/dask_registers.py +47 -20
- cudf_polars/experimental/dispatch.py +9 -3
- cudf_polars/experimental/explain.py +15 -2
- cudf_polars/experimental/expressions.py +22 -10
- cudf_polars/experimental/groupby.py +23 -4
- cudf_polars/experimental/io.py +93 -83
- cudf_polars/experimental/join.py +39 -22
- cudf_polars/experimental/parallel.py +60 -14
- cudf_polars/experimental/rapidsmpf/__init__.py +8 -0
- cudf_polars/experimental/rapidsmpf/core.py +361 -0
- cudf_polars/experimental/rapidsmpf/dispatch.py +150 -0
- cudf_polars/experimental/rapidsmpf/io.py +604 -0
- cudf_polars/experimental/rapidsmpf/join.py +237 -0
- cudf_polars/experimental/rapidsmpf/lower.py +74 -0
- cudf_polars/experimental/rapidsmpf/nodes.py +494 -0
- cudf_polars/experimental/rapidsmpf/repartition.py +151 -0
- cudf_polars/experimental/rapidsmpf/shuffle.py +277 -0
- cudf_polars/experimental/rapidsmpf/union.py +96 -0
- cudf_polars/experimental/rapidsmpf/utils.py +162 -0
- cudf_polars/experimental/repartition.py +9 -2
- cudf_polars/experimental/select.py +177 -14
- cudf_polars/experimental/shuffle.py +28 -8
- cudf_polars/experimental/sort.py +92 -25
- cudf_polars/experimental/statistics.py +24 -5
- cudf_polars/experimental/utils.py +25 -7
- cudf_polars/testing/asserts.py +13 -8
- cudf_polars/testing/io.py +2 -1
- cudf_polars/testing/plugin.py +88 -15
- cudf_polars/typing/__init__.py +86 -32
- cudf_polars/utils/config.py +406 -58
- cudf_polars/utils/cuda_stream.py +70 -0
- cudf_polars/utils/versions.py +3 -2
- cudf_polars_cu13-25.12.0.dist-info/METADATA +182 -0
- cudf_polars_cu13-25.12.0.dist-info/RECORD +104 -0
- cudf_polars_cu13-25.10.0.dist-info/METADATA +0 -136
- cudf_polars_cu13-25.10.0.dist-info/RECORD +0 -92
- {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-25.12.0.dist-info}/WHEEL +0 -0
- {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-25.12.0.dist-info}/licenses/LICENSE +0 -0
- {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-25.12.0.dist-info}/top_level.txt +0 -0
|
@@ -20,8 +20,9 @@ if TYPE_CHECKING:
|
|
|
20
20
|
|
|
21
21
|
from typing_extensions import Any, CapsuleType, Self
|
|
22
22
|
|
|
23
|
-
from
|
|
23
|
+
from rmm.pylibrmm.stream import Stream
|
|
24
24
|
|
|
25
|
+
from cudf_polars.typing import ColumnOptions, DataFrameHeader, PolarsDataType, Slice
|
|
25
26
|
|
|
26
27
|
__all__: list[str] = ["DataFrame"]
|
|
27
28
|
|
|
@@ -55,15 +56,21 @@ def _create_polars_column_metadata(
|
|
|
55
56
|
# This is also defined in pylibcudf.interop
|
|
56
57
|
class _ObjectWithArrowMetadata:
|
|
57
58
|
def __init__(
|
|
58
|
-
self,
|
|
59
|
+
self,
|
|
60
|
+
obj: plc.Table | plc.Column,
|
|
61
|
+
metadata: list[plc.interop.ColumnMetadata],
|
|
62
|
+
stream: Stream,
|
|
59
63
|
) -> None:
|
|
60
64
|
self.obj = obj
|
|
61
65
|
self.metadata = metadata
|
|
66
|
+
self.stream = stream
|
|
62
67
|
|
|
63
68
|
def __arrow_c_array__(
|
|
64
69
|
self, requested_schema: None = None
|
|
65
70
|
) -> tuple[CapsuleType, CapsuleType]:
|
|
66
|
-
return self.obj._to_schema(self.metadata), self.obj._to_host_array(
|
|
71
|
+
return self.obj._to_schema(self.metadata), self.obj._to_host_array(
|
|
72
|
+
stream=self.stream
|
|
73
|
+
)
|
|
67
74
|
|
|
68
75
|
|
|
69
76
|
# Pacify the type checker. DataFrame init asserts that all the columns
|
|
@@ -78,8 +85,9 @@ class DataFrame:
|
|
|
78
85
|
column_map: dict[str, Column]
|
|
79
86
|
table: plc.Table
|
|
80
87
|
columns: list[NamedColumn]
|
|
88
|
+
stream: Stream
|
|
81
89
|
|
|
82
|
-
def __init__(self, columns: Iterable[Column]) -> None:
|
|
90
|
+
def __init__(self, columns: Iterable[Column], stream: Stream) -> None:
|
|
83
91
|
columns = list(columns)
|
|
84
92
|
if any(c.name is None for c in columns):
|
|
85
93
|
raise ValueError("All columns must have a name")
|
|
@@ -87,10 +95,11 @@ class DataFrame:
|
|
|
87
95
|
self.dtypes = [c.dtype for c in self.columns]
|
|
88
96
|
self.column_map = {c.name: c for c in self.columns}
|
|
89
97
|
self.table = plc.Table([c.obj for c in self.columns])
|
|
98
|
+
self.stream = stream
|
|
90
99
|
|
|
91
100
|
def copy(self) -> Self:
|
|
92
101
|
"""Return a shallow copy of self."""
|
|
93
|
-
return type(self)(c.copy() for c in self.columns)
|
|
102
|
+
return type(self)((c.copy() for c in self.columns), stream=self.stream)
|
|
94
103
|
|
|
95
104
|
def to_polars(self) -> pl.DataFrame:
|
|
96
105
|
"""Convert to a polars DataFrame."""
|
|
@@ -102,10 +111,12 @@ class DataFrame:
|
|
|
102
111
|
# serialise with names we control and rename with that map.
|
|
103
112
|
name_map = {f"column_{i}": name for i, name in enumerate(self.column_map)}
|
|
104
113
|
metadata = [
|
|
105
|
-
_create_polars_column_metadata(name, dtype.
|
|
114
|
+
_create_polars_column_metadata(name, dtype.polars_type)
|
|
106
115
|
for name, dtype in zip(name_map, self.dtypes, strict=True)
|
|
107
116
|
]
|
|
108
|
-
table_with_metadata = _ObjectWithArrowMetadata(
|
|
117
|
+
table_with_metadata = _ObjectWithArrowMetadata(
|
|
118
|
+
self.table, metadata, self.stream
|
|
119
|
+
)
|
|
109
120
|
df = pl.DataFrame(table_with_metadata)
|
|
110
121
|
return df.rename(name_map).with_columns(
|
|
111
122
|
pl.col(c.name).set_sorted(descending=c.order == plc.types.Order.DESCENDING)
|
|
@@ -135,7 +146,7 @@ class DataFrame:
|
|
|
135
146
|
return self.table.num_rows() if self.column_map else 0
|
|
136
147
|
|
|
137
148
|
@classmethod
|
|
138
|
-
def from_polars(cls, df: pl.DataFrame) -> Self:
|
|
149
|
+
def from_polars(cls, df: pl.DataFrame, stream: Stream) -> Self:
|
|
139
150
|
"""
|
|
140
151
|
Create from a polars dataframe.
|
|
141
152
|
|
|
@@ -143,22 +154,34 @@ class DataFrame:
|
|
|
143
154
|
----------
|
|
144
155
|
df
|
|
145
156
|
Polars dataframe to convert
|
|
157
|
+
stream
|
|
158
|
+
CUDA stream used for device memory operations and kernel launches
|
|
159
|
+
on this dataframe.
|
|
146
160
|
|
|
147
161
|
Returns
|
|
148
162
|
-------
|
|
149
163
|
New dataframe representing the input.
|
|
150
164
|
"""
|
|
151
|
-
plc_table = plc.Table.from_arrow(df)
|
|
165
|
+
plc_table = plc.Table.from_arrow(df, stream=stream)
|
|
152
166
|
return cls(
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
167
|
+
(
|
|
168
|
+
Column(d_col, name=name, dtype=DataType(h_col.dtype)).copy_metadata(
|
|
169
|
+
h_col
|
|
170
|
+
)
|
|
171
|
+
for d_col, h_col, name in zip(
|
|
172
|
+
plc_table.columns(), df.iter_columns(), df.columns, strict=True
|
|
173
|
+
)
|
|
174
|
+
),
|
|
175
|
+
stream=stream,
|
|
157
176
|
)
|
|
158
177
|
|
|
159
178
|
@classmethod
|
|
160
179
|
def from_table(
|
|
161
|
-
cls,
|
|
180
|
+
cls,
|
|
181
|
+
table: plc.Table,
|
|
182
|
+
names: Sequence[str],
|
|
183
|
+
dtypes: Sequence[DataType],
|
|
184
|
+
stream: Stream,
|
|
162
185
|
) -> Self:
|
|
163
186
|
"""
|
|
164
187
|
Create from a pylibcudf table.
|
|
@@ -171,6 +194,10 @@ class DataFrame:
|
|
|
171
194
|
Names for the columns
|
|
172
195
|
dtypes
|
|
173
196
|
Dtypes for the columns
|
|
197
|
+
stream
|
|
198
|
+
CUDA stream used for device memory operations and kernel launches
|
|
199
|
+
on this dataframe. The caller is responsible for ensuring that
|
|
200
|
+
the data in ``table`` is valid on ``stream``.
|
|
174
201
|
|
|
175
202
|
Returns
|
|
176
203
|
-------
|
|
@@ -185,13 +212,19 @@ class DataFrame:
|
|
|
185
212
|
if table.num_columns() != len(names):
|
|
186
213
|
raise ValueError("Mismatching name and table length.")
|
|
187
214
|
return cls(
|
|
188
|
-
|
|
189
|
-
|
|
215
|
+
(
|
|
216
|
+
Column(c, name=name, dtype=dtype)
|
|
217
|
+
for c, name, dtype in zip(table.columns(), names, dtypes, strict=True)
|
|
218
|
+
),
|
|
219
|
+
stream=stream,
|
|
190
220
|
)
|
|
191
221
|
|
|
192
222
|
@classmethod
|
|
193
223
|
def deserialize(
|
|
194
|
-
cls,
|
|
224
|
+
cls,
|
|
225
|
+
header: DataFrameHeader,
|
|
226
|
+
frames: tuple[memoryview[bytes], plc.gpumemoryview],
|
|
227
|
+
stream: Stream,
|
|
195
228
|
) -> Self:
|
|
196
229
|
"""
|
|
197
230
|
Create a DataFrame from a serialized representation returned by `.serialize()`.
|
|
@@ -202,6 +235,10 @@ class DataFrame:
|
|
|
202
235
|
The (unpickled) metadata required to reconstruct the object.
|
|
203
236
|
frames
|
|
204
237
|
Two-tuple of frames (a memoryview and a gpumemoryview).
|
|
238
|
+
stream
|
|
239
|
+
CUDA stream used for device memory operations and kernel launches
|
|
240
|
+
on this dataframe. The caller is responsible for ensuring that
|
|
241
|
+
the data in ``frames`` is valid on ``stream``.
|
|
205
242
|
|
|
206
243
|
Returns
|
|
207
244
|
-------
|
|
@@ -210,16 +247,22 @@ class DataFrame:
|
|
|
210
247
|
"""
|
|
211
248
|
packed_metadata, packed_gpu_data = frames
|
|
212
249
|
table = plc.contiguous_split.unpack_from_memoryviews(
|
|
213
|
-
packed_metadata,
|
|
250
|
+
packed_metadata,
|
|
251
|
+
packed_gpu_data,
|
|
252
|
+
stream,
|
|
214
253
|
)
|
|
215
254
|
return cls(
|
|
216
|
-
|
|
217
|
-
|
|
255
|
+
(
|
|
256
|
+
Column(c, **Column.deserialize_ctor_kwargs(kw))
|
|
257
|
+
for c, kw in zip(table.columns(), header["columns_kwargs"], strict=True)
|
|
258
|
+
),
|
|
259
|
+
stream=stream,
|
|
218
260
|
)
|
|
219
261
|
|
|
220
262
|
def serialize(
|
|
221
263
|
self,
|
|
222
|
-
|
|
264
|
+
stream: Stream | None = None,
|
|
265
|
+
) -> tuple[DataFrameHeader, tuple[memoryview[bytes], plc.gpumemoryview]]:
|
|
223
266
|
"""
|
|
224
267
|
Serialize the table into header and frames.
|
|
225
268
|
|
|
@@ -231,6 +274,12 @@ class DataFrame:
|
|
|
231
274
|
>>> from cudf_polars.experimental.dask_serialize import register
|
|
232
275
|
>>> register()
|
|
233
276
|
|
|
277
|
+
Parameters
|
|
278
|
+
----------
|
|
279
|
+
stream
|
|
280
|
+
CUDA stream used for device memory operations and kernel launches
|
|
281
|
+
on this dataframe.
|
|
282
|
+
|
|
234
283
|
Returns
|
|
235
284
|
-------
|
|
236
285
|
header
|
|
@@ -238,7 +287,7 @@ class DataFrame:
|
|
|
238
287
|
frames
|
|
239
288
|
Two-tuple of frames suitable for passing to `plc.contiguous_split.unpack_from_memoryviews`
|
|
240
289
|
"""
|
|
241
|
-
packed = plc.contiguous_split.pack(self.table)
|
|
290
|
+
packed = plc.contiguous_split.pack(self.table, stream=stream)
|
|
242
291
|
|
|
243
292
|
# Keyword arguments for `Column.__init__`.
|
|
244
293
|
columns_kwargs: list[ColumnOptions] = [
|
|
@@ -276,12 +325,19 @@ class DataFrame:
|
|
|
276
325
|
raise ValueError("Can only copy from identically named frame")
|
|
277
326
|
subset = self.column_names_set if subset is None else subset
|
|
278
327
|
return type(self)(
|
|
279
|
-
|
|
280
|
-
|
|
328
|
+
(
|
|
329
|
+
c.sorted_like(other) if c.name in subset else c
|
|
330
|
+
for c, other in zip(self.columns, like.columns, strict=True)
|
|
331
|
+
),
|
|
332
|
+
stream=self.stream,
|
|
281
333
|
)
|
|
282
334
|
|
|
283
335
|
def with_columns(
|
|
284
|
-
self,
|
|
336
|
+
self,
|
|
337
|
+
columns: Iterable[Column],
|
|
338
|
+
*,
|
|
339
|
+
replace_only: bool = False,
|
|
340
|
+
stream: Stream,
|
|
285
341
|
) -> Self:
|
|
286
342
|
"""
|
|
287
343
|
Return a new dataframe with extra columns.
|
|
@@ -292,6 +348,13 @@ class DataFrame:
|
|
|
292
348
|
Columns to add
|
|
293
349
|
replace_only
|
|
294
350
|
If true, then only replacements are allowed (matching by name).
|
|
351
|
+
stream
|
|
352
|
+
CUDA stream used for device memory operations and kernel launches.
|
|
353
|
+
The caller is responsible for ensuring that
|
|
354
|
+
|
|
355
|
+
1. The data in ``columns`` is valid on ``stream``.
|
|
356
|
+
2. No additional operations occur on ``self.stream`` with the
|
|
357
|
+
original data in ``self``.
|
|
295
358
|
|
|
296
359
|
Returns
|
|
297
360
|
-------
|
|
@@ -305,33 +368,57 @@ class DataFrame:
|
|
|
305
368
|
new = {c.name: c for c in columns}
|
|
306
369
|
if replace_only and not self.column_names_set.issuperset(new.keys()):
|
|
307
370
|
raise ValueError("Cannot replace with non-existing names")
|
|
308
|
-
return type(self)((self.column_map | new).values())
|
|
371
|
+
return type(self)((self.column_map | new).values(), stream=stream)
|
|
309
372
|
|
|
310
373
|
def discard_columns(self, names: Set[str]) -> Self:
|
|
311
374
|
"""Drop columns by name."""
|
|
312
|
-
return type(self)(
|
|
375
|
+
return type(self)(
|
|
376
|
+
(column for column in self.columns if column.name not in names),
|
|
377
|
+
stream=self.stream,
|
|
378
|
+
)
|
|
313
379
|
|
|
314
380
|
def select(self, names: Sequence[str] | Mapping[str, Any]) -> Self:
|
|
315
381
|
"""Select columns by name returning DataFrame."""
|
|
316
382
|
try:
|
|
317
|
-
return type(self)(
|
|
383
|
+
return type(self)(
|
|
384
|
+
(self.column_map[name] for name in names), stream=self.stream
|
|
385
|
+
)
|
|
318
386
|
except KeyError as e:
|
|
319
387
|
raise ValueError("Can't select missing names") from e
|
|
320
388
|
|
|
321
389
|
def rename_columns(self, mapping: Mapping[str, str]) -> Self:
|
|
322
390
|
"""Rename some columns."""
|
|
323
|
-
return type(self)(
|
|
391
|
+
return type(self)(
|
|
392
|
+
(c.rename(mapping.get(c.name, c.name)) for c in self.columns),
|
|
393
|
+
stream=self.stream,
|
|
394
|
+
)
|
|
324
395
|
|
|
325
396
|
def select_columns(self, names: Set[str]) -> list[Column]:
|
|
326
397
|
"""Select columns by name."""
|
|
327
398
|
return [c for c in self.columns if c.name in names]
|
|
328
399
|
|
|
329
400
|
def filter(self, mask: Column) -> Self:
|
|
330
|
-
"""
|
|
331
|
-
|
|
401
|
+
"""
|
|
402
|
+
Return a filtered table given a mask.
|
|
403
|
+
|
|
404
|
+
Parameters
|
|
405
|
+
----------
|
|
406
|
+
mask
|
|
407
|
+
Boolean mask to apply to the dataframe. It is the caller's
|
|
408
|
+
responsibility to ensure that ``mask`` is valid on ``self.stream``.
|
|
409
|
+
A mask that is derived from ``self`` via a computation on ``self.stream``
|
|
410
|
+
automatically satisfies this requirement.
|
|
411
|
+
|
|
412
|
+
Returns
|
|
413
|
+
-------
|
|
414
|
+
Filtered dataframe
|
|
415
|
+
"""
|
|
416
|
+
table = plc.stream_compaction.apply_boolean_mask(
|
|
417
|
+
self.table, mask.obj, stream=self.stream
|
|
418
|
+
)
|
|
332
419
|
return (
|
|
333
420
|
type(self)
|
|
334
|
-
.from_table(table, self.column_names, self.dtypes)
|
|
421
|
+
.from_table(table, self.column_names, self.dtypes, self.stream)
|
|
335
422
|
.sorted_like(self)
|
|
336
423
|
)
|
|
337
424
|
|
|
@@ -352,10 +439,12 @@ class DataFrame:
|
|
|
352
439
|
if zlice is None:
|
|
353
440
|
return self
|
|
354
441
|
(table,) = plc.copying.slice(
|
|
355
|
-
self.table,
|
|
442
|
+
self.table,
|
|
443
|
+
conversion.from_polars_slice(zlice, num_rows=self.num_rows),
|
|
444
|
+
stream=self.stream,
|
|
356
445
|
)
|
|
357
446
|
return (
|
|
358
447
|
type(self)
|
|
359
|
-
.from_table(table, self.column_names, self.dtypes)
|
|
448
|
+
.from_table(table, self.column_names, self.dtypes, self.stream)
|
|
360
449
|
.sorted_like(self)
|
|
361
450
|
)
|
|
@@ -6,6 +6,7 @@
|
|
|
6
6
|
from __future__ import annotations
|
|
7
7
|
|
|
8
8
|
from functools import cache
|
|
9
|
+
from typing import TYPE_CHECKING, Literal, cast
|
|
9
10
|
|
|
10
11
|
from typing_extensions import assert_never
|
|
11
12
|
|
|
@@ -13,8 +14,103 @@ import polars as pl
|
|
|
13
14
|
|
|
14
15
|
import pylibcudf as plc
|
|
15
16
|
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from cudf_polars.typing import (
|
|
19
|
+
DataTypeHeader,
|
|
20
|
+
PolarsDataType,
|
|
21
|
+
)
|
|
22
|
+
|
|
16
23
|
__all__ = ["DataType"]
|
|
17
24
|
|
|
25
|
+
SCALAR_NAME_TO_POLARS_TYPE_MAP: dict[str, pl.DataType] = {
|
|
26
|
+
"Boolean": pl.Boolean(),
|
|
27
|
+
"Int8": pl.Int8(),
|
|
28
|
+
"Int16": pl.Int16(),
|
|
29
|
+
"Int32": pl.Int32(),
|
|
30
|
+
"Int64": pl.Int64(),
|
|
31
|
+
"Object": pl.Object(),
|
|
32
|
+
"UInt8": pl.UInt8(),
|
|
33
|
+
"UInt16": pl.UInt16(),
|
|
34
|
+
"UInt32": pl.UInt32(),
|
|
35
|
+
"UInt64": pl.UInt64(),
|
|
36
|
+
"Float32": pl.Float32(),
|
|
37
|
+
"Float64": pl.Float64(),
|
|
38
|
+
"String": pl.String(),
|
|
39
|
+
"Null": pl.Null(),
|
|
40
|
+
"Date": pl.Date(),
|
|
41
|
+
"Time": pl.Time(),
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _dtype_to_header(dtype: pl.DataType) -> DataTypeHeader:
|
|
46
|
+
name = type(dtype).__name__
|
|
47
|
+
if name in SCALAR_NAME_TO_POLARS_TYPE_MAP:
|
|
48
|
+
return {"kind": "scalar", "name": name}
|
|
49
|
+
if isinstance(dtype, pl.Decimal):
|
|
50
|
+
# TODO: Add version guard once we support polars 1.34
|
|
51
|
+
# Also keep in mind the typing change in polars:
|
|
52
|
+
# https://github.com/pola-rs/polars/pull/25227
|
|
53
|
+
precision = dtype.precision if dtype.precision is not None else 38
|
|
54
|
+
return {
|
|
55
|
+
"kind": "decimal",
|
|
56
|
+
"precision": precision,
|
|
57
|
+
"scale": dtype.scale,
|
|
58
|
+
}
|
|
59
|
+
if isinstance(dtype, pl.Datetime):
|
|
60
|
+
return {
|
|
61
|
+
"kind": "datetime",
|
|
62
|
+
"time_unit": dtype.time_unit,
|
|
63
|
+
"time_zone": dtype.time_zone,
|
|
64
|
+
}
|
|
65
|
+
if isinstance(dtype, pl.Duration):
|
|
66
|
+
return {"kind": "duration", "time_unit": dtype.time_unit}
|
|
67
|
+
if isinstance(dtype, pl.List):
|
|
68
|
+
# isinstance narrows dtype to pl.List, but .inner returns DataTypeClass | DataType
|
|
69
|
+
return {
|
|
70
|
+
"kind": "list",
|
|
71
|
+
"inner": _dtype_to_header(cast(pl.DataType, dtype.inner)),
|
|
72
|
+
}
|
|
73
|
+
if isinstance(dtype, pl.Struct):
|
|
74
|
+
# isinstance narrows dtype to pl.Struct, but field.dtype returns DataTypeClass | DataType
|
|
75
|
+
return {
|
|
76
|
+
"kind": "struct",
|
|
77
|
+
"fields": [
|
|
78
|
+
{"name": f.name, "dtype": _dtype_to_header(cast(pl.DataType, f.dtype))}
|
|
79
|
+
for f in dtype.fields
|
|
80
|
+
],
|
|
81
|
+
}
|
|
82
|
+
raise NotImplementedError(f"Unsupported dtype {dtype!r}")
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _dtype_from_header(header: DataTypeHeader) -> pl.DataType:
|
|
86
|
+
if header["kind"] == "scalar":
|
|
87
|
+
name = header["name"]
|
|
88
|
+
try:
|
|
89
|
+
return SCALAR_NAME_TO_POLARS_TYPE_MAP[name]
|
|
90
|
+
except KeyError as err:
|
|
91
|
+
raise NotImplementedError(f"Unknown scalar dtype name: {name}") from err
|
|
92
|
+
if header["kind"] == "decimal":
|
|
93
|
+
return pl.Decimal(header["precision"], header["scale"])
|
|
94
|
+
if header["kind"] == "datetime":
|
|
95
|
+
return pl.Datetime(
|
|
96
|
+
time_unit=cast(Literal["ns", "us", "ms"], header["time_unit"]),
|
|
97
|
+
time_zone=header["time_zone"],
|
|
98
|
+
)
|
|
99
|
+
if header["kind"] == "duration":
|
|
100
|
+
return pl.Duration(
|
|
101
|
+
time_unit=cast(Literal["ns", "us", "ms"], header["time_unit"])
|
|
102
|
+
)
|
|
103
|
+
if header["kind"] == "list":
|
|
104
|
+
return pl.List(_dtype_from_header(header["inner"]))
|
|
105
|
+
if header["kind"] == "struct":
|
|
106
|
+
return pl.Struct(
|
|
107
|
+
[
|
|
108
|
+
pl.Field(f["name"], _dtype_from_header(f["dtype"]))
|
|
109
|
+
for f in header["fields"]
|
|
110
|
+
]
|
|
111
|
+
)
|
|
112
|
+
raise NotImplementedError(f"Unsupported kind {header['kind']!r}")
|
|
113
|
+
|
|
18
114
|
|
|
19
115
|
@cache
|
|
20
116
|
def _from_polars(dtype: pl.DataType) -> plc.DataType:
|
|
@@ -102,36 +198,61 @@ def _from_polars(dtype: pl.DataType) -> plc.DataType:
|
|
|
102
198
|
class DataType:
|
|
103
199
|
"""A datatype, preserving polars metadata."""
|
|
104
200
|
|
|
105
|
-
|
|
106
|
-
|
|
201
|
+
polars_type: pl.datatypes.DataType
|
|
202
|
+
plc_type: plc.DataType
|
|
107
203
|
|
|
108
|
-
def __init__(self, polars_dtype:
|
|
109
|
-
|
|
110
|
-
|
|
204
|
+
def __init__(self, polars_dtype: PolarsDataType) -> None:
|
|
205
|
+
# Convert DataTypeClass to DataType instance if needed
|
|
206
|
+
# polars allows both pl.Int64 (class) and pl.Int64() (instance)
|
|
207
|
+
if isinstance(polars_dtype, type):
|
|
208
|
+
polars_dtype = polars_dtype()
|
|
209
|
+
# After conversion, it's guaranteed to be a DataType instance
|
|
210
|
+
self.polars_type = cast(pl.DataType, polars_dtype)
|
|
211
|
+
self.plc_type = _from_polars(self.polars_type)
|
|
111
212
|
|
|
112
213
|
def id(self) -> plc.TypeId:
|
|
113
214
|
"""The pylibcudf.TypeId of this DataType."""
|
|
114
|
-
return self.
|
|
215
|
+
return self.plc_type.id()
|
|
115
216
|
|
|
116
217
|
@property
|
|
117
218
|
def children(self) -> list[DataType]:
|
|
118
219
|
"""The children types of this DataType."""
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
return [
|
|
220
|
+
# Type checker doesn't narrow polars_type through plc_type.id() checks
|
|
221
|
+
if self.plc_type.id() == plc.TypeId.STRUCT:
|
|
222
|
+
# field.dtype returns DataTypeClass | DataType, need to cast to DataType
|
|
223
|
+
return [
|
|
224
|
+
DataType(cast(pl.DataType, field.dtype))
|
|
225
|
+
for field in cast(pl.Struct, self.polars_type).fields
|
|
226
|
+
]
|
|
227
|
+
elif self.plc_type.id() == plc.TypeId.LIST:
|
|
228
|
+
# .inner returns DataTypeClass | DataType, need to cast to DataType
|
|
229
|
+
return [DataType(cast(pl.DataType, cast(pl.List, self.polars_type).inner))]
|
|
123
230
|
return []
|
|
124
231
|
|
|
232
|
+
def scale(self) -> int:
|
|
233
|
+
"""The scale of this DataType."""
|
|
234
|
+
return self.plc_type.scale()
|
|
235
|
+
|
|
236
|
+
@staticmethod
|
|
237
|
+
def common_decimal_dtype(left: DataType, right: DataType) -> DataType:
|
|
238
|
+
"""Return a common decimal DataType for the two inputs."""
|
|
239
|
+
if not (
|
|
240
|
+
plc.traits.is_fixed_point(left.plc_type)
|
|
241
|
+
and plc.traits.is_fixed_point(right.plc_type)
|
|
242
|
+
):
|
|
243
|
+
raise ValueError("Both inputs required to be decimal types.")
|
|
244
|
+
return DataType(pl.Decimal(38, abs(min(left.scale(), right.scale()))))
|
|
245
|
+
|
|
125
246
|
def __eq__(self, other: object) -> bool:
|
|
126
247
|
"""Equality of DataTypes."""
|
|
127
248
|
if not isinstance(other, DataType):
|
|
128
249
|
return False
|
|
129
|
-
return self.
|
|
250
|
+
return self.polars_type == other.polars_type
|
|
130
251
|
|
|
131
252
|
def __hash__(self) -> int:
|
|
132
253
|
"""Hash of the DataType."""
|
|
133
|
-
return hash(self.
|
|
254
|
+
return hash(self.polars_type)
|
|
134
255
|
|
|
135
256
|
def __repr__(self) -> str:
|
|
136
257
|
"""Representation of the DataType."""
|
|
137
|
-
return f"<DataType(polars={self.
|
|
258
|
+
return f"<DataType(polars={self.polars_type}, plc={self.id()!r})>"
|