vlora-dev 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vlora/ops.py ADDED
@@ -0,0 +1,229 @@
1
+ """Pure math operations for shared low-rank subspace computation.
2
+
3
+ All functions are stateless and operate on raw tensors — no adapter
4
+ or subspace concepts leak in here.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import torch
10
+ from torch import Tensor
11
+
12
+
13
+ def compute_svd(
14
+ data_matrix: Tensor,
15
+ num_components: int | None = None,
16
+ center: bool = True,
17
+ ) -> tuple[Tensor, Tensor, Tensor]:
18
+ """SVD on data matrix, returning top-k components.
19
+
20
+ Args:
21
+ data_matrix: (N, D) matrix where rows are observations.
22
+ num_components: Number of singular vectors to keep. If None, keep all.
23
+ center: Whether to mean-center rows before SVD.
24
+
25
+ Returns:
26
+ components: (k, D) right singular vectors (shared basis).
27
+ singular_values: (k,) corresponding singular values.
28
+ mean: (D,) row mean (zeros if center=False).
29
+ """
30
+ if center:
31
+ mean = data_matrix.mean(dim=0)
32
+ centered = data_matrix - mean
33
+ else:
34
+ mean = torch.zeros(data_matrix.shape[1], dtype=data_matrix.dtype, device=data_matrix.device)
35
+ centered = data_matrix
36
+
37
+ # full_matrices=False gives the economy SVD — U is (N, min(N,D))
38
+ U, S, Vh = torch.linalg.svd(centered, full_matrices=False)
39
+
40
+ if torch.isnan(S).any():
41
+ raise ValueError(
42
+ "SVD produced NaN singular values. Check input data for "
43
+ "NaN/Inf or try using float64 precision."
44
+ )
45
+
46
+ k = num_components if num_components is not None else len(S)
47
+ k = min(k, len(S))
48
+
49
+ return Vh[:k], S[:k], mean
50
+
51
+
52
+ def project_onto_subspace(weights: Tensor, components: Tensor) -> Tensor:
53
+ """Project weight vectors onto the subspace basis.
54
+
55
+ Args:
56
+ weights: (D,) or (N, D) weight vectors to project.
57
+ components: (k, D) orthonormal basis vectors.
58
+
59
+ Returns:
60
+ loadings: (k,) or (N, k) projection coefficients.
61
+ """
62
+ # loadings = weights @ V^T (since components = V^T, i.e. rows are basis vectors)
63
+ return weights @ components.T
64
+
65
+
66
+ def reconstruct_from_subspace(components: Tensor, loadings: Tensor) -> Tensor:
67
+ """Reconstruct weight vectors from subspace loadings.
68
+
69
+ Args:
70
+ components: (k, D) orthonormal basis vectors.
71
+ loadings: (k,) or (N, k) projection coefficients.
72
+
73
+ Returns:
74
+ reconstructed: (D,) or (N, D) reconstructed weight vectors.
75
+ """
76
+ return loadings @ components
77
+
78
+
79
+ def gram_schmidt(basis: Tensor, new_vectors: Tensor) -> Tensor:
80
+ """Orthogonalize new_vectors against an existing orthonormal basis.
81
+
82
+ Appends only those new directions that have non-trivial norm after
83
+ projection removal (threshold: 1e-6).
84
+
85
+ Args:
86
+ basis: (k, D) existing orthonormal basis.
87
+ new_vectors: (m, D) candidate vectors.
88
+
89
+ Returns:
90
+ expanded_basis: (k + n, D) where n <= m new orthogonal directions
91
+ were found.
92
+ """
93
+ vectors = list(basis)
94
+
95
+ for v in new_vectors:
96
+ v = v.clone()
97
+ # Remove components along every existing basis vector
98
+ for b in vectors:
99
+ v = v - (v @ b) * b
100
+ norm = v.norm()
101
+ if norm > 1e-6:
102
+ vectors.append(v / norm)
103
+
104
+ return torch.stack(vectors)
105
+
106
+
107
+ def incremental_svd_update(
108
+ components: Tensor,
109
+ singular_values: Tensor,
110
+ mean: Tensor,
111
+ n_seen: int,
112
+ new_data: Tensor,
113
+ max_components: int | None = None,
114
+ ) -> tuple[Tensor, Tensor, Tensor, int]:
115
+ """Incrementally update SVD with new data points.
116
+
117
+ Uses the projection-residual approach: project new data onto existing
118
+ basis, compute residual, and if significant, expand the basis with
119
+ the residual direction via QR decomposition.
120
+
121
+ Args:
122
+ components: (k, D) current orthonormal basis.
123
+ singular_values: (k,) current singular values.
124
+ mean: (D,) current data mean.
125
+ n_seen: Number of data points seen so far.
126
+ new_data: (m, D) new data points to incorporate.
127
+ max_components: Cap on number of components. If None, allows growth.
128
+
129
+ Returns:
130
+ updated_components: (k', D) updated basis.
131
+ updated_singular_values: (k',) updated singular values.
132
+ updated_mean: (D,) updated mean.
133
+ new_n_seen: Updated count.
134
+ """
135
+ m = new_data.shape[0]
136
+ n_total = n_seen + m
137
+
138
+ # Update mean incrementally
139
+ new_mean = (mean * n_seen + new_data.sum(dim=0)) / n_total
140
+
141
+ # Center new data with updated mean
142
+ centered_new = new_data - new_mean
143
+
144
+ # Also adjust for mean shift on existing data:
145
+ # The old centered data had mean=0 relative to old_mean.
146
+ # Relative to new_mean, old data is shifted by (old_mean - new_mean).
147
+ mean_shift = mean - new_mean
148
+
149
+ # Project new centered data onto existing basis
150
+ projections = centered_new @ components.T # (m, k)
151
+ residuals = centered_new - projections @ components # (m, D)
152
+
153
+ # Find new orthogonal directions from residuals via QR
154
+ residual_norms = residuals.norm(dim=1)
155
+ significant = residual_norms > 1e-6
156
+ new_directions = []
157
+
158
+ if significant.any():
159
+ sig_residuals = residuals[significant]
160
+ # Orthogonalize residuals against existing basis and each other
161
+ expanded = gram_schmidt(components, sig_residuals)
162
+ new_directions_tensor = expanded[components.shape[0]:]
163
+ if new_directions_tensor.shape[0] > 0:
164
+ new_directions.append(new_directions_tensor)
165
+
166
+ # Build augmented system
167
+ if new_directions:
168
+ extra = torch.cat(new_directions, dim=0)
169
+ all_components = torch.cat([components, extra], dim=0)
170
+ # Approximate new singular values for expanded directions
171
+ extra_projections = centered_new @ extra.T # (m, n_extra)
172
+ extra_svals = extra_projections.norm(dim=0)
173
+ all_svals = torch.cat([singular_values, extra_svals])
174
+ else:
175
+ all_components = components
176
+ # Update singular values to account for new data contribution
177
+ new_contributions = projections.norm(dim=0)
178
+ all_svals = torch.sqrt(singular_values ** 2 + new_contributions ** 2)
179
+
180
+ # Account for mean shift effect on singular values
181
+ shift_proj = mean_shift @ all_components.T
182
+ shift_contribution = shift_proj * (n_seen ** 0.5)
183
+ all_svals = torch.sqrt(all_svals ** 2 + shift_contribution ** 2)
184
+
185
+ # Sort by singular value magnitude (descending)
186
+ order = all_svals.argsort(descending=True)
187
+ all_components = all_components[order]
188
+ all_svals = all_svals[order]
189
+
190
+ # Cap components if needed
191
+ if max_components is not None and all_components.shape[0] > max_components:
192
+ all_components = all_components[:max_components]
193
+ all_svals = all_svals[:max_components]
194
+
195
+ return all_components, all_svals, new_mean, n_total
196
+
197
+
198
+ def explained_variance_ratio(singular_values: Tensor) -> Tensor:
199
+ """Compute cumulative explained variance ratio from singular values.
200
+
201
+ Args:
202
+ singular_values: (k,) singular values in descending order.
203
+
204
+ Returns:
205
+ cumulative_ratio: (k,) cumulative fraction of total variance explained.
206
+ """
207
+ variance = singular_values ** 2
208
+ total = variance.sum()
209
+ if total == 0:
210
+ return torch.zeros_like(variance)
211
+ return torch.cumsum(variance, dim=0) / total
212
+
213
+
214
+ def select_num_components(singular_values: Tensor, threshold: float = 0.6) -> int:
215
+ """Select number of components to explain at least `threshold` variance.
216
+
217
+ Args:
218
+ singular_values: (k,) singular values in descending order.
219
+ threshold: Minimum cumulative variance ratio (default 0.6 per paper).
220
+
221
+ Returns:
222
+ Number of components needed.
223
+ """
224
+ cumulative = explained_variance_ratio(singular_values)
225
+ # Find first index where cumulative ratio >= threshold
226
+ above = (cumulative >= threshold).nonzero(as_tuple=True)[0]
227
+ if len(above) == 0:
228
+ return len(singular_values)
229
+ return above[0].item() + 1
vlora/pipeline.py ADDED
@@ -0,0 +1,70 @@
1
+ """High-level convenience wrappers for the 3-step pipeline."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+
7
+ from vlora.io import LoRAWeights, load_adapter, save_adapter
8
+ from vlora.subspace import SharedSubspace
9
+
10
+
11
+ def init_subspace(
12
+ adapter_paths: list[str | Path],
13
+ task_ids: list[str] | None = None,
14
+ variance_threshold: float = 0.6,
15
+ num_components: int | None = None,
16
+ ) -> SharedSubspace:
17
+ """Load adapters from disk and build a shared subspace in one call.
18
+
19
+ Args:
20
+ adapter_paths: Directories containing PEFT adapters.
21
+ task_ids: Names for each adapter.
22
+ variance_threshold: Variance threshold for auto component selection.
23
+ num_components: Explicit number of components (overrides threshold).
24
+
25
+ Returns:
26
+ Initialized SharedSubspace.
27
+ """
28
+ adapters = [load_adapter(p) for p in adapter_paths]
29
+ return SharedSubspace.from_adapters(
30
+ adapters,
31
+ task_ids=task_ids,
32
+ variance_threshold=variance_threshold,
33
+ num_components=num_components,
34
+ )
35
+
36
+
37
+ def absorb_task(
38
+ subspace: SharedSubspace,
39
+ adapter_path: str | Path,
40
+ task_id: str,
41
+ ) -> None:
42
+ """Load a new adapter and absorb it into the subspace.
43
+
44
+ Args:
45
+ subspace: Existing shared subspace (modified in-place).
46
+ adapter_path: Directory containing the new PEFT adapter.
47
+ task_id: Name for the new task.
48
+ """
49
+ adapter = load_adapter(adapter_path)
50
+ subspace.absorb(adapter, task_id)
51
+
52
+
53
+ def extract_adapter(
54
+ subspace: SharedSubspace,
55
+ task_id: str,
56
+ output_path: str | Path,
57
+ ) -> LoRAWeights:
58
+ """Reconstruct a task's adapter and save it to disk.
59
+
60
+ Args:
61
+ subspace: Shared subspace containing the task.
62
+ task_id: Task to reconstruct.
63
+ output_path: Directory to save the PEFT adapter to.
64
+
65
+ Returns:
66
+ The reconstructed LoRAWeights.
67
+ """
68
+ weights = subspace.reconstruct(task_id)
69
+ save_adapter(weights, output_path)
70
+ return weights
vlora/router.py ADDED
@@ -0,0 +1,173 @@
1
+ """TaskRouter — lightweight routing over adapter loadings.
2
+
3
+ Routes inputs to a soft blend of task adapters by producing per-task
4
+ weights. Since adapters are represented as small loading vectors in the
5
+ shared subspace, blending is a cheap linear combination rather than
6
+ reconstructing and merging full LoRA matrices.
7
+
8
+ Usage:
9
+ subspace = SharedSubspace.load("shared_subspace/")
10
+ router = TaskRouter.from_subspace(subspace, hidden_dim=64)
11
+
12
+ # During inference:
13
+ model = VLoRAModel(base_model, subspace)
14
+ x = get_input_embedding(batch) # (B, embed_dim)
15
+ blended = router.blend_loadings(x) # per-input blended loadings
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ from typing import Literal
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ import torch.nn.functional as F
25
+ from torch import Tensor
26
+
27
+ from vlora.subspace import SharedSubspace, TaskProjection
28
+
29
+
30
+ class TaskRouter(nn.Module):
31
+ """Small MLP that produces soft task-blend weights from input features.
32
+
33
+ Given input embeddings (B, input_dim), outputs (B, num_tasks) blend
34
+ weights that sum to 1 (via softmax). These weights define a per-input
35
+ mixture of task loadings in the shared subspace.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ input_dim: int,
41
+ num_tasks: int,
42
+ hidden_dim: int = 64,
43
+ task_ids: list[str] | None = None,
44
+ temperature: float = 1.0,
45
+ ):
46
+ super().__init__()
47
+ self.input_dim = input_dim
48
+ self.num_tasks = num_tasks
49
+ self.hidden_dim = hidden_dim
50
+ self.task_ids = task_ids or [f"task_{i}" for i in range(num_tasks)]
51
+ self.temperature = temperature
52
+
53
+ self.net = nn.Sequential(
54
+ nn.Linear(input_dim, hidden_dim),
55
+ nn.ReLU(),
56
+ nn.Linear(hidden_dim, num_tasks),
57
+ )
58
+
59
+ def forward(self, x: Tensor) -> Tensor:
60
+ """Compute task blend weights.
61
+
62
+ Args:
63
+ x: (B, input_dim) input features.
64
+
65
+ Returns:
66
+ weights: (B, num_tasks) softmax blend weights.
67
+ """
68
+ logits = self.net(x) / self.temperature
69
+ return F.softmax(logits, dim=-1)
70
+
71
+ def blend_loadings(
72
+ self,
73
+ x: Tensor,
74
+ subspace: SharedSubspace,
75
+ ) -> TaskProjection:
76
+ """Produce blended loadings for a batch by mixing task loadings.
77
+
78
+ Computes a weighted average of all task loadings using the
79
+ router's output weights. Returns a single TaskProjection whose
80
+ loadings are the blend. For batched inference, uses the mean
81
+ blend across the batch.
82
+
83
+ Args:
84
+ x: (B, input_dim) input features.
85
+ subspace: SharedSubspace containing the tasks to blend.
86
+
87
+ Returns:
88
+ TaskProjection with blended loadings (one set for the batch).
89
+ """
90
+ weights = self.forward(x) # (B, num_tasks)
91
+ # Average across batch for a single blended adapter
92
+ avg_weights = weights.mean(dim=0) # (num_tasks,)
93
+
94
+ blended_a: dict[str, Tensor] = {}
95
+ blended_b: dict[str, Tensor] = {}
96
+
97
+ for layer in subspace.layer_names:
98
+ # Stack all task loadings: (num_tasks, k)
99
+ stack_a = torch.stack([
100
+ subspace.tasks[tid].loadings_a[layer]
101
+ for tid in self.task_ids
102
+ ])
103
+ stack_b = torch.stack([
104
+ subspace.tasks[tid].loadings_b[layer]
105
+ for tid in self.task_ids
106
+ ])
107
+
108
+ # Weighted combination: (k,)
109
+ blended_a[layer] = avg_weights @ stack_a
110
+ blended_b[layer] = avg_weights @ stack_b
111
+
112
+ return TaskProjection(
113
+ task_id="__routed__",
114
+ loadings_a=blended_a,
115
+ loadings_b=blended_b,
116
+ )
117
+
118
+ @classmethod
119
+ def from_subspace(
120
+ cls,
121
+ subspace: SharedSubspace,
122
+ input_dim: int,
123
+ hidden_dim: int = 64,
124
+ temperature: float = 1.0,
125
+ init_from_loadings: bool = True,
126
+ ) -> TaskRouter:
127
+ """Create a router matched to a subspace's task structure.
128
+
129
+ Optionally initializes the final linear layer's weights so that
130
+ the router starts biased toward separating tasks based on their
131
+ loading similarity (warm start for fine-tuning).
132
+
133
+ Args:
134
+ subspace: SharedSubspace with registered tasks.
135
+ input_dim: Dimension of input features the router will see.
136
+ hidden_dim: Router hidden layer size.
137
+ temperature: Softmax temperature (higher = softer blending).
138
+ init_from_loadings: If True, use task loading similarity to
139
+ bias the output layer (helps convergence).
140
+ """
141
+ task_ids = sorted(subspace.tasks.keys())
142
+ num_tasks = len(task_ids)
143
+
144
+ router = cls(
145
+ input_dim=input_dim,
146
+ num_tasks=num_tasks,
147
+ hidden_dim=hidden_dim,
148
+ task_ids=task_ids,
149
+ temperature=temperature,
150
+ )
151
+
152
+ if init_from_loadings and num_tasks > 1:
153
+ # Use task loading norms as output bias (tasks with larger
154
+ # loadings get slightly higher initial routing weight)
155
+ with torch.no_grad():
156
+ biases = []
157
+ for tid in task_ids:
158
+ proj = subspace.tasks[tid]
159
+ total_norm = sum(
160
+ proj.loadings_a[l].norm() + proj.loadings_b[l].norm()
161
+ for l in subspace.layer_names
162
+ )
163
+ biases.append(total_norm)
164
+ bias_tensor = torch.stack(biases)
165
+ bias_tensor = bias_tensor / (bias_tensor.max() + 1e-8)
166
+ router.net[-1].bias.data.copy_(bias_tensor)
167
+
168
+ return router
169
+
170
+ @property
171
+ def num_params(self) -> int:
172
+ """Total number of router parameters."""
173
+ return sum(p.numel() for p in self.parameters())