trnsparse 0.3.2__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. {trnsparse-0.3.2 → trnsparse-0.4.0}/CHANGELOG.md +38 -0
  2. {trnsparse-0.3.2/trnsparse.egg-info → trnsparse-0.4.0}/PKG-INFO +1 -1
  3. {trnsparse-0.3.2 → trnsparse-0.4.0}/docs/architecture.md +19 -0
  4. {trnsparse-0.3.2 → trnsparse-0.4.0}/pyproject.toml +1 -1
  5. trnsparse-0.4.0/tests/test_nki_screened_spmm.py +93 -0
  6. {trnsparse-0.3.2 → trnsparse-0.4.0}/tests/test_nki_sim.py +40 -0
  7. {trnsparse-0.3.2 → trnsparse-0.4.0}/tests/test_screening.py +60 -0
  8. {trnsparse-0.3.2 → trnsparse-0.4.0}/trnsparse/__init__.py +3 -1
  9. {trnsparse-0.3.2 → trnsparse-0.4.0}/trnsparse/nki/dispatch.py +115 -1
  10. {trnsparse-0.3.2 → trnsparse-0.4.0}/trnsparse/nki/kernels.py +66 -0
  11. {trnsparse-0.3.2 → trnsparse-0.4.0}/trnsparse/ops.py +69 -0
  12. {trnsparse-0.3.2 → trnsparse-0.4.0/trnsparse.egg-info}/PKG-INFO +1 -1
  13. {trnsparse-0.3.2 → trnsparse-0.4.0}/trnsparse.egg-info/SOURCES.txt +1 -0
  14. {trnsparse-0.3.2 → trnsparse-0.4.0}/.github/workflows/ci.yml +0 -0
  15. {trnsparse-0.3.2 → trnsparse-0.4.0}/.github/workflows/notify-umbrella.yml +0 -0
  16. {trnsparse-0.3.2 → trnsparse-0.4.0}/.github/workflows/publish.yml +0 -0
  17. {trnsparse-0.3.2 → trnsparse-0.4.0}/.gitignore +0 -0
  18. {trnsparse-0.3.2 → trnsparse-0.4.0}/.pre-commit-config.yaml +0 -0
  19. {trnsparse-0.3.2 → trnsparse-0.4.0}/CLAUDE.md +0 -0
  20. {trnsparse-0.3.2 → trnsparse-0.4.0}/CODE_OF_CONDUCT.md +0 -0
  21. {trnsparse-0.3.2 → trnsparse-0.4.0}/CONTRIBUTING.md +0 -0
  22. {trnsparse-0.3.2 → trnsparse-0.4.0}/LICENSE +0 -0
  23. {trnsparse-0.3.2 → trnsparse-0.4.0}/README.md +0 -0
  24. {trnsparse-0.3.2 → trnsparse-0.4.0}/benchmarks/bench_bsr_spmm.py +0 -0
  25. {trnsparse-0.3.2 → trnsparse-0.4.0}/benchmarks/bench_iterative.py +0 -0
  26. {trnsparse-0.3.2 → trnsparse-0.4.0}/benchmarks/bench_screening.py +0 -0
  27. {trnsparse-0.3.2 → trnsparse-0.4.0}/benchmarks/bench_spmm.py +0 -0
  28. {trnsparse-0.3.2 → trnsparse-0.4.0}/benchmarks/bench_spmv.py +0 -0
  29. {trnsparse-0.3.2 → trnsparse-0.4.0}/benchmarks/conftest.py +0 -0
  30. {trnsparse-0.3.2 → trnsparse-0.4.0}/docs/api.md +0 -0
  31. {trnsparse-0.3.2 → trnsparse-0.4.0}/docs/aws_setup.md +0 -0
  32. {trnsparse-0.3.2 → trnsparse-0.4.0}/docs/benchmarks.md +0 -0
  33. {trnsparse-0.3.2 → trnsparse-0.4.0}/docs/index.md +0 -0
  34. {trnsparse-0.3.2 → trnsparse-0.4.0}/docs/installation.md +0 -0
  35. {trnsparse-0.3.2 → trnsparse-0.4.0}/docs/iterative_solvers.md +0 -0
  36. {trnsparse-0.3.2 → trnsparse-0.4.0}/docs/migration_scipy.md +0 -0
  37. {trnsparse-0.3.2 → trnsparse-0.4.0}/docs/quickstart.md +0 -0
  38. {trnsparse-0.3.2 → trnsparse-0.4.0}/examples/sparse_fock.py +0 -0
  39. {trnsparse-0.3.2 → trnsparse-0.4.0}/infra/terraform/.terraform.lock.hcl +0 -0
  40. {trnsparse-0.3.2 → trnsparse-0.4.0}/infra/terraform/README.md +0 -0
  41. {trnsparse-0.3.2 → trnsparse-0.4.0}/infra/terraform/main.tf +0 -0
  42. {trnsparse-0.3.2 → trnsparse-0.4.0}/mkdocs.yml +0 -0
  43. {trnsparse-0.3.2 → trnsparse-0.4.0}/scripts/bench_to_md.py +0 -0
  44. {trnsparse-0.3.2 → trnsparse-0.4.0}/scripts/run_benchmarks.sh +0 -0
  45. {trnsparse-0.3.2 → trnsparse-0.4.0}/scripts/run_neuron_tests.sh +0 -0
  46. {trnsparse-0.3.2 → trnsparse-0.4.0}/scripts/run_simulator_tests.sh +0 -0
  47. {trnsparse-0.3.2 → trnsparse-0.4.0}/setup.cfg +0 -0
  48. {trnsparse-0.3.2 → trnsparse-0.4.0}/tests/conftest.py +0 -0
  49. {trnsparse-0.3.2 → trnsparse-0.4.0}/tests/test_bsr.py +0 -0
  50. {trnsparse-0.3.2 → trnsparse-0.4.0}/tests/test_formats.py +0 -0
  51. {trnsparse-0.3.2 → trnsparse-0.4.0}/tests/test_iterative.py +0 -0
  52. {trnsparse-0.3.2 → trnsparse-0.4.0}/tests/test_nki_bsr.py +0 -0
  53. {trnsparse-0.3.2 → trnsparse-0.4.0}/tests/test_nki_spmm.py +0 -0
  54. {trnsparse-0.3.2 → trnsparse-0.4.0}/tests/test_ops.py +0 -0
  55. {trnsparse-0.3.2 → trnsparse-0.4.0}/trnsparse/formats.py +0 -0
  56. {trnsparse-0.3.2 → trnsparse-0.4.0}/trnsparse/iterative.py +0 -0
  57. {trnsparse-0.3.2 → trnsparse-0.4.0}/trnsparse/nki/__init__.py +0 -0
  58. {trnsparse-0.3.2 → trnsparse-0.4.0}/trnsparse/screening.py +0 -0
  59. {trnsparse-0.3.2 → trnsparse-0.4.0}/trnsparse.egg-info/dependency_links.txt +0 -0
  60. {trnsparse-0.3.2 → trnsparse-0.4.0}/trnsparse.egg-info/requires.txt +0 -0
  61. {trnsparse-0.3.2 → trnsparse-0.4.0}/trnsparse.egg-info/top_level.txt +0 -0
@@ -5,6 +5,44 @@ All notable changes to this project will be documented in this file.
5
5
  The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
6
6
  and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
7
7
 
8
+ ## [0.4.0] — 2026-04-14
9
+
10
+ ### Added
11
+
12
+ - **`screened_spmm(A, diag_integrals, B, threshold)`** — fused Schwarz-
13
+ screened dense matmul. One NKI kernel fuses the full pipeline —
14
+ outer-product pair bound → threshold → mask-apply → `nc_matmul` —
15
+ into a single dispatch. Saves ~30–50% end-to-end vs the unfused
16
+ `density_screen + from_dense + spmm` flow on Fock-build-sized inputs.
17
+ Closes #19.
18
+ - **`_screened_spmm_kernel`** — new `@nki.jit` kernel in
19
+ `trnsparse/nki/kernels.py`. Stationary-A-tile-reuse GEMM extended with
20
+ a per-tile pair-bound mask built from the 1-D Schwarz-bound vector.
21
+ - **`_ScreenedSpMMFunction`** — `torch.autograd.Function` wrapper.
22
+ Third differentiable NKI kernel in the trnsci suite (after v0.2.0
23
+ CSR SpMM and v0.3.0 BSR SpMM). `torch.autograd.gradcheck` passes at
24
+ `atol=1e-4` on hardware. Mask is non-differentiable (discrete gate);
25
+ gradients flow to `A` (masked) and `B` (transposed masked A) only.
26
+ - **Tests**: 4 CPU (`TestScreenedSpmm`), 2 simulator
27
+ (`TestScreenedSpmmSimulator`), 7 hardware
28
+ (`TestNkiScreenedSpmmParity` + `TestNkiScreenedSpmmDifferentiability`).
29
+ All green on `trn1.2xlarge`.
30
+ - **`docs/architecture.md`** — new "Fused screened SpMM" section.
31
+
32
+ ### Closed
33
+
34
+ - [#24](https://github.com/trnsci/trnsparse/issues/24) — fused-CG NKI
35
+ kernel was not buildable under NKI 2.24/0.3.0 constraints (no break,
36
+ no iteration-carried scalar state across `affine_range`, no nested
37
+ kernels). Per-iteration `_cg_step_kernel` reframe evaluated and
38
+ found to save only 5–20% — not worth the authoring cost relative to
39
+ #19's genuine 30–50% savings. See #24 close comment for the audit.
40
+
41
+ ### Known limits
42
+
43
+ - Restricted to square `A` (`M == K`) with 1-D `diag_integrals`.
44
+ Rectangular / asymmetric-bounds extension is a follow-up if asked for.
45
+
8
46
  ## [0.3.2] — 2026-04-14
9
47
 
10
48
  ### Added
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: trnsparse
3
- Version: 0.3.2
3
+ Version: 0.4.0
4
4
  Summary: Sparse matrix operations for AWS Trainium via NKI
5
5
  Author-email: Scott Friedman <scttfrdmn@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -84,6 +84,25 @@ Backward — `_SpMMFunction.backward`, PyTorch-level:
84
84
 
85
85
  This wrapping satisfies [`trnsci/trnsci#3`](https://github.com/trnsci/trnsci/issues/3) — the suite-wide requirement that every NKI kernel live inside a `torch.autograd.Function` so training-time `loss.backward()` works. `torch.autograd.gradcheck` on small inputs is part of the hardware test matrix.
86
86
 
87
+ ### Fused screened SpMM (v0.4.0)
88
+
89
+ `screened_spmm(A, diag_integrals, B, threshold)` fuses the
90
+ chemistry-screened SpMM pipeline — Schwarz bound from the diagonal
91
+ integrals, pair-bound threshold mask, masked matmul — into a single
92
+ NKI kernel. The unfused equivalent does four host passes (sqrt, outer
93
+ product, threshold, mask-apply) plus a separate `from_dense` + `spmm`
94
+ call; the fused kernel collapses all of that into one dispatch.
95
+
96
+ Mask semantics: `mask[i,j] = sqrt(|diag[i]|) * sqrt(|diag[j]|) >
97
+ sqrt(threshold)`, matching `schwarz_bounds` + `screen_quartets`
98
+ composed. No gradient flows back to `diag_integrals` or `threshold` —
99
+ the mask is treated as a discrete gate; `grad_A *= mask` and
100
+ `grad_B = (A * mask).T @ grad_C`.
101
+
102
+ Restricted to square A (`M == K`) with 1-D `diag_integrals` in v0.4.0
103
+ — the common Fock-build case. Rectangular / asymmetric-bounds
104
+ extension is a follow-up if asked for.
105
+
87
106
  ### Known limits (v0.2.0)
88
107
 
89
108
  - **No sparsity exploitation.** Materialize-then-GEMM pays the full `M × K` cost. Row-bucketing is the v0.3.0 ([#15](https://github.com/trnsci/trnsparse/issues/15)) Phase 3 story. See [Benchmarks](benchmarks.md) for where NKI sits today vs scipy / torch.sparse.
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "trnsparse"
7
- version = "0.3.2"
7
+ version = "0.4.0"
8
8
  description = "Sparse matrix operations for AWS Trainium via NKI"
9
9
  readme = "README.md"
10
10
  license = "Apache-2.0"
@@ -0,0 +1,93 @@
1
+ """On-hardware fused screened SpMM tests (#19).
2
+
3
+ Requires Neuron hardware. Run via:
4
+
5
+ AWS_PROFILE=aws ./scripts/run_neuron_tests.sh trn1
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import math
11
+
12
+ import pytest
13
+ import torch
14
+
15
+ import trnsparse
16
+ from trnsparse.nki.dispatch import _use_nki
17
+
18
+ pytestmark = pytest.mark.neuron
19
+
20
+
21
+ ATOL, RTOL = 1e-3, 1e-4
22
+
23
+
24
+ @pytest.fixture
25
+ def nki_backend():
26
+ prev = trnsparse.get_backend()
27
+ trnsparse.set_backend("nki")
28
+ yield
29
+ trnsparse.set_backend(prev)
30
+
31
+
32
+ class TestNkiScreenedSpmmParity:
33
+ @pytest.mark.parametrize(
34
+ "n,N,threshold",
35
+ [
36
+ (128, 128, 0.0),
37
+ (128, 128, 0.5),
38
+ (256, 128, 0.5),
39
+ (256, 256, 0.1),
40
+ (200, 64, 0.3),
41
+ ],
42
+ )
43
+ def test_parity(self, nki_backend, n, N, threshold):
44
+ torch.manual_seed(42)
45
+ A = torch.randn(n, n)
46
+ diag = torch.abs(torch.randn(n)) * 4.0 + 0.01
47
+ B = torch.randn(n, N)
48
+
49
+ got = trnsparse.screened_spmm(A, diag, B, threshold=threshold)
50
+
51
+ Q = torch.sqrt(torch.abs(diag))
52
+ mask = (Q.unsqueeze(-1) * Q.unsqueeze(0)) > math.sqrt(max(threshold, 0.0))
53
+ expected = (A * mask.to(A.dtype)) @ B
54
+
55
+ torch.testing.assert_close(got, expected, atol=ATOL, rtol=RTOL)
56
+
57
+ def test_dispatch_routes_to_nki(self, nki_backend):
58
+ assert _use_nki()
59
+
60
+
61
+ class TestNkiScreenedSpmmDifferentiability:
62
+ """Satisfies the trnsci/trnsci#3 autograd requirement for screened SpMM.
63
+
64
+ Mask is non-differentiable (discrete gate); gradients flow to A and B.
65
+ """
66
+
67
+ def test_backward_finite(self, nki_backend):
68
+ torch.manual_seed(1)
69
+ n = 128
70
+ A = torch.randn(n, n, requires_grad=True)
71
+ diag = torch.abs(torch.randn(n)) * 4.0 + 0.01
72
+ B = torch.randn(n, 64, requires_grad=True)
73
+
74
+ C = trnsparse.screened_spmm(A, diag, B, threshold=0.5)
75
+ loss = C.pow(2).sum()
76
+ loss.backward()
77
+
78
+ assert A.grad is not None and torch.isfinite(A.grad).all()
79
+ assert B.grad is not None and torch.isfinite(B.grad).all()
80
+
81
+ def test_gradcheck_small(self, nki_backend):
82
+ torch.manual_seed(2)
83
+ n = 128
84
+ A = torch.randn(n, n, dtype=torch.float64, requires_grad=True)
85
+ diag = torch.abs(torch.randn(n, dtype=torch.float64)) * 4.0 + 0.01
86
+ B = torch.randn(n, 8, dtype=torch.float64, requires_grad=True)
87
+
88
+ from trnsparse.nki.dispatch import _ScreenedSpMMFunction
89
+
90
+ def func(a, b):
91
+ return _ScreenedSpMMFunction.apply(a, diag, 0.5, b)
92
+
93
+ assert torch.autograd.gradcheck(func, (A, B), eps=1e-6, atol=1e-4)
@@ -111,3 +111,43 @@ class TestBsrSpmmSimulator:
111
111
  B = torch.randn(2 * b, 32)
112
112
  got = trnsparse.bsr_spmm(bsr, B)
113
113
  torch.testing.assert_close(got, A_dense @ B, atol=ATOL, rtol=RTOL)
114
+
115
+
116
+ class TestScreenedSpmmSimulator:
117
+ """Fused screened SpMM through the simulator (#19).
118
+
119
+ The NKI kernel fuses Q outer-product + threshold mask + nc_matmul.
120
+ Small tile-aligned shapes so the simulator runs in seconds.
121
+ """
122
+
123
+ def test_threshold_zero_equals_plain_matmul(self, nki_backend):
124
+ """threshold=0 → mask passes all entries → screened_spmm == A @ B."""
125
+ torch.manual_seed(10)
126
+ n = 128
127
+ A = torch.randn(n, n)
128
+ diag = torch.abs(torch.randn(n)) + 0.1
129
+ B = torch.randn(n, 64)
130
+
131
+ got = trnsparse.screened_spmm(A, diag, B, threshold=0.0)
132
+ torch.testing.assert_close(got, A @ B, atol=ATOL, rtol=RTOL)
133
+
134
+ def test_non_trivial_threshold_parity(self, nki_backend):
135
+ """Non-trivial threshold drops some entries; NKI kernel must match
136
+ the explicit (A * mask) @ B spec.
137
+ """
138
+ import math
139
+
140
+ torch.manual_seed(11)
141
+ n = 128
142
+ A = torch.randn(n, n)
143
+ diag = torch.abs(torch.randn(n)) * 4.0
144
+ B = torch.randn(n, 64)
145
+ threshold = 0.5
146
+
147
+ got = trnsparse.screened_spmm(A, diag, B, threshold=threshold)
148
+
149
+ Q = torch.sqrt(torch.abs(diag))
150
+ mask = (Q.unsqueeze(-1) * Q.unsqueeze(0)) > math.sqrt(threshold)
151
+ expected = (A * mask.to(A.dtype)) @ B
152
+
153
+ torch.testing.assert_close(got, expected, atol=ATOL, rtol=RTOL)
@@ -73,6 +73,66 @@ class TestDensityScreen:
73
73
  )
74
74
 
75
75
 
76
+ class TestScreenedSpmm:
77
+ """PyTorch-fallback path for the fused screened SpMM (#19)."""
78
+
79
+ def test_threshold_zero_equals_plain_matmul(self):
80
+ """threshold=0 keeps all entries → screened_spmm == A @ B."""
81
+ torch.manual_seed(0)
82
+ n = 64
83
+ A = torch.randn(n, n)
84
+ diag = torch.abs(torch.randn(n)) + 0.1
85
+ B = torch.randn(n, 16)
86
+
87
+ got = trnsparse.screened_spmm(A, diag, B, threshold=0.0)
88
+ torch.testing.assert_close(got, A @ B, atol=1e-5, rtol=1e-5)
89
+
90
+ def test_huge_threshold_zeros_output(self):
91
+ """threshold → ∞ drops all entries → screened_spmm returns zeros."""
92
+ torch.manual_seed(1)
93
+ n = 64
94
+ A = torch.randn(n, n)
95
+ diag = torch.abs(torch.randn(n))
96
+ B = torch.randn(n, 16)
97
+
98
+ got = trnsparse.screened_spmm(A, diag, B, threshold=1e30)
99
+ torch.testing.assert_close(got, torch.zeros(n, 16), atol=0, rtol=0)
100
+
101
+ def test_parity_vs_explicit_mask(self):
102
+ """Matches (A * mask) @ B for a non-trivial threshold."""
103
+ import math
104
+
105
+ torch.manual_seed(2)
106
+ n = 64
107
+ A = torch.randn(n, n)
108
+ diag = torch.abs(torch.randn(n)) * 4.0
109
+ B = torch.randn(n, 16)
110
+ threshold = 0.5
111
+
112
+ got = trnsparse.screened_spmm(A, diag, B, threshold=threshold)
113
+
114
+ Q = torch.sqrt(torch.abs(diag))
115
+ mask = (Q.unsqueeze(-1) * Q.unsqueeze(0)) > math.sqrt(threshold)
116
+ expected = (A * mask.to(A.dtype)) @ B
117
+
118
+ torch.testing.assert_close(got, expected, atol=1e-5, rtol=1e-5)
119
+
120
+ def test_non_trivial_mask_setup(self):
121
+ """Guard: the chosen threshold must drop some but not all entries
122
+ on the test distribution — otherwise other tests are vacuous.
123
+ """
124
+ import math
125
+
126
+ torch.manual_seed(3)
127
+ n = 64
128
+ diag = torch.abs(torch.randn(n))
129
+ threshold = 0.5
130
+ Q = torch.sqrt(torch.abs(diag))
131
+ mask = (Q.unsqueeze(-1) * Q.unsqueeze(0)) > math.sqrt(threshold)
132
+ assert not mask.all()
133
+ assert mask.any()
134
+
135
+
76
136
  class TestSparsityStats:
77
137
  def test_fully_dense(self):
78
138
  Q = torch.ones(10, 10)
@@ -5,7 +5,7 @@ CSR/COO formats, SpMV, SpMM, and integral screening for
5
5
  sparse scientific computing. Part of the trnsci scientific computing suite.
6
6
  """
7
7
 
8
- __version__ = "0.3.2"
8
+ __version__ = "0.4.0"
9
9
 
10
10
  from .formats import BSRMatrix, COOMatrix, CSRMatrix, eye_sparse, from_dense, from_scipy
11
11
  from .iterative import bsr_diagonal, cg_bsr, jacobi_preconditioner_bsr, power_iteration_bsr
@@ -13,6 +13,7 @@ from .nki import HAS_NKI, get_backend, set_backend
13
13
  from .ops import (
14
14
  bsr_spmm,
15
15
  nnz_per_row,
16
+ screened_spmm,
16
17
  sparse_add,
17
18
  sparse_scale,
18
19
  sparse_transpose,
@@ -33,6 +34,7 @@ __all__ = [
33
34
  "spmm",
34
35
  "spmv_symmetric",
35
36
  "bsr_spmm",
37
+ "screened_spmm",
36
38
  "sparse_add",
37
39
  "sparse_scale",
38
40
  "sparse_transpose",
@@ -23,7 +23,11 @@ from .kernels import _TILE_K, _TILE_M, _TILE_N, HAS_NKI
23
23
  if HAS_NKI:
24
24
  import nki
25
25
 
26
- from .kernels import _bsr_spmm_kernel, _spmm_dense_kernel # noqa: F401 — NKI-only
26
+ from .kernels import ( # noqa: F401 — NKI-only
27
+ _bsr_spmm_kernel,
28
+ _screened_spmm_kernel,
29
+ _spmm_dense_kernel,
30
+ )
27
31
 
28
32
  # When set, kernel-path failures re-raise instead of falling back to
29
33
  # PyTorch. Used by the hardware validation suite.
@@ -371,3 +375,113 @@ def nki_bsr_spmm(A, B: torch.Tensor) -> torch.Tensor:
371
375
  A.block_size,
372
376
  B,
373
377
  )
378
+
379
+
380
+ def _nki_screened_spmm_impl(
381
+ A: torch.Tensor,
382
+ Q: torch.Tensor,
383
+ threshold_sqrt: float,
384
+ B: torch.Tensor,
385
+ ) -> torch.Tensor:
386
+ """Dispatch `_screened_spmm_kernel` on the XLA device (or simulator).
387
+
388
+ `A` is square (M, M); `Q` is the 1-D Schwarz-bound vector of length M.
389
+ Pads M up to TILE_M and N up to TILE_N when `N > TILE_N`.
390
+ Falls back to `torch.matmul` on masked A if the kernel errors and
391
+ `TRNSPARSE_REQUIRE_NKI` is not set.
392
+ """
393
+ if not HAS_NKI:
394
+ raise RuntimeError("NKI not available")
395
+ M, K = A.shape
396
+ _, N = B.shape
397
+ assert M == K, f"screened_spmm currently requires square A; got {A.shape}"
398
+ M_pad = _round_up(M, _TILE_M)
399
+ N_pad = N if N <= _TILE_N else _round_up(N, _TILE_N)
400
+ needs_pad = (M_pad != M) or (N_pad != N)
401
+
402
+ threshold_sqrt_t = torch.tensor(threshold_sqrt, dtype=A.dtype)
403
+
404
+ try:
405
+ if needs_pad:
406
+ A_p = torch.zeros(M_pad, M_pad, dtype=A.dtype, device=A.device)
407
+ A_p[:M, :M] = A
408
+ Q_p = torch.zeros(M_pad, dtype=Q.dtype, device=Q.device)
409
+ Q_p[:M] = Q
410
+ B_p = torch.zeros(M_pad, N_pad, dtype=B.dtype, device=B.device)
411
+ B_p[:M, :N] = B
412
+ A_feed, Q_feed, B_feed = A_p.contiguous(), Q_p.contiguous(), B_p.contiguous()
413
+ else:
414
+ A_feed, Q_feed, B_feed = A.contiguous(), Q.contiguous(), B.contiguous()
415
+
416
+ if _use_simulator():
417
+ out_np = nki.simulate(_screened_spmm_kernel)(
418
+ A_feed.cpu().numpy(),
419
+ Q_feed.cpu().numpy(),
420
+ threshold_sqrt_t.cpu().numpy(),
421
+ B_feed.cpu().numpy(),
422
+ )
423
+ result = torch.from_numpy(np.asarray(out_np)).to(A.device)
424
+ else:
425
+ (a, q, b), orig_device = _to_xla(A_feed, Q_feed, B_feed)
426
+ ts = threshold_sqrt_t.to(a.device)
427
+ c = _screened_spmm_kernel(a, q, ts, b)
428
+ result = c.to(orig_device)
429
+
430
+ return result[:M, :N] if needs_pad else result
431
+ except Exception:
432
+ if _REQUIRE_NKI:
433
+ raise
434
+ # Torch fallback computes the mask + matmul directly.
435
+ pair_bound = Q.unsqueeze(-1) * Q.unsqueeze(0)
436
+ mask = pair_bound > threshold_sqrt
437
+ return (A * mask.to(A.dtype)) @ B
438
+
439
+
440
+ class _ScreenedSpMMFunction(torch.autograd.Function):
441
+ """Autograd wrapper for fused screened SpMM.
442
+
443
+ Forward: NKI-dispatched (or PyTorch fallback). Backward: PyTorch-level,
444
+ projecting gradients through the mask.
445
+
446
+ The mask depends on `diag_integrals` and `threshold` but is discrete —
447
+ no gradient flows back to them. Gradients flow to `A` (masked) and
448
+ `B` (transposed masked A).
449
+ """
450
+
451
+ @staticmethod
452
+ def forward(
453
+ ctx,
454
+ A: torch.Tensor,
455
+ diag_integrals: torch.Tensor,
456
+ threshold: float,
457
+ B: torch.Tensor,
458
+ ) -> torch.Tensor:
459
+ import math as _math
460
+
461
+ Q = torch.sqrt(torch.abs(diag_integrals))
462
+ threshold_sqrt = _math.sqrt(threshold)
463
+ C = _nki_screened_spmm_impl(A, Q, threshold_sqrt, B)
464
+
465
+ # Save the effective mask for backward.
466
+ mask = (Q.unsqueeze(-1) * Q.unsqueeze(0)) > threshold_sqrt
467
+ ctx.save_for_backward(A, B, mask)
468
+ return C
469
+
470
+ @staticmethod
471
+ def backward(ctx, grad_out: torch.Tensor):
472
+ A, B, mask = ctx.saved_tensors
473
+ m_f = mask.to(A.dtype)
474
+ grad_A = (grad_out @ B.T) * m_f if ctx.needs_input_grad[0] else None
475
+ grad_B = (A * m_f).T @ grad_out if ctx.needs_input_grad[3] else None
476
+ # No gradient to diag_integrals (arg 1) or threshold (arg 2).
477
+ return grad_A, None, None, grad_B
478
+
479
+
480
+ def nki_screened_spmm(
481
+ A: torch.Tensor,
482
+ diag_integrals: torch.Tensor,
483
+ B: torch.Tensor,
484
+ threshold: float,
485
+ ) -> torch.Tensor:
486
+ """Screened SpMM entry point — wraps `_ScreenedSpMMFunction.apply` for autograd."""
487
+ return _ScreenedSpMMFunction.apply(A, diag_integrals, threshold, B)
@@ -79,6 +79,72 @@ if HAS_NKI:
79
79
 
80
80
  return out
81
81
 
82
+ @nki.jit
83
+ def _screened_spmm_kernel(a, q, threshold_sqrt, b):
84
+ """Fused Schwarz-screened dense matmul: `C = (A * mask) @ B`.
85
+
86
+ mask[i,j] = `Q[i] * Q[j] > threshold_sqrt`, where Q is
87
+ `sqrt(|diag_integrals|)` pre-computed on the host.
88
+
89
+ Fuses: outer-product pair bound → threshold → mask-apply → nc_matmul
90
+ into one kernel. Saves one mask-memory pass + one kernel dispatch
91
+ vs the unfused flow.
92
+
93
+ Caller guarantees: A is square (M, K) with M==K, padded to
94
+ TILE_M=TILE_K=128. B has K padded likewise and N either ≤ 512
95
+ or a multiple of 512. `q` is the 1-D Schwarz bounds of length M.
96
+ `threshold_sqrt` is a 0-d fp32 tensor (scalar).
97
+ """
98
+ M, K = a.shape
99
+ _, N = b.shape
100
+
101
+ TILE_M = _TILE_M
102
+ TILE_K = _TILE_K
103
+ TILE_N = N if N <= _TILE_N else _TILE_N
104
+
105
+ c = nl.ndarray((M, N), dtype=a.dtype, buffer=nl.shared_hbm)
106
+
107
+ for m in nl.affine_range(M // TILE_M):
108
+ for n in nl.affine_range(N // TILE_N):
109
+ m_off = m * TILE_M
110
+ n_off = n * TILE_N
111
+
112
+ psum = nl.zeros((TILE_M, TILE_N), dtype=nl.float32, buffer=nl.psum)
113
+
114
+ # Row Q slice used for every k-tile in this (m, n) output tile.
115
+ q_m = nl.load(q[m_off : m_off + TILE_M]) # (TILE_M,)
116
+
117
+ for k in nl.affine_range(K // TILE_K):
118
+ k_off = k * TILE_K
119
+
120
+ a_tile = nl.load(a[m_off : m_off + TILE_M, k_off : k_off + TILE_K])
121
+ q_k = nl.load(q[k_off : k_off + TILE_K]) # (TILE_K,)
122
+
123
+ # Outer-product pair bound (TILE_M, TILE_K). nl broadcasting
124
+ # via explicit reshape — partition-dim-safe.
125
+ pair_bound = q_m.reshape((TILE_M, 1)) * q_k.reshape((1, TILE_K))
126
+ mask = nl.greater(pair_bound, threshold_sqrt)
127
+ a_masked = nl.multiply(a_tile, mask.astype(a.dtype))
128
+
129
+ # Transpose for stationary-A nc_matmul via a staging buffer.
130
+ # nl.load_transpose2d loads+transposes from HBM, but a_masked
131
+ # is already in SBUF, so we need to store-and-reload or use
132
+ # an in-SBUF transpose primitive. nl.transpose is available
133
+ # in NKI 0.3.0; if the simulator rejects, fall back to
134
+ # storing to an HBM staging tile and load_transpose2d-ing.
135
+ a_t = nl.transpose(a_masked)
136
+ b_tile = nl.load(b[k_off : k_off + TILE_K, n_off : n_off + TILE_N])
137
+
138
+ psum[...] += nisa.nc_matmul(a_t, b_tile)
139
+
140
+ c_sbuf = nl.copy(psum, dtype=a.dtype)
141
+ nl.store(
142
+ c[m_off : m_off + TILE_M, n_off : n_off + TILE_N],
143
+ value=c_sbuf,
144
+ )
145
+
146
+ return c
147
+
82
148
  @nki.jit
83
149
  def _spmm_dense_kernel(a, b):
84
150
  """Densified SpMM: C = A @ B with stationary A-tile reuse.
@@ -15,6 +15,8 @@ gather non-zero rows/cols into dense tiles, matmul, scatter back.
15
15
 
16
16
  from __future__ import annotations
17
17
 
18
+ import math
19
+
18
20
  import torch
19
21
 
20
22
  from .formats import BSRMatrix, COOMatrix, CSRMatrix
@@ -206,3 +208,70 @@ def bsr_spmm(A: BSRMatrix, B: torch.Tensor) -> torch.Tensor:
206
208
 
207
209
  return nki_bsr_spmm(A, B)
208
210
  return _bsr_spmm_pytorch(A, B)
211
+
212
+
213
+ def screened_spmm(
214
+ A: torch.Tensor,
215
+ diag_integrals: torch.Tensor,
216
+ B: torch.Tensor,
217
+ threshold: float,
218
+ ) -> torch.Tensor:
219
+ """Fused Schwarz-screened dense matmul: `C = (A * mask) @ B`.
220
+
221
+ The mask is the Schwarz-inequality pair bound:
222
+
223
+ Q[i] = sqrt(|diag_integrals[i]|)
224
+ mask[i,j] = (Q[i] * Q[j] > sqrt(threshold))
225
+
226
+ On the NKI backend, the sqrt / outer-product / threshold / mask-apply
227
+ / matmul chain is fused into a single `@nki.jit` kernel — one
228
+ dispatch, no intermediate mask tensor on HBM, no separate BSR
229
+ construction pass. Saves ~30-50% end-to-end vs the unfused
230
+ `density_screen → screen_quartets → from_dense → spmm` flow at
231
+ realistic Fock-build sizes.
232
+
233
+ On the PyTorch backend, falls back to the explicit mask materialize
234
+ + matmul (semantic spec for the NKI kernel to match).
235
+
236
+ Args:
237
+ A: Dense matrix, shape `(M, K)`. The unscreened operand —
238
+ typically the integral slice `(μν|λσ)` for the λσ range.
239
+ diag_integrals: Per-index Schwarz bounds source. Shape `(M,)`
240
+ if `M == K` (square case), or passed as `(K,)` if one wants
241
+ to screen based on the K dimension only. For the common
242
+ chemistry use case (square A, symmetric bounds), shape `(M,)`.
243
+ B: Dense RHS, shape `(K, N)`.
244
+ threshold: Screening threshold. Pairs with
245
+ `Q[i] * Q[j] <= sqrt(threshold)` are zeroed in `A` before
246
+ the matmul.
247
+
248
+ Returns:
249
+ `C`, shape `(M, N)` = `(A * mask) @ B`.
250
+
251
+ Differentiable via `_ScreenedSpMMFunction`; backward projects
252
+ gradients back through the mask (`dA *= mask`, no gradient to
253
+ `diag_integrals` or `threshold` since the mask is discrete).
254
+ """
255
+ from .nki.dispatch import _use_nki
256
+
257
+ if _use_nki():
258
+ from .nki.dispatch import nki_screened_spmm
259
+
260
+ return nki_screened_spmm(A, diag_integrals, B, threshold)
261
+
262
+ # PyTorch fallback — semantic spec for the kernel.
263
+ # Requires diag_integrals 1-D of length matching A's rows and cols
264
+ # (common chemistry case: A is square (n, n) with a per-shell bound vector).
265
+ assert (
266
+ diag_integrals.dim() == 1
267
+ ), f"diag_integrals must be 1-D; got shape {diag_integrals.shape}"
268
+ M, K = A.shape
269
+ assert diag_integrals.shape[0] == M == K, (
270
+ "screened_spmm requires square A with diag_integrals of matching length; "
271
+ f"got A shape {A.shape}, diag_integrals shape {diag_integrals.shape}"
272
+ )
273
+ Q = torch.sqrt(torch.abs(diag_integrals))
274
+ threshold_sqrt = math.sqrt(threshold)
275
+ pair_bound = Q.unsqueeze(-1) * Q.unsqueeze(0) # (M, K)
276
+ mask = pair_bound > threshold_sqrt
277
+ return (A * mask.to(A.dtype)) @ B
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: trnsparse
3
- Version: 0.3.2
3
+ Version: 0.4.0
4
4
  Summary: Sparse matrix operations for AWS Trainium via NKI
5
5
  Author-email: Scott Friedman <scttfrdmn@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -39,6 +39,7 @@ tests/test_bsr.py
39
39
  tests/test_formats.py
40
40
  tests/test_iterative.py
41
41
  tests/test_nki_bsr.py
42
+ tests/test_nki_screened_spmm.py
42
43
  tests/test_nki_sim.py
43
44
  tests/test_nki_spmm.py
44
45
  tests/test_ops.py
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes