cudf-polars-cu12 24.12.0__py3-none-any.whl → 25.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. cudf_polars/VERSION +1 -1
  2. cudf_polars/__init__.py +1 -1
  3. cudf_polars/callback.py +28 -3
  4. cudf_polars/containers/__init__.py +1 -1
  5. cudf_polars/dsl/expr.py +16 -16
  6. cudf_polars/dsl/expressions/aggregation.py +21 -4
  7. cudf_polars/dsl/expressions/base.py +7 -2
  8. cudf_polars/dsl/expressions/binaryop.py +1 -0
  9. cudf_polars/dsl/expressions/boolean.py +65 -22
  10. cudf_polars/dsl/expressions/datetime.py +82 -20
  11. cudf_polars/dsl/expressions/literal.py +2 -0
  12. cudf_polars/dsl/expressions/rolling.py +3 -1
  13. cudf_polars/dsl/expressions/selection.py +3 -1
  14. cudf_polars/dsl/expressions/sorting.py +2 -0
  15. cudf_polars/dsl/expressions/string.py +118 -39
  16. cudf_polars/dsl/expressions/ternary.py +1 -0
  17. cudf_polars/dsl/expressions/unary.py +11 -1
  18. cudf_polars/dsl/ir.py +173 -122
  19. cudf_polars/dsl/to_ast.py +4 -6
  20. cudf_polars/dsl/translate.py +53 -21
  21. cudf_polars/dsl/traversal.py +10 -10
  22. cudf_polars/experimental/base.py +43 -0
  23. cudf_polars/experimental/dispatch.py +84 -0
  24. cudf_polars/experimental/io.py +325 -0
  25. cudf_polars/experimental/parallel.py +253 -0
  26. cudf_polars/experimental/select.py +36 -0
  27. cudf_polars/testing/asserts.py +14 -5
  28. cudf_polars/testing/plugin.py +60 -4
  29. cudf_polars/typing/__init__.py +5 -5
  30. cudf_polars/utils/dtypes.py +9 -7
  31. cudf_polars/utils/versions.py +4 -7
  32. {cudf_polars_cu12-24.12.0.dist-info → cudf_polars_cu12-25.2.0.dist-info}/METADATA +6 -6
  33. cudf_polars_cu12-25.2.0.dist-info/RECORD +48 -0
  34. {cudf_polars_cu12-24.12.0.dist-info → cudf_polars_cu12-25.2.0.dist-info}/WHEEL +1 -1
  35. cudf_polars_cu12-24.12.0.dist-info/RECORD +0 -43
  36. {cudf_polars_cu12-24.12.0.dist-info → cudf_polars_cu12-25.2.0.dist-info}/LICENSE +0 -0
  37. {cudf_polars_cu12-24.12.0.dist-info → cudf_polars_cu12-25.2.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
2
2
  # SPDX-License-Identifier: Apache-2.0
3
3
 
4
4
  """Translate polars IR representation to ours."""
@@ -84,7 +84,7 @@ class Translator:
84
84
  # IR is versioned with major.minor, minor is bumped for backwards
85
85
  # compatible changes (e.g. adding new nodes), major is bumped for
86
86
  # incompatible changes (e.g. renaming nodes).
87
- if (version := self.visitor.version()) >= (4, 0):
87
+ if (version := self.visitor.version()) >= (5, 1):
88
88
  e = NotImplementedError(
89
89
  f"No support for polars IR {version=}"
90
90
  ) # pragma: no cover; no such version for now.
@@ -260,9 +260,7 @@ def _(
260
260
  schema,
261
261
  node.df,
262
262
  node.projection,
263
- translate_named_expr(translator, n=node.selection)
264
- if node.selection is not None
265
- else None,
263
+ translator.config.config.copy(),
266
264
  )
267
265
 
268
266
 
@@ -301,25 +299,52 @@ def _(
301
299
  # Join key dtypes are dependent on the schema of the left and
302
300
  # right inputs, so these must be translated with the relevant
303
301
  # input active.
302
+ def adjust_literal_dtype(literal: expr.Literal) -> expr.Literal:
303
+ if literal.dtype.id() == plc.types.TypeId.INT32:
304
+ plc_int64 = plc.types.DataType(plc.types.TypeId.INT64)
305
+ return expr.Literal(
306
+ plc_int64,
307
+ pa.scalar(literal.value.as_py(), type=plc.interop.to_arrow(plc_int64)),
308
+ )
309
+ return literal
310
+
311
+ def maybe_adjust_binop(e) -> expr.Expr:
312
+ if isinstance(e.value, expr.BinOp):
313
+ left, right = e.value.children
314
+ if isinstance(left, expr.Col) and isinstance(right, expr.Literal):
315
+ e.value.children = (left, adjust_literal_dtype(right))
316
+ elif isinstance(left, expr.Literal) and isinstance(right, expr.Col):
317
+ e.value.children = (adjust_literal_dtype(left), right)
318
+ return e
319
+
320
+ def translate_expr_and_maybe_fix_binop_args(translator, exprs):
321
+ return [
322
+ maybe_adjust_binop(translate_named_expr(translator, n=e)) for e in exprs
323
+ ]
324
+
304
325
  with set_node(translator.visitor, node.input_left):
305
326
  inp_left = translator.translate_ir(n=None)
306
- left_on = [translate_named_expr(translator, n=e) for e in node.left_on]
327
+ # TODO: There's bug in the polars type coercion phase. Use
328
+ # translate_named_expr directly once it is resolved.
329
+ # Tracking issue: https://github.com/pola-rs/polars/issues/20935
330
+ left_on = translate_expr_and_maybe_fix_binop_args(translator, node.left_on)
307
331
  with set_node(translator.visitor, node.input_right):
308
332
  inp_right = translator.translate_ir(n=None)
309
- right_on = [translate_named_expr(translator, n=e) for e in node.right_on]
333
+ right_on = translate_expr_and_maybe_fix_binop_args(translator, node.right_on)
334
+
310
335
  if (how := node.options[0]) in {
311
- "inner",
312
- "left",
313
- "right",
314
- "full",
315
- "cross",
316
- "semi",
317
- "anti",
336
+ "Inner",
337
+ "Left",
338
+ "Right",
339
+ "Full",
340
+ "Cross",
341
+ "Semi",
342
+ "Anti",
318
343
  }:
319
344
  return ir.Join(schema, left_on, right_on, node.options, inp_left, inp_right)
320
345
  else:
321
- how, op1, op2 = how
322
- if how != "ie_join":
346
+ how, op1, op2 = node.options[0]
347
+ if how != "IEJoin":
323
348
  raise NotImplementedError(
324
349
  f"Unsupported join type {how}"
325
350
  ) # pragma: no cover; asof joins not yet exposed
@@ -531,10 +556,16 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex
531
556
  column.dtype,
532
557
  pa.scalar("", type=plc.interop.to_arrow(column.dtype)),
533
558
  )
534
- return expr.StringFunction(dtype, name, options, column, chars)
559
+ return expr.StringFunction(
560
+ dtype,
561
+ expr.StringFunction.Name.from_polars(name),
562
+ options,
563
+ column,
564
+ chars,
565
+ )
535
566
  return expr.StringFunction(
536
567
  dtype,
537
- name,
568
+ expr.StringFunction.Name.from_polars(name),
538
569
  options,
539
570
  *(translator.translate_expr(n=n) for n in node.input),
540
571
  )
@@ -551,7 +582,7 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex
551
582
  )
552
583
  return expr.BooleanFunction(
553
584
  dtype,
554
- name,
585
+ expr.BooleanFunction.Name.from_polars(name),
555
586
  options,
556
587
  *(translator.translate_expr(n=n) for n in node.input),
557
588
  )
@@ -571,7 +602,7 @@ def _(node: pl_expr.Function, translator: Translator, dtype: plc.DataType) -> ex
571
602
  }
572
603
  result_expr = expr.TemporalFunction(
573
604
  dtype,
574
- name,
605
+ expr.TemporalFunction.Name.from_polars(name),
575
606
  options,
576
607
  *(translator.translate_expr(n=n) for n in node.input),
577
608
  )
@@ -633,9 +664,10 @@ def _(node: pl_expr.Sort, translator: Translator, dtype: plc.DataType) -> expr.E
633
664
 
634
665
  @_translate_expr.register
635
666
  def _(node: pl_expr.SortBy, translator: Translator, dtype: plc.DataType) -> expr.Expr:
667
+ options = node.sort_options
636
668
  return expr.SortBy(
637
669
  dtype,
638
- node.sort_options,
670
+ (options[0], tuple(options[1]), tuple(options[2])),
639
671
  translator.translate_expr(n=node.expr),
640
672
  *(translator.translate_expr(n=n) for n in node.by),
641
673
  )
@@ -10,35 +10,35 @@ from typing import TYPE_CHECKING, Any, Generic
10
10
  from cudf_polars.typing import U_contra, V_co
11
11
 
12
12
  if TYPE_CHECKING:
13
- from collections.abc import Callable, Generator, Mapping, MutableMapping
13
+ from collections.abc import Callable, Generator, Mapping, MutableMapping, Sequence
14
14
 
15
15
  from cudf_polars.typing import GenericTransformer, NodeT
16
16
 
17
17
 
18
18
  __all__: list[str] = [
19
- "traversal",
20
- "reuse_if_unchanged",
21
- "make_recursive",
22
19
  "CachingVisitor",
20
+ "make_recursive",
21
+ "reuse_if_unchanged",
22
+ "traversal",
23
23
  ]
24
24
 
25
25
 
26
- def traversal(node: NodeT) -> Generator[NodeT, None, None]:
26
+ def traversal(nodes: Sequence[NodeT]) -> Generator[NodeT, None, None]:
27
27
  """
28
28
  Pre-order traversal of nodes in an expression.
29
29
 
30
30
  Parameters
31
31
  ----------
32
- node
33
- Root of expression to traverse.
32
+ nodes
33
+ Roots of expressions to traverse.
34
34
 
35
35
  Yields
36
36
  ------
37
- Unique nodes in the expression, parent before child, children
37
+ Unique nodes in the expressions, parent before child, children
38
38
  in-order from left to right.
39
39
  """
40
- seen = {node}
41
- lifo = [node]
40
+ seen = set(nodes)
41
+ lifo = list(nodes)
42
42
 
43
43
  while lifo:
44
44
  node = lifo.pop()
@@ -0,0 +1,43 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Multi-partition base classes."""
4
+
5
+ from __future__ import annotations
6
+
7
+ from typing import TYPE_CHECKING
8
+
9
+ from cudf_polars.dsl.ir import Union
10
+
11
+ if TYPE_CHECKING:
12
+ from collections.abc import Iterator, Sequence
13
+
14
+ from cudf_polars.containers import DataFrame
15
+ from cudf_polars.dsl.nodebase import Node
16
+
17
+
18
+ class PartitionInfo:
19
+ """
20
+ Partitioning information.
21
+
22
+ This class only tracks the partition count (for now).
23
+ """
24
+
25
+ __slots__ = ("count",)
26
+
27
+ def __init__(self, count: int):
28
+ self.count = count
29
+
30
+ def keys(self, node: Node) -> Iterator[tuple[str, int]]:
31
+ """Return the partitioned keys for a given node."""
32
+ name = get_key_name(node)
33
+ yield from ((name, i) for i in range(self.count))
34
+
35
+
36
+ def get_key_name(node: Node) -> str:
37
+ """Generate the key name for a Node."""
38
+ return f"{type(node).__name__.lower()}-{hash(node)}"
39
+
40
+
41
+ def _concat(dfs: Sequence[DataFrame]) -> DataFrame:
42
+ # Concatenate a sequence of DataFrames vertically
43
+ return Union.do_evaluate(None, *dfs)
@@ -0,0 +1,84 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Multi-partition dispatch functions."""
4
+
5
+ from __future__ import annotations
6
+
7
+ from functools import singledispatch
8
+ from typing import TYPE_CHECKING, Any
9
+
10
+ if TYPE_CHECKING:
11
+ from collections.abc import MutableMapping
12
+ from typing import TypeAlias
13
+
14
+ from cudf_polars.dsl.ir import IR
15
+ from cudf_polars.experimental.base import PartitionInfo
16
+ from cudf_polars.typing import GenericTransformer
17
+
18
+
19
+ LowerIRTransformer: TypeAlias = (
20
+ "GenericTransformer[IR, tuple[IR, MutableMapping[IR, PartitionInfo]]]"
21
+ )
22
+ """Protocol for Lowering IR nodes."""
23
+
24
+
25
+ @singledispatch
26
+ def lower_ir_node(
27
+ ir: IR, rec: LowerIRTransformer
28
+ ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
29
+ """
30
+ Rewrite an IR node and extract partitioning information.
31
+
32
+ Parameters
33
+ ----------
34
+ ir
35
+ IR node to rewrite.
36
+ rec
37
+ Recursive LowerIRTransformer callable.
38
+
39
+ Returns
40
+ -------
41
+ new_ir, partition_info
42
+ The rewritten node, and a mapping from unique nodes in
43
+ the full IR graph to associated partitioning information.
44
+
45
+ Notes
46
+ -----
47
+ This function is used by `lower_ir_graph`.
48
+
49
+ See Also
50
+ --------
51
+ lower_ir_graph
52
+ """
53
+ raise AssertionError(f"Unhandled type {type(ir)}") # pragma: no cover
54
+
55
+
56
+ @singledispatch
57
+ def generate_ir_tasks(
58
+ ir: IR, partition_info: MutableMapping[IR, PartitionInfo]
59
+ ) -> MutableMapping[Any, Any]:
60
+ """
61
+ Generate a task graph for evaluation of an IR node.
62
+
63
+ Parameters
64
+ ----------
65
+ ir
66
+ IR node to generate tasks for.
67
+ partition_info
68
+ Partitioning information, obtained from :func:`lower_ir_graph`.
69
+
70
+ Returns
71
+ -------
72
+ mapping
73
+ A (partial) dask task graph for the evaluation of an ir node.
74
+
75
+ Notes
76
+ -----
77
+ Task generation should only produce the tasks for the current node,
78
+ referring to child tasks by name.
79
+
80
+ See Also
81
+ --------
82
+ task_graph
83
+ """
84
+ raise AssertionError(f"Unhandled type {type(ir)}") # pragma: no cover
@@ -0,0 +1,325 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Multi-partition IO Logic."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import enum
8
+ import math
9
+ import random
10
+ from enum import IntEnum
11
+ from typing import TYPE_CHECKING, Any
12
+
13
+ import pylibcudf as plc
14
+
15
+ from cudf_polars.dsl.ir import IR, DataFrameScan, Scan, Union
16
+ from cudf_polars.experimental.base import PartitionInfo
17
+ from cudf_polars.experimental.dispatch import lower_ir_node
18
+
19
+ if TYPE_CHECKING:
20
+ from collections.abc import MutableMapping
21
+
22
+ from cudf_polars.dsl.expr import NamedExpr
23
+ from cudf_polars.experimental.dispatch import LowerIRTransformer
24
+ from cudf_polars.typing import Schema
25
+
26
+
27
+ @lower_ir_node.register(DataFrameScan)
28
+ def _(
29
+ ir: DataFrameScan, rec: LowerIRTransformer
30
+ ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
31
+ rows_per_partition = ir.config_options.get("executor_options", {}).get(
32
+ "max_rows_per_partition", 1_000_000
33
+ )
34
+
35
+ nrows = max(ir.df.shape()[0], 1)
36
+ count = math.ceil(nrows / rows_per_partition)
37
+
38
+ if count > 1:
39
+ length = math.ceil(nrows / count)
40
+ slices = [
41
+ DataFrameScan(
42
+ ir.schema,
43
+ ir.df.slice(offset, length),
44
+ ir.projection,
45
+ ir.config_options,
46
+ )
47
+ for offset in range(0, nrows, length)
48
+ ]
49
+ new_node = Union(ir.schema, None, *slices)
50
+ return new_node, {slice: PartitionInfo(count=1) for slice in slices} | {
51
+ new_node: PartitionInfo(count=count)
52
+ }
53
+
54
+ return ir, {ir: PartitionInfo(count=1)}
55
+
56
+
57
+ class ScanPartitionFlavor(IntEnum):
58
+ """Flavor of Scan partitioning."""
59
+
60
+ SINGLE_FILE = enum.auto() # 1:1 mapping between files and partitions
61
+ SPLIT_FILES = enum.auto() # Split each file into >1 partition
62
+ FUSED_FILES = enum.auto() # Fuse multiple files into each partition
63
+
64
+
65
+ class ScanPartitionPlan:
66
+ """
67
+ Scan partitioning plan.
68
+
69
+ Notes
70
+ -----
71
+ The meaning of `factor` depends on the value of `flavor`:
72
+ - SINGLE_FILE: `factor` must be `1`.
73
+ - SPLIT_FILES: `factor` is the number of partitions per file.
74
+ - FUSED_FILES: `factor` is the number of files per partition.
75
+ """
76
+
77
+ __slots__ = ("factor", "flavor")
78
+ factor: int
79
+ flavor: ScanPartitionFlavor
80
+
81
+ def __init__(self, factor: int, flavor: ScanPartitionFlavor) -> None:
82
+ if (
83
+ flavor == ScanPartitionFlavor.SINGLE_FILE and factor != 1
84
+ ): # pragma: no cover
85
+ raise ValueError(f"Expected factor == 1 for {flavor}, got: {factor}")
86
+ self.factor = factor
87
+ self.flavor = flavor
88
+
89
+ @staticmethod
90
+ def from_scan(ir: Scan) -> ScanPartitionPlan:
91
+ """Extract the partitioning plan of a Scan operation."""
92
+ if ir.typ == "parquet":
93
+ # TODO: Use system info to set default blocksize
94
+ parallel_options = ir.config_options.get("executor_options", {})
95
+ blocksize: int = parallel_options.get("parquet_blocksize", 1024**3)
96
+ stats = _sample_pq_statistics(ir)
97
+ file_size = sum(float(stats[column]) for column in ir.schema)
98
+ if file_size > 0:
99
+ if file_size > blocksize:
100
+ # Split large files
101
+ return ScanPartitionPlan(
102
+ math.ceil(file_size / blocksize),
103
+ ScanPartitionFlavor.SPLIT_FILES,
104
+ )
105
+ else:
106
+ # Fuse small files
107
+ return ScanPartitionPlan(
108
+ max(blocksize // int(file_size), 1),
109
+ ScanPartitionFlavor.FUSED_FILES,
110
+ )
111
+
112
+ # TODO: Use file sizes for csv and json
113
+ return ScanPartitionPlan(1, ScanPartitionFlavor.SINGLE_FILE)
114
+
115
+
116
+ class SplitScan(IR):
117
+ """
118
+ Input from a split file.
119
+
120
+ This class wraps a single-file `Scan` object. At
121
+ IO/evaluation time, this class will only perform
122
+ a partial read of the underlying file. The range
123
+ (skip_rows and n_rows) is calculated at IO time.
124
+ """
125
+
126
+ __slots__ = (
127
+ "base_scan",
128
+ "schema",
129
+ "split_index",
130
+ "total_splits",
131
+ )
132
+ _non_child = (
133
+ "schema",
134
+ "base_scan",
135
+ "split_index",
136
+ "total_splits",
137
+ )
138
+ base_scan: Scan
139
+ """Scan operation this node is based on."""
140
+ split_index: int
141
+ """Index of the current split."""
142
+ total_splits: int
143
+ """Total number of splits."""
144
+
145
+ def __init__(
146
+ self, schema: Schema, base_scan: Scan, split_index: int, total_splits: int
147
+ ):
148
+ self.schema = schema
149
+ self.base_scan = base_scan
150
+ self.split_index = split_index
151
+ self.total_splits = total_splits
152
+ self._non_child_args = (
153
+ split_index,
154
+ total_splits,
155
+ *base_scan._non_child_args,
156
+ )
157
+ self.children = ()
158
+ if base_scan.typ not in ("parquet",): # pragma: no cover
159
+ raise NotImplementedError(
160
+ f"Unhandled Scan type for file splitting: {base_scan.typ}"
161
+ )
162
+
163
+ @classmethod
164
+ def do_evaluate(
165
+ cls,
166
+ split_index: int,
167
+ total_splits: int,
168
+ schema: Schema,
169
+ typ: str,
170
+ reader_options: dict[str, Any],
171
+ config_options: dict[str, Any],
172
+ paths: list[str],
173
+ with_columns: list[str] | None,
174
+ skip_rows: int,
175
+ n_rows: int,
176
+ row_index: tuple[str, int] | None,
177
+ predicate: NamedExpr | None,
178
+ ):
179
+ """Evaluate and return a dataframe."""
180
+ if typ not in ("parquet",): # pragma: no cover
181
+ raise NotImplementedError(f"Unhandled Scan type for file splitting: {typ}")
182
+
183
+ if len(paths) > 1: # pragma: no cover
184
+ raise ValueError(f"Expected a single path, got: {paths}")
185
+
186
+ # Parquet logic:
187
+ # - We are one of "total_splits" SplitScan nodes
188
+ # assigned to the same file.
189
+ # - We know our index within this file ("split_index")
190
+ # - We can also use parquet metadata to query the
191
+ # total number of rows in each row-group of the file.
192
+ # - We can use all this information to calculate the
193
+ # "skip_rows" and "n_rows" options to use locally.
194
+
195
+ rowgroup_metadata = plc.io.parquet_metadata.read_parquet_metadata(
196
+ plc.io.SourceInfo(paths)
197
+ ).rowgroup_metadata()
198
+ total_row_groups = len(rowgroup_metadata)
199
+ if total_splits <= total_row_groups:
200
+ # We have enough row-groups in the file to align
201
+ # all "total_splits" of our reads with row-group
202
+ # boundaries. Calculate which row-groups to include
203
+ # in the current read, and use metadata to translate
204
+ # the row-group indices to "skip_rows" and "n_rows".
205
+ rg_stride = total_row_groups // total_splits
206
+ skip_rgs = rg_stride * split_index
207
+ skip_rows = sum(rg["num_rows"] for rg in rowgroup_metadata[:skip_rgs])
208
+ n_rows = sum(
209
+ rg["num_rows"]
210
+ for rg in rowgroup_metadata[skip_rgs : skip_rgs + rg_stride]
211
+ )
212
+ else:
213
+ # There are not enough row-groups to align
214
+ # all "total_splits" of our reads with row-group
215
+ # boundaries. Use metadata to directly calculate
216
+ # "skip_rows" and "n_rows" for the current read.
217
+ total_rows = sum(rg["num_rows"] for rg in rowgroup_metadata)
218
+ n_rows = total_rows // total_splits
219
+ skip_rows = n_rows * split_index
220
+
221
+ # Last split should always read to end of file
222
+ if split_index == (total_splits - 1):
223
+ n_rows = -1
224
+
225
+ # Perform the partial read
226
+ return Scan.do_evaluate(
227
+ schema,
228
+ typ,
229
+ reader_options,
230
+ config_options,
231
+ paths,
232
+ with_columns,
233
+ skip_rows,
234
+ n_rows,
235
+ row_index,
236
+ predicate,
237
+ )
238
+
239
+
240
+ def _sample_pq_statistics(ir: Scan) -> dict[str, float]:
241
+ import numpy as np
242
+ import pyarrow.dataset as pa_ds
243
+
244
+ # Use average total_uncompressed_size of three files
245
+ # TODO: Use plc.io.parquet_metadata.read_parquet_metadata
246
+ n_sample = 3
247
+ column_sizes = {}
248
+ ds = pa_ds.dataset(random.sample(ir.paths, n_sample), format="parquet")
249
+ for i, frag in enumerate(ds.get_fragments()):
250
+ md = frag.metadata
251
+ for rg in range(md.num_row_groups):
252
+ row_group = md.row_group(rg)
253
+ for col in range(row_group.num_columns):
254
+ column = row_group.column(col)
255
+ name = column.path_in_schema
256
+ if name not in column_sizes:
257
+ column_sizes[name] = np.zeros(n_sample, dtype="int64")
258
+ column_sizes[name][i] += column.total_uncompressed_size
259
+
260
+ return {name: np.mean(sizes) for name, sizes in column_sizes.items()}
261
+
262
+
263
+ @lower_ir_node.register(Scan)
264
+ def _(
265
+ ir: Scan, rec: LowerIRTransformer
266
+ ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
267
+ partition_info: MutableMapping[IR, PartitionInfo]
268
+ if ir.typ in ("csv", "parquet", "ndjson") and ir.n_rows == -1 and ir.skip_rows == 0:
269
+ plan = ScanPartitionPlan.from_scan(ir)
270
+ paths = list(ir.paths)
271
+ if plan.flavor == ScanPartitionFlavor.SPLIT_FILES:
272
+ # Disable chunked reader when splitting files
273
+ config_options = ir.config_options.copy()
274
+ config_options["parquet_options"] = config_options.get(
275
+ "parquet_options", {}
276
+ ).copy()
277
+ config_options["parquet_options"]["chunked"] = False
278
+
279
+ slices: list[SplitScan] = []
280
+ for path in paths:
281
+ base_scan = Scan(
282
+ ir.schema,
283
+ ir.typ,
284
+ ir.reader_options,
285
+ ir.cloud_options,
286
+ config_options,
287
+ [path],
288
+ ir.with_columns,
289
+ ir.skip_rows,
290
+ ir.n_rows,
291
+ ir.row_index,
292
+ ir.predicate,
293
+ )
294
+ slices.extend(
295
+ SplitScan(ir.schema, base_scan, sindex, plan.factor)
296
+ for sindex in range(plan.factor)
297
+ )
298
+ new_node = Union(ir.schema, None, *slices)
299
+ partition_info = {slice: PartitionInfo(count=1) for slice in slices} | {
300
+ new_node: PartitionInfo(count=len(slices))
301
+ }
302
+ else:
303
+ groups: list[Scan] = [
304
+ Scan(
305
+ ir.schema,
306
+ ir.typ,
307
+ ir.reader_options,
308
+ ir.cloud_options,
309
+ ir.config_options,
310
+ paths[i : i + plan.factor],
311
+ ir.with_columns,
312
+ ir.skip_rows,
313
+ ir.n_rows,
314
+ ir.row_index,
315
+ ir.predicate,
316
+ )
317
+ for i in range(0, len(paths), plan.factor)
318
+ ]
319
+ new_node = Union(ir.schema, None, *groups)
320
+ partition_info = {group: PartitionInfo(count=1) for group in groups} | {
321
+ new_node: PartitionInfo(count=len(groups))
322
+ }
323
+ return new_node, partition_info
324
+
325
+ return ir, {ir: PartitionInfo(count=1)} # pragma: no cover