cudf-polars-cu13 25.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. cudf_polars/GIT_COMMIT +1 -0
  2. cudf_polars/VERSION +1 -0
  3. cudf_polars/__init__.py +28 -0
  4. cudf_polars/_version.py +21 -0
  5. cudf_polars/callback.py +318 -0
  6. cudf_polars/containers/__init__.py +13 -0
  7. cudf_polars/containers/column.py +495 -0
  8. cudf_polars/containers/dataframe.py +361 -0
  9. cudf_polars/containers/datatype.py +137 -0
  10. cudf_polars/dsl/__init__.py +8 -0
  11. cudf_polars/dsl/expr.py +66 -0
  12. cudf_polars/dsl/expressions/__init__.py +8 -0
  13. cudf_polars/dsl/expressions/aggregation.py +226 -0
  14. cudf_polars/dsl/expressions/base.py +272 -0
  15. cudf_polars/dsl/expressions/binaryop.py +120 -0
  16. cudf_polars/dsl/expressions/boolean.py +326 -0
  17. cudf_polars/dsl/expressions/datetime.py +271 -0
  18. cudf_polars/dsl/expressions/literal.py +97 -0
  19. cudf_polars/dsl/expressions/rolling.py +643 -0
  20. cudf_polars/dsl/expressions/selection.py +74 -0
  21. cudf_polars/dsl/expressions/slicing.py +46 -0
  22. cudf_polars/dsl/expressions/sorting.py +85 -0
  23. cudf_polars/dsl/expressions/string.py +1002 -0
  24. cudf_polars/dsl/expressions/struct.py +137 -0
  25. cudf_polars/dsl/expressions/ternary.py +49 -0
  26. cudf_polars/dsl/expressions/unary.py +517 -0
  27. cudf_polars/dsl/ir.py +2607 -0
  28. cudf_polars/dsl/nodebase.py +164 -0
  29. cudf_polars/dsl/to_ast.py +359 -0
  30. cudf_polars/dsl/tracing.py +16 -0
  31. cudf_polars/dsl/translate.py +939 -0
  32. cudf_polars/dsl/traversal.py +224 -0
  33. cudf_polars/dsl/utils/__init__.py +8 -0
  34. cudf_polars/dsl/utils/aggregations.py +481 -0
  35. cudf_polars/dsl/utils/groupby.py +98 -0
  36. cudf_polars/dsl/utils/naming.py +34 -0
  37. cudf_polars/dsl/utils/replace.py +61 -0
  38. cudf_polars/dsl/utils/reshape.py +74 -0
  39. cudf_polars/dsl/utils/rolling.py +121 -0
  40. cudf_polars/dsl/utils/windows.py +192 -0
  41. cudf_polars/experimental/__init__.py +8 -0
  42. cudf_polars/experimental/base.py +386 -0
  43. cudf_polars/experimental/benchmarks/__init__.py +4 -0
  44. cudf_polars/experimental/benchmarks/pdsds.py +220 -0
  45. cudf_polars/experimental/benchmarks/pdsds_queries/__init__.py +4 -0
  46. cudf_polars/experimental/benchmarks/pdsds_queries/q1.py +88 -0
  47. cudf_polars/experimental/benchmarks/pdsds_queries/q10.py +225 -0
  48. cudf_polars/experimental/benchmarks/pdsds_queries/q2.py +244 -0
  49. cudf_polars/experimental/benchmarks/pdsds_queries/q3.py +65 -0
  50. cudf_polars/experimental/benchmarks/pdsds_queries/q4.py +359 -0
  51. cudf_polars/experimental/benchmarks/pdsds_queries/q5.py +462 -0
  52. cudf_polars/experimental/benchmarks/pdsds_queries/q6.py +92 -0
  53. cudf_polars/experimental/benchmarks/pdsds_queries/q7.py +79 -0
  54. cudf_polars/experimental/benchmarks/pdsds_queries/q8.py +524 -0
  55. cudf_polars/experimental/benchmarks/pdsds_queries/q9.py +137 -0
  56. cudf_polars/experimental/benchmarks/pdsh.py +814 -0
  57. cudf_polars/experimental/benchmarks/utils.py +832 -0
  58. cudf_polars/experimental/dask_registers.py +200 -0
  59. cudf_polars/experimental/dispatch.py +156 -0
  60. cudf_polars/experimental/distinct.py +197 -0
  61. cudf_polars/experimental/explain.py +157 -0
  62. cudf_polars/experimental/expressions.py +590 -0
  63. cudf_polars/experimental/groupby.py +327 -0
  64. cudf_polars/experimental/io.py +943 -0
  65. cudf_polars/experimental/join.py +391 -0
  66. cudf_polars/experimental/parallel.py +423 -0
  67. cudf_polars/experimental/repartition.py +69 -0
  68. cudf_polars/experimental/scheduler.py +155 -0
  69. cudf_polars/experimental/select.py +188 -0
  70. cudf_polars/experimental/shuffle.py +354 -0
  71. cudf_polars/experimental/sort.py +609 -0
  72. cudf_polars/experimental/spilling.py +151 -0
  73. cudf_polars/experimental/statistics.py +795 -0
  74. cudf_polars/experimental/utils.py +169 -0
  75. cudf_polars/py.typed +0 -0
  76. cudf_polars/testing/__init__.py +8 -0
  77. cudf_polars/testing/asserts.py +448 -0
  78. cudf_polars/testing/io.py +122 -0
  79. cudf_polars/testing/plugin.py +236 -0
  80. cudf_polars/typing/__init__.py +219 -0
  81. cudf_polars/utils/__init__.py +8 -0
  82. cudf_polars/utils/config.py +741 -0
  83. cudf_polars/utils/conversion.py +40 -0
  84. cudf_polars/utils/dtypes.py +118 -0
  85. cudf_polars/utils/sorting.py +53 -0
  86. cudf_polars/utils/timer.py +39 -0
  87. cudf_polars/utils/versions.py +27 -0
  88. cudf_polars_cu13-25.10.0.dist-info/METADATA +136 -0
  89. cudf_polars_cu13-25.10.0.dist-info/RECORD +92 -0
  90. cudf_polars_cu13-25.10.0.dist-info/WHEEL +5 -0
  91. cudf_polars_cu13-25.10.0.dist-info/licenses/LICENSE +201 -0
  92. cudf_polars_cu13-25.10.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,423 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Multi-partition evaluation."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import itertools
8
+ import operator
9
+ from functools import partial, reduce
10
+ from typing import TYPE_CHECKING, Any
11
+
12
+ import cudf_polars.experimental.distinct
13
+ import cudf_polars.experimental.groupby
14
+ import cudf_polars.experimental.io
15
+ import cudf_polars.experimental.join
16
+ import cudf_polars.experimental.select
17
+ import cudf_polars.experimental.shuffle
18
+ import cudf_polars.experimental.sort # noqa: F401
19
+ from cudf_polars.dsl.ir import (
20
+ IR,
21
+ Cache,
22
+ Filter,
23
+ HConcat,
24
+ HStack,
25
+ MapFunction,
26
+ Projection,
27
+ Slice,
28
+ Union,
29
+ )
30
+ from cudf_polars.dsl.traversal import CachingVisitor, traversal
31
+ from cudf_polars.experimental.base import PartitionInfo, get_key_name
32
+ from cudf_polars.experimental.dispatch import (
33
+ generate_ir_tasks,
34
+ lower_ir_node,
35
+ )
36
+ from cudf_polars.experimental.io import _clear_source_info_cache
37
+ from cudf_polars.experimental.repartition import Repartition
38
+ from cudf_polars.experimental.statistics import collect_statistics
39
+ from cudf_polars.experimental.utils import _concat, _contains_over, _lower_ir_fallback
40
+
41
+ if TYPE_CHECKING:
42
+ from collections.abc import MutableMapping
43
+ from typing import Any
44
+
45
+ from cudf_polars.containers import DataFrame
46
+ from cudf_polars.experimental.dispatch import LowerIRTransformer, State
47
+ from cudf_polars.utils.config import ConfigOptions
48
+
49
+
50
+ @lower_ir_node.register(IR)
51
+ def _(
52
+ ir: IR, rec: LowerIRTransformer
53
+ ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]: # pragma: no cover
54
+ # Default logic - Requires single partition
55
+ return _lower_ir_fallback(
56
+ ir, rec, msg=f"Class {type(ir)} does not support multiple partitions."
57
+ )
58
+
59
+
60
+ def lower_ir_graph(
61
+ ir: IR, config_options: ConfigOptions
62
+ ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
63
+ """
64
+ Rewrite an IR graph and extract partitioning information.
65
+
66
+ Parameters
67
+ ----------
68
+ ir
69
+ Root of the graph to rewrite.
70
+ config_options
71
+ GPUEngine configuration options.
72
+
73
+ Returns
74
+ -------
75
+ new_ir, partition_info
76
+ The rewritten graph, and a mapping from unique nodes
77
+ in the new graph to associated partitioning information.
78
+
79
+ Notes
80
+ -----
81
+ This function traverses the unique nodes of the graph with
82
+ root `ir`, and applies :func:`lower_ir_node` to each node.
83
+
84
+ See Also
85
+ --------
86
+ lower_ir_node
87
+ """
88
+ state: State = {
89
+ "config_options": config_options,
90
+ "stats": collect_statistics(ir, config_options),
91
+ }
92
+ mapper: LowerIRTransformer = CachingVisitor(lower_ir_node, state=state)
93
+ return mapper(ir)
94
+
95
+
96
+ def task_graph(
97
+ ir: IR,
98
+ partition_info: MutableMapping[IR, PartitionInfo],
99
+ config_options: ConfigOptions,
100
+ ) -> tuple[MutableMapping[Any, Any], str | tuple[str, int]]:
101
+ """
102
+ Construct a task graph for evaluation of an IR graph.
103
+
104
+ Parameters
105
+ ----------
106
+ ir
107
+ Root of the graph to rewrite.
108
+ partition_info
109
+ A mapping from all unique IR nodes to the
110
+ associated partitioning information.
111
+ config_options
112
+ GPUEngine configuration options.
113
+
114
+ Returns
115
+ -------
116
+ graph
117
+ A Dask-compatible task graph for the entire
118
+ IR graph with root `ir`.
119
+
120
+ Notes
121
+ -----
122
+ This function traverses the unique nodes of the
123
+ graph with root `ir`, and extracts the tasks for
124
+ each node with :func:`generate_ir_tasks`.
125
+
126
+ The output is passed into :func:`post_process_task_graph` to
127
+ add any additional processing that is specific to the executor.
128
+
129
+ See Also
130
+ --------
131
+ generate_ir_tasks
132
+ """
133
+ graph = reduce(
134
+ operator.or_,
135
+ (generate_ir_tasks(node, partition_info) for node in traversal([ir])),
136
+ )
137
+
138
+ key_name = get_key_name(ir)
139
+ partition_count = partition_info[ir].count
140
+
141
+ key: str | tuple[str, int]
142
+ if partition_count > 1:
143
+ graph[key_name] = (_concat, *partition_info[ir].keys(ir))
144
+ key = key_name
145
+ else:
146
+ key = (key_name, 0)
147
+
148
+ graph = post_process_task_graph(graph, key, config_options)
149
+ return graph, key
150
+
151
+
152
+ # The true type signature for get_scheduler() needs an overload. Not worth it.
153
+
154
+
155
+ def get_scheduler(config_options: ConfigOptions) -> Any:
156
+ """Get appropriate task scheduler."""
157
+ assert config_options.executor.name == "streaming", (
158
+ "'in-memory' executor not supported in 'generate_ir_tasks'"
159
+ )
160
+
161
+ scheduler = config_options.executor.scheduler
162
+
163
+ if (
164
+ scheduler == "distributed"
165
+ ): # pragma: no cover; block depends on executor type and Distributed cluster
166
+ from distributed import get_client
167
+
168
+ from cudf_polars.experimental.dask_registers import DaskRegisterManager
169
+
170
+ client = get_client()
171
+ DaskRegisterManager.register_once()
172
+ DaskRegisterManager.run_on_cluster(client)
173
+ return client.get
174
+ elif scheduler == "synchronous":
175
+ from cudf_polars.experimental.scheduler import synchronous_scheduler
176
+
177
+ return synchronous_scheduler
178
+ else: # pragma: no cover
179
+ raise ValueError(f"{scheduler} not a supported scheduler option.")
180
+
181
+
182
+ def post_process_task_graph(
183
+ graph: MutableMapping[Any, Any],
184
+ key: str | tuple[str, int],
185
+ config_options: ConfigOptions,
186
+ ) -> MutableMapping[Any, Any]:
187
+ """
188
+ Post-process the task graph.
189
+
190
+ Parameters
191
+ ----------
192
+ graph
193
+ Task graph to post-process.
194
+ key
195
+ Output key for the graph.
196
+ config_options
197
+ GPUEngine configuration options.
198
+
199
+ Returns
200
+ -------
201
+ graph
202
+ A Dask-compatible task graph.
203
+ """
204
+ assert config_options.executor.name == "streaming", (
205
+ "'in-memory' executor not supported in 'post_process_task_graph'"
206
+ )
207
+
208
+ if config_options.executor.rapidsmpf_spill: # pragma: no cover
209
+ from cudf_polars.experimental.spilling import wrap_dataframe_in_spillable
210
+
211
+ return wrap_dataframe_in_spillable(
212
+ graph, ignore_key=key, config_options=config_options
213
+ )
214
+ return graph
215
+
216
+
217
+ def evaluate_streaming(
218
+ ir: IR,
219
+ config_options: ConfigOptions,
220
+ ) -> DataFrame:
221
+ """
222
+ Evaluate an IR graph with partitioning.
223
+
224
+ Parameters
225
+ ----------
226
+ ir
227
+ Logical plan to evaluate.
228
+ config_options
229
+ GPUEngine configuration options.
230
+
231
+ Returns
232
+ -------
233
+ A cudf-polars DataFrame object.
234
+ """
235
+ # Clear source info cache in case data was overwritten
236
+ _clear_source_info_cache()
237
+
238
+ ir, partition_info = lower_ir_graph(ir, config_options)
239
+
240
+ graph, key = task_graph(ir, partition_info, config_options)
241
+
242
+ return get_scheduler(config_options)(graph, key)
243
+
244
+
245
+ @generate_ir_tasks.register(IR)
246
+ def _(
247
+ ir: IR, partition_info: MutableMapping[IR, PartitionInfo]
248
+ ) -> MutableMapping[Any, Any]:
249
+ # Generate pointwise (embarrassingly-parallel) tasks by default
250
+ child_names = [get_key_name(c) for c in ir.children]
251
+ bcast_child = [partition_info[c].count == 1 for c in ir.children]
252
+
253
+ return {
254
+ key: (
255
+ ir.do_evaluate,
256
+ *ir._non_child_args,
257
+ *[
258
+ (child_name, 0 if bcast_child[j] else i)
259
+ for j, child_name in enumerate(child_names)
260
+ ],
261
+ )
262
+ for i, key in enumerate(partition_info[ir].keys(ir))
263
+ }
264
+
265
+
266
+ @lower_ir_node.register(Union)
267
+ def _(
268
+ ir: Union, rec: LowerIRTransformer
269
+ ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
270
+ # Check zlice
271
+ if ir.zlice is not None:
272
+ return rec(
273
+ Slice(
274
+ ir.schema,
275
+ *ir.zlice,
276
+ Union(ir.schema, None, *ir.children),
277
+ )
278
+ )
279
+
280
+ # Lower children
281
+ children, _partition_info = zip(*(rec(c) for c in ir.children), strict=True)
282
+ partition_info = reduce(operator.or_, _partition_info)
283
+
284
+ # Partition count is the sum of all child partitions
285
+ count = sum(partition_info[c].count for c in children)
286
+
287
+ # Return reconstructed node and partition-info dict
288
+ new_node = ir.reconstruct(children)
289
+ partition_info[new_node] = PartitionInfo(count=count)
290
+ return new_node, partition_info
291
+
292
+
293
+ @generate_ir_tasks.register(Union)
294
+ def _(
295
+ ir: Union, partition_info: MutableMapping[IR, PartitionInfo]
296
+ ) -> MutableMapping[Any, Any]:
297
+ key_name = get_key_name(ir)
298
+ partition = itertools.count()
299
+ return {
300
+ (key_name, next(partition)): child_key
301
+ for child in ir.children
302
+ for child_key in partition_info[child].keys(child)
303
+ }
304
+
305
+
306
+ @lower_ir_node.register(MapFunction)
307
+ def _(
308
+ ir: MapFunction, rec: LowerIRTransformer
309
+ ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
310
+ # Allow pointwise operations
311
+ if ir.name in ("rename", "explode"):
312
+ return _lower_ir_pwise(ir, rec)
313
+
314
+ # Fallback for everything else
315
+ return _lower_ir_fallback(
316
+ ir, rec, msg=f"{ir.name} is not supported for multiple partitions."
317
+ )
318
+
319
+
320
+ def _lower_ir_pwise(
321
+ ir: IR, rec: LowerIRTransformer, *, preserve_partitioning: bool = False
322
+ ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
323
+ # Lower a partition-wise (i.e. embarrassingly-parallel) IR node
324
+
325
+ # Lower children
326
+ children, _partition_info = zip(*(rec(c) for c in ir.children), strict=True)
327
+ partition_info = reduce(operator.or_, _partition_info)
328
+ counts = {partition_info[c].count for c in children}
329
+
330
+ # Check that child partitioning is supported
331
+ if len(counts) > 1: # pragma: no cover
332
+ return _lower_ir_fallback(
333
+ ir,
334
+ rec,
335
+ msg=f"Class {type(ir)} does not support children with mismatched partition counts.",
336
+ )
337
+
338
+ # Preserve child partition_info if possible
339
+ if preserve_partitioning and len(children) == 1:
340
+ partition = partition_info[children[0]]
341
+ else:
342
+ partition = PartitionInfo(count=max(counts))
343
+
344
+ # Return reconstructed node and partition-info dict
345
+ new_node = ir.reconstruct(children)
346
+ partition_info[new_node] = partition
347
+ return new_node, partition_info
348
+
349
+
350
+ _lower_ir_pwise_preserve = partial(_lower_ir_pwise, preserve_partitioning=True)
351
+ lower_ir_node.register(Projection, _lower_ir_pwise_preserve)
352
+ lower_ir_node.register(Cache, _lower_ir_pwise)
353
+ lower_ir_node.register(HConcat, _lower_ir_pwise)
354
+
355
+
356
+ @lower_ir_node.register(Filter)
357
+ def _(
358
+ ir: Filter, rec: LowerIRTransformer
359
+ ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
360
+ child, partition_info = rec(ir.children[0])
361
+
362
+ if partition_info[child].count > 1 and _contains_over([ir.mask.value]):
363
+ # mask contains .over(...), collapse to single partition
364
+ return _lower_ir_fallback(
365
+ ir.reconstruct([child]),
366
+ rec,
367
+ msg=(
368
+ "over(...) inside filter is not supported for multiple partitions; "
369
+ "falling back to in-memory evaluation."
370
+ ),
371
+ )
372
+
373
+ if partition_info[child].count > 1 and not all(
374
+ expr.is_pointwise for expr in traversal([ir.mask.value])
375
+ ):
376
+ # TODO: Use expression decomposition to lower Filter
377
+ # See: https://github.com/rapidsai/cudf/issues/20076
378
+ return _lower_ir_fallback(
379
+ ir, rec, msg="This filter is not supported for multiple partitions."
380
+ )
381
+
382
+ new_node = ir.reconstruct([child])
383
+ partition_info[new_node] = partition_info[child]
384
+ return new_node, partition_info
385
+
386
+
387
+ @lower_ir_node.register(Slice)
388
+ def _(
389
+ ir: Slice, rec: LowerIRTransformer
390
+ ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
391
+ if ir.offset == 0:
392
+ # Taking the first N rows.
393
+ # We don't know how large each partition is, so we reduce.
394
+ new_node, partition_info = _lower_ir_pwise(ir, rec)
395
+ if partition_info[new_node].count > 1:
396
+ # Collapse down to single partition
397
+ inter = Repartition(new_node.schema, new_node)
398
+ partition_info[inter] = PartitionInfo(count=1)
399
+ # Slice reduced partition
400
+ new_node = ir.reconstruct([inter])
401
+ partition_info[new_node] = PartitionInfo(count=1)
402
+ return new_node, partition_info
403
+
404
+ # Fallback
405
+ return _lower_ir_fallback(
406
+ ir, rec, msg="This slice not supported for multiple partitions."
407
+ )
408
+
409
+
410
+ @lower_ir_node.register(HStack)
411
+ def _(
412
+ ir: HStack, rec: LowerIRTransformer
413
+ ) -> tuple[IR, MutableMapping[IR, PartitionInfo]]:
414
+ if not all(expr.is_pointwise for expr in traversal([e.value for e in ir.columns])):
415
+ # TODO: Avoid fallback if/when possible
416
+ return _lower_ir_fallback(
417
+ ir, rec, msg="This HStack not supported for multiple partitions."
418
+ )
419
+
420
+ child, partition_info = rec(ir.children[0])
421
+ new_node = ir.reconstruct([child])
422
+ partition_info[new_node] = partition_info[child]
423
+ return new_node, partition_info
@@ -0,0 +1,69 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Repartitioning Logic."""
4
+
5
+ from __future__ import annotations
6
+
7
+ import itertools
8
+ from typing import TYPE_CHECKING, Any
9
+
10
+ from cudf_polars.dsl.ir import IR
11
+ from cudf_polars.experimental.base import get_key_name
12
+ from cudf_polars.experimental.dispatch import generate_ir_tasks
13
+ from cudf_polars.experimental.utils import _concat
14
+
15
+ if TYPE_CHECKING:
16
+ from collections.abc import MutableMapping
17
+
18
+ from cudf_polars.experimental.parallel import PartitionInfo
19
+ from cudf_polars.typing import Schema
20
+
21
+
22
+ class Repartition(IR):
23
+ """
24
+ Repartition a DataFrame.
25
+
26
+ Notes
27
+ -----
28
+ Repartitioning means that we are not modifying any
29
+ data, nor are we reordering or shuffling rows. We
30
+ are only changing the overall partition count. For
31
+ now, we only support an N -> [1...N] repartitioning
32
+ (inclusive). The output partition count is tracked
33
+ separately using PartitionInfo.
34
+ """
35
+
36
+ __slots__ = ()
37
+ _non_child = ("schema",)
38
+
39
+ def __init__(self, schema: Schema, df: IR):
40
+ self.schema = schema
41
+ self._non_child_args = ()
42
+ self.children = (df,)
43
+
44
+
45
+ @generate_ir_tasks.register(Repartition)
46
+ def _(
47
+ ir: Repartition, partition_info: MutableMapping[IR, PartitionInfo]
48
+ ) -> MutableMapping[Any, Any]:
49
+ # Repartition an IR node.
50
+ # Only supports rapartitioning to fewer (for now).
51
+
52
+ (child,) = ir.children
53
+ count_in = partition_info[child].count
54
+ count_out = partition_info[ir].count
55
+
56
+ if count_out > count_in: # pragma: no cover
57
+ raise NotImplementedError(
58
+ f"Repartition {count_in} -> {count_out} not supported."
59
+ )
60
+
61
+ key_name = get_key_name(ir)
62
+ n, remainder = divmod(count_in, count_out)
63
+ # Spread remainder evenly over the partitions.
64
+ offsets = [0, *itertools.accumulate(n + (i < remainder) for i in range(count_out))]
65
+ child_keys = tuple(partition_info[child].keys(child))
66
+ return {
67
+ (key_name, i): (_concat, *child_keys[offsets[i] : offsets[i + 1]])
68
+ for i in range(count_out)
69
+ }
@@ -0,0 +1,155 @@
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ """Synchronous task scheduler."""
4
+
5
+ from __future__ import annotations
6
+
7
+ from collections import Counter
8
+ from collections.abc import MutableMapping
9
+ from itertools import chain
10
+ from typing import TYPE_CHECKING, Any, TypeVar
11
+
12
+ from typing_extensions import Unpack
13
+
14
+ if TYPE_CHECKING:
15
+ from collections.abc import Mapping
16
+ from typing import TypeAlias
17
+
18
+
19
+ Key: TypeAlias = str | tuple[str, Unpack[tuple[int, ...]]]
20
+ Graph: TypeAlias = MutableMapping[Key, Any]
21
+ T_ = TypeVar("T_")
22
+
23
+
24
+ # NOTE: This is a slimmed-down version of the single-threaded
25
+ # (synchronous) scheduler in `dask.core`.
26
+ #
27
+ # Key Differences:
28
+ # * We do not allow a task to contain a list of key names.
29
+ # Keys must be distinct elements of the task.
30
+ # * We do not support nested tasks.
31
+
32
+
33
+ def istask(x: Any) -> bool:
34
+ """Check if x is a callable task."""
35
+ return isinstance(x, tuple) and bool(x) and callable(x[0])
36
+
37
+
38
+ def is_hashable(x: Any) -> bool:
39
+ """Check if x is hashable."""
40
+ try:
41
+ hash(x)
42
+ except BaseException:
43
+ return False
44
+ else:
45
+ return True
46
+
47
+
48
+ def _execute_task(arg: Any, cache: Mapping) -> Any:
49
+ """Execute a compute task."""
50
+ if istask(arg):
51
+ return arg[0](*(_execute_task(a, cache) for a in arg[1:]))
52
+ elif is_hashable(arg):
53
+ return cache.get(arg, arg)
54
+ else:
55
+ return arg
56
+
57
+
58
+ def required_keys(key: Key, graph: Graph) -> list[Key]:
59
+ """
60
+ Return the dependencies to extract a key from the graph.
61
+
62
+ Parameters
63
+ ----------
64
+ key
65
+ Root key we want to extract.
66
+ graph
67
+ The full task graph.
68
+
69
+ Returns
70
+ -------
71
+ List of other keys needed to extract ``key``.
72
+ """
73
+ maybe_task = graph[key]
74
+ return [
75
+ k
76
+ for k in (
77
+ maybe_task[1:]
78
+ if istask(maybe_task)
79
+ else [maybe_task] # maybe_task might be a key
80
+ )
81
+ if is_hashable(k) and k in graph
82
+ ]
83
+
84
+
85
+ def toposort(graph: Graph, dependencies: Mapping[Key, list[Key]]) -> list[Key]:
86
+ """Return a list of task keys sorted in topological order."""
87
+ # Stack-based depth-first search traversal. This is based on Tarjan's
88
+ # algorithm for strongly-connected components
89
+ # (https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm)
90
+ ordered: list[Key] = []
91
+ completed: set[Key] = set()
92
+
93
+ for key in graph:
94
+ if key in completed:
95
+ continue
96
+ nodes = [key]
97
+ while nodes:
98
+ # Keep current node on the stack until all descendants are visited
99
+ current = nodes[-1]
100
+ if current in completed: # pragma: no cover
101
+ # Already fully traversed descendants of current
102
+ nodes.pop()
103
+ continue
104
+
105
+ # Add direct descendants of current to nodes stack
106
+ next_nodes = set(dependencies[current]) - completed
107
+ if next_nodes:
108
+ nodes.extend(next_nodes)
109
+ else:
110
+ # Current has no more descendants to explore
111
+ ordered.append(current)
112
+ completed.add(current)
113
+ nodes.pop()
114
+
115
+ return ordered
116
+
117
+
118
+ def synchronous_scheduler(
119
+ graph: Graph,
120
+ key: Key,
121
+ *,
122
+ cache: MutableMapping | None = None,
123
+ ) -> Any:
124
+ """
125
+ Execute the task graph for a given key.
126
+
127
+ Parameters
128
+ ----------
129
+ graph
130
+ The task graph to execute.
131
+ key
132
+ The final output key to extract from the graph.
133
+ cache
134
+ Intermediate-data cache.
135
+
136
+ Returns
137
+ -------
138
+ Executed task-graph result for ``key``.
139
+ """
140
+ if key not in graph: # pragma: no cover
141
+ raise KeyError(f"{key} is not a key in the graph")
142
+ if cache is None:
143
+ cache = {}
144
+
145
+ dependencies = {k: required_keys(k, graph) for k in graph}
146
+ refcount = Counter(chain.from_iterable(dependencies.values()))
147
+
148
+ for k in toposort(graph, dependencies):
149
+ cache[k] = _execute_task(graph[k], cache)
150
+ for dep in dependencies[k]:
151
+ refcount[dep] -= 1
152
+ if refcount[dep] == 0 and dep != key:
153
+ del cache[dep]
154
+
155
+ return cache[key]