spacelearn 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacelearn/__init__.py +16 -0
- spacelearn/_combined.py +261 -0
- spacelearn/_find_subspaces.py +101 -0
- spacelearn/data/__init__.py +11 -0
- spacelearn/data/_compute_Q.py +70 -0
- spacelearn/data/_pipeline.py +55 -0
- spacelearn/data/_pool.py +35 -0
- spacelearn/data/_utils.py +115 -0
- spacelearn/loss/__init__.py +17 -0
- spacelearn/loss/_alignment_loss.py +59 -0
- spacelearn/loss/_disentanglement_loss.py +111 -0
- spacelearn/loss/_isotropy_loss.py +56 -0
- spacelearn/loss/_stability_loss.py +44 -0
- spacelearn/optim/__init__.py +19 -0
- spacelearn/optim/_fit_decoder.py +56 -0
- spacelearn/optim/_optimize_latent.py +92 -0
- spacelearn/optim/_optimize_subspaces.py +109 -0
- spacelearn/optim/_optimize_targets.py +119 -0
- spacelearn/optim/_subspace_change.py +52 -0
- spacelearn/settings.py +20 -0
- spacelearn/util/__init__.py +9 -0
- spacelearn/util/_eigen_fallback.py +126 -0
- spacelearn/util/_lowrank_fallback.py +25 -0
- spacelearn/util/_svd_fallback.py +24 -0
- spacelearn/util/svd_util.py +60 -0
- spacelearn-0.1.0.dist-info/METADATA +231 -0
- spacelearn-0.1.0.dist-info/RECORD +29 -0
- spacelearn-0.1.0.dist-info/WHEEL +4 -0
- spacelearn-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2026 Tobias Karusseit
|
|
3
|
+
This source code is licensed under the MIT license found in the
|
|
4
|
+
LICENSE file in the root directory of this source tree.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
|
|
10
|
+
def alignment_loss(
|
|
11
|
+
Z: torch.Tensor,
|
|
12
|
+
W: dict[str, torch.Tensor],
|
|
13
|
+
A: dict[str, torch.Tensor],
|
|
14
|
+
Y: dict[str, torch.Tensor],
|
|
15
|
+
) -> torch.Tensor:
|
|
16
|
+
"""
|
|
17
|
+
Physics reconstruction loss: penalizes how well each subspace's
|
|
18
|
+
decoded projection matches its pooled physical target.
|
|
19
|
+
|
|
20
|
+
#### USAGE:
|
|
21
|
+
loss = alignment_loss(Z, W, A, Y)
|
|
22
|
+
loss.backward()
|
|
23
|
+
|
|
24
|
+
#### HOW IT WORKS:
|
|
25
|
+
For each quantity q in W:
|
|
26
|
+
S_q = Z_c @ W_q.T # (B, k) latent coordinates in subspace q
|
|
27
|
+
y_hat = S_q @ A_q # (B, n^2) decoded prediction
|
|
28
|
+
loss += MSE(y_hat, Y_q_centered)
|
|
29
|
+
The per-quantity losses are averaged across all quantities in `W`.
|
|
30
|
+
|
|
31
|
+
#### ARGS:
|
|
32
|
+
- Z: (B, D) latent vectors.
|
|
33
|
+
- W: {name: (k, D)} subspace bases.
|
|
34
|
+
- A: {name: (k, n^2)} linear decoders.
|
|
35
|
+
- Y: {name: (B, n^2) or (B, n, n) ...} pooled physical targets.
|
|
36
|
+
Targets with more than 2 dims are flattened to (B, n^2).
|
|
37
|
+
|
|
38
|
+
#### RETURNS:
|
|
39
|
+
- A scalar torch.Tensor loss. Perfect reconstruction -> 0.
|
|
40
|
+
Averaged across quantities in `W`.
|
|
41
|
+
"""
|
|
42
|
+
Z_c = Z - Z.mean(0, keepdim=True)
|
|
43
|
+
names = list(W.keys())
|
|
44
|
+
|
|
45
|
+
loss = torch.tensor(0.0, device=Z.device)
|
|
46
|
+
for name in names:
|
|
47
|
+
W_q = W[name].to(Z.device)
|
|
48
|
+
A_q = A[name].to(Z.device)
|
|
49
|
+
S_q = Z_c @ W_q.T # (B, k)
|
|
50
|
+
|
|
51
|
+
target = Y[name].to(Z.device)
|
|
52
|
+
if target.dim() > 2:
|
|
53
|
+
target = target.flatten(1)
|
|
54
|
+
target_c = target - target.mean(0, keepdim=True)
|
|
55
|
+
|
|
56
|
+
y_hat = S_q @ A_q
|
|
57
|
+
loss = loss + F.mse_loss(y_hat, target_c)
|
|
58
|
+
|
|
59
|
+
return loss / max(len(names), 1)
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2026 Tobias Karusseit
|
|
3
|
+
This source code is licensed under the MIT license found in the
|
|
4
|
+
LICENSE file in the root directory of this source tree.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
def disentanglement_loss_original(
|
|
10
|
+
Z: torch.Tensor,
|
|
11
|
+
W: dict[str, torch.Tensor]
|
|
12
|
+
) -> torch.Tensor:
|
|
13
|
+
"""
|
|
14
|
+
[LEGACY] Penalizes correlation between latent vectors projected into
|
|
15
|
+
different subspaces, using mean squared correlation per pair.
|
|
16
|
+
|
|
17
|
+
We recommend `disentanglement_loss` instead — see that function's
|
|
18
|
+
docstring for the difference.
|
|
19
|
+
|
|
20
|
+
#### USAGE:
|
|
21
|
+
loss = disentanglement_loss_original(Z, W)
|
|
22
|
+
loss.backward()
|
|
23
|
+
|
|
24
|
+
#### HOW IT WORKS:
|
|
25
|
+
For every pair of subspaces (i, j):
|
|
26
|
+
S_i, S_j = Z_c @ W_i.T, Z_c @ W_j.T # project into each subspace
|
|
27
|
+
normalize S_i, S_j to unit std (no centering)
|
|
28
|
+
corr = S_i_n.T @ S_j_n / B # (k_i, k_j) correlation matrix
|
|
29
|
+
total += mean(corr ** 2) # squared correlation
|
|
30
|
+
The sum is averaged over the number of subspace pairs.
|
|
31
|
+
|
|
32
|
+
#### ARGS:
|
|
33
|
+
- Z: (B, D) latent vectors.
|
|
34
|
+
- W: {name: (D, k)} subspace bases.
|
|
35
|
+
|
|
36
|
+
#### RETURNS:
|
|
37
|
+
- A scalar torch.Tensor loss. Perfect disentanglement -> 0,
|
|
38
|
+
fully correlated -> 1.
|
|
39
|
+
"""
|
|
40
|
+
total = torch.tensor(0.0, device=Z.device)
|
|
41
|
+
Z_c = Z - Z.mean(0, keepdim=True)
|
|
42
|
+
names = list(W.keys())
|
|
43
|
+
for i in range(len(names)):
|
|
44
|
+
for j in range(i + 1, len(names)):
|
|
45
|
+
Si = Z_c @ W[names[i]].T # the latent vector in subspace i
|
|
46
|
+
Sj = Z_c @ W[names[j]].T # the latent vector in subspace j
|
|
47
|
+
Si_n = Si / (Si.std(dim=0, keepdim=True) + 1e-8) # normed to unit length (for scale invariance in the comming step)
|
|
48
|
+
Sj_n = Sj / (Sj.std(dim=0, keepdim=True) + 1e-8) # normed to unit length
|
|
49
|
+
corr = Si_n.T @ Sj_n / Si.shape[0] # correlation
|
|
50
|
+
total = total + corr.pow(2).mean() # squared correlation to focus on large correlations
|
|
51
|
+
n_pairs = len(names) * (len(names) - 1) / 2 # number of pairs for normalization
|
|
52
|
+
print(f"[disentanglement] n_pairs={n_pairs}")
|
|
53
|
+
total = total / max(n_pairs, 1) # normalized by number of pairs
|
|
54
|
+
return total
|
|
55
|
+
|
|
56
|
+
def disentanglement_loss(
|
|
57
|
+
Z: torch.Tensor,
|
|
58
|
+
W: dict[str, torch.Tensor]
|
|
59
|
+
) -> torch.Tensor:
|
|
60
|
+
"""
|
|
61
|
+
Penalizes correlation between latent vectors projected into different
|
|
62
|
+
subspaces, using the normalized trace of the squared correlation
|
|
63
|
+
matrix per pair.
|
|
64
|
+
|
|
65
|
+
#### USAGE:
|
|
66
|
+
loss = disentanglement_loss(Z, W)
|
|
67
|
+
loss.backward()
|
|
68
|
+
|
|
69
|
+
#### HOW IT WORKS:
|
|
70
|
+
For every pair of subspaces (i, j):
|
|
71
|
+
S_i, S_j = Z_c @ W_i.T, Z_c @ W_j.T # project into each subspace
|
|
72
|
+
standardize S_i, S_j (center and divide by std)
|
|
73
|
+
Corr = S_i_std.T @ S_j_std / B # (k_i, k_j) correlation matrix
|
|
74
|
+
total += trace(Corr.T @ Corr) / k_i # normalized squared correlation
|
|
75
|
+
The sum is averaged over the number of subspace pairs.
|
|
76
|
+
|
|
77
|
+
#### ARGS:
|
|
78
|
+
- Z: (B, D) latent vectors.
|
|
79
|
+
- W: {name: (D, k)} subspace bases.
|
|
80
|
+
|
|
81
|
+
#### RETURNS:
|
|
82
|
+
- A scalar torch.Tensor loss. Perfect disentanglement -> 0,
|
|
83
|
+
fully correlated -> 1.
|
|
84
|
+
"""
|
|
85
|
+
Z_c = Z - Z.mean(0, keepdim=True)
|
|
86
|
+
total = torch.tensor(0.0, device=Z.device)
|
|
87
|
+
names = list(W.keys())
|
|
88
|
+
pairs = 0
|
|
89
|
+
|
|
90
|
+
for i in range(len(names)):
|
|
91
|
+
for j in range(i + 1, len(names)):
|
|
92
|
+
# Project
|
|
93
|
+
Si = Z_c @ W[names[i]].T
|
|
94
|
+
Sj = Z_c @ W[names[j]].T
|
|
95
|
+
|
|
96
|
+
# Standardize (Center and divide by std)
|
|
97
|
+
# This is the correct way to compute correlation
|
|
98
|
+
Si_std = (Si - Si.mean(0, keepdim=True)) / (Si.std(0, keepdim=True) + 1e-8)
|
|
99
|
+
Sj_std = (Sj - Sj.mean(0, keepdim=True)) / (Sj.std(0, keepdim=True) + 1e-8)
|
|
100
|
+
|
|
101
|
+
# Correlation matrix (ki, kj)
|
|
102
|
+
# Dividing by Z.shape[0] gives the covariance of standardized variables
|
|
103
|
+
Corr = Si_std.T @ Sj_std / Z.shape[0]
|
|
104
|
+
|
|
105
|
+
# Use the Trace of the squared correlation matrix
|
|
106
|
+
# For identical subspaces, trace(Corr^2) = k
|
|
107
|
+
# Dividing by k gives exactly 1.0
|
|
108
|
+
total += torch.trace(Corr.T @ Corr) / Si.shape[1]
|
|
109
|
+
pairs += 1
|
|
110
|
+
|
|
111
|
+
return total / max(pairs, 1)
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2026 Tobias Karusseit
|
|
3
|
+
This source code is licensed under the MIT license found in the
|
|
4
|
+
LICENSE file in the root directory of this source tree.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from ..util import singular_values
|
|
9
|
+
|
|
10
|
+
def isotropy_loss(
|
|
11
|
+
Z: torch.Tensor,
|
|
12
|
+
W: dict[str, torch.Tensor],
|
|
13
|
+
n_dir: int = 32
|
|
14
|
+
) -> torch.Tensor:
|
|
15
|
+
"""
|
|
16
|
+
Enforces isotropy in the latent residual space by penalizing variance
|
|
17
|
+
among its leading singular values.
|
|
18
|
+
|
|
19
|
+
#### USAGE:
|
|
20
|
+
loss = isotropy_loss(Z, W, n_dir=32)
|
|
21
|
+
loss.backward()
|
|
22
|
+
|
|
23
|
+
#### HOW IT WORKS:
|
|
24
|
+
Z_c = Z - Z.mean(0) # center latents
|
|
25
|
+
W_all = concat(W.values()) # (sum_k, D) stacked bases
|
|
26
|
+
Z_resid = Z_c - Z_c @ W_all.T @ W_all # residual after projecting
|
|
27
|
+
# out all known subspaces
|
|
28
|
+
# (assumes subspaces are
|
|
29
|
+
# mutually orthogonal)
|
|
30
|
+
S = singular_values(Z_resid) # singular values of residual
|
|
31
|
+
S_top = S[:n_dir] # leading n_dir values
|
|
32
|
+
S_norm = S_top / S_top.mean() # normalize to mean 1
|
|
33
|
+
loss = mean((S_norm - 1) ** 2) # penalize deviation
|
|
34
|
+
# from uniform spread
|
|
35
|
+
|
|
36
|
+
#### ARGS:
|
|
37
|
+
- Z: (B, D) latent vectors.
|
|
38
|
+
- W: {name: (D, k)} subspace bases already accounted for (e.g. by
|
|
39
|
+
alignment/disentanglement losses). These are projected out of
|
|
40
|
+
`Z` before computing isotropy on what remains.
|
|
41
|
+
- n_dir: Number of leading singular values of the residual to
|
|
42
|
+
consider. Defaults to 32.
|
|
43
|
+
|
|
44
|
+
#### RETURNS:
|
|
45
|
+
- A scalar torch.Tensor loss. Perfectly isotropic residual
|
|
46
|
+
directions (uniform singular value spectrum) -> 0.
|
|
47
|
+
"""
|
|
48
|
+
z_c = Z - Z.mean(0, keepdim=True) # centered for variance calc
|
|
49
|
+
names = list(W.keys())
|
|
50
|
+
W_all = torch.cat([W[n] for n in names], dim=0).to(Z.device)
|
|
51
|
+
z_resid = z_c - z_c @ W_all.T @ W_all # residual subspace (approximation only, assumes the subspaces are orthogonal)
|
|
52
|
+
S_sv = singular_values(z_resid)
|
|
53
|
+
S_top = S_sv[:n_dir]
|
|
54
|
+
S_norm = S_top / (S_top.mean() + 1e-8)
|
|
55
|
+
loss_rank = ((S_norm - 1.0) ** 2).mean() # punish the model if the svd vals are not even
|
|
56
|
+
return loss_rank
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2026 Tobias Karusseit
|
|
3
|
+
This source code is licensed under the MIT license found in the
|
|
4
|
+
LICENSE file in the root directory of this source tree.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
def stability_loss(
|
|
10
|
+
W: dict[str, torch.Tensor],
|
|
11
|
+
W_prev: dict[str, torch.Tensor]
|
|
12
|
+
) -> torch.Tensor:
|
|
13
|
+
"""
|
|
14
|
+
Punishes large or rapid changes in subspace geometry between training
|
|
15
|
+
steps, encouraging the model to evolve its latent subspace structure
|
|
16
|
+
gradually rather than jumping between geometries.
|
|
17
|
+
|
|
18
|
+
#### USAGE:
|
|
19
|
+
loss = stability_loss(W, W_prev)
|
|
20
|
+
loss.backward()
|
|
21
|
+
|
|
22
|
+
#### HOW IT WORKS:
|
|
23
|
+
If W_prev is provided:
|
|
24
|
+
for each quantity name in W:
|
|
25
|
+
loss += sum((W[name] - W_prev[name]) ** 2) # squared drift
|
|
26
|
+
loss /= number of quantities
|
|
27
|
+
Otherwise, the loss is 0 (e.g. on the first step, when there is no
|
|
28
|
+
previous subspace to compare against).
|
|
29
|
+
|
|
30
|
+
#### ARGS:
|
|
31
|
+
- W: {name: (D, k)} current subspace bases.
|
|
32
|
+
- W_prev: {name: (D, k)} subspace bases from the previous step, or
|
|
33
|
+
None if there is no previous step to compare against.
|
|
34
|
+
|
|
35
|
+
#### RETURNS:
|
|
36
|
+
- A scalar torch.Tensor loss. No drift -> 0; larger or more
|
|
37
|
+
rapid changes in subspace geometry -> larger values.
|
|
38
|
+
"""
|
|
39
|
+
loss_stab = torch.tensor(0.0, device=W[next(iter(W.keys()))].device)
|
|
40
|
+
if W_prev is not None:
|
|
41
|
+
for name in W.keys():
|
|
42
|
+
loss_stab = loss_stab + (W[name] - W_prev[name]).pow(2).sum()
|
|
43
|
+
loss_stab = loss_stab / len(W.keys())
|
|
44
|
+
return loss_stab
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2026 Tobias Karusseit
|
|
3
|
+
This source code is licensed under the MIT license found in the
|
|
4
|
+
LICENSE file in the root directory of this source tree.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from ._optimize_targets import optimal_bins
|
|
8
|
+
from ._optimize_latent import minimal_latent_dim
|
|
9
|
+
from ._fit_decoder import fit_decoders
|
|
10
|
+
from ._subspace_change import subspace_change
|
|
11
|
+
from ._optimize_subspaces import k_per_q
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"optimal_bins",
|
|
15
|
+
"minimal_latent_dim",
|
|
16
|
+
"fit_decoders",
|
|
17
|
+
"subspace_change",
|
|
18
|
+
"k_per_q",
|
|
19
|
+
]
|
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2026 Tobias Karusseit
|
|
3
|
+
This source code is licensed under the MIT license found in the
|
|
4
|
+
LICENSE file in the root directory of this source tree.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
def fit_decoders(W: dict, Z: torch.Tensor, pooled_global: dict[str, torch.Tensor]) -> dict:
|
|
10
|
+
"""
|
|
11
|
+
Learns a linear mapping (decoder) from each latent subspace's
|
|
12
|
+
coordinates to its corresponding physical target space, via
|
|
13
|
+
least-squares regression.
|
|
14
|
+
|
|
15
|
+
#### USAGE:
|
|
16
|
+
A = fit_decoders(W, Z, pooled_global)
|
|
17
|
+
# A can then be passed directly to alignment_loss as the
|
|
18
|
+
# 'A' (decoder) argument
|
|
19
|
+
|
|
20
|
+
#### HOW IT WORKS:
|
|
21
|
+
Z_c = Z - Z.mean(0) # center latents
|
|
22
|
+
for each quantity name in W:
|
|
23
|
+
S_q = Z_c @ W_q.T # project latents
|
|
24
|
+
# into subspace q
|
|
25
|
+
Y_q = pooled_global[name] # physical target,
|
|
26
|
+
# flattened if needed
|
|
27
|
+
Y_q_c = Y_q - Y_q.mean(0) # center target
|
|
28
|
+
A_q = lstsq(S_q, Y_q_c).solution # solve for A_q
|
|
29
|
+
# minimizing
|
|
30
|
+
# ||Y_q_c - S_q @ A_q||_F
|
|
31
|
+
Each A_q is detached from the autograd graph before being
|
|
32
|
+
returned, since it is fit via least squares rather than learned
|
|
33
|
+
through backpropagation.
|
|
34
|
+
|
|
35
|
+
#### ARGS:
|
|
36
|
+
- W: {name: (k, D)} subspace basis matrices.
|
|
37
|
+
- Z: (B, D) latent vectors.
|
|
38
|
+
- pooled_global: {name: (B, n^2) or (B, n, n) ...} pooled
|
|
39
|
+
physical targets. Targets with more than 2 dims are flattened
|
|
40
|
+
to (B, n^2) before fitting.
|
|
41
|
+
|
|
42
|
+
#### RETURNS:
|
|
43
|
+
- A dict {name: (k, n^2)} of optimal linear decoder matrices, one
|
|
44
|
+
per quantity in W, each detached from the autograd graph.
|
|
45
|
+
"""
|
|
46
|
+
Z_c = Z - Z.mean(0, keepdim=True) # center
|
|
47
|
+
A = {}
|
|
48
|
+
for name, W_q in W.items():
|
|
49
|
+
S_q = Z_c @ W_q.T # projection of Z into the subspace
|
|
50
|
+
y = pooled_global[name] # pooled target
|
|
51
|
+
if y.dim() > 2: # flatten
|
|
52
|
+
y = y.flatten(1)
|
|
53
|
+
y_c = y - y.mean(0, keepdim=True) # center
|
|
54
|
+
# solve (finds A, such that (y_c - S_q @ A) is approx 0)
|
|
55
|
+
A[name] = torch.linalg.lstsq(S_q, y_c).solution.detach()
|
|
56
|
+
return A
|
|
@@ -0,0 +1,92 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2026 Tobias Karusseit
|
|
3
|
+
This source code is licensed under the MIT license found in the
|
|
4
|
+
LICENSE file in the root directory of this source tree.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _align_latent_dim(dim: int, num_heads: int) -> int:
|
|
11
|
+
"""Round dim up to the next multiple of num_heads (unchanged if already divisible)."""
|
|
12
|
+
if dim < 1:
|
|
13
|
+
raise ValueError(f"dim must be >= 1, got {dim}")
|
|
14
|
+
if num_heads < 1:
|
|
15
|
+
raise ValueError(f"num_heads must be >= 1, got {num_heads}")
|
|
16
|
+
if num_heads == 1:
|
|
17
|
+
return dim
|
|
18
|
+
remainder = dim % num_heads
|
|
19
|
+
if remainder == 0:
|
|
20
|
+
return dim
|
|
21
|
+
return dim + (num_heads - remainder)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def minimal_latent_dim(
|
|
25
|
+
k: int | dict[str, int],
|
|
26
|
+
*,
|
|
27
|
+
Q: int | None = None,
|
|
28
|
+
free_frac: float = 0.30,
|
|
29
|
+
num_heads: int | None = None,
|
|
30
|
+
) -> int:
|
|
31
|
+
"""
|
|
32
|
+
Computes the minimum latent dimension needed to fit a set of physics
|
|
33
|
+
subspaces plus a fraction of free (unconstrained) capacity, optionally
|
|
34
|
+
aligned to a multiple of num_heads.
|
|
35
|
+
|
|
36
|
+
#### USAGE:
|
|
37
|
+
# Example 1: shared k across all quantities
|
|
38
|
+
D = minimal_latent_dim(k=4, Q=6, free_frac=0.30)
|
|
39
|
+
|
|
40
|
+
# Example 2: per-quantity k
|
|
41
|
+
D = minimal_latent_dim(k={"scaled": 4, "grammian": 6}, free_frac=0.30)
|
|
42
|
+
|
|
43
|
+
# Example 3: align result to attention head count
|
|
44
|
+
D = minimal_latent_dim(k=4, Q=6, num_heads=8)
|
|
45
|
+
|
|
46
|
+
#### HOW IT WORKS:
|
|
47
|
+
physics_dims = Q * k # if k is a shared scalar
|
|
48
|
+
physics_dims = sum(k.values()) # if k is per-quantity
|
|
49
|
+
D_min = ceil(physics_dims / (1 - free_frac))
|
|
50
|
+
if num_heads is set:
|
|
51
|
+
D_min = next multiple of num_heads >= D_min
|
|
52
|
+
|
|
53
|
+
#### ARGS:
|
|
54
|
+
- k: Subspace dimensionality. Either a single int shared across all
|
|
55
|
+
`Q` quantities, or a dict mapping each quantity name to its
|
|
56
|
+
own k.
|
|
57
|
+
- Q: Number of quantities. Required (and only used) when `k` is a
|
|
58
|
+
scalar; ignored when `k` is a dict, since the count is
|
|
59
|
+
inferred from its keys.
|
|
60
|
+
- free_frac: Fraction of the latent dimension to reserve as free,
|
|
61
|
+
unconstrained capacity beyond the physics subspaces.
|
|
62
|
+
Must be in [0, 1). Defaults to 0.30.
|
|
63
|
+
- num_heads: If set, the result is rounded up to the next multiple
|
|
64
|
+
of num_heads (e.g. to match an attention head count).
|
|
65
|
+
Defaults to None (no alignment).
|
|
66
|
+
|
|
67
|
+
#### RETURNS:
|
|
68
|
+
- The minimum latent dimension D_min satisfying the physics
|
|
69
|
+
subspace requirement, the free-capacity fraction, and (if
|
|
70
|
+
given) the num_heads alignment constraint.
|
|
71
|
+
"""
|
|
72
|
+
if not 0.0 <= free_frac < 1.0:
|
|
73
|
+
raise ValueError(f"free_frac must be in [0, 1), got {free_frac}")
|
|
74
|
+
|
|
75
|
+
if isinstance(k, dict):
|
|
76
|
+
physics_dims = sum(k.values())
|
|
77
|
+
detail = ", ".join(f"{name}={k_q}" for name, k_q in sorted(k.items()))
|
|
78
|
+
# print(f"[latent-sizing] k per quantity: {detail} (sum={physics_dims})")
|
|
79
|
+
else:
|
|
80
|
+
if Q is None:
|
|
81
|
+
raise ValueError("Q is required when k is a scalar")
|
|
82
|
+
physics_dims = Q * k
|
|
83
|
+
# print(f"[latent-sizing] Q={Q}, k={k}")
|
|
84
|
+
|
|
85
|
+
D_min = math.ceil(physics_dims / (1.0 - free_frac))
|
|
86
|
+
print(f"[latent-sizing] free={free_frac * 100:.0f}% -> D_min={D_min}")
|
|
87
|
+
if num_heads is not None:
|
|
88
|
+
aligned = _align_latent_dim(D_min, num_heads)
|
|
89
|
+
# if aligned != D_min:
|
|
90
|
+
# print(f"[latent-sizing] aligned D {D_min} -> {aligned} (num_heads={num_heads})")
|
|
91
|
+
D_min = aligned
|
|
92
|
+
return D_min
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2026 Tobias Karusseit
|
|
3
|
+
This source code is licensed under the MIT license found in the
|
|
4
|
+
LICENSE file in the root directory of this source tree.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from ..util import singular_values
|
|
9
|
+
|
|
10
|
+
def _k_from_target(target: torch.Tensor, threshold: float) -> int:
|
|
11
|
+
if target.dim() == 1:
|
|
12
|
+
target = target.unsqueeze(1)
|
|
13
|
+
elif target.dim() > 2:
|
|
14
|
+
target = target.flatten(1)
|
|
15
|
+
y_c = target - target.mean(0, keepdim=True)
|
|
16
|
+
S = singular_values(y_c)
|
|
17
|
+
var = S.pow(2)
|
|
18
|
+
spec = var / (var.sum() + 1e-8)
|
|
19
|
+
cumvar = spec.cumsum(0)
|
|
20
|
+
return int((cumvar < threshold).sum().item() + 1)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def _resolve_threshold(
|
|
24
|
+
name: str,
|
|
25
|
+
threshold: float | dict[str, float],
|
|
26
|
+
default: float,
|
|
27
|
+
) -> float:
|
|
28
|
+
if isinstance(threshold, dict):
|
|
29
|
+
return float(threshold.get(name, default))
|
|
30
|
+
return float(threshold)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def k_per_q(
|
|
34
|
+
pooled_global: dict[str, torch.Tensor],
|
|
35
|
+
threshold: float | dict[str, float] = 0.90,
|
|
36
|
+
*,
|
|
37
|
+
default_threshold: float | None = None,
|
|
38
|
+
per_quantity: bool = True,
|
|
39
|
+
) -> int | dict[str, int]:
|
|
40
|
+
"""
|
|
41
|
+
Estimates the minimal subspace dimension k needed per quantity to
|
|
42
|
+
retain a target fraction of variance, via PCA-style spectral analysis
|
|
43
|
+
of each pooled target.
|
|
44
|
+
|
|
45
|
+
#### USAGE:
|
|
46
|
+
# Example 1: shared threshold, one k per quantity
|
|
47
|
+
k = k_per_q(pooled_global, threshold=0.90)
|
|
48
|
+
# k = {"scaled": 3, "grammian": 5}
|
|
49
|
+
|
|
50
|
+
# Example 2: per-quantity thresholds, with a fallback for missing names
|
|
51
|
+
k = k_per_q(
|
|
52
|
+
pooled_global,
|
|
53
|
+
threshold={"scaled": 0.95},
|
|
54
|
+
default_threshold=0.90,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# Example 3: collapse to a single global k (the max across quantities)
|
|
58
|
+
k = k_per_q(pooled_global, threshold=0.90, per_quantity=False)
|
|
59
|
+
# k = 5
|
|
60
|
+
|
|
61
|
+
#### HOW IT WORKS:
|
|
62
|
+
For each quantity name, target in pooled_global:
|
|
63
|
+
target_c = target - target.mean(0) # center (flattened
|
|
64
|
+
# if > 2D)
|
|
65
|
+
S = singular_values(target_c) # spectrum of target
|
|
66
|
+
spec = S^2 / sum(S^2) # normalized variance
|
|
67
|
+
# per component
|
|
68
|
+
k_q = number of leading components needed for
|
|
69
|
+
cumulative variance to reach `threshold`
|
|
70
|
+
If per_quantity is True, returns {name: k_q} for all quantities.
|
|
71
|
+
Otherwise, returns max(k_q for all quantities).
|
|
72
|
+
|
|
73
|
+
#### ARGS:
|
|
74
|
+
- pooled_global: {name: (B, n^2) or (B, n, n) ...} pooled physical
|
|
75
|
+
targets, one per quantity. Targets with more than 2 dims are
|
|
76
|
+
flattened to (B, n^2); 1D targets are treated as (B, 1).
|
|
77
|
+
- threshold: Either a single float shared across all quantities,
|
|
78
|
+
or a {name: threshold} dict for per-quantity thresholds. Each
|
|
79
|
+
threshold is the target fraction of cumulative variance to
|
|
80
|
+
retain (e.g. 0.90 -> keep enough components for 90% variance).
|
|
81
|
+
Defaults to 0.90.
|
|
82
|
+
- default_threshold: Fallback threshold used for any quantity name
|
|
83
|
+
missing from a `threshold` dict. If None, falls back to the
|
|
84
|
+
scalar `threshold` value (or 0.90 if `threshold` is itself a
|
|
85
|
+
dict with no scalar to fall back to). Defaults to None
|
|
86
|
+
(keyword-only).
|
|
87
|
+
- per_quantity: If True, returns a dict of k per quantity. If
|
|
88
|
+
False, returns the single largest k found across all
|
|
89
|
+
quantities. Defaults to True (keyword-only).
|
|
90
|
+
|
|
91
|
+
#### RETURNS:
|
|
92
|
+
- If per_quantity is True: a dict {name: int} mapping each
|
|
93
|
+
quantity to its estimated subspace dimension k.
|
|
94
|
+
- If per_quantity is False: a single int, the maximum k across
|
|
95
|
+
all quantities.
|
|
96
|
+
"""
|
|
97
|
+
fallback = default_threshold if default_threshold is not None else (
|
|
98
|
+
float(threshold) if not isinstance(threshold, dict) else 0.90
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
k_map = {
|
|
102
|
+
name: _k_from_target(target, _resolve_threshold(name, threshold, fallback))
|
|
103
|
+
for name, target in pooled_global.items()
|
|
104
|
+
}
|
|
105
|
+
|
|
106
|
+
if per_quantity:
|
|
107
|
+
return k_map
|
|
108
|
+
|
|
109
|
+
return max(k_map.values())
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2026 Tobias Karusseit
|
|
3
|
+
This source code is licensed under the MIT license found in the
|
|
4
|
+
LICENSE file in the root directory of this source tree.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
import torch
|
|
9
|
+
from ..util import singular_values
|
|
10
|
+
|
|
11
|
+
def _resolve_threshold(
|
|
12
|
+
name: str,
|
|
13
|
+
threshold: float | dict[str, float],
|
|
14
|
+
default: float,
|
|
15
|
+
) -> float:
|
|
16
|
+
if isinstance(threshold, dict):
|
|
17
|
+
return float(threshold.get(name, default))
|
|
18
|
+
return float(threshold)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def optimal_bins(
|
|
22
|
+
pooled_targets_coarse: dict[str, torch.Tensor],
|
|
23
|
+
max_bins: int = 16,
|
|
24
|
+
variance_threshold: float | dict[str, float] = 0.95,
|
|
25
|
+
*,
|
|
26
|
+
default_threshold: float | None = None,
|
|
27
|
+
per_quantity: bool = False,
|
|
28
|
+
) -> int | dict[str, int]:
|
|
29
|
+
"""
|
|
30
|
+
Finds the smallest n such that an n x n spatial pooling grid captures
|
|
31
|
+
a target fraction of the spatial variance in each physical target,
|
|
32
|
+
via PCA-style spectral analysis on a coarse pooling of that target.
|
|
33
|
+
|
|
34
|
+
Inspired by Nyquist sampling: this finds the resolution at which
|
|
35
|
+
adding more bins gives diminishing returns on spatial information
|
|
36
|
+
captured, rather than picking a pooling resolution by hand.
|
|
37
|
+
|
|
38
|
+
#### USAGE:
|
|
39
|
+
# Example 1: shared threshold, single global n
|
|
40
|
+
n = optimal_bins(pooled_targets_coarse, max_bins=16, variance_threshold=0.95)
|
|
41
|
+
|
|
42
|
+
# Example 2: per-quantity thresholds and per-quantity result
|
|
43
|
+
n = optimal_bins(
|
|
44
|
+
pooled_targets_coarse,
|
|
45
|
+
variance_threshold={"scaled": 0.99},
|
|
46
|
+
default_threshold=0.95,
|
|
47
|
+
per_quantity=True,
|
|
48
|
+
)
|
|
49
|
+
# n = {"scaled": 6, "grammian": 4}
|
|
50
|
+
|
|
51
|
+
#### HOW IT WORKS:
|
|
52
|
+
For each quantity name, Y in pooled_targets_coarse:
|
|
53
|
+
eta = resolved variance_threshold for this quantity
|
|
54
|
+
Y_c = Y - Y.mean(0) # center (flattened
|
|
55
|
+
# if > 2D)
|
|
56
|
+
S = singular_values(Y_c) # spectrum of Y
|
|
57
|
+
cumvar = cumsum(S^2) / sum(S^2) # cumulative
|
|
58
|
+
# variance explained
|
|
59
|
+
n_modes = number of leading modes needed for cumvar >= eta
|
|
60
|
+
n_bins_q = ceil(sqrt(n_modes)) # modes -> grid side
|
|
61
|
+
# length, clamped
|
|
62
|
+
# to [2, max_bins]
|
|
63
|
+
If per_quantity is True, returns {name: n_bins_q} for all quantities.
|
|
64
|
+
Otherwise, returns max(n_bins_q for all quantities), so every
|
|
65
|
+
quantity is adequately sampled by a single shared resolution.
|
|
66
|
+
|
|
67
|
+
#### ARGS:
|
|
68
|
+
- pooled_targets_coarse: {name: (N, n_coarse^2)} pooled physical
|
|
69
|
+
targets at some coarse resolution, used to estimate the spatial
|
|
70
|
+
spectrum. Targets with more than 2 dims are flattened to
|
|
71
|
+
(N, n_coarse^2).
|
|
72
|
+
- max_bins: Maximum grid side length n to consider. Defaults to
|
|
73
|
+
16 (i.e. up to a 16x16 = 256 bin grid).
|
|
74
|
+
- variance_threshold: Either a single float shared across all
|
|
75
|
+
quantities, or a {name: threshold} dict for per-quantity
|
|
76
|
+
thresholds. Each threshold is the target fraction of spatial
|
|
77
|
+
variance to capture. Defaults to 0.95.
|
|
78
|
+
- default_threshold: Fallback threshold for any quantity name
|
|
79
|
+
missing from a `variance_threshold` dict. If None, falls back
|
|
80
|
+
to the scalar `variance_threshold` value (or 0.95 if
|
|
81
|
+
`variance_threshold` is itself a dict with no scalar to fall
|
|
82
|
+
back to). Defaults to None (keyword-only).
|
|
83
|
+
- per_quantity: If True, returns a dict of n_optimal per
|
|
84
|
+
quantity. If False, returns the single largest n_optimal found
|
|
85
|
+
across all quantities. Defaults to False (keyword-only).
|
|
86
|
+
|
|
87
|
+
#### RETURNS:
|
|
88
|
+
- If per_quantity is True: a dict {name: int} mapping each
|
|
89
|
+
quantity to its optimal grid side length n.
|
|
90
|
+
- If per_quantity is False: a single int, the largest n across
|
|
91
|
+
all quantities, ensuring every quantity is adequately sampled.
|
|
92
|
+
"""
|
|
93
|
+
fallback = default_threshold if default_threshold is not None else (
|
|
94
|
+
float(variance_threshold) if not isinstance(variance_threshold, dict) else 0.95
|
|
95
|
+
)
|
|
96
|
+
n_optimal_per_quantity = {}
|
|
97
|
+
|
|
98
|
+
for name, Y in pooled_targets_coarse.items():
|
|
99
|
+
eta = _resolve_threshold(name, variance_threshold, fallback)
|
|
100
|
+
if Y.dim() > 2:
|
|
101
|
+
Y = Y.flatten(1)
|
|
102
|
+
Y_c = Y - Y.mean(0, keepdim=True)
|
|
103
|
+
|
|
104
|
+
S = singular_values(Y_c)
|
|
105
|
+
var = S.pow(2)
|
|
106
|
+
cumvar = var.cumsum(0) / (var.sum() + 1e-8)
|
|
107
|
+
|
|
108
|
+
n_modes = int((cumvar < eta).sum().item() + 1)
|
|
109
|
+
n_bins_q = math.ceil(math.sqrt(n_modes))
|
|
110
|
+
n_bins_q = max(2, min(n_bins_q, max_bins))
|
|
111
|
+
|
|
112
|
+
n_optimal_per_quantity[name] = n_bins_q
|
|
113
|
+
print(f" [{name}] {n_modes} modes (η={eta:.2f}) -> n_bins={n_bins_q}")
|
|
114
|
+
|
|
115
|
+
if per_quantity:
|
|
116
|
+
return n_optimal_per_quantity
|
|
117
|
+
|
|
118
|
+
# take maximum so every quantity is adequately sampled
|
|
119
|
+
return max(n_optimal_per_quantity.values())
|