vlora-dev 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vlora/__init__.py +73 -0
- vlora/_validate.py +82 -0
- vlora/analysis.py +191 -0
- vlora/cli.py +430 -0
- vlora/integrations/__init__.py +1 -0
- vlora/integrations/huggingface.py +163 -0
- vlora/io.py +191 -0
- vlora/merge.py +229 -0
- vlora/model.py +148 -0
- vlora/ops.py +229 -0
- vlora/pipeline.py +70 -0
- vlora/router.py +173 -0
- vlora/subspace.py +651 -0
- vlora/training.py +149 -0
- vlora_dev-0.2.0.dist-info/METADATA +409 -0
- vlora_dev-0.2.0.dist-info/RECORD +19 -0
- vlora_dev-0.2.0.dist-info/WHEEL +4 -0
- vlora_dev-0.2.0.dist-info/entry_points.txt +2 -0
- vlora_dev-0.2.0.dist-info/licenses/LICENSE +190 -0
vlora/subspace.py
ADDED
|
@@ -0,0 +1,651 @@
|
|
|
1
|
+
"""SharedSubspace — core state container and 3-step algorithm.
|
|
2
|
+
|
|
3
|
+
Step 1: from_adapters — build shared basis via SVD
|
|
4
|
+
Step 2: project — project new adapter onto basis
|
|
5
|
+
Step 3: absorb — incorporate new adapter, recompute basis
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import logging
|
|
11
|
+
from dataclasses import dataclass, field
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Literal
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger("vlora")
|
|
18
|
+
from safetensors.torch import load_file, save_file
|
|
19
|
+
from torch import Tensor
|
|
20
|
+
|
|
21
|
+
from vlora._validate import (
|
|
22
|
+
check_adapter_matches_subspace,
|
|
23
|
+
check_adapters_compatible,
|
|
24
|
+
check_task_exists,
|
|
25
|
+
check_tensor_health,
|
|
26
|
+
)
|
|
27
|
+
from vlora.io import LoRAWeights, stack_lora_weights
|
|
28
|
+
from vlora.ops import (
|
|
29
|
+
compute_svd,
|
|
30
|
+
explained_variance_ratio,
|
|
31
|
+
gram_schmidt,
|
|
32
|
+
incremental_svd_update,
|
|
33
|
+
project_onto_subspace,
|
|
34
|
+
reconstruct_from_subspace,
|
|
35
|
+
select_num_components,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@dataclass
|
|
40
|
+
class TaskProjection:
|
|
41
|
+
"""A single task's representation in the shared subspace."""
|
|
42
|
+
|
|
43
|
+
task_id: str
|
|
44
|
+
loadings_a: dict[str, Tensor] # layer_name -> (k,)
|
|
45
|
+
loadings_b: dict[str, Tensor] # layer_name -> (k,)
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class SharedSubspace:
|
|
49
|
+
"""Shared low-rank subspace for LoRA adapters.
|
|
50
|
+
|
|
51
|
+
Maintains per-layer orthonormal basis vectors (components) and
|
|
52
|
+
per-task coefficient vectors (loadings).
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
layer_names: list[str],
|
|
58
|
+
components_a: dict[str, Tensor],
|
|
59
|
+
components_b: dict[str, Tensor],
|
|
60
|
+
singular_values_a: dict[str, Tensor],
|
|
61
|
+
singular_values_b: dict[str, Tensor],
|
|
62
|
+
means_a: dict[str, Tensor],
|
|
63
|
+
means_b: dict[str, Tensor],
|
|
64
|
+
tasks: dict[str, TaskProjection],
|
|
65
|
+
rank: int,
|
|
66
|
+
num_components: int,
|
|
67
|
+
):
|
|
68
|
+
self.layer_names = layer_names
|
|
69
|
+
self.components_a = components_a
|
|
70
|
+
self.components_b = components_b
|
|
71
|
+
self.singular_values_a = singular_values_a
|
|
72
|
+
self.singular_values_b = singular_values_b
|
|
73
|
+
self.means_a = means_a
|
|
74
|
+
self.means_b = means_b
|
|
75
|
+
self.tasks = tasks
|
|
76
|
+
self.rank = rank
|
|
77
|
+
self.num_components = num_components
|
|
78
|
+
|
|
79
|
+
@classmethod
|
|
80
|
+
def from_adapters(
|
|
81
|
+
cls,
|
|
82
|
+
adapters: list[LoRAWeights],
|
|
83
|
+
task_ids: list[str] | None = None,
|
|
84
|
+
variance_threshold: float = 0.6,
|
|
85
|
+
num_components: int | None = None,
|
|
86
|
+
adaptive_k: bool = False,
|
|
87
|
+
) -> SharedSubspace:
|
|
88
|
+
"""Step 1: Build shared subspace from existing adapters.
|
|
89
|
+
|
|
90
|
+
Stacks each adapter's flattened weights, runs SVD per layer,
|
|
91
|
+
and projects all adapters onto the resulting basis.
|
|
92
|
+
|
|
93
|
+
Args:
|
|
94
|
+
adapters: List of LoRA adapters to initialize from.
|
|
95
|
+
task_ids: Names for each adapter. Defaults to "task_0", "task_1", etc.
|
|
96
|
+
variance_threshold: Minimum cumulative variance to explain (used if
|
|
97
|
+
num_components is None).
|
|
98
|
+
num_components: Explicit number of basis vectors per layer.
|
|
99
|
+
Overrides variance_threshold if set.
|
|
100
|
+
adaptive_k: If True, select k independently per layer based on
|
|
101
|
+
variance_threshold. Each layer gets the minimal k that explains
|
|
102
|
+
the threshold. Overrides num_components.
|
|
103
|
+
"""
|
|
104
|
+
check_adapters_compatible(adapters)
|
|
105
|
+
logger.info("Building subspace from %d adapters", len(adapters))
|
|
106
|
+
|
|
107
|
+
if task_ids is None:
|
|
108
|
+
task_ids = [f"task_{i}" for i in range(len(adapters))]
|
|
109
|
+
if len(task_ids) != len(adapters):
|
|
110
|
+
raise ValueError("task_ids length must match adapters length")
|
|
111
|
+
|
|
112
|
+
# Intersect layer names across all adapters for safety
|
|
113
|
+
layer_set = set(adapters[0].layer_names)
|
|
114
|
+
for adapter in adapters[1:]:
|
|
115
|
+
layer_set &= set(adapter.layer_names)
|
|
116
|
+
layer_names = sorted(layer_set)
|
|
117
|
+
|
|
118
|
+
if not layer_names:
|
|
119
|
+
raise ValueError("Adapters share no common layers")
|
|
120
|
+
|
|
121
|
+
if len(layer_names) < len(adapters[0].layer_names):
|
|
122
|
+
import warnings
|
|
123
|
+
dropped = set(adapters[0].layer_names) - layer_set
|
|
124
|
+
warnings.warn(
|
|
125
|
+
f"Adapters have different layer sets. Using {len(layer_names)} "
|
|
126
|
+
f"common layers (dropped {len(dropped)}: {sorted(dropped)[:3]}...)"
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
rank = adapters[0].rank
|
|
130
|
+
|
|
131
|
+
# Stack weights into data matrices
|
|
132
|
+
stacked_a = stack_lora_weights(adapters, side="A")
|
|
133
|
+
stacked_b = stack_lora_weights(adapters, side="B")
|
|
134
|
+
|
|
135
|
+
components_a: dict[str, Tensor] = {}
|
|
136
|
+
components_b: dict[str, Tensor] = {}
|
|
137
|
+
sv_a: dict[str, Tensor] = {}
|
|
138
|
+
sv_b: dict[str, Tensor] = {}
|
|
139
|
+
means_a: dict[str, Tensor] = {}
|
|
140
|
+
means_b: dict[str, Tensor] = {}
|
|
141
|
+
|
|
142
|
+
resolved_k: int | None = None
|
|
143
|
+
|
|
144
|
+
for layer in layer_names:
|
|
145
|
+
for side, stacked, comp_dict, sv_dict, mean_dict in [
|
|
146
|
+
("A", stacked_a, components_a, sv_a, means_a),
|
|
147
|
+
("B", stacked_b, components_b, sv_b, means_b),
|
|
148
|
+
]:
|
|
149
|
+
data = stacked[layer]
|
|
150
|
+
comps, svals, mean = compute_svd(data, num_components=None, center=True)
|
|
151
|
+
|
|
152
|
+
if adaptive_k:
|
|
153
|
+
# Per-layer: each layer/side gets its own k
|
|
154
|
+
k = select_num_components(svals, variance_threshold)
|
|
155
|
+
elif num_components is not None:
|
|
156
|
+
k = min(num_components, len(svals))
|
|
157
|
+
else:
|
|
158
|
+
k = select_num_components(svals, variance_threshold)
|
|
159
|
+
|
|
160
|
+
if not adaptive_k:
|
|
161
|
+
if resolved_k is None:
|
|
162
|
+
resolved_k = k
|
|
163
|
+
# Use consistent k across layers for simplicity
|
|
164
|
+
k = resolved_k
|
|
165
|
+
|
|
166
|
+
comp_dict[layer] = comps[:k]
|
|
167
|
+
sv_dict[layer] = svals[:k]
|
|
168
|
+
mean_dict[layer] = mean
|
|
169
|
+
|
|
170
|
+
# For adaptive_k, use the max per-layer k as the reported num_components
|
|
171
|
+
if adaptive_k:
|
|
172
|
+
resolved_k = max(
|
|
173
|
+
max(components_a[l].shape[0], components_b[l].shape[0])
|
|
174
|
+
for l in layer_names
|
|
175
|
+
)
|
|
176
|
+
resolved_k = resolved_k or 1
|
|
177
|
+
|
|
178
|
+
# Project all input adapters onto the basis
|
|
179
|
+
tasks: dict[str, TaskProjection] = {}
|
|
180
|
+
for i, (adapter, tid) in enumerate(zip(adapters, task_ids)):
|
|
181
|
+
loadings_a: dict[str, Tensor] = {}
|
|
182
|
+
loadings_b: dict[str, Tensor] = {}
|
|
183
|
+
for layer in layer_names:
|
|
184
|
+
wa = adapter.lora_a[layer].flatten() - means_a[layer]
|
|
185
|
+
wb = adapter.lora_b[layer].flatten() - means_b[layer]
|
|
186
|
+
loadings_a[layer] = project_onto_subspace(wa, components_a[layer])
|
|
187
|
+
loadings_b[layer] = project_onto_subspace(wb, components_b[layer])
|
|
188
|
+
tasks[tid] = TaskProjection(
|
|
189
|
+
task_id=tid, loadings_a=loadings_a, loadings_b=loadings_b
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
logger.info(
|
|
193
|
+
"Subspace built: k=%d, layers=%d, tasks=%d, rank=%d",
|
|
194
|
+
resolved_k, len(layer_names), len(tasks), rank,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
return cls(
|
|
198
|
+
layer_names=layer_names,
|
|
199
|
+
components_a=components_a,
|
|
200
|
+
components_b=components_b,
|
|
201
|
+
singular_values_a=sv_a,
|
|
202
|
+
singular_values_b=sv_b,
|
|
203
|
+
means_a=means_a,
|
|
204
|
+
means_b=means_b,
|
|
205
|
+
tasks=tasks,
|
|
206
|
+
rank=rank,
|
|
207
|
+
num_components=resolved_k,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
def project(self, adapter: LoRAWeights, task_id: str) -> TaskProjection:
|
|
211
|
+
"""Step 2a: Project a new adapter onto the existing basis."""
|
|
212
|
+
check_adapter_matches_subspace(adapter, self, "project")
|
|
213
|
+
loadings_a: dict[str, Tensor] = {}
|
|
214
|
+
loadings_b: dict[str, Tensor] = {}
|
|
215
|
+
|
|
216
|
+
for layer in self.layer_names:
|
|
217
|
+
wa = adapter.lora_a[layer].flatten() - self.means_a[layer]
|
|
218
|
+
wb = adapter.lora_b[layer].flatten() - self.means_b[layer]
|
|
219
|
+
loadings_a[layer] = project_onto_subspace(wa, self.components_a[layer])
|
|
220
|
+
loadings_b[layer] = project_onto_subspace(wb, self.components_b[layer])
|
|
221
|
+
|
|
222
|
+
return TaskProjection(
|
|
223
|
+
task_id=task_id, loadings_a=loadings_a, loadings_b=loadings_b
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
def add_task(self, projection: TaskProjection) -> None:
|
|
227
|
+
"""Register a projected task in the subspace."""
|
|
228
|
+
self.tasks[projection.task_id] = projection
|
|
229
|
+
|
|
230
|
+
def reconstruct(self, task_id: str) -> LoRAWeights:
|
|
231
|
+
"""Reconstruct full LoRA weights for a task from its loadings."""
|
|
232
|
+
check_task_exists(self, task_id)
|
|
233
|
+
|
|
234
|
+
proj = self.tasks[task_id]
|
|
235
|
+
lora_a: dict[str, Tensor] = {}
|
|
236
|
+
lora_b: dict[str, Tensor] = {}
|
|
237
|
+
|
|
238
|
+
for layer in self.layer_names:
|
|
239
|
+
flat_a = reconstruct_from_subspace(
|
|
240
|
+
self.components_a[layer], proj.loadings_a[layer]
|
|
241
|
+
) + self.means_a[layer]
|
|
242
|
+
flat_b = reconstruct_from_subspace(
|
|
243
|
+
self.components_b[layer], proj.loadings_b[layer]
|
|
244
|
+
) + self.means_b[layer]
|
|
245
|
+
|
|
246
|
+
# Recover original matrix shapes from the adapter's rank
|
|
247
|
+
# A: (rank, in_features), B: (out_features, rank)
|
|
248
|
+
ref_a_shape = (self.rank, flat_a.numel() // self.rank)
|
|
249
|
+
ref_b_shape = (flat_b.numel() // self.rank, self.rank)
|
|
250
|
+
lora_a[layer] = flat_a.reshape(ref_a_shape)
|
|
251
|
+
lora_b[layer] = flat_b.reshape(ref_b_shape)
|
|
252
|
+
|
|
253
|
+
return LoRAWeights(
|
|
254
|
+
layer_names=self.layer_names,
|
|
255
|
+
lora_a=lora_a,
|
|
256
|
+
lora_b=lora_b,
|
|
257
|
+
rank=self.rank,
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
def absorb(self, new_adapter: LoRAWeights, new_task_id: str) -> None:
|
|
261
|
+
"""Step 3: Absorb a new adapter, recomputing the shared basis.
|
|
262
|
+
|
|
263
|
+
Reconstructs all existing tasks, adds the new adapter, then
|
|
264
|
+
reruns SVD to produce an updated basis.
|
|
265
|
+
"""
|
|
266
|
+
check_adapter_matches_subspace(new_adapter, self, "absorb")
|
|
267
|
+
logger.info("Absorbing adapter '%s' (full SVD recompute, %d existing tasks)", new_task_id, len(self.tasks))
|
|
268
|
+
# Reconstruct all existing tasks as full adapters
|
|
269
|
+
all_adapters = []
|
|
270
|
+
all_ids = []
|
|
271
|
+
for tid, _ in self.tasks.items():
|
|
272
|
+
all_adapters.append(self.reconstruct(tid))
|
|
273
|
+
all_ids.append(tid)
|
|
274
|
+
|
|
275
|
+
all_adapters.append(new_adapter)
|
|
276
|
+
all_ids.append(new_task_id)
|
|
277
|
+
|
|
278
|
+
# Rebuild subspace from scratch
|
|
279
|
+
new_sub = SharedSubspace.from_adapters(
|
|
280
|
+
all_adapters,
|
|
281
|
+
task_ids=all_ids,
|
|
282
|
+
num_components=self.num_components,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
# Update self in-place
|
|
286
|
+
self.layer_names = new_sub.layer_names
|
|
287
|
+
self.components_a = new_sub.components_a
|
|
288
|
+
self.components_b = new_sub.components_b
|
|
289
|
+
self.singular_values_a = new_sub.singular_values_a
|
|
290
|
+
self.singular_values_b = new_sub.singular_values_b
|
|
291
|
+
self.means_a = new_sub.means_a
|
|
292
|
+
self.means_b = new_sub.means_b
|
|
293
|
+
self.tasks = new_sub.tasks
|
|
294
|
+
self.num_components = new_sub.num_components
|
|
295
|
+
|
|
296
|
+
def absorb_incremental(self, new_adapter: LoRAWeights, new_task_id: str) -> None:
|
|
297
|
+
"""Absorb a new adapter incrementally without full SVD recompute.
|
|
298
|
+
|
|
299
|
+
Instead of reconstructing all tasks and re-running SVD, this projects
|
|
300
|
+
the new adapter onto the existing basis, measures the residual, and
|
|
301
|
+
expands the basis with any significant new directions.
|
|
302
|
+
|
|
303
|
+
Much faster than absorb() for large collections, with a small
|
|
304
|
+
approximation trade-off.
|
|
305
|
+
"""
|
|
306
|
+
check_adapter_matches_subspace(new_adapter, self, "absorb_incremental")
|
|
307
|
+
logger.debug("Absorbing adapter '%s' incrementally", new_task_id)
|
|
308
|
+
loadings_a: dict[str, Tensor] = {}
|
|
309
|
+
loadings_b: dict[str, Tensor] = {}
|
|
310
|
+
|
|
311
|
+
for layer in self.layer_names:
|
|
312
|
+
for side, weights_dict, comp_attr, sv_attr, mean_attr, load_dict in [
|
|
313
|
+
("a", new_adapter.lora_a, "components_a", "singular_values_a", "means_a", loadings_a),
|
|
314
|
+
("b", new_adapter.lora_b, "components_b", "singular_values_b", "means_b", loadings_b),
|
|
315
|
+
]:
|
|
316
|
+
components = getattr(self, comp_attr)[layer]
|
|
317
|
+
svals = getattr(self, sv_attr)[layer]
|
|
318
|
+
mean = getattr(self, mean_attr)[layer]
|
|
319
|
+
flat = weights_dict[layer].flatten().unsqueeze(0) # (1, D)
|
|
320
|
+
|
|
321
|
+
new_comps, new_svals, new_mean, _ = incremental_svd_update(
|
|
322
|
+
components, svals, mean,
|
|
323
|
+
n_seen=len(self.tasks),
|
|
324
|
+
new_data=flat,
|
|
325
|
+
max_components=self.num_components,
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
getattr(self, comp_attr)[layer] = new_comps
|
|
329
|
+
getattr(self, sv_attr)[layer] = new_svals
|
|
330
|
+
getattr(self, mean_attr)[layer] = new_mean
|
|
331
|
+
|
|
332
|
+
# Project with updated basis
|
|
333
|
+
centered = flat.squeeze(0) - new_mean
|
|
334
|
+
load_dict[layer] = project_onto_subspace(centered, new_comps)
|
|
335
|
+
|
|
336
|
+
# Re-project existing tasks onto updated basis
|
|
337
|
+
for tid, proj in self.tasks.items():
|
|
338
|
+
for layer in self.layer_names:
|
|
339
|
+
# Reconstruct from old loadings, then re-project
|
|
340
|
+
for side, comp_attr, mean_attr, old_loads, new_loads_attr in [
|
|
341
|
+
("a", "components_a", "means_a", proj.loadings_a, "loadings_a"),
|
|
342
|
+
("b", "components_b", "means_b", proj.loadings_b, "loadings_b"),
|
|
343
|
+
]:
|
|
344
|
+
new_comps = getattr(self, comp_attr)[layer]
|
|
345
|
+
# Pad old loadings if basis grew
|
|
346
|
+
old = old_loads[layer]
|
|
347
|
+
if old.shape[0] < new_comps.shape[0]:
|
|
348
|
+
old = torch.cat([old, torch.zeros(new_comps.shape[0] - old.shape[0])])
|
|
349
|
+
elif old.shape[0] > new_comps.shape[0]:
|
|
350
|
+
old = old[:new_comps.shape[0]]
|
|
351
|
+
old_loads[layer] = old
|
|
352
|
+
|
|
353
|
+
self.tasks[new_task_id] = TaskProjection(
|
|
354
|
+
task_id=new_task_id, loadings_a=loadings_a, loadings_b=loadings_b
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
@classmethod
|
|
358
|
+
def from_adapters_streaming(
|
|
359
|
+
cls,
|
|
360
|
+
adapter_paths: list[str | Path],
|
|
361
|
+
task_ids: list[str] | None = None,
|
|
362
|
+
num_components: int = 4,
|
|
363
|
+
) -> SharedSubspace:
|
|
364
|
+
"""Build a subspace by streaming adapters one at a time from disk.
|
|
365
|
+
|
|
366
|
+
Only loads one adapter into memory at a time, unlike from_adapters
|
|
367
|
+
which loads all simultaneously. Uses incremental SVD updates.
|
|
368
|
+
|
|
369
|
+
Args:
|
|
370
|
+
adapter_paths: Paths to adapter directories on disk.
|
|
371
|
+
task_ids: Names for each adapter.
|
|
372
|
+
num_components: Number of basis components.
|
|
373
|
+
"""
|
|
374
|
+
from vlora.io import load_adapter
|
|
375
|
+
|
|
376
|
+
if not adapter_paths:
|
|
377
|
+
raise ValueError("Need at least one adapter path")
|
|
378
|
+
|
|
379
|
+
paths = [Path(p) for p in adapter_paths]
|
|
380
|
+
if task_ids is None:
|
|
381
|
+
task_ids = [p.name for p in paths]
|
|
382
|
+
|
|
383
|
+
# Initialize from first adapter(s) — use first two if available
|
|
384
|
+
# so SVD has enough samples to find >1 component
|
|
385
|
+
if len(paths) >= 2:
|
|
386
|
+
init_adapters = [load_adapter(paths[0]), load_adapter(paths[1])]
|
|
387
|
+
init_ids = task_ids[:2]
|
|
388
|
+
remaining = list(zip(paths[2:], task_ids[2:]))
|
|
389
|
+
else:
|
|
390
|
+
init_adapters = [load_adapter(paths[0])]
|
|
391
|
+
init_ids = [task_ids[0]]
|
|
392
|
+
remaining = []
|
|
393
|
+
|
|
394
|
+
sub = cls.from_adapters(init_adapters, task_ids=init_ids, num_components=num_components)
|
|
395
|
+
# Ensure target num_components is preserved even if initial SVD
|
|
396
|
+
# had fewer samples than requested components
|
|
397
|
+
sub.num_components = num_components
|
|
398
|
+
|
|
399
|
+
# Stream remaining adapters
|
|
400
|
+
for path, tid in remaining:
|
|
401
|
+
adapter = load_adapter(path)
|
|
402
|
+
sub.absorb_incremental(adapter, tid)
|
|
403
|
+
|
|
404
|
+
return sub
|
|
405
|
+
|
|
406
|
+
def to(self, device: str | torch.device | None = None, dtype: torch.dtype | None = None) -> SharedSubspace:
|
|
407
|
+
"""Move all tensors to a device and/or dtype. Returns self."""
|
|
408
|
+
for layer in self.layer_names:
|
|
409
|
+
for attr in ["components_a", "components_b", "singular_values_a",
|
|
410
|
+
"singular_values_b", "means_a", "means_b"]:
|
|
411
|
+
d = getattr(self, attr)
|
|
412
|
+
t = d[layer]
|
|
413
|
+
if device is not None:
|
|
414
|
+
t = t.to(device=device)
|
|
415
|
+
if dtype is not None:
|
|
416
|
+
t = t.to(dtype=dtype)
|
|
417
|
+
d[layer] = t
|
|
418
|
+
|
|
419
|
+
for proj in self.tasks.values():
|
|
420
|
+
for layer in self.layer_names:
|
|
421
|
+
for loads in [proj.loadings_a, proj.loadings_b]:
|
|
422
|
+
t = loads[layer]
|
|
423
|
+
if device is not None:
|
|
424
|
+
t = t.to(device=device)
|
|
425
|
+
if dtype is not None:
|
|
426
|
+
t = t.to(dtype=dtype)
|
|
427
|
+
loads[layer] = t
|
|
428
|
+
|
|
429
|
+
return self
|
|
430
|
+
|
|
431
|
+
def quantize(self, bits: int = 8) -> SharedSubspace:
|
|
432
|
+
"""Quantize components to reduce memory footprint.
|
|
433
|
+
|
|
434
|
+
Applies symmetric per-tensor quantization to the component matrices.
|
|
435
|
+
Loadings and means are kept in float32 for accuracy. This is a
|
|
436
|
+
lossy operation — quantized components introduce small reconstruction
|
|
437
|
+
errors but can reduce memory by 2-4x.
|
|
438
|
+
|
|
439
|
+
Args:
|
|
440
|
+
bits: Quantization bit width (8 or 4). Default 8.
|
|
441
|
+
|
|
442
|
+
Returns:
|
|
443
|
+
self (modified in-place).
|
|
444
|
+
"""
|
|
445
|
+
if bits not in (4, 8):
|
|
446
|
+
raise ValueError(f"bits must be 4 or 8, got {bits}")
|
|
447
|
+
|
|
448
|
+
qmax = (1 << (bits - 1)) - 1 # 127 for int8, 7 for int4
|
|
449
|
+
|
|
450
|
+
for layer in self.layer_names:
|
|
451
|
+
for attr in ["components_a", "components_b"]:
|
|
452
|
+
d = getattr(self, attr)
|
|
453
|
+
t = d[layer].float()
|
|
454
|
+
# Symmetric quantization: scale = max(abs(t)) / qmax
|
|
455
|
+
scale = t.abs().max() / qmax
|
|
456
|
+
if scale == 0:
|
|
457
|
+
continue
|
|
458
|
+
# Quantize, round, dequantize
|
|
459
|
+
quantized = (t / scale).round().clamp(-qmax, qmax)
|
|
460
|
+
d[layer] = (quantized * scale).to(t.dtype)
|
|
461
|
+
|
|
462
|
+
return self
|
|
463
|
+
|
|
464
|
+
def compression_stats(self) -> dict:
|
|
465
|
+
"""Compute compression statistics for the current subspace.
|
|
466
|
+
|
|
467
|
+
Returns a dict with per-layer and aggregate stats including:
|
|
468
|
+
- components_per_layer: dict of layer -> (k_a, k_b)
|
|
469
|
+
- total_params: total parameters in compressed representation
|
|
470
|
+
- total_original: estimated original parameters (N adapters)
|
|
471
|
+
- compression_ratio: original / compressed
|
|
472
|
+
"""
|
|
473
|
+
n_tasks = len(self.tasks)
|
|
474
|
+
total_compressed = 0
|
|
475
|
+
total_original = 0
|
|
476
|
+
per_layer = {}
|
|
477
|
+
|
|
478
|
+
for layer in self.layer_names:
|
|
479
|
+
k_a = self.components_a[layer].shape[0]
|
|
480
|
+
k_b = self.components_b[layer].shape[0]
|
|
481
|
+
dim_a = self.components_a[layer].shape[1]
|
|
482
|
+
dim_b = self.components_b[layer].shape[1]
|
|
483
|
+
|
|
484
|
+
# Compressed: components + means + per-task loadings
|
|
485
|
+
layer_compressed = (
|
|
486
|
+
k_a * dim_a + k_b * dim_b # components
|
|
487
|
+
+ dim_a + dim_b # means
|
|
488
|
+
+ n_tasks * (k_a + k_b) # loadings
|
|
489
|
+
)
|
|
490
|
+
# Original: N full adapter matrices
|
|
491
|
+
layer_original = n_tasks * (dim_a + dim_b)
|
|
492
|
+
|
|
493
|
+
per_layer[layer] = {
|
|
494
|
+
"k_a": k_a, "k_b": k_b,
|
|
495
|
+
"compressed": layer_compressed,
|
|
496
|
+
"original": layer_original,
|
|
497
|
+
}
|
|
498
|
+
total_compressed += layer_compressed
|
|
499
|
+
total_original += layer_original
|
|
500
|
+
|
|
501
|
+
return {
|
|
502
|
+
"components_per_layer": {l: (d["k_a"], d["k_b"]) for l, d in per_layer.items()},
|
|
503
|
+
"total_params_compressed": total_compressed,
|
|
504
|
+
"total_params_original": total_original,
|
|
505
|
+
"compression_ratio": total_original / total_compressed if total_compressed > 0 else 0,
|
|
506
|
+
"num_tasks": n_tasks,
|
|
507
|
+
"num_layers": len(self.layer_names),
|
|
508
|
+
}
|
|
509
|
+
|
|
510
|
+
def get_trainable_params(
|
|
511
|
+
self, task_id: str, num_expand: int = 0
|
|
512
|
+
) -> dict[str, Tensor]:
|
|
513
|
+
"""Get trainable loading parameters for a task.
|
|
514
|
+
|
|
515
|
+
Useful for integrating with a training loop: freeze the components,
|
|
516
|
+
train only the loadings.
|
|
517
|
+
|
|
518
|
+
Args:
|
|
519
|
+
task_id: Task whose loadings to return.
|
|
520
|
+
num_expand: Number of extra orthogonal directions to add via
|
|
521
|
+
Gram-Schmidt (gives the optimizer room to escape the subspace).
|
|
522
|
+
|
|
523
|
+
Returns:
|
|
524
|
+
Dict of parameter name -> tensor (with requires_grad=True).
|
|
525
|
+
"""
|
|
526
|
+
if num_expand > 0:
|
|
527
|
+
import warnings
|
|
528
|
+
warnings.warn(
|
|
529
|
+
f"get_trainable_params(num_expand={num_expand}) will permanently "
|
|
530
|
+
"expand the subspace basis via Gram-Schmidt. This modifies the "
|
|
531
|
+
"subspace in-place and cannot be undone.",
|
|
532
|
+
stacklevel=2,
|
|
533
|
+
)
|
|
534
|
+
for layer in self.layer_names:
|
|
535
|
+
random_a = torch.randn(num_expand, self.components_a[layer].shape[1])
|
|
536
|
+
random_b = torch.randn(num_expand, self.components_b[layer].shape[1])
|
|
537
|
+
self.components_a[layer] = gram_schmidt(self.components_a[layer], random_a)
|
|
538
|
+
self.components_b[layer] = gram_schmidt(self.components_b[layer], random_b)
|
|
539
|
+
|
|
540
|
+
# Re-project the task onto the expanded basis
|
|
541
|
+
proj = self.tasks.get(task_id)
|
|
542
|
+
if proj is not None:
|
|
543
|
+
for layer in self.layer_names:
|
|
544
|
+
old_k_a = proj.loadings_a[layer].shape[0]
|
|
545
|
+
new_k_a = self.components_a[layer].shape[0]
|
|
546
|
+
if new_k_a > old_k_a:
|
|
547
|
+
proj.loadings_a[layer] = torch.cat([
|
|
548
|
+
proj.loadings_a[layer],
|
|
549
|
+
torch.zeros(new_k_a - old_k_a),
|
|
550
|
+
])
|
|
551
|
+
old_k_b = proj.loadings_b[layer].shape[0]
|
|
552
|
+
new_k_b = self.components_b[layer].shape[0]
|
|
553
|
+
if new_k_b > old_k_b:
|
|
554
|
+
proj.loadings_b[layer] = torch.cat([
|
|
555
|
+
proj.loadings_b[layer],
|
|
556
|
+
torch.zeros(new_k_b - old_k_b),
|
|
557
|
+
])
|
|
558
|
+
|
|
559
|
+
if task_id not in self.tasks:
|
|
560
|
+
raise KeyError(f"Unknown task: {task_id}")
|
|
561
|
+
|
|
562
|
+
params = {}
|
|
563
|
+
proj = self.tasks[task_id]
|
|
564
|
+
for layer in self.layer_names:
|
|
565
|
+
la = proj.loadings_a[layer].clone().detach().requires_grad_(True)
|
|
566
|
+
lb = proj.loadings_b[layer].clone().detach().requires_grad_(True)
|
|
567
|
+
params[f"{layer}.loadings_a"] = la
|
|
568
|
+
params[f"{layer}.loadings_b"] = lb
|
|
569
|
+
|
|
570
|
+
return params
|
|
571
|
+
|
|
572
|
+
def save(self, path: str | Path) -> None:
|
|
573
|
+
"""Serialize the subspace to disk."""
|
|
574
|
+
path = Path(path)
|
|
575
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
576
|
+
|
|
577
|
+
# Save components and means (contiguous() needed for safetensors)
|
|
578
|
+
tensors = {}
|
|
579
|
+
for layer in self.layer_names:
|
|
580
|
+
tensors[f"{layer}.components_a"] = self.components_a[layer].contiguous()
|
|
581
|
+
tensors[f"{layer}.components_b"] = self.components_b[layer].contiguous()
|
|
582
|
+
tensors[f"{layer}.sv_a"] = self.singular_values_a[layer].contiguous()
|
|
583
|
+
tensors[f"{layer}.sv_b"] = self.singular_values_b[layer].contiguous()
|
|
584
|
+
tensors[f"{layer}.mean_a"] = self.means_a[layer].contiguous()
|
|
585
|
+
tensors[f"{layer}.mean_b"] = self.means_b[layer].contiguous()
|
|
586
|
+
|
|
587
|
+
save_file(tensors, str(path / "subspace.safetensors"))
|
|
588
|
+
|
|
589
|
+
# Save per-task loadings
|
|
590
|
+
for tid, proj in self.tasks.items():
|
|
591
|
+
task_tensors = {}
|
|
592
|
+
for layer in self.layer_names:
|
|
593
|
+
task_tensors[f"{layer}.loadings_a"] = proj.loadings_a[layer].contiguous()
|
|
594
|
+
task_tensors[f"{layer}.loadings_b"] = proj.loadings_b[layer].contiguous()
|
|
595
|
+
save_file(task_tensors, str(path / f"task_{tid}.safetensors"))
|
|
596
|
+
|
|
597
|
+
# Save metadata
|
|
598
|
+
import json
|
|
599
|
+
meta = {
|
|
600
|
+
"layer_names": self.layer_names,
|
|
601
|
+
"task_ids": list(self.tasks.keys()),
|
|
602
|
+
"rank": self.rank,
|
|
603
|
+
"num_components": self.num_components,
|
|
604
|
+
}
|
|
605
|
+
with open(path / "subspace_meta.json", "w") as f:
|
|
606
|
+
json.dump(meta, f, indent=2)
|
|
607
|
+
|
|
608
|
+
@classmethod
|
|
609
|
+
def load(cls, path: str | Path) -> SharedSubspace:
|
|
610
|
+
"""Deserialize a subspace from disk."""
|
|
611
|
+
import json
|
|
612
|
+
|
|
613
|
+
path = Path(path)
|
|
614
|
+
|
|
615
|
+
with open(path / "subspace_meta.json") as f:
|
|
616
|
+
meta = json.load(f)
|
|
617
|
+
|
|
618
|
+
layer_names = meta["layer_names"]
|
|
619
|
+
task_ids = meta["task_ids"]
|
|
620
|
+
rank = meta["rank"]
|
|
621
|
+
num_components = meta["num_components"]
|
|
622
|
+
|
|
623
|
+
tensors = load_file(str(path / "subspace.safetensors"))
|
|
624
|
+
components_a = {l: tensors[f"{l}.components_a"] for l in layer_names}
|
|
625
|
+
components_b = {l: tensors[f"{l}.components_b"] for l in layer_names}
|
|
626
|
+
sv_a = {l: tensors[f"{l}.sv_a"] for l in layer_names}
|
|
627
|
+
sv_b = {l: tensors[f"{l}.sv_b"] for l in layer_names}
|
|
628
|
+
means_a = {l: tensors[f"{l}.mean_a"] for l in layer_names}
|
|
629
|
+
means_b = {l: tensors[f"{l}.mean_b"] for l in layer_names}
|
|
630
|
+
|
|
631
|
+
tasks = {}
|
|
632
|
+
for tid in task_ids:
|
|
633
|
+
task_tensors = load_file(str(path / f"task_{tid}.safetensors"))
|
|
634
|
+
loadings_a = {l: task_tensors[f"{l}.loadings_a"] for l in layer_names}
|
|
635
|
+
loadings_b = {l: task_tensors[f"{l}.loadings_b"] for l in layer_names}
|
|
636
|
+
tasks[tid] = TaskProjection(
|
|
637
|
+
task_id=tid, loadings_a=loadings_a, loadings_b=loadings_b
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
return cls(
|
|
641
|
+
layer_names=layer_names,
|
|
642
|
+
components_a=components_a,
|
|
643
|
+
components_b=components_b,
|
|
644
|
+
singular_values_a=sv_a,
|
|
645
|
+
singular_values_b=sv_b,
|
|
646
|
+
means_a=means_a,
|
|
647
|
+
means_b=means_b,
|
|
648
|
+
tasks=tasks,
|
|
649
|
+
rank=rank,
|
|
650
|
+
num_components=num_components,
|
|
651
|
+
)
|