cudf-polars-cu13 25.10.0__py3-none-any.whl → 26.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. cudf_polars/GIT_COMMIT +1 -1
  2. cudf_polars/VERSION +1 -1
  3. cudf_polars/callback.py +60 -15
  4. cudf_polars/containers/column.py +137 -77
  5. cudf_polars/containers/dataframe.py +123 -34
  6. cudf_polars/containers/datatype.py +134 -13
  7. cudf_polars/dsl/expr.py +0 -2
  8. cudf_polars/dsl/expressions/aggregation.py +80 -28
  9. cudf_polars/dsl/expressions/binaryop.py +34 -14
  10. cudf_polars/dsl/expressions/boolean.py +110 -37
  11. cudf_polars/dsl/expressions/datetime.py +59 -30
  12. cudf_polars/dsl/expressions/literal.py +11 -5
  13. cudf_polars/dsl/expressions/rolling.py +460 -119
  14. cudf_polars/dsl/expressions/selection.py +9 -8
  15. cudf_polars/dsl/expressions/slicing.py +1 -1
  16. cudf_polars/dsl/expressions/string.py +256 -114
  17. cudf_polars/dsl/expressions/struct.py +19 -7
  18. cudf_polars/dsl/expressions/ternary.py +33 -3
  19. cudf_polars/dsl/expressions/unary.py +126 -64
  20. cudf_polars/dsl/ir.py +1053 -350
  21. cudf_polars/dsl/to_ast.py +30 -13
  22. cudf_polars/dsl/tracing.py +194 -0
  23. cudf_polars/dsl/translate.py +307 -107
  24. cudf_polars/dsl/utils/aggregations.py +43 -30
  25. cudf_polars/dsl/utils/reshape.py +14 -2
  26. cudf_polars/dsl/utils/rolling.py +12 -8
  27. cudf_polars/dsl/utils/windows.py +35 -20
  28. cudf_polars/experimental/base.py +55 -2
  29. cudf_polars/experimental/benchmarks/pdsds.py +12 -126
  30. cudf_polars/experimental/benchmarks/pdsh.py +792 -2
  31. cudf_polars/experimental/benchmarks/utils.py +596 -39
  32. cudf_polars/experimental/dask_registers.py +47 -20
  33. cudf_polars/experimental/dispatch.py +9 -3
  34. cudf_polars/experimental/distinct.py +2 -0
  35. cudf_polars/experimental/explain.py +15 -2
  36. cudf_polars/experimental/expressions.py +30 -15
  37. cudf_polars/experimental/groupby.py +25 -4
  38. cudf_polars/experimental/io.py +156 -124
  39. cudf_polars/experimental/join.py +53 -23
  40. cudf_polars/experimental/parallel.py +68 -19
  41. cudf_polars/experimental/rapidsmpf/__init__.py +8 -0
  42. cudf_polars/experimental/rapidsmpf/collectives/__init__.py +9 -0
  43. cudf_polars/experimental/rapidsmpf/collectives/allgather.py +90 -0
  44. cudf_polars/experimental/rapidsmpf/collectives/common.py +96 -0
  45. cudf_polars/experimental/rapidsmpf/collectives/shuffle.py +253 -0
  46. cudf_polars/experimental/rapidsmpf/core.py +488 -0
  47. cudf_polars/experimental/rapidsmpf/dask.py +172 -0
  48. cudf_polars/experimental/rapidsmpf/dispatch.py +153 -0
  49. cudf_polars/experimental/rapidsmpf/io.py +696 -0
  50. cudf_polars/experimental/rapidsmpf/join.py +322 -0
  51. cudf_polars/experimental/rapidsmpf/lower.py +74 -0
  52. cudf_polars/experimental/rapidsmpf/nodes.py +735 -0
  53. cudf_polars/experimental/rapidsmpf/repartition.py +216 -0
  54. cudf_polars/experimental/rapidsmpf/union.py +115 -0
  55. cudf_polars/experimental/rapidsmpf/utils.py +374 -0
  56. cudf_polars/experimental/repartition.py +9 -2
  57. cudf_polars/experimental/select.py +177 -14
  58. cudf_polars/experimental/shuffle.py +46 -12
  59. cudf_polars/experimental/sort.py +100 -26
  60. cudf_polars/experimental/spilling.py +1 -1
  61. cudf_polars/experimental/statistics.py +24 -5
  62. cudf_polars/experimental/utils.py +25 -7
  63. cudf_polars/testing/asserts.py +13 -8
  64. cudf_polars/testing/io.py +2 -1
  65. cudf_polars/testing/plugin.py +93 -17
  66. cudf_polars/typing/__init__.py +86 -32
  67. cudf_polars/utils/config.py +473 -58
  68. cudf_polars/utils/cuda_stream.py +70 -0
  69. cudf_polars/utils/versions.py +5 -4
  70. cudf_polars_cu13-26.2.0.dist-info/METADATA +181 -0
  71. cudf_polars_cu13-26.2.0.dist-info/RECORD +108 -0
  72. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/WHEEL +1 -1
  73. cudf_polars_cu13-25.10.0.dist-info/METADATA +0 -136
  74. cudf_polars_cu13-25.10.0.dist-info/RECORD +0 -92
  75. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/licenses/LICENSE +0 -0
  76. {cudf_polars_cu13-25.10.0.dist-info → cudf_polars_cu13-26.2.0.dist-info}/top_level.txt +0 -0
@@ -20,8 +20,9 @@ if TYPE_CHECKING:
20
20
 
21
21
  from typing_extensions import Any, CapsuleType, Self
22
22
 
23
- from cudf_polars.typing import ColumnOptions, DataFrameHeader, PolarsDataType, Slice
23
+ from rmm.pylibrmm.stream import Stream
24
24
 
25
+ from cudf_polars.typing import ColumnOptions, DataFrameHeader, PolarsDataType, Slice
25
26
 
26
27
  __all__: list[str] = ["DataFrame"]
27
28
 
@@ -55,15 +56,21 @@ def _create_polars_column_metadata(
55
56
  # This is also defined in pylibcudf.interop
56
57
  class _ObjectWithArrowMetadata:
57
58
  def __init__(
58
- self, obj: plc.Table | plc.Column, metadata: list[plc.interop.ColumnMetadata]
59
+ self,
60
+ obj: plc.Table | plc.Column,
61
+ metadata: list[plc.interop.ColumnMetadata],
62
+ stream: Stream,
59
63
  ) -> None:
60
64
  self.obj = obj
61
65
  self.metadata = metadata
66
+ self.stream = stream
62
67
 
63
68
  def __arrow_c_array__(
64
69
  self, requested_schema: None = None
65
70
  ) -> tuple[CapsuleType, CapsuleType]:
66
- return self.obj._to_schema(self.metadata), self.obj._to_host_array()
71
+ return self.obj._to_schema(self.metadata), self.obj._to_host_array(
72
+ stream=self.stream
73
+ )
67
74
 
68
75
 
69
76
  # Pacify the type checker. DataFrame init asserts that all the columns
@@ -78,8 +85,9 @@ class DataFrame:
78
85
  column_map: dict[str, Column]
79
86
  table: plc.Table
80
87
  columns: list[NamedColumn]
88
+ stream: Stream
81
89
 
82
- def __init__(self, columns: Iterable[Column]) -> None:
90
+ def __init__(self, columns: Iterable[Column], stream: Stream) -> None:
83
91
  columns = list(columns)
84
92
  if any(c.name is None for c in columns):
85
93
  raise ValueError("All columns must have a name")
@@ -87,10 +95,11 @@ class DataFrame:
87
95
  self.dtypes = [c.dtype for c in self.columns]
88
96
  self.column_map = {c.name: c for c in self.columns}
89
97
  self.table = plc.Table([c.obj for c in self.columns])
98
+ self.stream = stream
90
99
 
91
100
  def copy(self) -> Self:
92
101
  """Return a shallow copy of self."""
93
- return type(self)(c.copy() for c in self.columns)
102
+ return type(self)((c.copy() for c in self.columns), stream=self.stream)
94
103
 
95
104
  def to_polars(self) -> pl.DataFrame:
96
105
  """Convert to a polars DataFrame."""
@@ -102,10 +111,12 @@ class DataFrame:
102
111
  # serialise with names we control and rename with that map.
103
112
  name_map = {f"column_{i}": name for i, name in enumerate(self.column_map)}
104
113
  metadata = [
105
- _create_polars_column_metadata(name, dtype.polars)
114
+ _create_polars_column_metadata(name, dtype.polars_type)
106
115
  for name, dtype in zip(name_map, self.dtypes, strict=True)
107
116
  ]
108
- table_with_metadata = _ObjectWithArrowMetadata(self.table, metadata)
117
+ table_with_metadata = _ObjectWithArrowMetadata(
118
+ self.table, metadata, self.stream
119
+ )
109
120
  df = pl.DataFrame(table_with_metadata)
110
121
  return df.rename(name_map).with_columns(
111
122
  pl.col(c.name).set_sorted(descending=c.order == plc.types.Order.DESCENDING)
@@ -135,7 +146,7 @@ class DataFrame:
135
146
  return self.table.num_rows() if self.column_map else 0
136
147
 
137
148
  @classmethod
138
- def from_polars(cls, df: pl.DataFrame) -> Self:
149
+ def from_polars(cls, df: pl.DataFrame, stream: Stream) -> Self:
139
150
  """
140
151
  Create from a polars dataframe.
141
152
 
@@ -143,22 +154,34 @@ class DataFrame:
143
154
  ----------
144
155
  df
145
156
  Polars dataframe to convert
157
+ stream
158
+ CUDA stream used for device memory operations and kernel launches
159
+ on this dataframe.
146
160
 
147
161
  Returns
148
162
  -------
149
163
  New dataframe representing the input.
150
164
  """
151
- plc_table = plc.Table.from_arrow(df)
165
+ plc_table = plc.Table.from_arrow(df, stream=stream)
152
166
  return cls(
153
- Column(d_col, name=name, dtype=DataType(h_col.dtype)).copy_metadata(h_col)
154
- for d_col, h_col, name in zip(
155
- plc_table.columns(), df.iter_columns(), df.columns, strict=True
156
- )
167
+ (
168
+ Column(d_col, name=name, dtype=DataType(h_col.dtype)).copy_metadata(
169
+ h_col
170
+ )
171
+ for d_col, h_col, name in zip(
172
+ plc_table.columns(), df.iter_columns(), df.columns, strict=True
173
+ )
174
+ ),
175
+ stream=stream,
157
176
  )
158
177
 
159
178
  @classmethod
160
179
  def from_table(
161
- cls, table: plc.Table, names: Sequence[str], dtypes: Sequence[DataType]
180
+ cls,
181
+ table: plc.Table,
182
+ names: Sequence[str],
183
+ dtypes: Sequence[DataType],
184
+ stream: Stream,
162
185
  ) -> Self:
163
186
  """
164
187
  Create from a pylibcudf table.
@@ -171,6 +194,10 @@ class DataFrame:
171
194
  Names for the columns
172
195
  dtypes
173
196
  Dtypes for the columns
197
+ stream
198
+ CUDA stream used for device memory operations and kernel launches
199
+ on this dataframe. The caller is responsible for ensuring that
200
+ the data in ``table`` is valid on ``stream``.
174
201
 
175
202
  Returns
176
203
  -------
@@ -185,13 +212,19 @@ class DataFrame:
185
212
  if table.num_columns() != len(names):
186
213
  raise ValueError("Mismatching name and table length.")
187
214
  return cls(
188
- Column(c, name=name, dtype=dtype)
189
- for c, name, dtype in zip(table.columns(), names, dtypes, strict=True)
215
+ (
216
+ Column(c, name=name, dtype=dtype)
217
+ for c, name, dtype in zip(table.columns(), names, dtypes, strict=True)
218
+ ),
219
+ stream=stream,
190
220
  )
191
221
 
192
222
  @classmethod
193
223
  def deserialize(
194
- cls, header: DataFrameHeader, frames: tuple[memoryview, plc.gpumemoryview]
224
+ cls,
225
+ header: DataFrameHeader,
226
+ frames: tuple[memoryview[bytes], plc.gpumemoryview],
227
+ stream: Stream,
195
228
  ) -> Self:
196
229
  """
197
230
  Create a DataFrame from a serialized representation returned by `.serialize()`.
@@ -202,6 +235,10 @@ class DataFrame:
202
235
  The (unpickled) metadata required to reconstruct the object.
203
236
  frames
204
237
  Two-tuple of frames (a memoryview and a gpumemoryview).
238
+ stream
239
+ CUDA stream used for device memory operations and kernel launches
240
+ on this dataframe. The caller is responsible for ensuring that
241
+ the data in ``frames`` is valid on ``stream``.
205
242
 
206
243
  Returns
207
244
  -------
@@ -210,16 +247,22 @@ class DataFrame:
210
247
  """
211
248
  packed_metadata, packed_gpu_data = frames
212
249
  table = plc.contiguous_split.unpack_from_memoryviews(
213
- packed_metadata, packed_gpu_data
250
+ packed_metadata,
251
+ packed_gpu_data,
252
+ stream,
214
253
  )
215
254
  return cls(
216
- Column(c, **Column.deserialize_ctor_kwargs(kw))
217
- for c, kw in zip(table.columns(), header["columns_kwargs"], strict=True)
255
+ (
256
+ Column(c, **Column.deserialize_ctor_kwargs(kw))
257
+ for c, kw in zip(table.columns(), header["columns_kwargs"], strict=True)
258
+ ),
259
+ stream=stream,
218
260
  )
219
261
 
220
262
  def serialize(
221
263
  self,
222
- ) -> tuple[DataFrameHeader, tuple[memoryview, plc.gpumemoryview]]:
264
+ stream: Stream | None = None,
265
+ ) -> tuple[DataFrameHeader, tuple[memoryview[bytes], plc.gpumemoryview]]:
223
266
  """
224
267
  Serialize the table into header and frames.
225
268
 
@@ -231,6 +274,12 @@ class DataFrame:
231
274
  >>> from cudf_polars.experimental.dask_serialize import register
232
275
  >>> register()
233
276
 
277
+ Parameters
278
+ ----------
279
+ stream
280
+ CUDA stream used for device memory operations and kernel launches
281
+ on this dataframe.
282
+
234
283
  Returns
235
284
  -------
236
285
  header
@@ -238,7 +287,7 @@ class DataFrame:
238
287
  frames
239
288
  Two-tuple of frames suitable for passing to `plc.contiguous_split.unpack_from_memoryviews`
240
289
  """
241
- packed = plc.contiguous_split.pack(self.table)
290
+ packed = plc.contiguous_split.pack(self.table, stream=stream)
242
291
 
243
292
  # Keyword arguments for `Column.__init__`.
244
293
  columns_kwargs: list[ColumnOptions] = [
@@ -276,12 +325,19 @@ class DataFrame:
276
325
  raise ValueError("Can only copy from identically named frame")
277
326
  subset = self.column_names_set if subset is None else subset
278
327
  return type(self)(
279
- c.sorted_like(other) if c.name in subset else c
280
- for c, other in zip(self.columns, like.columns, strict=True)
328
+ (
329
+ c.sorted_like(other) if c.name in subset else c
330
+ for c, other in zip(self.columns, like.columns, strict=True)
331
+ ),
332
+ stream=self.stream,
281
333
  )
282
334
 
283
335
  def with_columns(
284
- self, columns: Iterable[Column], *, replace_only: bool = False
336
+ self,
337
+ columns: Iterable[Column],
338
+ *,
339
+ replace_only: bool = False,
340
+ stream: Stream,
285
341
  ) -> Self:
286
342
  """
287
343
  Return a new dataframe with extra columns.
@@ -292,6 +348,13 @@ class DataFrame:
292
348
  Columns to add
293
349
  replace_only
294
350
  If true, then only replacements are allowed (matching by name).
351
+ stream
352
+ CUDA stream used for device memory operations and kernel launches.
353
+ The caller is responsible for ensuring that
354
+
355
+ 1. The data in ``columns`` is valid on ``stream``.
356
+ 2. No additional operations occur on ``self.stream`` with the
357
+ original data in ``self``.
295
358
 
296
359
  Returns
297
360
  -------
@@ -305,33 +368,57 @@ class DataFrame:
305
368
  new = {c.name: c for c in columns}
306
369
  if replace_only and not self.column_names_set.issuperset(new.keys()):
307
370
  raise ValueError("Cannot replace with non-existing names")
308
- return type(self)((self.column_map | new).values())
371
+ return type(self)((self.column_map | new).values(), stream=stream)
309
372
 
310
373
  def discard_columns(self, names: Set[str]) -> Self:
311
374
  """Drop columns by name."""
312
- return type(self)(column for column in self.columns if column.name not in names)
375
+ return type(self)(
376
+ (column for column in self.columns if column.name not in names),
377
+ stream=self.stream,
378
+ )
313
379
 
314
380
  def select(self, names: Sequence[str] | Mapping[str, Any]) -> Self:
315
381
  """Select columns by name returning DataFrame."""
316
382
  try:
317
- return type(self)(self.column_map[name] for name in names)
383
+ return type(self)(
384
+ (self.column_map[name] for name in names), stream=self.stream
385
+ )
318
386
  except KeyError as e:
319
387
  raise ValueError("Can't select missing names") from e
320
388
 
321
389
  def rename_columns(self, mapping: Mapping[str, str]) -> Self:
322
390
  """Rename some columns."""
323
- return type(self)(c.rename(mapping.get(c.name, c.name)) for c in self.columns)
391
+ return type(self)(
392
+ (c.rename(mapping.get(c.name, c.name)) for c in self.columns),
393
+ stream=self.stream,
394
+ )
324
395
 
325
396
  def select_columns(self, names: Set[str]) -> list[Column]:
326
397
  """Select columns by name."""
327
398
  return [c for c in self.columns if c.name in names]
328
399
 
329
400
  def filter(self, mask: Column) -> Self:
330
- """Return a filtered table given a mask."""
331
- table = plc.stream_compaction.apply_boolean_mask(self.table, mask.obj)
401
+ """
402
+ Return a filtered table given a mask.
403
+
404
+ Parameters
405
+ ----------
406
+ mask
407
+ Boolean mask to apply to the dataframe. It is the caller's
408
+ responsibility to ensure that ``mask`` is valid on ``self.stream``.
409
+ A mask that is derived from ``self`` via a computation on ``self.stream``
410
+ automatically satisfies this requirement.
411
+
412
+ Returns
413
+ -------
414
+ Filtered dataframe
415
+ """
416
+ table = plc.stream_compaction.apply_boolean_mask(
417
+ self.table, mask.obj, stream=self.stream
418
+ )
332
419
  return (
333
420
  type(self)
334
- .from_table(table, self.column_names, self.dtypes)
421
+ .from_table(table, self.column_names, self.dtypes, self.stream)
335
422
  .sorted_like(self)
336
423
  )
337
424
 
@@ -352,10 +439,12 @@ class DataFrame:
352
439
  if zlice is None:
353
440
  return self
354
441
  (table,) = plc.copying.slice(
355
- self.table, conversion.from_polars_slice(zlice, num_rows=self.num_rows)
442
+ self.table,
443
+ conversion.from_polars_slice(zlice, num_rows=self.num_rows),
444
+ stream=self.stream,
356
445
  )
357
446
  return (
358
447
  type(self)
359
- .from_table(table, self.column_names, self.dtypes)
448
+ .from_table(table, self.column_names, self.dtypes, self.stream)
360
449
  .sorted_like(self)
361
450
  )
@@ -6,6 +6,7 @@
6
6
  from __future__ import annotations
7
7
 
8
8
  from functools import cache
9
+ from typing import TYPE_CHECKING, Literal, cast
9
10
 
10
11
  from typing_extensions import assert_never
11
12
 
@@ -13,8 +14,103 @@ import polars as pl
13
14
 
14
15
  import pylibcudf as plc
15
16
 
17
+ if TYPE_CHECKING:
18
+ from cudf_polars.typing import (
19
+ DataTypeHeader,
20
+ PolarsDataType,
21
+ )
22
+
16
23
  __all__ = ["DataType"]
17
24
 
25
+ SCALAR_NAME_TO_POLARS_TYPE_MAP: dict[str, pl.DataType] = {
26
+ "Boolean": pl.Boolean(),
27
+ "Int8": pl.Int8(),
28
+ "Int16": pl.Int16(),
29
+ "Int32": pl.Int32(),
30
+ "Int64": pl.Int64(),
31
+ "Object": pl.Object(),
32
+ "UInt8": pl.UInt8(),
33
+ "UInt16": pl.UInt16(),
34
+ "UInt32": pl.UInt32(),
35
+ "UInt64": pl.UInt64(),
36
+ "Float32": pl.Float32(),
37
+ "Float64": pl.Float64(),
38
+ "String": pl.String(),
39
+ "Null": pl.Null(),
40
+ "Date": pl.Date(),
41
+ "Time": pl.Time(),
42
+ }
43
+
44
+
45
+ def _dtype_to_header(dtype: pl.DataType) -> DataTypeHeader:
46
+ name = type(dtype).__name__
47
+ if name in SCALAR_NAME_TO_POLARS_TYPE_MAP:
48
+ return {"kind": "scalar", "name": name}
49
+ if isinstance(dtype, pl.Decimal):
50
+ # TODO: Add version guard once we support polars 1.34
51
+ # Also keep in mind the typing change in polars:
52
+ # https://github.com/pola-rs/polars/pull/25227
53
+ precision = dtype.precision if dtype.precision is not None else 38
54
+ return {
55
+ "kind": "decimal",
56
+ "precision": precision,
57
+ "scale": dtype.scale,
58
+ }
59
+ if isinstance(dtype, pl.Datetime):
60
+ return {
61
+ "kind": "datetime",
62
+ "time_unit": dtype.time_unit,
63
+ "time_zone": dtype.time_zone,
64
+ }
65
+ if isinstance(dtype, pl.Duration):
66
+ return {"kind": "duration", "time_unit": dtype.time_unit}
67
+ if isinstance(dtype, pl.List):
68
+ # isinstance narrows dtype to pl.List, but .inner returns DataTypeClass | DataType
69
+ return {
70
+ "kind": "list",
71
+ "inner": _dtype_to_header(cast(pl.DataType, dtype.inner)),
72
+ }
73
+ if isinstance(dtype, pl.Struct):
74
+ # isinstance narrows dtype to pl.Struct, but field.dtype returns DataTypeClass | DataType
75
+ return {
76
+ "kind": "struct",
77
+ "fields": [
78
+ {"name": f.name, "dtype": _dtype_to_header(cast(pl.DataType, f.dtype))}
79
+ for f in dtype.fields
80
+ ],
81
+ }
82
+ raise NotImplementedError(f"Unsupported dtype {dtype!r}")
83
+
84
+
85
+ def _dtype_from_header(header: DataTypeHeader) -> pl.DataType:
86
+ if header["kind"] == "scalar":
87
+ name = header["name"]
88
+ try:
89
+ return SCALAR_NAME_TO_POLARS_TYPE_MAP[name]
90
+ except KeyError as err:
91
+ raise NotImplementedError(f"Unknown scalar dtype name: {name}") from err
92
+ if header["kind"] == "decimal":
93
+ return pl.Decimal(header["precision"], header["scale"])
94
+ if header["kind"] == "datetime":
95
+ return pl.Datetime(
96
+ time_unit=cast(Literal["ns", "us", "ms"], header["time_unit"]),
97
+ time_zone=header["time_zone"],
98
+ )
99
+ if header["kind"] == "duration":
100
+ return pl.Duration(
101
+ time_unit=cast(Literal["ns", "us", "ms"], header["time_unit"])
102
+ )
103
+ if header["kind"] == "list":
104
+ return pl.List(_dtype_from_header(header["inner"]))
105
+ if header["kind"] == "struct":
106
+ return pl.Struct(
107
+ [
108
+ pl.Field(f["name"], _dtype_from_header(f["dtype"]))
109
+ for f in header["fields"]
110
+ ]
111
+ )
112
+ raise NotImplementedError(f"Unsupported kind {header['kind']!r}")
113
+
18
114
 
19
115
  @cache
20
116
  def _from_polars(dtype: pl.DataType) -> plc.DataType:
@@ -102,36 +198,61 @@ def _from_polars(dtype: pl.DataType) -> plc.DataType:
102
198
  class DataType:
103
199
  """A datatype, preserving polars metadata."""
104
200
 
105
- polars: pl.datatypes.DataType
106
- plc: plc.DataType
201
+ polars_type: pl.datatypes.DataType
202
+ plc_type: plc.DataType
107
203
 
108
- def __init__(self, polars_dtype: pl.DataType) -> None:
109
- self.polars = polars_dtype
110
- self.plc = _from_polars(polars_dtype)
204
+ def __init__(self, polars_dtype: PolarsDataType) -> None:
205
+ # Convert DataTypeClass to DataType instance if needed
206
+ # polars allows both pl.Int64 (class) and pl.Int64() (instance)
207
+ if isinstance(polars_dtype, type):
208
+ polars_dtype = polars_dtype()
209
+ # After conversion, it's guaranteed to be a DataType instance
210
+ self.polars_type = cast(pl.DataType, polars_dtype)
211
+ self.plc_type = _from_polars(self.polars_type)
111
212
 
112
213
  def id(self) -> plc.TypeId:
113
214
  """The pylibcudf.TypeId of this DataType."""
114
- return self.plc.id()
215
+ return self.plc_type.id()
115
216
 
116
217
  @property
117
218
  def children(self) -> list[DataType]:
118
219
  """The children types of this DataType."""
119
- if self.plc.id() == plc.TypeId.STRUCT:
120
- return [DataType(field.dtype) for field in self.polars.fields]
121
- elif self.plc.id() == plc.TypeId.LIST:
122
- return [DataType(self.polars.inner)]
220
+ # Type checker doesn't narrow polars_type through plc_type.id() checks
221
+ if self.plc_type.id() == plc.TypeId.STRUCT:
222
+ # field.dtype returns DataTypeClass | DataType, need to cast to DataType
223
+ return [
224
+ DataType(cast(pl.DataType, field.dtype))
225
+ for field in cast(pl.Struct, self.polars_type).fields
226
+ ]
227
+ elif self.plc_type.id() == plc.TypeId.LIST:
228
+ # .inner returns DataTypeClass | DataType, need to cast to DataType
229
+ return [DataType(cast(pl.DataType, cast(pl.List, self.polars_type).inner))]
123
230
  return []
124
231
 
232
+ def scale(self) -> int:
233
+ """The scale of this DataType."""
234
+ return self.plc_type.scale()
235
+
236
+ @staticmethod
237
+ def common_decimal_dtype(left: DataType, right: DataType) -> DataType:
238
+ """Return a common decimal DataType for the two inputs."""
239
+ if not (
240
+ plc.traits.is_fixed_point(left.plc_type)
241
+ and plc.traits.is_fixed_point(right.plc_type)
242
+ ):
243
+ raise ValueError("Both inputs required to be decimal types.")
244
+ return DataType(pl.Decimal(38, abs(min(left.scale(), right.scale()))))
245
+
125
246
  def __eq__(self, other: object) -> bool:
126
247
  """Equality of DataTypes."""
127
248
  if not isinstance(other, DataType):
128
249
  return False
129
- return self.polars == other.polars
250
+ return self.polars_type == other.polars_type
130
251
 
131
252
  def __hash__(self) -> int:
132
253
  """Hash of the DataType."""
133
- return hash(self.polars)
254
+ return hash(self.polars_type)
134
255
 
135
256
  def __repr__(self) -> str:
136
257
  """Representation of the DataType."""
137
- return f"<DataType(polars={self.polars}, plc={self.id()!r})>"
258
+ return f"<DataType(polars={self.polars_type}, plc={self.id()!r})>"
cudf_polars/dsl/expr.py CHANGED
@@ -1,7 +1,5 @@
1
1
  # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
- # TODO: remove need for this
4
- # ruff: noqa: D101
5
3
  """
6
4
  DSL nodes for the polars expression language.
7
5