vlora-dev 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vlora/__init__.py ADDED
@@ -0,0 +1,73 @@
1
+ """vLoRA — Shared low-rank subspaces for LoRA adapter management.
2
+
3
+ Based on the Share paper (arXiv:2602.06043): LoRA adapters across tasks
4
+ share a common low-rank subspace. Instead of storing N separate adapters,
5
+ maintain one shared basis and per-task coefficient vectors.
6
+ """
7
+
8
+ __version__ = "0.1.0"
9
+
10
+ from vlora.io import LoRAWeights, load_adapter, load_adapter_from_hub, save_adapter
11
+ from vlora.ops import (
12
+ compute_svd,
13
+ explained_variance_ratio,
14
+ gram_schmidt,
15
+ project_onto_subspace,
16
+ reconstruct_from_subspace,
17
+ select_num_components,
18
+ )
19
+ from vlora.model import VLoRAModel
20
+ from vlora.ops import incremental_svd_update
21
+ from vlora.analysis import (
22
+ adapter_diff,
23
+ compute_similarity_matrix,
24
+ find_clusters,
25
+ find_outliers,
26
+ subspace_coverage,
27
+ )
28
+ from vlora.pipeline import absorb_task, extract_adapter, init_subspace
29
+ from vlora.router import TaskRouter
30
+ from vlora.subspace import SharedSubspace, TaskProjection
31
+ from vlora.training import SubspaceTrainer, orthogonal_init
32
+ from vlora.merge import task_arithmetic, ties_merge, dare_merge
33
+
34
+ __all__ = [
35
+ # Core
36
+ "SharedSubspace",
37
+ "TaskProjection",
38
+ "LoRAWeights",
39
+ # I/O
40
+ "load_adapter",
41
+ "load_adapter_from_hub",
42
+ "save_adapter",
43
+ # Pipeline
44
+ "init_subspace",
45
+ "absorb_task",
46
+ "extract_adapter",
47
+ # Ops
48
+ "compute_svd",
49
+ "project_onto_subspace",
50
+ "reconstruct_from_subspace",
51
+ "gram_schmidt",
52
+ "explained_variance_ratio",
53
+ "select_num_components",
54
+ # Analysis
55
+ "compute_similarity_matrix",
56
+ "find_clusters",
57
+ "adapter_diff",
58
+ "subspace_coverage",
59
+ "find_outliers",
60
+ # Model
61
+ "VLoRAModel",
62
+ # Router
63
+ "TaskRouter",
64
+ # Training
65
+ "SubspaceTrainer",
66
+ "orthogonal_init",
67
+ # Incremental
68
+ "incremental_svd_update",
69
+ # Merging
70
+ "task_arithmetic",
71
+ "ties_merge",
72
+ "dare_merge",
73
+ ]
vlora/_validate.py ADDED
@@ -0,0 +1,82 @@
1
+ """Input validation helpers for vlora public APIs."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import warnings
7
+ from typing import TYPE_CHECKING
8
+
9
+ import torch
10
+ from torch import Tensor
11
+
12
+ if TYPE_CHECKING:
13
+ from vlora.io import LoRAWeights
14
+ from vlora.subspace import SharedSubspace
15
+
16
+ logger = logging.getLogger("vlora")
17
+
18
+
19
+ def check_adapters_compatible(adapters: list[LoRAWeights]) -> None:
20
+ """Validate that adapters can be used together (same rank, overlapping layers)."""
21
+ if not adapters:
22
+ raise ValueError("Need at least one adapter.")
23
+
24
+ ranks = {a.rank for a in adapters}
25
+ if len(ranks) > 1:
26
+ raise ValueError(
27
+ f"Adapters have inconsistent ranks: {sorted(ranks)}. "
28
+ "All adapters must have the same LoRA rank to share a subspace."
29
+ )
30
+
31
+
32
+ def check_adapter_matches_subspace(
33
+ adapter: LoRAWeights, subspace: SharedSubspace, operation: str = "operation"
34
+ ) -> None:
35
+ """Validate that an adapter is compatible with a subspace."""
36
+ if adapter.rank != subspace.rank:
37
+ raise ValueError(
38
+ f"Cannot {operation}: adapter rank ({adapter.rank}) does not match "
39
+ f"subspace rank ({subspace.rank})."
40
+ )
41
+
42
+ common = set(adapter.layer_names) & set(subspace.layer_names)
43
+ if not common:
44
+ raise ValueError(
45
+ f"Cannot {operation}: adapter and subspace share no common layers. "
46
+ f"Adapter layers: {adapter.layer_names[:3]}... "
47
+ f"Subspace layers: {subspace.layer_names[:3]}..."
48
+ )
49
+
50
+ missing = set(subspace.layer_names) - set(adapter.layer_names)
51
+ if missing:
52
+ warnings.warn(
53
+ f"Adapter is missing {len(missing)} subspace layers "
54
+ f"(e.g. {sorted(missing)[:2]}). These will use mean values.",
55
+ stacklevel=3,
56
+ )
57
+
58
+
59
+ def check_task_exists(subspace: SharedSubspace, task_id: str) -> None:
60
+ """Raise KeyError with helpful message if task not found."""
61
+ if task_id not in subspace.tasks:
62
+ available = ", ".join(sorted(subspace.tasks.keys()))
63
+ raise KeyError(
64
+ f"Unknown task '{task_id}'. "
65
+ f"Available tasks: [{available}]. "
66
+ f"Use subspace.tasks.keys() to list all tasks."
67
+ )
68
+
69
+
70
+ def check_tensor_health(tensor: Tensor, name: str = "tensor") -> None:
71
+ """Check for NaN/Inf in a tensor."""
72
+ if torch.isnan(tensor).any():
73
+ raise ValueError(
74
+ f"NaN detected in {name}. This usually indicates numerical "
75
+ "instability during SVD. Try using fewer components or "
76
+ "checking your input adapters for degenerate weights."
77
+ )
78
+ if torch.isinf(tensor).any():
79
+ raise ValueError(
80
+ f"Inf detected in {name}. This usually indicates overflow "
81
+ "during computation. Try using float32 precision."
82
+ )
vlora/analysis.py ADDED
@@ -0,0 +1,191 @@
1
+ """Adapter analysis — similarity, clustering, diffing, and coverage."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import torch
6
+ from torch import Tensor
7
+
8
+ from vlora.io import LoRAWeights
9
+ from vlora.ops import project_onto_subspace
10
+
11
+ from typing import TYPE_CHECKING
12
+
13
+ if TYPE_CHECKING:
14
+ from vlora.subspace import SharedSubspace
15
+
16
+
17
+ def _flatten_adapter(adapter: LoRAWeights) -> Tensor:
18
+ """Flatten all A and B weights into a single 1D vector."""
19
+ parts = []
20
+ for layer in adapter.layer_names:
21
+ parts.append(adapter.lora_a[layer].flatten())
22
+ parts.append(adapter.lora_b[layer].flatten())
23
+ return torch.cat(parts)
24
+
25
+
26
+ def compute_similarity_matrix(adapters: list[LoRAWeights]) -> Tensor:
27
+ """Compute pairwise cosine similarity between adapters.
28
+
29
+ Returns:
30
+ (N, N) similarity matrix where entry [i,j] is the cosine
31
+ similarity between adapter i and adapter j.
32
+ """
33
+ if len(adapters) < 2:
34
+ raise ValueError("Need at least 2 adapters for similarity")
35
+
36
+ vectors = torch.stack([_flatten_adapter(a) for a in adapters])
37
+ # Normalize rows
38
+ norms = vectors.norm(dim=1, keepdim=True).clamp(min=1e-8)
39
+ normalized = vectors / norms
40
+ return normalized @ normalized.T
41
+
42
+
43
+ def find_clusters(
44
+ similarity_matrix: Tensor,
45
+ threshold: float = 0.9,
46
+ ) -> list[list[int]]:
47
+ """Group adapters into clusters based on similarity threshold.
48
+
49
+ Uses simple greedy clustering: iterate through adapters, assign each
50
+ to the first cluster where similarity to all members >= threshold,
51
+ or create a new cluster.
52
+
53
+ Returns:
54
+ List of clusters, where each cluster is a list of adapter indices.
55
+ """
56
+ n = similarity_matrix.shape[0]
57
+ clusters: list[list[int]] = []
58
+
59
+ for i in range(n):
60
+ placed = False
61
+ for cluster in clusters:
62
+ # Check similarity to all members
63
+ if all(similarity_matrix[i, j].item() >= threshold for j in cluster):
64
+ cluster.append(i)
65
+ placed = True
66
+ break
67
+ if not placed:
68
+ clusters.append([i])
69
+
70
+ return clusters
71
+
72
+
73
+ def adapter_diff(
74
+ adapter_a: LoRAWeights,
75
+ adapter_b: LoRAWeights,
76
+ ) -> dict[str, dict[str, float]]:
77
+ """Per-layer comparison of two adapters.
78
+
79
+ Returns:
80
+ Dict mapping layer_name -> {"l2_distance": float, "cosine_sim": float}
81
+ for each layer present in both adapters.
82
+ """
83
+ common_layers = sorted(set(adapter_a.layer_names) & set(adapter_b.layer_names))
84
+ result: dict[str, dict[str, float]] = {}
85
+
86
+ for layer in common_layers:
87
+ vec_a = torch.cat([
88
+ adapter_a.lora_a[layer].flatten(),
89
+ adapter_a.lora_b[layer].flatten(),
90
+ ])
91
+ vec_b = torch.cat([
92
+ adapter_b.lora_a[layer].flatten(),
93
+ adapter_b.lora_b[layer].flatten(),
94
+ ])
95
+
96
+ l2 = (vec_a - vec_b).norm().item()
97
+ cos = torch.nn.functional.cosine_similarity(
98
+ vec_a.unsqueeze(0), vec_b.unsqueeze(0)
99
+ ).item()
100
+
101
+ result[layer] = {"l2_distance": l2, "cosine_sim": cos}
102
+
103
+ return result
104
+
105
+
106
+ def subspace_coverage(
107
+ subspace: SharedSubspace,
108
+ adapter: LoRAWeights,
109
+ ) -> dict[str, dict[str, float]]:
110
+ """Measure how well a subspace represents a given adapter.
111
+
112
+ For each layer and side (A/B), projects the adapter onto the subspace
113
+ and measures the fraction of the adapter's norm that is captured.
114
+
115
+ Returns:
116
+ Dict mapping layer_name -> {"coverage_a": float, "coverage_b": float,
117
+ "coverage_mean": float}
118
+ """
119
+ result: dict[str, dict[str, float]] = {}
120
+
121
+ for layer in subspace.layer_names:
122
+ if layer not in adapter.lora_a:
123
+ continue
124
+
125
+ coverages = {}
126
+ for side, weights_dict, components, means in [
127
+ ("a", adapter.lora_a, subspace.components_a, subspace.means_a),
128
+ ("b", adapter.lora_b, subspace.components_b, subspace.means_b),
129
+ ]:
130
+ flat = weights_dict[layer].flatten() - means[layer]
131
+ original_norm = flat.norm().item()
132
+ if original_norm < 1e-8:
133
+ coverages[f"coverage_{side}"] = 1.0
134
+ continue
135
+
136
+ loadings = project_onto_subspace(flat, components[layer])
137
+ reconstructed = loadings @ components[layer]
138
+ residual_norm = (flat - reconstructed).norm().item()
139
+ coverages[f"coverage_{side}"] = 1.0 - (residual_norm / original_norm)
140
+
141
+ coverages["coverage_mean"] = (
142
+ coverages["coverage_a"] + coverages["coverage_b"]
143
+ ) / 2.0
144
+ result[layer] = coverages
145
+
146
+ return result
147
+
148
+
149
+ def find_outliers(
150
+ adapters: list[LoRAWeights],
151
+ threshold: float = 2.0,
152
+ ) -> list[dict]:
153
+ """Detect adapter outliers based on distance from the group mean.
154
+
155
+ Computes each adapter's flattened weight vector, measures the L2
156
+ distance from the group centroid, and flags adapters whose distance
157
+ exceeds `threshold` standard deviations above the mean distance.
158
+
159
+ Args:
160
+ adapters: List of adapters to analyze.
161
+ threshold: Number of standard deviations above mean distance
162
+ to consider an outlier. Default 2.0.
163
+
164
+ Returns:
165
+ List of dicts with keys: {"index", "distance", "z_score"} for
166
+ each outlier adapter.
167
+ """
168
+ if len(adapters) < 3:
169
+ return []
170
+
171
+ vectors = torch.stack([_flatten_adapter(a) for a in adapters])
172
+ centroid = vectors.mean(dim=0)
173
+ distances = (vectors - centroid).norm(dim=1)
174
+
175
+ mean_dist = distances.mean().item()
176
+ std_dist = distances.std().item()
177
+
178
+ if std_dist < 1e-8:
179
+ return []
180
+
181
+ outliers = []
182
+ for i in range(len(adapters)):
183
+ z = (distances[i].item() - mean_dist) / std_dist
184
+ if z > threshold:
185
+ outliers.append({
186
+ "index": i,
187
+ "distance": distances[i].item(),
188
+ "z_score": z,
189
+ })
190
+
191
+ return outliers