hoptorch 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hoptorch-0.1.0/LICENSE +21 -0
- hoptorch-0.1.0/PKG-INFO +109 -0
- hoptorch-0.1.0/README.md +80 -0
- hoptorch-0.1.0/pyproject.toml +39 -0
- hoptorch-0.1.0/setup.cfg +4 -0
- hoptorch-0.1.0/src/hoptorch/__init__.py +8 -0
- hoptorch-0.1.0/src/hoptorch/_scan_backport_27.py +144 -0
- hoptorch-0.1.0/src/hoptorch/_scan_patch.py +251 -0
- hoptorch-0.1.0/src/hoptorch/_scan_probe.py +65 -0
- hoptorch-0.1.0/src/hoptorch/scan.py +298 -0
- hoptorch-0.1.0/src/hoptorch.egg-info/PKG-INFO +109 -0
- hoptorch-0.1.0/src/hoptorch.egg-info/SOURCES.txt +15 -0
- hoptorch-0.1.0/src/hoptorch.egg-info/dependency_links.txt +1 -0
- hoptorch-0.1.0/src/hoptorch.egg-info/requires.txt +2 -0
- hoptorch-0.1.0/src/hoptorch.egg-info/top_level.txt +1 -0
- hoptorch-0.1.0/tests/test_public_api.py +99 -0
- hoptorch-0.1.0/tests/test_scan_backward.py +167 -0
hoptorch-0.1.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2026 Vincent Moens
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
hoptorch-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: hoptorch
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Compatibility helpers for PyTorch higher-order operators.
|
|
5
|
+
Author: Vincent Moens
|
|
6
|
+
License-Expression: MIT
|
|
7
|
+
Project-URL: Homepage, https://github.com/vmoens/hoptorch
|
|
8
|
+
Project-URL: Repository, https://github.com/vmoens/hoptorch
|
|
9
|
+
Project-URL: Issues, https://github.com/vmoens/hoptorch/issues
|
|
10
|
+
Keywords: pytorch,torch,scan,higher-order-operators,autograd
|
|
11
|
+
Classifier: Development Status :: 3 - Alpha
|
|
12
|
+
Classifier: Intended Audience :: Developers
|
|
13
|
+
Classifier: Intended Audience :: Science/Research
|
|
14
|
+
Classifier: Programming Language :: Python :: 3
|
|
15
|
+
Classifier: Programming Language :: Python :: 3 :: Only
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
20
|
+
Classifier: Programming Language :: Python :: 3.14
|
|
21
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
22
|
+
Classifier: Topic :: Software Development :: Libraries :: Python Modules
|
|
23
|
+
Requires-Python: >=3.10
|
|
24
|
+
Description-Content-Type: text/markdown
|
|
25
|
+
License-File: LICENSE
|
|
26
|
+
Requires-Dist: torch>=2.7
|
|
27
|
+
Requires-Dist: pyvers>=0.2.2
|
|
28
|
+
Dynamic: license-file
|
|
29
|
+
|
|
30
|
+
# hoptorch
|
|
31
|
+
|
|
32
|
+
`hoptorch` is a small compatibility package for PyTorch higher-order operators.
|
|
33
|
+
Its first helper is a safe wrapper around `torch._higher_order_ops.scan` that
|
|
34
|
+
checks whether scan backward is healthy on the requested device and lazily
|
|
35
|
+
installs version-specific compatibility patches for affected PyTorch internals.
|
|
36
|
+
For PyTorch 2.7, where scan backward is not implemented, `hoptorch` backports
|
|
37
|
+
a small eager scan implementation with ordinary autograd support.
|
|
38
|
+
|
|
39
|
+
## Install
|
|
40
|
+
|
|
41
|
+
```bash
|
|
42
|
+
pip install hoptorch
|
|
43
|
+
```
|
|
44
|
+
|
|
45
|
+
Runtime dependency:
|
|
46
|
+
|
|
47
|
+
```text
|
|
48
|
+
torch>=2.7
|
|
49
|
+
pyvers>=0.2.2
|
|
50
|
+
```
|
|
51
|
+
|
|
52
|
+
## Usage
|
|
53
|
+
|
|
54
|
+
```python
|
|
55
|
+
import torch
|
|
56
|
+
from hoptorch import scan
|
|
57
|
+
from hoptorch.scan import ensure_scan_backward, scan_unavailable_reason
|
|
58
|
+
|
|
59
|
+
if ensure_scan_backward("cpu"):
|
|
60
|
+
xs = torch.arange(4.0)
|
|
61
|
+
|
|
62
|
+
def step(carry, x):
|
|
63
|
+
next_carry = carry + x
|
|
64
|
+
return next_carry, next_carry.clone()
|
|
65
|
+
|
|
66
|
+
carry, ys = scan(step, torch.zeros(()), xs)
|
|
67
|
+
else:
|
|
68
|
+
print(scan_unavailable_reason("cpu"))
|
|
69
|
+
```
|
|
70
|
+
|
|
71
|
+
For compiled code, warm the health check before entering `torch.compile`:
|
|
72
|
+
|
|
73
|
+
```python
|
|
74
|
+
from hoptorch.scan import ensure_scan_backward
|
|
75
|
+
|
|
76
|
+
ensure_scan_backward("cpu")
|
|
77
|
+
compiled_fn = torch.compile(fn)
|
|
78
|
+
```
|
|
79
|
+
|
|
80
|
+
If a wrapper call happens during Dynamo tracing before the health result is
|
|
81
|
+
cached, `hoptorch` fails closed instead of tracing the probe.
|
|
82
|
+
|
|
83
|
+
## Public API
|
|
84
|
+
|
|
85
|
+
```python
|
|
86
|
+
from hoptorch import scan
|
|
87
|
+
from hoptorch.scan import (
|
|
88
|
+
ensure_scan_backward,
|
|
89
|
+
has_scan,
|
|
90
|
+
patch_scan_backward,
|
|
91
|
+
scan_unavailable_reason,
|
|
92
|
+
)
|
|
93
|
+
```
|
|
94
|
+
|
|
95
|
+
- `scan(fn, init, xs, *, dim=0, **kwargs)`: calls PyTorch scan only when scan
|
|
96
|
+
backward is known to be healthy for the inferred device.
|
|
97
|
+
- `has_scan()`: reports whether `torch._higher_order_ops.scan.scan` exists.
|
|
98
|
+
- `ensure_scan_backward(device=None)`: runs or reads the cached health check,
|
|
99
|
+
patching lazily if needed, and returns `True` only after a passing probe.
|
|
100
|
+
- `scan_unavailable_reason(device=None)`: returns `None` when scan backward is
|
|
101
|
+
usable, otherwise a stable human-readable reason.
|
|
102
|
+
- `patch_scan_backward()`: attempts to install the compatibility patch and is
|
|
103
|
+
intended mainly for diagnostics.
|
|
104
|
+
|
|
105
|
+
## Tests
|
|
106
|
+
|
|
107
|
+
```bash
|
|
108
|
+
python -m unittest discover -s tests
|
|
109
|
+
```
|
hoptorch-0.1.0/README.md
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
# hoptorch
|
|
2
|
+
|
|
3
|
+
`hoptorch` is a small compatibility package for PyTorch higher-order operators.
|
|
4
|
+
Its first helper is a safe wrapper around `torch._higher_order_ops.scan` that
|
|
5
|
+
checks whether scan backward is healthy on the requested device and lazily
|
|
6
|
+
installs version-specific compatibility patches for affected PyTorch internals.
|
|
7
|
+
For PyTorch 2.7, where scan backward is not implemented, `hoptorch` backports
|
|
8
|
+
a small eager scan implementation with ordinary autograd support.
|
|
9
|
+
|
|
10
|
+
## Install
|
|
11
|
+
|
|
12
|
+
```bash
|
|
13
|
+
pip install hoptorch
|
|
14
|
+
```
|
|
15
|
+
|
|
16
|
+
Runtime dependency:
|
|
17
|
+
|
|
18
|
+
```text
|
|
19
|
+
torch>=2.7
|
|
20
|
+
pyvers>=0.2.2
|
|
21
|
+
```
|
|
22
|
+
|
|
23
|
+
## Usage
|
|
24
|
+
|
|
25
|
+
```python
|
|
26
|
+
import torch
|
|
27
|
+
from hoptorch import scan
|
|
28
|
+
from hoptorch.scan import ensure_scan_backward, scan_unavailable_reason
|
|
29
|
+
|
|
30
|
+
if ensure_scan_backward("cpu"):
|
|
31
|
+
xs = torch.arange(4.0)
|
|
32
|
+
|
|
33
|
+
def step(carry, x):
|
|
34
|
+
next_carry = carry + x
|
|
35
|
+
return next_carry, next_carry.clone()
|
|
36
|
+
|
|
37
|
+
carry, ys = scan(step, torch.zeros(()), xs)
|
|
38
|
+
else:
|
|
39
|
+
print(scan_unavailable_reason("cpu"))
|
|
40
|
+
```
|
|
41
|
+
|
|
42
|
+
For compiled code, warm the health check before entering `torch.compile`:
|
|
43
|
+
|
|
44
|
+
```python
|
|
45
|
+
from hoptorch.scan import ensure_scan_backward
|
|
46
|
+
|
|
47
|
+
ensure_scan_backward("cpu")
|
|
48
|
+
compiled_fn = torch.compile(fn)
|
|
49
|
+
```
|
|
50
|
+
|
|
51
|
+
If a wrapper call happens during Dynamo tracing before the health result is
|
|
52
|
+
cached, `hoptorch` fails closed instead of tracing the probe.
|
|
53
|
+
|
|
54
|
+
## Public API
|
|
55
|
+
|
|
56
|
+
```python
|
|
57
|
+
from hoptorch import scan
|
|
58
|
+
from hoptorch.scan import (
|
|
59
|
+
ensure_scan_backward,
|
|
60
|
+
has_scan,
|
|
61
|
+
patch_scan_backward,
|
|
62
|
+
scan_unavailable_reason,
|
|
63
|
+
)
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
- `scan(fn, init, xs, *, dim=0, **kwargs)`: calls PyTorch scan only when scan
|
|
67
|
+
backward is known to be healthy for the inferred device.
|
|
68
|
+
- `has_scan()`: reports whether `torch._higher_order_ops.scan.scan` exists.
|
|
69
|
+
- `ensure_scan_backward(device=None)`: runs or reads the cached health check,
|
|
70
|
+
patching lazily if needed, and returns `True` only after a passing probe.
|
|
71
|
+
- `scan_unavailable_reason(device=None)`: returns `None` when scan backward is
|
|
72
|
+
usable, otherwise a stable human-readable reason.
|
|
73
|
+
- `patch_scan_backward()`: attempts to install the compatibility patch and is
|
|
74
|
+
intended mainly for diagnostics.
|
|
75
|
+
|
|
76
|
+
## Tests
|
|
77
|
+
|
|
78
|
+
```bash
|
|
79
|
+
python -m unittest discover -s tests
|
|
80
|
+
```
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=68", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "hoptorch"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "Compatibility helpers for PyTorch higher-order operators."
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.10"
|
|
11
|
+
license = "MIT"
|
|
12
|
+
authors = [{ name = "Vincent Moens" }]
|
|
13
|
+
keywords = ["pytorch", "torch", "scan", "higher-order-operators", "autograd"]
|
|
14
|
+
classifiers = [
|
|
15
|
+
"Development Status :: 3 - Alpha",
|
|
16
|
+
"Intended Audience :: Developers",
|
|
17
|
+
"Intended Audience :: Science/Research",
|
|
18
|
+
"Programming Language :: Python :: 3",
|
|
19
|
+
"Programming Language :: Python :: 3 :: Only",
|
|
20
|
+
"Programming Language :: Python :: 3.10",
|
|
21
|
+
"Programming Language :: Python :: 3.11",
|
|
22
|
+
"Programming Language :: Python :: 3.12",
|
|
23
|
+
"Programming Language :: Python :: 3.13",
|
|
24
|
+
"Programming Language :: Python :: 3.14",
|
|
25
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
26
|
+
"Topic :: Software Development :: Libraries :: Python Modules",
|
|
27
|
+
]
|
|
28
|
+
dependencies = ["torch>=2.7", "pyvers>=0.2.2"]
|
|
29
|
+
|
|
30
|
+
[project.urls]
|
|
31
|
+
Homepage = "https://github.com/vmoens/hoptorch"
|
|
32
|
+
Repository = "https://github.com/vmoens/hoptorch"
|
|
33
|
+
Issues = "https://github.com/vmoens/hoptorch/issues"
|
|
34
|
+
|
|
35
|
+
[tool.setuptools]
|
|
36
|
+
package-dir = {"" = "src"}
|
|
37
|
+
|
|
38
|
+
[tool.setuptools.packages.find]
|
|
39
|
+
where = ["src"]
|
hoptorch-0.1.0/setup.cfg
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
"""PyTorch 2.7 scan backport.
|
|
2
|
+
|
|
3
|
+
PyTorch 2.7 exposes ``torch._higher_order_ops.scan.scan`` but registers its
|
|
4
|
+
Autograd dispatch key as ``autograd_not_implemented``. For that version, patch
|
|
5
|
+
the public scan function to a small eager Python implementation whose backward
|
|
6
|
+
is handled by ordinary PyTorch autograd through the unrolled loop.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
from typing import Any, Callable
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _canonicalize_dim(ndim: int, dim: int) -> int:
|
|
15
|
+
if not isinstance(dim, int):
|
|
16
|
+
raise RuntimeError(f"Dim must be an int, but got {type(dim)}")
|
|
17
|
+
if ndim == 0:
|
|
18
|
+
raise RuntimeError("Cannot scan over a scalar xs tensor")
|
|
19
|
+
if dim < 0:
|
|
20
|
+
dim += ndim
|
|
21
|
+
if dim < 0 or dim >= ndim:
|
|
22
|
+
raise IndexError(f"Dimension out of range (expected 0 <= dim < {ndim}, got {dim})")
|
|
23
|
+
return dim
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _validate_tensor_leaves(torch: Any, leaves: list[Any], name: str) -> None:
|
|
27
|
+
if not leaves:
|
|
28
|
+
raise RuntimeError(f"scan() operator requires {name} leaves.")
|
|
29
|
+
for leaf in leaves:
|
|
30
|
+
if not isinstance(leaf, torch.Tensor):
|
|
31
|
+
raise RuntimeError(f"All {name} leaves must be tensors, but got {leaf!r}")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def scan_27_backport(
|
|
35
|
+
combine_fn: Callable[[Any, Any], tuple[Any, Any]],
|
|
36
|
+
init: Any,
|
|
37
|
+
xs: Any,
|
|
38
|
+
*,
|
|
39
|
+
dim: int = 0,
|
|
40
|
+
reverse: bool = False,
|
|
41
|
+
) -> tuple[Any, Any]:
|
|
42
|
+
"""Eager scan implementation for PyTorch 2.7 with normal autograd support."""
|
|
43
|
+
|
|
44
|
+
import torch
|
|
45
|
+
import torch.utils._pytree as pytree
|
|
46
|
+
|
|
47
|
+
if not callable(combine_fn):
|
|
48
|
+
raise RuntimeError(f"Combine_fn must be callable, but got {combine_fn!r}")
|
|
49
|
+
if not isinstance(reverse, bool):
|
|
50
|
+
raise RuntimeError(f"Reverse must be a bool, but got {type(reverse)}")
|
|
51
|
+
|
|
52
|
+
leaves_init, spec_init = pytree.tree_flatten(init)
|
|
53
|
+
leaves_xs_orig, spec_xs = pytree.tree_flatten(xs)
|
|
54
|
+
if not leaves_xs_orig:
|
|
55
|
+
return init, []
|
|
56
|
+
|
|
57
|
+
_validate_tensor_leaves(torch, leaves_init, "init")
|
|
58
|
+
_validate_tensor_leaves(torch, leaves_xs_orig, "xs")
|
|
59
|
+
|
|
60
|
+
ndim = leaves_xs_orig[0].ndim
|
|
61
|
+
dim = _canonicalize_dim(ndim, dim)
|
|
62
|
+
scan_length = leaves_xs_orig[0].shape[dim]
|
|
63
|
+
for leaf in leaves_xs_orig:
|
|
64
|
+
if leaf.ndim <= dim:
|
|
65
|
+
raise RuntimeError(
|
|
66
|
+
f"All xs leaves must have dimension {dim}, but got shape {tuple(leaf.shape)}"
|
|
67
|
+
)
|
|
68
|
+
if leaf.shape[dim] != scan_length:
|
|
69
|
+
raise RuntimeError("All xs leaves must have the same scan dimension size")
|
|
70
|
+
|
|
71
|
+
leaves_xs = [torch.movedim(leaf, dim, 0) for leaf in leaves_xs_orig]
|
|
72
|
+
if reverse:
|
|
73
|
+
leaves_xs = [torch.flip(leaf, [0]) for leaf in leaves_xs]
|
|
74
|
+
|
|
75
|
+
carry = init
|
|
76
|
+
out_spec = None
|
|
77
|
+
flat_outputs_by_step: list[list[Any]] = []
|
|
78
|
+
for index in range(scan_length):
|
|
79
|
+
x_slice = pytree.tree_unflatten(
|
|
80
|
+
[leaf.select(0, index) for leaf in leaves_xs], spec_xs
|
|
81
|
+
)
|
|
82
|
+
carry, output = combine_fn(carry, x_slice)
|
|
83
|
+
flat_output, current_out_spec = pytree.tree_flatten(output)
|
|
84
|
+
if out_spec is None:
|
|
85
|
+
out_spec = current_out_spec
|
|
86
|
+
elif current_out_spec != out_spec:
|
|
87
|
+
raise RuntimeError("scan output pytree structure changed across iterations")
|
|
88
|
+
for leaf in flat_output:
|
|
89
|
+
if not isinstance(leaf, torch.Tensor):
|
|
90
|
+
raise RuntimeError(
|
|
91
|
+
"hoptorch's PyTorch 2.7 scan backport only supports tensor output leaves"
|
|
92
|
+
)
|
|
93
|
+
flat_outputs_by_step.append(flat_output)
|
|
94
|
+
|
|
95
|
+
flat_carry, current_carry_spec = pytree.tree_flatten(carry)
|
|
96
|
+
if current_carry_spec != spec_init:
|
|
97
|
+
raise RuntimeError("scan carry pytree structure must match init")
|
|
98
|
+
for init_leaf, carry_leaf in zip(leaves_init, flat_carry):
|
|
99
|
+
if not isinstance(carry_leaf, torch.Tensor):
|
|
100
|
+
raise RuntimeError("All carry leaves must be tensors")
|
|
101
|
+
if init_leaf.shape != carry_leaf.shape or init_leaf.dtype != carry_leaf.dtype:
|
|
102
|
+
raise RuntimeError("scan carry tensor metadata must match init")
|
|
103
|
+
|
|
104
|
+
if out_spec is None:
|
|
105
|
+
return carry, []
|
|
106
|
+
|
|
107
|
+
flat_stacked_outputs = []
|
|
108
|
+
for output_index in range(len(flat_outputs_by_step[0])):
|
|
109
|
+
leaves = [step[output_index] for step in flat_outputs_by_step]
|
|
110
|
+
stacked = torch.stack(leaves, dim=0)
|
|
111
|
+
if reverse:
|
|
112
|
+
stacked = torch.flip(stacked, [0])
|
|
113
|
+
flat_stacked_outputs.append(stacked)
|
|
114
|
+
|
|
115
|
+
return carry, pytree.tree_unflatten(flat_stacked_outputs, out_spec)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def install_scan_27_backport(scan_module: Any) -> bool:
|
|
119
|
+
"""Patch ``torch._higher_order_ops.scan.scan`` for PyTorch 2.7."""
|
|
120
|
+
|
|
121
|
+
if getattr(scan_module, "_hoptorch_scan_27_backport", False):
|
|
122
|
+
return True
|
|
123
|
+
original_scan = getattr(scan_module, "scan", None)
|
|
124
|
+
if original_scan is None:
|
|
125
|
+
return False
|
|
126
|
+
scan_27_backport._hoptorch_scan_backward_original = original_scan # type: ignore[attr-defined]
|
|
127
|
+
scan_module.scan = scan_27_backport
|
|
128
|
+
scan_module._hoptorch_scan_27_backport = True
|
|
129
|
+
return True
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def rollback_scan_27_backport(scan_module: Any) -> bool:
|
|
133
|
+
"""Restore the original PyTorch 2.7 scan function if this patch was applied."""
|
|
134
|
+
|
|
135
|
+
if not getattr(scan_module, "_hoptorch_scan_27_backport", False):
|
|
136
|
+
return False
|
|
137
|
+
original_scan = getattr(
|
|
138
|
+
getattr(scan_module, "scan", None), "_hoptorch_scan_backward_original", None
|
|
139
|
+
)
|
|
140
|
+
if original_scan is None:
|
|
141
|
+
return False
|
|
142
|
+
scan_module.scan = original_scan
|
|
143
|
+
scan_module._hoptorch_scan_27_backport = False
|
|
144
|
+
return True
|
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
"""Private PyTorch monkey patches for scan backward compatibility.
|
|
2
|
+
|
|
3
|
+
All direct interaction with PyTorch scan autograd internals is isolated here so
|
|
4
|
+
callers do not need to import or reason about those private symbols.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from pyvers import implement_for
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def _differentiable_sample_requires_grad(value: Any) -> Any:
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
if not isinstance(value, torch.Tensor):
|
|
18
|
+
return value
|
|
19
|
+
sample = value.detach().clone()
|
|
20
|
+
if value.dtype.is_floating_point or value.dtype.is_complex:
|
|
21
|
+
sample.requires_grad_(True)
|
|
22
|
+
return sample
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _map_differentiable_samples(value: Any) -> Any:
|
|
26
|
+
import torch
|
|
27
|
+
|
|
28
|
+
if isinstance(value, torch.Tensor):
|
|
29
|
+
return _differentiable_sample_requires_grad(value)
|
|
30
|
+
if isinstance(value, tuple):
|
|
31
|
+
return tuple(_map_differentiable_samples(item) for item in value)
|
|
32
|
+
if isinstance(value, list):
|
|
33
|
+
return [_map_differentiable_samples(item) for item in value]
|
|
34
|
+
if isinstance(value, dict):
|
|
35
|
+
return {key: _map_differentiable_samples(item) for key, item in value.items()}
|
|
36
|
+
return value
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def _register_scan_autograd_impl(scan_module: Any, scan_autograd: Any) -> bool:
|
|
40
|
+
import torch
|
|
41
|
+
|
|
42
|
+
scan_op = getattr(scan_module, "scan_op", None)
|
|
43
|
+
if scan_op is None or not hasattr(scan_op, "py_autograd_impl"):
|
|
44
|
+
return False
|
|
45
|
+
|
|
46
|
+
py_kernels = getattr(scan_op, "py_kernels", None)
|
|
47
|
+
if py_kernels is None:
|
|
48
|
+
return False
|
|
49
|
+
|
|
50
|
+
autograd_key = torch._C.DispatchKey.Autograd
|
|
51
|
+
previous = py_kernels.pop(autograd_key, None)
|
|
52
|
+
if hasattr(scan_op, "_dispatch_cache"):
|
|
53
|
+
scan_op._dispatch_cache.clear()
|
|
54
|
+
try:
|
|
55
|
+
scan_op.py_autograd_impl(scan_autograd)
|
|
56
|
+
except Exception:
|
|
57
|
+
py_kernels.pop(autograd_key, None)
|
|
58
|
+
if previous is not None:
|
|
59
|
+
py_kernels[autograd_key] = previous
|
|
60
|
+
if hasattr(scan_op, "_dispatch_cache"):
|
|
61
|
+
scan_op._dispatch_cache.clear()
|
|
62
|
+
return False
|
|
63
|
+
|
|
64
|
+
scan_module.scan_autograd = scan_autograd
|
|
65
|
+
return True
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def _patch_backward_input_output_aliasing(scan_module: Any) -> bool:
|
|
69
|
+
import torch
|
|
70
|
+
import torch.fx
|
|
71
|
+
|
|
72
|
+
try:
|
|
73
|
+
from torch.multiprocessing.reductions import StorageWeakRef
|
|
74
|
+
except Exception:
|
|
75
|
+
StorageWeakRef = None
|
|
76
|
+
|
|
77
|
+
scan_autograd_impl = getattr(scan_module, "ScanAutogradImpl", None)
|
|
78
|
+
if scan_autograd_impl is None:
|
|
79
|
+
return False
|
|
80
|
+
if getattr(scan_autograd_impl, "_hoptorch_scan_backward_alias_patch", False):
|
|
81
|
+
return True
|
|
82
|
+
|
|
83
|
+
def _break_bw_input_output_aliasing(self) -> None:
|
|
84
|
+
if StorageWeakRef is None or not hasattr(self, "_insert_clone"):
|
|
85
|
+
return
|
|
86
|
+
|
|
87
|
+
bw_gm = self.hop_partitioned_graph.bw_gm
|
|
88
|
+
bw_output_node = next(iter(bw_gm.graph.find_nodes(op="output")))
|
|
89
|
+
if len(bw_output_node.args) != 1:
|
|
90
|
+
raise AssertionError(
|
|
91
|
+
"expected bw_gm output to have 1 arg, got "
|
|
92
|
+
f"{len(bw_output_node.args)}"
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
bw_outputs = bw_output_node.args[0]
|
|
96
|
+
if not isinstance(bw_outputs, (tuple, list)):
|
|
97
|
+
bw_outputs = (bw_outputs,)
|
|
98
|
+
|
|
99
|
+
placeholder_storages = set()
|
|
100
|
+
for placeholder in bw_gm.graph.find_nodes(op="placeholder"):
|
|
101
|
+
val = placeholder.meta.get("val", None) if hasattr(placeholder, "meta") else None
|
|
102
|
+
if isinstance(val, torch.Tensor):
|
|
103
|
+
placeholder_storages.add(StorageWeakRef(val._typed_storage()))
|
|
104
|
+
|
|
105
|
+
def aliases_placeholder(node: torch.fx.Node) -> bool:
|
|
106
|
+
if node.op == "placeholder":
|
|
107
|
+
return True
|
|
108
|
+
val = node.meta.get("val", None) if hasattr(node, "meta") else None
|
|
109
|
+
if isinstance(val, torch.Tensor):
|
|
110
|
+
return StorageWeakRef(val._typed_storage()) in placeholder_storages
|
|
111
|
+
return False
|
|
112
|
+
|
|
113
|
+
new_bw_outputs = []
|
|
114
|
+
rewrote = False
|
|
115
|
+
for output in bw_outputs:
|
|
116
|
+
if isinstance(output, torch.fx.Node) and aliases_placeholder(output):
|
|
117
|
+
new_bw_outputs.append(self._insert_clone(output, bw_output_node))
|
|
118
|
+
rewrote = True
|
|
119
|
+
else:
|
|
120
|
+
new_bw_outputs.append(output)
|
|
121
|
+
|
|
122
|
+
if rewrote:
|
|
123
|
+
bw_output_node.args = (tuple(new_bw_outputs),)
|
|
124
|
+
bw_gm.graph.lint()
|
|
125
|
+
bw_gm.recompile()
|
|
126
|
+
|
|
127
|
+
original_init = scan_autograd_impl.__init__
|
|
128
|
+
|
|
129
|
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
|
130
|
+
original_init(self, *args, **kwargs)
|
|
131
|
+
self._break_bw_input_output_aliasing()
|
|
132
|
+
|
|
133
|
+
scan_autograd_impl._break_bw_input_output_aliasing = _break_bw_input_output_aliasing
|
|
134
|
+
scan_autograd_impl.__init__ = __init__
|
|
135
|
+
scan_autograd_impl._hoptorch_scan_backward_alias_patch = True
|
|
136
|
+
return True
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def _patch_current_scan_autograd(scan_module: Any) -> bool:
|
|
140
|
+
required = (
|
|
141
|
+
"HopGraphMinCutPartitioner",
|
|
142
|
+
"ScanAutogradImpl",
|
|
143
|
+
"ScanAutogradOp",
|
|
144
|
+
"disable_proxy_modes_tracing",
|
|
145
|
+
)
|
|
146
|
+
if not all(hasattr(scan_module, name) for name in required):
|
|
147
|
+
return False
|
|
148
|
+
|
|
149
|
+
_patch_backward_input_output_aliasing(scan_module)
|
|
150
|
+
|
|
151
|
+
def scan_autograd(combine_fn, init, xs, additional_inputs):
|
|
152
|
+
with scan_module.disable_proxy_modes_tracing():
|
|
153
|
+
sample_init = [_differentiable_sample_requires_grad(t) for t in init]
|
|
154
|
+
sample_args = (*sample_init, *[x[0] for x in xs], *additional_inputs)
|
|
155
|
+
try:
|
|
156
|
+
hop_partitioned_graph = (
|
|
157
|
+
scan_module.HopGraphMinCutPartitioner.create_partitioned_graph(
|
|
158
|
+
combine_fn,
|
|
159
|
+
sample_args,
|
|
160
|
+
always_recompute_complex_exprs=True,
|
|
161
|
+
)
|
|
162
|
+
)
|
|
163
|
+
except TypeError:
|
|
164
|
+
real_args = (*init, *[x[0] for x in xs], *additional_inputs)
|
|
165
|
+
hop_partitioned_graph = (
|
|
166
|
+
scan_module.HopGraphMinCutPartitioner.create_partitioned_graph(
|
|
167
|
+
combine_fn,
|
|
168
|
+
real_args,
|
|
169
|
+
sample_args,
|
|
170
|
+
always_recompute_complex_exprs=True,
|
|
171
|
+
)
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
return scan_module.ScanAutogradOp.apply(
|
|
175
|
+
hop_partitioned_graph,
|
|
176
|
+
len(init),
|
|
177
|
+
len(xs),
|
|
178
|
+
len(additional_inputs),
|
|
179
|
+
*init,
|
|
180
|
+
*xs,
|
|
181
|
+
*additional_inputs,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
return _register_scan_autograd_impl(scan_module, scan_autograd)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def _patch_older_scan_autograd(scan_module: Any) -> bool:
|
|
188
|
+
materialize_as_graph = getattr(scan_module, "materialize_as_graph", None)
|
|
189
|
+
if materialize_as_graph is None:
|
|
190
|
+
return False
|
|
191
|
+
if getattr(materialize_as_graph, "_hoptorch_scan_backward_sample_patch", False):
|
|
192
|
+
return True
|
|
193
|
+
|
|
194
|
+
def patched_materialize_as_graph(fn, args, *other_args, **kwargs):
|
|
195
|
+
return materialize_as_graph(
|
|
196
|
+
fn,
|
|
197
|
+
_map_differentiable_samples(args),
|
|
198
|
+
*other_args,
|
|
199
|
+
**kwargs,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
patched_materialize_as_graph._hoptorch_scan_backward_sample_patch = True
|
|
203
|
+
patched_materialize_as_graph._hoptorch_scan_backward_original = materialize_as_graph
|
|
204
|
+
scan_module.materialize_as_graph = patched_materialize_as_graph
|
|
205
|
+
return True
|
|
206
|
+
|
|
207
|
+
|
|
208
|
+
@implement_for("torch")
|
|
209
|
+
def install_scan_backward_patch(scan_module: Any) -> bool:
|
|
210
|
+
"""Attempt to install a scan backward compatibility patch."""
|
|
211
|
+
|
|
212
|
+
return False
|
|
213
|
+
|
|
214
|
+
|
|
215
|
+
@install_scan_backward_patch.register(from_version="2.7", to_version="2.8")
|
|
216
|
+
def _(scan_module: Any) -> bool:
|
|
217
|
+
from ._scan_backport_27 import install_scan_27_backport
|
|
218
|
+
|
|
219
|
+
return install_scan_27_backport(scan_module)
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
@install_scan_backward_patch.register(from_version="2.8")
|
|
223
|
+
def _(scan_module: Any) -> bool:
|
|
224
|
+
try:
|
|
225
|
+
return _patch_current_scan_autograd(scan_module) or _patch_older_scan_autograd(scan_module)
|
|
226
|
+
except Exception:
|
|
227
|
+
return False
|
|
228
|
+
|
|
229
|
+
|
|
230
|
+
@implement_for("torch")
|
|
231
|
+
def rollback_failed_scan_backward_patch(scan_module: Any) -> bool:
|
|
232
|
+
"""Best-effort rollback for reversible compatibility patches."""
|
|
233
|
+
|
|
234
|
+
return False
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
@rollback_failed_scan_backward_patch.register(from_version="2.7", to_version="2.8")
|
|
238
|
+
def _(scan_module: Any) -> bool:
|
|
239
|
+
from ._scan_backport_27 import rollback_scan_27_backport
|
|
240
|
+
|
|
241
|
+
return rollback_scan_27_backport(scan_module)
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
@rollback_failed_scan_backward_patch.register(from_version="2.8")
|
|
245
|
+
def _(scan_module: Any) -> bool:
|
|
246
|
+
materialize_as_graph = getattr(scan_module, "materialize_as_graph", None)
|
|
247
|
+
original = getattr(materialize_as_graph, "_hoptorch_scan_backward_original", None)
|
|
248
|
+
if original is None:
|
|
249
|
+
return False
|
|
250
|
+
scan_module.materialize_as_graph = original
|
|
251
|
+
return True
|