vlora-dev 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vlora/io.py ADDED
@@ -0,0 +1,191 @@
1
+ """Load and parse PEFT LoRA adapters from disk or HuggingFace Hub.
2
+
3
+ Reads safetensors + adapter_config.json without requiring the PEFT
4
+ library at runtime.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ import logging
11
+ import re
12
+ from dataclasses import dataclass, field
13
+ from pathlib import Path
14
+ from typing import Literal
15
+
16
+ import torch
17
+
18
+ logger = logging.getLogger("vlora")
19
+ from safetensors.torch import load_file, save_file
20
+ from torch import Tensor
21
+
22
+
23
+ @dataclass
24
+ class LoRAWeights:
25
+ """Parsed LoRA adapter weights grouped by layer."""
26
+
27
+ layer_names: list[str]
28
+ lora_a: dict[str, Tensor] # layer_name -> (rank, in_features)
29
+ lora_b: dict[str, Tensor] # layer_name -> (out_features, rank)
30
+ rank: int
31
+ metadata: dict = field(default_factory=dict)
32
+
33
+
34
+ # Pattern to extract layer name + side from PEFT state dict keys.
35
+ # Handles: base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight
36
+ # or: model.layers.0.self_attn.q_proj.lora_A.weight
37
+ _LORA_KEY_RE = re.compile(
38
+ r"(?:base_model\.model\.)?(.+)\.(lora_[AB])\.(?:weight|default\.weight)"
39
+ )
40
+
41
+
42
+ def parse_state_dict(
43
+ state_dict: dict[str, Tensor],
44
+ ) -> tuple[dict[str, Tensor], dict[str, Tensor], list[str]]:
45
+ """Parse a PEFT state dict into grouped A and B weight dicts.
46
+
47
+ Returns:
48
+ lora_a: {layer_name: tensor}
49
+ lora_b: {layer_name: tensor}
50
+ layer_names: sorted unique layer names
51
+ """
52
+ lora_a: dict[str, Tensor] = {}
53
+ lora_b: dict[str, Tensor] = {}
54
+
55
+ for key, tensor in state_dict.items():
56
+ match = _LORA_KEY_RE.search(key)
57
+ if match is None:
58
+ continue
59
+ layer_name = match.group(1)
60
+ side = match.group(2)
61
+ if side == "lora_A":
62
+ lora_a[layer_name] = tensor
63
+ else:
64
+ lora_b[layer_name] = tensor
65
+
66
+ layer_names = sorted(set(lora_a.keys()) & set(lora_b.keys()))
67
+ # Keep only layers where both A and B exist
68
+ lora_a = {k: lora_a[k] for k in layer_names}
69
+ lora_b = {k: lora_b[k] for k in layer_names}
70
+
71
+ return lora_a, lora_b, layer_names
72
+
73
+
74
+ def load_adapter(path: str | Path) -> LoRAWeights:
75
+ """Load a PEFT LoRA adapter from a local directory.
76
+
77
+ Expects adapter_model.safetensors and adapter_config.json in the
78
+ given directory.
79
+ """
80
+ path = Path(path)
81
+
82
+ # Load safetensors weights
83
+ safetensors_path = path / "adapter_model.safetensors"
84
+ if not safetensors_path.exists():
85
+ raise FileNotFoundError(f"No adapter_model.safetensors in {path}")
86
+ state_dict = load_file(str(safetensors_path))
87
+
88
+ # Load config for metadata
89
+ config_path = path / "adapter_config.json"
90
+ metadata: dict = {}
91
+ rank = 0
92
+ if config_path.exists():
93
+ with open(config_path) as f:
94
+ config = json.load(f)
95
+ rank = config.get("r", 0)
96
+ metadata = config
97
+
98
+ lora_a, lora_b, layer_names = parse_state_dict(state_dict)
99
+ logger.debug("Loaded adapter from %s: %d layers", path, len(layer_names))
100
+
101
+ # Infer rank from weight shapes if not in config
102
+ if rank == 0 and layer_names:
103
+ rank = lora_a[layer_names[0]].shape[0]
104
+
105
+ return LoRAWeights(
106
+ layer_names=layer_names,
107
+ lora_a=lora_a,
108
+ lora_b=lora_b,
109
+ rank=rank,
110
+ metadata=metadata,
111
+ )
112
+
113
+
114
+ def load_adapter_from_hub(repo_id: str, revision: str | None = None) -> LoRAWeights:
115
+ """Load a PEFT LoRA adapter from HuggingFace Hub.
116
+
117
+ Requires the `huggingface-hub` package (install with `pip install vlora-dev[hub]`).
118
+ """
119
+ try:
120
+ from huggingface_hub import snapshot_download
121
+ except ImportError:
122
+ raise ImportError(
123
+ "huggingface-hub is required to load from Hub. "
124
+ "Install with: pip install vlora-dev[hub]"
125
+ )
126
+
127
+ local_dir = snapshot_download(
128
+ repo_id,
129
+ revision=revision,
130
+ allow_patterns=["adapter_model.safetensors", "adapter_config.json"],
131
+ )
132
+ return load_adapter(local_dir)
133
+
134
+
135
+ def stack_lora_weights(
136
+ adapters: list[LoRAWeights],
137
+ side: Literal["A", "B"],
138
+ ) -> dict[str, Tensor]:
139
+ """Stack LoRA weight matrices from multiple adapters per layer.
140
+
141
+ For side="A": each adapter's A matrix is (rank, in_features).
142
+ We flatten each to a row vector and stack N adapters into (N, rank*in_features).
143
+
144
+ This produces the "factor data matrix" the paper feeds into SVD.
145
+
146
+ Returns:
147
+ {layer_name: (N, flattened_dim)} stacked matrix.
148
+ """
149
+ if not adapters:
150
+ raise ValueError("Need at least one adapter to stack")
151
+
152
+ # Use intersection of all adapters' layer names
153
+ layer_set = set(adapters[0].layer_names)
154
+ for adapter in adapters[1:]:
155
+ layer_set &= set(adapter.layer_names)
156
+ layer_names = sorted(layer_set)
157
+
158
+ stacked: dict[str, Tensor] = {}
159
+ for layer in layer_names:
160
+ matrices = []
161
+ for adapter in adapters:
162
+ w = adapter.lora_a[layer] if side == "A" else adapter.lora_b[layer]
163
+ matrices.append(w.flatten())
164
+ stacked[layer] = torch.stack(matrices)
165
+
166
+ return stacked
167
+
168
+
169
+ def save_adapter(weights: LoRAWeights, path: str | Path) -> None:
170
+ """Save LoRA weights back to PEFT-compatible format."""
171
+ path = Path(path)
172
+ path.mkdir(parents=True, exist_ok=True)
173
+
174
+ # Rebuild state dict with PEFT key format
175
+ state_dict = {}
176
+ for layer_name in weights.layer_names:
177
+ state_dict[f"base_model.model.{layer_name}.lora_A.weight"] = weights.lora_a[layer_name]
178
+ state_dict[f"base_model.model.{layer_name}.lora_B.weight"] = weights.lora_b[layer_name]
179
+
180
+ save_file(state_dict, str(path / "adapter_model.safetensors"))
181
+
182
+ # Save config — include defaults needed by vLLM/TGI
183
+ config = dict(weights.metadata) if weights.metadata else {}
184
+ config.setdefault("r", weights.rank)
185
+ config.setdefault("lora_alpha", weights.rank) # alpha=rank → scaling=1.0
186
+ config.setdefault("peft_type", "LORA")
187
+ config.setdefault("task_type", "CAUSAL_LM")
188
+ config.setdefault("bias", "none")
189
+ config.setdefault("lora_dropout", 0.0)
190
+ with open(path / "adapter_config.json", "w") as f:
191
+ json.dump(config, f, indent=2)
vlora/merge.py ADDED
@@ -0,0 +1,229 @@
1
+ """Adapter merging — task arithmetic, TIES, and DARE.
2
+
3
+ These techniques operate on LoRA weight matrices directly, producing
4
+ a single merged adapter from multiple inputs. All three methods work
5
+ on a per-layer basis and return a new LoRAWeights object.
6
+
7
+ References:
8
+ - Task Arithmetic: Ilharco et al., "Editing Models with Task Arithmetic" (2023)
9
+ - TIES: Yadav et al., "TIES-Merging: Resolving Interference When Merging Models" (2023)
10
+ - DARE: Yu et al., "Language Models are Super Mario" (2024)
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import logging
16
+ from typing import Literal
17
+
18
+ import torch
19
+ from torch import Tensor
20
+
21
+ from vlora._validate import check_adapters_compatible
22
+ from vlora.io import LoRAWeights
23
+
24
+ logger = logging.getLogger("vlora")
25
+
26
+
27
+ def task_arithmetic(
28
+ adapters: list[LoRAWeights],
29
+ weights: list[float] | None = None,
30
+ ) -> LoRAWeights:
31
+ """Merge adapters via weighted average of their weight matrices.
32
+
33
+ Each adapter's LoRA A and B matrices are averaged (optionally weighted)
34
+ per-layer to produce a single merged adapter.
35
+
36
+ Args:
37
+ adapters: List of adapters to merge (must share layers and rank).
38
+ weights: Per-adapter weights. Defaults to uniform (1/N each).
39
+
40
+ Returns:
41
+ Merged LoRAWeights.
42
+ """
43
+ check_adapters_compatible(adapters)
44
+ n = len(adapters)
45
+ if n == 0:
46
+ raise ValueError("Need at least one adapter to merge.")
47
+
48
+ if weights is None:
49
+ weights = [1.0 / n] * n
50
+ if len(weights) != n:
51
+ raise ValueError(f"weights length ({len(weights)}) must match adapters ({n})")
52
+
53
+ # Use intersection of layers
54
+ layer_names = sorted(set.intersection(*(set(a.layer_names) for a in adapters)))
55
+ if not layer_names:
56
+ raise ValueError("Adapters share no common layers.")
57
+
58
+ logger.info("Task arithmetic merge: %d adapters, %d layers", n, len(layer_names))
59
+
60
+ lora_a: dict[str, Tensor] = {}
61
+ lora_b: dict[str, Tensor] = {}
62
+
63
+ for layer in layer_names:
64
+ merged_a = sum(w * adapters[i].lora_a[layer] for i, w in enumerate(weights))
65
+ merged_b = sum(w * adapters[i].lora_b[layer] for i, w in enumerate(weights))
66
+ lora_a[layer] = merged_a
67
+ lora_b[layer] = merged_b
68
+
69
+ return LoRAWeights(
70
+ layer_names=layer_names,
71
+ lora_a=lora_a,
72
+ lora_b=lora_b,
73
+ rank=adapters[0].rank,
74
+ )
75
+
76
+
77
+ def ties_merge(
78
+ adapters: list[LoRAWeights],
79
+ density: float = 0.5,
80
+ weights: list[float] | None = None,
81
+ ) -> LoRAWeights:
82
+ """Merge adapters using TIES: Trim, Elect sign, and merge.
83
+
84
+ 1. Trim: zero out the smallest elements per adapter (keep top `density` fraction)
85
+ 2. Elect sign: for each position, choose the majority sign across adapters
86
+ 3. Merge: average only the values that agree with the elected sign
87
+
88
+ Args:
89
+ adapters: List of adapters to merge.
90
+ density: Fraction of elements to keep per adapter (0, 1]. Default 0.5.
91
+ weights: Per-adapter weights for the final average.
92
+
93
+ Returns:
94
+ Merged LoRAWeights.
95
+ """
96
+ check_adapters_compatible(adapters)
97
+ n = len(adapters)
98
+ if n == 0:
99
+ raise ValueError("Need at least one adapter to merge.")
100
+ if not 0 < density <= 1:
101
+ raise ValueError(f"density must be in (0, 1], got {density}")
102
+
103
+ if weights is None:
104
+ weights = [1.0 / n] * n
105
+ if len(weights) != n:
106
+ raise ValueError(f"weights length ({len(weights)}) must match adapters ({n})")
107
+
108
+ layer_names = sorted(set.intersection(*(set(a.layer_names) for a in adapters)))
109
+ if not layer_names:
110
+ raise ValueError("Adapters share no common layers.")
111
+
112
+ logger.info("TIES merge: %d adapters, density=%.2f, %d layers", n, density, len(layer_names))
113
+
114
+ lora_a: dict[str, Tensor] = {}
115
+ lora_b: dict[str, Tensor] = {}
116
+
117
+ for layer in layer_names:
118
+ for side, out_dict, attr in [("a", lora_a, "lora_a"), ("b", lora_b, "lora_b")]:
119
+ # Stack all adapters for this layer/side
120
+ tensors = [getattr(adapters[i], attr)[layer].clone() for i in range(n)]
121
+
122
+ # Step 1: Trim — zero out smallest elements per adapter
123
+ for t in tensors:
124
+ flat = t.flatten().abs()
125
+ k = max(1, int(density * flat.numel()))
126
+ threshold = flat.topk(k).values[-1]
127
+ t[t.abs() < threshold] = 0.0
128
+
129
+ stacked = torch.stack(tensors) # (N, *shape)
130
+
131
+ # Step 2: Elect sign — majority vote at each position
132
+ sign_votes = (stacked > 0).float().sum(dim=0) - (stacked < 0).float().sum(dim=0)
133
+ elected_sign = sign_votes.sign()
134
+ # Ties go positive (convention)
135
+ elected_sign[elected_sign == 0] = 1.0
136
+
137
+ # Step 3: Merge — weighted average of values matching elected sign
138
+ mask = (stacked.sign() == elected_sign.unsqueeze(0))
139
+ # Apply weights
140
+ w = torch.tensor(weights, dtype=stacked.dtype).view(-1, *([1] * (stacked.dim() - 1)))
141
+ weighted = stacked * w
142
+ # Zero out values with wrong sign
143
+ weighted = weighted * mask.float()
144
+ # Sum and normalize by number of contributors (avoid division by zero)
145
+ contributor_count = mask.float().sum(dim=0).clamp(min=1)
146
+ merged = weighted.sum(dim=0) * (n / contributor_count)
147
+
148
+ out_dict[layer] = merged
149
+
150
+ return LoRAWeights(
151
+ layer_names=layer_names,
152
+ lora_a=lora_a,
153
+ lora_b=lora_b,
154
+ rank=adapters[0].rank,
155
+ )
156
+
157
+
158
+ def dare_merge(
159
+ adapters: list[LoRAWeights],
160
+ drop_rate: float = 0.5,
161
+ weights: list[float] | None = None,
162
+ seed: int | None = None,
163
+ ) -> LoRAWeights:
164
+ """Merge adapters using DARE: Drop And REscale.
165
+
166
+ For each adapter, randomly drop elements with probability `drop_rate`
167
+ and rescale survivors by 1/(1-drop_rate). Then average the results.
168
+
169
+ Args:
170
+ adapters: List of adapters to merge.
171
+ drop_rate: Probability of dropping each element. Default 0.5.
172
+ weights: Per-adapter weights for the final average.
173
+ seed: Random seed for reproducibility.
174
+
175
+ Returns:
176
+ Merged LoRAWeights.
177
+ """
178
+ check_adapters_compatible(adapters)
179
+ n = len(adapters)
180
+ if n == 0:
181
+ raise ValueError("Need at least one adapter to merge.")
182
+ if not 0 <= drop_rate < 1:
183
+ raise ValueError(f"drop_rate must be in [0, 1), got {drop_rate}")
184
+
185
+ if weights is None:
186
+ weights = [1.0 / n] * n
187
+ if len(weights) != n:
188
+ raise ValueError(f"weights length ({len(weights)}) must match adapters ({n})")
189
+
190
+ layer_names = sorted(set.intersection(*(set(a.layer_names) for a in adapters)))
191
+ if not layer_names:
192
+ raise ValueError("Adapters share no common layers.")
193
+
194
+ logger.info("DARE merge: %d adapters, drop_rate=%.2f, %d layers", n, drop_rate, len(layer_names))
195
+
196
+ if seed is not None:
197
+ torch.manual_seed(seed)
198
+
199
+ rescale = 1.0 / (1.0 - drop_rate) if drop_rate > 0 else 1.0
200
+
201
+ lora_a: dict[str, Tensor] = {}
202
+ lora_b: dict[str, Tensor] = {}
203
+
204
+ for layer in layer_names:
205
+ for side, out_dict, attr in [("a", lora_a, "lora_a"), ("b", lora_b, "lora_b")]:
206
+ merged = torch.zeros_like(getattr(adapters[0], attr)[layer])
207
+
208
+ for i, adapter in enumerate(adapters):
209
+ t = getattr(adapter, attr)[layer].clone()
210
+ if drop_rate > 0:
211
+ mask = torch.bernoulli(torch.full_like(t, 1.0 - drop_rate))
212
+ t = t * mask * rescale
213
+ merged = merged + weights[i] * t
214
+
215
+ out_dict[layer] = merged
216
+
217
+ return LoRAWeights(
218
+ layer_names=layer_names,
219
+ lora_a=lora_a,
220
+ lora_b=lora_b,
221
+ rank=adapters[0].rank,
222
+ )
223
+
224
+
225
+ MERGE_METHODS = {
226
+ "average": task_arithmetic,
227
+ "ties": ties_merge,
228
+ "dare": dare_merge,
229
+ }
vlora/model.py ADDED
@@ -0,0 +1,148 @@
1
+ """VLoRAModel — inference wrapper that applies reconstructed LoRA deltas."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch import Tensor
10
+
11
+ from vlora._validate import check_task_exists
12
+ from vlora.subspace import SharedSubspace
13
+
14
+
15
+ class VLoRAModel(nn.Module):
16
+ """Wraps a base model with a shared subspace for multi-task LoRA inference.
17
+
18
+ Reconstructs task-specific LoRA deltas on demand and applies them to
19
+ the base model's linear layers during forward pass.
20
+
21
+ Usage:
22
+ subspace = SharedSubspace.load("shared_subspace/")
23
+ base_model = AutoModelForCausalLM.from_pretrained("model-name")
24
+ model = VLoRAModel(base_model, subspace)
25
+
26
+ model.set_task("task_0")
27
+ output = model(input_ids)
28
+
29
+ model.set_task("task_1") # switches adapter, cached if same task
30
+ output = model(input_ids)
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ base_model: nn.Module,
36
+ subspace: SharedSubspace,
37
+ scaling: float | None = None,
38
+ lora_alpha: float | None = None,
39
+ ):
40
+ super().__init__()
41
+ self.base_model = base_model
42
+ self.subspace = subspace
43
+
44
+ # Resolve scaling: explicit scaling > lora_alpha/rank > 1.0
45
+ if scaling is not None:
46
+ self.scaling = scaling
47
+ elif lora_alpha is not None:
48
+ self.scaling = lora_alpha / subspace.rank
49
+ else:
50
+ self.scaling = 1.0
51
+ self._active_task: str | None = None
52
+ self._cached_deltas: dict[str, Tensor] | None = None
53
+ self._hooks: list[torch.utils.hooks.RemovableHook] = []
54
+
55
+ def set_task(self, task_id: str) -> None:
56
+ """Set the active task adapter. Reconstructs and caches if changed."""
57
+ if task_id == self._active_task:
58
+ return
59
+
60
+ check_task_exists(self.subspace, task_id)
61
+
62
+ # Reconstruct and cache the LoRA deltas
63
+ weights = self.subspace.reconstruct(task_id)
64
+ self._cached_deltas = {}
65
+ for layer_name in weights.layer_names:
66
+ # delta_W = B @ A
67
+ delta = weights.lora_b[layer_name] @ weights.lora_a[layer_name]
68
+ self._cached_deltas[layer_name] = delta
69
+
70
+ self._active_task = task_id
71
+ self._apply_hooks()
72
+
73
+ def clear_task(self) -> None:
74
+ """Remove the active task adapter."""
75
+ self._remove_hooks()
76
+ self._active_task = None
77
+ self._cached_deltas = None
78
+
79
+ def _apply_hooks(self) -> None:
80
+ """Register forward hooks on matching linear layers."""
81
+ self._remove_hooks()
82
+
83
+ if self._cached_deltas is None:
84
+ return
85
+
86
+ for name, module in self.base_model.named_modules():
87
+ if name in self._cached_deltas and isinstance(module, nn.Linear):
88
+ delta = self._cached_deltas[name]
89
+ hook = module.register_forward_hook(
90
+ self._make_lora_hook(delta)
91
+ )
92
+ self._hooks.append(hook)
93
+
94
+ def _remove_hooks(self) -> None:
95
+ """Remove all registered forward hooks."""
96
+ for hook in self._hooks:
97
+ hook.remove()
98
+ self._hooks.clear()
99
+
100
+ def _make_lora_hook(self, delta: Tensor):
101
+ """Create a forward hook that adds LoRA delta to the output."""
102
+ scaling = self.scaling
103
+
104
+ def hook(module: nn.Module, input: Any, output: Tensor) -> Tensor:
105
+ # input[0] is the input tensor to the linear layer
106
+ x = input[0] if isinstance(input, tuple) else input
107
+ lora_out = x @ delta.T.to(x.device, x.dtype)
108
+ return output + scaling * lora_out
109
+
110
+ return hook
111
+
112
+ def forward(self, *args, **kwargs):
113
+ """Forward pass through the base model with active LoRA adapter."""
114
+ return self.base_model(*args, **kwargs)
115
+
116
+ @property
117
+ def active_task(self) -> str | None:
118
+ """Currently active task ID, or None."""
119
+ return self._active_task
120
+
121
+ @property
122
+ def available_tasks(self) -> list[str]:
123
+ """List of available task IDs."""
124
+ return sorted(self.subspace.tasks.keys())
125
+
126
+ def reconstruct_state_dict(self, task_id: str) -> dict[str, Tensor]:
127
+ """Get the LoRA delta weight dict for a task without applying hooks.
128
+
129
+ Returns dict of {layer_name: delta_W} where delta_W = B @ A.
130
+ Useful for manual integration with custom model architectures.
131
+ """
132
+ weights = self.subspace.reconstruct(task_id)
133
+ deltas = {}
134
+ for layer_name in weights.layer_names:
135
+ deltas[layer_name] = weights.lora_b[layer_name] @ weights.lora_a[layer_name]
136
+ return deltas
137
+
138
+ def compile(self, **kwargs) -> VLoRAModel:
139
+ """Compile the base model with torch.compile for faster inference.
140
+
141
+ Passes all kwargs to torch.compile(). The LoRA hooks remain
142
+ uncompiled (they're lightweight matmuls) while the base model
143
+ benefits from fusion and kernel optimization.
144
+
145
+ Returns self for chaining.
146
+ """
147
+ self.base_model = torch.compile(self.base_model, **kwargs)
148
+ return self