vlora-dev 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vlora/subspace.py ADDED
@@ -0,0 +1,651 @@
1
+ """SharedSubspace — core state container and 3-step algorithm.
2
+
3
+ Step 1: from_adapters — build shared basis via SVD
4
+ Step 2: project — project new adapter onto basis
5
+ Step 3: absorb — incorporate new adapter, recompute basis
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import logging
11
+ from dataclasses import dataclass, field
12
+ from pathlib import Path
13
+ from typing import Literal
14
+
15
+ import torch
16
+
17
+ logger = logging.getLogger("vlora")
18
+ from safetensors.torch import load_file, save_file
19
+ from torch import Tensor
20
+
21
+ from vlora._validate import (
22
+ check_adapter_matches_subspace,
23
+ check_adapters_compatible,
24
+ check_task_exists,
25
+ check_tensor_health,
26
+ )
27
+ from vlora.io import LoRAWeights, stack_lora_weights
28
+ from vlora.ops import (
29
+ compute_svd,
30
+ explained_variance_ratio,
31
+ gram_schmidt,
32
+ incremental_svd_update,
33
+ project_onto_subspace,
34
+ reconstruct_from_subspace,
35
+ select_num_components,
36
+ )
37
+
38
+
39
+ @dataclass
40
+ class TaskProjection:
41
+ """A single task's representation in the shared subspace."""
42
+
43
+ task_id: str
44
+ loadings_a: dict[str, Tensor] # layer_name -> (k,)
45
+ loadings_b: dict[str, Tensor] # layer_name -> (k,)
46
+
47
+
48
+ class SharedSubspace:
49
+ """Shared low-rank subspace for LoRA adapters.
50
+
51
+ Maintains per-layer orthonormal basis vectors (components) and
52
+ per-task coefficient vectors (loadings).
53
+ """
54
+
55
+ def __init__(
56
+ self,
57
+ layer_names: list[str],
58
+ components_a: dict[str, Tensor],
59
+ components_b: dict[str, Tensor],
60
+ singular_values_a: dict[str, Tensor],
61
+ singular_values_b: dict[str, Tensor],
62
+ means_a: dict[str, Tensor],
63
+ means_b: dict[str, Tensor],
64
+ tasks: dict[str, TaskProjection],
65
+ rank: int,
66
+ num_components: int,
67
+ ):
68
+ self.layer_names = layer_names
69
+ self.components_a = components_a
70
+ self.components_b = components_b
71
+ self.singular_values_a = singular_values_a
72
+ self.singular_values_b = singular_values_b
73
+ self.means_a = means_a
74
+ self.means_b = means_b
75
+ self.tasks = tasks
76
+ self.rank = rank
77
+ self.num_components = num_components
78
+
79
+ @classmethod
80
+ def from_adapters(
81
+ cls,
82
+ adapters: list[LoRAWeights],
83
+ task_ids: list[str] | None = None,
84
+ variance_threshold: float = 0.6,
85
+ num_components: int | None = None,
86
+ adaptive_k: bool = False,
87
+ ) -> SharedSubspace:
88
+ """Step 1: Build shared subspace from existing adapters.
89
+
90
+ Stacks each adapter's flattened weights, runs SVD per layer,
91
+ and projects all adapters onto the resulting basis.
92
+
93
+ Args:
94
+ adapters: List of LoRA adapters to initialize from.
95
+ task_ids: Names for each adapter. Defaults to "task_0", "task_1", etc.
96
+ variance_threshold: Minimum cumulative variance to explain (used if
97
+ num_components is None).
98
+ num_components: Explicit number of basis vectors per layer.
99
+ Overrides variance_threshold if set.
100
+ adaptive_k: If True, select k independently per layer based on
101
+ variance_threshold. Each layer gets the minimal k that explains
102
+ the threshold. Overrides num_components.
103
+ """
104
+ check_adapters_compatible(adapters)
105
+ logger.info("Building subspace from %d adapters", len(adapters))
106
+
107
+ if task_ids is None:
108
+ task_ids = [f"task_{i}" for i in range(len(adapters))]
109
+ if len(task_ids) != len(adapters):
110
+ raise ValueError("task_ids length must match adapters length")
111
+
112
+ # Intersect layer names across all adapters for safety
113
+ layer_set = set(adapters[0].layer_names)
114
+ for adapter in adapters[1:]:
115
+ layer_set &= set(adapter.layer_names)
116
+ layer_names = sorted(layer_set)
117
+
118
+ if not layer_names:
119
+ raise ValueError("Adapters share no common layers")
120
+
121
+ if len(layer_names) < len(adapters[0].layer_names):
122
+ import warnings
123
+ dropped = set(adapters[0].layer_names) - layer_set
124
+ warnings.warn(
125
+ f"Adapters have different layer sets. Using {len(layer_names)} "
126
+ f"common layers (dropped {len(dropped)}: {sorted(dropped)[:3]}...)"
127
+ )
128
+
129
+ rank = adapters[0].rank
130
+
131
+ # Stack weights into data matrices
132
+ stacked_a = stack_lora_weights(adapters, side="A")
133
+ stacked_b = stack_lora_weights(adapters, side="B")
134
+
135
+ components_a: dict[str, Tensor] = {}
136
+ components_b: dict[str, Tensor] = {}
137
+ sv_a: dict[str, Tensor] = {}
138
+ sv_b: dict[str, Tensor] = {}
139
+ means_a: dict[str, Tensor] = {}
140
+ means_b: dict[str, Tensor] = {}
141
+
142
+ resolved_k: int | None = None
143
+
144
+ for layer in layer_names:
145
+ for side, stacked, comp_dict, sv_dict, mean_dict in [
146
+ ("A", stacked_a, components_a, sv_a, means_a),
147
+ ("B", stacked_b, components_b, sv_b, means_b),
148
+ ]:
149
+ data = stacked[layer]
150
+ comps, svals, mean = compute_svd(data, num_components=None, center=True)
151
+
152
+ if adaptive_k:
153
+ # Per-layer: each layer/side gets its own k
154
+ k = select_num_components(svals, variance_threshold)
155
+ elif num_components is not None:
156
+ k = min(num_components, len(svals))
157
+ else:
158
+ k = select_num_components(svals, variance_threshold)
159
+
160
+ if not adaptive_k:
161
+ if resolved_k is None:
162
+ resolved_k = k
163
+ # Use consistent k across layers for simplicity
164
+ k = resolved_k
165
+
166
+ comp_dict[layer] = comps[:k]
167
+ sv_dict[layer] = svals[:k]
168
+ mean_dict[layer] = mean
169
+
170
+ # For adaptive_k, use the max per-layer k as the reported num_components
171
+ if adaptive_k:
172
+ resolved_k = max(
173
+ max(components_a[l].shape[0], components_b[l].shape[0])
174
+ for l in layer_names
175
+ )
176
+ resolved_k = resolved_k or 1
177
+
178
+ # Project all input adapters onto the basis
179
+ tasks: dict[str, TaskProjection] = {}
180
+ for i, (adapter, tid) in enumerate(zip(adapters, task_ids)):
181
+ loadings_a: dict[str, Tensor] = {}
182
+ loadings_b: dict[str, Tensor] = {}
183
+ for layer in layer_names:
184
+ wa = adapter.lora_a[layer].flatten() - means_a[layer]
185
+ wb = adapter.lora_b[layer].flatten() - means_b[layer]
186
+ loadings_a[layer] = project_onto_subspace(wa, components_a[layer])
187
+ loadings_b[layer] = project_onto_subspace(wb, components_b[layer])
188
+ tasks[tid] = TaskProjection(
189
+ task_id=tid, loadings_a=loadings_a, loadings_b=loadings_b
190
+ )
191
+
192
+ logger.info(
193
+ "Subspace built: k=%d, layers=%d, tasks=%d, rank=%d",
194
+ resolved_k, len(layer_names), len(tasks), rank,
195
+ )
196
+
197
+ return cls(
198
+ layer_names=layer_names,
199
+ components_a=components_a,
200
+ components_b=components_b,
201
+ singular_values_a=sv_a,
202
+ singular_values_b=sv_b,
203
+ means_a=means_a,
204
+ means_b=means_b,
205
+ tasks=tasks,
206
+ rank=rank,
207
+ num_components=resolved_k,
208
+ )
209
+
210
+ def project(self, adapter: LoRAWeights, task_id: str) -> TaskProjection:
211
+ """Step 2a: Project a new adapter onto the existing basis."""
212
+ check_adapter_matches_subspace(adapter, self, "project")
213
+ loadings_a: dict[str, Tensor] = {}
214
+ loadings_b: dict[str, Tensor] = {}
215
+
216
+ for layer in self.layer_names:
217
+ wa = adapter.lora_a[layer].flatten() - self.means_a[layer]
218
+ wb = adapter.lora_b[layer].flatten() - self.means_b[layer]
219
+ loadings_a[layer] = project_onto_subspace(wa, self.components_a[layer])
220
+ loadings_b[layer] = project_onto_subspace(wb, self.components_b[layer])
221
+
222
+ return TaskProjection(
223
+ task_id=task_id, loadings_a=loadings_a, loadings_b=loadings_b
224
+ )
225
+
226
+ def add_task(self, projection: TaskProjection) -> None:
227
+ """Register a projected task in the subspace."""
228
+ self.tasks[projection.task_id] = projection
229
+
230
+ def reconstruct(self, task_id: str) -> LoRAWeights:
231
+ """Reconstruct full LoRA weights for a task from its loadings."""
232
+ check_task_exists(self, task_id)
233
+
234
+ proj = self.tasks[task_id]
235
+ lora_a: dict[str, Tensor] = {}
236
+ lora_b: dict[str, Tensor] = {}
237
+
238
+ for layer in self.layer_names:
239
+ flat_a = reconstruct_from_subspace(
240
+ self.components_a[layer], proj.loadings_a[layer]
241
+ ) + self.means_a[layer]
242
+ flat_b = reconstruct_from_subspace(
243
+ self.components_b[layer], proj.loadings_b[layer]
244
+ ) + self.means_b[layer]
245
+
246
+ # Recover original matrix shapes from the adapter's rank
247
+ # A: (rank, in_features), B: (out_features, rank)
248
+ ref_a_shape = (self.rank, flat_a.numel() // self.rank)
249
+ ref_b_shape = (flat_b.numel() // self.rank, self.rank)
250
+ lora_a[layer] = flat_a.reshape(ref_a_shape)
251
+ lora_b[layer] = flat_b.reshape(ref_b_shape)
252
+
253
+ return LoRAWeights(
254
+ layer_names=self.layer_names,
255
+ lora_a=lora_a,
256
+ lora_b=lora_b,
257
+ rank=self.rank,
258
+ )
259
+
260
+ def absorb(self, new_adapter: LoRAWeights, new_task_id: str) -> None:
261
+ """Step 3: Absorb a new adapter, recomputing the shared basis.
262
+
263
+ Reconstructs all existing tasks, adds the new adapter, then
264
+ reruns SVD to produce an updated basis.
265
+ """
266
+ check_adapter_matches_subspace(new_adapter, self, "absorb")
267
+ logger.info("Absorbing adapter '%s' (full SVD recompute, %d existing tasks)", new_task_id, len(self.tasks))
268
+ # Reconstruct all existing tasks as full adapters
269
+ all_adapters = []
270
+ all_ids = []
271
+ for tid, _ in self.tasks.items():
272
+ all_adapters.append(self.reconstruct(tid))
273
+ all_ids.append(tid)
274
+
275
+ all_adapters.append(new_adapter)
276
+ all_ids.append(new_task_id)
277
+
278
+ # Rebuild subspace from scratch
279
+ new_sub = SharedSubspace.from_adapters(
280
+ all_adapters,
281
+ task_ids=all_ids,
282
+ num_components=self.num_components,
283
+ )
284
+
285
+ # Update self in-place
286
+ self.layer_names = new_sub.layer_names
287
+ self.components_a = new_sub.components_a
288
+ self.components_b = new_sub.components_b
289
+ self.singular_values_a = new_sub.singular_values_a
290
+ self.singular_values_b = new_sub.singular_values_b
291
+ self.means_a = new_sub.means_a
292
+ self.means_b = new_sub.means_b
293
+ self.tasks = new_sub.tasks
294
+ self.num_components = new_sub.num_components
295
+
296
+ def absorb_incremental(self, new_adapter: LoRAWeights, new_task_id: str) -> None:
297
+ """Absorb a new adapter incrementally without full SVD recompute.
298
+
299
+ Instead of reconstructing all tasks and re-running SVD, this projects
300
+ the new adapter onto the existing basis, measures the residual, and
301
+ expands the basis with any significant new directions.
302
+
303
+ Much faster than absorb() for large collections, with a small
304
+ approximation trade-off.
305
+ """
306
+ check_adapter_matches_subspace(new_adapter, self, "absorb_incremental")
307
+ logger.debug("Absorbing adapter '%s' incrementally", new_task_id)
308
+ loadings_a: dict[str, Tensor] = {}
309
+ loadings_b: dict[str, Tensor] = {}
310
+
311
+ for layer in self.layer_names:
312
+ for side, weights_dict, comp_attr, sv_attr, mean_attr, load_dict in [
313
+ ("a", new_adapter.lora_a, "components_a", "singular_values_a", "means_a", loadings_a),
314
+ ("b", new_adapter.lora_b, "components_b", "singular_values_b", "means_b", loadings_b),
315
+ ]:
316
+ components = getattr(self, comp_attr)[layer]
317
+ svals = getattr(self, sv_attr)[layer]
318
+ mean = getattr(self, mean_attr)[layer]
319
+ flat = weights_dict[layer].flatten().unsqueeze(0) # (1, D)
320
+
321
+ new_comps, new_svals, new_mean, _ = incremental_svd_update(
322
+ components, svals, mean,
323
+ n_seen=len(self.tasks),
324
+ new_data=flat,
325
+ max_components=self.num_components,
326
+ )
327
+
328
+ getattr(self, comp_attr)[layer] = new_comps
329
+ getattr(self, sv_attr)[layer] = new_svals
330
+ getattr(self, mean_attr)[layer] = new_mean
331
+
332
+ # Project with updated basis
333
+ centered = flat.squeeze(0) - new_mean
334
+ load_dict[layer] = project_onto_subspace(centered, new_comps)
335
+
336
+ # Re-project existing tasks onto updated basis
337
+ for tid, proj in self.tasks.items():
338
+ for layer in self.layer_names:
339
+ # Reconstruct from old loadings, then re-project
340
+ for side, comp_attr, mean_attr, old_loads, new_loads_attr in [
341
+ ("a", "components_a", "means_a", proj.loadings_a, "loadings_a"),
342
+ ("b", "components_b", "means_b", proj.loadings_b, "loadings_b"),
343
+ ]:
344
+ new_comps = getattr(self, comp_attr)[layer]
345
+ # Pad old loadings if basis grew
346
+ old = old_loads[layer]
347
+ if old.shape[0] < new_comps.shape[0]:
348
+ old = torch.cat([old, torch.zeros(new_comps.shape[0] - old.shape[0])])
349
+ elif old.shape[0] > new_comps.shape[0]:
350
+ old = old[:new_comps.shape[0]]
351
+ old_loads[layer] = old
352
+
353
+ self.tasks[new_task_id] = TaskProjection(
354
+ task_id=new_task_id, loadings_a=loadings_a, loadings_b=loadings_b
355
+ )
356
+
357
+ @classmethod
358
+ def from_adapters_streaming(
359
+ cls,
360
+ adapter_paths: list[str | Path],
361
+ task_ids: list[str] | None = None,
362
+ num_components: int = 4,
363
+ ) -> SharedSubspace:
364
+ """Build a subspace by streaming adapters one at a time from disk.
365
+
366
+ Only loads one adapter into memory at a time, unlike from_adapters
367
+ which loads all simultaneously. Uses incremental SVD updates.
368
+
369
+ Args:
370
+ adapter_paths: Paths to adapter directories on disk.
371
+ task_ids: Names for each adapter.
372
+ num_components: Number of basis components.
373
+ """
374
+ from vlora.io import load_adapter
375
+
376
+ if not adapter_paths:
377
+ raise ValueError("Need at least one adapter path")
378
+
379
+ paths = [Path(p) for p in adapter_paths]
380
+ if task_ids is None:
381
+ task_ids = [p.name for p in paths]
382
+
383
+ # Initialize from first adapter(s) — use first two if available
384
+ # so SVD has enough samples to find >1 component
385
+ if len(paths) >= 2:
386
+ init_adapters = [load_adapter(paths[0]), load_adapter(paths[1])]
387
+ init_ids = task_ids[:2]
388
+ remaining = list(zip(paths[2:], task_ids[2:]))
389
+ else:
390
+ init_adapters = [load_adapter(paths[0])]
391
+ init_ids = [task_ids[0]]
392
+ remaining = []
393
+
394
+ sub = cls.from_adapters(init_adapters, task_ids=init_ids, num_components=num_components)
395
+ # Ensure target num_components is preserved even if initial SVD
396
+ # had fewer samples than requested components
397
+ sub.num_components = num_components
398
+
399
+ # Stream remaining adapters
400
+ for path, tid in remaining:
401
+ adapter = load_adapter(path)
402
+ sub.absorb_incremental(adapter, tid)
403
+
404
+ return sub
405
+
406
+ def to(self, device: str | torch.device | None = None, dtype: torch.dtype | None = None) -> SharedSubspace:
407
+ """Move all tensors to a device and/or dtype. Returns self."""
408
+ for layer in self.layer_names:
409
+ for attr in ["components_a", "components_b", "singular_values_a",
410
+ "singular_values_b", "means_a", "means_b"]:
411
+ d = getattr(self, attr)
412
+ t = d[layer]
413
+ if device is not None:
414
+ t = t.to(device=device)
415
+ if dtype is not None:
416
+ t = t.to(dtype=dtype)
417
+ d[layer] = t
418
+
419
+ for proj in self.tasks.values():
420
+ for layer in self.layer_names:
421
+ for loads in [proj.loadings_a, proj.loadings_b]:
422
+ t = loads[layer]
423
+ if device is not None:
424
+ t = t.to(device=device)
425
+ if dtype is not None:
426
+ t = t.to(dtype=dtype)
427
+ loads[layer] = t
428
+
429
+ return self
430
+
431
+ def quantize(self, bits: int = 8) -> SharedSubspace:
432
+ """Quantize components to reduce memory footprint.
433
+
434
+ Applies symmetric per-tensor quantization to the component matrices.
435
+ Loadings and means are kept in float32 for accuracy. This is a
436
+ lossy operation — quantized components introduce small reconstruction
437
+ errors but can reduce memory by 2-4x.
438
+
439
+ Args:
440
+ bits: Quantization bit width (8 or 4). Default 8.
441
+
442
+ Returns:
443
+ self (modified in-place).
444
+ """
445
+ if bits not in (4, 8):
446
+ raise ValueError(f"bits must be 4 or 8, got {bits}")
447
+
448
+ qmax = (1 << (bits - 1)) - 1 # 127 for int8, 7 for int4
449
+
450
+ for layer in self.layer_names:
451
+ for attr in ["components_a", "components_b"]:
452
+ d = getattr(self, attr)
453
+ t = d[layer].float()
454
+ # Symmetric quantization: scale = max(abs(t)) / qmax
455
+ scale = t.abs().max() / qmax
456
+ if scale == 0:
457
+ continue
458
+ # Quantize, round, dequantize
459
+ quantized = (t / scale).round().clamp(-qmax, qmax)
460
+ d[layer] = (quantized * scale).to(t.dtype)
461
+
462
+ return self
463
+
464
+ def compression_stats(self) -> dict:
465
+ """Compute compression statistics for the current subspace.
466
+
467
+ Returns a dict with per-layer and aggregate stats including:
468
+ - components_per_layer: dict of layer -> (k_a, k_b)
469
+ - total_params: total parameters in compressed representation
470
+ - total_original: estimated original parameters (N adapters)
471
+ - compression_ratio: original / compressed
472
+ """
473
+ n_tasks = len(self.tasks)
474
+ total_compressed = 0
475
+ total_original = 0
476
+ per_layer = {}
477
+
478
+ for layer in self.layer_names:
479
+ k_a = self.components_a[layer].shape[0]
480
+ k_b = self.components_b[layer].shape[0]
481
+ dim_a = self.components_a[layer].shape[1]
482
+ dim_b = self.components_b[layer].shape[1]
483
+
484
+ # Compressed: components + means + per-task loadings
485
+ layer_compressed = (
486
+ k_a * dim_a + k_b * dim_b # components
487
+ + dim_a + dim_b # means
488
+ + n_tasks * (k_a + k_b) # loadings
489
+ )
490
+ # Original: N full adapter matrices
491
+ layer_original = n_tasks * (dim_a + dim_b)
492
+
493
+ per_layer[layer] = {
494
+ "k_a": k_a, "k_b": k_b,
495
+ "compressed": layer_compressed,
496
+ "original": layer_original,
497
+ }
498
+ total_compressed += layer_compressed
499
+ total_original += layer_original
500
+
501
+ return {
502
+ "components_per_layer": {l: (d["k_a"], d["k_b"]) for l, d in per_layer.items()},
503
+ "total_params_compressed": total_compressed,
504
+ "total_params_original": total_original,
505
+ "compression_ratio": total_original / total_compressed if total_compressed > 0 else 0,
506
+ "num_tasks": n_tasks,
507
+ "num_layers": len(self.layer_names),
508
+ }
509
+
510
+ def get_trainable_params(
511
+ self, task_id: str, num_expand: int = 0
512
+ ) -> dict[str, Tensor]:
513
+ """Get trainable loading parameters for a task.
514
+
515
+ Useful for integrating with a training loop: freeze the components,
516
+ train only the loadings.
517
+
518
+ Args:
519
+ task_id: Task whose loadings to return.
520
+ num_expand: Number of extra orthogonal directions to add via
521
+ Gram-Schmidt (gives the optimizer room to escape the subspace).
522
+
523
+ Returns:
524
+ Dict of parameter name -> tensor (with requires_grad=True).
525
+ """
526
+ if num_expand > 0:
527
+ import warnings
528
+ warnings.warn(
529
+ f"get_trainable_params(num_expand={num_expand}) will permanently "
530
+ "expand the subspace basis via Gram-Schmidt. This modifies the "
531
+ "subspace in-place and cannot be undone.",
532
+ stacklevel=2,
533
+ )
534
+ for layer in self.layer_names:
535
+ random_a = torch.randn(num_expand, self.components_a[layer].shape[1])
536
+ random_b = torch.randn(num_expand, self.components_b[layer].shape[1])
537
+ self.components_a[layer] = gram_schmidt(self.components_a[layer], random_a)
538
+ self.components_b[layer] = gram_schmidt(self.components_b[layer], random_b)
539
+
540
+ # Re-project the task onto the expanded basis
541
+ proj = self.tasks.get(task_id)
542
+ if proj is not None:
543
+ for layer in self.layer_names:
544
+ old_k_a = proj.loadings_a[layer].shape[0]
545
+ new_k_a = self.components_a[layer].shape[0]
546
+ if new_k_a > old_k_a:
547
+ proj.loadings_a[layer] = torch.cat([
548
+ proj.loadings_a[layer],
549
+ torch.zeros(new_k_a - old_k_a),
550
+ ])
551
+ old_k_b = proj.loadings_b[layer].shape[0]
552
+ new_k_b = self.components_b[layer].shape[0]
553
+ if new_k_b > old_k_b:
554
+ proj.loadings_b[layer] = torch.cat([
555
+ proj.loadings_b[layer],
556
+ torch.zeros(new_k_b - old_k_b),
557
+ ])
558
+
559
+ if task_id not in self.tasks:
560
+ raise KeyError(f"Unknown task: {task_id}")
561
+
562
+ params = {}
563
+ proj = self.tasks[task_id]
564
+ for layer in self.layer_names:
565
+ la = proj.loadings_a[layer].clone().detach().requires_grad_(True)
566
+ lb = proj.loadings_b[layer].clone().detach().requires_grad_(True)
567
+ params[f"{layer}.loadings_a"] = la
568
+ params[f"{layer}.loadings_b"] = lb
569
+
570
+ return params
571
+
572
+ def save(self, path: str | Path) -> None:
573
+ """Serialize the subspace to disk."""
574
+ path = Path(path)
575
+ path.mkdir(parents=True, exist_ok=True)
576
+
577
+ # Save components and means (contiguous() needed for safetensors)
578
+ tensors = {}
579
+ for layer in self.layer_names:
580
+ tensors[f"{layer}.components_a"] = self.components_a[layer].contiguous()
581
+ tensors[f"{layer}.components_b"] = self.components_b[layer].contiguous()
582
+ tensors[f"{layer}.sv_a"] = self.singular_values_a[layer].contiguous()
583
+ tensors[f"{layer}.sv_b"] = self.singular_values_b[layer].contiguous()
584
+ tensors[f"{layer}.mean_a"] = self.means_a[layer].contiguous()
585
+ tensors[f"{layer}.mean_b"] = self.means_b[layer].contiguous()
586
+
587
+ save_file(tensors, str(path / "subspace.safetensors"))
588
+
589
+ # Save per-task loadings
590
+ for tid, proj in self.tasks.items():
591
+ task_tensors = {}
592
+ for layer in self.layer_names:
593
+ task_tensors[f"{layer}.loadings_a"] = proj.loadings_a[layer].contiguous()
594
+ task_tensors[f"{layer}.loadings_b"] = proj.loadings_b[layer].contiguous()
595
+ save_file(task_tensors, str(path / f"task_{tid}.safetensors"))
596
+
597
+ # Save metadata
598
+ import json
599
+ meta = {
600
+ "layer_names": self.layer_names,
601
+ "task_ids": list(self.tasks.keys()),
602
+ "rank": self.rank,
603
+ "num_components": self.num_components,
604
+ }
605
+ with open(path / "subspace_meta.json", "w") as f:
606
+ json.dump(meta, f, indent=2)
607
+
608
+ @classmethod
609
+ def load(cls, path: str | Path) -> SharedSubspace:
610
+ """Deserialize a subspace from disk."""
611
+ import json
612
+
613
+ path = Path(path)
614
+
615
+ with open(path / "subspace_meta.json") as f:
616
+ meta = json.load(f)
617
+
618
+ layer_names = meta["layer_names"]
619
+ task_ids = meta["task_ids"]
620
+ rank = meta["rank"]
621
+ num_components = meta["num_components"]
622
+
623
+ tensors = load_file(str(path / "subspace.safetensors"))
624
+ components_a = {l: tensors[f"{l}.components_a"] for l in layer_names}
625
+ components_b = {l: tensors[f"{l}.components_b"] for l in layer_names}
626
+ sv_a = {l: tensors[f"{l}.sv_a"] for l in layer_names}
627
+ sv_b = {l: tensors[f"{l}.sv_b"] for l in layer_names}
628
+ means_a = {l: tensors[f"{l}.mean_a"] for l in layer_names}
629
+ means_b = {l: tensors[f"{l}.mean_b"] for l in layer_names}
630
+
631
+ tasks = {}
632
+ for tid in task_ids:
633
+ task_tensors = load_file(str(path / f"task_{tid}.safetensors"))
634
+ loadings_a = {l: task_tensors[f"{l}.loadings_a"] for l in layer_names}
635
+ loadings_b = {l: task_tensors[f"{l}.loadings_b"] for l in layer_names}
636
+ tasks[tid] = TaskProjection(
637
+ task_id=tid, loadings_a=loadings_a, loadings_b=loadings_b
638
+ )
639
+
640
+ return cls(
641
+ layer_names=layer_names,
642
+ components_a=components_a,
643
+ components_b=components_b,
644
+ singular_values_a=sv_a,
645
+ singular_values_b=sv_b,
646
+ means_a=means_a,
647
+ means_b=means_b,
648
+ tasks=tasks,
649
+ rank=rank,
650
+ num_components=num_components,
651
+ )