spacelearn 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. spacelearn-0.1.0/.github/workflows/spacelearn.yml +24 -0
  2. spacelearn-0.1.0/.gitignore +8 -0
  3. spacelearn-0.1.0/LICENSE +21 -0
  4. spacelearn-0.1.0/PKG-INFO +231 -0
  5. spacelearn-0.1.0/README.md +169 -0
  6. spacelearn-0.1.0/SPACE.md +190 -0
  7. spacelearn-0.1.0/docs/SETTINGS.md +121 -0
  8. spacelearn-0.1.0/docs/SPACELEARN.md +227 -0
  9. spacelearn-0.1.0/docs/data/DATA.md +31 -0
  10. spacelearn-0.1.0/docs/data/POOL.md +107 -0
  11. spacelearn-0.1.0/docs/data/QUANTITY.md +151 -0
  12. spacelearn-0.1.0/docs/data/UTILS.md +124 -0
  13. spacelearn-0.1.0/docs/graphics/Threshold_analysis_2.png +0 -0
  14. spacelearn-0.1.0/docs/loss/ALIGNMENT.md +69 -0
  15. spacelearn-0.1.0/docs/loss/DISENTANGLEMENT.md +88 -0
  16. spacelearn-0.1.0/docs/loss/ISOTROPY.md +61 -0
  17. spacelearn-0.1.0/docs/loss/LOSS.md +31 -0
  18. spacelearn-0.1.0/docs/loss/STABILITY.md +64 -0
  19. spacelearn-0.1.0/docs/optim/FIT_DECODER.md +69 -0
  20. spacelearn-0.1.0/docs/optim/OPTIM.md +39 -0
  21. spacelearn-0.1.0/docs/optim/OPTIMIZE_SUBSPACES.md +77 -0
  22. spacelearn-0.1.0/docs/optim/OPTIMIZE_TAREGTS.md +81 -0
  23. spacelearn-0.1.0/docs/optim/OPTIMZE_LATENT.md +62 -0
  24. spacelearn-0.1.0/docs/optim/SUBSPACE_CHANGE.md +90 -0
  25. spacelearn-0.1.0/pyproject.toml +40 -0
  26. spacelearn-0.1.0/src/spacelearn/__init__.py +16 -0
  27. spacelearn-0.1.0/src/spacelearn/_combined.py +261 -0
  28. spacelearn-0.1.0/src/spacelearn/_find_subspaces.py +101 -0
  29. spacelearn-0.1.0/src/spacelearn/data/__init__.py +11 -0
  30. spacelearn-0.1.0/src/spacelearn/data/_compute_Q.py +70 -0
  31. spacelearn-0.1.0/src/spacelearn/data/_pipeline.py +55 -0
  32. spacelearn-0.1.0/src/spacelearn/data/_pool.py +35 -0
  33. spacelearn-0.1.0/src/spacelearn/data/_utils.py +115 -0
  34. spacelearn-0.1.0/src/spacelearn/loss/__init__.py +17 -0
  35. spacelearn-0.1.0/src/spacelearn/loss/_alignment_loss.py +59 -0
  36. spacelearn-0.1.0/src/spacelearn/loss/_disentanglement_loss.py +111 -0
  37. spacelearn-0.1.0/src/spacelearn/loss/_isotropy_loss.py +56 -0
  38. spacelearn-0.1.0/src/spacelearn/loss/_stability_loss.py +44 -0
  39. spacelearn-0.1.0/src/spacelearn/optim/__init__.py +19 -0
  40. spacelearn-0.1.0/src/spacelearn/optim/_fit_decoder.py +56 -0
  41. spacelearn-0.1.0/src/spacelearn/optim/_optimize_latent.py +92 -0
  42. spacelearn-0.1.0/src/spacelearn/optim/_optimize_subspaces.py +109 -0
  43. spacelearn-0.1.0/src/spacelearn/optim/_optimize_targets.py +119 -0
  44. spacelearn-0.1.0/src/spacelearn/optim/_subspace_change.py +52 -0
  45. spacelearn-0.1.0/src/spacelearn/settings.py +20 -0
  46. spacelearn-0.1.0/src/spacelearn/util/__init__.py +9 -0
  47. spacelearn-0.1.0/src/spacelearn/util/_eigen_fallback.py +126 -0
  48. spacelearn-0.1.0/src/spacelearn/util/_lowrank_fallback.py +25 -0
  49. spacelearn-0.1.0/src/spacelearn/util/_svd_fallback.py +24 -0
  50. spacelearn-0.1.0/src/spacelearn/util/svd_util.py +60 -0
  51. spacelearn-0.1.0/tests/test_compute_Q.py +52 -0
  52. spacelearn-0.1.0/tests/test_data_utils.py +49 -0
  53. spacelearn-0.1.0/tests/test_input_to_q.py +54 -0
  54. spacelearn-0.1.0/tests/test_losses.py +329 -0
  55. spacelearn-0.1.0/tests/test_minimal_latent_init.py +105 -0
  56. spacelearn-0.1.0/tests/test_optim.py +248 -0
  57. spacelearn-0.1.0/tests/test_subspace_change.py +132 -0
  58. spacelearn-0.1.0/tests/test_subspace_solver.py +219 -0
  59. spacelearn-0.1.0/tests/test_subspaces.py +294 -0
  60. spacelearn-0.1.0/tests/test_svd_backend.py +352 -0
  61. spacelearn-0.1.0/tox.ini +10 -0
@@ -0,0 +1,24 @@
1
+ name: spacelearn
2
+
3
+ on: [push, pull_request]
4
+
5
+ jobs:
6
+ build:
7
+ runs-on: ubuntu-latest
8
+ strategy:
9
+ matrix:
10
+ python-version: ["3.10", "3.11", "3.12"]
11
+
12
+ steps:
13
+ - uses: actions/checkout@v4
14
+ - name: Set up Python ${{ matrix.python-version }}
15
+ uses: actions/setup-python@v4
16
+ with:
17
+ python-version: ${{ matrix.python-version }}
18
+ - name: Install tox
19
+ run: |
20
+ python -m pip install --upgrade pip
21
+ pip install tox
22
+
23
+ - name: Run tests with tox
24
+ run: tox
@@ -0,0 +1,8 @@
1
+ __pycache__/
2
+ venv/
3
+ .pytest_cache/
4
+ dist/
5
+ build/
6
+ *.egg-info/
7
+ .pytest_cache/
8
+ .tox/
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) [2026] [Tobias Karusseit]
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,231 @@
1
+ Metadata-Version: 2.4
2
+ Name: spacelearn
3
+ Version: 0.1.0
4
+ Summary: Guided latent subspace learning for deep neural networks
5
+ Project-URL: Homepage, https://github.com/K-T0BIAS/SPACE
6
+ Project-URL: Issues, https://github.com/K-T0BIAS/SPACE/issues
7
+ Author-email: Tobias Karusseit <karusseittobi@gmail.com>
8
+ License: MIT License
9
+
10
+ Copyright (c) [2026] [Tobias Karusseit]
11
+
12
+ Permission is hereby granted, free of charge, to any person obtaining a copy
13
+ of this software and associated documentation files (the "Software"), to deal
14
+ in the Software without restriction, including without limitation the rights
15
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16
+ copies of the Software, and to permit persons to whom the Software is
17
+ furnished to do so, subject to the following conditions:
18
+
19
+ The above copyright notice and this permission notice shall be included in all
20
+ copies or substantial portions of the Software.
21
+
22
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28
+ SOFTWARE.
29
+ License-File: LICENSE
30
+ Keywords: latent-space,physics-informed,pytorch,representation-learning,subspaces
31
+ Classifier: License :: OSI Approved :: MIT License
32
+ Classifier: Operating System :: OS Independent
33
+ Classifier: Programming Language :: Python :: 3
34
+ Requires-Python: >=3.8
35
+ Requires-Dist: numpy>=2.0.0
36
+ Requires-Dist: torch>=2.0.0
37
+ Provides-Extra: dev
38
+ Requires-Dist: pytest; extra == 'dev'
39
+ Requires-Dist: tox; extra == 'dev'
40
+ Description-Content-Type: text/markdown
41
+
42
+ [README](./README.md)
43
+
44
+ # SPACE (Subspace Partitioning for Accessible Controlled Encodings)
45
+
46
+ `SPACE-Learn` provides a utility library build on `pytorch`, that offers functions to train your Deep-Learning Models on the `Latent-SPACE` regime. The SPACE method helps create structured and easy to decode embedding spaces or explicitly guide a models internal latent space using training side information. This method encodes features that are known during training time or intermediate results directly into a models latentspace.
47
+
48
+ ## Example Setup
49
+
50
+ ### 1. Install
51
+
52
+ ```bash
53
+ pip install spacelearn
54
+ ```
55
+
56
+ ### 2. Example Training Setup
57
+
58
+ ```python
59
+ import torch
60
+ from torch.optim import Adam
61
+ from collections import defaultdict
62
+
63
+ from spacelearn import (
64
+ solve_dims,
65
+ solve_subspaces,
66
+ combined_loss,
67
+ )
68
+ from spacelearn.data import input_to_quantity
69
+ from spacelearn.optim import minimal_latent_dim
70
+
71
+ from my_model import Model
72
+ from my_data import (
73
+ DataLoader,
74
+ quantity_helper_a,
75
+ quantity_helper_b,
76
+ )
77
+
78
+ EPOCHS = 5
79
+ LR = 1e-3
80
+
81
+ MAX_BINS = 10
82
+
83
+ # target variance retained by pooling
84
+ BIN_THRESHOLDS = 0.80
85
+
86
+ # target variance retained by subspaces
87
+ K_THRESHOLDS = {
88
+ "A": 0.90,
89
+ "B": 0.85,
90
+ }
91
+
92
+ REEVAL_EVERY = 2
93
+ INIT_SAMPLES = 100
94
+
95
+ dataloader = DataLoader()
96
+
97
+ # ------------------------------------------------------------------
98
+ # 1. Collect data for dimension estimation
99
+ # ------------------------------------------------------------------
100
+
101
+ test_data = defaultdict(list)
102
+ samples = []
103
+
104
+ for i in range(INIT_SAMPLES):
105
+ sample = dataloader[i]
106
+
107
+ test_data["A"].append(
108
+ quantity_helper_a(samples=sample)
109
+ )
110
+ test_data["B"].append(
111
+ quantity_helper_b(samples=sample)
112
+ )
113
+
114
+ samples.append(sample)
115
+
116
+ test_data["A"] = torch.stack(test_data["A"])
117
+ test_data["B"] = torch.stack(test_data["B"])
118
+
119
+ samples = torch.stack(samples)
120
+
121
+ # ------------------------------------------------------------------
122
+ # 2. Estimate n and k
123
+ # ------------------------------------------------------------------
124
+
125
+ dims = solve_dims(
126
+ test_data,
127
+ max_bins=MAX_BINS,
128
+ bin_thresholds=BIN_THRESHOLDS,
129
+ k_thresholds=K_THRESHOLDS,
130
+ k_per_quantity=True,
131
+ )
132
+
133
+ # shared pooling resolution
134
+ N = next(iter(dims.values()))[0]
135
+
136
+ # per-quantity subspace dimensions
137
+ K = {q: k for q, (_, k) in dims.items()}
138
+
139
+ # latent dimension
140
+ D = minimal_latent_dim(
141
+ K,
142
+ free_frac=0.10,
143
+ )
144
+
145
+ # ------------------------------------------------------------------
146
+ # 3. Build quantity extraction helper
147
+ # ------------------------------------------------------------------
148
+
149
+ pool_helper = input_to_quantity(
150
+ N,
151
+ "avg",
152
+ A=quantity_helper_a,
153
+ B=quantity_helper_b,
154
+ )
155
+
156
+ # ------------------------------------------------------------------
157
+ # 4. Initialize model
158
+ # ------------------------------------------------------------------
159
+
160
+ model = Model(D)
161
+ optimizer = Adam(model.parameters(), lr=LR)
162
+
163
+ # ------------------------------------------------------------------
164
+ # 5. Training loop
165
+ # ------------------------------------------------------------------
166
+
167
+ W_prev = None
168
+
169
+ for epoch in range(EPOCHS):
170
+
171
+ # periodically recompute subspaces
172
+ if epoch % REEVAL_EVERY == 0:
173
+
174
+ with torch.no_grad():
175
+ Z_ref = model(samples)
176
+
177
+ Y_ref = pool_helper(samples=samples)
178
+
179
+ WAV = solve_subspaces(
180
+ Z_ref,
181
+ Y_ref,
182
+ k=k,
183
+ )
184
+
185
+ W = WAV.W
186
+ A = WAV.A
187
+
188
+ if W_prev is None:
189
+ W_prev = {
190
+ q: w.clone()
191
+ for q, w in W.items()
192
+ }
193
+
194
+ for batch in dataloader:
195
+
196
+ Y = pool_helper(samples=batch)
197
+ Z = model(batch)
198
+
199
+ space_loss = combined_loss(
200
+ Z,
201
+ W,
202
+ A,
203
+ Y,
204
+ W_prev,
205
+ )
206
+
207
+ # task_loss = ...
208
+ # loss = task_loss + space_loss
209
+
210
+ optimizer.zero_grad()
211
+ space_loss.backward()
212
+ optimizer.step()
213
+
214
+ W_prev = {
215
+ q: w.clone()
216
+ for q, w in W.items()
217
+ }
218
+ ```
219
+
220
+
221
+ ## DOCS:
222
+
223
+ [spacelearn](./docs/SPACELEARN.md)
224
+
225
+ [spacelearn.settings](./docs/SETTINGS.md)
226
+
227
+ [spacelearn.data](./docs/data/DATA.md)
228
+
229
+ [spacelearn.loss](./docs/loss/LOSS.md)
230
+
231
+ [spacelearn.optim](./docs/optim/OPTIM.md)
@@ -0,0 +1,169 @@
1
+ # SPACE (Subspace Partitioning for Accessible Controlled Encodings)
2
+
3
+ `spacelearn` is a utility library built on top of `PyTorch` that provides tools for training deep-learning models in the **Latent-SPACE** regime. The SPACE method helps create structured, interpretable, and easily decodable embedding spaces by explicitly guiding a model's internal latent representation using side information available during training.
4
+
5
+ Rather than treating the latent space as an unconstrained representation, SPACE partitions it into dedicated subspaces that encode specific physical quantities, intermediate computations, or other known properties of the data. This allows these quantities to be decoded directly from the latent representation while still preserving unconstrained latent capacity for the model to learn additional features.
6
+
7
+ #### [Documentation](./SPACE.md)
8
+
9
+ ## The Method
10
+
11
+ ### 1. Embedding Quantities
12
+
13
+ Since SPACE is used to embed specific information into the latent space, the first step is defining the target quantities. These quantities may have any number of dimensions; however, the method generally works best when targets contain between roughly 10 and a few thousand elements.
14
+
15
+ *Soft guidelines for target quantities:*
16
+
17
+ * Quantities should contain enough information that they cannot easily be represented by only a few latent dimensions.
18
+
19
+ * **Bad example:** a 2D rotation matrix
20
+
21
+ \[
22
+ \mathcal{R}(\theta)=
23
+ \left(
24
+ \begin{matrix}
25
+ \cos(\theta) & -\sin(\theta) \\
26
+ \sin(\theta) & \cos(\theta)
27
+ \end{matrix}
28
+ \right)
29
+ \]
30
+
31
+ * **Good example:** the spatial gradient of a 64 × 64 image
32
+
33
+ \[
34
+ G_x=
35
+ \left(
36
+ \begin{matrix}
37
+ -1 & 0 & 1 \\
38
+ -2 & 0 & 2 \\
39
+ -1 & 0 & 1
40
+ \end{matrix}
41
+ \right),
42
+ \qquad
43
+ G_y=
44
+ \left(
45
+ \begin{matrix}
46
+ -1 & -2 & -1 \\
47
+ 0 & 0 & 0 \\
48
+ 1 & 2 & 1
49
+ \end{matrix}
50
+ \right)
51
+ \]
52
+
53
+ \[
54
+ G_{A,x}=G_x*A,
55
+ \qquad
56
+ G_{A,y}=G_y*A
57
+ \]
58
+
59
+ * Quantities with only a few elements can often be encoded directly into a small number of latent dimensions instead of first projecting them into learned subspaces.
60
+ * Quantities should exhibit predictable variance across samples or be analyzed using relatively high variance-retention thresholds. For quantities with highly variable sample-to-sample structure, it is often beneficial to estimate subspaces and decoders using large and diverse datasets.
61
+ * Since SPACE relies heavily on singular value decompositions and eigendecompositions, the underlying data distributions should be reasonably well-conditioned.
62
+
63
+ ### 2. Setup / Hyperparameter Tuning
64
+
65
+ Once the target quantities have been defined, SPACE provides utilities to estimate all major SPACE-specific hyperparameters:
66
+
67
+ * `optimal_bins()` / `solve_dims()` for determining the pooling resolution `n`
68
+ * `k_per_q()` / `solve_dims()` for determining subspace dimensions `k`
69
+ * `minimal_latent_dim()` for determining the latent dimension `D`
70
+
71
+ #### Step 1: Choose Pooling Resolution
72
+
73
+ Use `optimal_bins()` (or `solve_dims()`) to estimate the minimum pooling resolution required to retain a desired fraction of spatial variance.
74
+
75
+ You must choose a target variance-retention threshold for the pooled quantities. Depending on the structure of the quantity, this threshold may have a dramatic or only minor effect on the resulting pooling resolution.
76
+
77
+ #### Step 2: Choose Subspace Dimensions
78
+
79
+ After determining the pooling resolution, use `k_per_q()` (or `solve_dims()`) to estimate the minimum subspace dimension required for each quantity.
80
+
81
+ The resulting value of `k` depends on:
82
+
83
+ * The variance-retention threshold chosen for the subspace
84
+ * The pooling resolution `n`
85
+ * The structure and complexity of the quantity itself
86
+
87
+ ![image showing the relationship of n and k on 2D thermodynamic quantities](./docs/graphics/Threshold_analysis_2.png)
88
+
89
+ *Relationship between pooling resolution (`n`) and subspace dimensionality (`k`) for 2D thermodynamic quantities (original quantity shape: 64 × 64).*
90
+
91
+ #### Step 3: Choose the Latent Dimension
92
+
93
+ Once `n` and `k` are known, use `minimal_latent_dim()` to estimate the required latent dimension `D`.
94
+
95
+ This computation reserves enough latent capacity for all learned subspaces while optionally leaving a configurable fraction of the latent space unconstrained.
96
+
97
+ We recommend reserving free latent capacity because:
98
+
99
+ * Some quantities require additional latent freedom to find stable subspace geometries
100
+ * Joint training setups often benefit from latent dimensions that are not explicitly assigned to any known quantity
101
+ * Models frequently learn useful hidden representations that are not captured by the selected physical targets
102
+
103
+ ### 3. Training
104
+
105
+ #### Step 1: Initialize Subspaces and Decoders
106
+
107
+ Before beginning the main training loop, initialize the subspaces and decoders.
108
+
109
+ We recommend collecting a dataset similar in size to the one used when estimating `n` and `k`.
110
+
111
+ Subspaces and decoders can then be initialized using:
112
+
113
+ ```python
114
+ WVA = solve_subspaces(Z, pooled_targets, k=k)
115
+ ```
116
+
117
+ or individually through:
118
+
119
+ * `principal_directions()`
120
+ * `subspaces()`
121
+ * `fit_decoders()`
122
+
123
+ #### Step 2: Train Using the SPACE Objective
124
+
125
+ Add the SPACE objective to your training loss in the same way as any other regularization term.
126
+
127
+ The simplest approach is to use:
128
+
129
+ ```python
130
+ loss = combined_loss(
131
+ Z,
132
+ W,
133
+ A,
134
+ pooled_targets,
135
+ W_prev
136
+ )
137
+ ```
138
+
139
+ Internally this combines:
140
+
141
+ * `alignment_loss()`
142
+ * `disentanglement_loss()`
143
+ * `isotropy_loss()`
144
+ * `stability_loss()`
145
+
146
+ #### Step 3: Periodically Recompute Subspaces
147
+
148
+ We recommend periodically recomputing:
149
+
150
+ * `W` (subspace projectors)
151
+ * `A` (subspace decoders)
152
+
153
+ using `solve_subspaces()`.
154
+
155
+ Experiments have shown that an exponential decay schedule often produces the best subspace geometries. In practice, this means recomputing relatively frequently during the early stages of training and less frequently later on.
156
+
157
+ We generally do not recommend recomputing subspaces within an epoch unless extremely large datasets are being used.
158
+
159
+ ## Architectures and Setups That Work Well
160
+
161
+ ### 1. Separate Encoder / Decoder U-Net
162
+
163
+ Our initial experiments used a U-Net-style architecture in which the encoder (including the bottleneck) was trained under the SPACE regime, while the decoder was trained separately on the final inference objective.
164
+
165
+ This setup showed promising results for physics-guided inference tasks where conventional U-Nets often struggle to learn physically meaningful latent representations.
166
+
167
+ ### 2. Additional Architectures
168
+
169
+ This section is currently a work in progress and may be published separately in the future.
@@ -0,0 +1,190 @@
1
+ [README](./README.md)
2
+
3
+ # SPACE (Subspace Partitioning for Accessible Controlled Encodings)
4
+
5
+ `SPACE-Learn` provides a utility library build on `pytorch`, that offers functions to train your Deep-Learning Models on the `Latent-SPACE` regime. The SPACE method helps create structured and easy to decode embedding spaces or explicitly guide a models internal latent space using training side information. This method encodes features that are known during training time or intermediate results directly into a models latentspace.
6
+
7
+ ## Example Setup
8
+
9
+ ### 1. Install
10
+
11
+ ```bash
12
+ pip install spacelearn
13
+ ```
14
+
15
+ ### 2. Example Training Setup
16
+
17
+ ```python
18
+ import torch
19
+ from torch.optim import Adam
20
+ from collections import defaultdict
21
+
22
+ from spacelearn import (
23
+ solve_dims,
24
+ solve_subspaces,
25
+ combined_loss,
26
+ )
27
+ from spacelearn.data import input_to_quantity
28
+ from spacelearn.optim import minimal_latent_dim
29
+
30
+ from my_model import Model
31
+ from my_data import (
32
+ DataLoader,
33
+ quantity_helper_a,
34
+ quantity_helper_b,
35
+ )
36
+
37
+ EPOCHS = 5
38
+ LR = 1e-3
39
+
40
+ MAX_BINS = 10
41
+
42
+ # target variance retained by pooling
43
+ BIN_THRESHOLDS = 0.80
44
+
45
+ # target variance retained by subspaces
46
+ K_THRESHOLDS = {
47
+ "A": 0.90,
48
+ "B": 0.85,
49
+ }
50
+
51
+ REEVAL_EVERY = 2
52
+ INIT_SAMPLES = 100
53
+
54
+ dataloader = DataLoader()
55
+
56
+ # ------------------------------------------------------------------
57
+ # 1. Collect data for dimension estimation
58
+ # ------------------------------------------------------------------
59
+
60
+ test_data = defaultdict(list)
61
+ samples = []
62
+
63
+ for i in range(INIT_SAMPLES):
64
+ sample = dataloader[i]
65
+
66
+ test_data["A"].append(
67
+ quantity_helper_a(samples=sample)
68
+ )
69
+ test_data["B"].append(
70
+ quantity_helper_b(samples=sample)
71
+ )
72
+
73
+ samples.append(sample)
74
+
75
+ test_data["A"] = torch.stack(test_data["A"])
76
+ test_data["B"] = torch.stack(test_data["B"])
77
+
78
+ samples = torch.stack(samples)
79
+
80
+ # ------------------------------------------------------------------
81
+ # 2. Estimate n and k
82
+ # ------------------------------------------------------------------
83
+
84
+ dims = solve_dims(
85
+ test_data,
86
+ max_bins=MAX_BINS,
87
+ bin_thresholds=BIN_THRESHOLDS,
88
+ k_thresholds=K_THRESHOLDS,
89
+ k_per_quantity=True,
90
+ )
91
+
92
+ # shared pooling resolution
93
+ N = next(iter(dims.values()))[0]
94
+
95
+ # per-quantity subspace dimensions
96
+ K = {q: k for q, (_, k) in dims.items()}
97
+
98
+ # latent dimension
99
+ D = minimal_latent_dim(
100
+ K,
101
+ free_frac=0.10,
102
+ )
103
+
104
+ # ------------------------------------------------------------------
105
+ # 3. Build quantity extraction helper
106
+ # ------------------------------------------------------------------
107
+
108
+ pool_helper = input_to_quantity(
109
+ N,
110
+ "avg",
111
+ A=quantity_helper_a,
112
+ B=quantity_helper_b,
113
+ )
114
+
115
+ # ------------------------------------------------------------------
116
+ # 4. Initialize model
117
+ # ------------------------------------------------------------------
118
+
119
+ model = Model(D)
120
+ optimizer = Adam(model.parameters(), lr=LR)
121
+
122
+ # ------------------------------------------------------------------
123
+ # 5. Training loop
124
+ # ------------------------------------------------------------------
125
+
126
+ W_prev = None
127
+
128
+ for epoch in range(EPOCHS):
129
+
130
+ # periodically recompute subspaces
131
+ if epoch % REEVAL_EVERY == 0:
132
+
133
+ with torch.no_grad():
134
+ Z_ref = model(samples)
135
+
136
+ Y_ref = pool_helper(samples=samples)
137
+
138
+ WAV = solve_subspaces(
139
+ Z_ref,
140
+ Y_ref,
141
+ k=k,
142
+ )
143
+
144
+ W = WAV.W
145
+ A = WAV.A
146
+
147
+ if W_prev is None:
148
+ W_prev = {
149
+ q: w.clone()
150
+ for q, w in W.items()
151
+ }
152
+
153
+ for batch in dataloader:
154
+
155
+ Y = pool_helper(samples=batch)
156
+ Z = model(batch)
157
+
158
+ space_loss = combined_loss(
159
+ Z,
160
+ W,
161
+ A,
162
+ Y,
163
+ W_prev,
164
+ )
165
+
166
+ # task_loss = ...
167
+ # loss = task_loss + space_loss
168
+
169
+ optimizer.zero_grad()
170
+ space_loss.backward()
171
+ optimizer.step()
172
+
173
+ W_prev = {
174
+ q: w.clone()
175
+ for q, w in W.items()
176
+ }
177
+ ```
178
+
179
+
180
+ ## DOCS:
181
+
182
+ [spacelearn](./docs/SPACELEARN.md)
183
+
184
+ [spacelearn.settings](./docs/SETTINGS.md)
185
+
186
+ [spacelearn.data](./docs/data/DATA.md)
187
+
188
+ [spacelearn.loss](./docs/loss/LOSS.md)
189
+
190
+ [spacelearn.optim](./docs/optim/OPTIM.md)