spacelearn 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,59 @@
1
+ """
2
+ Copyright (c) 2026 Tobias Karusseit
3
+ This source code is licensed under the MIT license found in the
4
+ LICENSE file in the root directory of this source tree.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ def alignment_loss(
11
+ Z: torch.Tensor,
12
+ W: dict[str, torch.Tensor],
13
+ A: dict[str, torch.Tensor],
14
+ Y: dict[str, torch.Tensor],
15
+ ) -> torch.Tensor:
16
+ """
17
+ Physics reconstruction loss: penalizes how well each subspace's
18
+ decoded projection matches its pooled physical target.
19
+
20
+ #### USAGE:
21
+ loss = alignment_loss(Z, W, A, Y)
22
+ loss.backward()
23
+
24
+ #### HOW IT WORKS:
25
+ For each quantity q in W:
26
+ S_q = Z_c @ W_q.T # (B, k) latent coordinates in subspace q
27
+ y_hat = S_q @ A_q # (B, n^2) decoded prediction
28
+ loss += MSE(y_hat, Y_q_centered)
29
+ The per-quantity losses are averaged across all quantities in `W`.
30
+
31
+ #### ARGS:
32
+ - Z: (B, D) latent vectors.
33
+ - W: {name: (k, D)} subspace bases.
34
+ - A: {name: (k, n^2)} linear decoders.
35
+ - Y: {name: (B, n^2) or (B, n, n) ...} pooled physical targets.
36
+ Targets with more than 2 dims are flattened to (B, n^2).
37
+
38
+ #### RETURNS:
39
+ - A scalar torch.Tensor loss. Perfect reconstruction -> 0.
40
+ Averaged across quantities in `W`.
41
+ """
42
+ Z_c = Z - Z.mean(0, keepdim=True)
43
+ names = list(W.keys())
44
+
45
+ loss = torch.tensor(0.0, device=Z.device)
46
+ for name in names:
47
+ W_q = W[name].to(Z.device)
48
+ A_q = A[name].to(Z.device)
49
+ S_q = Z_c @ W_q.T # (B, k)
50
+
51
+ target = Y[name].to(Z.device)
52
+ if target.dim() > 2:
53
+ target = target.flatten(1)
54
+ target_c = target - target.mean(0, keepdim=True)
55
+
56
+ y_hat = S_q @ A_q
57
+ loss = loss + F.mse_loss(y_hat, target_c)
58
+
59
+ return loss / max(len(names), 1)
@@ -0,0 +1,111 @@
1
+ """
2
+ Copyright (c) 2026 Tobias Karusseit
3
+ This source code is licensed under the MIT license found in the
4
+ LICENSE file in the root directory of this source tree.
5
+ """
6
+
7
+ import torch
8
+
9
+ def disentanglement_loss_original(
10
+ Z: torch.Tensor,
11
+ W: dict[str, torch.Tensor]
12
+ ) -> torch.Tensor:
13
+ """
14
+ [LEGACY] Penalizes correlation between latent vectors projected into
15
+ different subspaces, using mean squared correlation per pair.
16
+
17
+ We recommend `disentanglement_loss` instead — see that function's
18
+ docstring for the difference.
19
+
20
+ #### USAGE:
21
+ loss = disentanglement_loss_original(Z, W)
22
+ loss.backward()
23
+
24
+ #### HOW IT WORKS:
25
+ For every pair of subspaces (i, j):
26
+ S_i, S_j = Z_c @ W_i.T, Z_c @ W_j.T # project into each subspace
27
+ normalize S_i, S_j to unit std (no centering)
28
+ corr = S_i_n.T @ S_j_n / B # (k_i, k_j) correlation matrix
29
+ total += mean(corr ** 2) # squared correlation
30
+ The sum is averaged over the number of subspace pairs.
31
+
32
+ #### ARGS:
33
+ - Z: (B, D) latent vectors.
34
+ - W: {name: (D, k)} subspace bases.
35
+
36
+ #### RETURNS:
37
+ - A scalar torch.Tensor loss. Perfect disentanglement -> 0,
38
+ fully correlated -> 1.
39
+ """
40
+ total = torch.tensor(0.0, device=Z.device)
41
+ Z_c = Z - Z.mean(0, keepdim=True)
42
+ names = list(W.keys())
43
+ for i in range(len(names)):
44
+ for j in range(i + 1, len(names)):
45
+ Si = Z_c @ W[names[i]].T # the latent vector in subspace i
46
+ Sj = Z_c @ W[names[j]].T # the latent vector in subspace j
47
+ Si_n = Si / (Si.std(dim=0, keepdim=True) + 1e-8) # normed to unit length (for scale invariance in the comming step)
48
+ Sj_n = Sj / (Sj.std(dim=0, keepdim=True) + 1e-8) # normed to unit length
49
+ corr = Si_n.T @ Sj_n / Si.shape[0] # correlation
50
+ total = total + corr.pow(2).mean() # squared correlation to focus on large correlations
51
+ n_pairs = len(names) * (len(names) - 1) / 2 # number of pairs for normalization
52
+ print(f"[disentanglement] n_pairs={n_pairs}")
53
+ total = total / max(n_pairs, 1) # normalized by number of pairs
54
+ return total
55
+
56
+ def disentanglement_loss(
57
+ Z: torch.Tensor,
58
+ W: dict[str, torch.Tensor]
59
+ ) -> torch.Tensor:
60
+ """
61
+ Penalizes correlation between latent vectors projected into different
62
+ subspaces, using the normalized trace of the squared correlation
63
+ matrix per pair.
64
+
65
+ #### USAGE:
66
+ loss = disentanglement_loss(Z, W)
67
+ loss.backward()
68
+
69
+ #### HOW IT WORKS:
70
+ For every pair of subspaces (i, j):
71
+ S_i, S_j = Z_c @ W_i.T, Z_c @ W_j.T # project into each subspace
72
+ standardize S_i, S_j (center and divide by std)
73
+ Corr = S_i_std.T @ S_j_std / B # (k_i, k_j) correlation matrix
74
+ total += trace(Corr.T @ Corr) / k_i # normalized squared correlation
75
+ The sum is averaged over the number of subspace pairs.
76
+
77
+ #### ARGS:
78
+ - Z: (B, D) latent vectors.
79
+ - W: {name: (D, k)} subspace bases.
80
+
81
+ #### RETURNS:
82
+ - A scalar torch.Tensor loss. Perfect disentanglement -> 0,
83
+ fully correlated -> 1.
84
+ """
85
+ Z_c = Z - Z.mean(0, keepdim=True)
86
+ total = torch.tensor(0.0, device=Z.device)
87
+ names = list(W.keys())
88
+ pairs = 0
89
+
90
+ for i in range(len(names)):
91
+ for j in range(i + 1, len(names)):
92
+ # Project
93
+ Si = Z_c @ W[names[i]].T
94
+ Sj = Z_c @ W[names[j]].T
95
+
96
+ # Standardize (Center and divide by std)
97
+ # This is the correct way to compute correlation
98
+ Si_std = (Si - Si.mean(0, keepdim=True)) / (Si.std(0, keepdim=True) + 1e-8)
99
+ Sj_std = (Sj - Sj.mean(0, keepdim=True)) / (Sj.std(0, keepdim=True) + 1e-8)
100
+
101
+ # Correlation matrix (ki, kj)
102
+ # Dividing by Z.shape[0] gives the covariance of standardized variables
103
+ Corr = Si_std.T @ Sj_std / Z.shape[0]
104
+
105
+ # Use the Trace of the squared correlation matrix
106
+ # For identical subspaces, trace(Corr^2) = k
107
+ # Dividing by k gives exactly 1.0
108
+ total += torch.trace(Corr.T @ Corr) / Si.shape[1]
109
+ pairs += 1
110
+
111
+ return total / max(pairs, 1)
@@ -0,0 +1,56 @@
1
+ """
2
+ Copyright (c) 2026 Tobias Karusseit
3
+ This source code is licensed under the MIT license found in the
4
+ LICENSE file in the root directory of this source tree.
5
+ """
6
+
7
+ import torch
8
+ from ..util import singular_values
9
+
10
+ def isotropy_loss(
11
+ Z: torch.Tensor,
12
+ W: dict[str, torch.Tensor],
13
+ n_dir: int = 32
14
+ ) -> torch.Tensor:
15
+ """
16
+ Enforces isotropy in the latent residual space by penalizing variance
17
+ among its leading singular values.
18
+
19
+ #### USAGE:
20
+ loss = isotropy_loss(Z, W, n_dir=32)
21
+ loss.backward()
22
+
23
+ #### HOW IT WORKS:
24
+ Z_c = Z - Z.mean(0) # center latents
25
+ W_all = concat(W.values()) # (sum_k, D) stacked bases
26
+ Z_resid = Z_c - Z_c @ W_all.T @ W_all # residual after projecting
27
+ # out all known subspaces
28
+ # (assumes subspaces are
29
+ # mutually orthogonal)
30
+ S = singular_values(Z_resid) # singular values of residual
31
+ S_top = S[:n_dir] # leading n_dir values
32
+ S_norm = S_top / S_top.mean() # normalize to mean 1
33
+ loss = mean((S_norm - 1) ** 2) # penalize deviation
34
+ # from uniform spread
35
+
36
+ #### ARGS:
37
+ - Z: (B, D) latent vectors.
38
+ - W: {name: (D, k)} subspace bases already accounted for (e.g. by
39
+ alignment/disentanglement losses). These are projected out of
40
+ `Z` before computing isotropy on what remains.
41
+ - n_dir: Number of leading singular values of the residual to
42
+ consider. Defaults to 32.
43
+
44
+ #### RETURNS:
45
+ - A scalar torch.Tensor loss. Perfectly isotropic residual
46
+ directions (uniform singular value spectrum) -> 0.
47
+ """
48
+ z_c = Z - Z.mean(0, keepdim=True) # centered for variance calc
49
+ names = list(W.keys())
50
+ W_all = torch.cat([W[n] for n in names], dim=0).to(Z.device)
51
+ z_resid = z_c - z_c @ W_all.T @ W_all # residual subspace (approximation only, assumes the subspaces are orthogonal)
52
+ S_sv = singular_values(z_resid)
53
+ S_top = S_sv[:n_dir]
54
+ S_norm = S_top / (S_top.mean() + 1e-8)
55
+ loss_rank = ((S_norm - 1.0) ** 2).mean() # punish the model if the svd vals are not even
56
+ return loss_rank
@@ -0,0 +1,44 @@
1
+ """
2
+ Copyright (c) 2026 Tobias Karusseit
3
+ This source code is licensed under the MIT license found in the
4
+ LICENSE file in the root directory of this source tree.
5
+ """
6
+
7
+ import torch
8
+
9
+ def stability_loss(
10
+ W: dict[str, torch.Tensor],
11
+ W_prev: dict[str, torch.Tensor]
12
+ ) -> torch.Tensor:
13
+ """
14
+ Punishes large or rapid changes in subspace geometry between training
15
+ steps, encouraging the model to evolve its latent subspace structure
16
+ gradually rather than jumping between geometries.
17
+
18
+ #### USAGE:
19
+ loss = stability_loss(W, W_prev)
20
+ loss.backward()
21
+
22
+ #### HOW IT WORKS:
23
+ If W_prev is provided:
24
+ for each quantity name in W:
25
+ loss += sum((W[name] - W_prev[name]) ** 2) # squared drift
26
+ loss /= number of quantities
27
+ Otherwise, the loss is 0 (e.g. on the first step, when there is no
28
+ previous subspace to compare against).
29
+
30
+ #### ARGS:
31
+ - W: {name: (D, k)} current subspace bases.
32
+ - W_prev: {name: (D, k)} subspace bases from the previous step, or
33
+ None if there is no previous step to compare against.
34
+
35
+ #### RETURNS:
36
+ - A scalar torch.Tensor loss. No drift -> 0; larger or more
37
+ rapid changes in subspace geometry -> larger values.
38
+ """
39
+ loss_stab = torch.tensor(0.0, device=W[next(iter(W.keys()))].device)
40
+ if W_prev is not None:
41
+ for name in W.keys():
42
+ loss_stab = loss_stab + (W[name] - W_prev[name]).pow(2).sum()
43
+ loss_stab = loss_stab / len(W.keys())
44
+ return loss_stab
@@ -0,0 +1,19 @@
1
+ """
2
+ Copyright (c) 2026 Tobias Karusseit
3
+ This source code is licensed under the MIT license found in the
4
+ LICENSE file in the root directory of this source tree.
5
+ """
6
+
7
+ from ._optimize_targets import optimal_bins
8
+ from ._optimize_latent import minimal_latent_dim
9
+ from ._fit_decoder import fit_decoders
10
+ from ._subspace_change import subspace_change
11
+ from ._optimize_subspaces import k_per_q
12
+
13
+ __all__ = [
14
+ "optimal_bins",
15
+ "minimal_latent_dim",
16
+ "fit_decoders",
17
+ "subspace_change",
18
+ "k_per_q",
19
+ ]
@@ -0,0 +1,56 @@
1
+ """
2
+ Copyright (c) 2026 Tobias Karusseit
3
+ This source code is licensed under the MIT license found in the
4
+ LICENSE file in the root directory of this source tree.
5
+ """
6
+
7
+ import torch
8
+
9
+ def fit_decoders(W: dict, Z: torch.Tensor, pooled_global: dict[str, torch.Tensor]) -> dict:
10
+ """
11
+ Learns a linear mapping (decoder) from each latent subspace's
12
+ coordinates to its corresponding physical target space, via
13
+ least-squares regression.
14
+
15
+ #### USAGE:
16
+ A = fit_decoders(W, Z, pooled_global)
17
+ # A can then be passed directly to alignment_loss as the
18
+ # 'A' (decoder) argument
19
+
20
+ #### HOW IT WORKS:
21
+ Z_c = Z - Z.mean(0) # center latents
22
+ for each quantity name in W:
23
+ S_q = Z_c @ W_q.T # project latents
24
+ # into subspace q
25
+ Y_q = pooled_global[name] # physical target,
26
+ # flattened if needed
27
+ Y_q_c = Y_q - Y_q.mean(0) # center target
28
+ A_q = lstsq(S_q, Y_q_c).solution # solve for A_q
29
+ # minimizing
30
+ # ||Y_q_c - S_q @ A_q||_F
31
+ Each A_q is detached from the autograd graph before being
32
+ returned, since it is fit via least squares rather than learned
33
+ through backpropagation.
34
+
35
+ #### ARGS:
36
+ - W: {name: (k, D)} subspace basis matrices.
37
+ - Z: (B, D) latent vectors.
38
+ - pooled_global: {name: (B, n^2) or (B, n, n) ...} pooled
39
+ physical targets. Targets with more than 2 dims are flattened
40
+ to (B, n^2) before fitting.
41
+
42
+ #### RETURNS:
43
+ - A dict {name: (k, n^2)} of optimal linear decoder matrices, one
44
+ per quantity in W, each detached from the autograd graph.
45
+ """
46
+ Z_c = Z - Z.mean(0, keepdim=True) # center
47
+ A = {}
48
+ for name, W_q in W.items():
49
+ S_q = Z_c @ W_q.T # projection of Z into the subspace
50
+ y = pooled_global[name] # pooled target
51
+ if y.dim() > 2: # flatten
52
+ y = y.flatten(1)
53
+ y_c = y - y.mean(0, keepdim=True) # center
54
+ # solve (finds A, such that (y_c - S_q @ A) is approx 0)
55
+ A[name] = torch.linalg.lstsq(S_q, y_c).solution.detach()
56
+ return A
@@ -0,0 +1,92 @@
1
+ """
2
+ Copyright (c) 2026 Tobias Karusseit
3
+ This source code is licensed under the MIT license found in the
4
+ LICENSE file in the root directory of this source tree.
5
+ """
6
+
7
+ import math
8
+
9
+
10
+ def _align_latent_dim(dim: int, num_heads: int) -> int:
11
+ """Round dim up to the next multiple of num_heads (unchanged if already divisible)."""
12
+ if dim < 1:
13
+ raise ValueError(f"dim must be >= 1, got {dim}")
14
+ if num_heads < 1:
15
+ raise ValueError(f"num_heads must be >= 1, got {num_heads}")
16
+ if num_heads == 1:
17
+ return dim
18
+ remainder = dim % num_heads
19
+ if remainder == 0:
20
+ return dim
21
+ return dim + (num_heads - remainder)
22
+
23
+
24
+ def minimal_latent_dim(
25
+ k: int | dict[str, int],
26
+ *,
27
+ Q: int | None = None,
28
+ free_frac: float = 0.30,
29
+ num_heads: int | None = None,
30
+ ) -> int:
31
+ """
32
+ Computes the minimum latent dimension needed to fit a set of physics
33
+ subspaces plus a fraction of free (unconstrained) capacity, optionally
34
+ aligned to a multiple of num_heads.
35
+
36
+ #### USAGE:
37
+ # Example 1: shared k across all quantities
38
+ D = minimal_latent_dim(k=4, Q=6, free_frac=0.30)
39
+
40
+ # Example 2: per-quantity k
41
+ D = minimal_latent_dim(k={"scaled": 4, "grammian": 6}, free_frac=0.30)
42
+
43
+ # Example 3: align result to attention head count
44
+ D = minimal_latent_dim(k=4, Q=6, num_heads=8)
45
+
46
+ #### HOW IT WORKS:
47
+ physics_dims = Q * k # if k is a shared scalar
48
+ physics_dims = sum(k.values()) # if k is per-quantity
49
+ D_min = ceil(physics_dims / (1 - free_frac))
50
+ if num_heads is set:
51
+ D_min = next multiple of num_heads >= D_min
52
+
53
+ #### ARGS:
54
+ - k: Subspace dimensionality. Either a single int shared across all
55
+ `Q` quantities, or a dict mapping each quantity name to its
56
+ own k.
57
+ - Q: Number of quantities. Required (and only used) when `k` is a
58
+ scalar; ignored when `k` is a dict, since the count is
59
+ inferred from its keys.
60
+ - free_frac: Fraction of the latent dimension to reserve as free,
61
+ unconstrained capacity beyond the physics subspaces.
62
+ Must be in [0, 1). Defaults to 0.30.
63
+ - num_heads: If set, the result is rounded up to the next multiple
64
+ of num_heads (e.g. to match an attention head count).
65
+ Defaults to None (no alignment).
66
+
67
+ #### RETURNS:
68
+ - The minimum latent dimension D_min satisfying the physics
69
+ subspace requirement, the free-capacity fraction, and (if
70
+ given) the num_heads alignment constraint.
71
+ """
72
+ if not 0.0 <= free_frac < 1.0:
73
+ raise ValueError(f"free_frac must be in [0, 1), got {free_frac}")
74
+
75
+ if isinstance(k, dict):
76
+ physics_dims = sum(k.values())
77
+ detail = ", ".join(f"{name}={k_q}" for name, k_q in sorted(k.items()))
78
+ # print(f"[latent-sizing] k per quantity: {detail} (sum={physics_dims})")
79
+ else:
80
+ if Q is None:
81
+ raise ValueError("Q is required when k is a scalar")
82
+ physics_dims = Q * k
83
+ # print(f"[latent-sizing] Q={Q}, k={k}")
84
+
85
+ D_min = math.ceil(physics_dims / (1.0 - free_frac))
86
+ print(f"[latent-sizing] free={free_frac * 100:.0f}% -> D_min={D_min}")
87
+ if num_heads is not None:
88
+ aligned = _align_latent_dim(D_min, num_heads)
89
+ # if aligned != D_min:
90
+ # print(f"[latent-sizing] aligned D {D_min} -> {aligned} (num_heads={num_heads})")
91
+ D_min = aligned
92
+ return D_min
@@ -0,0 +1,109 @@
1
+ """
2
+ Copyright (c) 2026 Tobias Karusseit
3
+ This source code is licensed under the MIT license found in the
4
+ LICENSE file in the root directory of this source tree.
5
+ """
6
+
7
+ import torch
8
+ from ..util import singular_values
9
+
10
+ def _k_from_target(target: torch.Tensor, threshold: float) -> int:
11
+ if target.dim() == 1:
12
+ target = target.unsqueeze(1)
13
+ elif target.dim() > 2:
14
+ target = target.flatten(1)
15
+ y_c = target - target.mean(0, keepdim=True)
16
+ S = singular_values(y_c)
17
+ var = S.pow(2)
18
+ spec = var / (var.sum() + 1e-8)
19
+ cumvar = spec.cumsum(0)
20
+ return int((cumvar < threshold).sum().item() + 1)
21
+
22
+
23
+ def _resolve_threshold(
24
+ name: str,
25
+ threshold: float | dict[str, float],
26
+ default: float,
27
+ ) -> float:
28
+ if isinstance(threshold, dict):
29
+ return float(threshold.get(name, default))
30
+ return float(threshold)
31
+
32
+
33
+ def k_per_q(
34
+ pooled_global: dict[str, torch.Tensor],
35
+ threshold: float | dict[str, float] = 0.90,
36
+ *,
37
+ default_threshold: float | None = None,
38
+ per_quantity: bool = True,
39
+ ) -> int | dict[str, int]:
40
+ """
41
+ Estimates the minimal subspace dimension k needed per quantity to
42
+ retain a target fraction of variance, via PCA-style spectral analysis
43
+ of each pooled target.
44
+
45
+ #### USAGE:
46
+ # Example 1: shared threshold, one k per quantity
47
+ k = k_per_q(pooled_global, threshold=0.90)
48
+ # k = {"scaled": 3, "grammian": 5}
49
+
50
+ # Example 2: per-quantity thresholds, with a fallback for missing names
51
+ k = k_per_q(
52
+ pooled_global,
53
+ threshold={"scaled": 0.95},
54
+ default_threshold=0.90,
55
+ )
56
+
57
+ # Example 3: collapse to a single global k (the max across quantities)
58
+ k = k_per_q(pooled_global, threshold=0.90, per_quantity=False)
59
+ # k = 5
60
+
61
+ #### HOW IT WORKS:
62
+ For each quantity name, target in pooled_global:
63
+ target_c = target - target.mean(0) # center (flattened
64
+ # if > 2D)
65
+ S = singular_values(target_c) # spectrum of target
66
+ spec = S^2 / sum(S^2) # normalized variance
67
+ # per component
68
+ k_q = number of leading components needed for
69
+ cumulative variance to reach `threshold`
70
+ If per_quantity is True, returns {name: k_q} for all quantities.
71
+ Otherwise, returns max(k_q for all quantities).
72
+
73
+ #### ARGS:
74
+ - pooled_global: {name: (B, n^2) or (B, n, n) ...} pooled physical
75
+ targets, one per quantity. Targets with more than 2 dims are
76
+ flattened to (B, n^2); 1D targets are treated as (B, 1).
77
+ - threshold: Either a single float shared across all quantities,
78
+ or a {name: threshold} dict for per-quantity thresholds. Each
79
+ threshold is the target fraction of cumulative variance to
80
+ retain (e.g. 0.90 -> keep enough components for 90% variance).
81
+ Defaults to 0.90.
82
+ - default_threshold: Fallback threshold used for any quantity name
83
+ missing from a `threshold` dict. If None, falls back to the
84
+ scalar `threshold` value (or 0.90 if `threshold` is itself a
85
+ dict with no scalar to fall back to). Defaults to None
86
+ (keyword-only).
87
+ - per_quantity: If True, returns a dict of k per quantity. If
88
+ False, returns the single largest k found across all
89
+ quantities. Defaults to True (keyword-only).
90
+
91
+ #### RETURNS:
92
+ - If per_quantity is True: a dict {name: int} mapping each
93
+ quantity to its estimated subspace dimension k.
94
+ - If per_quantity is False: a single int, the maximum k across
95
+ all quantities.
96
+ """
97
+ fallback = default_threshold if default_threshold is not None else (
98
+ float(threshold) if not isinstance(threshold, dict) else 0.90
99
+ )
100
+
101
+ k_map = {
102
+ name: _k_from_target(target, _resolve_threshold(name, threshold, fallback))
103
+ for name, target in pooled_global.items()
104
+ }
105
+
106
+ if per_quantity:
107
+ return k_map
108
+
109
+ return max(k_map.values())
@@ -0,0 +1,119 @@
1
+ """
2
+ Copyright (c) 2026 Tobias Karusseit
3
+ This source code is licensed under the MIT license found in the
4
+ LICENSE file in the root directory of this source tree.
5
+ """
6
+
7
+ import math
8
+ import torch
9
+ from ..util import singular_values
10
+
11
+ def _resolve_threshold(
12
+ name: str,
13
+ threshold: float | dict[str, float],
14
+ default: float,
15
+ ) -> float:
16
+ if isinstance(threshold, dict):
17
+ return float(threshold.get(name, default))
18
+ return float(threshold)
19
+
20
+
21
+ def optimal_bins(
22
+ pooled_targets_coarse: dict[str, torch.Tensor],
23
+ max_bins: int = 16,
24
+ variance_threshold: float | dict[str, float] = 0.95,
25
+ *,
26
+ default_threshold: float | None = None,
27
+ per_quantity: bool = False,
28
+ ) -> int | dict[str, int]:
29
+ """
30
+ Finds the smallest n such that an n x n spatial pooling grid captures
31
+ a target fraction of the spatial variance in each physical target,
32
+ via PCA-style spectral analysis on a coarse pooling of that target.
33
+
34
+ Inspired by Nyquist sampling: this finds the resolution at which
35
+ adding more bins gives diminishing returns on spatial information
36
+ captured, rather than picking a pooling resolution by hand.
37
+
38
+ #### USAGE:
39
+ # Example 1: shared threshold, single global n
40
+ n = optimal_bins(pooled_targets_coarse, max_bins=16, variance_threshold=0.95)
41
+
42
+ # Example 2: per-quantity thresholds and per-quantity result
43
+ n = optimal_bins(
44
+ pooled_targets_coarse,
45
+ variance_threshold={"scaled": 0.99},
46
+ default_threshold=0.95,
47
+ per_quantity=True,
48
+ )
49
+ # n = {"scaled": 6, "grammian": 4}
50
+
51
+ #### HOW IT WORKS:
52
+ For each quantity name, Y in pooled_targets_coarse:
53
+ eta = resolved variance_threshold for this quantity
54
+ Y_c = Y - Y.mean(0) # center (flattened
55
+ # if > 2D)
56
+ S = singular_values(Y_c) # spectrum of Y
57
+ cumvar = cumsum(S^2) / sum(S^2) # cumulative
58
+ # variance explained
59
+ n_modes = number of leading modes needed for cumvar >= eta
60
+ n_bins_q = ceil(sqrt(n_modes)) # modes -> grid side
61
+ # length, clamped
62
+ # to [2, max_bins]
63
+ If per_quantity is True, returns {name: n_bins_q} for all quantities.
64
+ Otherwise, returns max(n_bins_q for all quantities), so every
65
+ quantity is adequately sampled by a single shared resolution.
66
+
67
+ #### ARGS:
68
+ - pooled_targets_coarse: {name: (N, n_coarse^2)} pooled physical
69
+ targets at some coarse resolution, used to estimate the spatial
70
+ spectrum. Targets with more than 2 dims are flattened to
71
+ (N, n_coarse^2).
72
+ - max_bins: Maximum grid side length n to consider. Defaults to
73
+ 16 (i.e. up to a 16x16 = 256 bin grid).
74
+ - variance_threshold: Either a single float shared across all
75
+ quantities, or a {name: threshold} dict for per-quantity
76
+ thresholds. Each threshold is the target fraction of spatial
77
+ variance to capture. Defaults to 0.95.
78
+ - default_threshold: Fallback threshold for any quantity name
79
+ missing from a `variance_threshold` dict. If None, falls back
80
+ to the scalar `variance_threshold` value (or 0.95 if
81
+ `variance_threshold` is itself a dict with no scalar to fall
82
+ back to). Defaults to None (keyword-only).
83
+ - per_quantity: If True, returns a dict of n_optimal per
84
+ quantity. If False, returns the single largest n_optimal found
85
+ across all quantities. Defaults to False (keyword-only).
86
+
87
+ #### RETURNS:
88
+ - If per_quantity is True: a dict {name: int} mapping each
89
+ quantity to its optimal grid side length n.
90
+ - If per_quantity is False: a single int, the largest n across
91
+ all quantities, ensuring every quantity is adequately sampled.
92
+ """
93
+ fallback = default_threshold if default_threshold is not None else (
94
+ float(variance_threshold) if not isinstance(variance_threshold, dict) else 0.95
95
+ )
96
+ n_optimal_per_quantity = {}
97
+
98
+ for name, Y in pooled_targets_coarse.items():
99
+ eta = _resolve_threshold(name, variance_threshold, fallback)
100
+ if Y.dim() > 2:
101
+ Y = Y.flatten(1)
102
+ Y_c = Y - Y.mean(0, keepdim=True)
103
+
104
+ S = singular_values(Y_c)
105
+ var = S.pow(2)
106
+ cumvar = var.cumsum(0) / (var.sum() + 1e-8)
107
+
108
+ n_modes = int((cumvar < eta).sum().item() + 1)
109
+ n_bins_q = math.ceil(math.sqrt(n_modes))
110
+ n_bins_q = max(2, min(n_bins_q, max_bins))
111
+
112
+ n_optimal_per_quantity[name] = n_bins_q
113
+ print(f" [{name}] {n_modes} modes (η={eta:.2f}) -> n_bins={n_bins_q}")
114
+
115
+ if per_quantity:
116
+ return n_optimal_per_quantity
117
+
118
+ # take maximum so every quantity is adequately sampled
119
+ return max(n_optimal_per_quantity.values())