vlora-dev 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vlora/__init__.py +73 -0
- vlora/_validate.py +82 -0
- vlora/analysis.py +191 -0
- vlora/cli.py +430 -0
- vlora/integrations/__init__.py +1 -0
- vlora/integrations/huggingface.py +163 -0
- vlora/io.py +191 -0
- vlora/merge.py +229 -0
- vlora/model.py +148 -0
- vlora/ops.py +229 -0
- vlora/pipeline.py +70 -0
- vlora/router.py +173 -0
- vlora/subspace.py +651 -0
- vlora/training.py +149 -0
- vlora_dev-0.2.0.dist-info/METADATA +409 -0
- vlora_dev-0.2.0.dist-info/RECORD +19 -0
- vlora_dev-0.2.0.dist-info/WHEEL +4 -0
- vlora_dev-0.2.0.dist-info/entry_points.txt +2 -0
- vlora_dev-0.2.0.dist-info/licenses/LICENSE +190 -0
vlora/ops.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
"""Pure math operations for shared low-rank subspace computation.
|
|
2
|
+
|
|
3
|
+
All functions are stateless and operate on raw tensors — no adapter
|
|
4
|
+
or subspace concepts leak in here.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from torch import Tensor
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def compute_svd(
|
|
14
|
+
data_matrix: Tensor,
|
|
15
|
+
num_components: int | None = None,
|
|
16
|
+
center: bool = True,
|
|
17
|
+
) -> tuple[Tensor, Tensor, Tensor]:
|
|
18
|
+
"""SVD on data matrix, returning top-k components.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
data_matrix: (N, D) matrix where rows are observations.
|
|
22
|
+
num_components: Number of singular vectors to keep. If None, keep all.
|
|
23
|
+
center: Whether to mean-center rows before SVD.
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
components: (k, D) right singular vectors (shared basis).
|
|
27
|
+
singular_values: (k,) corresponding singular values.
|
|
28
|
+
mean: (D,) row mean (zeros if center=False).
|
|
29
|
+
"""
|
|
30
|
+
if center:
|
|
31
|
+
mean = data_matrix.mean(dim=0)
|
|
32
|
+
centered = data_matrix - mean
|
|
33
|
+
else:
|
|
34
|
+
mean = torch.zeros(data_matrix.shape[1], dtype=data_matrix.dtype, device=data_matrix.device)
|
|
35
|
+
centered = data_matrix
|
|
36
|
+
|
|
37
|
+
# full_matrices=False gives the economy SVD — U is (N, min(N,D))
|
|
38
|
+
U, S, Vh = torch.linalg.svd(centered, full_matrices=False)
|
|
39
|
+
|
|
40
|
+
if torch.isnan(S).any():
|
|
41
|
+
raise ValueError(
|
|
42
|
+
"SVD produced NaN singular values. Check input data for "
|
|
43
|
+
"NaN/Inf or try using float64 precision."
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
k = num_components if num_components is not None else len(S)
|
|
47
|
+
k = min(k, len(S))
|
|
48
|
+
|
|
49
|
+
return Vh[:k], S[:k], mean
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def project_onto_subspace(weights: Tensor, components: Tensor) -> Tensor:
|
|
53
|
+
"""Project weight vectors onto the subspace basis.
|
|
54
|
+
|
|
55
|
+
Args:
|
|
56
|
+
weights: (D,) or (N, D) weight vectors to project.
|
|
57
|
+
components: (k, D) orthonormal basis vectors.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
loadings: (k,) or (N, k) projection coefficients.
|
|
61
|
+
"""
|
|
62
|
+
# loadings = weights @ V^T (since components = V^T, i.e. rows are basis vectors)
|
|
63
|
+
return weights @ components.T
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def reconstruct_from_subspace(components: Tensor, loadings: Tensor) -> Tensor:
|
|
67
|
+
"""Reconstruct weight vectors from subspace loadings.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
components: (k, D) orthonormal basis vectors.
|
|
71
|
+
loadings: (k,) or (N, k) projection coefficients.
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
reconstructed: (D,) or (N, D) reconstructed weight vectors.
|
|
75
|
+
"""
|
|
76
|
+
return loadings @ components
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def gram_schmidt(basis: Tensor, new_vectors: Tensor) -> Tensor:
|
|
80
|
+
"""Orthogonalize new_vectors against an existing orthonormal basis.
|
|
81
|
+
|
|
82
|
+
Appends only those new directions that have non-trivial norm after
|
|
83
|
+
projection removal (threshold: 1e-6).
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
basis: (k, D) existing orthonormal basis.
|
|
87
|
+
new_vectors: (m, D) candidate vectors.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
expanded_basis: (k + n, D) where n <= m new orthogonal directions
|
|
91
|
+
were found.
|
|
92
|
+
"""
|
|
93
|
+
vectors = list(basis)
|
|
94
|
+
|
|
95
|
+
for v in new_vectors:
|
|
96
|
+
v = v.clone()
|
|
97
|
+
# Remove components along every existing basis vector
|
|
98
|
+
for b in vectors:
|
|
99
|
+
v = v - (v @ b) * b
|
|
100
|
+
norm = v.norm()
|
|
101
|
+
if norm > 1e-6:
|
|
102
|
+
vectors.append(v / norm)
|
|
103
|
+
|
|
104
|
+
return torch.stack(vectors)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def incremental_svd_update(
|
|
108
|
+
components: Tensor,
|
|
109
|
+
singular_values: Tensor,
|
|
110
|
+
mean: Tensor,
|
|
111
|
+
n_seen: int,
|
|
112
|
+
new_data: Tensor,
|
|
113
|
+
max_components: int | None = None,
|
|
114
|
+
) -> tuple[Tensor, Tensor, Tensor, int]:
|
|
115
|
+
"""Incrementally update SVD with new data points.
|
|
116
|
+
|
|
117
|
+
Uses the projection-residual approach: project new data onto existing
|
|
118
|
+
basis, compute residual, and if significant, expand the basis with
|
|
119
|
+
the residual direction via QR decomposition.
|
|
120
|
+
|
|
121
|
+
Args:
|
|
122
|
+
components: (k, D) current orthonormal basis.
|
|
123
|
+
singular_values: (k,) current singular values.
|
|
124
|
+
mean: (D,) current data mean.
|
|
125
|
+
n_seen: Number of data points seen so far.
|
|
126
|
+
new_data: (m, D) new data points to incorporate.
|
|
127
|
+
max_components: Cap on number of components. If None, allows growth.
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
updated_components: (k', D) updated basis.
|
|
131
|
+
updated_singular_values: (k',) updated singular values.
|
|
132
|
+
updated_mean: (D,) updated mean.
|
|
133
|
+
new_n_seen: Updated count.
|
|
134
|
+
"""
|
|
135
|
+
m = new_data.shape[0]
|
|
136
|
+
n_total = n_seen + m
|
|
137
|
+
|
|
138
|
+
# Update mean incrementally
|
|
139
|
+
new_mean = (mean * n_seen + new_data.sum(dim=0)) / n_total
|
|
140
|
+
|
|
141
|
+
# Center new data with updated mean
|
|
142
|
+
centered_new = new_data - new_mean
|
|
143
|
+
|
|
144
|
+
# Also adjust for mean shift on existing data:
|
|
145
|
+
# The old centered data had mean=0 relative to old_mean.
|
|
146
|
+
# Relative to new_mean, old data is shifted by (old_mean - new_mean).
|
|
147
|
+
mean_shift = mean - new_mean
|
|
148
|
+
|
|
149
|
+
# Project new centered data onto existing basis
|
|
150
|
+
projections = centered_new @ components.T # (m, k)
|
|
151
|
+
residuals = centered_new - projections @ components # (m, D)
|
|
152
|
+
|
|
153
|
+
# Find new orthogonal directions from residuals via QR
|
|
154
|
+
residual_norms = residuals.norm(dim=1)
|
|
155
|
+
significant = residual_norms > 1e-6
|
|
156
|
+
new_directions = []
|
|
157
|
+
|
|
158
|
+
if significant.any():
|
|
159
|
+
sig_residuals = residuals[significant]
|
|
160
|
+
# Orthogonalize residuals against existing basis and each other
|
|
161
|
+
expanded = gram_schmidt(components, sig_residuals)
|
|
162
|
+
new_directions_tensor = expanded[components.shape[0]:]
|
|
163
|
+
if new_directions_tensor.shape[0] > 0:
|
|
164
|
+
new_directions.append(new_directions_tensor)
|
|
165
|
+
|
|
166
|
+
# Build augmented system
|
|
167
|
+
if new_directions:
|
|
168
|
+
extra = torch.cat(new_directions, dim=0)
|
|
169
|
+
all_components = torch.cat([components, extra], dim=0)
|
|
170
|
+
# Approximate new singular values for expanded directions
|
|
171
|
+
extra_projections = centered_new @ extra.T # (m, n_extra)
|
|
172
|
+
extra_svals = extra_projections.norm(dim=0)
|
|
173
|
+
all_svals = torch.cat([singular_values, extra_svals])
|
|
174
|
+
else:
|
|
175
|
+
all_components = components
|
|
176
|
+
# Update singular values to account for new data contribution
|
|
177
|
+
new_contributions = projections.norm(dim=0)
|
|
178
|
+
all_svals = torch.sqrt(singular_values ** 2 + new_contributions ** 2)
|
|
179
|
+
|
|
180
|
+
# Account for mean shift effect on singular values
|
|
181
|
+
shift_proj = mean_shift @ all_components.T
|
|
182
|
+
shift_contribution = shift_proj * (n_seen ** 0.5)
|
|
183
|
+
all_svals = torch.sqrt(all_svals ** 2 + shift_contribution ** 2)
|
|
184
|
+
|
|
185
|
+
# Sort by singular value magnitude (descending)
|
|
186
|
+
order = all_svals.argsort(descending=True)
|
|
187
|
+
all_components = all_components[order]
|
|
188
|
+
all_svals = all_svals[order]
|
|
189
|
+
|
|
190
|
+
# Cap components if needed
|
|
191
|
+
if max_components is not None and all_components.shape[0] > max_components:
|
|
192
|
+
all_components = all_components[:max_components]
|
|
193
|
+
all_svals = all_svals[:max_components]
|
|
194
|
+
|
|
195
|
+
return all_components, all_svals, new_mean, n_total
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def explained_variance_ratio(singular_values: Tensor) -> Tensor:
|
|
199
|
+
"""Compute cumulative explained variance ratio from singular values.
|
|
200
|
+
|
|
201
|
+
Args:
|
|
202
|
+
singular_values: (k,) singular values in descending order.
|
|
203
|
+
|
|
204
|
+
Returns:
|
|
205
|
+
cumulative_ratio: (k,) cumulative fraction of total variance explained.
|
|
206
|
+
"""
|
|
207
|
+
variance = singular_values ** 2
|
|
208
|
+
total = variance.sum()
|
|
209
|
+
if total == 0:
|
|
210
|
+
return torch.zeros_like(variance)
|
|
211
|
+
return torch.cumsum(variance, dim=0) / total
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def select_num_components(singular_values: Tensor, threshold: float = 0.6) -> int:
|
|
215
|
+
"""Select number of components to explain at least `threshold` variance.
|
|
216
|
+
|
|
217
|
+
Args:
|
|
218
|
+
singular_values: (k,) singular values in descending order.
|
|
219
|
+
threshold: Minimum cumulative variance ratio (default 0.6 per paper).
|
|
220
|
+
|
|
221
|
+
Returns:
|
|
222
|
+
Number of components needed.
|
|
223
|
+
"""
|
|
224
|
+
cumulative = explained_variance_ratio(singular_values)
|
|
225
|
+
# Find first index where cumulative ratio >= threshold
|
|
226
|
+
above = (cumulative >= threshold).nonzero(as_tuple=True)[0]
|
|
227
|
+
if len(above) == 0:
|
|
228
|
+
return len(singular_values)
|
|
229
|
+
return above[0].item() + 1
|
vlora/pipeline.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
"""High-level convenience wrappers for the 3-step pipeline."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
|
|
7
|
+
from vlora.io import LoRAWeights, load_adapter, save_adapter
|
|
8
|
+
from vlora.subspace import SharedSubspace
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def init_subspace(
|
|
12
|
+
adapter_paths: list[str | Path],
|
|
13
|
+
task_ids: list[str] | None = None,
|
|
14
|
+
variance_threshold: float = 0.6,
|
|
15
|
+
num_components: int | None = None,
|
|
16
|
+
) -> SharedSubspace:
|
|
17
|
+
"""Load adapters from disk and build a shared subspace in one call.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
adapter_paths: Directories containing PEFT adapters.
|
|
21
|
+
task_ids: Names for each adapter.
|
|
22
|
+
variance_threshold: Variance threshold for auto component selection.
|
|
23
|
+
num_components: Explicit number of components (overrides threshold).
|
|
24
|
+
|
|
25
|
+
Returns:
|
|
26
|
+
Initialized SharedSubspace.
|
|
27
|
+
"""
|
|
28
|
+
adapters = [load_adapter(p) for p in adapter_paths]
|
|
29
|
+
return SharedSubspace.from_adapters(
|
|
30
|
+
adapters,
|
|
31
|
+
task_ids=task_ids,
|
|
32
|
+
variance_threshold=variance_threshold,
|
|
33
|
+
num_components=num_components,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def absorb_task(
|
|
38
|
+
subspace: SharedSubspace,
|
|
39
|
+
adapter_path: str | Path,
|
|
40
|
+
task_id: str,
|
|
41
|
+
) -> None:
|
|
42
|
+
"""Load a new adapter and absorb it into the subspace.
|
|
43
|
+
|
|
44
|
+
Args:
|
|
45
|
+
subspace: Existing shared subspace (modified in-place).
|
|
46
|
+
adapter_path: Directory containing the new PEFT adapter.
|
|
47
|
+
task_id: Name for the new task.
|
|
48
|
+
"""
|
|
49
|
+
adapter = load_adapter(adapter_path)
|
|
50
|
+
subspace.absorb(adapter, task_id)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def extract_adapter(
|
|
54
|
+
subspace: SharedSubspace,
|
|
55
|
+
task_id: str,
|
|
56
|
+
output_path: str | Path,
|
|
57
|
+
) -> LoRAWeights:
|
|
58
|
+
"""Reconstruct a task's adapter and save it to disk.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
subspace: Shared subspace containing the task.
|
|
62
|
+
task_id: Task to reconstruct.
|
|
63
|
+
output_path: Directory to save the PEFT adapter to.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
The reconstructed LoRAWeights.
|
|
67
|
+
"""
|
|
68
|
+
weights = subspace.reconstruct(task_id)
|
|
69
|
+
save_adapter(weights, output_path)
|
|
70
|
+
return weights
|
vlora/router.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
1
|
+
"""TaskRouter — lightweight routing over adapter loadings.
|
|
2
|
+
|
|
3
|
+
Routes inputs to a soft blend of task adapters by producing per-task
|
|
4
|
+
weights. Since adapters are represented as small loading vectors in the
|
|
5
|
+
shared subspace, blending is a cheap linear combination rather than
|
|
6
|
+
reconstructing and merging full LoRA matrices.
|
|
7
|
+
|
|
8
|
+
Usage:
|
|
9
|
+
subspace = SharedSubspace.load("shared_subspace/")
|
|
10
|
+
router = TaskRouter.from_subspace(subspace, hidden_dim=64)
|
|
11
|
+
|
|
12
|
+
# During inference:
|
|
13
|
+
model = VLoRAModel(base_model, subspace)
|
|
14
|
+
x = get_input_embedding(batch) # (B, embed_dim)
|
|
15
|
+
blended = router.blend_loadings(x) # per-input blended loadings
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
from __future__ import annotations
|
|
19
|
+
|
|
20
|
+
from typing import Literal
|
|
21
|
+
|
|
22
|
+
import torch
|
|
23
|
+
import torch.nn as nn
|
|
24
|
+
import torch.nn.functional as F
|
|
25
|
+
from torch import Tensor
|
|
26
|
+
|
|
27
|
+
from vlora.subspace import SharedSubspace, TaskProjection
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class TaskRouter(nn.Module):
|
|
31
|
+
"""Small MLP that produces soft task-blend weights from input features.
|
|
32
|
+
|
|
33
|
+
Given input embeddings (B, input_dim), outputs (B, num_tasks) blend
|
|
34
|
+
weights that sum to 1 (via softmax). These weights define a per-input
|
|
35
|
+
mixture of task loadings in the shared subspace.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
input_dim: int,
|
|
41
|
+
num_tasks: int,
|
|
42
|
+
hidden_dim: int = 64,
|
|
43
|
+
task_ids: list[str] | None = None,
|
|
44
|
+
temperature: float = 1.0,
|
|
45
|
+
):
|
|
46
|
+
super().__init__()
|
|
47
|
+
self.input_dim = input_dim
|
|
48
|
+
self.num_tasks = num_tasks
|
|
49
|
+
self.hidden_dim = hidden_dim
|
|
50
|
+
self.task_ids = task_ids or [f"task_{i}" for i in range(num_tasks)]
|
|
51
|
+
self.temperature = temperature
|
|
52
|
+
|
|
53
|
+
self.net = nn.Sequential(
|
|
54
|
+
nn.Linear(input_dim, hidden_dim),
|
|
55
|
+
nn.ReLU(),
|
|
56
|
+
nn.Linear(hidden_dim, num_tasks),
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
60
|
+
"""Compute task blend weights.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
x: (B, input_dim) input features.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
weights: (B, num_tasks) softmax blend weights.
|
|
67
|
+
"""
|
|
68
|
+
logits = self.net(x) / self.temperature
|
|
69
|
+
return F.softmax(logits, dim=-1)
|
|
70
|
+
|
|
71
|
+
def blend_loadings(
|
|
72
|
+
self,
|
|
73
|
+
x: Tensor,
|
|
74
|
+
subspace: SharedSubspace,
|
|
75
|
+
) -> TaskProjection:
|
|
76
|
+
"""Produce blended loadings for a batch by mixing task loadings.
|
|
77
|
+
|
|
78
|
+
Computes a weighted average of all task loadings using the
|
|
79
|
+
router's output weights. Returns a single TaskProjection whose
|
|
80
|
+
loadings are the blend. For batched inference, uses the mean
|
|
81
|
+
blend across the batch.
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
x: (B, input_dim) input features.
|
|
85
|
+
subspace: SharedSubspace containing the tasks to blend.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
TaskProjection with blended loadings (one set for the batch).
|
|
89
|
+
"""
|
|
90
|
+
weights = self.forward(x) # (B, num_tasks)
|
|
91
|
+
# Average across batch for a single blended adapter
|
|
92
|
+
avg_weights = weights.mean(dim=0) # (num_tasks,)
|
|
93
|
+
|
|
94
|
+
blended_a: dict[str, Tensor] = {}
|
|
95
|
+
blended_b: dict[str, Tensor] = {}
|
|
96
|
+
|
|
97
|
+
for layer in subspace.layer_names:
|
|
98
|
+
# Stack all task loadings: (num_tasks, k)
|
|
99
|
+
stack_a = torch.stack([
|
|
100
|
+
subspace.tasks[tid].loadings_a[layer]
|
|
101
|
+
for tid in self.task_ids
|
|
102
|
+
])
|
|
103
|
+
stack_b = torch.stack([
|
|
104
|
+
subspace.tasks[tid].loadings_b[layer]
|
|
105
|
+
for tid in self.task_ids
|
|
106
|
+
])
|
|
107
|
+
|
|
108
|
+
# Weighted combination: (k,)
|
|
109
|
+
blended_a[layer] = avg_weights @ stack_a
|
|
110
|
+
blended_b[layer] = avg_weights @ stack_b
|
|
111
|
+
|
|
112
|
+
return TaskProjection(
|
|
113
|
+
task_id="__routed__",
|
|
114
|
+
loadings_a=blended_a,
|
|
115
|
+
loadings_b=blended_b,
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
@classmethod
|
|
119
|
+
def from_subspace(
|
|
120
|
+
cls,
|
|
121
|
+
subspace: SharedSubspace,
|
|
122
|
+
input_dim: int,
|
|
123
|
+
hidden_dim: int = 64,
|
|
124
|
+
temperature: float = 1.0,
|
|
125
|
+
init_from_loadings: bool = True,
|
|
126
|
+
) -> TaskRouter:
|
|
127
|
+
"""Create a router matched to a subspace's task structure.
|
|
128
|
+
|
|
129
|
+
Optionally initializes the final linear layer's weights so that
|
|
130
|
+
the router starts biased toward separating tasks based on their
|
|
131
|
+
loading similarity (warm start for fine-tuning).
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
subspace: SharedSubspace with registered tasks.
|
|
135
|
+
input_dim: Dimension of input features the router will see.
|
|
136
|
+
hidden_dim: Router hidden layer size.
|
|
137
|
+
temperature: Softmax temperature (higher = softer blending).
|
|
138
|
+
init_from_loadings: If True, use task loading similarity to
|
|
139
|
+
bias the output layer (helps convergence).
|
|
140
|
+
"""
|
|
141
|
+
task_ids = sorted(subspace.tasks.keys())
|
|
142
|
+
num_tasks = len(task_ids)
|
|
143
|
+
|
|
144
|
+
router = cls(
|
|
145
|
+
input_dim=input_dim,
|
|
146
|
+
num_tasks=num_tasks,
|
|
147
|
+
hidden_dim=hidden_dim,
|
|
148
|
+
task_ids=task_ids,
|
|
149
|
+
temperature=temperature,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
if init_from_loadings and num_tasks > 1:
|
|
153
|
+
# Use task loading norms as output bias (tasks with larger
|
|
154
|
+
# loadings get slightly higher initial routing weight)
|
|
155
|
+
with torch.no_grad():
|
|
156
|
+
biases = []
|
|
157
|
+
for tid in task_ids:
|
|
158
|
+
proj = subspace.tasks[tid]
|
|
159
|
+
total_norm = sum(
|
|
160
|
+
proj.loadings_a[l].norm() + proj.loadings_b[l].norm()
|
|
161
|
+
for l in subspace.layer_names
|
|
162
|
+
)
|
|
163
|
+
biases.append(total_norm)
|
|
164
|
+
bias_tensor = torch.stack(biases)
|
|
165
|
+
bias_tensor = bias_tensor / (bias_tensor.max() + 1e-8)
|
|
166
|
+
router.net[-1].bias.data.copy_(bias_tensor)
|
|
167
|
+
|
|
168
|
+
return router
|
|
169
|
+
|
|
170
|
+
@property
|
|
171
|
+
def num_params(self) -> int:
|
|
172
|
+
"""Total number of router parameters."""
|
|
173
|
+
return sum(p.numel() for p in self.parameters())
|