spacelearn 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacelearn/__init__.py ADDED
@@ -0,0 +1,16 @@
1
+ """
2
+ Copyright (c) 2026 Tobias Karusseit
3
+ This source code is licensed under the MIT license found in the
4
+ LICENSE file in the root directory of this source tree.
5
+ """
6
+
7
+ from ._find_subspaces import subspaces, principal_directions
8
+ from ._combined import combined_loss, solve_subspaces, solve_dims
9
+
10
+ __all__ = [
11
+ "principal_directions",
12
+ "subspaces",
13
+ "combined_loss",
14
+ "solve_subspaces",
15
+ "solve_dims",
16
+ ]
@@ -0,0 +1,261 @@
1
+ """
2
+ Copyright (c) 2026 Tobias Karusseit
3
+ This source code is licensed under the MIT license found in the
4
+ LICENSE file in the root directory of this source tree.
5
+ """
6
+
7
+ import torch
8
+ from dataclasses import dataclass, field
9
+ from ._find_subspaces import principal_directions, subspaces
10
+ from .loss import alignment_loss, disentanglement_loss, isotropy_loss, stability_loss
11
+ from .optim import fit_decoders, optimal_bins, k_per_q
12
+ from .data import pool
13
+
14
+ @dataclass
15
+ class Space:
16
+ W: dict[str, torch.Tensor]
17
+ A: dict[str, torch.Tensor]
18
+ V: dict[str, torch.Tensor]
19
+ _names: list[str] = field(init=False)
20
+
21
+ def __post_init__(self):
22
+ if not (
23
+ self.W.keys() == self.A.keys() == self.V.keys()
24
+ ):
25
+ raise ValueError(
26
+ "W, A, and V must contain identical keys"
27
+ )
28
+
29
+ self._names = list(self.W.keys())
30
+
31
+ def __getitem__(self, key: str) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
32
+ return self.W[key], self.A[key], self.V[key]
33
+
34
+ def keys(self):
35
+ return self._names
36
+
37
+ def values(self):
38
+ return [(self.W[k], self.A[k], self.V[k]) for k in self._names]
39
+
40
+ def __len__(self):
41
+ return len(self._names)
42
+
43
+ def items(self):
44
+ return [(k, (self.W[k], self.A[k], self.V[k])) for k in self._names]
45
+
46
+ def __iter__(self):
47
+ return iter(self._names)
48
+
49
+ def __contains__(self, key):
50
+ return key in self._names
51
+
52
+ def combined_loss(
53
+ Z: torch.Tensor,
54
+ W: dict[str, torch.Tensor],
55
+ A: dict[str, torch.Tensor],
56
+ Y: dict[str, torch.Tensor],
57
+ W_prev: dict[str, torch.Tensor],
58
+ *,
59
+ n_dir: int = 32
60
+ ) -> torch.Tensor:
61
+ """
62
+ Computes the complete SPACE training objective by combining alignment,
63
+ disentanglement, isotropy, and stability losses into a single scalar.
64
+
65
+ #### USAGE:
66
+ loss = combined_loss(Z, W, A, Y, W_prev)
67
+
68
+ # custom isotropy sampling
69
+ loss = combined_loss(
70
+ Z, W, A, Y, W_prev,
71
+ n_dir=64
72
+ )
73
+
74
+ #### HOW IT WORKS:
75
+ loss =
76
+ alignment_loss(Z, W, A, Y)
77
+ + disentanglement_loss(Z, W)
78
+ + isotropy_loss(Z, W, n_dir=n_dir)
79
+ + stability_loss(W, W_prev)
80
+
81
+ #### ARGS:
82
+ - Z: Latent representations with shape (B, D).
83
+ - W: Dictionary of subspace basis matrices, one per quantity,
84
+ each shaped (k, D).
85
+ - A: Dictionary of fitted decoder matrices mapping subspace
86
+ coordinates to physical targets.
87
+ - Y: Dictionary of pooled physical targets.
88
+ - W_prev: Dictionary of subspace bases from the previous
89
+ optimization step, used for stability regularization.
90
+ - n_dir: Number of random directions used by isotropy_loss when
91
+ estimating residual isotropy. Defaults to 32.
92
+
93
+ #### RETURNS:
94
+ - A scalar tensor containing the summed training objective.
95
+ """
96
+ return (
97
+ alignment_loss(Z, W, A, Y)
98
+ + disentanglement_loss(Z, W)
99
+ + isotropy_loss(Z, W, n_dir=n_dir)
100
+ + stability_loss(W, W_prev)
101
+ )
102
+
103
+ def solve_subspaces(
104
+ Z: torch.Tensor,
105
+ Y: dict[torch.Tensor],
106
+ *,
107
+ k: int | dict[str, int] = 32,
108
+ ) -> dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
109
+ """
110
+ Solves the complete subspace-discovery pipeline for a collection of
111
+ physical quantities.
112
+
113
+ For each quantity, this function finds its principal latent
114
+ direction, constructs a k-dimensional subspace basis, and fits an
115
+ optimal linear decoder back to physical space.
116
+
117
+ #### USAGE:
118
+ WVA = solve_subspaces(Z, Y, k=32)
119
+
120
+ W, V, A = WVA["scaled"]
121
+
122
+ #### HOW IT WORKS:
123
+ 1. principal_directions() identifies the strongest latent
124
+ direction associated with each target quantity
125
+ 2. subspaces() expands these directions into k-dimensional
126
+ orthonormal subspace bases
127
+ 3. fit_decoders() solves least-squares decoders from subspace
128
+ coordinates back to physical targets
129
+ 4. Results are combined into:
130
+ {name: (W_q, V_q, A_q)}
131
+
132
+ #### ARGS:
133
+ - Z: Latent representations with shape (B, D).
134
+ - Y: Dictionary of pooled physical targets.
135
+ - k: Number of basis directions per subspace. Defaults to 32.
136
+
137
+ #### RETURNS:
138
+ - A dict / Space_dataclass mapping each quantity name to:
139
+ (W_q, A_q, V_q)
140
+
141
+ where:
142
+ W_q = subspace basis matrix (k, D)
143
+ V_q = principal latent direction (D,)
144
+ A_q = fitted decoder matrix
145
+ """
146
+ wv = principal_directions(Z, Y)
147
+ W = subspaces(Z, Y, k=k)
148
+ A = fit_decoders(W, Z, Y)
149
+ SPACE = Space(W, A, V={k: v[1] for k, v in wv.items()})
150
+ return SPACE
151
+
152
+ def solve_dims(
153
+ Y: dict[str, torch.Tensor],
154
+ max_bins: int,
155
+ bin_thresholds: float | dict[str, float],
156
+ k_thresholds: float | dict[str, float],
157
+ *,
158
+ bin_default_threshold: float | None = None,
159
+ k_default_threshold: float | None = None,
160
+ pooling_mode: str = "avg",
161
+ pool_per_quantity: bool = False,
162
+ k_per_quantity: bool = False
163
+ ) -> dict[str, tuple[int, int]]:
164
+ """
165
+ Computes recommended pooling resolutions and subspace dimensions for
166
+ a collection of physical targets.
167
+
168
+ This combines optimal_bins() and k_per_q() into a single sizing
169
+ pipeline.
170
+
171
+ #### USAGE:
172
+ dims = solve_dims(
173
+ Y,
174
+ max_bins=16,
175
+ bin_thresholds=0.95,
176
+ k_thresholds=0.90
177
+ )
178
+
179
+ # per-quantity sizing
180
+ dims = solve_dims(
181
+ Y,
182
+ max_bins=16,
183
+ bin_thresholds=0.95,
184
+ k_thresholds=0.90,
185
+ pool_per_quantity=True,
186
+ k_per_quantity=True
187
+ )
188
+
189
+ #### HOW IT WORKS:
190
+ 1. optimal_bins() determines the minimum pooling resolution
191
+ needed to retain the desired fraction of spatial variance
192
+ 2. Targets are pooled using the selected resolution(s)
193
+ 3. k_per_q() estimates the minimum subspace dimensionality
194
+ needed to retain the desired fraction of variance
195
+ 4. Pooling resolutions and subspace dimensions are combined into
196
+ per-quantity (n_bins, k) recommendations
197
+
198
+ #### ARGS:
199
+ - Y: Dictionary of physical target tensors.
200
+ - max_bins: Maximum pooling grid side length considered.
201
+ - bin_thresholds: Variance-retention threshold(s) used by
202
+ optimal_bins().
203
+ - k_thresholds: Variance-retention threshold(s) used by
204
+ k_per_q().
205
+ - bin_default_threshold: Fallback threshold for quantities
206
+ missing from bin_thresholds.
207
+ - k_default_threshold: Fallback threshold for quantities
208
+ missing from k_thresholds.
209
+ - pooling_mode: Pooling mode passed to pool().
210
+ Defaults to "avg".
211
+ - pool_per_quantity: Whether pooling resolution is selected
212
+ independently for each quantity.
213
+ - k_per_quantity: Whether subspace dimensionality is selected
214
+ independently for each quantity.
215
+
216
+ #### RETURNS:
217
+ - A dict:
218
+ {name: (n_bins, k)}
219
+
220
+ mapping each quantity to its recommended pooling resolution
221
+ and subspace dimensionality.
222
+ """
223
+ """
224
+ Computes optimal binning and subspace dimensions for a collection of targets.
225
+
226
+ If 'pool_per_quantity' is True, it returns a dictionary mapping each quantity
227
+ to its optimal (n, k) pair. Otherwise, it returns a single global (n, k) tuple.
228
+ """
229
+ # optimal bin resolution
230
+ n = optimal_bins(
231
+ Y, max_bins, bin_thresholds,
232
+ default_threshold=bin_default_threshold,
233
+ per_quantity=pool_per_quantity
234
+ )
235
+
236
+ # pool based on optimal bin resolution
237
+ if pool_per_quantity:
238
+ # n is a dict {name: n_q}
239
+ pooled = {q: pool(y, n[q], pooling_mode) for q, y in Y.items()}
240
+ else:
241
+ # n is a single global int
242
+ pooled = {q: pool(y, n, pooling_mode) for q, y in Y.items()}
243
+
244
+ # optimal subspace sizes
245
+ k = k_per_q(
246
+ pooled, k_thresholds,
247
+ default_threshold=k_default_threshold,
248
+ per_quantity=k_per_quantity
249
+ )
250
+
251
+ # aggregate n and k
252
+ if pool_per_quantity:
253
+ if k_per_quantity:
254
+ return {q: (n[q], k[q]) for q in Y.keys()}
255
+ else:
256
+ return {q: (n[q], k) for q in Y.keys()}
257
+ else:
258
+ if k_per_quantity:
259
+ return {q: (n, k[q]) for q in Y.keys()}
260
+ else:
261
+ return {q: (n, k) for q in Y.keys()}
@@ -0,0 +1,101 @@
1
+ """
2
+ Copyright (c) 2026 Tobias Karusseit
3
+ This source code is licensed under the MIT license found in the
4
+ LICENSE file in the root directory of this source tree.
5
+ """
6
+
7
+ import torch
8
+ from .util import full, left
9
+
10
+ def principal_directions(
11
+ z: torch.Tensor, # (B, D)
12
+ pooled_targets: dict[str, torch.Tensor], # {name: (B, n^2)}
13
+ ) -> dict[str, tuple[torch.Tensor, torch.Tensor]]:
14
+ """
15
+ Finds the rank-1 pair (w, v) for multiple quantities such that:
16
+ (z @ w).unsqueeze(1) @ v ~ y
17
+ via the cross-covariance SVD. This runs 'batched' across all quantities
18
+ provided in the dictionary.
19
+
20
+ w : (D,) — direction in latent space
21
+ v : (n^2,) — direction in target space
22
+
23
+ This is the analytical solution, the first left/right singular
24
+ vectors of C = z.T @ y are exactly the directions that maximise
25
+ the explained covariance between z and y under a rank-1 constraint.
26
+
27
+ HOW THIS WORKS IN DETAIL:
28
+
29
+ 1. We compute the cross-covariance matrix C = z.T @ y
30
+ 2. We want to find the vectors w (the latent directions) and v (the target directions)
31
+ that solve y = (z @ W) @ z.T
32
+ 3. since theres no guaranteed solution we instead search the optimal solution
33
+ min_{w, v} |y - (z @ w).unsqueeze(1) @ v|^2_F which is minimizing the forbenius norm of the error
34
+ Note that the forbenious norm ||A||^2_F is equal to the trace of A.T @ A (let the trace fn be Tr())
35
+ This gives: Tr((y-zWv.T).T @ (y-zWv.T)) = Tr((y.T - v(zW).T)(Y-zWv.T))
36
+ ...
37
+ We end up having to maximize 2v.T(z.T@y).Tw
38
+ -> max_{w,v} v.T(z.T@y).T w with z.T @ y = C
39
+ 4. using SVD we get w.T(UEV.T)v = (U.Tw).T E (V.T v)
40
+ Since UV are both orthogonal matrices we get unit vectors and this simplifies to
41
+ Sum_{i=1}^n sigma_i ^w_i ^v_i (let ^ denote the result of V.T v and U.T w)
42
+ 5. To maximize this sum we need to put the weight on the first sinular value (making the unit vecs one hot vectors along the first singular direction)
43
+ Naturally this leaves us with the first singular vectors found by svd, as they
44
+ align the input space along the direction in output space and scale by the larges singular value
45
+
46
+ 6. In short: C gives how well any dimension in z correlates with any other dimension in y.
47
+ SVD finds the basis in which te maximum covariance lies along one direction.
48
+ -> leaves us with a rank-1 solution since we discared all othersingular vectors
49
+ """
50
+ # center z once
51
+ z_c = z - z.mean(0, keepdim=True) # (B, D)
52
+
53
+ results = {}
54
+
55
+ for name, y in pooled_targets.items():
56
+ # if not flat, flatten
57
+ if y.dim() > 2:
58
+ y = y.flatten(1)
59
+
60
+ # center target
61
+ y_c = y - y.mean(0, keepdim=True) # (B, n^2)
62
+
63
+ # cross-covariance matrix
64
+ # gives correlation between latent and target dimensions
65
+ C = z_c.T @ y_c # (D, n^2)
66
+
67
+ # rank-1 SVD (keep only first singular vector)
68
+ # using full SVD and taking first column is fine for small n^2
69
+ U, _, Vh = full(C)
70
+
71
+ w = U[:, 0] # (D,) latent direction
72
+ v = Vh[0, :] # (n^2,) target direction
73
+
74
+ # normalise to unit length
75
+ w = w / (w.norm() + 1e-8)
76
+ v = v / (v.norm() + 1e-8)
77
+
78
+ results[name] = (w.detach(), v.detach())
79
+
80
+ return results
81
+
82
+ def subspaces(Z, pooled_global, k: int) -> dict[str, torch.Tensor]:
83
+ """
84
+ Find the optimal orthogonal basis in the latent space,
85
+ such that the span of subspaces capture the most variance of the pooled physics targets
86
+ """
87
+ print("Estimating subspaces...")
88
+ Z_c = Z - Z.mean(0, keepdim=True) # centered for variance calc
89
+ W = {}
90
+ for name, target in pooled_global.items():
91
+ # if not flat, flatten
92
+ if target.dim() > 2:
93
+ target = target.flatten(1)
94
+ y_c = target - target.mean(0, keepdim=True) # center
95
+ C = Z_c.T @ y_c # cross-cov-matrix
96
+ U, _ = left(C) # svd to find the left singular vectors (projection directions from latent into the subspace)
97
+ _k = k if isinstance(k, int) else k[name] # check if k is global or per quantity
98
+ W_q = U[:, :_k].T # only keep the top k (subspace size constraint)
99
+ W_q, _ = torch.linalg.qr(W_q.T, mode="reduced") # orthonormalise
100
+ W[name] = W_q.T # save
101
+ return W
@@ -0,0 +1,11 @@
1
+ """
2
+ Copyright (c) 2026 Tobias Karusseit
3
+ This source code is licensed under the MIT license found in the
4
+ LICENSE file in the root directory of this source tree.
5
+ """
6
+
7
+ from ._pipeline import input_to_quantity
8
+ from ._pool import pool
9
+ from ._utils import ensure_dim, to_device
10
+
11
+ __all__ = ["input_to_quantity", "pool", "ensure_dim", "to_device"]
@@ -0,0 +1,70 @@
1
+ """
2
+ Copyright (c) 2026 Tobias Karusseit
3
+ This source code is licensed under the MIT license found in the
4
+ LICENSE file in the root directory of this source tree.
5
+ """
6
+
7
+ import inspect
8
+ from typing import Callable, Any
9
+
10
+ def compute_Q(*arg_functions: Callable, **kwarg_functions: Callable) -> Callable:
11
+ """
12
+ Creates a dependency-injecting wrapper that executes multiple functions
13
+ simultaneously, passing only the required arguments to each.
14
+
15
+ #### USAGE:
16
+ # Define solvers
17
+ def solver_a(A: torch.Tensor, q: int): ...
18
+ def solver_b(A: torch.Tensor): ...
19
+
20
+ # Create the combined executor
21
+ executor = compute_Q(solver_a, solver_b=solver_b)
22
+
23
+ # Run both, injecting only the arguments they need
24
+ results = executor(A=my_matrix, q=5)
25
+ # solver_a receives {'A': my_matrix, 'q': 5}
26
+ # solver_b receives {'A': my_matrix}
27
+
28
+ #### VALIDATION:
29
+ The function will raise a TypeError if multiple registered functions
30
+ define the same parameter name but specify different type annotations.
31
+
32
+ #### ARGS:
33
+ - *arg_functions: Unnamed callables (names derived from __name__).
34
+ - **kwarg_functions: Named callables for the registry.
35
+ """
36
+ registry = {}
37
+ for f in arg_functions:
38
+ registry[getattr(f, "__name__", str(id(f)))] = f
39
+ registry.update(kwarg_functions)
40
+
41
+ # collect all fns and check that dublicated params all use the same type
42
+ type_hints = {}
43
+ registry_meta = {}
44
+
45
+ for name, f in registry.items():
46
+ sig = inspect.signature(f)
47
+ params = sig.parameters
48
+ registry_meta[name] = list(params.keys())
49
+
50
+ # validate types for params with the same names accross all fns
51
+ for p_name, param in params.items():
52
+ if p_name in type_hints and param.annotation != inspect.Parameter.empty:
53
+ if type_hints[p_name] != param.annotation:
54
+ raise TypeError(
55
+ f"Type mismatch for argument '{p_name}': "
56
+ f"Expected {type_hints[p_name]}, but {name} defined it as {param.annotation}"
57
+ )
58
+ type_hints[p_name] = param.annotation
59
+
60
+ def wrapped(**kwargs) -> dict[str, Any]:
61
+ results = {}
62
+ for name, f in registry.items():
63
+ relevant_args = {
64
+ k: kwargs[k] for k in registry_meta[name] if k in kwargs
65
+ }
66
+ results[name] = f(**relevant_args)
67
+ return results
68
+
69
+ return wrapped
70
+
@@ -0,0 +1,55 @@
1
+ """
2
+ Copyright (c) 2026 Tobias Karusseit
3
+ This source code is licensed under the MIT license found in the
4
+ LICENSE file in the root directory of this source tree.
5
+ """
6
+
7
+ import torch
8
+ from typing import Callable
9
+ from ._compute_Q import compute_Q
10
+ from ._pool import pool
11
+
12
+
13
+ def input_to_quantity(
14
+ n_bins: int,
15
+ mode: str,
16
+ *arg_functions: Callable,
17
+ **kwarg_functions: Callable
18
+ ) -> Callable:
19
+ """
20
+ Builds a callable that computes one or more quantities from injected
21
+ arguments (via `compute_Q`) and pools each resulting tensor into a
22
+ fixed number of bins.
23
+
24
+ #### USAGE:
25
+ def solver_a(A: torch.Tensor, q: int): ...
26
+ def solver_b(A: torch.Tensor): ...
27
+
28
+ # Create the combined executor: run solver_a and solver_b, then
29
+ # pool each of their outputs into 10 bins using mean pooling
30
+ executor = input_to_quantity(10, "mean", solver_a, solver_b=solver_b)
31
+
32
+ results = executor(A=my_matrix, q=5)
33
+ # results = {
34
+ # "solver_a": pool(solver_a(A=my_matrix, q=5), 10, "mean"),
35
+ # "solver_b": pool(solver_b(A=my_matrix), 10, "mean"),
36
+ # }
37
+
38
+ #### ARGS:
39
+ - n_bins: Number of bins each function's output is pooled into.
40
+ - mode: Pooling strategy passed to `pool` (e.g. "mean", "sum", "max").
41
+ - *arg_functions: Unnamed callables to register (names derived from
42
+ `__name__`); see `compute_Q`.
43
+ - **kwarg_functions: Named callables to register; see `compute_Q`.
44
+
45
+ #### RETURNS:
46
+ A callable that accepts keyword arguments, injects the relevant
47
+ subset into each registered function (see `compute_Q`), and returns
48
+ a dict mapping each function's name to its pooled `torch.Tensor`
49
+ output.
50
+ """
51
+ Q = compute_Q(*arg_functions, **kwarg_functions)
52
+ def wrapped(**kwargs) -> dict[str, torch.Tensor]:
53
+ data = Q(**kwargs)
54
+ return {name: pool(chunks, n_bins, mode) for name, chunks in data.items()}
55
+ return wrapped
@@ -0,0 +1,35 @@
1
+ """
2
+ Copyright (c) 2026 Tobias Karusseit
3
+ This source code is licensed under the MIT license found in the
4
+ LICENSE file in the root directory of this source tree.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ def spatial_pool_targets(chunks: torch.Tensor, n_bins: int, mode: str = "avg") -> torch.Tensor:
11
+ """chunks: (B, 1, H, W) -> (B, 1, n_bins, n_bins)"""
12
+ if mode == "avg":
13
+ return F.adaptive_avg_pool2d(chunks, (n_bins, n_bins))
14
+ elif mode == "max":
15
+ return F.adaptive_max_pool2d(chunks, (n_bins, n_bins))
16
+ else:
17
+ raise ValueError(f"Invalid mode: {mode}")
18
+
19
+ def pool_targets(chunks: torch.Tensor, n_bins: int, mode: str = "avg") -> torch.Tensor:
20
+ """chunks: (B, 1, H, W) -> (B, 1, n_bins, n_bins)"""
21
+ if mode == "avg":
22
+ return F.adaptive_avg_pool1d(chunks, n_bins)
23
+ elif mode == "max":
24
+ return F.adaptive_max_pool1d(chunks, n_bins)
25
+ else:
26
+ raise ValueError(f"Invalid mode: {mode}")
27
+
28
+ def pool(chunks: torch.Tensor, n_bins: int, mode: str = "avg") -> torch.Tensor:
29
+ ndims = chunks.dim()
30
+ if ndims == 4: # B, C, H, W
31
+ return spatial_pool_targets(chunks, n_bins, mode)
32
+ elif ndims == 2: # B, S
33
+ return pool_targets(chunks, n_bins, mode)
34
+ else:
35
+ raise ValueError(f"Invalid number of dimensions: {ndims}. Expected 2 (B, S) or 4 (B, C, H, W)")
@@ -0,0 +1,115 @@
1
+ """
2
+ Copyright (c) 2026 Tobias Karusseit
3
+ This source code is licensed under the MIT license found in the
4
+ LICENSE file in the root directory of this source tree.
5
+ """
6
+
7
+ from typing import Callable
8
+ import torch
9
+ import functools
10
+
11
+ def ensure_dim(func: Callable = None, *, dims: int, targets: list[str] | None = None) -> Callable:
12
+ """
13
+ Wraps a function to force input tensors to have a minimum number of dimensions
14
+ by prepending unit dimensions (1s) to their shape.
15
+
16
+ #### USAGE:
17
+ # Example 1: Force all inputs to be at least 4D (B, C, H, W)
18
+ @ensure_dim(dims=4)
19
+ def process_spatial(grid: torch.Tensor): ...
20
+
21
+ # Example 2: Only force specific inputs to be 4D
22
+ @ensure_dim(dims=4, targets=['pressure', 'velocity'])
23
+ def solve_physics(pressure: torch.Tensor, density: torch.Tensor):
24
+ # 'pressure' will be forced to 4D
25
+ # 'density' will remain untouched
26
+ ...
27
+
28
+ #### ARGS:
29
+ - function: The callable to wrap.
30
+ - dims: The target number of dimensions (e.g., 4 for spatial grids).
31
+ - targets: An optional list of argument names to target for expansion.
32
+ If None or empty, all positional and keyword tensors are expanded.
33
+
34
+ #### RETURNS:
35
+ - A wrapped callable that expands tensors to the requested dimensionality
36
+ using `view()` before execution.
37
+ """
38
+ if func is None:
39
+ return functools.partial(ensure_dim, dims=dims, targets=targets)
40
+ # helper to expand a tensor using view
41
+ def expand(t):
42
+ if isinstance(t, torch.Tensor) and t.dim() < dims:
43
+ # adds dimensions to the front (e.g., to make (H, W) -> (1, 1, H, W))
44
+ new_shape = (1,) * (dims - t.dim()) + t.shape
45
+ return t.view(new_shape)
46
+ return t
47
+
48
+ @functools.wraps(func)
49
+ def wrapped(*args, **kwargs):
50
+
51
+
52
+ # positional arguments
53
+ new_args = tuple(expand(a) for a in args)
54
+
55
+ # keyword arguments
56
+ new_kwargs = {}
57
+ for k, v in kwargs.items():
58
+ if targets is None or not targets or k in targets:
59
+ new_kwargs[k] = expand(v)
60
+ else:
61
+ new_kwargs[k] = v
62
+
63
+ return func(*new_args, **new_kwargs)
64
+
65
+ return wrapped
66
+
67
+ def to_device(func: Callable = None, *, device: str = "cpu", targets: list[str] = None):
68
+ """
69
+ Wraps a function to move input tensors to a target device before execution.
70
+
71
+ #### USAGE:
72
+ # Example 1: Move all tensor inputs to a device
73
+ @to_device(device="cuda")
74
+ def process(grid: torch.Tensor): ...
75
+
76
+ # Example 2: Used bare, defaults to "cpu"
77
+ @to_device
78
+ def process(grid: torch.Tensor): ...
79
+
80
+ # Example 3: Only move specific inputs
81
+ @to_device(device="cuda", targets=["pressure"])
82
+ def solve_physics(pressure: torch.Tensor, density: torch.Tensor):
83
+ # 'pressure' will be moved to "cuda"
84
+ # 'density' will remain untouched
85
+ ...
86
+
87
+ #### ARGS:
88
+ - func: The callable to wrap.
89
+ - device: The target device to move tensors to (e.g., "cpu", "cuda").
90
+ - targets: An optional list of argument names to target for moving.
91
+ If None, all positional and keyword tensors are moved.
92
+
93
+ #### RETURNS:
94
+ - A wrapped callable that moves tensors to the requested device
95
+ using `.to()` before execution.
96
+ """
97
+ # This allows the decorator to be used as @to_device OR @to_device(device='cuda')
98
+ if func is None:
99
+ return functools.partial(to_device, device=device, targets=targets)
100
+
101
+ @functools.wraps(func)
102
+ def wrapped(*args, **kwargs):
103
+ def move(t):
104
+ if isinstance(t, torch.Tensor):
105
+ return t.to(device)
106
+ return t
107
+
108
+ new_args = tuple(move(a) for a in args)
109
+ new_kwargs = {
110
+ k: move(v) if (targets is None or k in targets) else v
111
+ for k, v in kwargs.items()
112
+ }
113
+ return func(*new_args, **new_kwargs)
114
+
115
+ return wrapped
@@ -0,0 +1,17 @@
1
+ """
2
+ Copyright (c) 2026 Tobias Karusseit
3
+ This source code is licensed under the MIT license found in the
4
+ LICENSE file in the root directory of this source tree.
5
+ """
6
+
7
+ from ._alignment_loss import alignment_loss
8
+ from ._disentanglement_loss import disentanglement_loss
9
+ from ._isotropy_loss import isotropy_loss
10
+ from ._stability_loss import stability_loss
11
+
12
+ __all__ = [
13
+ "alignment_loss",
14
+ "disentanglement_loss",
15
+ "isotropy_loss",
16
+ "stability_loss",
17
+ ]