cudf-polars-cu13 25.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. cudf_polars/GIT_COMMIT +1 -0
  2. cudf_polars/VERSION +1 -0
  3. cudf_polars/__init__.py +28 -0
  4. cudf_polars/_version.py +21 -0
  5. cudf_polars/callback.py +318 -0
  6. cudf_polars/containers/__init__.py +13 -0
  7. cudf_polars/containers/column.py +495 -0
  8. cudf_polars/containers/dataframe.py +361 -0
  9. cudf_polars/containers/datatype.py +137 -0
  10. cudf_polars/dsl/__init__.py +8 -0
  11. cudf_polars/dsl/expr.py +66 -0
  12. cudf_polars/dsl/expressions/__init__.py +8 -0
  13. cudf_polars/dsl/expressions/aggregation.py +226 -0
  14. cudf_polars/dsl/expressions/base.py +272 -0
  15. cudf_polars/dsl/expressions/binaryop.py +120 -0
  16. cudf_polars/dsl/expressions/boolean.py +326 -0
  17. cudf_polars/dsl/expressions/datetime.py +271 -0
  18. cudf_polars/dsl/expressions/literal.py +97 -0
  19. cudf_polars/dsl/expressions/rolling.py +643 -0
  20. cudf_polars/dsl/expressions/selection.py +74 -0
  21. cudf_polars/dsl/expressions/slicing.py +46 -0
  22. cudf_polars/dsl/expressions/sorting.py +85 -0
  23. cudf_polars/dsl/expressions/string.py +1002 -0
  24. cudf_polars/dsl/expressions/struct.py +137 -0
  25. cudf_polars/dsl/expressions/ternary.py +49 -0
  26. cudf_polars/dsl/expressions/unary.py +517 -0
  27. cudf_polars/dsl/ir.py +2607 -0
  28. cudf_polars/dsl/nodebase.py +164 -0
  29. cudf_polars/dsl/to_ast.py +359 -0
  30. cudf_polars/dsl/tracing.py +16 -0
  31. cudf_polars/dsl/translate.py +939 -0
  32. cudf_polars/dsl/traversal.py +224 -0
  33. cudf_polars/dsl/utils/__init__.py +8 -0
  34. cudf_polars/dsl/utils/aggregations.py +481 -0
  35. cudf_polars/dsl/utils/groupby.py +98 -0
  36. cudf_polars/dsl/utils/naming.py +34 -0
  37. cudf_polars/dsl/utils/replace.py +61 -0
  38. cudf_polars/dsl/utils/reshape.py +74 -0
  39. cudf_polars/dsl/utils/rolling.py +121 -0
  40. cudf_polars/dsl/utils/windows.py +192 -0
  41. cudf_polars/experimental/__init__.py +8 -0
  42. cudf_polars/experimental/base.py +386 -0
  43. cudf_polars/experimental/benchmarks/__init__.py +4 -0
  44. cudf_polars/experimental/benchmarks/pdsds.py +220 -0
  45. cudf_polars/experimental/benchmarks/pdsds_queries/__init__.py +4 -0
  46. cudf_polars/experimental/benchmarks/pdsds_queries/q1.py +88 -0
  47. cudf_polars/experimental/benchmarks/pdsds_queries/q10.py +225 -0
  48. cudf_polars/experimental/benchmarks/pdsds_queries/q2.py +244 -0
  49. cudf_polars/experimental/benchmarks/pdsds_queries/q3.py +65 -0
  50. cudf_polars/experimental/benchmarks/pdsds_queries/q4.py +359 -0
  51. cudf_polars/experimental/benchmarks/pdsds_queries/q5.py +462 -0
  52. cudf_polars/experimental/benchmarks/pdsds_queries/q6.py +92 -0
  53. cudf_polars/experimental/benchmarks/pdsds_queries/q7.py +79 -0
  54. cudf_polars/experimental/benchmarks/pdsds_queries/q8.py +524 -0
  55. cudf_polars/experimental/benchmarks/pdsds_queries/q9.py +137 -0
  56. cudf_polars/experimental/benchmarks/pdsh.py +814 -0
  57. cudf_polars/experimental/benchmarks/utils.py +832 -0
  58. cudf_polars/experimental/dask_registers.py +200 -0
  59. cudf_polars/experimental/dispatch.py +156 -0
  60. cudf_polars/experimental/distinct.py +197 -0
  61. cudf_polars/experimental/explain.py +157 -0
  62. cudf_polars/experimental/expressions.py +590 -0
  63. cudf_polars/experimental/groupby.py +327 -0
  64. cudf_polars/experimental/io.py +943 -0
  65. cudf_polars/experimental/join.py +391 -0
  66. cudf_polars/experimental/parallel.py +423 -0
  67. cudf_polars/experimental/repartition.py +69 -0
  68. cudf_polars/experimental/scheduler.py +155 -0
  69. cudf_polars/experimental/select.py +188 -0
  70. cudf_polars/experimental/shuffle.py +354 -0
  71. cudf_polars/experimental/sort.py +609 -0
  72. cudf_polars/experimental/spilling.py +151 -0
  73. cudf_polars/experimental/statistics.py +795 -0
  74. cudf_polars/experimental/utils.py +169 -0
  75. cudf_polars/py.typed +0 -0
  76. cudf_polars/testing/__init__.py +8 -0
  77. cudf_polars/testing/asserts.py +448 -0
  78. cudf_polars/testing/io.py +122 -0
  79. cudf_polars/testing/plugin.py +236 -0
  80. cudf_polars/typing/__init__.py +219 -0
  81. cudf_polars/utils/__init__.py +8 -0
  82. cudf_polars/utils/config.py +741 -0
  83. cudf_polars/utils/conversion.py +40 -0
  84. cudf_polars/utils/dtypes.py +118 -0
  85. cudf_polars/utils/sorting.py +53 -0
  86. cudf_polars/utils/timer.py +39 -0
  87. cudf_polars/utils/versions.py +27 -0
  88. cudf_polars_cu13-25.10.0.dist-info/METADATA +136 -0
  89. cudf_polars_cu13-25.10.0.dist-info/RECORD +92 -0
  90. cudf_polars_cu13-25.10.0.dist-info/WHEEL +5 -0
  91. cudf_polars_cu13-25.10.0.dist-info/licenses/LICENSE +201 -0
  92. cudf_polars_cu13-25.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1002 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ # TODO: Document StringFunction to remove noqa
4
+ # ruff: noqa: D101
5
+ """DSL nodes for string operations."""
6
+
7
+ from __future__ import annotations
8
+
9
+ import functools
10
+ import re
11
+ from datetime import datetime
12
+ from enum import IntEnum, auto
13
+ from typing import TYPE_CHECKING, Any, ClassVar
14
+
15
+ from polars.exceptions import InvalidOperationError
16
+ from polars.polars import dtype_str_repr
17
+
18
+ import pylibcudf as plc
19
+
20
+ from cudf_polars.containers import Column
21
+ from cudf_polars.dsl.expressions.base import ExecutionContext, Expr
22
+ from cudf_polars.dsl.expressions.literal import Literal, LiteralColumn
23
+ from cudf_polars.dsl.utils.reshape import broadcast
24
+ from cudf_polars.utils.versions import POLARS_VERSION_LT_132
25
+
26
+ if TYPE_CHECKING:
27
+ from typing_extensions import Self
28
+
29
+ from polars.polars import _expr_nodes as pl_expr
30
+
31
+ from cudf_polars.containers import DataFrame, DataType
32
+
33
+ __all__ = ["StringFunction"]
34
+
35
+ JsonDecodeType = list[tuple[str, plc.DataType, "JsonDecodeType"]]
36
+
37
+
38
+ def _dtypes_for_json_decode(dtype: DataType) -> JsonDecodeType:
39
+ """Get the dtypes for json decode."""
40
+ if dtype.id() == plc.TypeId.STRUCT:
41
+ return [
42
+ (field.name, child.plc, _dtypes_for_json_decode(child))
43
+ for field, child in zip(dtype.polars.fields, dtype.children, strict=True)
44
+ ]
45
+ else:
46
+ return []
47
+
48
+
49
+ class StringFunction(Expr):
50
+ class Name(IntEnum):
51
+ """Internal and picklable representation of polars' `StringFunction`."""
52
+
53
+ Base64Decode = auto()
54
+ Base64Encode = auto()
55
+ ConcatHorizontal = auto()
56
+ ConcatVertical = auto()
57
+ Contains = auto()
58
+ ContainsAny = auto()
59
+ CountMatches = auto()
60
+ EndsWith = auto()
61
+ EscapeRegex = auto()
62
+ Extract = auto()
63
+ ExtractAll = auto()
64
+ ExtractGroups = auto()
65
+ Find = auto()
66
+ Head = auto()
67
+ HexDecode = auto()
68
+ HexEncode = auto()
69
+ JsonDecode = auto()
70
+ JsonPathMatch = auto()
71
+ LenBytes = auto()
72
+ LenChars = auto()
73
+ Lowercase = auto()
74
+ Normalize = auto()
75
+ PadEnd = auto()
76
+ PadStart = auto()
77
+ Replace = auto()
78
+ ReplaceMany = auto()
79
+ Reverse = auto()
80
+ Slice = auto()
81
+ Split = auto()
82
+ SplitExact = auto()
83
+ SplitN = auto()
84
+ StartsWith = auto()
85
+ StripChars = auto()
86
+ StripCharsEnd = auto()
87
+ StripCharsStart = auto()
88
+ StripPrefix = auto()
89
+ StripSuffix = auto()
90
+ Strptime = auto()
91
+ Tail = auto()
92
+ Titlecase = auto()
93
+ ToDecimal = auto()
94
+ ToInteger = auto()
95
+ Uppercase = auto()
96
+ ZFill = auto()
97
+
98
+ @classmethod
99
+ def from_polars(cls, obj: pl_expr.StringFunction) -> Self:
100
+ """Convert from polars' `StringFunction`."""
101
+ try:
102
+ function, name = str(obj).split(".", maxsplit=1)
103
+ except ValueError:
104
+ # Failed to unpack string
105
+ function = None
106
+ if function != "StringFunction":
107
+ raise ValueError("StringFunction required")
108
+ return getattr(cls, name)
109
+
110
+ _valid_ops: ClassVar[set[Name]] = {
111
+ Name.ConcatHorizontal,
112
+ Name.ConcatVertical,
113
+ Name.ContainsAny,
114
+ Name.Contains,
115
+ Name.CountMatches,
116
+ Name.EndsWith,
117
+ Name.Extract,
118
+ Name.ExtractGroups,
119
+ Name.Find,
120
+ Name.Head,
121
+ Name.JsonDecode,
122
+ Name.JsonPathMatch,
123
+ Name.LenBytes,
124
+ Name.LenChars,
125
+ Name.Lowercase,
126
+ Name.PadEnd,
127
+ Name.PadStart,
128
+ Name.Replace,
129
+ Name.ReplaceMany,
130
+ Name.Slice,
131
+ Name.SplitN,
132
+ Name.SplitExact,
133
+ Name.Strptime,
134
+ Name.StartsWith,
135
+ Name.StripChars,
136
+ Name.StripCharsStart,
137
+ Name.StripCharsEnd,
138
+ Name.StripPrefix,
139
+ Name.StripSuffix,
140
+ Name.Uppercase,
141
+ Name.Reverse,
142
+ Name.Tail,
143
+ Name.Titlecase,
144
+ Name.ZFill,
145
+ }
146
+ __slots__ = ("_regex_program", "name", "options")
147
+ _non_child = ("dtype", "name", "options")
148
+
149
+ def __init__(
150
+ self,
151
+ dtype: DataType,
152
+ name: StringFunction.Name,
153
+ options: tuple[Any, ...],
154
+ *children: Expr,
155
+ ) -> None:
156
+ self.dtype = dtype
157
+ self.options = options
158
+ self.name = name
159
+ self.children = children
160
+ self.is_pointwise = self.name != StringFunction.Name.ConcatVertical
161
+ self._validate_input()
162
+
163
+ def _validate_input(self) -> None:
164
+ if self.name not in self._valid_ops:
165
+ raise NotImplementedError(f"String function {self.name!r}")
166
+ if self.name is StringFunction.Name.CountMatches:
167
+ (literal,) = self.options
168
+ if literal:
169
+ raise NotImplementedError(
170
+ f"{literal=} is not supported for count_matches"
171
+ )
172
+ literal_expr = self.children[1]
173
+ assert isinstance(literal_expr, Literal)
174
+ pattern = literal_expr.value
175
+ self._regex_program = self._create_regex_program(pattern)
176
+ elif self.name is StringFunction.Name.Contains:
177
+ literal, strict = self.options
178
+ if not literal:
179
+ if not strict:
180
+ raise NotImplementedError(
181
+ f"{strict=} is not supported for regex contains"
182
+ )
183
+ if not isinstance(self.children[1], Literal):
184
+ raise NotImplementedError(
185
+ "Regex contains only supports a scalar pattern"
186
+ )
187
+ pattern = self.children[1].value
188
+ self._regex_program = self._create_regex_program(pattern)
189
+ elif self.name is StringFunction.Name.Extract:
190
+ (group_index,) = self.options
191
+ if group_index == 0:
192
+ raise NotImplementedError(f"{group_index=} is not supported")
193
+ literal_expr = self.children[1]
194
+ assert isinstance(literal_expr, Literal)
195
+ pattern = literal_expr.value
196
+ self._regex_program = self._create_regex_program(pattern)
197
+ elif self.name is StringFunction.Name.ExtractGroups:
198
+ (_, pattern) = self.options
199
+ self._regex_program = self._create_regex_program(pattern)
200
+ elif self.name is StringFunction.Name.Find:
201
+ literal, strict = self.options
202
+ if not literal:
203
+ if not strict:
204
+ raise NotImplementedError(
205
+ f"{strict=} is not supported for regex contains"
206
+ )
207
+ if not isinstance(self.children[1], Literal):
208
+ raise NotImplementedError(
209
+ "Regex contains only supports a scalar pattern"
210
+ )
211
+ pattern = self.children[1].value
212
+ self._regex_program = self._create_regex_program(pattern)
213
+ elif self.name is StringFunction.Name.Replace:
214
+ _, literal = self.options
215
+ if not literal:
216
+ raise NotImplementedError("literal=False is not supported for replace")
217
+ if not all(isinstance(expr, Literal) for expr in self.children[1:]):
218
+ raise NotImplementedError("replace only supports scalar target")
219
+ target = self.children[1]
220
+ # Above, we raise NotImplementedError if the target is not a Literal,
221
+ # so we can safely access .value here.
222
+ if target.value == "": # type: ignore[attr-defined]
223
+ raise NotImplementedError(
224
+ "libcudf replace does not support empty strings"
225
+ )
226
+ elif self.name is StringFunction.Name.ReplaceMany:
227
+ (ascii_case_insensitive,) = self.options
228
+ if ascii_case_insensitive:
229
+ raise NotImplementedError(
230
+ "ascii_case_insensitive not implemented for replace_many"
231
+ )
232
+ if not all(
233
+ isinstance(expr, (LiteralColumn, Literal)) for expr in self.children[1:]
234
+ ):
235
+ raise NotImplementedError("replace_many only supports literal inputs")
236
+ target = self.children[1]
237
+ # Above, we raise NotImplementedError if the target is not a Literal,
238
+ # so we can safely access .value here.
239
+ if (isinstance(target, Literal) and target.value == "") or (
240
+ isinstance(target, LiteralColumn) and (target.value == "").any()
241
+ ):
242
+ raise NotImplementedError(
243
+ "libcudf replace_many is implemented differently from polars "
244
+ "for empty strings"
245
+ )
246
+ elif self.name is StringFunction.Name.Slice:
247
+ if not all(isinstance(child, Literal) for child in self.children[1:]):
248
+ raise NotImplementedError(
249
+ "Slice only supports literal start and stop values"
250
+ )
251
+ elif self.name is StringFunction.Name.SplitExact:
252
+ (_, inclusive) = self.options
253
+ if inclusive:
254
+ raise NotImplementedError(f"{inclusive=} is not supported for split")
255
+ elif self.name is StringFunction.Name.Strptime:
256
+ format, strict, exact, cache = self.options
257
+ if not format and not strict:
258
+ raise NotImplementedError("format inference requires strict checking")
259
+ if cache:
260
+ raise NotImplementedError("Strptime cache is a CPU feature")
261
+ if not exact:
262
+ raise NotImplementedError("Strptime does not support exact=False")
263
+ elif self.name in {
264
+ StringFunction.Name.StripChars,
265
+ StringFunction.Name.StripCharsStart,
266
+ StringFunction.Name.StripCharsEnd,
267
+ }:
268
+ if not isinstance(self.children[1], Literal):
269
+ raise NotImplementedError(
270
+ "strip operations only support scalar patterns"
271
+ )
272
+ elif self.name is StringFunction.Name.ZFill:
273
+ if isinstance(self.children[1], Literal):
274
+ _, width = self.children
275
+ assert isinstance(width, Literal)
276
+ if (
277
+ POLARS_VERSION_LT_132
278
+ and width.value is not None
279
+ and width.value < 0
280
+ ): # pragma: no cover
281
+ dtypestr = dtype_str_repr(width.dtype.polars)
282
+ raise InvalidOperationError(
283
+ f"conversion from `{dtypestr}` to `u64` "
284
+ f"failed in column 'literal' for 1 out of "
285
+ f"1 values: [{width.value}]"
286
+ ) from None
287
+
288
+ @staticmethod
289
+ def _create_regex_program(
290
+ pattern: str,
291
+ flags: plc.strings.regex_flags.RegexFlags = plc.strings.regex_flags.RegexFlags.DEFAULT,
292
+ ) -> plc.strings.regex_program.RegexProgram:
293
+ if pattern == "":
294
+ raise NotImplementedError("Empty regex pattern is not yet supported")
295
+ try:
296
+ return plc.strings.regex_program.RegexProgram.create(
297
+ pattern,
298
+ flags=flags,
299
+ )
300
+ except RuntimeError as e:
301
+ raise NotImplementedError(
302
+ f"Unsupported regex {pattern} for GPU engine."
303
+ ) from e
304
+
305
+ def do_evaluate(
306
+ self, df: DataFrame, *, context: ExecutionContext = ExecutionContext.FRAME
307
+ ) -> Column:
308
+ """Evaluate this expression given a dataframe for context."""
309
+ if self.name is StringFunction.Name.ConcatHorizontal:
310
+ columns = [
311
+ Column(
312
+ child.evaluate(df, context=context).obj, dtype=child.dtype
313
+ ).astype(self.dtype)
314
+ for child in self.children
315
+ ]
316
+
317
+ non_unit_sizes = [c.size for c in columns if c.size != 1]
318
+ broadcasted = broadcast(
319
+ *columns,
320
+ target_length=max(non_unit_sizes) if non_unit_sizes else None,
321
+ )
322
+
323
+ delimiter, ignore_nulls = self.options
324
+
325
+ return Column(
326
+ plc.strings.combine.concatenate(
327
+ plc.Table([col.obj for col in broadcasted]),
328
+ plc.Scalar.from_py(delimiter, self.dtype.plc),
329
+ None if ignore_nulls else plc.Scalar.from_py(None, self.dtype.plc),
330
+ None,
331
+ plc.strings.combine.SeparatorOnNulls.NO,
332
+ ),
333
+ dtype=self.dtype,
334
+ )
335
+ elif self.name is StringFunction.Name.ConcatVertical:
336
+ (child,) = self.children
337
+ column = child.evaluate(df, context=context).astype(self.dtype)
338
+ delimiter, ignore_nulls = self.options
339
+ if column.null_count > 0 and not ignore_nulls:
340
+ return Column(plc.Column.all_null_like(column.obj, 1), dtype=self.dtype)
341
+ return Column(
342
+ plc.strings.combine.join_strings(
343
+ column.obj,
344
+ plc.Scalar.from_py(delimiter, self.dtype.plc),
345
+ plc.Scalar.from_py(None, self.dtype.plc),
346
+ ),
347
+ dtype=self.dtype,
348
+ )
349
+ elif self.name is StringFunction.Name.ZFill:
350
+ # TODO: expensive validation
351
+ # polars pads based on bytes, libcudf by visual width
352
+ # only pass chars if the visual width matches the byte length
353
+ column = self.children[0].evaluate(df, context=context)
354
+ col_len_bytes = plc.strings.attributes.count_bytes(column.obj)
355
+ col_len_chars = plc.strings.attributes.count_characters(column.obj)
356
+ equal = plc.binaryop.binary_operation(
357
+ col_len_bytes,
358
+ col_len_chars,
359
+ plc.binaryop.BinaryOperator.NULL_EQUALS,
360
+ plc.DataType(plc.TypeId.BOOL8),
361
+ )
362
+ if not plc.reduce.reduce(
363
+ equal,
364
+ plc.aggregation.all(),
365
+ plc.DataType(plc.TypeId.BOOL8),
366
+ ).to_py():
367
+ raise InvalidOperationError(
368
+ "zfill only supports ascii strings with no unicode characters"
369
+ )
370
+ if isinstance(self.children[1], Literal):
371
+ width = self.children[1]
372
+ assert isinstance(width, Literal)
373
+ if width.value is None:
374
+ return Column(
375
+ plc.Column.from_scalar(
376
+ plc.Scalar.from_py(None, self.dtype.plc),
377
+ column.size,
378
+ ),
379
+ self.dtype,
380
+ )
381
+ return Column(
382
+ plc.strings.padding.zfill(column.obj, width.value), self.dtype
383
+ )
384
+ else:
385
+ col_width = self.children[1].evaluate(df, context=context)
386
+ assert isinstance(col_width, Column)
387
+ all_gt_0 = plc.binaryop.binary_operation(
388
+ col_width.obj,
389
+ plc.Scalar.from_py(0, plc.DataType(plc.TypeId.INT64)),
390
+ plc.binaryop.BinaryOperator.GREATER_EQUAL,
391
+ plc.DataType(plc.TypeId.BOOL8),
392
+ )
393
+
394
+ if (
395
+ POLARS_VERSION_LT_132
396
+ and not plc.reduce.reduce(
397
+ all_gt_0,
398
+ plc.aggregation.all(),
399
+ plc.DataType(plc.TypeId.BOOL8),
400
+ ).to_py()
401
+ ): # pragma: no cover
402
+ raise InvalidOperationError("fill conversion failed.")
403
+
404
+ return Column(
405
+ plc.strings.padding.zfill_by_widths(column.obj, col_width.obj),
406
+ self.dtype,
407
+ )
408
+
409
+ elif self.name is StringFunction.Name.Contains:
410
+ child, arg = self.children
411
+ column = child.evaluate(df, context=context)
412
+
413
+ literal, _ = self.options
414
+ if literal:
415
+ pat = arg.evaluate(df, context=context)
416
+ pattern = (
417
+ pat.obj_scalar
418
+ if pat.is_scalar and pat.size != column.size
419
+ else pat.obj
420
+ )
421
+ return Column(
422
+ plc.strings.find.contains(column.obj, pattern), dtype=self.dtype
423
+ )
424
+ else:
425
+ return Column(
426
+ plc.strings.contains.contains_re(column.obj, self._regex_program),
427
+ dtype=self.dtype,
428
+ )
429
+ elif self.name is StringFunction.Name.ContainsAny:
430
+ (ascii_case_insensitive,) = self.options
431
+ child, arg = self.children
432
+ column = child.evaluate(df, context=context).obj
433
+ targets = arg.evaluate(df, context=context).obj
434
+ if ascii_case_insensitive:
435
+ column = plc.strings.case.to_lower(column)
436
+ targets = plc.strings.case.to_lower(targets)
437
+ contains = plc.strings.find_multiple.contains_multiple(
438
+ column,
439
+ targets,
440
+ )
441
+ binary_or = functools.partial(
442
+ plc.binaryop.binary_operation,
443
+ op=plc.binaryop.BinaryOperator.BITWISE_OR,
444
+ output_type=self.dtype.plc,
445
+ )
446
+ return Column(
447
+ functools.reduce(binary_or, contains.columns()),
448
+ dtype=self.dtype,
449
+ )
450
+ elif self.name is StringFunction.Name.CountMatches:
451
+ (child, _) = self.children
452
+ column = child.evaluate(df, context=context).obj
453
+ return Column(
454
+ plc.unary.cast(
455
+ plc.strings.contains.count_re(column, self._regex_program),
456
+ self.dtype.plc,
457
+ ),
458
+ dtype=self.dtype,
459
+ )
460
+ elif self.name is StringFunction.Name.Extract:
461
+ (group_index,) = self.options
462
+ column = self.children[0].evaluate(df, context=context).obj
463
+ return Column(
464
+ plc.strings.extract.extract_single(
465
+ column, self._regex_program, group_index - 1
466
+ ),
467
+ dtype=self.dtype,
468
+ )
469
+ elif self.name is StringFunction.Name.ExtractGroups:
470
+ column = self.children[0].evaluate(df, context=context).obj
471
+ plc_table = plc.strings.extract.extract(
472
+ column,
473
+ self._regex_program,
474
+ )
475
+ return Column(
476
+ plc.Column.struct_from_children(plc_table.columns()),
477
+ dtype=self.dtype,
478
+ )
479
+ elif self.name is StringFunction.Name.Find:
480
+ literal, _ = self.options
481
+ (child, expr) = self.children
482
+ column = child.evaluate(df, context=context).obj
483
+ if literal:
484
+ assert isinstance(expr, Literal)
485
+ plc_column = plc.strings.find.find(
486
+ column,
487
+ plc.Scalar.from_py(expr.value, expr.dtype.plc),
488
+ )
489
+ else:
490
+ plc_column = plc.strings.findall.find_re(
491
+ column,
492
+ self._regex_program,
493
+ )
494
+ # Polars returns None for not found, libcudf returns -1
495
+ new_mask, null_count = plc.transform.bools_to_mask(
496
+ plc.binaryop.binary_operation(
497
+ plc_column,
498
+ plc.Scalar.from_py(-1, plc_column.type()),
499
+ plc.binaryop.BinaryOperator.NOT_EQUAL,
500
+ plc.DataType(plc.TypeId.BOOL8),
501
+ )
502
+ )
503
+ plc_column = plc.unary.cast(
504
+ plc_column.with_mask(new_mask, null_count), self.dtype.plc
505
+ )
506
+ return Column(plc_column, dtype=self.dtype)
507
+ elif self.name is StringFunction.Name.JsonDecode:
508
+ plc_column = self.children[0].evaluate(df, context=context).obj
509
+ plc_table_with_metadata = plc.io.json.read_json_from_string_column(
510
+ plc_column,
511
+ plc.Scalar.from_py("\n"),
512
+ plc.Scalar.from_py("NULL"),
513
+ _dtypes_for_json_decode(self.dtype),
514
+ )
515
+ return Column(
516
+ plc.Column.struct_from_children(plc_table_with_metadata.columns),
517
+ dtype=self.dtype,
518
+ )
519
+ elif self.name is StringFunction.Name.JsonPathMatch:
520
+ (child, expr) = self.children
521
+ column = child.evaluate(df, context=context).obj
522
+ assert isinstance(expr, Literal)
523
+ json_path = plc.Scalar.from_py(expr.value, expr.dtype.plc)
524
+ return Column(
525
+ plc.json.get_json_object(column, json_path),
526
+ dtype=self.dtype,
527
+ )
528
+ elif self.name is StringFunction.Name.LenBytes:
529
+ column = self.children[0].evaluate(df, context=context).obj
530
+ return Column(
531
+ plc.unary.cast(
532
+ plc.strings.attributes.count_bytes(column), self.dtype.plc
533
+ ),
534
+ dtype=self.dtype,
535
+ )
536
+ elif self.name is StringFunction.Name.LenChars:
537
+ column = self.children[0].evaluate(df, context=context).obj
538
+ return Column(
539
+ plc.unary.cast(
540
+ plc.strings.attributes.count_characters(column), self.dtype.plc
541
+ ),
542
+ dtype=self.dtype,
543
+ )
544
+ elif self.name is StringFunction.Name.Slice:
545
+ child, expr_offset, expr_length = self.children
546
+ assert isinstance(expr_offset, Literal)
547
+ assert isinstance(expr_length, Literal)
548
+
549
+ column = child.evaluate(df, context=context)
550
+ # libcudf slices via [start,stop).
551
+ # polars slices with offset + length where start == offset
552
+ # stop = start + length. Negative values for start look backward
553
+ # from the last element of the string. If the end index would be
554
+ # below zero, an empty string is returned.
555
+ # Do this maths on the host
556
+ start = expr_offset.value
557
+ length = expr_length.value
558
+
559
+ if length == 0:
560
+ stop = start
561
+ else:
562
+ # No length indicates a scan to the end
563
+ # The libcudf equivalent is a null stop
564
+ stop = start + length if length else None
565
+ if length and start < 0 and length >= -start:
566
+ stop = None
567
+ return Column(
568
+ plc.strings.slice.slice_strings(
569
+ column.obj,
570
+ plc.Scalar.from_py(start, plc.DataType(plc.TypeId.INT32)),
571
+ plc.Scalar.from_py(stop, plc.DataType(plc.TypeId.INT32)),
572
+ ),
573
+ dtype=self.dtype,
574
+ )
575
+ elif self.name in {
576
+ StringFunction.Name.SplitExact,
577
+ StringFunction.Name.SplitN,
578
+ }:
579
+ is_split_n = self.name is StringFunction.Name.SplitN
580
+ n = self.options[0]
581
+ child, expr = self.children
582
+ column = child.evaluate(df, context=context)
583
+ if n == 1 and self.name is StringFunction.Name.SplitN:
584
+ plc_column = plc.Column(
585
+ self.dtype.plc,
586
+ column.obj.size(),
587
+ None,
588
+ None,
589
+ 0,
590
+ column.obj.offset(),
591
+ [column.obj],
592
+ )
593
+ else:
594
+ assert isinstance(expr, Literal)
595
+ by = plc.Scalar.from_py(expr.value, expr.dtype.plc)
596
+ # See https://github.com/pola-rs/polars/issues/11640
597
+ # for SplitN vs SplitExact edge case behaviors
598
+ max_splits = n if is_split_n else 0
599
+ plc_table = plc.strings.split.split.split(
600
+ column.obj,
601
+ by,
602
+ max_splits - 1,
603
+ )
604
+ children = plc_table.columns()
605
+ ref_column = children[0]
606
+ if (remainder := n - len(children)) > 0:
607
+ # Reach expected number of splits by padding with nulls
608
+ children.extend(
609
+ plc.Column.all_null_like(ref_column, ref_column.size())
610
+ for _ in range(remainder + int(not is_split_n))
611
+ )
612
+ if not is_split_n:
613
+ children = children[: n + 1]
614
+ # TODO: Use plc.Column.struct_from_children once it is generalized
615
+ # to handle columns that don't share the same null_mask/null_count
616
+ plc_column = plc.Column(
617
+ self.dtype.plc,
618
+ ref_column.size(),
619
+ None,
620
+ None,
621
+ 0,
622
+ ref_column.offset(),
623
+ children,
624
+ )
625
+ return Column(plc_column, dtype=self.dtype)
626
+ elif self.name in {
627
+ StringFunction.Name.StripPrefix,
628
+ StringFunction.Name.StripSuffix,
629
+ }:
630
+ child, expr = self.children
631
+ column = child.evaluate(df, context=context).obj
632
+ assert isinstance(expr, Literal)
633
+ target = plc.Scalar.from_py(expr.value, expr.dtype.plc)
634
+ if self.name == StringFunction.Name.StripPrefix:
635
+ find = plc.strings.find.starts_with
636
+ start = len(expr.value)
637
+ end: int | None = None
638
+ else:
639
+ find = plc.strings.find.ends_with
640
+ start = 0
641
+ end = -len(expr.value)
642
+
643
+ mask = find(column, target)
644
+ sliced = plc.strings.slice.slice_strings(
645
+ column,
646
+ plc.Scalar.from_py(start, plc.DataType(plc.TypeId.INT32)),
647
+ plc.Scalar.from_py(end, plc.DataType(plc.TypeId.INT32)),
648
+ )
649
+ return Column(
650
+ plc.copying.copy_if_else(
651
+ sliced,
652
+ column,
653
+ mask,
654
+ ),
655
+ dtype=self.dtype,
656
+ )
657
+ elif self.name in {
658
+ StringFunction.Name.StripChars,
659
+ StringFunction.Name.StripCharsStart,
660
+ StringFunction.Name.StripCharsEnd,
661
+ }:
662
+ column, chars = (c.evaluate(df, context=context) for c in self.children)
663
+ if self.name is StringFunction.Name.StripCharsStart:
664
+ side = plc.strings.SideType.LEFT
665
+ elif self.name is StringFunction.Name.StripCharsEnd:
666
+ side = plc.strings.SideType.RIGHT
667
+ else:
668
+ side = plc.strings.SideType.BOTH
669
+ return Column(
670
+ plc.strings.strip.strip(column.obj, side, chars.obj_scalar),
671
+ dtype=self.dtype,
672
+ )
673
+
674
+ elif self.name is StringFunction.Name.Tail:
675
+ column = self.children[0].evaluate(df, context=context)
676
+
677
+ assert isinstance(self.children[1], Literal)
678
+ if self.children[1].value is None:
679
+ return Column(
680
+ plc.Column.from_scalar(
681
+ plc.Scalar.from_py(None, self.dtype.plc),
682
+ column.size,
683
+ ),
684
+ self.dtype,
685
+ )
686
+ elif self.children[1].value == 0:
687
+ result = plc.Column.from_scalar(
688
+ plc.Scalar.from_py("", self.dtype.plc),
689
+ column.size,
690
+ )
691
+ if column.obj.null_mask():
692
+ result = result.with_mask(
693
+ column.obj.null_mask(), column.obj.null_count()
694
+ )
695
+ return Column(result, self.dtype)
696
+
697
+ else:
698
+ start = -(self.children[1].value)
699
+ end = 2**31 - 1
700
+ return Column(
701
+ plc.strings.slice.slice_strings(
702
+ column.obj,
703
+ plc.Scalar.from_py(start, plc.DataType(plc.TypeId.INT32)),
704
+ plc.Scalar.from_py(end, plc.DataType(plc.TypeId.INT32)),
705
+ None,
706
+ ),
707
+ self.dtype,
708
+ )
709
+ elif self.name is StringFunction.Name.Head:
710
+ column = self.children[0].evaluate(df, context=context)
711
+
712
+ assert isinstance(self.children[1], Literal)
713
+
714
+ end = self.children[1].value
715
+ if end is None:
716
+ return Column(
717
+ plc.Column.from_scalar(
718
+ plc.Scalar.from_py(None, self.dtype.plc),
719
+ column.size,
720
+ ),
721
+ self.dtype,
722
+ )
723
+ return Column(
724
+ plc.strings.slice.slice_strings(
725
+ column.obj,
726
+ plc.Scalar.from_py(0, plc.DataType(plc.TypeId.INT32)),
727
+ plc.Scalar.from_py(end, plc.DataType(plc.TypeId.INT32)),
728
+ ),
729
+ self.dtype,
730
+ )
731
+
732
+ columns = [child.evaluate(df, context=context) for child in self.children]
733
+ if self.name is StringFunction.Name.Lowercase:
734
+ (column,) = columns
735
+ return Column(plc.strings.case.to_lower(column.obj), dtype=self.dtype)
736
+ elif self.name is StringFunction.Name.Uppercase:
737
+ (column,) = columns
738
+ return Column(plc.strings.case.to_upper(column.obj), dtype=self.dtype)
739
+ elif self.name is StringFunction.Name.EndsWith:
740
+ column, suffix = columns
741
+ return Column(
742
+ plc.strings.find.ends_with(
743
+ column.obj,
744
+ suffix.obj_scalar
745
+ if column.size != suffix.size and suffix.is_scalar
746
+ else suffix.obj,
747
+ ),
748
+ dtype=self.dtype,
749
+ )
750
+ elif self.name is StringFunction.Name.StartsWith:
751
+ column, prefix = columns
752
+ return Column(
753
+ plc.strings.find.starts_with(
754
+ column.obj,
755
+ prefix.obj_scalar
756
+ if column.size != prefix.size and prefix.is_scalar
757
+ else prefix.obj,
758
+ ),
759
+ dtype=self.dtype,
760
+ )
761
+ elif self.name is StringFunction.Name.Strptime:
762
+ # TODO: ignores ambiguous
763
+ format, strict, _, _ = self.options
764
+ col = self.children[0].evaluate(df, context=context)
765
+ plc_col = col.obj
766
+ if plc_col.null_count() == plc_col.size():
767
+ return Column(
768
+ plc.Column.from_scalar(
769
+ plc.Scalar.from_py(None, self.dtype.plc),
770
+ plc_col.size(),
771
+ ),
772
+ self.dtype,
773
+ )
774
+ if format is None:
775
+ # Polars begins inference with the first non null value
776
+ if plc_col.null_mask() is not None:
777
+ boolmask = plc.unary.is_valid(plc_col)
778
+ table = plc.stream_compaction.apply_boolean_mask(
779
+ plc.Table([plc_col]), boolmask
780
+ )
781
+ filtered = table.columns()[0]
782
+ first_valid_data = plc.copying.get_element(filtered, 0).to_py()
783
+ else:
784
+ first_valid_data = plc.copying.get_element(plc_col, 0).to_py()
785
+
786
+ format = _infer_datetime_format(first_valid_data)
787
+ if not format:
788
+ raise InvalidOperationError(
789
+ "Unable to infer datetime format from data"
790
+ )
791
+
792
+ is_timestamps = plc.strings.convert.convert_datetime.is_timestamp(
793
+ plc_col, format
794
+ )
795
+ if strict:
796
+ if not plc.reduce.reduce(
797
+ is_timestamps,
798
+ plc.aggregation.all(),
799
+ plc.DataType(plc.TypeId.BOOL8),
800
+ ).to_py():
801
+ raise InvalidOperationError("conversion from `str` failed.")
802
+ else:
803
+ not_timestamps = plc.unary.unary_operation(
804
+ is_timestamps, plc.unary.UnaryOperator.NOT
805
+ )
806
+ null = plc.Scalar.from_py(None, plc_col.type())
807
+ plc_col = plc.copying.boolean_mask_scatter(
808
+ [null], plc.Table([plc_col]), not_timestamps
809
+ ).columns()[0]
810
+
811
+ return Column(
812
+ plc.strings.convert.convert_datetime.to_timestamps(
813
+ plc_col, self.dtype.plc, format
814
+ ),
815
+ dtype=self.dtype,
816
+ )
817
+ elif self.name is StringFunction.Name.Replace:
818
+ column, target, repl = columns
819
+ n, _ = self.options
820
+ return Column(
821
+ plc.strings.replace.replace(
822
+ column.obj, target.obj_scalar, repl.obj_scalar, maxrepl=n
823
+ ),
824
+ dtype=self.dtype,
825
+ )
826
+ elif self.name is StringFunction.Name.ReplaceMany:
827
+ column, target, repl = columns
828
+ return Column(
829
+ plc.strings.replace.replace_multiple(column.obj, target.obj, repl.obj),
830
+ dtype=self.dtype,
831
+ )
832
+ elif self.name is StringFunction.Name.PadStart:
833
+ if POLARS_VERSION_LT_132: # pragma: no cover
834
+ (column,) = columns
835
+ width, char = self.options
836
+ else:
837
+ (column, width_col) = columns
838
+ (char,) = self.options
839
+ # TODO: Maybe accept a string scalar in
840
+ # cudf::strings::pad to avoid DtoH transfer
841
+ width = width_col.obj.to_scalar().to_py()
842
+ return Column(
843
+ plc.strings.padding.pad(
844
+ column.obj, width, plc.strings.SideType.LEFT, char
845
+ ),
846
+ dtype=self.dtype,
847
+ )
848
+ elif self.name is StringFunction.Name.PadEnd:
849
+ if POLARS_VERSION_LT_132: # pragma: no cover
850
+ (column,) = columns
851
+ width, char = self.options
852
+ else:
853
+ (column, width_col) = columns
854
+ (char,) = self.options
855
+ # TODO: Maybe accept a string scalar in
856
+ # cudf::strings::pad to avoid DtoH transfer
857
+ width = width_col.obj.to_scalar().to_py()
858
+ return Column(
859
+ plc.strings.padding.pad(
860
+ column.obj, width, plc.strings.SideType.RIGHT, char
861
+ ),
862
+ dtype=self.dtype,
863
+ )
864
+ elif self.name is StringFunction.Name.Reverse:
865
+ (column,) = columns
866
+ return Column(plc.strings.reverse.reverse(column.obj), dtype=self.dtype)
867
+ elif self.name is StringFunction.Name.Titlecase:
868
+ (column,) = columns
869
+ return Column(plc.strings.capitalize.title(column.obj), dtype=self.dtype)
870
+ raise NotImplementedError(
871
+ f"StringFunction {self.name}"
872
+ ) # pragma: no cover; handled by init raising
873
+
874
+
875
+ def _infer_datetime_format(val: str) -> str | None:
876
+ # port of parts of infer.rs and patterns.rs from polars rust
877
+ DATETIME_DMY_RE = re.compile(
878
+ r"""
879
+ ^
880
+ ['"]?
881
+ (\d{1,2})
882
+ [-/\.]
883
+ (?P<month>[01]?\d{1})
884
+ [-/\.]
885
+ (\d{4,})
886
+ (
887
+ [T\ ]
888
+ (\d{1,2})
889
+ :?
890
+ (\d{1,2})
891
+ (
892
+ :?
893
+ (\d{1,2})
894
+ (
895
+ \.(\d{1,9})
896
+ )?
897
+ )?
898
+ )?
899
+ ['"]?
900
+ $
901
+ """,
902
+ re.VERBOSE,
903
+ )
904
+
905
+ DATETIME_YMD_RE = re.compile(
906
+ r"""
907
+ ^
908
+ ['"]?
909
+ (\d{4,})
910
+ [-/\.]
911
+ (?P<month>[01]?\d{1})
912
+ [-/\.]
913
+ (\d{1,2})
914
+ (
915
+ [T\ ]
916
+ (\d{1,2})
917
+ :?
918
+ (\d{1,2})
919
+ (
920
+ :?
921
+ (\d{1,2})
922
+ (
923
+ \.(\d{1,9})
924
+ )?
925
+ )?
926
+ )?
927
+ ['"]?
928
+ $
929
+ """,
930
+ re.VERBOSE,
931
+ )
932
+
933
+ DATETIME_YMDZ_RE = re.compile(
934
+ r"""
935
+ ^
936
+ ['"]?
937
+ (\d{4,})
938
+ [-/\.]
939
+ (?P<month>[01]?\d{1})
940
+ [-/\.]
941
+ (\d{1,2})
942
+ [T\ ]
943
+ (\d{2})
944
+ :?
945
+ (\d{2})
946
+ (
947
+ :?
948
+ (\d{2})
949
+ (
950
+ \.(\d{1,9})
951
+ )?
952
+ )?
953
+ (
954
+ [+-](\d{2})(:?(\d{2}))? | Z
955
+ )
956
+ ['"]?
957
+ $
958
+ """,
959
+ re.VERBOSE,
960
+ )
961
+ PATTERN_FORMATS = {
962
+ "DATETIME_DMY": [
963
+ "%d-%m-%Y",
964
+ "%d/%m/%Y",
965
+ "%d.%m.%Y",
966
+ "%d-%m-%Y %H:%M:%S",
967
+ "%d/%m/%Y %H:%M:%S",
968
+ "%d.%m.%Y %H:%M:%S",
969
+ ],
970
+ "DATETIME_YMD": [
971
+ "%Y/%m/%d",
972
+ "%Y-%m-%d",
973
+ "%Y.%m.%d",
974
+ "%Y-%m-%d %H:%M:%S",
975
+ "%Y/%m/%d %H:%M:%S",
976
+ "%Y.%m.%d %H:%M:%S",
977
+ "%Y-%m-%dT%H:%M:%S",
978
+ ],
979
+ "DATETIME_YMDZ": [
980
+ "%Y-%m-%dT%H:%M:%S%z",
981
+ "%Y-%m-%dT%H:%M:%S.%f%z",
982
+ "%Y-%m-%d %H:%M:%S%z",
983
+ ],
984
+ }
985
+ for pattern_name, regex in [
986
+ ("DATETIME_DMY", DATETIME_DMY_RE),
987
+ ("DATETIME_YMD", DATETIME_YMD_RE),
988
+ ("DATETIME_YMDZ", DATETIME_YMDZ_RE),
989
+ ]:
990
+ m = regex.match(val)
991
+ if m:
992
+ month = int(m.group("month"))
993
+ if not (1 <= month <= 12):
994
+ continue
995
+ for fmt in PATTERN_FORMATS[pattern_name]:
996
+ try:
997
+ datetime.strptime(val, fmt)
998
+ except ValueError: # noqa: PERF203
999
+ continue
1000
+ else:
1001
+ return fmt
1002
+ return None