gabp-sparse-inv 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. gabp_sparse_inv/__init__.py +146 -0
  2. gabp_sparse_inv/_linalg.py +41 -0
  3. gabp_sparse_inv/autodiff.py +226 -0
  4. gabp_sparse_inv/bench/__init__.py +1 -0
  5. gabp_sparse_inv/bench/confirmatory.py +432 -0
  6. gabp_sparse_inv/bench/deq_breakeven.py +485 -0
  7. gabp_sparse_inv/bench/deq_cross_eval.py +508 -0
  8. gabp_sparse_inv/bench/deq_gradient_isolation.py +355 -0
  9. gabp_sparse_inv/bench/deq_gradient_isolation_analysis.py +345 -0
  10. gabp_sparse_inv/bench/gmrf_scaling.py +195 -0
  11. gabp_sparse_inv/bench/harness.py +501 -0
  12. gabp_sparse_inv/bench/matched_compute.py +234 -0
  13. gabp_sparse_inv/bench/maze_cross_eval.py +165 -0
  14. gabp_sparse_inv/bench/maze_extrapolation.py +339 -0
  15. gabp_sparse_inv/bench/maze_symmetric_swap_control.py +468 -0
  16. gabp_sparse_inv/bench/metrics.py +492 -0
  17. gabp_sparse_inv/bench/nonsym_stability.py +254 -0
  18. gabp_sparse_inv/bench/operator_cross_eval_analysis.py +452 -0
  19. gabp_sparse_inv/bench/phase1_analysis.py +340 -0
  20. gabp_sparse_inv/bench/precision.py +517 -0
  21. gabp_sparse_inv/bench/run.py +207 -0
  22. gabp_sparse_inv/bench/seeds.py +127 -0
  23. gabp_sparse_inv/bench/stability.py +197 -0
  24. gabp_sparse_inv/chain.py +140 -0
  25. gabp_sparse_inv/demos/__init__.py +1 -0
  26. gabp_sparse_inv/demos/deltanet_chunk.py +258 -0
  27. gabp_sparse_inv/demos/deq_fixedpoint.py +354 -0
  28. gabp_sparse_inv/demos/maze_baselines.py +380 -0
  29. gabp_sparse_inv/demos/maze_grid.py +335 -0
  30. gabp_sparse_inv/demos/maze_tree.py +256 -0
  31. gabp_sparse_inv/generators.py +573 -0
  32. gabp_sparse_inv/gmrf.py +369 -0
  33. gabp_sparse_inv/gmrf_grid.py +180 -0
  34. gabp_sparse_inv/junction.py +752 -0
  35. gabp_sparse_inv/junction_autodiff.py +373 -0
  36. gabp_sparse_inv/layout.py +832 -0
  37. gabp_sparse_inv/nonsym.py +354 -0
  38. gabp_sparse_inv/nonsym_junction.py +397 -0
  39. gabp_sparse_inv/sampling.py +134 -0
  40. gabp_sparse_inv/star.py +140 -0
  41. gabp_sparse_inv/tree.py +248 -0
  42. gabp_sparse_inv-0.3.0.dist-info/METADATA +545 -0
  43. gabp_sparse_inv-0.3.0.dist-info/RECORD +46 -0
  44. gabp_sparse_inv-0.3.0.dist-info/WHEEL +5 -0
  45. gabp_sparse_inv-0.3.0.dist-info/licenses/LICENSE +21 -0
  46. gabp_sparse_inv-0.3.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,146 @@
1
+ """Selected inversion for sparse block-structured matrices.
2
+
3
+ Returns the blocks of ``A^{-1}`` on a chosen sparsity pattern in ``O(n)`` time / ``O(fill)``
4
+ memory, without ever forming the dense inverse. Kernels (all validated against dense fp64
5
+ oracles; see ``docs/PROJECT_STATUS.md`` for the authoritative scope):
6
+
7
+ SPD
8
+ ``selected_inverse_chain`` (block-tridiagonal), ``selected_inverse_star`` (arrowhead),
9
+ ``selected_inverse_tree`` (general tree), ``selinv_tree`` (the differentiable tree kernel
10
+ with an analytic backward), and ``selected_inverse_junction`` / ``selinv_junction`` (general
11
+ sparse SPD via the filled / chordal pattern, handling loopy graphs *exactly*).
12
+ Non-symmetric
13
+ ``selected_inverse_bidiag`` (block lower-bidiagonal, DeltaNet chunk),
14
+ ``selected_inverse_nonsym_tree`` (tree pattern, ``M_{uv} != M_{vu}``), and
15
+ ``selected_inverse_nonsym_junction`` / ``selinv_nonsym_junction`` (general sparse
16
+ non-symmetric on the filled pattern, via block LDU + the two-sided Takahashi /
17
+ Erisman-Tinney recurrence; no pivoting -- the static-pattern regime), with
18
+ ``nonsym_junction_solve`` the sibling linear solve ``A^{-1} b`` / ``A^{-T} b`` (the
19
+ fixed-point / DEQ implicit-differentiation primitive).
20
+ Applications
21
+ differentiable tree- and grid/loopy-GMRF learning (:mod:`gabp_sparse_inv.gmrf`,
22
+ :mod:`gabp_sparse_inv.gmrf_grid`) and Gaussian sampling (:mod:`gabp_sparse_inv.sampling`).
23
+
24
+ Out of scope unless noted otherwise here: iterative / loopy GaBP, pivoting for the
25
+ non-symmetric kernels (the static-pattern regime is assumed), and indefinite or
26
+ complex-Hermitian matrices.
27
+ """
28
+
29
+ from __future__ import annotations
30
+
31
+ from .chain import ChainFactors, selected_inverse_chain
32
+ from .generators import (
33
+ condition_number,
34
+ grid_edges,
35
+ random_nonsym_bidiag,
36
+ random_spd_chain,
37
+ random_spd_graph,
38
+ random_spd_laplacian,
39
+ random_spd_star,
40
+ random_spd_tree,
41
+ )
42
+ from .autodiff import SelInvTree, selinv_tree
43
+ from .junction import (
44
+ elimination_order_min_degree,
45
+ elimination_order_nested_dissection,
46
+ junction_logdet,
47
+ junction_solve,
48
+ selected_inverse_junction,
49
+ selinv_junction,
50
+ )
51
+ from .nonsym import (
52
+ SelInvBidiag,
53
+ SelInvTril,
54
+ selected_inverse_bidiag,
55
+ selected_inverse_nonsym_tree,
56
+ selected_inverse_tril,
57
+ selinv_bidiag,
58
+ selinv_tril,
59
+ )
60
+ from .nonsym_junction import (
61
+ nonsym_junction_solve,
62
+ selected_inverse_nonsym_junction,
63
+ selinv_nonsym_junction,
64
+ )
65
+ from .junction_autodiff import (
66
+ selinv_junction_analytic,
67
+ selinv_nonsym_junction_analytic,
68
+ )
69
+ from .gmrf import (
70
+ fit_marginal_likelihood,
71
+ marginal_log_likelihood,
72
+ posterior_marginal_variances,
73
+ sample_tree_gmrf,
74
+ tree_gmrf_precision,
75
+ tree_logdet,
76
+ tree_solve,
77
+ )
78
+ from .gmrf_grid import (
79
+ fit_grid_marginal_likelihood,
80
+ grid_gmrf_precision,
81
+ grid_node_degrees,
82
+ junction_marginal_log_likelihood,
83
+ junction_posterior_marginal_variances,
84
+ )
85
+ from .layout import BlockBidiag, BlockSparseSym, BlockStar, BlockTree, BlockTridiag
86
+ from .sampling import sample_gaussian_junction, sample_gaussian_tree
87
+ from .star import StarFactors, selected_inverse_star
88
+ from .tree import TreeFactors, selected_inverse_tree
89
+
90
+ __all__ = [
91
+ "BlockTridiag",
92
+ "BlockBidiag",
93
+ "BlockStar",
94
+ "BlockTree",
95
+ "BlockSparseSym",
96
+ "ChainFactors",
97
+ "StarFactors",
98
+ "TreeFactors",
99
+ "selected_inverse_chain",
100
+ "selected_inverse_star",
101
+ "selected_inverse_tree",
102
+ "selinv_tree",
103
+ "SelInvTree",
104
+ "selected_inverse_junction",
105
+ "selinv_junction",
106
+ "junction_solve",
107
+ "junction_logdet",
108
+ "elimination_order_min_degree",
109
+ "elimination_order_nested_dissection",
110
+ "selected_inverse_bidiag",
111
+ "selinv_bidiag",
112
+ "SelInvBidiag",
113
+ "selected_inverse_tril",
114
+ "selinv_tril",
115
+ "SelInvTril",
116
+ "selected_inverse_nonsym_tree",
117
+ "selected_inverse_nonsym_junction",
118
+ "selinv_nonsym_junction",
119
+ "nonsym_junction_solve",
120
+ "selinv_junction_analytic",
121
+ "selinv_nonsym_junction_analytic",
122
+ "tree_gmrf_precision",
123
+ "tree_logdet",
124
+ "tree_solve",
125
+ "marginal_log_likelihood",
126
+ "posterior_marginal_variances",
127
+ "sample_tree_gmrf",
128
+ "sample_gaussian_tree",
129
+ "sample_gaussian_junction",
130
+ "fit_marginal_likelihood",
131
+ "grid_gmrf_precision",
132
+ "grid_node_degrees",
133
+ "junction_marginal_log_likelihood",
134
+ "junction_posterior_marginal_variances",
135
+ "fit_grid_marginal_likelihood",
136
+ "random_spd_chain",
137
+ "random_spd_star",
138
+ "random_spd_tree",
139
+ "random_nonsym_bidiag",
140
+ "random_spd_graph",
141
+ "random_spd_laplacian",
142
+ "grid_edges",
143
+ "condition_number",
144
+ ]
145
+
146
+ __version__ = "0.3.0"
@@ -0,0 +1,41 @@
1
+ """Shared batched SPD linear-algebra helpers.
2
+
3
+ These back every selected-inverse kernel (chain, star, ...): a batched Cholesky
4
+ with an informative SPD-failure message and an explicit SPD-block inverse from a
5
+ Cholesky factor. Keeping them in one place means there is a single, tested
6
+ batched-Cholesky path regardless of the matrix structure being inverted.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import torch
12
+ from torch import Tensor
13
+
14
+ __all__ = ["cholesky_spd", "inv_via_chol"]
15
+
16
+
17
+ def cholesky_spd(block: Tensor, *, name: str) -> Tensor:
18
+ """Batched Cholesky via ``cholesky_ex`` with an informative SPD failure message.
19
+
20
+ ``cholesky_ex`` avoids the CUDA-side error-check synchronization of
21
+ ``torch.linalg.cholesky`` and returns an ``info`` code per batch element.
22
+
23
+ ``name`` identifies the failing pivot in the error message (e.g. ``"D_3"`` or
24
+ ``"S (center Schur complement)"``).
25
+ """
26
+ chol, info = torch.linalg.cholesky_ex(block)
27
+ if torch.any(info != 0):
28
+ bad = torch.nonzero(info != 0, as_tuple=False)
29
+ raise torch.linalg.LinAlgError(
30
+ f"pivot block {name} lost positive-definiteness during factorization "
31
+ f"(cholesky_ex info != 0 at batch indices {bad.tolist()}); matrix is "
32
+ f"not SPD or is too ill-conditioned"
33
+ )
34
+ return chol
35
+
36
+
37
+ def inv_via_chol(chol: Tensor) -> Tensor:
38
+ """Explicit inverse of an SPD block from its Cholesky factor."""
39
+ eye = torch.eye(chol.shape[-1], dtype=chol.dtype, device=chol.device)
40
+ eye = eye.expand_as(chol)
41
+ return torch.cholesky_solve(eye, chol)
@@ -0,0 +1,226 @@
1
+ """Differentiable selected inverse on SPD block trees.
2
+
3
+ :class:`SelInvTree` wraps the forward kernel
4
+ (:func:`gabp_sparse_inv.tree.selected_inverse_tree`) with a **hand-written analytic
5
+ backward**: the reverse two-pass collect/distribute of ``docs/derivations.md`` §8.3,
6
+ reusing the forward factors ``(chol_D, ell)`` and the selected blocks ``G_diag``.
7
+
8
+ The backward is the *self-adjoint* schedule proved in §8.2 -- a reverse-distribute
9
+ sweep (leaves->root) followed by a reverse-collect sweep (root->leaves) on the same
10
+ elimination tree. It costs ``O((|V|+|E|) b^3)`` time and ``O((|V|+|E|) b^2)`` memory,
11
+ identical to the forward pass, and never tapes the per-node loop through autograd
12
+ (the shipped forward loop is in fact *not* autograd-traceable -- it writes into
13
+ preallocated buffers -- which is part of why this analytic backward exists).
14
+
15
+ Both passes have a per-node ``loop`` form (the correctness reference) and a level-set
16
+ ``batched`` form (one batched solve + one ``index_add_`` per height antichain, the
17
+ same batching the forward uses); the two agree block-for-block.
18
+
19
+ Convention (``docs/derivations.md`` §8.1): the differentiable inputs are the stored
20
+ blocks ``diag`` (node diagonals ``A_vv``, symmetric) and ``edge`` (``edge[v] =
21
+ A_{p(v),v}``). The returned cotangents ``barAd[v] = d f / d diag[v]`` and
22
+ ``barAe[v] = d f / d edge[v]`` are gradients in the autograd sense w.r.t. those input
23
+ tensors -- exactly what :func:`torch.autograd.gradcheck` verifies.
24
+
25
+ **First-order only for this module.** ``backward`` runs under no-grad (the standard
26
+ custom-Function contract). For Hessian-vector products use the functional junction kernels
27
+ or ``selected_inverse_tree(batched=True)``; both pass ``gradgradcheck``
28
+ (``tests/test_double_backward.py``).
29
+ """
30
+
31
+ from __future__ import annotations
32
+
33
+ from typing import Sequence
34
+
35
+ import torch
36
+ from torch import Tensor
37
+
38
+ from ._linalg import inv_via_chol
39
+ from .layout import _as_parent_tensor, tree_levels, tree_orders
40
+ from .tree import _level_index_tensors, selected_inverse_tree
41
+
42
+ __all__ = ["selinv_tree", "SelInvTree"]
43
+
44
+
45
+ def _sym(x: Tensor) -> Tensor:
46
+ """Symmetrize the trailing ``b x b`` blocks: ``(X + X^T)/2``."""
47
+ return 0.5 * (x + x.mT)
48
+
49
+
50
+ def _backward_loop(ctx, barGd, barGe):
51
+ """Per-node reverse two-pass (eqs. 13-14); the correctness reference."""
52
+ edge, G_diag, ell, chol_D = ctx.saved_tensors
53
+ root, children, collect_order, plist = ctx.root, ctx.children, ctx.collect_order, ctx.plist
54
+ Dinv = inv_via_chol(chol_D)
55
+
56
+ zeros = torch.zeros_like(G_diag)
57
+ bG = barGd.clone() if barGd is not None else zeros.clone()
58
+ if barGe is None:
59
+ barGe = zeros
60
+ bDinv = zeros.clone()
61
+ bell = zeros.clone()
62
+ bU = zeros.clone()
63
+
64
+ # Pass 1: reverse-distribute (leaves -> root), eq. (13).
65
+ for v in collect_order:
66
+ M = _sym(bG[..., v, :, :])
67
+ bDinv[..., v, :, :] = bDinv[..., v, :, :] + M
68
+ if v != root:
69
+ p = plist[v]
70
+ Gpp = G_diag[..., p, :, :]
71
+ ell_v = ell[..., v, :, :]
72
+ bge = barGe[..., v, :, :]
73
+ bell[..., v, :, :] = bell[..., v, :, :] + Gpp @ ell_v @ (M + M.mT) - Gpp @ bge
74
+ bG[..., p, :, :] = bG[..., p, :, :] + ell_v @ M @ ell_v.mT - bge @ ell_v.mT
75
+
76
+ # Pass 2: reverse-collect (root -> leaves), eq. (14).
77
+ barAd = zeros.clone()
78
+ for v in reversed(collect_order):
79
+ Dinv_v = Dinv[..., v, :, :]
80
+ if v != root:
81
+ ell_v = ell[..., v, :, :]
82
+ U_v = edge[..., v, :, :]
83
+ bell_v = bell[..., v, :, :]
84
+ bU[..., v, :, :] = bU[..., v, :, :] + bell_v @ Dinv_v
85
+ bDinv[..., v, :, :] = bDinv[..., v, :, :] + U_v.mT @ bell_v
86
+ bP = _sym(-(Dinv_v @ bDinv[..., v, :, :] @ Dinv_v))
87
+ barAd[..., v, :, :] = bP
88
+ for c in children[v]:
89
+ bP_v = bP
90
+ bell[..., c, :, :] = bell[..., c, :, :] - bP_v @ edge[..., c, :, :]
91
+ bU[..., c, :, :] = bU[..., c, :, :] - bP_v @ ell[..., c, :, :]
92
+
93
+ return barAd, bU
94
+
95
+
96
+ def _backward_batched(ctx, barGd, barGe):
97
+ """Level-set reverse two-pass: vectorizes (13)-(14) over height antichains.
98
+
99
+ Same arithmetic as :func:`_backward_loop`, scheduled by :func:`tree_levels` so each
100
+ height level is one batched block op + one ``index_add_`` -- the backward mirror of
101
+ the batched forward (the schedule is self-adjoint).
102
+ """
103
+ edge, G_diag, ell, chol_D = ctx.saved_tensors
104
+ root, plist, levels = ctx.root, ctx.plist, ctx.levels
105
+ Dinv = inv_via_chol(chol_D)
106
+ dev = G_diag.device
107
+ lvl = _level_index_tensors(levels, root, plist, dev)
108
+
109
+ zeros = torch.zeros_like(G_diag)
110
+ bG = barGd.clone() if barGd is not None else zeros.clone()
111
+ barGe = barGe if barGe is not None else zeros
112
+ bDinv = zeros.clone()
113
+ bell = zeros.clone()
114
+ bU = zeros.clone()
115
+ bP = zeros.clone()
116
+ barAd = zeros.clone()
117
+
118
+ # Pass 1: reverse-distribute, increasing height.
119
+ for node_idx, nonroot_idx, parent_idx in lvl:
120
+ bGl = bG.index_select(-3, node_idx)
121
+ Ml = _sym(bGl)
122
+ bDinv.index_copy_(-3, node_idx, bDinv.index_select(-3, node_idx) + Ml)
123
+ if nonroot_idx.numel() > 0:
124
+ M = _sym(bG.index_select(-3, nonroot_idx))
125
+ ell_v = ell.index_select(-3, nonroot_idx)
126
+ Gpp = G_diag.index_select(-3, parent_idx)
127
+ bge = barGe.index_select(-3, nonroot_idx)
128
+ add_bell = Gpp @ ell_v @ (M + M.mT) - Gpp @ bge
129
+ bell.index_copy_(-3, nonroot_idx, bell.index_select(-3, nonroot_idx) + add_bell)
130
+ bG = bG.index_add(-3, parent_idx, ell_v @ M @ ell_v.mT - bge @ ell_v.mT)
131
+
132
+ # Pass 2: reverse-collect, decreasing height. Root (top level, alone) first.
133
+ ridx = torch.tensor([root], dtype=torch.long, device=dev)
134
+ Dinv_r = Dinv.index_select(-3, ridx)
135
+ bP_r = _sym(-(Dinv_r @ bDinv.index_select(-3, ridx) @ Dinv_r))
136
+ bP.index_copy_(-3, ridx, bP_r)
137
+ barAd.index_copy_(-3, ridx, bP_r)
138
+ for node_idx, nonroot_idx, parent_idx in reversed(lvl):
139
+ if nonroot_idx.numel() == 0:
140
+ continue
141
+ ell_v = ell.index_select(-3, nonroot_idx)
142
+ U_v = edge.index_select(-3, nonroot_idx)
143
+ Dinv_v = Dinv.index_select(-3, nonroot_idx)
144
+ bP_par = bP.index_select(-3, parent_idx)
145
+ bell_v = bell.index_select(-3, nonroot_idx) - bP_par @ U_v # (14d)
146
+ bU_v = bU.index_select(-3, nonroot_idx) - bP_par @ ell_v # (14d)
147
+ bU_v = bU_v + bell_v @ Dinv_v # (14a)
148
+ bDinv_v = bDinv.index_select(-3, nonroot_idx) + U_v.mT @ bell_v # (14a)
149
+ bP_v = _sym(-(Dinv_v @ bDinv_v @ Dinv_v)) # (14b)
150
+ bell.index_copy_(-3, nonroot_idx, bell_v)
151
+ bU.index_copy_(-3, nonroot_idx, bU_v)
152
+ bP.index_copy_(-3, nonroot_idx, bP_v)
153
+ barAd.index_copy_(-3, nonroot_idx, bP_v)
154
+
155
+ return barAd, bU
156
+
157
+
158
+ class SelInvTree(torch.autograd.Function):
159
+ """Autograd Function: selected inverse of an SPD block tree with analytic backward.
160
+
161
+ Use :func:`selinv_tree` rather than calling ``apply`` directly; it resolves and
162
+ validates the tree topology once and routes through here.
163
+ """
164
+
165
+ @staticmethod
166
+ def forward(ctx, diag, edge, parent_t, root, children, collect_order, batched):
167
+ levels = tree_levels(parent_t) if batched else None
168
+ G_diag, G_edge, factors = selected_inverse_tree(
169
+ diag, edge, parent_t, return_factors=True, batched=batched
170
+ )
171
+ ctx.save_for_backward(edge, G_diag, factors.ell, factors.chol_D)
172
+ ctx.root = root
173
+ ctx.children = children
174
+ ctx.collect_order = collect_order
175
+ ctx.plist = parent_t.tolist()
176
+ ctx.levels = levels
177
+ ctx.batched = batched
178
+ return G_diag, G_edge
179
+
180
+ @staticmethod
181
+ def backward(ctx, barGd, barGe):
182
+ if ctx.batched:
183
+ barAd, barAe = _backward_batched(ctx, barGd, barGe)
184
+ else:
185
+ barAd, barAe = _backward_loop(ctx, barGd, barGe)
186
+ return barAd, barAe, None, None, None, None, None
187
+
188
+
189
+ def selinv_tree(
190
+ diag: Tensor,
191
+ edge: Tensor,
192
+ parent: Tensor | Sequence[int],
193
+ *,
194
+ check: bool = False,
195
+ batched: bool = False,
196
+ ):
197
+ """Autograd-connected selected inverse of an SPD block tree.
198
+
199
+ Same forward result as :func:`gabp_sparse_inv.selected_inverse_tree`, but the
200
+ returned ``(G_diag, G_edge)`` are connected to ``diag`` and ``edge`` through the
201
+ analytic backward of :class:`SelInvTree`. Gradients flow to ``diag`` and ``edge``;
202
+ ``parent`` is a static (non-differentiable) topology argument.
203
+
204
+ Parameters
205
+ ----------
206
+ diag, edge, parent:
207
+ As in :func:`gabp_sparse_inv.selected_inverse_tree`.
208
+ check:
209
+ If ``True``, validate the block-tree inputs before computing.
210
+ batched:
211
+ If ``True``, use the level-set batched forward and backward (same result;
212
+ amortizes launch latency on GPU). Default is the per-node reference loop.
213
+
214
+ Returns
215
+ -------
216
+ (G_diag, G_edge):
217
+ ``G_diag[v] = G_vv`` and ``G_edge[v] = G_{p(v),v}`` (root slot zero), with
218
+ autograd support w.r.t. ``diag`` and ``edge``.
219
+ """
220
+ if check:
221
+ from .layout import BlockTree
222
+
223
+ BlockTree(diag=diag, edge=edge, parent=parent).validate()
224
+ parent_t = _as_parent_tensor(parent)
225
+ root, children, collect_order = tree_orders(parent_t)
226
+ return SelInvTree.apply(diag, edge, parent_t, root, children, collect_order, batched)
@@ -0,0 +1 @@
1
+ """Benchmark harness: time, memory, and accuracy for chain and star kernels."""