cudf-polars-cu12 25.6.0__py3-none-any.whl → 25.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. cudf_polars/VERSION +1 -1
  2. cudf_polars/callback.py +21 -12
  3. cudf_polars/containers/__init__.py +4 -2
  4. cudf_polars/containers/column.py +87 -42
  5. cudf_polars/containers/dataframe.py +62 -22
  6. cudf_polars/containers/datatype.py +135 -0
  7. cudf_polars/dsl/expr.py +2 -0
  8. cudf_polars/dsl/expressions/aggregation.py +31 -15
  9. cudf_polars/dsl/expressions/base.py +5 -5
  10. cudf_polars/dsl/expressions/binaryop.py +26 -5
  11. cudf_polars/dsl/expressions/boolean.py +58 -37
  12. cudf_polars/dsl/expressions/datetime.py +29 -35
  13. cudf_polars/dsl/expressions/literal.py +23 -11
  14. cudf_polars/dsl/expressions/rolling.py +37 -15
  15. cudf_polars/dsl/expressions/selection.py +7 -7
  16. cudf_polars/dsl/expressions/slicing.py +4 -5
  17. cudf_polars/dsl/expressions/sorting.py +5 -4
  18. cudf_polars/dsl/expressions/string.py +449 -60
  19. cudf_polars/dsl/expressions/struct.py +138 -0
  20. cudf_polars/dsl/expressions/ternary.py +6 -3
  21. cudf_polars/dsl/expressions/unary.py +127 -25
  22. cudf_polars/dsl/ir.py +284 -225
  23. cudf_polars/dsl/nodebase.py +10 -3
  24. cudf_polars/dsl/to_ast.py +60 -21
  25. cudf_polars/dsl/tracing.py +16 -0
  26. cudf_polars/dsl/translate.py +53 -61
  27. cudf_polars/dsl/traversal.py +64 -15
  28. cudf_polars/dsl/utils/aggregations.py +12 -3
  29. cudf_polars/dsl/utils/groupby.py +2 -6
  30. cudf_polars/dsl/utils/replace.py +19 -4
  31. cudf_polars/dsl/utils/reshape.py +74 -0
  32. cudf_polars/dsl/utils/rolling.py +5 -3
  33. cudf_polars/dsl/utils/windows.py +1 -1
  34. cudf_polars/experimental/base.py +114 -2
  35. cudf_polars/experimental/benchmarks/pdsds.py +216 -0
  36. cudf_polars/experimental/benchmarks/pdsds_queries/__init__.py +4 -0
  37. cudf_polars/experimental/benchmarks/pdsds_queries/q1.py +88 -0
  38. cudf_polars/experimental/benchmarks/pdsh.py +11 -478
  39. cudf_polars/experimental/benchmarks/utils.py +725 -0
  40. cudf_polars/experimental/dask_registers.py +13 -9
  41. cudf_polars/experimental/dispatch.py +22 -7
  42. cudf_polars/experimental/distinct.py +39 -19
  43. cudf_polars/experimental/expressions.py +49 -23
  44. cudf_polars/experimental/groupby.py +79 -43
  45. cudf_polars/experimental/io.py +617 -69
  46. cudf_polars/experimental/join.py +51 -15
  47. cudf_polars/experimental/parallel.py +76 -12
  48. cudf_polars/experimental/select.py +41 -1
  49. cudf_polars/experimental/shuffle.py +33 -25
  50. cudf_polars/experimental/utils.py +13 -1
  51. cudf_polars/testing/asserts.py +85 -26
  52. cudf_polars/testing/plugin.py +64 -67
  53. cudf_polars/typing/__init__.py +41 -22
  54. cudf_polars/utils/config.py +335 -83
  55. cudf_polars/utils/dtypes.py +3 -123
  56. cudf_polars/utils/versions.py +6 -4
  57. {cudf_polars_cu12-25.6.0.dist-info → cudf_polars_cu12-25.8.0.dist-info}/METADATA +12 -19
  58. cudf_polars_cu12-25.8.0.dist-info/RECORD +81 -0
  59. cudf_polars_cu12-25.6.0.dist-info/RECORD +0 -73
  60. {cudf_polars_cu12-25.6.0.dist-info → cudf_polars_cu12-25.8.0.dist-info}/WHEEL +0 -0
  61. {cudf_polars_cu12-25.6.0.dist-info → cudf_polars_cu12-25.8.0.dist-info}/licenses/LICENSE +0 -0
  62. {cudf_polars_cu12-25.6.0.dist-info → cudf_polars_cu12-25.8.0.dist-info}/top_level.txt +0 -0
cudf_polars/VERSION CHANGED
@@ -1 +1 @@
1
- 25.06.00
1
+ 25.08.00
cudf_polars/callback.py CHANGED
@@ -7,6 +7,7 @@ from __future__ import annotations
7
7
 
8
8
  import contextlib
9
9
  import os
10
+ import textwrap
10
11
  import time
11
12
  import warnings
12
13
  from functools import cache, partial
@@ -21,7 +22,9 @@ import pylibcudf
21
22
  import rmm
22
23
  from rmm._cuda import gpu
23
24
 
25
+ from cudf_polars.dsl.tracing import CUDF_POLARS_NVTX_DOMAIN
24
26
  from cudf_polars.dsl.translate import Translator
27
+ from cudf_polars.utils.config import _env_get_int, get_total_device_memory
25
28
  from cudf_polars.utils.timer import Timer
26
29
 
27
30
  if TYPE_CHECKING:
@@ -45,13 +48,6 @@ _SUPPORTED_PREFETCHES = {
45
48
  }
46
49
 
47
50
 
48
- def _env_get_int(name: str, default: int) -> int:
49
- try:
50
- return int(os.getenv(name, default))
51
- except (ValueError, TypeError): # pragma: no cover
52
- return default # pragma: no cover
53
-
54
-
55
51
  @cache
56
52
  def default_memory_resource(
57
53
  device: int,
@@ -102,8 +98,7 @@ def default_memory_resource(
102
98
  ):
103
99
  raise ComputeError(
104
100
  "GPU engine requested, but incorrect cudf-polars package installed. "
105
- "If your system has a CUDA 11 driver, please uninstall `cudf-polars-cu12` "
106
- "and install `cudf-polars-cu11`"
101
+ "cudf-polars requires CUDA 12.0+ to installed."
107
102
  ) from None
108
103
  else:
109
104
  raise
@@ -140,7 +135,11 @@ def set_memory_resource(
140
135
  mr = default_memory_resource(
141
136
  device=device,
142
137
  cuda_managed_memory=bool(
143
- _env_get_int("POLARS_GPU_ENABLE_CUDA_MANAGED_MEMORY", default=1) != 0
138
+ _env_get_int(
139
+ "POLARS_GPU_ENABLE_CUDA_MANAGED_MEMORY",
140
+ default=1 if get_total_device_memory() is not None else 0,
141
+ )
142
+ != 0
144
143
  ),
145
144
  )
146
145
  rmm.mr.set_current_device_resource(mr)
@@ -222,7 +221,7 @@ def _callback(
222
221
  if timer is not None:
223
222
  assert should_time
224
223
  with (
225
- nvtx.annotate(message="ExecuteIR", domain="cudf_polars"),
224
+ nvtx.annotate(message="ExecuteIR", domain=CUDF_POLARS_NVTX_DOMAIN),
226
225
  # Device must be set before memory resource is obtained.
227
226
  set_device(config_options.device),
228
227
  set_memory_resource(memory_resource),
@@ -236,6 +235,16 @@ def _callback(
236
235
  elif config_options.executor.name == "streaming":
237
236
  from cudf_polars.experimental.parallel import evaluate_streaming
238
237
 
238
+ if timer is not None:
239
+ msg = textwrap.dedent("""\
240
+ LazyFrame.profile() is not supported with the streaming executor.
241
+ To profile execution with the streaming executor, use:
242
+
243
+ - NVIDIA NSight Systems with the 'streaming' scheduler.
244
+ - Dask's built-in profiling tools with the 'distributed' scheduler.
245
+ """)
246
+ raise NotImplementedError(msg)
247
+
239
248
  return evaluate_streaming(ir, config_options).to_polars()
240
249
  assert_never(f"Unknown executor '{config_options.executor}'")
241
250
 
@@ -277,7 +286,7 @@ def execute_with_cudf(
277
286
 
278
287
  memory_resource = config.memory_resource
279
288
 
280
- with nvtx.annotate(message="ConvertIR", domain="cudf_polars"):
289
+ with nvtx.annotate(message="ConvertIR", domain=CUDF_POLARS_NVTX_DOMAIN):
281
290
  translator = Translator(nt, config)
282
291
  ir = translator.translate_ir()
283
292
  ir_translation_errors = translator.errors
@@ -1,11 +1,13 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  """Containers of concrete data."""
5
5
 
6
6
  from __future__ import annotations
7
7
 
8
- __all__: list[str] = ["Column", "DataFrame"]
8
+ __all__: list[str] = ["Column", "DataFrame", "DataType"]
9
9
 
10
+ # dataframe.py & column.py imports DataType, so import in this order to avoid circular import
11
+ from cudf_polars.containers.datatype import DataType # noqa: I001
10
12
  from cudf_polars.containers.column import Column
11
13
  from cudf_polars.containers.dataframe import DataFrame
@@ -8,6 +8,8 @@ from __future__ import annotations
8
8
  import functools
9
9
  from typing import TYPE_CHECKING
10
10
 
11
+ import polars as pl
12
+ import polars.datatypes.convert
11
13
  from polars.exceptions import InvalidOperationError
12
14
 
13
15
  import pylibcudf as plc
@@ -19,19 +21,39 @@ from pylibcudf.strings.convert.convert_integers import (
19
21
  )
20
22
  from pylibcudf.traits import is_floating_point
21
23
 
24
+ from cudf_polars.containers import DataType
22
25
  from cudf_polars.utils import conversion
23
26
  from cudf_polars.utils.dtypes import is_order_preserving_cast
24
27
 
25
28
  if TYPE_CHECKING:
26
29
  from typing_extensions import Self
27
30
 
28
- import polars as pl
29
-
30
- from cudf_polars.typing import ColumnHeader, ColumnOptions, Slice
31
+ from cudf_polars.typing import (
32
+ ColumnHeader,
33
+ ColumnOptions,
34
+ DeserializedColumnOptions,
35
+ Slice,
36
+ )
31
37
 
32
38
  __all__: list[str] = ["Column"]
33
39
 
34
40
 
41
+ def _dtype_short_repr_to_dtype(dtype_str: str) -> pl.DataType:
42
+ """Convert a Polars dtype short repr to a Polars dtype."""
43
+ # limitations of dtype_short_repr_to_dtype described in
44
+ # py-polars/polars/datatypes/convert.py#L299
45
+ if dtype_str.startswith("list["):
46
+ stripped = dtype_str.removeprefix("list[").removesuffix("]")
47
+ return pl.List(_dtype_short_repr_to_dtype(stripped))
48
+ pl_type = polars.datatypes.convert.dtype_short_repr_to_dtype(dtype_str)
49
+ if pl_type is None:
50
+ raise ValueError(f"{dtype_str} was not able to be parsed by Polars.")
51
+ if isinstance(pl_type, polars.datatypes.DataTypeClass):
52
+ return pl_type()
53
+ else:
54
+ return pl_type
55
+
56
+
35
57
  class Column:
36
58
  """An immutable column with sortedness metadata."""
37
59
 
@@ -43,10 +65,12 @@ class Column:
43
65
  # Optional name, only ever set by evaluation of NamedExpr nodes
44
66
  # The internal evaluation should not care about the name.
45
67
  name: str | None
68
+ dtype: DataType
46
69
 
47
70
  def __init__(
48
71
  self,
49
72
  column: plc.Column,
73
+ dtype: DataType,
50
74
  *,
51
75
  is_sorted: plc.types.Sorted = plc.types.Sorted.NO,
52
76
  order: plc.types.Order = plc.types.Order.ASCENDING,
@@ -56,6 +80,7 @@ class Column:
56
80
  self.obj = column
57
81
  self.is_scalar = self.size == 1
58
82
  self.name = name
83
+ self.dtype = dtype
59
84
  self.set_sorted(is_sorted=is_sorted, order=order, null_order=null_order)
60
85
 
61
86
  @classmethod
@@ -81,7 +106,23 @@ class Column:
81
106
  (plc_column,) = plc.contiguous_split.unpack_from_memoryviews(
82
107
  packed_metadata, packed_gpu_data
83
108
  ).columns()
84
- return cls(plc_column, **header["column_kwargs"])
109
+ return cls(plc_column, **cls.deserialize_ctor_kwargs(header["column_kwargs"]))
110
+
111
+ @staticmethod
112
+ def deserialize_ctor_kwargs(
113
+ column_kwargs: ColumnOptions,
114
+ ) -> DeserializedColumnOptions:
115
+ """Deserialize the constructor kwargs for a Column."""
116
+ dtype = DataType( # pragma: no cover
117
+ _dtype_short_repr_to_dtype(column_kwargs["dtype"])
118
+ )
119
+ return {
120
+ "is_sorted": column_kwargs["is_sorted"],
121
+ "order": column_kwargs["order"],
122
+ "null_order": column_kwargs["null_order"],
123
+ "name": column_kwargs["name"],
124
+ "dtype": dtype,
125
+ }
85
126
 
86
127
  def serialize(
87
128
  self,
@@ -105,17 +146,21 @@ class Column:
105
146
  Two-tuple of frames suitable for passing to `plc.contiguous_split.unpack_from_memoryviews`
106
147
  """
107
148
  packed = plc.contiguous_split.pack(plc.Table([self.obj]))
108
- column_kwargs: ColumnOptions = {
149
+ header: ColumnHeader = {
150
+ "column_kwargs": self.serialize_ctor_kwargs(),
151
+ "frame_count": 2,
152
+ }
153
+ return header, packed.release()
154
+
155
+ def serialize_ctor_kwargs(self) -> ColumnOptions:
156
+ """Serialize the constructor kwargs for self."""
157
+ return {
109
158
  "is_sorted": self.is_sorted,
110
159
  "order": self.order,
111
160
  "null_order": self.null_order,
112
161
  "name": self.name,
162
+ "dtype": pl.polars.dtype_str_repr(self.dtype.polars),
113
163
  }
114
- header: ColumnHeader = {
115
- "column_kwargs": column_kwargs,
116
- "frame_count": 2,
117
- }
118
- return header, packed.release()
119
164
 
120
165
  @functools.cached_property
121
166
  def obj_scalar(self) -> plc.Scalar:
@@ -172,6 +217,7 @@ class Column:
172
217
  return type(self)(
173
218
  self.obj,
174
219
  name=self.name,
220
+ dtype=self.dtype,
175
221
  is_sorted=like.is_sorted,
176
222
  order=like.order,
177
223
  null_order=like.null_order,
@@ -202,11 +248,11 @@ class Column:
202
248
  If the sortedness flag is not set, this launches a kernel to
203
249
  check sortedness.
204
250
  """
205
- if self.obj.size() <= 1 or self.obj.size() == self.obj.null_count():
251
+ if self.size <= 1 or self.size == self.null_count:
206
252
  return True
207
253
  if self.is_sorted == plc.types.Sorted.YES:
208
254
  return self.order == order and (
209
- self.obj.null_count() == 0 or self.null_order == null_order
255
+ self.null_count == 0 or self.null_order == null_order
210
256
  )
211
257
  if plc.sorting.is_sorted(plc.Table([self.obj]), [order], [null_order]):
212
258
  self.sorted = plc.types.Sorted.YES
@@ -215,7 +261,7 @@ class Column:
215
261
  return True
216
262
  return False
217
263
 
218
- def astype(self, dtype: plc.DataType) -> Column:
264
+ def astype(self, dtype: DataType) -> Column:
219
265
  """
220
266
  Cast the column to as the requested dtype.
221
267
 
@@ -238,14 +284,18 @@ class Column:
238
284
  This only produces a copy if the requested dtype doesn't match
239
285
  the current one.
240
286
  """
241
- if self.obj.type() == dtype:
287
+ plc_dtype = dtype.plc
288
+ if self.obj.type() == plc_dtype:
242
289
  return self
243
290
 
244
- if dtype.id() == plc.TypeId.STRING or self.obj.type().id() == plc.TypeId.STRING:
245
- return Column(self._handle_string_cast(dtype))
291
+ if (
292
+ plc_dtype.id() == plc.TypeId.STRING
293
+ or self.obj.type().id() == plc.TypeId.STRING
294
+ ):
295
+ return Column(self._handle_string_cast(plc_dtype), dtype=dtype)
246
296
  else:
247
- result = Column(plc.unary.cast(self.obj, dtype))
248
- if is_order_preserving_cast(self.obj.type(), dtype):
297
+ result = Column(plc.unary.cast(self.obj, plc_dtype), dtype=dtype)
298
+ if is_order_preserving_cast(self.obj.type(), plc_dtype):
249
299
  return result.sorted_like(self)
250
300
  return result
251
301
 
@@ -258,24 +308,20 @@ class Column:
258
308
  else:
259
309
  if is_floating_point(dtype):
260
310
  floats = is_float(self.obj)
261
- if not plc.interop.to_arrow(
262
- plc.reduce.reduce(
263
- floats,
264
- plc.aggregation.all(),
265
- plc.DataType(plc.TypeId.BOOL8),
266
- )
267
- ).as_py():
311
+ if not plc.reduce.reduce(
312
+ floats,
313
+ plc.aggregation.all(),
314
+ plc.DataType(plc.TypeId.BOOL8),
315
+ ).to_py():
268
316
  raise InvalidOperationError("Conversion from `str` failed.")
269
317
  return to_floats(self.obj, dtype)
270
318
  else:
271
319
  integers = is_integer(self.obj)
272
- if not plc.interop.to_arrow(
273
- plc.reduce.reduce(
274
- integers,
275
- plc.aggregation.all(),
276
- plc.DataType(plc.TypeId.BOOL8),
277
- )
278
- ).as_py():
320
+ if not plc.reduce.reduce(
321
+ integers,
322
+ plc.aggregation.all(),
323
+ plc.DataType(plc.TypeId.BOOL8),
324
+ ).to_py():
279
325
  raise InvalidOperationError("Conversion from `str` failed.")
280
326
  return to_integers(self.obj, dtype)
281
327
 
@@ -361,6 +407,7 @@ class Column:
361
407
  order=self.order,
362
408
  null_order=self.null_order,
363
409
  name=self.name,
410
+ dtype=self.dtype,
364
411
  )
365
412
 
366
413
  def mask_nans(self) -> Self:
@@ -368,7 +415,7 @@ class Column:
368
415
  if plc.traits.is_floating_point(self.obj.type()):
369
416
  old_count = self.null_count
370
417
  mask, new_count = plc.transform.nans_to_nulls(self.obj)
371
- result = type(self)(self.obj.with_mask(mask, new_count))
418
+ result = type(self)(self.obj.with_mask(mask, new_count), self.dtype)
372
419
  if old_count == new_count:
373
420
  return result.sorted_like(self)
374
421
  return result
@@ -377,14 +424,12 @@ class Column:
377
424
  @functools.cached_property
378
425
  def nan_count(self) -> int:
379
426
  """Return the number of NaN values in the column."""
380
- if plc.traits.is_floating_point(self.obj.type()):
381
- return plc.interop.to_arrow(
382
- plc.reduce.reduce(
383
- plc.unary.is_nan(self.obj),
384
- plc.aggregation.sum(),
385
- plc.types.SIZE_TYPE,
386
- )
387
- ).as_py()
427
+ if self.size > 0 and plc.traits.is_floating_point(self.obj.type()):
428
+ return plc.reduce.reduce(
429
+ plc.unary.is_nan(self.obj),
430
+ plc.aggregation.sum(),
431
+ plc.types.SIZE_TYPE,
432
+ ).to_py()
388
433
  return 0
389
434
 
390
435
  @property
@@ -418,4 +463,4 @@ class Column:
418
463
  conversion.from_polars_slice(zlice, num_rows=self.size),
419
464
  )
420
465
  (column,) = table.columns()
421
- return type(self)(column, name=self.name).sorted_like(self)
466
+ return type(self)(column, name=self.name, dtype=self.dtype).sorted_like(self)
@@ -12,20 +12,51 @@ import polars as pl
12
12
 
13
13
  import pylibcudf as plc
14
14
 
15
- from cudf_polars.containers import Column
15
+ from cudf_polars.containers import Column, DataType
16
16
  from cudf_polars.utils import conversion
17
17
 
18
18
  if TYPE_CHECKING:
19
19
  from collections.abc import Iterable, Mapping, Sequence, Set
20
20
 
21
- from typing_extensions import Any, Self
21
+ from typing_extensions import Any, CapsuleType, Self
22
22
 
23
- from cudf_polars.typing import ColumnOptions, DataFrameHeader, Slice
23
+ from cudf_polars.typing import ColumnOptions, DataFrameHeader, PolarsDataType, Slice
24
24
 
25
25
 
26
26
  __all__: list[str] = ["DataFrame"]
27
27
 
28
28
 
29
+ def _create_polars_column_metadata(
30
+ name: str, dtype: PolarsDataType
31
+ ) -> plc.interop.ColumnMetadata:
32
+ """Create ColumnMetadata preserving pl.Struct field names."""
33
+ if isinstance(dtype, pl.Struct):
34
+ children_meta = [
35
+ _create_polars_column_metadata(field.name, field.dtype)
36
+ for field in dtype.fields
37
+ ]
38
+ else:
39
+ children_meta = []
40
+ timezone = dtype.time_zone if isinstance(dtype, pl.Datetime) else None
41
+ return plc.interop.ColumnMetadata(
42
+ name=name, timezone=timezone or "", children_meta=children_meta
43
+ )
44
+
45
+
46
+ # This is also defined in pylibcudf.interop
47
+ class _ObjectWithArrowMetadata:
48
+ def __init__(
49
+ self, obj: plc.Table, metadata: list[plc.interop.ColumnMetadata]
50
+ ) -> None:
51
+ self.obj = obj
52
+ self.metadata = metadata
53
+
54
+ def __arrow_c_array__(
55
+ self, requested_schema: None = None
56
+ ) -> tuple[CapsuleType, CapsuleType]:
57
+ return self.obj._to_schema(self.metadata), self.obj._to_host_array()
58
+
59
+
29
60
  # Pacify the type checker. DataFrame init asserts that all the columns
30
61
  # have a string name, so let's narrow the type.
31
62
  class NamedColumn(Column):
@@ -44,6 +75,7 @@ class DataFrame:
44
75
  if any(c.name is None for c in columns):
45
76
  raise ValueError("All columns must have a name")
46
77
  self.columns = [cast(NamedColumn, c) for c in columns]
78
+ self.dtypes = [c.dtype for c in self.columns]
47
79
  self.column_map = {c.name: c for c in self.columns}
48
80
  self.table = plc.Table([c.obj for c in self.columns])
49
81
 
@@ -60,11 +92,12 @@ class DataFrame:
60
92
  # To guarantee we produce correct names, we therefore
61
93
  # serialise with names we control and rename with that map.
62
94
  name_map = {f"column_{i}": name for i, name in enumerate(self.column_map)}
63
- table = plc.interop.to_arrow(
64
- self.table,
65
- [plc.interop.ColumnMetadata(name=name) for name in name_map],
66
- )
67
- df: pl.DataFrame = pl.from_arrow(table)
95
+ metadata = [
96
+ _create_polars_column_metadata(name, dtype.polars)
97
+ for name, dtype in zip(name_map, self.dtypes, strict=True)
98
+ ]
99
+ table_with_metadata = _ObjectWithArrowMetadata(self.table, metadata)
100
+ df = pl.DataFrame(table_with_metadata)
68
101
  return df.rename(name_map).with_columns(
69
102
  pl.col(c.name).set_sorted(descending=c.order == plc.types.Order.DESCENDING)
70
103
  if c.is_sorted
@@ -106,16 +139,18 @@ class DataFrame:
106
139
  -------
107
140
  New dataframe representing the input.
108
141
  """
109
- plc_table = plc.Table(df)
142
+ plc_table = plc.Table.from_arrow(df)
110
143
  return cls(
111
- Column(d_col, name=name).copy_metadata(h_col)
144
+ Column(d_col, name=name, dtype=DataType(h_col.dtype)).copy_metadata(h_col)
112
145
  for d_col, h_col, name in zip(
113
146
  plc_table.columns(), df.iter_columns(), df.columns, strict=True
114
147
  )
115
148
  )
116
149
 
117
150
  @classmethod
118
- def from_table(cls, table: plc.Table, names: Sequence[str]) -> Self:
151
+ def from_table(
152
+ cls, table: plc.Table, names: Sequence[str], dtypes: Sequence[DataType]
153
+ ) -> Self:
119
154
  """
120
155
  Create from a pylibcudf table.
121
156
 
@@ -125,6 +160,8 @@ class DataFrame:
125
160
  Pylibcudf table to obtain columns from
126
161
  names
127
162
  Names for the columns
163
+ dtypes
164
+ Dtypes for the columns
128
165
 
129
166
  Returns
130
167
  -------
@@ -139,7 +176,8 @@ class DataFrame:
139
176
  if table.num_columns() != len(names):
140
177
  raise ValueError("Mismatching name and table length.")
141
178
  return cls(
142
- Column(c, name=name) for c, name in zip(table.columns(), names, strict=True)
179
+ Column(c, name=name, dtype=dtype)
180
+ for c, name, dtype in zip(table.columns(), names, dtypes, strict=True)
143
181
  )
144
182
 
145
183
  @classmethod
@@ -166,7 +204,7 @@ class DataFrame:
166
204
  packed_metadata, packed_gpu_data
167
205
  )
168
206
  return cls(
169
- Column(c, **kw)
207
+ Column(c, **Column.deserialize_ctor_kwargs(kw))
170
208
  for c, kw in zip(table.columns(), header["columns_kwargs"], strict=True)
171
209
  )
172
210
 
@@ -195,13 +233,7 @@ class DataFrame:
195
233
 
196
234
  # Keyword arguments for `Column.__init__`.
197
235
  columns_kwargs: list[ColumnOptions] = [
198
- {
199
- "is_sorted": col.is_sorted,
200
- "order": col.order,
201
- "null_order": col.null_order,
202
- "name": col.name,
203
- }
204
- for col in self.columns
236
+ col.serialize_ctor_kwargs() for col in self.columns
205
237
  ]
206
238
  header: DataFrameHeader = {
207
239
  "columns_kwargs": columns_kwargs,
@@ -288,7 +320,11 @@ class DataFrame:
288
320
  def filter(self, mask: Column) -> Self:
289
321
  """Return a filtered table given a mask."""
290
322
  table = plc.stream_compaction.apply_boolean_mask(self.table, mask.obj)
291
- return type(self).from_table(table, self.column_names).sorted_like(self)
323
+ return (
324
+ type(self)
325
+ .from_table(table, self.column_names, self.dtypes)
326
+ .sorted_like(self)
327
+ )
292
328
 
293
329
  def slice(self, zlice: Slice | None) -> Self:
294
330
  """
@@ -309,4 +345,8 @@ class DataFrame:
309
345
  (table,) = plc.copying.slice(
310
346
  self.table, conversion.from_polars_slice(zlice, num_rows=self.num_rows)
311
347
  )
312
- return type(self).from_table(table, self.column_names).sorted_like(self)
348
+ return (
349
+ type(self)
350
+ .from_table(table, self.column_names, self.dtypes)
351
+ .sorted_like(self)
352
+ )
@@ -0,0 +1,135 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """A datatype, preserving polars metadata."""
5
+
6
+ from __future__ import annotations
7
+
8
+ from functools import cache
9
+
10
+ from typing_extensions import assert_never
11
+
12
+ import polars as pl
13
+
14
+ import pylibcudf as plc
15
+
16
+ __all__ = ["DataType"]
17
+
18
+
19
+ @cache
20
+ def _from_polars(dtype: pl.DataType) -> plc.DataType:
21
+ """
22
+ Convert a polars datatype to a pylibcudf one.
23
+
24
+ Parameters
25
+ ----------
26
+ dtype
27
+ Polars dtype to convert
28
+
29
+ Returns
30
+ -------
31
+ Matching pylibcudf DataType object.
32
+
33
+ Raises
34
+ ------
35
+ NotImplementedError
36
+ For unsupported conversions.
37
+ """
38
+ if isinstance(dtype, pl.Boolean):
39
+ return plc.DataType(plc.TypeId.BOOL8)
40
+ elif isinstance(dtype, pl.Int8):
41
+ return plc.DataType(plc.TypeId.INT8)
42
+ elif isinstance(dtype, pl.Int16):
43
+ return plc.DataType(plc.TypeId.INT16)
44
+ elif isinstance(dtype, pl.Int32):
45
+ return plc.DataType(plc.TypeId.INT32)
46
+ elif isinstance(dtype, pl.Int64):
47
+ return plc.DataType(plc.TypeId.INT64)
48
+ if isinstance(dtype, pl.UInt8):
49
+ return plc.DataType(plc.TypeId.UINT8)
50
+ elif isinstance(dtype, pl.UInt16):
51
+ return plc.DataType(plc.TypeId.UINT16)
52
+ elif isinstance(dtype, pl.UInt32):
53
+ return plc.DataType(plc.TypeId.UINT32)
54
+ elif isinstance(dtype, pl.UInt64):
55
+ return plc.DataType(plc.TypeId.UINT64)
56
+ elif isinstance(dtype, pl.Float32):
57
+ return plc.DataType(plc.TypeId.FLOAT32)
58
+ elif isinstance(dtype, pl.Float64):
59
+ return plc.DataType(plc.TypeId.FLOAT64)
60
+ elif isinstance(dtype, pl.Date):
61
+ return plc.DataType(plc.TypeId.TIMESTAMP_DAYS)
62
+ elif isinstance(dtype, pl.Time):
63
+ raise NotImplementedError("Time of day dtype not implemented")
64
+ elif isinstance(dtype, pl.Datetime):
65
+ if dtype.time_unit == "ms":
66
+ return plc.DataType(plc.TypeId.TIMESTAMP_MILLISECONDS)
67
+ elif dtype.time_unit == "us":
68
+ return plc.DataType(plc.TypeId.TIMESTAMP_MICROSECONDS)
69
+ elif dtype.time_unit == "ns":
70
+ return plc.DataType(plc.TypeId.TIMESTAMP_NANOSECONDS)
71
+ assert dtype.time_unit is not None # pragma: no cover
72
+ assert_never(dtype.time_unit)
73
+ elif isinstance(dtype, pl.Duration):
74
+ if dtype.time_unit == "ms":
75
+ return plc.DataType(plc.TypeId.DURATION_MILLISECONDS)
76
+ elif dtype.time_unit == "us":
77
+ return plc.DataType(plc.TypeId.DURATION_MICROSECONDS)
78
+ elif dtype.time_unit == "ns":
79
+ return plc.DataType(plc.TypeId.DURATION_NANOSECONDS)
80
+ assert dtype.time_unit is not None # pragma: no cover
81
+ assert_never(dtype.time_unit)
82
+ elif isinstance(dtype, pl.String):
83
+ return plc.DataType(plc.TypeId.STRING)
84
+ elif isinstance(dtype, pl.Null):
85
+ # TODO: Hopefully
86
+ return plc.DataType(plc.TypeId.EMPTY)
87
+ elif isinstance(dtype, pl.List):
88
+ # Recurse to catch unsupported inner types
89
+ _ = _from_polars(dtype.inner)
90
+ return plc.DataType(plc.TypeId.LIST)
91
+ elif isinstance(dtype, pl.Struct):
92
+ # Recurse to catch unsupported field types
93
+ for field in dtype.fields:
94
+ _ = _from_polars(field.dtype)
95
+ return plc.DataType(plc.TypeId.STRUCT)
96
+ else:
97
+ raise NotImplementedError(f"{dtype=} conversion not supported")
98
+
99
+
100
+ class DataType:
101
+ """A datatype, preserving polars metadata."""
102
+
103
+ polars: pl.datatypes.DataType
104
+ plc: plc.DataType
105
+
106
+ def __init__(self, polars_dtype: pl.DataType) -> None:
107
+ self.polars = polars_dtype
108
+ self.plc = _from_polars(polars_dtype)
109
+
110
+ def id(self) -> plc.TypeId:
111
+ """The pylibcudf.TypeId of this DataType."""
112
+ return self.plc.id()
113
+
114
+ @property
115
+ def children(self) -> list[DataType]:
116
+ """The children types of this DataType."""
117
+ if self.plc.id() == plc.TypeId.STRUCT:
118
+ return [DataType(field.dtype) for field in self.polars.fields]
119
+ elif self.plc.id() == plc.TypeId.LIST:
120
+ return [DataType(self.polars.inner)]
121
+ return []
122
+
123
+ def __eq__(self, other: object) -> bool:
124
+ """Equality of DataTypes."""
125
+ if not isinstance(other, DataType):
126
+ return False
127
+ return self.polars == other.polars
128
+
129
+ def __hash__(self) -> int:
130
+ """Hash of the DataType."""
131
+ return hash(self.polars)
132
+
133
+ def __repr__(self) -> str:
134
+ """Representation of the DataType."""
135
+ return f"<DataType(polars={self.polars}, plc={self.id()!r})>"
cudf_polars/dsl/expr.py CHANGED
@@ -33,6 +33,7 @@ from cudf_polars.dsl.expressions.selection import Filter, Gather
33
33
  from cudf_polars.dsl.expressions.slicing import Slice
34
34
  from cudf_polars.dsl.expressions.sorting import Sort, SortBy
35
35
  from cudf_polars.dsl.expressions.string import StringFunction
36
+ from cudf_polars.dsl.expressions.struct import StructFunction
36
37
  from cudf_polars.dsl.expressions.ternary import Ternary
37
38
  from cudf_polars.dsl.expressions.unary import Cast, Len, UnaryFunction
38
39
 
@@ -58,6 +59,7 @@ __all__ = [
58
59
  "Sort",
59
60
  "SortBy",
60
61
  "StringFunction",
62
+ "StructFunction",
61
63
  "TemporalFunction",
62
64
  "Ternary",
63
65
  "UnaryFunction",