hoptorch 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
hoptorch-0.1.0/LICENSE ADDED
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Vincent Moens
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,109 @@
1
+ Metadata-Version: 2.4
2
+ Name: hoptorch
3
+ Version: 0.1.0
4
+ Summary: Compatibility helpers for PyTorch higher-order operators.
5
+ Author: Vincent Moens
6
+ License-Expression: MIT
7
+ Project-URL: Homepage, https://github.com/vmoens/hoptorch
8
+ Project-URL: Repository, https://github.com/vmoens/hoptorch
9
+ Project-URL: Issues, https://github.com/vmoens/hoptorch/issues
10
+ Keywords: pytorch,torch,scan,higher-order-operators,autograd
11
+ Classifier: Development Status :: 3 - Alpha
12
+ Classifier: Intended Audience :: Developers
13
+ Classifier: Intended Audience :: Science/Research
14
+ Classifier: Programming Language :: Python :: 3
15
+ Classifier: Programming Language :: Python :: 3 :: Only
16
+ Classifier: Programming Language :: Python :: 3.10
17
+ Classifier: Programming Language :: Python :: 3.11
18
+ Classifier: Programming Language :: Python :: 3.12
19
+ Classifier: Programming Language :: Python :: 3.13
20
+ Classifier: Programming Language :: Python :: 3.14
21
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
22
+ Classifier: Topic :: Software Development :: Libraries :: Python Modules
23
+ Requires-Python: >=3.10
24
+ Description-Content-Type: text/markdown
25
+ License-File: LICENSE
26
+ Requires-Dist: torch>=2.7
27
+ Requires-Dist: pyvers>=0.2.2
28
+ Dynamic: license-file
29
+
30
+ # hoptorch
31
+
32
+ `hoptorch` is a small compatibility package for PyTorch higher-order operators.
33
+ Its first helper is a safe wrapper around `torch._higher_order_ops.scan` that
34
+ checks whether scan backward is healthy on the requested device and lazily
35
+ installs version-specific compatibility patches for affected PyTorch internals.
36
+ For PyTorch 2.7, where scan backward is not implemented, `hoptorch` backports
37
+ a small eager scan implementation with ordinary autograd support.
38
+
39
+ ## Install
40
+
41
+ ```bash
42
+ pip install hoptorch
43
+ ```
44
+
45
+ Runtime dependency:
46
+
47
+ ```text
48
+ torch>=2.7
49
+ pyvers>=0.2.2
50
+ ```
51
+
52
+ ## Usage
53
+
54
+ ```python
55
+ import torch
56
+ from hoptorch import scan
57
+ from hoptorch.scan import ensure_scan_backward, scan_unavailable_reason
58
+
59
+ if ensure_scan_backward("cpu"):
60
+ xs = torch.arange(4.0)
61
+
62
+ def step(carry, x):
63
+ next_carry = carry + x
64
+ return next_carry, next_carry.clone()
65
+
66
+ carry, ys = scan(step, torch.zeros(()), xs)
67
+ else:
68
+ print(scan_unavailable_reason("cpu"))
69
+ ```
70
+
71
+ For compiled code, warm the health check before entering `torch.compile`:
72
+
73
+ ```python
74
+ from hoptorch.scan import ensure_scan_backward
75
+
76
+ ensure_scan_backward("cpu")
77
+ compiled_fn = torch.compile(fn)
78
+ ```
79
+
80
+ If a wrapper call happens during Dynamo tracing before the health result is
81
+ cached, `hoptorch` fails closed instead of tracing the probe.
82
+
83
+ ## Public API
84
+
85
+ ```python
86
+ from hoptorch import scan
87
+ from hoptorch.scan import (
88
+ ensure_scan_backward,
89
+ has_scan,
90
+ patch_scan_backward,
91
+ scan_unavailable_reason,
92
+ )
93
+ ```
94
+
95
+ - `scan(fn, init, xs, *, dim=0, **kwargs)`: calls PyTorch scan only when scan
96
+ backward is known to be healthy for the inferred device.
97
+ - `has_scan()`: reports whether `torch._higher_order_ops.scan.scan` exists.
98
+ - `ensure_scan_backward(device=None)`: runs or reads the cached health check,
99
+ patching lazily if needed, and returns `True` only after a passing probe.
100
+ - `scan_unavailable_reason(device=None)`: returns `None` when scan backward is
101
+ usable, otherwise a stable human-readable reason.
102
+ - `patch_scan_backward()`: attempts to install the compatibility patch and is
103
+ intended mainly for diagnostics.
104
+
105
+ ## Tests
106
+
107
+ ```bash
108
+ python -m unittest discover -s tests
109
+ ```
@@ -0,0 +1,80 @@
1
+ # hoptorch
2
+
3
+ `hoptorch` is a small compatibility package for PyTorch higher-order operators.
4
+ Its first helper is a safe wrapper around `torch._higher_order_ops.scan` that
5
+ checks whether scan backward is healthy on the requested device and lazily
6
+ installs version-specific compatibility patches for affected PyTorch internals.
7
+ For PyTorch 2.7, where scan backward is not implemented, `hoptorch` backports
8
+ a small eager scan implementation with ordinary autograd support.
9
+
10
+ ## Install
11
+
12
+ ```bash
13
+ pip install hoptorch
14
+ ```
15
+
16
+ Runtime dependency:
17
+
18
+ ```text
19
+ torch>=2.7
20
+ pyvers>=0.2.2
21
+ ```
22
+
23
+ ## Usage
24
+
25
+ ```python
26
+ import torch
27
+ from hoptorch import scan
28
+ from hoptorch.scan import ensure_scan_backward, scan_unavailable_reason
29
+
30
+ if ensure_scan_backward("cpu"):
31
+ xs = torch.arange(4.0)
32
+
33
+ def step(carry, x):
34
+ next_carry = carry + x
35
+ return next_carry, next_carry.clone()
36
+
37
+ carry, ys = scan(step, torch.zeros(()), xs)
38
+ else:
39
+ print(scan_unavailable_reason("cpu"))
40
+ ```
41
+
42
+ For compiled code, warm the health check before entering `torch.compile`:
43
+
44
+ ```python
45
+ from hoptorch.scan import ensure_scan_backward
46
+
47
+ ensure_scan_backward("cpu")
48
+ compiled_fn = torch.compile(fn)
49
+ ```
50
+
51
+ If a wrapper call happens during Dynamo tracing before the health result is
52
+ cached, `hoptorch` fails closed instead of tracing the probe.
53
+
54
+ ## Public API
55
+
56
+ ```python
57
+ from hoptorch import scan
58
+ from hoptorch.scan import (
59
+ ensure_scan_backward,
60
+ has_scan,
61
+ patch_scan_backward,
62
+ scan_unavailable_reason,
63
+ )
64
+ ```
65
+
66
+ - `scan(fn, init, xs, *, dim=0, **kwargs)`: calls PyTorch scan only when scan
67
+ backward is known to be healthy for the inferred device.
68
+ - `has_scan()`: reports whether `torch._higher_order_ops.scan.scan` exists.
69
+ - `ensure_scan_backward(device=None)`: runs or reads the cached health check,
70
+ patching lazily if needed, and returns `True` only after a passing probe.
71
+ - `scan_unavailable_reason(device=None)`: returns `None` when scan backward is
72
+ usable, otherwise a stable human-readable reason.
73
+ - `patch_scan_backward()`: attempts to install the compatibility patch and is
74
+ intended mainly for diagnostics.
75
+
76
+ ## Tests
77
+
78
+ ```bash
79
+ python -m unittest discover -s tests
80
+ ```
@@ -0,0 +1,39 @@
1
+ [build-system]
2
+ requires = ["setuptools>=68", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "hoptorch"
7
+ version = "0.1.0"
8
+ description = "Compatibility helpers for PyTorch higher-order operators."
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ license = "MIT"
12
+ authors = [{ name = "Vincent Moens" }]
13
+ keywords = ["pytorch", "torch", "scan", "higher-order-operators", "autograd"]
14
+ classifiers = [
15
+ "Development Status :: 3 - Alpha",
16
+ "Intended Audience :: Developers",
17
+ "Intended Audience :: Science/Research",
18
+ "Programming Language :: Python :: 3",
19
+ "Programming Language :: Python :: 3 :: Only",
20
+ "Programming Language :: Python :: 3.10",
21
+ "Programming Language :: Python :: 3.11",
22
+ "Programming Language :: Python :: 3.12",
23
+ "Programming Language :: Python :: 3.13",
24
+ "Programming Language :: Python :: 3.14",
25
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
26
+ "Topic :: Software Development :: Libraries :: Python Modules",
27
+ ]
28
+ dependencies = ["torch>=2.7", "pyvers>=0.2.2"]
29
+
30
+ [project.urls]
31
+ Homepage = "https://github.com/vmoens/hoptorch"
32
+ Repository = "https://github.com/vmoens/hoptorch"
33
+ Issues = "https://github.com/vmoens/hoptorch/issues"
34
+
35
+ [tool.setuptools]
36
+ package-dir = {"" = "src"}
37
+
38
+ [tool.setuptools.packages.find]
39
+ where = ["src"]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,8 @@
1
+ """Compatibility helpers for PyTorch higher-order operators."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from .scan import scan
6
+
7
+ __all__ = ["scan"]
8
+ __version__ = "0.1.0"
@@ -0,0 +1,144 @@
1
+ """PyTorch 2.7 scan backport.
2
+
3
+ PyTorch 2.7 exposes ``torch._higher_order_ops.scan.scan`` but registers its
4
+ Autograd dispatch key as ``autograd_not_implemented``. For that version, patch
5
+ the public scan function to a small eager Python implementation whose backward
6
+ is handled by ordinary PyTorch autograd through the unrolled loop.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from typing import Any, Callable
12
+
13
+
14
+ def _canonicalize_dim(ndim: int, dim: int) -> int:
15
+ if not isinstance(dim, int):
16
+ raise RuntimeError(f"Dim must be an int, but got {type(dim)}")
17
+ if ndim == 0:
18
+ raise RuntimeError("Cannot scan over a scalar xs tensor")
19
+ if dim < 0:
20
+ dim += ndim
21
+ if dim < 0 or dim >= ndim:
22
+ raise IndexError(f"Dimension out of range (expected 0 <= dim < {ndim}, got {dim})")
23
+ return dim
24
+
25
+
26
+ def _validate_tensor_leaves(torch: Any, leaves: list[Any], name: str) -> None:
27
+ if not leaves:
28
+ raise RuntimeError(f"scan() operator requires {name} leaves.")
29
+ for leaf in leaves:
30
+ if not isinstance(leaf, torch.Tensor):
31
+ raise RuntimeError(f"All {name} leaves must be tensors, but got {leaf!r}")
32
+
33
+
34
+ def scan_27_backport(
35
+ combine_fn: Callable[[Any, Any], tuple[Any, Any]],
36
+ init: Any,
37
+ xs: Any,
38
+ *,
39
+ dim: int = 0,
40
+ reverse: bool = False,
41
+ ) -> tuple[Any, Any]:
42
+ """Eager scan implementation for PyTorch 2.7 with normal autograd support."""
43
+
44
+ import torch
45
+ import torch.utils._pytree as pytree
46
+
47
+ if not callable(combine_fn):
48
+ raise RuntimeError(f"Combine_fn must be callable, but got {combine_fn!r}")
49
+ if not isinstance(reverse, bool):
50
+ raise RuntimeError(f"Reverse must be a bool, but got {type(reverse)}")
51
+
52
+ leaves_init, spec_init = pytree.tree_flatten(init)
53
+ leaves_xs_orig, spec_xs = pytree.tree_flatten(xs)
54
+ if not leaves_xs_orig:
55
+ return init, []
56
+
57
+ _validate_tensor_leaves(torch, leaves_init, "init")
58
+ _validate_tensor_leaves(torch, leaves_xs_orig, "xs")
59
+
60
+ ndim = leaves_xs_orig[0].ndim
61
+ dim = _canonicalize_dim(ndim, dim)
62
+ scan_length = leaves_xs_orig[0].shape[dim]
63
+ for leaf in leaves_xs_orig:
64
+ if leaf.ndim <= dim:
65
+ raise RuntimeError(
66
+ f"All xs leaves must have dimension {dim}, but got shape {tuple(leaf.shape)}"
67
+ )
68
+ if leaf.shape[dim] != scan_length:
69
+ raise RuntimeError("All xs leaves must have the same scan dimension size")
70
+
71
+ leaves_xs = [torch.movedim(leaf, dim, 0) for leaf in leaves_xs_orig]
72
+ if reverse:
73
+ leaves_xs = [torch.flip(leaf, [0]) for leaf in leaves_xs]
74
+
75
+ carry = init
76
+ out_spec = None
77
+ flat_outputs_by_step: list[list[Any]] = []
78
+ for index in range(scan_length):
79
+ x_slice = pytree.tree_unflatten(
80
+ [leaf.select(0, index) for leaf in leaves_xs], spec_xs
81
+ )
82
+ carry, output = combine_fn(carry, x_slice)
83
+ flat_output, current_out_spec = pytree.tree_flatten(output)
84
+ if out_spec is None:
85
+ out_spec = current_out_spec
86
+ elif current_out_spec != out_spec:
87
+ raise RuntimeError("scan output pytree structure changed across iterations")
88
+ for leaf in flat_output:
89
+ if not isinstance(leaf, torch.Tensor):
90
+ raise RuntimeError(
91
+ "hoptorch's PyTorch 2.7 scan backport only supports tensor output leaves"
92
+ )
93
+ flat_outputs_by_step.append(flat_output)
94
+
95
+ flat_carry, current_carry_spec = pytree.tree_flatten(carry)
96
+ if current_carry_spec != spec_init:
97
+ raise RuntimeError("scan carry pytree structure must match init")
98
+ for init_leaf, carry_leaf in zip(leaves_init, flat_carry):
99
+ if not isinstance(carry_leaf, torch.Tensor):
100
+ raise RuntimeError("All carry leaves must be tensors")
101
+ if init_leaf.shape != carry_leaf.shape or init_leaf.dtype != carry_leaf.dtype:
102
+ raise RuntimeError("scan carry tensor metadata must match init")
103
+
104
+ if out_spec is None:
105
+ return carry, []
106
+
107
+ flat_stacked_outputs = []
108
+ for output_index in range(len(flat_outputs_by_step[0])):
109
+ leaves = [step[output_index] for step in flat_outputs_by_step]
110
+ stacked = torch.stack(leaves, dim=0)
111
+ if reverse:
112
+ stacked = torch.flip(stacked, [0])
113
+ flat_stacked_outputs.append(stacked)
114
+
115
+ return carry, pytree.tree_unflatten(flat_stacked_outputs, out_spec)
116
+
117
+
118
+ def install_scan_27_backport(scan_module: Any) -> bool:
119
+ """Patch ``torch._higher_order_ops.scan.scan`` for PyTorch 2.7."""
120
+
121
+ if getattr(scan_module, "_hoptorch_scan_27_backport", False):
122
+ return True
123
+ original_scan = getattr(scan_module, "scan", None)
124
+ if original_scan is None:
125
+ return False
126
+ scan_27_backport._hoptorch_scan_backward_original = original_scan # type: ignore[attr-defined]
127
+ scan_module.scan = scan_27_backport
128
+ scan_module._hoptorch_scan_27_backport = True
129
+ return True
130
+
131
+
132
+ def rollback_scan_27_backport(scan_module: Any) -> bool:
133
+ """Restore the original PyTorch 2.7 scan function if this patch was applied."""
134
+
135
+ if not getattr(scan_module, "_hoptorch_scan_27_backport", False):
136
+ return False
137
+ original_scan = getattr(
138
+ getattr(scan_module, "scan", None), "_hoptorch_scan_backward_original", None
139
+ )
140
+ if original_scan is None:
141
+ return False
142
+ scan_module.scan = original_scan
143
+ scan_module._hoptorch_scan_27_backport = False
144
+ return True
@@ -0,0 +1,251 @@
1
+ """Private PyTorch monkey patches for scan backward compatibility.
2
+
3
+ All direct interaction with PyTorch scan autograd internals is isolated here so
4
+ callers do not need to import or reason about those private symbols.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from typing import Any
10
+
11
+ from pyvers import implement_for
12
+
13
+
14
+ def _differentiable_sample_requires_grad(value: Any) -> Any:
15
+ import torch
16
+
17
+ if not isinstance(value, torch.Tensor):
18
+ return value
19
+ sample = value.detach().clone()
20
+ if value.dtype.is_floating_point or value.dtype.is_complex:
21
+ sample.requires_grad_(True)
22
+ return sample
23
+
24
+
25
+ def _map_differentiable_samples(value: Any) -> Any:
26
+ import torch
27
+
28
+ if isinstance(value, torch.Tensor):
29
+ return _differentiable_sample_requires_grad(value)
30
+ if isinstance(value, tuple):
31
+ return tuple(_map_differentiable_samples(item) for item in value)
32
+ if isinstance(value, list):
33
+ return [_map_differentiable_samples(item) for item in value]
34
+ if isinstance(value, dict):
35
+ return {key: _map_differentiable_samples(item) for key, item in value.items()}
36
+ return value
37
+
38
+
39
+ def _register_scan_autograd_impl(scan_module: Any, scan_autograd: Any) -> bool:
40
+ import torch
41
+
42
+ scan_op = getattr(scan_module, "scan_op", None)
43
+ if scan_op is None or not hasattr(scan_op, "py_autograd_impl"):
44
+ return False
45
+
46
+ py_kernels = getattr(scan_op, "py_kernels", None)
47
+ if py_kernels is None:
48
+ return False
49
+
50
+ autograd_key = torch._C.DispatchKey.Autograd
51
+ previous = py_kernels.pop(autograd_key, None)
52
+ if hasattr(scan_op, "_dispatch_cache"):
53
+ scan_op._dispatch_cache.clear()
54
+ try:
55
+ scan_op.py_autograd_impl(scan_autograd)
56
+ except Exception:
57
+ py_kernels.pop(autograd_key, None)
58
+ if previous is not None:
59
+ py_kernels[autograd_key] = previous
60
+ if hasattr(scan_op, "_dispatch_cache"):
61
+ scan_op._dispatch_cache.clear()
62
+ return False
63
+
64
+ scan_module.scan_autograd = scan_autograd
65
+ return True
66
+
67
+
68
+ def _patch_backward_input_output_aliasing(scan_module: Any) -> bool:
69
+ import torch
70
+ import torch.fx
71
+
72
+ try:
73
+ from torch.multiprocessing.reductions import StorageWeakRef
74
+ except Exception:
75
+ StorageWeakRef = None
76
+
77
+ scan_autograd_impl = getattr(scan_module, "ScanAutogradImpl", None)
78
+ if scan_autograd_impl is None:
79
+ return False
80
+ if getattr(scan_autograd_impl, "_hoptorch_scan_backward_alias_patch", False):
81
+ return True
82
+
83
+ def _break_bw_input_output_aliasing(self) -> None:
84
+ if StorageWeakRef is None or not hasattr(self, "_insert_clone"):
85
+ return
86
+
87
+ bw_gm = self.hop_partitioned_graph.bw_gm
88
+ bw_output_node = next(iter(bw_gm.graph.find_nodes(op="output")))
89
+ if len(bw_output_node.args) != 1:
90
+ raise AssertionError(
91
+ "expected bw_gm output to have 1 arg, got "
92
+ f"{len(bw_output_node.args)}"
93
+ )
94
+
95
+ bw_outputs = bw_output_node.args[0]
96
+ if not isinstance(bw_outputs, (tuple, list)):
97
+ bw_outputs = (bw_outputs,)
98
+
99
+ placeholder_storages = set()
100
+ for placeholder in bw_gm.graph.find_nodes(op="placeholder"):
101
+ val = placeholder.meta.get("val", None) if hasattr(placeholder, "meta") else None
102
+ if isinstance(val, torch.Tensor):
103
+ placeholder_storages.add(StorageWeakRef(val._typed_storage()))
104
+
105
+ def aliases_placeholder(node: torch.fx.Node) -> bool:
106
+ if node.op == "placeholder":
107
+ return True
108
+ val = node.meta.get("val", None) if hasattr(node, "meta") else None
109
+ if isinstance(val, torch.Tensor):
110
+ return StorageWeakRef(val._typed_storage()) in placeholder_storages
111
+ return False
112
+
113
+ new_bw_outputs = []
114
+ rewrote = False
115
+ for output in bw_outputs:
116
+ if isinstance(output, torch.fx.Node) and aliases_placeholder(output):
117
+ new_bw_outputs.append(self._insert_clone(output, bw_output_node))
118
+ rewrote = True
119
+ else:
120
+ new_bw_outputs.append(output)
121
+
122
+ if rewrote:
123
+ bw_output_node.args = (tuple(new_bw_outputs),)
124
+ bw_gm.graph.lint()
125
+ bw_gm.recompile()
126
+
127
+ original_init = scan_autograd_impl.__init__
128
+
129
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
130
+ original_init(self, *args, **kwargs)
131
+ self._break_bw_input_output_aliasing()
132
+
133
+ scan_autograd_impl._break_bw_input_output_aliasing = _break_bw_input_output_aliasing
134
+ scan_autograd_impl.__init__ = __init__
135
+ scan_autograd_impl._hoptorch_scan_backward_alias_patch = True
136
+ return True
137
+
138
+
139
+ def _patch_current_scan_autograd(scan_module: Any) -> bool:
140
+ required = (
141
+ "HopGraphMinCutPartitioner",
142
+ "ScanAutogradImpl",
143
+ "ScanAutogradOp",
144
+ "disable_proxy_modes_tracing",
145
+ )
146
+ if not all(hasattr(scan_module, name) for name in required):
147
+ return False
148
+
149
+ _patch_backward_input_output_aliasing(scan_module)
150
+
151
+ def scan_autograd(combine_fn, init, xs, additional_inputs):
152
+ with scan_module.disable_proxy_modes_tracing():
153
+ sample_init = [_differentiable_sample_requires_grad(t) for t in init]
154
+ sample_args = (*sample_init, *[x[0] for x in xs], *additional_inputs)
155
+ try:
156
+ hop_partitioned_graph = (
157
+ scan_module.HopGraphMinCutPartitioner.create_partitioned_graph(
158
+ combine_fn,
159
+ sample_args,
160
+ always_recompute_complex_exprs=True,
161
+ )
162
+ )
163
+ except TypeError:
164
+ real_args = (*init, *[x[0] for x in xs], *additional_inputs)
165
+ hop_partitioned_graph = (
166
+ scan_module.HopGraphMinCutPartitioner.create_partitioned_graph(
167
+ combine_fn,
168
+ real_args,
169
+ sample_args,
170
+ always_recompute_complex_exprs=True,
171
+ )
172
+ )
173
+
174
+ return scan_module.ScanAutogradOp.apply(
175
+ hop_partitioned_graph,
176
+ len(init),
177
+ len(xs),
178
+ len(additional_inputs),
179
+ *init,
180
+ *xs,
181
+ *additional_inputs,
182
+ )
183
+
184
+ return _register_scan_autograd_impl(scan_module, scan_autograd)
185
+
186
+
187
+ def _patch_older_scan_autograd(scan_module: Any) -> bool:
188
+ materialize_as_graph = getattr(scan_module, "materialize_as_graph", None)
189
+ if materialize_as_graph is None:
190
+ return False
191
+ if getattr(materialize_as_graph, "_hoptorch_scan_backward_sample_patch", False):
192
+ return True
193
+
194
+ def patched_materialize_as_graph(fn, args, *other_args, **kwargs):
195
+ return materialize_as_graph(
196
+ fn,
197
+ _map_differentiable_samples(args),
198
+ *other_args,
199
+ **kwargs,
200
+ )
201
+
202
+ patched_materialize_as_graph._hoptorch_scan_backward_sample_patch = True
203
+ patched_materialize_as_graph._hoptorch_scan_backward_original = materialize_as_graph
204
+ scan_module.materialize_as_graph = patched_materialize_as_graph
205
+ return True
206
+
207
+
208
+ @implement_for("torch")
209
+ def install_scan_backward_patch(scan_module: Any) -> bool:
210
+ """Attempt to install a scan backward compatibility patch."""
211
+
212
+ return False
213
+
214
+
215
+ @install_scan_backward_patch.register(from_version="2.7", to_version="2.8")
216
+ def _(scan_module: Any) -> bool:
217
+ from ._scan_backport_27 import install_scan_27_backport
218
+
219
+ return install_scan_27_backport(scan_module)
220
+
221
+
222
+ @install_scan_backward_patch.register(from_version="2.8")
223
+ def _(scan_module: Any) -> bool:
224
+ try:
225
+ return _patch_current_scan_autograd(scan_module) or _patch_older_scan_autograd(scan_module)
226
+ except Exception:
227
+ return False
228
+
229
+
230
+ @implement_for("torch")
231
+ def rollback_failed_scan_backward_patch(scan_module: Any) -> bool:
232
+ """Best-effort rollback for reversible compatibility patches."""
233
+
234
+ return False
235
+
236
+
237
+ @rollback_failed_scan_backward_patch.register(from_version="2.7", to_version="2.8")
238
+ def _(scan_module: Any) -> bool:
239
+ from ._scan_backport_27 import rollback_scan_27_backport
240
+
241
+ return rollback_scan_27_backport(scan_module)
242
+
243
+
244
+ @rollback_failed_scan_backward_patch.register(from_version="2.8")
245
+ def _(scan_module: Any) -> bool:
246
+ materialize_as_graph = getattr(scan_module, "materialize_as_graph", None)
247
+ original = getattr(materialize_as_graph, "_hoptorch_scan_backward_original", None)
248
+ if original is None:
249
+ return False
250
+ scan_module.materialize_as_graph = original
251
+ return True