spacelearn 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacelearn/__init__.py +16 -0
- spacelearn/_combined.py +261 -0
- spacelearn/_find_subspaces.py +101 -0
- spacelearn/data/__init__.py +11 -0
- spacelearn/data/_compute_Q.py +70 -0
- spacelearn/data/_pipeline.py +55 -0
- spacelearn/data/_pool.py +35 -0
- spacelearn/data/_utils.py +115 -0
- spacelearn/loss/__init__.py +17 -0
- spacelearn/loss/_alignment_loss.py +59 -0
- spacelearn/loss/_disentanglement_loss.py +111 -0
- spacelearn/loss/_isotropy_loss.py +56 -0
- spacelearn/loss/_stability_loss.py +44 -0
- spacelearn/optim/__init__.py +19 -0
- spacelearn/optim/_fit_decoder.py +56 -0
- spacelearn/optim/_optimize_latent.py +92 -0
- spacelearn/optim/_optimize_subspaces.py +109 -0
- spacelearn/optim/_optimize_targets.py +119 -0
- spacelearn/optim/_subspace_change.py +52 -0
- spacelearn/settings.py +20 -0
- spacelearn/util/__init__.py +9 -0
- spacelearn/util/_eigen_fallback.py +126 -0
- spacelearn/util/_lowrank_fallback.py +25 -0
- spacelearn/util/_svd_fallback.py +24 -0
- spacelearn/util/svd_util.py +60 -0
- spacelearn-0.1.0.dist-info/METADATA +231 -0
- spacelearn-0.1.0.dist-info/RECORD +29 -0
- spacelearn-0.1.0.dist-info/WHEEL +4 -0
- spacelearn-0.1.0.dist-info/licenses/LICENSE +21 -0
spacelearn/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2026 Tobias Karusseit
|
|
3
|
+
This source code is licensed under the MIT license found in the
|
|
4
|
+
LICENSE file in the root directory of this source tree.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from ._find_subspaces import subspaces, principal_directions
|
|
8
|
+
from ._combined import combined_loss, solve_subspaces, solve_dims
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"principal_directions",
|
|
12
|
+
"subspaces",
|
|
13
|
+
"combined_loss",
|
|
14
|
+
"solve_subspaces",
|
|
15
|
+
"solve_dims",
|
|
16
|
+
]
|
spacelearn/_combined.py
ADDED
|
@@ -0,0 +1,261 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2026 Tobias Karusseit
|
|
3
|
+
This source code is licensed under the MIT license found in the
|
|
4
|
+
LICENSE file in the root directory of this source tree.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from ._find_subspaces import principal_directions, subspaces
|
|
10
|
+
from .loss import alignment_loss, disentanglement_loss, isotropy_loss, stability_loss
|
|
11
|
+
from .optim import fit_decoders, optimal_bins, k_per_q
|
|
12
|
+
from .data import pool
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class Space:
|
|
16
|
+
W: dict[str, torch.Tensor]
|
|
17
|
+
A: dict[str, torch.Tensor]
|
|
18
|
+
V: dict[str, torch.Tensor]
|
|
19
|
+
_names: list[str] = field(init=False)
|
|
20
|
+
|
|
21
|
+
def __post_init__(self):
|
|
22
|
+
if not (
|
|
23
|
+
self.W.keys() == self.A.keys() == self.V.keys()
|
|
24
|
+
):
|
|
25
|
+
raise ValueError(
|
|
26
|
+
"W, A, and V must contain identical keys"
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
self._names = list(self.W.keys())
|
|
30
|
+
|
|
31
|
+
def __getitem__(self, key: str) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
32
|
+
return self.W[key], self.A[key], self.V[key]
|
|
33
|
+
|
|
34
|
+
def keys(self):
|
|
35
|
+
return self._names
|
|
36
|
+
|
|
37
|
+
def values(self):
|
|
38
|
+
return [(self.W[k], self.A[k], self.V[k]) for k in self._names]
|
|
39
|
+
|
|
40
|
+
def __len__(self):
|
|
41
|
+
return len(self._names)
|
|
42
|
+
|
|
43
|
+
def items(self):
|
|
44
|
+
return [(k, (self.W[k], self.A[k], self.V[k])) for k in self._names]
|
|
45
|
+
|
|
46
|
+
def __iter__(self):
|
|
47
|
+
return iter(self._names)
|
|
48
|
+
|
|
49
|
+
def __contains__(self, key):
|
|
50
|
+
return key in self._names
|
|
51
|
+
|
|
52
|
+
def combined_loss(
|
|
53
|
+
Z: torch.Tensor,
|
|
54
|
+
W: dict[str, torch.Tensor],
|
|
55
|
+
A: dict[str, torch.Tensor],
|
|
56
|
+
Y: dict[str, torch.Tensor],
|
|
57
|
+
W_prev: dict[str, torch.Tensor],
|
|
58
|
+
*,
|
|
59
|
+
n_dir: int = 32
|
|
60
|
+
) -> torch.Tensor:
|
|
61
|
+
"""
|
|
62
|
+
Computes the complete SPACE training objective by combining alignment,
|
|
63
|
+
disentanglement, isotropy, and stability losses into a single scalar.
|
|
64
|
+
|
|
65
|
+
#### USAGE:
|
|
66
|
+
loss = combined_loss(Z, W, A, Y, W_prev)
|
|
67
|
+
|
|
68
|
+
# custom isotropy sampling
|
|
69
|
+
loss = combined_loss(
|
|
70
|
+
Z, W, A, Y, W_prev,
|
|
71
|
+
n_dir=64
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
#### HOW IT WORKS:
|
|
75
|
+
loss =
|
|
76
|
+
alignment_loss(Z, W, A, Y)
|
|
77
|
+
+ disentanglement_loss(Z, W)
|
|
78
|
+
+ isotropy_loss(Z, W, n_dir=n_dir)
|
|
79
|
+
+ stability_loss(W, W_prev)
|
|
80
|
+
|
|
81
|
+
#### ARGS:
|
|
82
|
+
- Z: Latent representations with shape (B, D).
|
|
83
|
+
- W: Dictionary of subspace basis matrices, one per quantity,
|
|
84
|
+
each shaped (k, D).
|
|
85
|
+
- A: Dictionary of fitted decoder matrices mapping subspace
|
|
86
|
+
coordinates to physical targets.
|
|
87
|
+
- Y: Dictionary of pooled physical targets.
|
|
88
|
+
- W_prev: Dictionary of subspace bases from the previous
|
|
89
|
+
optimization step, used for stability regularization.
|
|
90
|
+
- n_dir: Number of random directions used by isotropy_loss when
|
|
91
|
+
estimating residual isotropy. Defaults to 32.
|
|
92
|
+
|
|
93
|
+
#### RETURNS:
|
|
94
|
+
- A scalar tensor containing the summed training objective.
|
|
95
|
+
"""
|
|
96
|
+
return (
|
|
97
|
+
alignment_loss(Z, W, A, Y)
|
|
98
|
+
+ disentanglement_loss(Z, W)
|
|
99
|
+
+ isotropy_loss(Z, W, n_dir=n_dir)
|
|
100
|
+
+ stability_loss(W, W_prev)
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
def solve_subspaces(
|
|
104
|
+
Z: torch.Tensor,
|
|
105
|
+
Y: dict[torch.Tensor],
|
|
106
|
+
*,
|
|
107
|
+
k: int | dict[str, int] = 32,
|
|
108
|
+
) -> dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
|
109
|
+
"""
|
|
110
|
+
Solves the complete subspace-discovery pipeline for a collection of
|
|
111
|
+
physical quantities.
|
|
112
|
+
|
|
113
|
+
For each quantity, this function finds its principal latent
|
|
114
|
+
direction, constructs a k-dimensional subspace basis, and fits an
|
|
115
|
+
optimal linear decoder back to physical space.
|
|
116
|
+
|
|
117
|
+
#### USAGE:
|
|
118
|
+
WVA = solve_subspaces(Z, Y, k=32)
|
|
119
|
+
|
|
120
|
+
W, V, A = WVA["scaled"]
|
|
121
|
+
|
|
122
|
+
#### HOW IT WORKS:
|
|
123
|
+
1. principal_directions() identifies the strongest latent
|
|
124
|
+
direction associated with each target quantity
|
|
125
|
+
2. subspaces() expands these directions into k-dimensional
|
|
126
|
+
orthonormal subspace bases
|
|
127
|
+
3. fit_decoders() solves least-squares decoders from subspace
|
|
128
|
+
coordinates back to physical targets
|
|
129
|
+
4. Results are combined into:
|
|
130
|
+
{name: (W_q, V_q, A_q)}
|
|
131
|
+
|
|
132
|
+
#### ARGS:
|
|
133
|
+
- Z: Latent representations with shape (B, D).
|
|
134
|
+
- Y: Dictionary of pooled physical targets.
|
|
135
|
+
- k: Number of basis directions per subspace. Defaults to 32.
|
|
136
|
+
|
|
137
|
+
#### RETURNS:
|
|
138
|
+
- A dict / Space_dataclass mapping each quantity name to:
|
|
139
|
+
(W_q, A_q, V_q)
|
|
140
|
+
|
|
141
|
+
where:
|
|
142
|
+
W_q = subspace basis matrix (k, D)
|
|
143
|
+
V_q = principal latent direction (D,)
|
|
144
|
+
A_q = fitted decoder matrix
|
|
145
|
+
"""
|
|
146
|
+
wv = principal_directions(Z, Y)
|
|
147
|
+
W = subspaces(Z, Y, k=k)
|
|
148
|
+
A = fit_decoders(W, Z, Y)
|
|
149
|
+
SPACE = Space(W, A, V={k: v[1] for k, v in wv.items()})
|
|
150
|
+
return SPACE
|
|
151
|
+
|
|
152
|
+
def solve_dims(
|
|
153
|
+
Y: dict[str, torch.Tensor],
|
|
154
|
+
max_bins: int,
|
|
155
|
+
bin_thresholds: float | dict[str, float],
|
|
156
|
+
k_thresholds: float | dict[str, float],
|
|
157
|
+
*,
|
|
158
|
+
bin_default_threshold: float | None = None,
|
|
159
|
+
k_default_threshold: float | None = None,
|
|
160
|
+
pooling_mode: str = "avg",
|
|
161
|
+
pool_per_quantity: bool = False,
|
|
162
|
+
k_per_quantity: bool = False
|
|
163
|
+
) -> dict[str, tuple[int, int]]:
|
|
164
|
+
"""
|
|
165
|
+
Computes recommended pooling resolutions and subspace dimensions for
|
|
166
|
+
a collection of physical targets.
|
|
167
|
+
|
|
168
|
+
This combines optimal_bins() and k_per_q() into a single sizing
|
|
169
|
+
pipeline.
|
|
170
|
+
|
|
171
|
+
#### USAGE:
|
|
172
|
+
dims = solve_dims(
|
|
173
|
+
Y,
|
|
174
|
+
max_bins=16,
|
|
175
|
+
bin_thresholds=0.95,
|
|
176
|
+
k_thresholds=0.90
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# per-quantity sizing
|
|
180
|
+
dims = solve_dims(
|
|
181
|
+
Y,
|
|
182
|
+
max_bins=16,
|
|
183
|
+
bin_thresholds=0.95,
|
|
184
|
+
k_thresholds=0.90,
|
|
185
|
+
pool_per_quantity=True,
|
|
186
|
+
k_per_quantity=True
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
#### HOW IT WORKS:
|
|
190
|
+
1. optimal_bins() determines the minimum pooling resolution
|
|
191
|
+
needed to retain the desired fraction of spatial variance
|
|
192
|
+
2. Targets are pooled using the selected resolution(s)
|
|
193
|
+
3. k_per_q() estimates the minimum subspace dimensionality
|
|
194
|
+
needed to retain the desired fraction of variance
|
|
195
|
+
4. Pooling resolutions and subspace dimensions are combined into
|
|
196
|
+
per-quantity (n_bins, k) recommendations
|
|
197
|
+
|
|
198
|
+
#### ARGS:
|
|
199
|
+
- Y: Dictionary of physical target tensors.
|
|
200
|
+
- max_bins: Maximum pooling grid side length considered.
|
|
201
|
+
- bin_thresholds: Variance-retention threshold(s) used by
|
|
202
|
+
optimal_bins().
|
|
203
|
+
- k_thresholds: Variance-retention threshold(s) used by
|
|
204
|
+
k_per_q().
|
|
205
|
+
- bin_default_threshold: Fallback threshold for quantities
|
|
206
|
+
missing from bin_thresholds.
|
|
207
|
+
- k_default_threshold: Fallback threshold for quantities
|
|
208
|
+
missing from k_thresholds.
|
|
209
|
+
- pooling_mode: Pooling mode passed to pool().
|
|
210
|
+
Defaults to "avg".
|
|
211
|
+
- pool_per_quantity: Whether pooling resolution is selected
|
|
212
|
+
independently for each quantity.
|
|
213
|
+
- k_per_quantity: Whether subspace dimensionality is selected
|
|
214
|
+
independently for each quantity.
|
|
215
|
+
|
|
216
|
+
#### RETURNS:
|
|
217
|
+
- A dict:
|
|
218
|
+
{name: (n_bins, k)}
|
|
219
|
+
|
|
220
|
+
mapping each quantity to its recommended pooling resolution
|
|
221
|
+
and subspace dimensionality.
|
|
222
|
+
"""
|
|
223
|
+
"""
|
|
224
|
+
Computes optimal binning and subspace dimensions for a collection of targets.
|
|
225
|
+
|
|
226
|
+
If 'pool_per_quantity' is True, it returns a dictionary mapping each quantity
|
|
227
|
+
to its optimal (n, k) pair. Otherwise, it returns a single global (n, k) tuple.
|
|
228
|
+
"""
|
|
229
|
+
# optimal bin resolution
|
|
230
|
+
n = optimal_bins(
|
|
231
|
+
Y, max_bins, bin_thresholds,
|
|
232
|
+
default_threshold=bin_default_threshold,
|
|
233
|
+
per_quantity=pool_per_quantity
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
# pool based on optimal bin resolution
|
|
237
|
+
if pool_per_quantity:
|
|
238
|
+
# n is a dict {name: n_q}
|
|
239
|
+
pooled = {q: pool(y, n[q], pooling_mode) for q, y in Y.items()}
|
|
240
|
+
else:
|
|
241
|
+
# n is a single global int
|
|
242
|
+
pooled = {q: pool(y, n, pooling_mode) for q, y in Y.items()}
|
|
243
|
+
|
|
244
|
+
# optimal subspace sizes
|
|
245
|
+
k = k_per_q(
|
|
246
|
+
pooled, k_thresholds,
|
|
247
|
+
default_threshold=k_default_threshold,
|
|
248
|
+
per_quantity=k_per_quantity
|
|
249
|
+
)
|
|
250
|
+
|
|
251
|
+
# aggregate n and k
|
|
252
|
+
if pool_per_quantity:
|
|
253
|
+
if k_per_quantity:
|
|
254
|
+
return {q: (n[q], k[q]) for q in Y.keys()}
|
|
255
|
+
else:
|
|
256
|
+
return {q: (n[q], k) for q in Y.keys()}
|
|
257
|
+
else:
|
|
258
|
+
if k_per_quantity:
|
|
259
|
+
return {q: (n, k[q]) for q in Y.keys()}
|
|
260
|
+
else:
|
|
261
|
+
return {q: (n, k) for q in Y.keys()}
|
|
@@ -0,0 +1,101 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2026 Tobias Karusseit
|
|
3
|
+
This source code is licensed under the MIT license found in the
|
|
4
|
+
LICENSE file in the root directory of this source tree.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from .util import full, left
|
|
9
|
+
|
|
10
|
+
def principal_directions(
|
|
11
|
+
z: torch.Tensor, # (B, D)
|
|
12
|
+
pooled_targets: dict[str, torch.Tensor], # {name: (B, n^2)}
|
|
13
|
+
) -> dict[str, tuple[torch.Tensor, torch.Tensor]]:
|
|
14
|
+
"""
|
|
15
|
+
Finds the rank-1 pair (w, v) for multiple quantities such that:
|
|
16
|
+
(z @ w).unsqueeze(1) @ v ~ y
|
|
17
|
+
via the cross-covariance SVD. This runs 'batched' across all quantities
|
|
18
|
+
provided in the dictionary.
|
|
19
|
+
|
|
20
|
+
w : (D,) — direction in latent space
|
|
21
|
+
v : (n^2,) — direction in target space
|
|
22
|
+
|
|
23
|
+
This is the analytical solution, the first left/right singular
|
|
24
|
+
vectors of C = z.T @ y are exactly the directions that maximise
|
|
25
|
+
the explained covariance between z and y under a rank-1 constraint.
|
|
26
|
+
|
|
27
|
+
HOW THIS WORKS IN DETAIL:
|
|
28
|
+
|
|
29
|
+
1. We compute the cross-covariance matrix C = z.T @ y
|
|
30
|
+
2. We want to find the vectors w (the latent directions) and v (the target directions)
|
|
31
|
+
that solve y = (z @ W) @ z.T
|
|
32
|
+
3. since theres no guaranteed solution we instead search the optimal solution
|
|
33
|
+
min_{w, v} |y - (z @ w).unsqueeze(1) @ v|^2_F which is minimizing the forbenius norm of the error
|
|
34
|
+
Note that the forbenious norm ||A||^2_F is equal to the trace of A.T @ A (let the trace fn be Tr())
|
|
35
|
+
This gives: Tr((y-zWv.T).T @ (y-zWv.T)) = Tr((y.T - v(zW).T)(Y-zWv.T))
|
|
36
|
+
...
|
|
37
|
+
We end up having to maximize 2v.T(z.T@y).Tw
|
|
38
|
+
-> max_{w,v} v.T(z.T@y).T w with z.T @ y = C
|
|
39
|
+
4. using SVD we get w.T(UEV.T)v = (U.Tw).T E (V.T v)
|
|
40
|
+
Since UV are both orthogonal matrices we get unit vectors and this simplifies to
|
|
41
|
+
Sum_{i=1}^n sigma_i ^w_i ^v_i (let ^ denote the result of V.T v and U.T w)
|
|
42
|
+
5. To maximize this sum we need to put the weight on the first sinular value (making the unit vecs one hot vectors along the first singular direction)
|
|
43
|
+
Naturally this leaves us with the first singular vectors found by svd, as they
|
|
44
|
+
align the input space along the direction in output space and scale by the larges singular value
|
|
45
|
+
|
|
46
|
+
6. In short: C gives how well any dimension in z correlates with any other dimension in y.
|
|
47
|
+
SVD finds the basis in which te maximum covariance lies along one direction.
|
|
48
|
+
-> leaves us with a rank-1 solution since we discared all othersingular vectors
|
|
49
|
+
"""
|
|
50
|
+
# center z once
|
|
51
|
+
z_c = z - z.mean(0, keepdim=True) # (B, D)
|
|
52
|
+
|
|
53
|
+
results = {}
|
|
54
|
+
|
|
55
|
+
for name, y in pooled_targets.items():
|
|
56
|
+
# if not flat, flatten
|
|
57
|
+
if y.dim() > 2:
|
|
58
|
+
y = y.flatten(1)
|
|
59
|
+
|
|
60
|
+
# center target
|
|
61
|
+
y_c = y - y.mean(0, keepdim=True) # (B, n^2)
|
|
62
|
+
|
|
63
|
+
# cross-covariance matrix
|
|
64
|
+
# gives correlation between latent and target dimensions
|
|
65
|
+
C = z_c.T @ y_c # (D, n^2)
|
|
66
|
+
|
|
67
|
+
# rank-1 SVD (keep only first singular vector)
|
|
68
|
+
# using full SVD and taking first column is fine for small n^2
|
|
69
|
+
U, _, Vh = full(C)
|
|
70
|
+
|
|
71
|
+
w = U[:, 0] # (D,) latent direction
|
|
72
|
+
v = Vh[0, :] # (n^2,) target direction
|
|
73
|
+
|
|
74
|
+
# normalise to unit length
|
|
75
|
+
w = w / (w.norm() + 1e-8)
|
|
76
|
+
v = v / (v.norm() + 1e-8)
|
|
77
|
+
|
|
78
|
+
results[name] = (w.detach(), v.detach())
|
|
79
|
+
|
|
80
|
+
return results
|
|
81
|
+
|
|
82
|
+
def subspaces(Z, pooled_global, k: int) -> dict[str, torch.Tensor]:
|
|
83
|
+
"""
|
|
84
|
+
Find the optimal orthogonal basis in the latent space,
|
|
85
|
+
such that the span of subspaces capture the most variance of the pooled physics targets
|
|
86
|
+
"""
|
|
87
|
+
print("Estimating subspaces...")
|
|
88
|
+
Z_c = Z - Z.mean(0, keepdim=True) # centered for variance calc
|
|
89
|
+
W = {}
|
|
90
|
+
for name, target in pooled_global.items():
|
|
91
|
+
# if not flat, flatten
|
|
92
|
+
if target.dim() > 2:
|
|
93
|
+
target = target.flatten(1)
|
|
94
|
+
y_c = target - target.mean(0, keepdim=True) # center
|
|
95
|
+
C = Z_c.T @ y_c # cross-cov-matrix
|
|
96
|
+
U, _ = left(C) # svd to find the left singular vectors (projection directions from latent into the subspace)
|
|
97
|
+
_k = k if isinstance(k, int) else k[name] # check if k is global or per quantity
|
|
98
|
+
W_q = U[:, :_k].T # only keep the top k (subspace size constraint)
|
|
99
|
+
W_q, _ = torch.linalg.qr(W_q.T, mode="reduced") # orthonormalise
|
|
100
|
+
W[name] = W_q.T # save
|
|
101
|
+
return W
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2026 Tobias Karusseit
|
|
3
|
+
This source code is licensed under the MIT license found in the
|
|
4
|
+
LICENSE file in the root directory of this source tree.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from ._pipeline import input_to_quantity
|
|
8
|
+
from ._pool import pool
|
|
9
|
+
from ._utils import ensure_dim, to_device
|
|
10
|
+
|
|
11
|
+
__all__ = ["input_to_quantity", "pool", "ensure_dim", "to_device"]
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2026 Tobias Karusseit
|
|
3
|
+
This source code is licensed under the MIT license found in the
|
|
4
|
+
LICENSE file in the root directory of this source tree.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import inspect
|
|
8
|
+
from typing import Callable, Any
|
|
9
|
+
|
|
10
|
+
def compute_Q(*arg_functions: Callable, **kwarg_functions: Callable) -> Callable:
|
|
11
|
+
"""
|
|
12
|
+
Creates a dependency-injecting wrapper that executes multiple functions
|
|
13
|
+
simultaneously, passing only the required arguments to each.
|
|
14
|
+
|
|
15
|
+
#### USAGE:
|
|
16
|
+
# Define solvers
|
|
17
|
+
def solver_a(A: torch.Tensor, q: int): ...
|
|
18
|
+
def solver_b(A: torch.Tensor): ...
|
|
19
|
+
|
|
20
|
+
# Create the combined executor
|
|
21
|
+
executor = compute_Q(solver_a, solver_b=solver_b)
|
|
22
|
+
|
|
23
|
+
# Run both, injecting only the arguments they need
|
|
24
|
+
results = executor(A=my_matrix, q=5)
|
|
25
|
+
# solver_a receives {'A': my_matrix, 'q': 5}
|
|
26
|
+
# solver_b receives {'A': my_matrix}
|
|
27
|
+
|
|
28
|
+
#### VALIDATION:
|
|
29
|
+
The function will raise a TypeError if multiple registered functions
|
|
30
|
+
define the same parameter name but specify different type annotations.
|
|
31
|
+
|
|
32
|
+
#### ARGS:
|
|
33
|
+
- *arg_functions: Unnamed callables (names derived from __name__).
|
|
34
|
+
- **kwarg_functions: Named callables for the registry.
|
|
35
|
+
"""
|
|
36
|
+
registry = {}
|
|
37
|
+
for f in arg_functions:
|
|
38
|
+
registry[getattr(f, "__name__", str(id(f)))] = f
|
|
39
|
+
registry.update(kwarg_functions)
|
|
40
|
+
|
|
41
|
+
# collect all fns and check that dublicated params all use the same type
|
|
42
|
+
type_hints = {}
|
|
43
|
+
registry_meta = {}
|
|
44
|
+
|
|
45
|
+
for name, f in registry.items():
|
|
46
|
+
sig = inspect.signature(f)
|
|
47
|
+
params = sig.parameters
|
|
48
|
+
registry_meta[name] = list(params.keys())
|
|
49
|
+
|
|
50
|
+
# validate types for params with the same names accross all fns
|
|
51
|
+
for p_name, param in params.items():
|
|
52
|
+
if p_name in type_hints and param.annotation != inspect.Parameter.empty:
|
|
53
|
+
if type_hints[p_name] != param.annotation:
|
|
54
|
+
raise TypeError(
|
|
55
|
+
f"Type mismatch for argument '{p_name}': "
|
|
56
|
+
f"Expected {type_hints[p_name]}, but {name} defined it as {param.annotation}"
|
|
57
|
+
)
|
|
58
|
+
type_hints[p_name] = param.annotation
|
|
59
|
+
|
|
60
|
+
def wrapped(**kwargs) -> dict[str, Any]:
|
|
61
|
+
results = {}
|
|
62
|
+
for name, f in registry.items():
|
|
63
|
+
relevant_args = {
|
|
64
|
+
k: kwargs[k] for k in registry_meta[name] if k in kwargs
|
|
65
|
+
}
|
|
66
|
+
results[name] = f(**relevant_args)
|
|
67
|
+
return results
|
|
68
|
+
|
|
69
|
+
return wrapped
|
|
70
|
+
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2026 Tobias Karusseit
|
|
3
|
+
This source code is licensed under the MIT license found in the
|
|
4
|
+
LICENSE file in the root directory of this source tree.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from typing import Callable
|
|
9
|
+
from ._compute_Q import compute_Q
|
|
10
|
+
from ._pool import pool
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def input_to_quantity(
|
|
14
|
+
n_bins: int,
|
|
15
|
+
mode: str,
|
|
16
|
+
*arg_functions: Callable,
|
|
17
|
+
**kwarg_functions: Callable
|
|
18
|
+
) -> Callable:
|
|
19
|
+
"""
|
|
20
|
+
Builds a callable that computes one or more quantities from injected
|
|
21
|
+
arguments (via `compute_Q`) and pools each resulting tensor into a
|
|
22
|
+
fixed number of bins.
|
|
23
|
+
|
|
24
|
+
#### USAGE:
|
|
25
|
+
def solver_a(A: torch.Tensor, q: int): ...
|
|
26
|
+
def solver_b(A: torch.Tensor): ...
|
|
27
|
+
|
|
28
|
+
# Create the combined executor: run solver_a and solver_b, then
|
|
29
|
+
# pool each of their outputs into 10 bins using mean pooling
|
|
30
|
+
executor = input_to_quantity(10, "mean", solver_a, solver_b=solver_b)
|
|
31
|
+
|
|
32
|
+
results = executor(A=my_matrix, q=5)
|
|
33
|
+
# results = {
|
|
34
|
+
# "solver_a": pool(solver_a(A=my_matrix, q=5), 10, "mean"),
|
|
35
|
+
# "solver_b": pool(solver_b(A=my_matrix), 10, "mean"),
|
|
36
|
+
# }
|
|
37
|
+
|
|
38
|
+
#### ARGS:
|
|
39
|
+
- n_bins: Number of bins each function's output is pooled into.
|
|
40
|
+
- mode: Pooling strategy passed to `pool` (e.g. "mean", "sum", "max").
|
|
41
|
+
- *arg_functions: Unnamed callables to register (names derived from
|
|
42
|
+
`__name__`); see `compute_Q`.
|
|
43
|
+
- **kwarg_functions: Named callables to register; see `compute_Q`.
|
|
44
|
+
|
|
45
|
+
#### RETURNS:
|
|
46
|
+
A callable that accepts keyword arguments, injects the relevant
|
|
47
|
+
subset into each registered function (see `compute_Q`), and returns
|
|
48
|
+
a dict mapping each function's name to its pooled `torch.Tensor`
|
|
49
|
+
output.
|
|
50
|
+
"""
|
|
51
|
+
Q = compute_Q(*arg_functions, **kwarg_functions)
|
|
52
|
+
def wrapped(**kwargs) -> dict[str, torch.Tensor]:
|
|
53
|
+
data = Q(**kwargs)
|
|
54
|
+
return {name: pool(chunks, n_bins, mode) for name, chunks in data.items()}
|
|
55
|
+
return wrapped
|
spacelearn/data/_pool.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2026 Tobias Karusseit
|
|
3
|
+
This source code is licensed under the MIT license found in the
|
|
4
|
+
LICENSE file in the root directory of this source tree.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
|
|
10
|
+
def spatial_pool_targets(chunks: torch.Tensor, n_bins: int, mode: str = "avg") -> torch.Tensor:
|
|
11
|
+
"""chunks: (B, 1, H, W) -> (B, 1, n_bins, n_bins)"""
|
|
12
|
+
if mode == "avg":
|
|
13
|
+
return F.adaptive_avg_pool2d(chunks, (n_bins, n_bins))
|
|
14
|
+
elif mode == "max":
|
|
15
|
+
return F.adaptive_max_pool2d(chunks, (n_bins, n_bins))
|
|
16
|
+
else:
|
|
17
|
+
raise ValueError(f"Invalid mode: {mode}")
|
|
18
|
+
|
|
19
|
+
def pool_targets(chunks: torch.Tensor, n_bins: int, mode: str = "avg") -> torch.Tensor:
|
|
20
|
+
"""chunks: (B, 1, H, W) -> (B, 1, n_bins, n_bins)"""
|
|
21
|
+
if mode == "avg":
|
|
22
|
+
return F.adaptive_avg_pool1d(chunks, n_bins)
|
|
23
|
+
elif mode == "max":
|
|
24
|
+
return F.adaptive_max_pool1d(chunks, n_bins)
|
|
25
|
+
else:
|
|
26
|
+
raise ValueError(f"Invalid mode: {mode}")
|
|
27
|
+
|
|
28
|
+
def pool(chunks: torch.Tensor, n_bins: int, mode: str = "avg") -> torch.Tensor:
|
|
29
|
+
ndims = chunks.dim()
|
|
30
|
+
if ndims == 4: # B, C, H, W
|
|
31
|
+
return spatial_pool_targets(chunks, n_bins, mode)
|
|
32
|
+
elif ndims == 2: # B, S
|
|
33
|
+
return pool_targets(chunks, n_bins, mode)
|
|
34
|
+
else:
|
|
35
|
+
raise ValueError(f"Invalid number of dimensions: {ndims}. Expected 2 (B, S) or 4 (B, C, H, W)")
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2026 Tobias Karusseit
|
|
3
|
+
This source code is licensed under the MIT license found in the
|
|
4
|
+
LICENSE file in the root directory of this source tree.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from typing import Callable
|
|
8
|
+
import torch
|
|
9
|
+
import functools
|
|
10
|
+
|
|
11
|
+
def ensure_dim(func: Callable = None, *, dims: int, targets: list[str] | None = None) -> Callable:
|
|
12
|
+
"""
|
|
13
|
+
Wraps a function to force input tensors to have a minimum number of dimensions
|
|
14
|
+
by prepending unit dimensions (1s) to their shape.
|
|
15
|
+
|
|
16
|
+
#### USAGE:
|
|
17
|
+
# Example 1: Force all inputs to be at least 4D (B, C, H, W)
|
|
18
|
+
@ensure_dim(dims=4)
|
|
19
|
+
def process_spatial(grid: torch.Tensor): ...
|
|
20
|
+
|
|
21
|
+
# Example 2: Only force specific inputs to be 4D
|
|
22
|
+
@ensure_dim(dims=4, targets=['pressure', 'velocity'])
|
|
23
|
+
def solve_physics(pressure: torch.Tensor, density: torch.Tensor):
|
|
24
|
+
# 'pressure' will be forced to 4D
|
|
25
|
+
# 'density' will remain untouched
|
|
26
|
+
...
|
|
27
|
+
|
|
28
|
+
#### ARGS:
|
|
29
|
+
- function: The callable to wrap.
|
|
30
|
+
- dims: The target number of dimensions (e.g., 4 for spatial grids).
|
|
31
|
+
- targets: An optional list of argument names to target for expansion.
|
|
32
|
+
If None or empty, all positional and keyword tensors are expanded.
|
|
33
|
+
|
|
34
|
+
#### RETURNS:
|
|
35
|
+
- A wrapped callable that expands tensors to the requested dimensionality
|
|
36
|
+
using `view()` before execution.
|
|
37
|
+
"""
|
|
38
|
+
if func is None:
|
|
39
|
+
return functools.partial(ensure_dim, dims=dims, targets=targets)
|
|
40
|
+
# helper to expand a tensor using view
|
|
41
|
+
def expand(t):
|
|
42
|
+
if isinstance(t, torch.Tensor) and t.dim() < dims:
|
|
43
|
+
# adds dimensions to the front (e.g., to make (H, W) -> (1, 1, H, W))
|
|
44
|
+
new_shape = (1,) * (dims - t.dim()) + t.shape
|
|
45
|
+
return t.view(new_shape)
|
|
46
|
+
return t
|
|
47
|
+
|
|
48
|
+
@functools.wraps(func)
|
|
49
|
+
def wrapped(*args, **kwargs):
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
# positional arguments
|
|
53
|
+
new_args = tuple(expand(a) for a in args)
|
|
54
|
+
|
|
55
|
+
# keyword arguments
|
|
56
|
+
new_kwargs = {}
|
|
57
|
+
for k, v in kwargs.items():
|
|
58
|
+
if targets is None or not targets or k in targets:
|
|
59
|
+
new_kwargs[k] = expand(v)
|
|
60
|
+
else:
|
|
61
|
+
new_kwargs[k] = v
|
|
62
|
+
|
|
63
|
+
return func(*new_args, **new_kwargs)
|
|
64
|
+
|
|
65
|
+
return wrapped
|
|
66
|
+
|
|
67
|
+
def to_device(func: Callable = None, *, device: str = "cpu", targets: list[str] = None):
|
|
68
|
+
"""
|
|
69
|
+
Wraps a function to move input tensors to a target device before execution.
|
|
70
|
+
|
|
71
|
+
#### USAGE:
|
|
72
|
+
# Example 1: Move all tensor inputs to a device
|
|
73
|
+
@to_device(device="cuda")
|
|
74
|
+
def process(grid: torch.Tensor): ...
|
|
75
|
+
|
|
76
|
+
# Example 2: Used bare, defaults to "cpu"
|
|
77
|
+
@to_device
|
|
78
|
+
def process(grid: torch.Tensor): ...
|
|
79
|
+
|
|
80
|
+
# Example 3: Only move specific inputs
|
|
81
|
+
@to_device(device="cuda", targets=["pressure"])
|
|
82
|
+
def solve_physics(pressure: torch.Tensor, density: torch.Tensor):
|
|
83
|
+
# 'pressure' will be moved to "cuda"
|
|
84
|
+
# 'density' will remain untouched
|
|
85
|
+
...
|
|
86
|
+
|
|
87
|
+
#### ARGS:
|
|
88
|
+
- func: The callable to wrap.
|
|
89
|
+
- device: The target device to move tensors to (e.g., "cpu", "cuda").
|
|
90
|
+
- targets: An optional list of argument names to target for moving.
|
|
91
|
+
If None, all positional and keyword tensors are moved.
|
|
92
|
+
|
|
93
|
+
#### RETURNS:
|
|
94
|
+
- A wrapped callable that moves tensors to the requested device
|
|
95
|
+
using `.to()` before execution.
|
|
96
|
+
"""
|
|
97
|
+
# This allows the decorator to be used as @to_device OR @to_device(device='cuda')
|
|
98
|
+
if func is None:
|
|
99
|
+
return functools.partial(to_device, device=device, targets=targets)
|
|
100
|
+
|
|
101
|
+
@functools.wraps(func)
|
|
102
|
+
def wrapped(*args, **kwargs):
|
|
103
|
+
def move(t):
|
|
104
|
+
if isinstance(t, torch.Tensor):
|
|
105
|
+
return t.to(device)
|
|
106
|
+
return t
|
|
107
|
+
|
|
108
|
+
new_args = tuple(move(a) for a in args)
|
|
109
|
+
new_kwargs = {
|
|
110
|
+
k: move(v) if (targets is None or k in targets) else v
|
|
111
|
+
for k, v in kwargs.items()
|
|
112
|
+
}
|
|
113
|
+
return func(*new_args, **new_kwargs)
|
|
114
|
+
|
|
115
|
+
return wrapped
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copyright (c) 2026 Tobias Karusseit
|
|
3
|
+
This source code is licensed under the MIT license found in the
|
|
4
|
+
LICENSE file in the root directory of this source tree.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from ._alignment_loss import alignment_loss
|
|
8
|
+
from ._disentanglement_loss import disentanglement_loss
|
|
9
|
+
from ._isotropy_loss import isotropy_loss
|
|
10
|
+
from ._stability_loss import stability_loss
|
|
11
|
+
|
|
12
|
+
__all__ = [
|
|
13
|
+
"alignment_loss",
|
|
14
|
+
"disentanglement_loss",
|
|
15
|
+
"isotropy_loss",
|
|
16
|
+
"stability_loss",
|
|
17
|
+
]
|