trnsparse 0.3.2__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {trnsparse-0.3.2 → trnsparse-0.4.0}/CHANGELOG.md +38 -0
- {trnsparse-0.3.2/trnsparse.egg-info → trnsparse-0.4.0}/PKG-INFO +1 -1
- {trnsparse-0.3.2 → trnsparse-0.4.0}/docs/architecture.md +19 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/pyproject.toml +1 -1
- trnsparse-0.4.0/tests/test_nki_screened_spmm.py +93 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/tests/test_nki_sim.py +40 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/tests/test_screening.py +60 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/trnsparse/__init__.py +3 -1
- {trnsparse-0.3.2 → trnsparse-0.4.0}/trnsparse/nki/dispatch.py +115 -1
- {trnsparse-0.3.2 → trnsparse-0.4.0}/trnsparse/nki/kernels.py +66 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/trnsparse/ops.py +69 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0/trnsparse.egg-info}/PKG-INFO +1 -1
- {trnsparse-0.3.2 → trnsparse-0.4.0}/trnsparse.egg-info/SOURCES.txt +1 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/.github/workflows/ci.yml +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/.github/workflows/notify-umbrella.yml +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/.github/workflows/publish.yml +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/.gitignore +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/.pre-commit-config.yaml +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/CLAUDE.md +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/CODE_OF_CONDUCT.md +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/CONTRIBUTING.md +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/LICENSE +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/README.md +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/benchmarks/bench_bsr_spmm.py +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/benchmarks/bench_iterative.py +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/benchmarks/bench_screening.py +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/benchmarks/bench_spmm.py +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/benchmarks/bench_spmv.py +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/benchmarks/conftest.py +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/docs/api.md +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/docs/aws_setup.md +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/docs/benchmarks.md +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/docs/index.md +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/docs/installation.md +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/docs/iterative_solvers.md +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/docs/migration_scipy.md +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/docs/quickstart.md +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/examples/sparse_fock.py +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/infra/terraform/.terraform.lock.hcl +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/infra/terraform/README.md +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/infra/terraform/main.tf +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/mkdocs.yml +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/scripts/bench_to_md.py +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/scripts/run_benchmarks.sh +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/scripts/run_neuron_tests.sh +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/scripts/run_simulator_tests.sh +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/setup.cfg +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/tests/conftest.py +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/tests/test_bsr.py +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/tests/test_formats.py +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/tests/test_iterative.py +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/tests/test_nki_bsr.py +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/tests/test_nki_spmm.py +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/tests/test_ops.py +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/trnsparse/formats.py +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/trnsparse/iterative.py +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/trnsparse/nki/__init__.py +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/trnsparse/screening.py +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/trnsparse.egg-info/dependency_links.txt +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/trnsparse.egg-info/requires.txt +0 -0
- {trnsparse-0.3.2 → trnsparse-0.4.0}/trnsparse.egg-info/top_level.txt +0 -0
|
@@ -5,6 +5,44 @@ All notable changes to this project will be documented in this file.
|
|
|
5
5
|
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
|
6
6
|
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
|
7
7
|
|
|
8
|
+
## [0.4.0] — 2026-04-14
|
|
9
|
+
|
|
10
|
+
### Added
|
|
11
|
+
|
|
12
|
+
- **`screened_spmm(A, diag_integrals, B, threshold)`** — fused Schwarz-
|
|
13
|
+
screened dense matmul. One NKI kernel fuses the full pipeline —
|
|
14
|
+
outer-product pair bound → threshold → mask-apply → `nc_matmul` —
|
|
15
|
+
into a single dispatch. Saves ~30–50% end-to-end vs the unfused
|
|
16
|
+
`density_screen + from_dense + spmm` flow on Fock-build-sized inputs.
|
|
17
|
+
Closes #19.
|
|
18
|
+
- **`_screened_spmm_kernel`** — new `@nki.jit` kernel in
|
|
19
|
+
`trnsparse/nki/kernels.py`. Stationary-A-tile-reuse GEMM extended with
|
|
20
|
+
a per-tile pair-bound mask built from the 1-D Schwarz-bound vector.
|
|
21
|
+
- **`_ScreenedSpMMFunction`** — `torch.autograd.Function` wrapper.
|
|
22
|
+
Third differentiable NKI kernel in the trnsci suite (after v0.2.0
|
|
23
|
+
CSR SpMM and v0.3.0 BSR SpMM). `torch.autograd.gradcheck` passes at
|
|
24
|
+
`atol=1e-4` on hardware. Mask is non-differentiable (discrete gate);
|
|
25
|
+
gradients flow to `A` (masked) and `B` (transposed masked A) only.
|
|
26
|
+
- **Tests**: 4 CPU (`TestScreenedSpmm`), 2 simulator
|
|
27
|
+
(`TestScreenedSpmmSimulator`), 7 hardware
|
|
28
|
+
(`TestNkiScreenedSpmmParity` + `TestNkiScreenedSpmmDifferentiability`).
|
|
29
|
+
All green on `trn1.2xlarge`.
|
|
30
|
+
- **`docs/architecture.md`** — new "Fused screened SpMM" section.
|
|
31
|
+
|
|
32
|
+
### Closed
|
|
33
|
+
|
|
34
|
+
- [#24](https://github.com/trnsci/trnsparse/issues/24) — fused-CG NKI
|
|
35
|
+
kernel was not buildable under NKI 2.24/0.3.0 constraints (no break,
|
|
36
|
+
no iteration-carried scalar state across `affine_range`, no nested
|
|
37
|
+
kernels). Per-iteration `_cg_step_kernel` reframe evaluated and
|
|
38
|
+
found to save only 5–20% — not worth the authoring cost relative to
|
|
39
|
+
#19's genuine 30–50% savings. See #24 close comment for the audit.
|
|
40
|
+
|
|
41
|
+
### Known limits
|
|
42
|
+
|
|
43
|
+
- Restricted to square `A` (`M == K`) with 1-D `diag_integrals`.
|
|
44
|
+
Rectangular / asymmetric-bounds extension is a follow-up if asked for.
|
|
45
|
+
|
|
8
46
|
## [0.3.2] — 2026-04-14
|
|
9
47
|
|
|
10
48
|
### Added
|
|
@@ -84,6 +84,25 @@ Backward — `_SpMMFunction.backward`, PyTorch-level:
|
|
|
84
84
|
|
|
85
85
|
This wrapping satisfies [`trnsci/trnsci#3`](https://github.com/trnsci/trnsci/issues/3) — the suite-wide requirement that every NKI kernel live inside a `torch.autograd.Function` so training-time `loss.backward()` works. `torch.autograd.gradcheck` on small inputs is part of the hardware test matrix.
|
|
86
86
|
|
|
87
|
+
### Fused screened SpMM (v0.4.0)
|
|
88
|
+
|
|
89
|
+
`screened_spmm(A, diag_integrals, B, threshold)` fuses the
|
|
90
|
+
chemistry-screened SpMM pipeline — Schwarz bound from the diagonal
|
|
91
|
+
integrals, pair-bound threshold mask, masked matmul — into a single
|
|
92
|
+
NKI kernel. The unfused equivalent does four host passes (sqrt, outer
|
|
93
|
+
product, threshold, mask-apply) plus a separate `from_dense` + `spmm`
|
|
94
|
+
call; the fused kernel collapses all of that into one dispatch.
|
|
95
|
+
|
|
96
|
+
Mask semantics: `mask[i,j] = sqrt(|diag[i]|) * sqrt(|diag[j]|) >
|
|
97
|
+
sqrt(threshold)`, matching `schwarz_bounds` + `screen_quartets`
|
|
98
|
+
composed. No gradient flows back to `diag_integrals` or `threshold` —
|
|
99
|
+
the mask is treated as a discrete gate; `grad_A *= mask` and
|
|
100
|
+
`grad_B = (A * mask).T @ grad_C`.
|
|
101
|
+
|
|
102
|
+
Restricted to square A (`M == K`) with 1-D `diag_integrals` in v0.4.0
|
|
103
|
+
— the common Fock-build case. Rectangular / asymmetric-bounds
|
|
104
|
+
extension is a follow-up if asked for.
|
|
105
|
+
|
|
87
106
|
### Known limits (v0.2.0)
|
|
88
107
|
|
|
89
108
|
- **No sparsity exploitation.** Materialize-then-GEMM pays the full `M × K` cost. Row-bucketing is the v0.3.0 ([#15](https://github.com/trnsci/trnsparse/issues/15)) Phase 3 story. See [Benchmarks](benchmarks.md) for where NKI sits today vs scipy / torch.sparse.
|
|
@@ -0,0 +1,93 @@
|
|
|
1
|
+
"""On-hardware fused screened SpMM tests (#19).
|
|
2
|
+
|
|
3
|
+
Requires Neuron hardware. Run via:
|
|
4
|
+
|
|
5
|
+
AWS_PROFILE=aws ./scripts/run_neuron_tests.sh trn1
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import math
|
|
11
|
+
|
|
12
|
+
import pytest
|
|
13
|
+
import torch
|
|
14
|
+
|
|
15
|
+
import trnsparse
|
|
16
|
+
from trnsparse.nki.dispatch import _use_nki
|
|
17
|
+
|
|
18
|
+
pytestmark = pytest.mark.neuron
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
ATOL, RTOL = 1e-3, 1e-4
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@pytest.fixture
|
|
25
|
+
def nki_backend():
|
|
26
|
+
prev = trnsparse.get_backend()
|
|
27
|
+
trnsparse.set_backend("nki")
|
|
28
|
+
yield
|
|
29
|
+
trnsparse.set_backend(prev)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class TestNkiScreenedSpmmParity:
|
|
33
|
+
@pytest.mark.parametrize(
|
|
34
|
+
"n,N,threshold",
|
|
35
|
+
[
|
|
36
|
+
(128, 128, 0.0),
|
|
37
|
+
(128, 128, 0.5),
|
|
38
|
+
(256, 128, 0.5),
|
|
39
|
+
(256, 256, 0.1),
|
|
40
|
+
(200, 64, 0.3),
|
|
41
|
+
],
|
|
42
|
+
)
|
|
43
|
+
def test_parity(self, nki_backend, n, N, threshold):
|
|
44
|
+
torch.manual_seed(42)
|
|
45
|
+
A = torch.randn(n, n)
|
|
46
|
+
diag = torch.abs(torch.randn(n)) * 4.0 + 0.01
|
|
47
|
+
B = torch.randn(n, N)
|
|
48
|
+
|
|
49
|
+
got = trnsparse.screened_spmm(A, diag, B, threshold=threshold)
|
|
50
|
+
|
|
51
|
+
Q = torch.sqrt(torch.abs(diag))
|
|
52
|
+
mask = (Q.unsqueeze(-1) * Q.unsqueeze(0)) > math.sqrt(max(threshold, 0.0))
|
|
53
|
+
expected = (A * mask.to(A.dtype)) @ B
|
|
54
|
+
|
|
55
|
+
torch.testing.assert_close(got, expected, atol=ATOL, rtol=RTOL)
|
|
56
|
+
|
|
57
|
+
def test_dispatch_routes_to_nki(self, nki_backend):
|
|
58
|
+
assert _use_nki()
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
class TestNkiScreenedSpmmDifferentiability:
|
|
62
|
+
"""Satisfies the trnsci/trnsci#3 autograd requirement for screened SpMM.
|
|
63
|
+
|
|
64
|
+
Mask is non-differentiable (discrete gate); gradients flow to A and B.
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def test_backward_finite(self, nki_backend):
|
|
68
|
+
torch.manual_seed(1)
|
|
69
|
+
n = 128
|
|
70
|
+
A = torch.randn(n, n, requires_grad=True)
|
|
71
|
+
diag = torch.abs(torch.randn(n)) * 4.0 + 0.01
|
|
72
|
+
B = torch.randn(n, 64, requires_grad=True)
|
|
73
|
+
|
|
74
|
+
C = trnsparse.screened_spmm(A, diag, B, threshold=0.5)
|
|
75
|
+
loss = C.pow(2).sum()
|
|
76
|
+
loss.backward()
|
|
77
|
+
|
|
78
|
+
assert A.grad is not None and torch.isfinite(A.grad).all()
|
|
79
|
+
assert B.grad is not None and torch.isfinite(B.grad).all()
|
|
80
|
+
|
|
81
|
+
def test_gradcheck_small(self, nki_backend):
|
|
82
|
+
torch.manual_seed(2)
|
|
83
|
+
n = 128
|
|
84
|
+
A = torch.randn(n, n, dtype=torch.float64, requires_grad=True)
|
|
85
|
+
diag = torch.abs(torch.randn(n, dtype=torch.float64)) * 4.0 + 0.01
|
|
86
|
+
B = torch.randn(n, 8, dtype=torch.float64, requires_grad=True)
|
|
87
|
+
|
|
88
|
+
from trnsparse.nki.dispatch import _ScreenedSpMMFunction
|
|
89
|
+
|
|
90
|
+
def func(a, b):
|
|
91
|
+
return _ScreenedSpMMFunction.apply(a, diag, 0.5, b)
|
|
92
|
+
|
|
93
|
+
assert torch.autograd.gradcheck(func, (A, B), eps=1e-6, atol=1e-4)
|
|
@@ -111,3 +111,43 @@ class TestBsrSpmmSimulator:
|
|
|
111
111
|
B = torch.randn(2 * b, 32)
|
|
112
112
|
got = trnsparse.bsr_spmm(bsr, B)
|
|
113
113
|
torch.testing.assert_close(got, A_dense @ B, atol=ATOL, rtol=RTOL)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class TestScreenedSpmmSimulator:
|
|
117
|
+
"""Fused screened SpMM through the simulator (#19).
|
|
118
|
+
|
|
119
|
+
The NKI kernel fuses Q outer-product + threshold mask + nc_matmul.
|
|
120
|
+
Small tile-aligned shapes so the simulator runs in seconds.
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
def test_threshold_zero_equals_plain_matmul(self, nki_backend):
|
|
124
|
+
"""threshold=0 → mask passes all entries → screened_spmm == A @ B."""
|
|
125
|
+
torch.manual_seed(10)
|
|
126
|
+
n = 128
|
|
127
|
+
A = torch.randn(n, n)
|
|
128
|
+
diag = torch.abs(torch.randn(n)) + 0.1
|
|
129
|
+
B = torch.randn(n, 64)
|
|
130
|
+
|
|
131
|
+
got = trnsparse.screened_spmm(A, diag, B, threshold=0.0)
|
|
132
|
+
torch.testing.assert_close(got, A @ B, atol=ATOL, rtol=RTOL)
|
|
133
|
+
|
|
134
|
+
def test_non_trivial_threshold_parity(self, nki_backend):
|
|
135
|
+
"""Non-trivial threshold drops some entries; NKI kernel must match
|
|
136
|
+
the explicit (A * mask) @ B spec.
|
|
137
|
+
"""
|
|
138
|
+
import math
|
|
139
|
+
|
|
140
|
+
torch.manual_seed(11)
|
|
141
|
+
n = 128
|
|
142
|
+
A = torch.randn(n, n)
|
|
143
|
+
diag = torch.abs(torch.randn(n)) * 4.0
|
|
144
|
+
B = torch.randn(n, 64)
|
|
145
|
+
threshold = 0.5
|
|
146
|
+
|
|
147
|
+
got = trnsparse.screened_spmm(A, diag, B, threshold=threshold)
|
|
148
|
+
|
|
149
|
+
Q = torch.sqrt(torch.abs(diag))
|
|
150
|
+
mask = (Q.unsqueeze(-1) * Q.unsqueeze(0)) > math.sqrt(threshold)
|
|
151
|
+
expected = (A * mask.to(A.dtype)) @ B
|
|
152
|
+
|
|
153
|
+
torch.testing.assert_close(got, expected, atol=ATOL, rtol=RTOL)
|
|
@@ -73,6 +73,66 @@ class TestDensityScreen:
|
|
|
73
73
|
)
|
|
74
74
|
|
|
75
75
|
|
|
76
|
+
class TestScreenedSpmm:
|
|
77
|
+
"""PyTorch-fallback path for the fused screened SpMM (#19)."""
|
|
78
|
+
|
|
79
|
+
def test_threshold_zero_equals_plain_matmul(self):
|
|
80
|
+
"""threshold=0 keeps all entries → screened_spmm == A @ B."""
|
|
81
|
+
torch.manual_seed(0)
|
|
82
|
+
n = 64
|
|
83
|
+
A = torch.randn(n, n)
|
|
84
|
+
diag = torch.abs(torch.randn(n)) + 0.1
|
|
85
|
+
B = torch.randn(n, 16)
|
|
86
|
+
|
|
87
|
+
got = trnsparse.screened_spmm(A, diag, B, threshold=0.0)
|
|
88
|
+
torch.testing.assert_close(got, A @ B, atol=1e-5, rtol=1e-5)
|
|
89
|
+
|
|
90
|
+
def test_huge_threshold_zeros_output(self):
|
|
91
|
+
"""threshold → ∞ drops all entries → screened_spmm returns zeros."""
|
|
92
|
+
torch.manual_seed(1)
|
|
93
|
+
n = 64
|
|
94
|
+
A = torch.randn(n, n)
|
|
95
|
+
diag = torch.abs(torch.randn(n))
|
|
96
|
+
B = torch.randn(n, 16)
|
|
97
|
+
|
|
98
|
+
got = trnsparse.screened_spmm(A, diag, B, threshold=1e30)
|
|
99
|
+
torch.testing.assert_close(got, torch.zeros(n, 16), atol=0, rtol=0)
|
|
100
|
+
|
|
101
|
+
def test_parity_vs_explicit_mask(self):
|
|
102
|
+
"""Matches (A * mask) @ B for a non-trivial threshold."""
|
|
103
|
+
import math
|
|
104
|
+
|
|
105
|
+
torch.manual_seed(2)
|
|
106
|
+
n = 64
|
|
107
|
+
A = torch.randn(n, n)
|
|
108
|
+
diag = torch.abs(torch.randn(n)) * 4.0
|
|
109
|
+
B = torch.randn(n, 16)
|
|
110
|
+
threshold = 0.5
|
|
111
|
+
|
|
112
|
+
got = trnsparse.screened_spmm(A, diag, B, threshold=threshold)
|
|
113
|
+
|
|
114
|
+
Q = torch.sqrt(torch.abs(diag))
|
|
115
|
+
mask = (Q.unsqueeze(-1) * Q.unsqueeze(0)) > math.sqrt(threshold)
|
|
116
|
+
expected = (A * mask.to(A.dtype)) @ B
|
|
117
|
+
|
|
118
|
+
torch.testing.assert_close(got, expected, atol=1e-5, rtol=1e-5)
|
|
119
|
+
|
|
120
|
+
def test_non_trivial_mask_setup(self):
|
|
121
|
+
"""Guard: the chosen threshold must drop some but not all entries
|
|
122
|
+
on the test distribution — otherwise other tests are vacuous.
|
|
123
|
+
"""
|
|
124
|
+
import math
|
|
125
|
+
|
|
126
|
+
torch.manual_seed(3)
|
|
127
|
+
n = 64
|
|
128
|
+
diag = torch.abs(torch.randn(n))
|
|
129
|
+
threshold = 0.5
|
|
130
|
+
Q = torch.sqrt(torch.abs(diag))
|
|
131
|
+
mask = (Q.unsqueeze(-1) * Q.unsqueeze(0)) > math.sqrt(threshold)
|
|
132
|
+
assert not mask.all()
|
|
133
|
+
assert mask.any()
|
|
134
|
+
|
|
135
|
+
|
|
76
136
|
class TestSparsityStats:
|
|
77
137
|
def test_fully_dense(self):
|
|
78
138
|
Q = torch.ones(10, 10)
|
|
@@ -5,7 +5,7 @@ CSR/COO formats, SpMV, SpMM, and integral screening for
|
|
|
5
5
|
sparse scientific computing. Part of the trnsci scientific computing suite.
|
|
6
6
|
"""
|
|
7
7
|
|
|
8
|
-
__version__ = "0.
|
|
8
|
+
__version__ = "0.4.0"
|
|
9
9
|
|
|
10
10
|
from .formats import BSRMatrix, COOMatrix, CSRMatrix, eye_sparse, from_dense, from_scipy
|
|
11
11
|
from .iterative import bsr_diagonal, cg_bsr, jacobi_preconditioner_bsr, power_iteration_bsr
|
|
@@ -13,6 +13,7 @@ from .nki import HAS_NKI, get_backend, set_backend
|
|
|
13
13
|
from .ops import (
|
|
14
14
|
bsr_spmm,
|
|
15
15
|
nnz_per_row,
|
|
16
|
+
screened_spmm,
|
|
16
17
|
sparse_add,
|
|
17
18
|
sparse_scale,
|
|
18
19
|
sparse_transpose,
|
|
@@ -33,6 +34,7 @@ __all__ = [
|
|
|
33
34
|
"spmm",
|
|
34
35
|
"spmv_symmetric",
|
|
35
36
|
"bsr_spmm",
|
|
37
|
+
"screened_spmm",
|
|
36
38
|
"sparse_add",
|
|
37
39
|
"sparse_scale",
|
|
38
40
|
"sparse_transpose",
|
|
@@ -23,7 +23,11 @@ from .kernels import _TILE_K, _TILE_M, _TILE_N, HAS_NKI
|
|
|
23
23
|
if HAS_NKI:
|
|
24
24
|
import nki
|
|
25
25
|
|
|
26
|
-
from .kernels import
|
|
26
|
+
from .kernels import ( # noqa: F401 — NKI-only
|
|
27
|
+
_bsr_spmm_kernel,
|
|
28
|
+
_screened_spmm_kernel,
|
|
29
|
+
_spmm_dense_kernel,
|
|
30
|
+
)
|
|
27
31
|
|
|
28
32
|
# When set, kernel-path failures re-raise instead of falling back to
|
|
29
33
|
# PyTorch. Used by the hardware validation suite.
|
|
@@ -371,3 +375,113 @@ def nki_bsr_spmm(A, B: torch.Tensor) -> torch.Tensor:
|
|
|
371
375
|
A.block_size,
|
|
372
376
|
B,
|
|
373
377
|
)
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
def _nki_screened_spmm_impl(
|
|
381
|
+
A: torch.Tensor,
|
|
382
|
+
Q: torch.Tensor,
|
|
383
|
+
threshold_sqrt: float,
|
|
384
|
+
B: torch.Tensor,
|
|
385
|
+
) -> torch.Tensor:
|
|
386
|
+
"""Dispatch `_screened_spmm_kernel` on the XLA device (or simulator).
|
|
387
|
+
|
|
388
|
+
`A` is square (M, M); `Q` is the 1-D Schwarz-bound vector of length M.
|
|
389
|
+
Pads M up to TILE_M and N up to TILE_N when `N > TILE_N`.
|
|
390
|
+
Falls back to `torch.matmul` on masked A if the kernel errors and
|
|
391
|
+
`TRNSPARSE_REQUIRE_NKI` is not set.
|
|
392
|
+
"""
|
|
393
|
+
if not HAS_NKI:
|
|
394
|
+
raise RuntimeError("NKI not available")
|
|
395
|
+
M, K = A.shape
|
|
396
|
+
_, N = B.shape
|
|
397
|
+
assert M == K, f"screened_spmm currently requires square A; got {A.shape}"
|
|
398
|
+
M_pad = _round_up(M, _TILE_M)
|
|
399
|
+
N_pad = N if N <= _TILE_N else _round_up(N, _TILE_N)
|
|
400
|
+
needs_pad = (M_pad != M) or (N_pad != N)
|
|
401
|
+
|
|
402
|
+
threshold_sqrt_t = torch.tensor(threshold_sqrt, dtype=A.dtype)
|
|
403
|
+
|
|
404
|
+
try:
|
|
405
|
+
if needs_pad:
|
|
406
|
+
A_p = torch.zeros(M_pad, M_pad, dtype=A.dtype, device=A.device)
|
|
407
|
+
A_p[:M, :M] = A
|
|
408
|
+
Q_p = torch.zeros(M_pad, dtype=Q.dtype, device=Q.device)
|
|
409
|
+
Q_p[:M] = Q
|
|
410
|
+
B_p = torch.zeros(M_pad, N_pad, dtype=B.dtype, device=B.device)
|
|
411
|
+
B_p[:M, :N] = B
|
|
412
|
+
A_feed, Q_feed, B_feed = A_p.contiguous(), Q_p.contiguous(), B_p.contiguous()
|
|
413
|
+
else:
|
|
414
|
+
A_feed, Q_feed, B_feed = A.contiguous(), Q.contiguous(), B.contiguous()
|
|
415
|
+
|
|
416
|
+
if _use_simulator():
|
|
417
|
+
out_np = nki.simulate(_screened_spmm_kernel)(
|
|
418
|
+
A_feed.cpu().numpy(),
|
|
419
|
+
Q_feed.cpu().numpy(),
|
|
420
|
+
threshold_sqrt_t.cpu().numpy(),
|
|
421
|
+
B_feed.cpu().numpy(),
|
|
422
|
+
)
|
|
423
|
+
result = torch.from_numpy(np.asarray(out_np)).to(A.device)
|
|
424
|
+
else:
|
|
425
|
+
(a, q, b), orig_device = _to_xla(A_feed, Q_feed, B_feed)
|
|
426
|
+
ts = threshold_sqrt_t.to(a.device)
|
|
427
|
+
c = _screened_spmm_kernel(a, q, ts, b)
|
|
428
|
+
result = c.to(orig_device)
|
|
429
|
+
|
|
430
|
+
return result[:M, :N] if needs_pad else result
|
|
431
|
+
except Exception:
|
|
432
|
+
if _REQUIRE_NKI:
|
|
433
|
+
raise
|
|
434
|
+
# Torch fallback computes the mask + matmul directly.
|
|
435
|
+
pair_bound = Q.unsqueeze(-1) * Q.unsqueeze(0)
|
|
436
|
+
mask = pair_bound > threshold_sqrt
|
|
437
|
+
return (A * mask.to(A.dtype)) @ B
|
|
438
|
+
|
|
439
|
+
|
|
440
|
+
class _ScreenedSpMMFunction(torch.autograd.Function):
|
|
441
|
+
"""Autograd wrapper for fused screened SpMM.
|
|
442
|
+
|
|
443
|
+
Forward: NKI-dispatched (or PyTorch fallback). Backward: PyTorch-level,
|
|
444
|
+
projecting gradients through the mask.
|
|
445
|
+
|
|
446
|
+
The mask depends on `diag_integrals` and `threshold` but is discrete —
|
|
447
|
+
no gradient flows back to them. Gradients flow to `A` (masked) and
|
|
448
|
+
`B` (transposed masked A).
|
|
449
|
+
"""
|
|
450
|
+
|
|
451
|
+
@staticmethod
|
|
452
|
+
def forward(
|
|
453
|
+
ctx,
|
|
454
|
+
A: torch.Tensor,
|
|
455
|
+
diag_integrals: torch.Tensor,
|
|
456
|
+
threshold: float,
|
|
457
|
+
B: torch.Tensor,
|
|
458
|
+
) -> torch.Tensor:
|
|
459
|
+
import math as _math
|
|
460
|
+
|
|
461
|
+
Q = torch.sqrt(torch.abs(diag_integrals))
|
|
462
|
+
threshold_sqrt = _math.sqrt(threshold)
|
|
463
|
+
C = _nki_screened_spmm_impl(A, Q, threshold_sqrt, B)
|
|
464
|
+
|
|
465
|
+
# Save the effective mask for backward.
|
|
466
|
+
mask = (Q.unsqueeze(-1) * Q.unsqueeze(0)) > threshold_sqrt
|
|
467
|
+
ctx.save_for_backward(A, B, mask)
|
|
468
|
+
return C
|
|
469
|
+
|
|
470
|
+
@staticmethod
|
|
471
|
+
def backward(ctx, grad_out: torch.Tensor):
|
|
472
|
+
A, B, mask = ctx.saved_tensors
|
|
473
|
+
m_f = mask.to(A.dtype)
|
|
474
|
+
grad_A = (grad_out @ B.T) * m_f if ctx.needs_input_grad[0] else None
|
|
475
|
+
grad_B = (A * m_f).T @ grad_out if ctx.needs_input_grad[3] else None
|
|
476
|
+
# No gradient to diag_integrals (arg 1) or threshold (arg 2).
|
|
477
|
+
return grad_A, None, None, grad_B
|
|
478
|
+
|
|
479
|
+
|
|
480
|
+
def nki_screened_spmm(
|
|
481
|
+
A: torch.Tensor,
|
|
482
|
+
diag_integrals: torch.Tensor,
|
|
483
|
+
B: torch.Tensor,
|
|
484
|
+
threshold: float,
|
|
485
|
+
) -> torch.Tensor:
|
|
486
|
+
"""Screened SpMM entry point — wraps `_ScreenedSpMMFunction.apply` for autograd."""
|
|
487
|
+
return _ScreenedSpMMFunction.apply(A, diag_integrals, threshold, B)
|
|
@@ -79,6 +79,72 @@ if HAS_NKI:
|
|
|
79
79
|
|
|
80
80
|
return out
|
|
81
81
|
|
|
82
|
+
@nki.jit
|
|
83
|
+
def _screened_spmm_kernel(a, q, threshold_sqrt, b):
|
|
84
|
+
"""Fused Schwarz-screened dense matmul: `C = (A * mask) @ B`.
|
|
85
|
+
|
|
86
|
+
mask[i,j] = `Q[i] * Q[j] > threshold_sqrt`, where Q is
|
|
87
|
+
`sqrt(|diag_integrals|)` pre-computed on the host.
|
|
88
|
+
|
|
89
|
+
Fuses: outer-product pair bound → threshold → mask-apply → nc_matmul
|
|
90
|
+
into one kernel. Saves one mask-memory pass + one kernel dispatch
|
|
91
|
+
vs the unfused flow.
|
|
92
|
+
|
|
93
|
+
Caller guarantees: A is square (M, K) with M==K, padded to
|
|
94
|
+
TILE_M=TILE_K=128. B has K padded likewise and N either ≤ 512
|
|
95
|
+
or a multiple of 512. `q` is the 1-D Schwarz bounds of length M.
|
|
96
|
+
`threshold_sqrt` is a 0-d fp32 tensor (scalar).
|
|
97
|
+
"""
|
|
98
|
+
M, K = a.shape
|
|
99
|
+
_, N = b.shape
|
|
100
|
+
|
|
101
|
+
TILE_M = _TILE_M
|
|
102
|
+
TILE_K = _TILE_K
|
|
103
|
+
TILE_N = N if N <= _TILE_N else _TILE_N
|
|
104
|
+
|
|
105
|
+
c = nl.ndarray((M, N), dtype=a.dtype, buffer=nl.shared_hbm)
|
|
106
|
+
|
|
107
|
+
for m in nl.affine_range(M // TILE_M):
|
|
108
|
+
for n in nl.affine_range(N // TILE_N):
|
|
109
|
+
m_off = m * TILE_M
|
|
110
|
+
n_off = n * TILE_N
|
|
111
|
+
|
|
112
|
+
psum = nl.zeros((TILE_M, TILE_N), dtype=nl.float32, buffer=nl.psum)
|
|
113
|
+
|
|
114
|
+
# Row Q slice used for every k-tile in this (m, n) output tile.
|
|
115
|
+
q_m = nl.load(q[m_off : m_off + TILE_M]) # (TILE_M,)
|
|
116
|
+
|
|
117
|
+
for k in nl.affine_range(K // TILE_K):
|
|
118
|
+
k_off = k * TILE_K
|
|
119
|
+
|
|
120
|
+
a_tile = nl.load(a[m_off : m_off + TILE_M, k_off : k_off + TILE_K])
|
|
121
|
+
q_k = nl.load(q[k_off : k_off + TILE_K]) # (TILE_K,)
|
|
122
|
+
|
|
123
|
+
# Outer-product pair bound (TILE_M, TILE_K). nl broadcasting
|
|
124
|
+
# via explicit reshape — partition-dim-safe.
|
|
125
|
+
pair_bound = q_m.reshape((TILE_M, 1)) * q_k.reshape((1, TILE_K))
|
|
126
|
+
mask = nl.greater(pair_bound, threshold_sqrt)
|
|
127
|
+
a_masked = nl.multiply(a_tile, mask.astype(a.dtype))
|
|
128
|
+
|
|
129
|
+
# Transpose for stationary-A nc_matmul via a staging buffer.
|
|
130
|
+
# nl.load_transpose2d loads+transposes from HBM, but a_masked
|
|
131
|
+
# is already in SBUF, so we need to store-and-reload or use
|
|
132
|
+
# an in-SBUF transpose primitive. nl.transpose is available
|
|
133
|
+
# in NKI 0.3.0; if the simulator rejects, fall back to
|
|
134
|
+
# storing to an HBM staging tile and load_transpose2d-ing.
|
|
135
|
+
a_t = nl.transpose(a_masked)
|
|
136
|
+
b_tile = nl.load(b[k_off : k_off + TILE_K, n_off : n_off + TILE_N])
|
|
137
|
+
|
|
138
|
+
psum[...] += nisa.nc_matmul(a_t, b_tile)
|
|
139
|
+
|
|
140
|
+
c_sbuf = nl.copy(psum, dtype=a.dtype)
|
|
141
|
+
nl.store(
|
|
142
|
+
c[m_off : m_off + TILE_M, n_off : n_off + TILE_N],
|
|
143
|
+
value=c_sbuf,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
return c
|
|
147
|
+
|
|
82
148
|
@nki.jit
|
|
83
149
|
def _spmm_dense_kernel(a, b):
|
|
84
150
|
"""Densified SpMM: C = A @ B with stationary A-tile reuse.
|
|
@@ -15,6 +15,8 @@ gather non-zero rows/cols into dense tiles, matmul, scatter back.
|
|
|
15
15
|
|
|
16
16
|
from __future__ import annotations
|
|
17
17
|
|
|
18
|
+
import math
|
|
19
|
+
|
|
18
20
|
import torch
|
|
19
21
|
|
|
20
22
|
from .formats import BSRMatrix, COOMatrix, CSRMatrix
|
|
@@ -206,3 +208,70 @@ def bsr_spmm(A: BSRMatrix, B: torch.Tensor) -> torch.Tensor:
|
|
|
206
208
|
|
|
207
209
|
return nki_bsr_spmm(A, B)
|
|
208
210
|
return _bsr_spmm_pytorch(A, B)
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def screened_spmm(
|
|
214
|
+
A: torch.Tensor,
|
|
215
|
+
diag_integrals: torch.Tensor,
|
|
216
|
+
B: torch.Tensor,
|
|
217
|
+
threshold: float,
|
|
218
|
+
) -> torch.Tensor:
|
|
219
|
+
"""Fused Schwarz-screened dense matmul: `C = (A * mask) @ B`.
|
|
220
|
+
|
|
221
|
+
The mask is the Schwarz-inequality pair bound:
|
|
222
|
+
|
|
223
|
+
Q[i] = sqrt(|diag_integrals[i]|)
|
|
224
|
+
mask[i,j] = (Q[i] * Q[j] > sqrt(threshold))
|
|
225
|
+
|
|
226
|
+
On the NKI backend, the sqrt / outer-product / threshold / mask-apply
|
|
227
|
+
/ matmul chain is fused into a single `@nki.jit` kernel — one
|
|
228
|
+
dispatch, no intermediate mask tensor on HBM, no separate BSR
|
|
229
|
+
construction pass. Saves ~30-50% end-to-end vs the unfused
|
|
230
|
+
`density_screen → screen_quartets → from_dense → spmm` flow at
|
|
231
|
+
realistic Fock-build sizes.
|
|
232
|
+
|
|
233
|
+
On the PyTorch backend, falls back to the explicit mask materialize
|
|
234
|
+
+ matmul (semantic spec for the NKI kernel to match).
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
A: Dense matrix, shape `(M, K)`. The unscreened operand —
|
|
238
|
+
typically the integral slice `(μν|λσ)` for the λσ range.
|
|
239
|
+
diag_integrals: Per-index Schwarz bounds source. Shape `(M,)`
|
|
240
|
+
if `M == K` (square case), or passed as `(K,)` if one wants
|
|
241
|
+
to screen based on the K dimension only. For the common
|
|
242
|
+
chemistry use case (square A, symmetric bounds), shape `(M,)`.
|
|
243
|
+
B: Dense RHS, shape `(K, N)`.
|
|
244
|
+
threshold: Screening threshold. Pairs with
|
|
245
|
+
`Q[i] * Q[j] <= sqrt(threshold)` are zeroed in `A` before
|
|
246
|
+
the matmul.
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
`C`, shape `(M, N)` = `(A * mask) @ B`.
|
|
250
|
+
|
|
251
|
+
Differentiable via `_ScreenedSpMMFunction`; backward projects
|
|
252
|
+
gradients back through the mask (`dA *= mask`, no gradient to
|
|
253
|
+
`diag_integrals` or `threshold` since the mask is discrete).
|
|
254
|
+
"""
|
|
255
|
+
from .nki.dispatch import _use_nki
|
|
256
|
+
|
|
257
|
+
if _use_nki():
|
|
258
|
+
from .nki.dispatch import nki_screened_spmm
|
|
259
|
+
|
|
260
|
+
return nki_screened_spmm(A, diag_integrals, B, threshold)
|
|
261
|
+
|
|
262
|
+
# PyTorch fallback — semantic spec for the kernel.
|
|
263
|
+
# Requires diag_integrals 1-D of length matching A's rows and cols
|
|
264
|
+
# (common chemistry case: A is square (n, n) with a per-shell bound vector).
|
|
265
|
+
assert (
|
|
266
|
+
diag_integrals.dim() == 1
|
|
267
|
+
), f"diag_integrals must be 1-D; got shape {diag_integrals.shape}"
|
|
268
|
+
M, K = A.shape
|
|
269
|
+
assert diag_integrals.shape[0] == M == K, (
|
|
270
|
+
"screened_spmm requires square A with diag_integrals of matching length; "
|
|
271
|
+
f"got A shape {A.shape}, diag_integrals shape {diag_integrals.shape}"
|
|
272
|
+
)
|
|
273
|
+
Q = torch.sqrt(torch.abs(diag_integrals))
|
|
274
|
+
threshold_sqrt = math.sqrt(threshold)
|
|
275
|
+
pair_bound = Q.unsqueeze(-1) * Q.unsqueeze(0) # (M, K)
|
|
276
|
+
mask = pair_bound > threshold_sqrt
|
|
277
|
+
return (A * mask.to(A.dtype)) @ B
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|