vlora-dev 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vlora/__init__.py +73 -0
- vlora/_validate.py +82 -0
- vlora/analysis.py +191 -0
- vlora/cli.py +430 -0
- vlora/integrations/__init__.py +1 -0
- vlora/integrations/huggingface.py +163 -0
- vlora/io.py +191 -0
- vlora/merge.py +229 -0
- vlora/model.py +148 -0
- vlora/ops.py +229 -0
- vlora/pipeline.py +70 -0
- vlora/router.py +173 -0
- vlora/subspace.py +651 -0
- vlora/training.py +149 -0
- vlora_dev-0.2.0.dist-info/METADATA +409 -0
- vlora_dev-0.2.0.dist-info/RECORD +19 -0
- vlora_dev-0.2.0.dist-info/WHEEL +4 -0
- vlora_dev-0.2.0.dist-info/entry_points.txt +2 -0
- vlora_dev-0.2.0.dist-info/licenses/LICENSE +190 -0
vlora/io.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
"""Load and parse PEFT LoRA adapters from disk or HuggingFace Hub.
|
|
2
|
+
|
|
3
|
+
Reads safetensors + adapter_config.json without requiring the PEFT
|
|
4
|
+
library at runtime.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import json
|
|
10
|
+
import logging
|
|
11
|
+
import re
|
|
12
|
+
from dataclasses import dataclass, field
|
|
13
|
+
from pathlib import Path
|
|
14
|
+
from typing import Literal
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger("vlora")
|
|
19
|
+
from safetensors.torch import load_file, save_file
|
|
20
|
+
from torch import Tensor
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class LoRAWeights:
|
|
25
|
+
"""Parsed LoRA adapter weights grouped by layer."""
|
|
26
|
+
|
|
27
|
+
layer_names: list[str]
|
|
28
|
+
lora_a: dict[str, Tensor] # layer_name -> (rank, in_features)
|
|
29
|
+
lora_b: dict[str, Tensor] # layer_name -> (out_features, rank)
|
|
30
|
+
rank: int
|
|
31
|
+
metadata: dict = field(default_factory=dict)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# Pattern to extract layer name + side from PEFT state dict keys.
|
|
35
|
+
# Handles: base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight
|
|
36
|
+
# or: model.layers.0.self_attn.q_proj.lora_A.weight
|
|
37
|
+
_LORA_KEY_RE = re.compile(
|
|
38
|
+
r"(?:base_model\.model\.)?(.+)\.(lora_[AB])\.(?:weight|default\.weight)"
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def parse_state_dict(
|
|
43
|
+
state_dict: dict[str, Tensor],
|
|
44
|
+
) -> tuple[dict[str, Tensor], dict[str, Tensor], list[str]]:
|
|
45
|
+
"""Parse a PEFT state dict into grouped A and B weight dicts.
|
|
46
|
+
|
|
47
|
+
Returns:
|
|
48
|
+
lora_a: {layer_name: tensor}
|
|
49
|
+
lora_b: {layer_name: tensor}
|
|
50
|
+
layer_names: sorted unique layer names
|
|
51
|
+
"""
|
|
52
|
+
lora_a: dict[str, Tensor] = {}
|
|
53
|
+
lora_b: dict[str, Tensor] = {}
|
|
54
|
+
|
|
55
|
+
for key, tensor in state_dict.items():
|
|
56
|
+
match = _LORA_KEY_RE.search(key)
|
|
57
|
+
if match is None:
|
|
58
|
+
continue
|
|
59
|
+
layer_name = match.group(1)
|
|
60
|
+
side = match.group(2)
|
|
61
|
+
if side == "lora_A":
|
|
62
|
+
lora_a[layer_name] = tensor
|
|
63
|
+
else:
|
|
64
|
+
lora_b[layer_name] = tensor
|
|
65
|
+
|
|
66
|
+
layer_names = sorted(set(lora_a.keys()) & set(lora_b.keys()))
|
|
67
|
+
# Keep only layers where both A and B exist
|
|
68
|
+
lora_a = {k: lora_a[k] for k in layer_names}
|
|
69
|
+
lora_b = {k: lora_b[k] for k in layer_names}
|
|
70
|
+
|
|
71
|
+
return lora_a, lora_b, layer_names
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def load_adapter(path: str | Path) -> LoRAWeights:
|
|
75
|
+
"""Load a PEFT LoRA adapter from a local directory.
|
|
76
|
+
|
|
77
|
+
Expects adapter_model.safetensors and adapter_config.json in the
|
|
78
|
+
given directory.
|
|
79
|
+
"""
|
|
80
|
+
path = Path(path)
|
|
81
|
+
|
|
82
|
+
# Load safetensors weights
|
|
83
|
+
safetensors_path = path / "adapter_model.safetensors"
|
|
84
|
+
if not safetensors_path.exists():
|
|
85
|
+
raise FileNotFoundError(f"No adapter_model.safetensors in {path}")
|
|
86
|
+
state_dict = load_file(str(safetensors_path))
|
|
87
|
+
|
|
88
|
+
# Load config for metadata
|
|
89
|
+
config_path = path / "adapter_config.json"
|
|
90
|
+
metadata: dict = {}
|
|
91
|
+
rank = 0
|
|
92
|
+
if config_path.exists():
|
|
93
|
+
with open(config_path) as f:
|
|
94
|
+
config = json.load(f)
|
|
95
|
+
rank = config.get("r", 0)
|
|
96
|
+
metadata = config
|
|
97
|
+
|
|
98
|
+
lora_a, lora_b, layer_names = parse_state_dict(state_dict)
|
|
99
|
+
logger.debug("Loaded adapter from %s: %d layers", path, len(layer_names))
|
|
100
|
+
|
|
101
|
+
# Infer rank from weight shapes if not in config
|
|
102
|
+
if rank == 0 and layer_names:
|
|
103
|
+
rank = lora_a[layer_names[0]].shape[0]
|
|
104
|
+
|
|
105
|
+
return LoRAWeights(
|
|
106
|
+
layer_names=layer_names,
|
|
107
|
+
lora_a=lora_a,
|
|
108
|
+
lora_b=lora_b,
|
|
109
|
+
rank=rank,
|
|
110
|
+
metadata=metadata,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def load_adapter_from_hub(repo_id: str, revision: str | None = None) -> LoRAWeights:
|
|
115
|
+
"""Load a PEFT LoRA adapter from HuggingFace Hub.
|
|
116
|
+
|
|
117
|
+
Requires the `huggingface-hub` package (install with `pip install vlora-dev[hub]`).
|
|
118
|
+
"""
|
|
119
|
+
try:
|
|
120
|
+
from huggingface_hub import snapshot_download
|
|
121
|
+
except ImportError:
|
|
122
|
+
raise ImportError(
|
|
123
|
+
"huggingface-hub is required to load from Hub. "
|
|
124
|
+
"Install with: pip install vlora-dev[hub]"
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
local_dir = snapshot_download(
|
|
128
|
+
repo_id,
|
|
129
|
+
revision=revision,
|
|
130
|
+
allow_patterns=["adapter_model.safetensors", "adapter_config.json"],
|
|
131
|
+
)
|
|
132
|
+
return load_adapter(local_dir)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def stack_lora_weights(
|
|
136
|
+
adapters: list[LoRAWeights],
|
|
137
|
+
side: Literal["A", "B"],
|
|
138
|
+
) -> dict[str, Tensor]:
|
|
139
|
+
"""Stack LoRA weight matrices from multiple adapters per layer.
|
|
140
|
+
|
|
141
|
+
For side="A": each adapter's A matrix is (rank, in_features).
|
|
142
|
+
We flatten each to a row vector and stack N adapters into (N, rank*in_features).
|
|
143
|
+
|
|
144
|
+
This produces the "factor data matrix" the paper feeds into SVD.
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
{layer_name: (N, flattened_dim)} stacked matrix.
|
|
148
|
+
"""
|
|
149
|
+
if not adapters:
|
|
150
|
+
raise ValueError("Need at least one adapter to stack")
|
|
151
|
+
|
|
152
|
+
# Use intersection of all adapters' layer names
|
|
153
|
+
layer_set = set(adapters[0].layer_names)
|
|
154
|
+
for adapter in adapters[1:]:
|
|
155
|
+
layer_set &= set(adapter.layer_names)
|
|
156
|
+
layer_names = sorted(layer_set)
|
|
157
|
+
|
|
158
|
+
stacked: dict[str, Tensor] = {}
|
|
159
|
+
for layer in layer_names:
|
|
160
|
+
matrices = []
|
|
161
|
+
for adapter in adapters:
|
|
162
|
+
w = adapter.lora_a[layer] if side == "A" else adapter.lora_b[layer]
|
|
163
|
+
matrices.append(w.flatten())
|
|
164
|
+
stacked[layer] = torch.stack(matrices)
|
|
165
|
+
|
|
166
|
+
return stacked
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def save_adapter(weights: LoRAWeights, path: str | Path) -> None:
|
|
170
|
+
"""Save LoRA weights back to PEFT-compatible format."""
|
|
171
|
+
path = Path(path)
|
|
172
|
+
path.mkdir(parents=True, exist_ok=True)
|
|
173
|
+
|
|
174
|
+
# Rebuild state dict with PEFT key format
|
|
175
|
+
state_dict = {}
|
|
176
|
+
for layer_name in weights.layer_names:
|
|
177
|
+
state_dict[f"base_model.model.{layer_name}.lora_A.weight"] = weights.lora_a[layer_name]
|
|
178
|
+
state_dict[f"base_model.model.{layer_name}.lora_B.weight"] = weights.lora_b[layer_name]
|
|
179
|
+
|
|
180
|
+
save_file(state_dict, str(path / "adapter_model.safetensors"))
|
|
181
|
+
|
|
182
|
+
# Save config — include defaults needed by vLLM/TGI
|
|
183
|
+
config = dict(weights.metadata) if weights.metadata else {}
|
|
184
|
+
config.setdefault("r", weights.rank)
|
|
185
|
+
config.setdefault("lora_alpha", weights.rank) # alpha=rank → scaling=1.0
|
|
186
|
+
config.setdefault("peft_type", "LORA")
|
|
187
|
+
config.setdefault("task_type", "CAUSAL_LM")
|
|
188
|
+
config.setdefault("bias", "none")
|
|
189
|
+
config.setdefault("lora_dropout", 0.0)
|
|
190
|
+
with open(path / "adapter_config.json", "w") as f:
|
|
191
|
+
json.dump(config, f, indent=2)
|
vlora/merge.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
"""Adapter merging — task arithmetic, TIES, and DARE.
|
|
2
|
+
|
|
3
|
+
These techniques operate on LoRA weight matrices directly, producing
|
|
4
|
+
a single merged adapter from multiple inputs. All three methods work
|
|
5
|
+
on a per-layer basis and return a new LoRAWeights object.
|
|
6
|
+
|
|
7
|
+
References:
|
|
8
|
+
- Task Arithmetic: Ilharco et al., "Editing Models with Task Arithmetic" (2023)
|
|
9
|
+
- TIES: Yadav et al., "TIES-Merging: Resolving Interference When Merging Models" (2023)
|
|
10
|
+
- DARE: Yu et al., "Language Models are Super Mario" (2024)
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from __future__ import annotations
|
|
14
|
+
|
|
15
|
+
import logging
|
|
16
|
+
from typing import Literal
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from torch import Tensor
|
|
20
|
+
|
|
21
|
+
from vlora._validate import check_adapters_compatible
|
|
22
|
+
from vlora.io import LoRAWeights
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger("vlora")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def task_arithmetic(
|
|
28
|
+
adapters: list[LoRAWeights],
|
|
29
|
+
weights: list[float] | None = None,
|
|
30
|
+
) -> LoRAWeights:
|
|
31
|
+
"""Merge adapters via weighted average of their weight matrices.
|
|
32
|
+
|
|
33
|
+
Each adapter's LoRA A and B matrices are averaged (optionally weighted)
|
|
34
|
+
per-layer to produce a single merged adapter.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
adapters: List of adapters to merge (must share layers and rank).
|
|
38
|
+
weights: Per-adapter weights. Defaults to uniform (1/N each).
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
Merged LoRAWeights.
|
|
42
|
+
"""
|
|
43
|
+
check_adapters_compatible(adapters)
|
|
44
|
+
n = len(adapters)
|
|
45
|
+
if n == 0:
|
|
46
|
+
raise ValueError("Need at least one adapter to merge.")
|
|
47
|
+
|
|
48
|
+
if weights is None:
|
|
49
|
+
weights = [1.0 / n] * n
|
|
50
|
+
if len(weights) != n:
|
|
51
|
+
raise ValueError(f"weights length ({len(weights)}) must match adapters ({n})")
|
|
52
|
+
|
|
53
|
+
# Use intersection of layers
|
|
54
|
+
layer_names = sorted(set.intersection(*(set(a.layer_names) for a in adapters)))
|
|
55
|
+
if not layer_names:
|
|
56
|
+
raise ValueError("Adapters share no common layers.")
|
|
57
|
+
|
|
58
|
+
logger.info("Task arithmetic merge: %d adapters, %d layers", n, len(layer_names))
|
|
59
|
+
|
|
60
|
+
lora_a: dict[str, Tensor] = {}
|
|
61
|
+
lora_b: dict[str, Tensor] = {}
|
|
62
|
+
|
|
63
|
+
for layer in layer_names:
|
|
64
|
+
merged_a = sum(w * adapters[i].lora_a[layer] for i, w in enumerate(weights))
|
|
65
|
+
merged_b = sum(w * adapters[i].lora_b[layer] for i, w in enumerate(weights))
|
|
66
|
+
lora_a[layer] = merged_a
|
|
67
|
+
lora_b[layer] = merged_b
|
|
68
|
+
|
|
69
|
+
return LoRAWeights(
|
|
70
|
+
layer_names=layer_names,
|
|
71
|
+
lora_a=lora_a,
|
|
72
|
+
lora_b=lora_b,
|
|
73
|
+
rank=adapters[0].rank,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def ties_merge(
|
|
78
|
+
adapters: list[LoRAWeights],
|
|
79
|
+
density: float = 0.5,
|
|
80
|
+
weights: list[float] | None = None,
|
|
81
|
+
) -> LoRAWeights:
|
|
82
|
+
"""Merge adapters using TIES: Trim, Elect sign, and merge.
|
|
83
|
+
|
|
84
|
+
1. Trim: zero out the smallest elements per adapter (keep top `density` fraction)
|
|
85
|
+
2. Elect sign: for each position, choose the majority sign across adapters
|
|
86
|
+
3. Merge: average only the values that agree with the elected sign
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
adapters: List of adapters to merge.
|
|
90
|
+
density: Fraction of elements to keep per adapter (0, 1]. Default 0.5.
|
|
91
|
+
weights: Per-adapter weights for the final average.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
Merged LoRAWeights.
|
|
95
|
+
"""
|
|
96
|
+
check_adapters_compatible(adapters)
|
|
97
|
+
n = len(adapters)
|
|
98
|
+
if n == 0:
|
|
99
|
+
raise ValueError("Need at least one adapter to merge.")
|
|
100
|
+
if not 0 < density <= 1:
|
|
101
|
+
raise ValueError(f"density must be in (0, 1], got {density}")
|
|
102
|
+
|
|
103
|
+
if weights is None:
|
|
104
|
+
weights = [1.0 / n] * n
|
|
105
|
+
if len(weights) != n:
|
|
106
|
+
raise ValueError(f"weights length ({len(weights)}) must match adapters ({n})")
|
|
107
|
+
|
|
108
|
+
layer_names = sorted(set.intersection(*(set(a.layer_names) for a in adapters)))
|
|
109
|
+
if not layer_names:
|
|
110
|
+
raise ValueError("Adapters share no common layers.")
|
|
111
|
+
|
|
112
|
+
logger.info("TIES merge: %d adapters, density=%.2f, %d layers", n, density, len(layer_names))
|
|
113
|
+
|
|
114
|
+
lora_a: dict[str, Tensor] = {}
|
|
115
|
+
lora_b: dict[str, Tensor] = {}
|
|
116
|
+
|
|
117
|
+
for layer in layer_names:
|
|
118
|
+
for side, out_dict, attr in [("a", lora_a, "lora_a"), ("b", lora_b, "lora_b")]:
|
|
119
|
+
# Stack all adapters for this layer/side
|
|
120
|
+
tensors = [getattr(adapters[i], attr)[layer].clone() for i in range(n)]
|
|
121
|
+
|
|
122
|
+
# Step 1: Trim — zero out smallest elements per adapter
|
|
123
|
+
for t in tensors:
|
|
124
|
+
flat = t.flatten().abs()
|
|
125
|
+
k = max(1, int(density * flat.numel()))
|
|
126
|
+
threshold = flat.topk(k).values[-1]
|
|
127
|
+
t[t.abs() < threshold] = 0.0
|
|
128
|
+
|
|
129
|
+
stacked = torch.stack(tensors) # (N, *shape)
|
|
130
|
+
|
|
131
|
+
# Step 2: Elect sign — majority vote at each position
|
|
132
|
+
sign_votes = (stacked > 0).float().sum(dim=0) - (stacked < 0).float().sum(dim=0)
|
|
133
|
+
elected_sign = sign_votes.sign()
|
|
134
|
+
# Ties go positive (convention)
|
|
135
|
+
elected_sign[elected_sign == 0] = 1.0
|
|
136
|
+
|
|
137
|
+
# Step 3: Merge — weighted average of values matching elected sign
|
|
138
|
+
mask = (stacked.sign() == elected_sign.unsqueeze(0))
|
|
139
|
+
# Apply weights
|
|
140
|
+
w = torch.tensor(weights, dtype=stacked.dtype).view(-1, *([1] * (stacked.dim() - 1)))
|
|
141
|
+
weighted = stacked * w
|
|
142
|
+
# Zero out values with wrong sign
|
|
143
|
+
weighted = weighted * mask.float()
|
|
144
|
+
# Sum and normalize by number of contributors (avoid division by zero)
|
|
145
|
+
contributor_count = mask.float().sum(dim=0).clamp(min=1)
|
|
146
|
+
merged = weighted.sum(dim=0) * (n / contributor_count)
|
|
147
|
+
|
|
148
|
+
out_dict[layer] = merged
|
|
149
|
+
|
|
150
|
+
return LoRAWeights(
|
|
151
|
+
layer_names=layer_names,
|
|
152
|
+
lora_a=lora_a,
|
|
153
|
+
lora_b=lora_b,
|
|
154
|
+
rank=adapters[0].rank,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def dare_merge(
|
|
159
|
+
adapters: list[LoRAWeights],
|
|
160
|
+
drop_rate: float = 0.5,
|
|
161
|
+
weights: list[float] | None = None,
|
|
162
|
+
seed: int | None = None,
|
|
163
|
+
) -> LoRAWeights:
|
|
164
|
+
"""Merge adapters using DARE: Drop And REscale.
|
|
165
|
+
|
|
166
|
+
For each adapter, randomly drop elements with probability `drop_rate`
|
|
167
|
+
and rescale survivors by 1/(1-drop_rate). Then average the results.
|
|
168
|
+
|
|
169
|
+
Args:
|
|
170
|
+
adapters: List of adapters to merge.
|
|
171
|
+
drop_rate: Probability of dropping each element. Default 0.5.
|
|
172
|
+
weights: Per-adapter weights for the final average.
|
|
173
|
+
seed: Random seed for reproducibility.
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
Merged LoRAWeights.
|
|
177
|
+
"""
|
|
178
|
+
check_adapters_compatible(adapters)
|
|
179
|
+
n = len(adapters)
|
|
180
|
+
if n == 0:
|
|
181
|
+
raise ValueError("Need at least one adapter to merge.")
|
|
182
|
+
if not 0 <= drop_rate < 1:
|
|
183
|
+
raise ValueError(f"drop_rate must be in [0, 1), got {drop_rate}")
|
|
184
|
+
|
|
185
|
+
if weights is None:
|
|
186
|
+
weights = [1.0 / n] * n
|
|
187
|
+
if len(weights) != n:
|
|
188
|
+
raise ValueError(f"weights length ({len(weights)}) must match adapters ({n})")
|
|
189
|
+
|
|
190
|
+
layer_names = sorted(set.intersection(*(set(a.layer_names) for a in adapters)))
|
|
191
|
+
if not layer_names:
|
|
192
|
+
raise ValueError("Adapters share no common layers.")
|
|
193
|
+
|
|
194
|
+
logger.info("DARE merge: %d adapters, drop_rate=%.2f, %d layers", n, drop_rate, len(layer_names))
|
|
195
|
+
|
|
196
|
+
if seed is not None:
|
|
197
|
+
torch.manual_seed(seed)
|
|
198
|
+
|
|
199
|
+
rescale = 1.0 / (1.0 - drop_rate) if drop_rate > 0 else 1.0
|
|
200
|
+
|
|
201
|
+
lora_a: dict[str, Tensor] = {}
|
|
202
|
+
lora_b: dict[str, Tensor] = {}
|
|
203
|
+
|
|
204
|
+
for layer in layer_names:
|
|
205
|
+
for side, out_dict, attr in [("a", lora_a, "lora_a"), ("b", lora_b, "lora_b")]:
|
|
206
|
+
merged = torch.zeros_like(getattr(adapters[0], attr)[layer])
|
|
207
|
+
|
|
208
|
+
for i, adapter in enumerate(adapters):
|
|
209
|
+
t = getattr(adapter, attr)[layer].clone()
|
|
210
|
+
if drop_rate > 0:
|
|
211
|
+
mask = torch.bernoulli(torch.full_like(t, 1.0 - drop_rate))
|
|
212
|
+
t = t * mask * rescale
|
|
213
|
+
merged = merged + weights[i] * t
|
|
214
|
+
|
|
215
|
+
out_dict[layer] = merged
|
|
216
|
+
|
|
217
|
+
return LoRAWeights(
|
|
218
|
+
layer_names=layer_names,
|
|
219
|
+
lora_a=lora_a,
|
|
220
|
+
lora_b=lora_b,
|
|
221
|
+
rank=adapters[0].rank,
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
MERGE_METHODS = {
|
|
226
|
+
"average": task_arithmetic,
|
|
227
|
+
"ties": ties_merge,
|
|
228
|
+
"dare": dare_merge,
|
|
229
|
+
}
|
vlora/model.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
1
|
+
"""VLoRAModel — inference wrapper that applies reconstructed LoRA deltas."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
from torch import Tensor
|
|
10
|
+
|
|
11
|
+
from vlora._validate import check_task_exists
|
|
12
|
+
from vlora.subspace import SharedSubspace
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class VLoRAModel(nn.Module):
|
|
16
|
+
"""Wraps a base model with a shared subspace for multi-task LoRA inference.
|
|
17
|
+
|
|
18
|
+
Reconstructs task-specific LoRA deltas on demand and applies them to
|
|
19
|
+
the base model's linear layers during forward pass.
|
|
20
|
+
|
|
21
|
+
Usage:
|
|
22
|
+
subspace = SharedSubspace.load("shared_subspace/")
|
|
23
|
+
base_model = AutoModelForCausalLM.from_pretrained("model-name")
|
|
24
|
+
model = VLoRAModel(base_model, subspace)
|
|
25
|
+
|
|
26
|
+
model.set_task("task_0")
|
|
27
|
+
output = model(input_ids)
|
|
28
|
+
|
|
29
|
+
model.set_task("task_1") # switches adapter, cached if same task
|
|
30
|
+
output = model(input_ids)
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
base_model: nn.Module,
|
|
36
|
+
subspace: SharedSubspace,
|
|
37
|
+
scaling: float | None = None,
|
|
38
|
+
lora_alpha: float | None = None,
|
|
39
|
+
):
|
|
40
|
+
super().__init__()
|
|
41
|
+
self.base_model = base_model
|
|
42
|
+
self.subspace = subspace
|
|
43
|
+
|
|
44
|
+
# Resolve scaling: explicit scaling > lora_alpha/rank > 1.0
|
|
45
|
+
if scaling is not None:
|
|
46
|
+
self.scaling = scaling
|
|
47
|
+
elif lora_alpha is not None:
|
|
48
|
+
self.scaling = lora_alpha / subspace.rank
|
|
49
|
+
else:
|
|
50
|
+
self.scaling = 1.0
|
|
51
|
+
self._active_task: str | None = None
|
|
52
|
+
self._cached_deltas: dict[str, Tensor] | None = None
|
|
53
|
+
self._hooks: list[torch.utils.hooks.RemovableHook] = []
|
|
54
|
+
|
|
55
|
+
def set_task(self, task_id: str) -> None:
|
|
56
|
+
"""Set the active task adapter. Reconstructs and caches if changed."""
|
|
57
|
+
if task_id == self._active_task:
|
|
58
|
+
return
|
|
59
|
+
|
|
60
|
+
check_task_exists(self.subspace, task_id)
|
|
61
|
+
|
|
62
|
+
# Reconstruct and cache the LoRA deltas
|
|
63
|
+
weights = self.subspace.reconstruct(task_id)
|
|
64
|
+
self._cached_deltas = {}
|
|
65
|
+
for layer_name in weights.layer_names:
|
|
66
|
+
# delta_W = B @ A
|
|
67
|
+
delta = weights.lora_b[layer_name] @ weights.lora_a[layer_name]
|
|
68
|
+
self._cached_deltas[layer_name] = delta
|
|
69
|
+
|
|
70
|
+
self._active_task = task_id
|
|
71
|
+
self._apply_hooks()
|
|
72
|
+
|
|
73
|
+
def clear_task(self) -> None:
|
|
74
|
+
"""Remove the active task adapter."""
|
|
75
|
+
self._remove_hooks()
|
|
76
|
+
self._active_task = None
|
|
77
|
+
self._cached_deltas = None
|
|
78
|
+
|
|
79
|
+
def _apply_hooks(self) -> None:
|
|
80
|
+
"""Register forward hooks on matching linear layers."""
|
|
81
|
+
self._remove_hooks()
|
|
82
|
+
|
|
83
|
+
if self._cached_deltas is None:
|
|
84
|
+
return
|
|
85
|
+
|
|
86
|
+
for name, module in self.base_model.named_modules():
|
|
87
|
+
if name in self._cached_deltas and isinstance(module, nn.Linear):
|
|
88
|
+
delta = self._cached_deltas[name]
|
|
89
|
+
hook = module.register_forward_hook(
|
|
90
|
+
self._make_lora_hook(delta)
|
|
91
|
+
)
|
|
92
|
+
self._hooks.append(hook)
|
|
93
|
+
|
|
94
|
+
def _remove_hooks(self) -> None:
|
|
95
|
+
"""Remove all registered forward hooks."""
|
|
96
|
+
for hook in self._hooks:
|
|
97
|
+
hook.remove()
|
|
98
|
+
self._hooks.clear()
|
|
99
|
+
|
|
100
|
+
def _make_lora_hook(self, delta: Tensor):
|
|
101
|
+
"""Create a forward hook that adds LoRA delta to the output."""
|
|
102
|
+
scaling = self.scaling
|
|
103
|
+
|
|
104
|
+
def hook(module: nn.Module, input: Any, output: Tensor) -> Tensor:
|
|
105
|
+
# input[0] is the input tensor to the linear layer
|
|
106
|
+
x = input[0] if isinstance(input, tuple) else input
|
|
107
|
+
lora_out = x @ delta.T.to(x.device, x.dtype)
|
|
108
|
+
return output + scaling * lora_out
|
|
109
|
+
|
|
110
|
+
return hook
|
|
111
|
+
|
|
112
|
+
def forward(self, *args, **kwargs):
|
|
113
|
+
"""Forward pass through the base model with active LoRA adapter."""
|
|
114
|
+
return self.base_model(*args, **kwargs)
|
|
115
|
+
|
|
116
|
+
@property
|
|
117
|
+
def active_task(self) -> str | None:
|
|
118
|
+
"""Currently active task ID, or None."""
|
|
119
|
+
return self._active_task
|
|
120
|
+
|
|
121
|
+
@property
|
|
122
|
+
def available_tasks(self) -> list[str]:
|
|
123
|
+
"""List of available task IDs."""
|
|
124
|
+
return sorted(self.subspace.tasks.keys())
|
|
125
|
+
|
|
126
|
+
def reconstruct_state_dict(self, task_id: str) -> dict[str, Tensor]:
|
|
127
|
+
"""Get the LoRA delta weight dict for a task without applying hooks.
|
|
128
|
+
|
|
129
|
+
Returns dict of {layer_name: delta_W} where delta_W = B @ A.
|
|
130
|
+
Useful for manual integration with custom model architectures.
|
|
131
|
+
"""
|
|
132
|
+
weights = self.subspace.reconstruct(task_id)
|
|
133
|
+
deltas = {}
|
|
134
|
+
for layer_name in weights.layer_names:
|
|
135
|
+
deltas[layer_name] = weights.lora_b[layer_name] @ weights.lora_a[layer_name]
|
|
136
|
+
return deltas
|
|
137
|
+
|
|
138
|
+
def compile(self, **kwargs) -> VLoRAModel:
|
|
139
|
+
"""Compile the base model with torch.compile for faster inference.
|
|
140
|
+
|
|
141
|
+
Passes all kwargs to torch.compile(). The LoRA hooks remain
|
|
142
|
+
uncompiled (they're lightweight matmuls) while the base model
|
|
143
|
+
benefits from fusion and kernel optimization.
|
|
144
|
+
|
|
145
|
+
Returns self for chaining.
|
|
146
|
+
"""
|
|
147
|
+
self.base_model = torch.compile(self.base_model, **kwargs)
|
|
148
|
+
return self
|