gabp-sparse-inv 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gabp_sparse_inv/__init__.py +146 -0
- gabp_sparse_inv/_linalg.py +41 -0
- gabp_sparse_inv/autodiff.py +226 -0
- gabp_sparse_inv/bench/__init__.py +1 -0
- gabp_sparse_inv/bench/confirmatory.py +432 -0
- gabp_sparse_inv/bench/deq_breakeven.py +485 -0
- gabp_sparse_inv/bench/deq_cross_eval.py +508 -0
- gabp_sparse_inv/bench/deq_gradient_isolation.py +355 -0
- gabp_sparse_inv/bench/deq_gradient_isolation_analysis.py +345 -0
- gabp_sparse_inv/bench/gmrf_scaling.py +195 -0
- gabp_sparse_inv/bench/harness.py +501 -0
- gabp_sparse_inv/bench/matched_compute.py +234 -0
- gabp_sparse_inv/bench/maze_cross_eval.py +165 -0
- gabp_sparse_inv/bench/maze_extrapolation.py +339 -0
- gabp_sparse_inv/bench/maze_symmetric_swap_control.py +468 -0
- gabp_sparse_inv/bench/metrics.py +492 -0
- gabp_sparse_inv/bench/nonsym_stability.py +254 -0
- gabp_sparse_inv/bench/operator_cross_eval_analysis.py +452 -0
- gabp_sparse_inv/bench/phase1_analysis.py +340 -0
- gabp_sparse_inv/bench/precision.py +517 -0
- gabp_sparse_inv/bench/run.py +207 -0
- gabp_sparse_inv/bench/seeds.py +127 -0
- gabp_sparse_inv/bench/stability.py +197 -0
- gabp_sparse_inv/chain.py +140 -0
- gabp_sparse_inv/demos/__init__.py +1 -0
- gabp_sparse_inv/demos/deltanet_chunk.py +258 -0
- gabp_sparse_inv/demos/deq_fixedpoint.py +354 -0
- gabp_sparse_inv/demos/maze_baselines.py +380 -0
- gabp_sparse_inv/demos/maze_grid.py +335 -0
- gabp_sparse_inv/demos/maze_tree.py +256 -0
- gabp_sparse_inv/generators.py +573 -0
- gabp_sparse_inv/gmrf.py +369 -0
- gabp_sparse_inv/gmrf_grid.py +180 -0
- gabp_sparse_inv/junction.py +752 -0
- gabp_sparse_inv/junction_autodiff.py +373 -0
- gabp_sparse_inv/layout.py +832 -0
- gabp_sparse_inv/nonsym.py +354 -0
- gabp_sparse_inv/nonsym_junction.py +397 -0
- gabp_sparse_inv/sampling.py +134 -0
- gabp_sparse_inv/star.py +140 -0
- gabp_sparse_inv/tree.py +248 -0
- gabp_sparse_inv-0.3.0.dist-info/METADATA +545 -0
- gabp_sparse_inv-0.3.0.dist-info/RECORD +46 -0
- gabp_sparse_inv-0.3.0.dist-info/WHEEL +5 -0
- gabp_sparse_inv-0.3.0.dist-info/licenses/LICENSE +21 -0
- gabp_sparse_inv-0.3.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
"""Selected inversion for sparse block-structured matrices.
|
|
2
|
+
|
|
3
|
+
Returns the blocks of ``A^{-1}`` on a chosen sparsity pattern in ``O(n)`` time / ``O(fill)``
|
|
4
|
+
memory, without ever forming the dense inverse. Kernels (all validated against dense fp64
|
|
5
|
+
oracles; see ``docs/PROJECT_STATUS.md`` for the authoritative scope):
|
|
6
|
+
|
|
7
|
+
SPD
|
|
8
|
+
``selected_inverse_chain`` (block-tridiagonal), ``selected_inverse_star`` (arrowhead),
|
|
9
|
+
``selected_inverse_tree`` (general tree), ``selinv_tree`` (the differentiable tree kernel
|
|
10
|
+
with an analytic backward), and ``selected_inverse_junction`` / ``selinv_junction`` (general
|
|
11
|
+
sparse SPD via the filled / chordal pattern, handling loopy graphs *exactly*).
|
|
12
|
+
Non-symmetric
|
|
13
|
+
``selected_inverse_bidiag`` (block lower-bidiagonal, DeltaNet chunk),
|
|
14
|
+
``selected_inverse_nonsym_tree`` (tree pattern, ``M_{uv} != M_{vu}``), and
|
|
15
|
+
``selected_inverse_nonsym_junction`` / ``selinv_nonsym_junction`` (general sparse
|
|
16
|
+
non-symmetric on the filled pattern, via block LDU + the two-sided Takahashi /
|
|
17
|
+
Erisman-Tinney recurrence; no pivoting -- the static-pattern regime), with
|
|
18
|
+
``nonsym_junction_solve`` the sibling linear solve ``A^{-1} b`` / ``A^{-T} b`` (the
|
|
19
|
+
fixed-point / DEQ implicit-differentiation primitive).
|
|
20
|
+
Applications
|
|
21
|
+
differentiable tree- and grid/loopy-GMRF learning (:mod:`gabp_sparse_inv.gmrf`,
|
|
22
|
+
:mod:`gabp_sparse_inv.gmrf_grid`) and Gaussian sampling (:mod:`gabp_sparse_inv.sampling`).
|
|
23
|
+
|
|
24
|
+
Out of scope unless noted otherwise here: iterative / loopy GaBP, pivoting for the
|
|
25
|
+
non-symmetric kernels (the static-pattern regime is assumed), and indefinite or
|
|
26
|
+
complex-Hermitian matrices.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
from __future__ import annotations
|
|
30
|
+
|
|
31
|
+
from .chain import ChainFactors, selected_inverse_chain
|
|
32
|
+
from .generators import (
|
|
33
|
+
condition_number,
|
|
34
|
+
grid_edges,
|
|
35
|
+
random_nonsym_bidiag,
|
|
36
|
+
random_spd_chain,
|
|
37
|
+
random_spd_graph,
|
|
38
|
+
random_spd_laplacian,
|
|
39
|
+
random_spd_star,
|
|
40
|
+
random_spd_tree,
|
|
41
|
+
)
|
|
42
|
+
from .autodiff import SelInvTree, selinv_tree
|
|
43
|
+
from .junction import (
|
|
44
|
+
elimination_order_min_degree,
|
|
45
|
+
elimination_order_nested_dissection,
|
|
46
|
+
junction_logdet,
|
|
47
|
+
junction_solve,
|
|
48
|
+
selected_inverse_junction,
|
|
49
|
+
selinv_junction,
|
|
50
|
+
)
|
|
51
|
+
from .nonsym import (
|
|
52
|
+
SelInvBidiag,
|
|
53
|
+
SelInvTril,
|
|
54
|
+
selected_inverse_bidiag,
|
|
55
|
+
selected_inverse_nonsym_tree,
|
|
56
|
+
selected_inverse_tril,
|
|
57
|
+
selinv_bidiag,
|
|
58
|
+
selinv_tril,
|
|
59
|
+
)
|
|
60
|
+
from .nonsym_junction import (
|
|
61
|
+
nonsym_junction_solve,
|
|
62
|
+
selected_inverse_nonsym_junction,
|
|
63
|
+
selinv_nonsym_junction,
|
|
64
|
+
)
|
|
65
|
+
from .junction_autodiff import (
|
|
66
|
+
selinv_junction_analytic,
|
|
67
|
+
selinv_nonsym_junction_analytic,
|
|
68
|
+
)
|
|
69
|
+
from .gmrf import (
|
|
70
|
+
fit_marginal_likelihood,
|
|
71
|
+
marginal_log_likelihood,
|
|
72
|
+
posterior_marginal_variances,
|
|
73
|
+
sample_tree_gmrf,
|
|
74
|
+
tree_gmrf_precision,
|
|
75
|
+
tree_logdet,
|
|
76
|
+
tree_solve,
|
|
77
|
+
)
|
|
78
|
+
from .gmrf_grid import (
|
|
79
|
+
fit_grid_marginal_likelihood,
|
|
80
|
+
grid_gmrf_precision,
|
|
81
|
+
grid_node_degrees,
|
|
82
|
+
junction_marginal_log_likelihood,
|
|
83
|
+
junction_posterior_marginal_variances,
|
|
84
|
+
)
|
|
85
|
+
from .layout import BlockBidiag, BlockSparseSym, BlockStar, BlockTree, BlockTridiag
|
|
86
|
+
from .sampling import sample_gaussian_junction, sample_gaussian_tree
|
|
87
|
+
from .star import StarFactors, selected_inverse_star
|
|
88
|
+
from .tree import TreeFactors, selected_inverse_tree
|
|
89
|
+
|
|
90
|
+
__all__ = [
|
|
91
|
+
"BlockTridiag",
|
|
92
|
+
"BlockBidiag",
|
|
93
|
+
"BlockStar",
|
|
94
|
+
"BlockTree",
|
|
95
|
+
"BlockSparseSym",
|
|
96
|
+
"ChainFactors",
|
|
97
|
+
"StarFactors",
|
|
98
|
+
"TreeFactors",
|
|
99
|
+
"selected_inverse_chain",
|
|
100
|
+
"selected_inverse_star",
|
|
101
|
+
"selected_inverse_tree",
|
|
102
|
+
"selinv_tree",
|
|
103
|
+
"SelInvTree",
|
|
104
|
+
"selected_inverse_junction",
|
|
105
|
+
"selinv_junction",
|
|
106
|
+
"junction_solve",
|
|
107
|
+
"junction_logdet",
|
|
108
|
+
"elimination_order_min_degree",
|
|
109
|
+
"elimination_order_nested_dissection",
|
|
110
|
+
"selected_inverse_bidiag",
|
|
111
|
+
"selinv_bidiag",
|
|
112
|
+
"SelInvBidiag",
|
|
113
|
+
"selected_inverse_tril",
|
|
114
|
+
"selinv_tril",
|
|
115
|
+
"SelInvTril",
|
|
116
|
+
"selected_inverse_nonsym_tree",
|
|
117
|
+
"selected_inverse_nonsym_junction",
|
|
118
|
+
"selinv_nonsym_junction",
|
|
119
|
+
"nonsym_junction_solve",
|
|
120
|
+
"selinv_junction_analytic",
|
|
121
|
+
"selinv_nonsym_junction_analytic",
|
|
122
|
+
"tree_gmrf_precision",
|
|
123
|
+
"tree_logdet",
|
|
124
|
+
"tree_solve",
|
|
125
|
+
"marginal_log_likelihood",
|
|
126
|
+
"posterior_marginal_variances",
|
|
127
|
+
"sample_tree_gmrf",
|
|
128
|
+
"sample_gaussian_tree",
|
|
129
|
+
"sample_gaussian_junction",
|
|
130
|
+
"fit_marginal_likelihood",
|
|
131
|
+
"grid_gmrf_precision",
|
|
132
|
+
"grid_node_degrees",
|
|
133
|
+
"junction_marginal_log_likelihood",
|
|
134
|
+
"junction_posterior_marginal_variances",
|
|
135
|
+
"fit_grid_marginal_likelihood",
|
|
136
|
+
"random_spd_chain",
|
|
137
|
+
"random_spd_star",
|
|
138
|
+
"random_spd_tree",
|
|
139
|
+
"random_nonsym_bidiag",
|
|
140
|
+
"random_spd_graph",
|
|
141
|
+
"random_spd_laplacian",
|
|
142
|
+
"grid_edges",
|
|
143
|
+
"condition_number",
|
|
144
|
+
]
|
|
145
|
+
|
|
146
|
+
__version__ = "0.3.0"
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
"""Shared batched SPD linear-algebra helpers.
|
|
2
|
+
|
|
3
|
+
These back every selected-inverse kernel (chain, star, ...): a batched Cholesky
|
|
4
|
+
with an informative SPD-failure message and an explicit SPD-block inverse from a
|
|
5
|
+
Cholesky factor. Keeping them in one place means there is a single, tested
|
|
6
|
+
batched-Cholesky path regardless of the matrix structure being inverted.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from __future__ import annotations
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
from torch import Tensor
|
|
13
|
+
|
|
14
|
+
__all__ = ["cholesky_spd", "inv_via_chol"]
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def cholesky_spd(block: Tensor, *, name: str) -> Tensor:
|
|
18
|
+
"""Batched Cholesky via ``cholesky_ex`` with an informative SPD failure message.
|
|
19
|
+
|
|
20
|
+
``cholesky_ex`` avoids the CUDA-side error-check synchronization of
|
|
21
|
+
``torch.linalg.cholesky`` and returns an ``info`` code per batch element.
|
|
22
|
+
|
|
23
|
+
``name`` identifies the failing pivot in the error message (e.g. ``"D_3"`` or
|
|
24
|
+
``"S (center Schur complement)"``).
|
|
25
|
+
"""
|
|
26
|
+
chol, info = torch.linalg.cholesky_ex(block)
|
|
27
|
+
if torch.any(info != 0):
|
|
28
|
+
bad = torch.nonzero(info != 0, as_tuple=False)
|
|
29
|
+
raise torch.linalg.LinAlgError(
|
|
30
|
+
f"pivot block {name} lost positive-definiteness during factorization "
|
|
31
|
+
f"(cholesky_ex info != 0 at batch indices {bad.tolist()}); matrix is "
|
|
32
|
+
f"not SPD or is too ill-conditioned"
|
|
33
|
+
)
|
|
34
|
+
return chol
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def inv_via_chol(chol: Tensor) -> Tensor:
|
|
38
|
+
"""Explicit inverse of an SPD block from its Cholesky factor."""
|
|
39
|
+
eye = torch.eye(chol.shape[-1], dtype=chol.dtype, device=chol.device)
|
|
40
|
+
eye = eye.expand_as(chol)
|
|
41
|
+
return torch.cholesky_solve(eye, chol)
|
|
@@ -0,0 +1,226 @@
|
|
|
1
|
+
"""Differentiable selected inverse on SPD block trees.
|
|
2
|
+
|
|
3
|
+
:class:`SelInvTree` wraps the forward kernel
|
|
4
|
+
(:func:`gabp_sparse_inv.tree.selected_inverse_tree`) with a **hand-written analytic
|
|
5
|
+
backward**: the reverse two-pass collect/distribute of ``docs/derivations.md`` §8.3,
|
|
6
|
+
reusing the forward factors ``(chol_D, ell)`` and the selected blocks ``G_diag``.
|
|
7
|
+
|
|
8
|
+
The backward is the *self-adjoint* schedule proved in §8.2 -- a reverse-distribute
|
|
9
|
+
sweep (leaves->root) followed by a reverse-collect sweep (root->leaves) on the same
|
|
10
|
+
elimination tree. It costs ``O((|V|+|E|) b^3)`` time and ``O((|V|+|E|) b^2)`` memory,
|
|
11
|
+
identical to the forward pass, and never tapes the per-node loop through autograd
|
|
12
|
+
(the shipped forward loop is in fact *not* autograd-traceable -- it writes into
|
|
13
|
+
preallocated buffers -- which is part of why this analytic backward exists).
|
|
14
|
+
|
|
15
|
+
Both passes have a per-node ``loop`` form (the correctness reference) and a level-set
|
|
16
|
+
``batched`` form (one batched solve + one ``index_add_`` per height antichain, the
|
|
17
|
+
same batching the forward uses); the two agree block-for-block.
|
|
18
|
+
|
|
19
|
+
Convention (``docs/derivations.md`` §8.1): the differentiable inputs are the stored
|
|
20
|
+
blocks ``diag`` (node diagonals ``A_vv``, symmetric) and ``edge`` (``edge[v] =
|
|
21
|
+
A_{p(v),v}``). The returned cotangents ``barAd[v] = d f / d diag[v]`` and
|
|
22
|
+
``barAe[v] = d f / d edge[v]`` are gradients in the autograd sense w.r.t. those input
|
|
23
|
+
tensors -- exactly what :func:`torch.autograd.gradcheck` verifies.
|
|
24
|
+
|
|
25
|
+
**First-order only for this module.** ``backward`` runs under no-grad (the standard
|
|
26
|
+
custom-Function contract). For Hessian-vector products use the functional junction kernels
|
|
27
|
+
or ``selected_inverse_tree(batched=True)``; both pass ``gradgradcheck``
|
|
28
|
+
(``tests/test_double_backward.py``).
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
from __future__ import annotations
|
|
32
|
+
|
|
33
|
+
from typing import Sequence
|
|
34
|
+
|
|
35
|
+
import torch
|
|
36
|
+
from torch import Tensor
|
|
37
|
+
|
|
38
|
+
from ._linalg import inv_via_chol
|
|
39
|
+
from .layout import _as_parent_tensor, tree_levels, tree_orders
|
|
40
|
+
from .tree import _level_index_tensors, selected_inverse_tree
|
|
41
|
+
|
|
42
|
+
__all__ = ["selinv_tree", "SelInvTree"]
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _sym(x: Tensor) -> Tensor:
|
|
46
|
+
"""Symmetrize the trailing ``b x b`` blocks: ``(X + X^T)/2``."""
|
|
47
|
+
return 0.5 * (x + x.mT)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _backward_loop(ctx, barGd, barGe):
|
|
51
|
+
"""Per-node reverse two-pass (eqs. 13-14); the correctness reference."""
|
|
52
|
+
edge, G_diag, ell, chol_D = ctx.saved_tensors
|
|
53
|
+
root, children, collect_order, plist = ctx.root, ctx.children, ctx.collect_order, ctx.plist
|
|
54
|
+
Dinv = inv_via_chol(chol_D)
|
|
55
|
+
|
|
56
|
+
zeros = torch.zeros_like(G_diag)
|
|
57
|
+
bG = barGd.clone() if barGd is not None else zeros.clone()
|
|
58
|
+
if barGe is None:
|
|
59
|
+
barGe = zeros
|
|
60
|
+
bDinv = zeros.clone()
|
|
61
|
+
bell = zeros.clone()
|
|
62
|
+
bU = zeros.clone()
|
|
63
|
+
|
|
64
|
+
# Pass 1: reverse-distribute (leaves -> root), eq. (13).
|
|
65
|
+
for v in collect_order:
|
|
66
|
+
M = _sym(bG[..., v, :, :])
|
|
67
|
+
bDinv[..., v, :, :] = bDinv[..., v, :, :] + M
|
|
68
|
+
if v != root:
|
|
69
|
+
p = plist[v]
|
|
70
|
+
Gpp = G_diag[..., p, :, :]
|
|
71
|
+
ell_v = ell[..., v, :, :]
|
|
72
|
+
bge = barGe[..., v, :, :]
|
|
73
|
+
bell[..., v, :, :] = bell[..., v, :, :] + Gpp @ ell_v @ (M + M.mT) - Gpp @ bge
|
|
74
|
+
bG[..., p, :, :] = bG[..., p, :, :] + ell_v @ M @ ell_v.mT - bge @ ell_v.mT
|
|
75
|
+
|
|
76
|
+
# Pass 2: reverse-collect (root -> leaves), eq. (14).
|
|
77
|
+
barAd = zeros.clone()
|
|
78
|
+
for v in reversed(collect_order):
|
|
79
|
+
Dinv_v = Dinv[..., v, :, :]
|
|
80
|
+
if v != root:
|
|
81
|
+
ell_v = ell[..., v, :, :]
|
|
82
|
+
U_v = edge[..., v, :, :]
|
|
83
|
+
bell_v = bell[..., v, :, :]
|
|
84
|
+
bU[..., v, :, :] = bU[..., v, :, :] + bell_v @ Dinv_v
|
|
85
|
+
bDinv[..., v, :, :] = bDinv[..., v, :, :] + U_v.mT @ bell_v
|
|
86
|
+
bP = _sym(-(Dinv_v @ bDinv[..., v, :, :] @ Dinv_v))
|
|
87
|
+
barAd[..., v, :, :] = bP
|
|
88
|
+
for c in children[v]:
|
|
89
|
+
bP_v = bP
|
|
90
|
+
bell[..., c, :, :] = bell[..., c, :, :] - bP_v @ edge[..., c, :, :]
|
|
91
|
+
bU[..., c, :, :] = bU[..., c, :, :] - bP_v @ ell[..., c, :, :]
|
|
92
|
+
|
|
93
|
+
return barAd, bU
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def _backward_batched(ctx, barGd, barGe):
|
|
97
|
+
"""Level-set reverse two-pass: vectorizes (13)-(14) over height antichains.
|
|
98
|
+
|
|
99
|
+
Same arithmetic as :func:`_backward_loop`, scheduled by :func:`tree_levels` so each
|
|
100
|
+
height level is one batched block op + one ``index_add_`` -- the backward mirror of
|
|
101
|
+
the batched forward (the schedule is self-adjoint).
|
|
102
|
+
"""
|
|
103
|
+
edge, G_diag, ell, chol_D = ctx.saved_tensors
|
|
104
|
+
root, plist, levels = ctx.root, ctx.plist, ctx.levels
|
|
105
|
+
Dinv = inv_via_chol(chol_D)
|
|
106
|
+
dev = G_diag.device
|
|
107
|
+
lvl = _level_index_tensors(levels, root, plist, dev)
|
|
108
|
+
|
|
109
|
+
zeros = torch.zeros_like(G_diag)
|
|
110
|
+
bG = barGd.clone() if barGd is not None else zeros.clone()
|
|
111
|
+
barGe = barGe if barGe is not None else zeros
|
|
112
|
+
bDinv = zeros.clone()
|
|
113
|
+
bell = zeros.clone()
|
|
114
|
+
bU = zeros.clone()
|
|
115
|
+
bP = zeros.clone()
|
|
116
|
+
barAd = zeros.clone()
|
|
117
|
+
|
|
118
|
+
# Pass 1: reverse-distribute, increasing height.
|
|
119
|
+
for node_idx, nonroot_idx, parent_idx in lvl:
|
|
120
|
+
bGl = bG.index_select(-3, node_idx)
|
|
121
|
+
Ml = _sym(bGl)
|
|
122
|
+
bDinv.index_copy_(-3, node_idx, bDinv.index_select(-3, node_idx) + Ml)
|
|
123
|
+
if nonroot_idx.numel() > 0:
|
|
124
|
+
M = _sym(bG.index_select(-3, nonroot_idx))
|
|
125
|
+
ell_v = ell.index_select(-3, nonroot_idx)
|
|
126
|
+
Gpp = G_diag.index_select(-3, parent_idx)
|
|
127
|
+
bge = barGe.index_select(-3, nonroot_idx)
|
|
128
|
+
add_bell = Gpp @ ell_v @ (M + M.mT) - Gpp @ bge
|
|
129
|
+
bell.index_copy_(-3, nonroot_idx, bell.index_select(-3, nonroot_idx) + add_bell)
|
|
130
|
+
bG = bG.index_add(-3, parent_idx, ell_v @ M @ ell_v.mT - bge @ ell_v.mT)
|
|
131
|
+
|
|
132
|
+
# Pass 2: reverse-collect, decreasing height. Root (top level, alone) first.
|
|
133
|
+
ridx = torch.tensor([root], dtype=torch.long, device=dev)
|
|
134
|
+
Dinv_r = Dinv.index_select(-3, ridx)
|
|
135
|
+
bP_r = _sym(-(Dinv_r @ bDinv.index_select(-3, ridx) @ Dinv_r))
|
|
136
|
+
bP.index_copy_(-3, ridx, bP_r)
|
|
137
|
+
barAd.index_copy_(-3, ridx, bP_r)
|
|
138
|
+
for node_idx, nonroot_idx, parent_idx in reversed(lvl):
|
|
139
|
+
if nonroot_idx.numel() == 0:
|
|
140
|
+
continue
|
|
141
|
+
ell_v = ell.index_select(-3, nonroot_idx)
|
|
142
|
+
U_v = edge.index_select(-3, nonroot_idx)
|
|
143
|
+
Dinv_v = Dinv.index_select(-3, nonroot_idx)
|
|
144
|
+
bP_par = bP.index_select(-3, parent_idx)
|
|
145
|
+
bell_v = bell.index_select(-3, nonroot_idx) - bP_par @ U_v # (14d)
|
|
146
|
+
bU_v = bU.index_select(-3, nonroot_idx) - bP_par @ ell_v # (14d)
|
|
147
|
+
bU_v = bU_v + bell_v @ Dinv_v # (14a)
|
|
148
|
+
bDinv_v = bDinv.index_select(-3, nonroot_idx) + U_v.mT @ bell_v # (14a)
|
|
149
|
+
bP_v = _sym(-(Dinv_v @ bDinv_v @ Dinv_v)) # (14b)
|
|
150
|
+
bell.index_copy_(-3, nonroot_idx, bell_v)
|
|
151
|
+
bU.index_copy_(-3, nonroot_idx, bU_v)
|
|
152
|
+
bP.index_copy_(-3, nonroot_idx, bP_v)
|
|
153
|
+
barAd.index_copy_(-3, nonroot_idx, bP_v)
|
|
154
|
+
|
|
155
|
+
return barAd, bU
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
class SelInvTree(torch.autograd.Function):
|
|
159
|
+
"""Autograd Function: selected inverse of an SPD block tree with analytic backward.
|
|
160
|
+
|
|
161
|
+
Use :func:`selinv_tree` rather than calling ``apply`` directly; it resolves and
|
|
162
|
+
validates the tree topology once and routes through here.
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
@staticmethod
|
|
166
|
+
def forward(ctx, diag, edge, parent_t, root, children, collect_order, batched):
|
|
167
|
+
levels = tree_levels(parent_t) if batched else None
|
|
168
|
+
G_diag, G_edge, factors = selected_inverse_tree(
|
|
169
|
+
diag, edge, parent_t, return_factors=True, batched=batched
|
|
170
|
+
)
|
|
171
|
+
ctx.save_for_backward(edge, G_diag, factors.ell, factors.chol_D)
|
|
172
|
+
ctx.root = root
|
|
173
|
+
ctx.children = children
|
|
174
|
+
ctx.collect_order = collect_order
|
|
175
|
+
ctx.plist = parent_t.tolist()
|
|
176
|
+
ctx.levels = levels
|
|
177
|
+
ctx.batched = batched
|
|
178
|
+
return G_diag, G_edge
|
|
179
|
+
|
|
180
|
+
@staticmethod
|
|
181
|
+
def backward(ctx, barGd, barGe):
|
|
182
|
+
if ctx.batched:
|
|
183
|
+
barAd, barAe = _backward_batched(ctx, barGd, barGe)
|
|
184
|
+
else:
|
|
185
|
+
barAd, barAe = _backward_loop(ctx, barGd, barGe)
|
|
186
|
+
return barAd, barAe, None, None, None, None, None
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def selinv_tree(
|
|
190
|
+
diag: Tensor,
|
|
191
|
+
edge: Tensor,
|
|
192
|
+
parent: Tensor | Sequence[int],
|
|
193
|
+
*,
|
|
194
|
+
check: bool = False,
|
|
195
|
+
batched: bool = False,
|
|
196
|
+
):
|
|
197
|
+
"""Autograd-connected selected inverse of an SPD block tree.
|
|
198
|
+
|
|
199
|
+
Same forward result as :func:`gabp_sparse_inv.selected_inverse_tree`, but the
|
|
200
|
+
returned ``(G_diag, G_edge)`` are connected to ``diag`` and ``edge`` through the
|
|
201
|
+
analytic backward of :class:`SelInvTree`. Gradients flow to ``diag`` and ``edge``;
|
|
202
|
+
``parent`` is a static (non-differentiable) topology argument.
|
|
203
|
+
|
|
204
|
+
Parameters
|
|
205
|
+
----------
|
|
206
|
+
diag, edge, parent:
|
|
207
|
+
As in :func:`gabp_sparse_inv.selected_inverse_tree`.
|
|
208
|
+
check:
|
|
209
|
+
If ``True``, validate the block-tree inputs before computing.
|
|
210
|
+
batched:
|
|
211
|
+
If ``True``, use the level-set batched forward and backward (same result;
|
|
212
|
+
amortizes launch latency on GPU). Default is the per-node reference loop.
|
|
213
|
+
|
|
214
|
+
Returns
|
|
215
|
+
-------
|
|
216
|
+
(G_diag, G_edge):
|
|
217
|
+
``G_diag[v] = G_vv`` and ``G_edge[v] = G_{p(v),v}`` (root slot zero), with
|
|
218
|
+
autograd support w.r.t. ``diag`` and ``edge``.
|
|
219
|
+
"""
|
|
220
|
+
if check:
|
|
221
|
+
from .layout import BlockTree
|
|
222
|
+
|
|
223
|
+
BlockTree(diag=diag, edge=edge, parent=parent).validate()
|
|
224
|
+
parent_t = _as_parent_tensor(parent)
|
|
225
|
+
root, children, collect_order = tree_orders(parent_t)
|
|
226
|
+
return SelInvTree.apply(diag, edge, parent_t, root, children, collect_order, batched)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Benchmark harness: time, memory, and accuracy for chain and star kernels."""
|