vlora-dev 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vlora/__init__.py +73 -0
- vlora/_validate.py +82 -0
- vlora/analysis.py +191 -0
- vlora/cli.py +430 -0
- vlora/integrations/__init__.py +1 -0
- vlora/integrations/huggingface.py +163 -0
- vlora/io.py +191 -0
- vlora/merge.py +229 -0
- vlora/model.py +148 -0
- vlora/ops.py +229 -0
- vlora/pipeline.py +70 -0
- vlora/router.py +173 -0
- vlora/subspace.py +651 -0
- vlora/training.py +149 -0
- vlora_dev-0.2.0.dist-info/METADATA +409 -0
- vlora_dev-0.2.0.dist-info/RECORD +19 -0
- vlora_dev-0.2.0.dist-info/WHEEL +4 -0
- vlora_dev-0.2.0.dist-info/entry_points.txt +2 -0
- vlora_dev-0.2.0.dist-info/licenses/LICENSE +190 -0
vlora/cli.py
ADDED
|
@@ -0,0 +1,430 @@
|
|
|
1
|
+
"""vlora CLI — command-line interface for adapter management."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json as json_mod
|
|
6
|
+
import logging
|
|
7
|
+
import time
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
import click
|
|
11
|
+
|
|
12
|
+
from vlora.io import LoRAWeights, load_adapter, save_adapter
|
|
13
|
+
from vlora.ops import explained_variance_ratio
|
|
14
|
+
from vlora.subspace import SharedSubspace
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@click.group()
|
|
18
|
+
@click.version_option(package_name="vlora")
|
|
19
|
+
@click.option("-v", "--verbose", is_flag=True, help="Enable verbose logging.")
|
|
20
|
+
def cli(verbose: bool):
|
|
21
|
+
"""vLoRA — Shared low-rank subspaces for LoRA adapter management."""
|
|
22
|
+
level = logging.DEBUG if verbose else logging.WARNING
|
|
23
|
+
logging.basicConfig(
|
|
24
|
+
level=level,
|
|
25
|
+
format="%(name)s %(levelname)s: %(message)s",
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
@cli.command()
|
|
30
|
+
@click.argument("subspace_path", type=click.Path(exists=True))
|
|
31
|
+
@click.option("--json", "as_json", is_flag=True, help="Output as JSON.")
|
|
32
|
+
def info(subspace_path: str, as_json: bool):
|
|
33
|
+
"""Show subspace stats: tasks, layers, compression ratios."""
|
|
34
|
+
sub = SharedSubspace.load(subspace_path)
|
|
35
|
+
|
|
36
|
+
stats = sub.compression_stats()
|
|
37
|
+
|
|
38
|
+
# Variance explained (first layer)
|
|
39
|
+
first_layer = sub.layer_names[0]
|
|
40
|
+
var_a = explained_variance_ratio(sub.singular_values_a[first_layer])
|
|
41
|
+
var_b = explained_variance_ratio(sub.singular_values_b[first_layer])
|
|
42
|
+
k = sub.num_components
|
|
43
|
+
var_a_val = var_a[k - 1].item() if k <= len(var_a) else None
|
|
44
|
+
var_b_val = var_b[k - 1].item() if k <= len(var_b) else None
|
|
45
|
+
|
|
46
|
+
if as_json:
|
|
47
|
+
output = {
|
|
48
|
+
"path": subspace_path,
|
|
49
|
+
"num_components": sub.num_components,
|
|
50
|
+
"rank": sub.rank,
|
|
51
|
+
"num_layers": len(sub.layer_names),
|
|
52
|
+
"num_tasks": len(sub.tasks),
|
|
53
|
+
"task_ids": sorted(sub.tasks.keys()),
|
|
54
|
+
"variance_explained_a": var_a_val,
|
|
55
|
+
"variance_explained_b": var_b_val,
|
|
56
|
+
**stats,
|
|
57
|
+
}
|
|
58
|
+
click.echo(json_mod.dumps(output, indent=2, default=str))
|
|
59
|
+
return
|
|
60
|
+
|
|
61
|
+
click.echo(f"\n Subspace: {subspace_path}")
|
|
62
|
+
click.echo(f" Components (k): {sub.num_components}")
|
|
63
|
+
click.echo(f" LoRA rank: {sub.rank}")
|
|
64
|
+
click.echo(f" Layers: {len(sub.layer_names)}")
|
|
65
|
+
click.echo(f" Tasks: {len(sub.tasks)}")
|
|
66
|
+
|
|
67
|
+
if sub.tasks:
|
|
68
|
+
click.echo(f"\n Task IDs:")
|
|
69
|
+
for tid in sorted(sub.tasks.keys()):
|
|
70
|
+
click.echo(f" - {tid}")
|
|
71
|
+
|
|
72
|
+
click.echo(f"\n Variance explained (first layer, k={k}):")
|
|
73
|
+
if var_a_val is not None:
|
|
74
|
+
click.echo(f" A: {var_a_val:.1%}")
|
|
75
|
+
if var_b_val is not None:
|
|
76
|
+
click.echo(f" B: {var_b_val:.1%}")
|
|
77
|
+
|
|
78
|
+
# Compression estimate
|
|
79
|
+
n = len(sub.tasks)
|
|
80
|
+
if n > 0:
|
|
81
|
+
ratio = stats["compression_ratio"]
|
|
82
|
+
click.echo(f"\n Compression ratio: {ratio:.1f}x ({n} adapters)")
|
|
83
|
+
|
|
84
|
+
click.echo()
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
@cli.command()
|
|
88
|
+
@click.argument("adapter_dirs", nargs=-1, required=True, type=click.Path(exists=True))
|
|
89
|
+
@click.option("-o", "--output", required=True, type=click.Path(), help="Output directory for shared subspace.")
|
|
90
|
+
@click.option("-k", "--num-components", type=int, default=None, help="Number of basis components (auto if not set).")
|
|
91
|
+
@click.option("--variance-threshold", type=float, default=0.6, help="Variance threshold for auto k selection.")
|
|
92
|
+
@click.option("--adaptive-k", is_flag=True, help="Use per-layer adaptive k selection.")
|
|
93
|
+
def compress(adapter_dirs: tuple[str, ...], output: str, num_components: int | None, variance_threshold: float, adaptive_k: bool):
|
|
94
|
+
"""Build shared subspace from adapter directories."""
|
|
95
|
+
click.echo(f"\n Loading {len(adapter_dirs)} adapters...")
|
|
96
|
+
|
|
97
|
+
adapters = []
|
|
98
|
+
task_ids = []
|
|
99
|
+
for d in adapter_dirs:
|
|
100
|
+
path = Path(d)
|
|
101
|
+
adapters.append(load_adapter(path))
|
|
102
|
+
task_ids.append(path.name)
|
|
103
|
+
click.echo(f" Loaded: {path.name}")
|
|
104
|
+
|
|
105
|
+
click.echo(f" Building subspace...")
|
|
106
|
+
sub = SharedSubspace.from_adapters(
|
|
107
|
+
adapters,
|
|
108
|
+
task_ids=task_ids,
|
|
109
|
+
num_components=num_components,
|
|
110
|
+
variance_threshold=variance_threshold,
|
|
111
|
+
adaptive_k=adaptive_k,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
sub.save(output)
|
|
115
|
+
click.echo(f" Saved to: {output}")
|
|
116
|
+
click.echo(f" Components: {sub.num_components}, Layers: {len(sub.layer_names)}, Tasks: {len(sub.tasks)}")
|
|
117
|
+
click.echo()
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@cli.command("export")
|
|
121
|
+
@click.argument("subspace_path", type=click.Path(exists=True))
|
|
122
|
+
@click.argument("task_id")
|
|
123
|
+
@click.option("-o", "--output", required=True, type=click.Path(), help="Output directory for PEFT adapter.")
|
|
124
|
+
@click.option("--alpha", type=float, default=None, help="LoRA alpha for adapter_config.json (default: same as rank).")
|
|
125
|
+
@click.option("--base-model", type=str, default=None, help="Base model name for adapter_config.json.")
|
|
126
|
+
@click.option("--target-modules", type=str, default=None, help="Comma-separated target modules for adapter_config.json.")
|
|
127
|
+
def export_cmd(subspace_path: str, task_id: str, output: str, alpha: float | None, base_model: str | None, target_modules: str | None):
|
|
128
|
+
"""Reconstruct a task adapter to PEFT format."""
|
|
129
|
+
sub = SharedSubspace.load(subspace_path)
|
|
130
|
+
|
|
131
|
+
if task_id not in sub.tasks:
|
|
132
|
+
available = ", ".join(sorted(sub.tasks.keys()))
|
|
133
|
+
raise click.ClickException(f"Unknown task '{task_id}'. Available: {available}")
|
|
134
|
+
|
|
135
|
+
click.echo(f"\n Reconstructing '{task_id}'...")
|
|
136
|
+
weights = sub.reconstruct(task_id)
|
|
137
|
+
|
|
138
|
+
# Enrich metadata for serving compatibility
|
|
139
|
+
if alpha is not None:
|
|
140
|
+
weights.metadata["lora_alpha"] = alpha
|
|
141
|
+
if base_model is not None:
|
|
142
|
+
weights.metadata["base_model_name_or_path"] = base_model
|
|
143
|
+
if target_modules is not None:
|
|
144
|
+
weights.metadata["target_modules"] = [m.strip() for m in target_modules.split(",")]
|
|
145
|
+
|
|
146
|
+
save_adapter(weights, output)
|
|
147
|
+
click.echo(f" Exported to: {output}")
|
|
148
|
+
click.echo()
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
@cli.command()
|
|
152
|
+
@click.argument("subspace_path", type=click.Path(exists=True))
|
|
153
|
+
@click.argument("adapter_dir", type=click.Path(exists=True))
|
|
154
|
+
@click.option("--task-id", required=True, help="ID for the new task.")
|
|
155
|
+
@click.option("--incremental", is_flag=True, help="Use fast incremental absorb (approximate).")
|
|
156
|
+
def add(subspace_path: str, adapter_dir: str, task_id: str, incremental: bool):
|
|
157
|
+
"""Absorb a new adapter into an existing subspace."""
|
|
158
|
+
sub = SharedSubspace.load(subspace_path)
|
|
159
|
+
|
|
160
|
+
click.echo(f"\n Loading adapter from {adapter_dir}...")
|
|
161
|
+
adapter = load_adapter(adapter_dir)
|
|
162
|
+
|
|
163
|
+
method = "incremental" if incremental else "full SVD"
|
|
164
|
+
click.echo(f" Absorbing as '{task_id}' ({method})...")
|
|
165
|
+
if incremental:
|
|
166
|
+
sub.absorb_incremental(adapter, task_id)
|
|
167
|
+
else:
|
|
168
|
+
sub.absorb(adapter, task_id)
|
|
169
|
+
|
|
170
|
+
sub.save(subspace_path)
|
|
171
|
+
click.echo(f" Subspace updated. Tasks: {len(sub.tasks)}")
|
|
172
|
+
click.echo()
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
@cli.command()
|
|
176
|
+
@click.argument("adapter_dirs", nargs=-1, required=True, type=click.Path(exists=True))
|
|
177
|
+
@click.option("--threshold", type=float, default=0.9, help="Similarity threshold for clustering.")
|
|
178
|
+
@click.option("--json", "as_json", is_flag=True, help="Output as JSON.")
|
|
179
|
+
def analyze(adapter_dirs: tuple[str, ...], threshold: float, as_json: bool):
|
|
180
|
+
"""Analyze adapter similarity and find redundant clusters."""
|
|
181
|
+
from vlora.analysis import compute_similarity_matrix, find_clusters
|
|
182
|
+
|
|
183
|
+
adapters = []
|
|
184
|
+
names = []
|
|
185
|
+
for d in adapter_dirs:
|
|
186
|
+
path = Path(d)
|
|
187
|
+
adapters.append(load_adapter(path))
|
|
188
|
+
names.append(path.name)
|
|
189
|
+
|
|
190
|
+
if len(adapters) < 2:
|
|
191
|
+
raise click.ClickException("Need at least 2 adapters for analysis.")
|
|
192
|
+
|
|
193
|
+
sim = compute_similarity_matrix(adapters)
|
|
194
|
+
clusters = find_clusters(sim, threshold=threshold)
|
|
195
|
+
|
|
196
|
+
if as_json:
|
|
197
|
+
sim_dict = {}
|
|
198
|
+
for i, name_i in enumerate(names):
|
|
199
|
+
sim_dict[name_i] = {names[j]: sim[i, j].item() for j in range(len(names))}
|
|
200
|
+
output = {
|
|
201
|
+
"similarity_matrix": sim_dict,
|
|
202
|
+
"clusters": [[names[i] for i in c] for c in clusters],
|
|
203
|
+
"threshold": threshold,
|
|
204
|
+
"redundant_count": sum(len(c) - 1 for c in clusters if len(c) > 1),
|
|
205
|
+
}
|
|
206
|
+
click.echo(json_mod.dumps(output, indent=2))
|
|
207
|
+
return
|
|
208
|
+
|
|
209
|
+
click.echo(f"\n Loading {len(adapter_dirs)} adapters...")
|
|
210
|
+
for n in names:
|
|
211
|
+
click.echo(f" Loaded: {n}")
|
|
212
|
+
|
|
213
|
+
click.echo(f"\n Pairwise Cosine Similarity:")
|
|
214
|
+
header = " " + " " * 20 + " ".join(f"{n[:8]:>8}" for n in names)
|
|
215
|
+
click.echo(header)
|
|
216
|
+
for i, name in enumerate(names):
|
|
217
|
+
row = f" {name[:20]:<20}"
|
|
218
|
+
for j in range(len(names)):
|
|
219
|
+
val = sim[i, j].item()
|
|
220
|
+
row += f" {val:8.3f}"
|
|
221
|
+
click.echo(row)
|
|
222
|
+
|
|
223
|
+
clusters = find_clusters(sim, threshold=threshold)
|
|
224
|
+
click.echo(f"\n Clusters (threshold={threshold}):")
|
|
225
|
+
for ci, cluster in enumerate(clusters):
|
|
226
|
+
members = ", ".join(names[i] for i in cluster)
|
|
227
|
+
click.echo(f" Cluster {ci + 1}: {members}")
|
|
228
|
+
|
|
229
|
+
redundant = sum(len(c) - 1 for c in clusters if len(c) > 1)
|
|
230
|
+
if redundant > 0:
|
|
231
|
+
click.echo(f"\n Potentially redundant adapters: {redundant}")
|
|
232
|
+
else:
|
|
233
|
+
click.echo(f"\n No redundant adapters found at threshold={threshold}")
|
|
234
|
+
|
|
235
|
+
click.echo()
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
@cli.command()
|
|
239
|
+
@click.argument("subspace_path", type=click.Path(exists=True))
|
|
240
|
+
def validate(subspace_path: str):
|
|
241
|
+
"""Run health checks on a subspace."""
|
|
242
|
+
import torch
|
|
243
|
+
|
|
244
|
+
sub = SharedSubspace.load(subspace_path)
|
|
245
|
+
issues = {"errors": [], "warnings": []}
|
|
246
|
+
|
|
247
|
+
click.echo(f"\n Validating: {subspace_path}")
|
|
248
|
+
click.echo(f" Tasks: {len(sub.tasks)}, Layers: {len(sub.layer_names)}, k={sub.num_components}")
|
|
249
|
+
|
|
250
|
+
# Check for NaN/Inf in components and means
|
|
251
|
+
for layer in sub.layer_names:
|
|
252
|
+
for name, tensor in [
|
|
253
|
+
(f"{layer}.components_a", sub.components_a[layer]),
|
|
254
|
+
(f"{layer}.components_b", sub.components_b[layer]),
|
|
255
|
+
(f"{layer}.means_a", sub.means_a[layer]),
|
|
256
|
+
(f"{layer}.means_b", sub.means_b[layer]),
|
|
257
|
+
]:
|
|
258
|
+
if torch.isnan(tensor).any():
|
|
259
|
+
issues["errors"].append(f"NaN in {name}")
|
|
260
|
+
if torch.isinf(tensor).any():
|
|
261
|
+
issues["errors"].append(f"Inf in {name}")
|
|
262
|
+
|
|
263
|
+
# Check component orthonormality
|
|
264
|
+
for layer in sub.layer_names:
|
|
265
|
+
for side, comps in [("A", sub.components_a[layer]), ("B", sub.components_b[layer])]:
|
|
266
|
+
if comps.shape[0] > 0:
|
|
267
|
+
gram = comps @ comps.T
|
|
268
|
+
eye = torch.eye(comps.shape[0])
|
|
269
|
+
err = (gram - eye).abs().max().item()
|
|
270
|
+
if err > 0.01:
|
|
271
|
+
issues["warnings"].append(
|
|
272
|
+
f"{layer}.{side} components not orthonormal (max error: {err:.4f})"
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
# Check task loadings consistency
|
|
276
|
+
for tid, proj in sub.tasks.items():
|
|
277
|
+
for layer in sub.layer_names:
|
|
278
|
+
k_a = sub.components_a[layer].shape[0]
|
|
279
|
+
k_b = sub.components_b[layer].shape[0]
|
|
280
|
+
if proj.loadings_a[layer].shape[0] != k_a:
|
|
281
|
+
issues["errors"].append(
|
|
282
|
+
f"Task '{tid}' loadings_a mismatch at {layer}: "
|
|
283
|
+
f"expected {k_a}, got {proj.loadings_a[layer].shape[0]}"
|
|
284
|
+
)
|
|
285
|
+
if proj.loadings_b[layer].shape[0] != k_b:
|
|
286
|
+
issues["errors"].append(
|
|
287
|
+
f"Task '{tid}' loadings_b mismatch at {layer}: "
|
|
288
|
+
f"expected {k_b}, got {proj.loadings_b[layer].shape[0]}"
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
# Report
|
|
292
|
+
if issues["errors"]:
|
|
293
|
+
click.echo(f"\n ERRORS ({len(issues['errors'])}):")
|
|
294
|
+
for err in issues["errors"]:
|
|
295
|
+
click.echo(f" [ERROR] {err}")
|
|
296
|
+
if issues["warnings"]:
|
|
297
|
+
click.echo(f"\n WARNINGS ({len(issues['warnings'])}):")
|
|
298
|
+
for warn in issues["warnings"]:
|
|
299
|
+
click.echo(f" [WARN] {warn}")
|
|
300
|
+
if not issues["errors"] and not issues["warnings"]:
|
|
301
|
+
click.echo(f"\n All checks passed.")
|
|
302
|
+
|
|
303
|
+
click.echo()
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
@cli.command()
|
|
307
|
+
@click.argument("subspace_path", type=click.Path(exists=True))
|
|
308
|
+
@click.argument("task_a")
|
|
309
|
+
@click.argument("task_b")
|
|
310
|
+
def diff(subspace_path: str, task_a: str, task_b: str):
|
|
311
|
+
"""Compare two tasks within a subspace."""
|
|
312
|
+
import torch
|
|
313
|
+
|
|
314
|
+
sub = SharedSubspace.load(subspace_path)
|
|
315
|
+
|
|
316
|
+
for tid in [task_a, task_b]:
|
|
317
|
+
if tid not in sub.tasks:
|
|
318
|
+
available = ", ".join(sorted(sub.tasks.keys()))
|
|
319
|
+
raise click.ClickException(f"Unknown task '{tid}'. Available: {available}")
|
|
320
|
+
|
|
321
|
+
click.echo(f"\n Comparing '{task_a}' vs '{task_b}'")
|
|
322
|
+
|
|
323
|
+
recon_a = sub.reconstruct(task_a)
|
|
324
|
+
recon_b = sub.reconstruct(task_b)
|
|
325
|
+
|
|
326
|
+
click.echo(f"\n {'Layer':<30} {'L2 Dist (A)':>12} {'L2 Dist (B)':>12} {'Cosine (A)':>12} {'Cosine (B)':>12}")
|
|
327
|
+
click.echo(f" {'─' * 78}")
|
|
328
|
+
|
|
329
|
+
for layer in sub.layer_names:
|
|
330
|
+
a_a = recon_a.lora_a[layer].flatten()
|
|
331
|
+
a_b = recon_a.lora_b[layer].flatten()
|
|
332
|
+
b_a = recon_b.lora_a[layer].flatten()
|
|
333
|
+
b_b = recon_b.lora_b[layer].flatten()
|
|
334
|
+
|
|
335
|
+
l2_a = (a_a - b_a).norm().item()
|
|
336
|
+
l2_b = (a_b - b_b).norm().item()
|
|
337
|
+
cos_a = torch.nn.functional.cosine_similarity(a_a.unsqueeze(0), b_a.unsqueeze(0)).item()
|
|
338
|
+
cos_b = torch.nn.functional.cosine_similarity(a_b.unsqueeze(0), b_b.unsqueeze(0)).item()
|
|
339
|
+
|
|
340
|
+
name = layer[:30]
|
|
341
|
+
click.echo(f" {name:<30} {l2_a:>12.4f} {l2_b:>12.4f} {cos_a:>12.4f} {cos_b:>12.4f}")
|
|
342
|
+
|
|
343
|
+
click.echo()
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
@cli.command()
|
|
347
|
+
@click.argument("subspace_path", type=click.Path(exists=True))
|
|
348
|
+
def benchmark(subspace_path: str):
|
|
349
|
+
"""Benchmark subspace operations: reconstruct, project, absorb."""
|
|
350
|
+
import torch
|
|
351
|
+
|
|
352
|
+
sub = SharedSubspace.load(subspace_path)
|
|
353
|
+
task_ids = sorted(sub.tasks.keys())
|
|
354
|
+
|
|
355
|
+
if not task_ids:
|
|
356
|
+
raise click.ClickException("Subspace has no tasks to benchmark.")
|
|
357
|
+
|
|
358
|
+
click.echo(f"\n Benchmarking: {subspace_path}")
|
|
359
|
+
click.echo(f" Tasks: {len(task_ids)}, Layers: {len(sub.layer_names)}, k={sub.num_components}")
|
|
360
|
+
|
|
361
|
+
# Benchmark reconstruct
|
|
362
|
+
tid = task_ids[0]
|
|
363
|
+
times = []
|
|
364
|
+
for _ in range(10):
|
|
365
|
+
start = time.perf_counter()
|
|
366
|
+
sub.reconstruct(tid)
|
|
367
|
+
times.append(time.perf_counter() - start)
|
|
368
|
+
avg_recon = sum(times) / len(times)
|
|
369
|
+
click.echo(f"\n reconstruct('{tid}'): {avg_recon * 1000:.2f} ms (avg of 10)")
|
|
370
|
+
|
|
371
|
+
# Benchmark project
|
|
372
|
+
recon = sub.reconstruct(tid)
|
|
373
|
+
times = []
|
|
374
|
+
for _ in range(10):
|
|
375
|
+
start = time.perf_counter()
|
|
376
|
+
sub.project(recon, "bench_proj")
|
|
377
|
+
times.append(time.perf_counter() - start)
|
|
378
|
+
avg_proj = sum(times) / len(times)
|
|
379
|
+
click.echo(f" project(): {avg_proj * 1000:.2f} ms (avg of 10)")
|
|
380
|
+
|
|
381
|
+
# Benchmark compression stats
|
|
382
|
+
start = time.perf_counter()
|
|
383
|
+
stats = sub.compression_stats()
|
|
384
|
+
stats_time = time.perf_counter() - start
|
|
385
|
+
click.echo(f" compression_stats(): {stats_time * 1000:.2f} ms")
|
|
386
|
+
click.echo(f"\n Compression ratio: {stats['compression_ratio']:.1f}x")
|
|
387
|
+
|
|
388
|
+
click.echo()
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
@cli.command()
|
|
392
|
+
@click.argument("adapter_dirs", nargs=-1, required=True, type=click.Path(exists=True))
|
|
393
|
+
@click.option("-o", "--output", required=True, type=click.Path(), help="Output directory for merged adapter.")
|
|
394
|
+
@click.option("--method", type=click.Choice(["average", "ties", "dare"]), default="average", help="Merge method.")
|
|
395
|
+
@click.option("--weights", type=str, default=None, help="Comma-separated per-adapter weights (e.g. '0.7,0.3').")
|
|
396
|
+
@click.option("--density", type=float, default=0.5, help="TIES density: fraction of values to keep.")
|
|
397
|
+
@click.option("--drop-rate", type=float, default=0.5, help="DARE drop rate: probability of dropping each element.")
|
|
398
|
+
@click.option("--seed", type=int, default=None, help="Random seed for DARE reproducibility.")
|
|
399
|
+
def merge(adapter_dirs: tuple[str, ...], output: str, method: str, weights: str | None, density: float, drop_rate: float, seed: int | None):
|
|
400
|
+
"""Merge multiple adapters into one using task arithmetic, TIES, or DARE."""
|
|
401
|
+
from vlora.merge import MERGE_METHODS
|
|
402
|
+
|
|
403
|
+
click.echo(f"\n Loading {len(adapter_dirs)} adapters...")
|
|
404
|
+
adapters = []
|
|
405
|
+
for d in adapter_dirs:
|
|
406
|
+
path = Path(d)
|
|
407
|
+
adapters.append(load_adapter(path))
|
|
408
|
+
click.echo(f" Loaded: {path.name}")
|
|
409
|
+
|
|
410
|
+
if len(adapters) < 2:
|
|
411
|
+
raise click.ClickException("Need at least 2 adapters to merge.")
|
|
412
|
+
|
|
413
|
+
parsed_weights = None
|
|
414
|
+
if weights is not None:
|
|
415
|
+
parsed_weights = [float(w.strip()) for w in weights.split(",")]
|
|
416
|
+
|
|
417
|
+
click.echo(f" Merging with method={method}...")
|
|
418
|
+
|
|
419
|
+
fn = MERGE_METHODS[method]
|
|
420
|
+
if method == "ties":
|
|
421
|
+
merged = fn(adapters, density=density, weights=parsed_weights)
|
|
422
|
+
elif method == "dare":
|
|
423
|
+
merged = fn(adapters, drop_rate=drop_rate, weights=parsed_weights, seed=seed)
|
|
424
|
+
else:
|
|
425
|
+
merged = fn(adapters, weights=parsed_weights)
|
|
426
|
+
|
|
427
|
+
save_adapter(merged, output)
|
|
428
|
+
click.echo(f" Merged adapter saved to: {output}")
|
|
429
|
+
click.echo(f" Layers: {len(merged.layer_names)}, Rank: {merged.rank}")
|
|
430
|
+
click.echo()
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""vlora integrations with external training and serving frameworks."""
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
"""HuggingFace Trainer integration — VLoRACallback for training-in-subspace.
|
|
2
|
+
|
|
3
|
+
Usage with HuggingFace Trainer:
|
|
4
|
+
from vlora import SharedSubspace, orthogonal_init
|
|
5
|
+
from vlora.integrations.huggingface import VLoRACallback
|
|
6
|
+
|
|
7
|
+
subspace = SharedSubspace.load("shared_subspace/")
|
|
8
|
+
orthogonal_init(subspace, "new_task")
|
|
9
|
+
|
|
10
|
+
callback = VLoRACallback(subspace, "new_task", lr=1e-3)
|
|
11
|
+
trainer = Trainer(
|
|
12
|
+
model=base_model,
|
|
13
|
+
args=training_args,
|
|
14
|
+
train_dataset=dataset,
|
|
15
|
+
callbacks=[callback],
|
|
16
|
+
)
|
|
17
|
+
trainer.train()
|
|
18
|
+
subspace.save("updated_subspace/")
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
from __future__ import annotations
|
|
22
|
+
|
|
23
|
+
import logging
|
|
24
|
+
from typing import Any
|
|
25
|
+
|
|
26
|
+
import torch
|
|
27
|
+
from torch import Tensor
|
|
28
|
+
|
|
29
|
+
from vlora.subspace import SharedSubspace
|
|
30
|
+
from vlora.training import SubspaceTrainer
|
|
31
|
+
|
|
32
|
+
logger = logging.getLogger("vlora")
|
|
33
|
+
|
|
34
|
+
try:
|
|
35
|
+
from transformers import TrainerCallback, TrainerControl, TrainerState
|
|
36
|
+
from transformers.training_args import TrainingArguments
|
|
37
|
+
|
|
38
|
+
HAS_TRANSFORMERS = True
|
|
39
|
+
except ImportError:
|
|
40
|
+
HAS_TRANSFORMERS = False
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _require_transformers():
|
|
44
|
+
if not HAS_TRANSFORMERS:
|
|
45
|
+
raise ImportError(
|
|
46
|
+
"transformers is required for HuggingFace integration. "
|
|
47
|
+
"Install with: pip install vlora-dev[hf]"
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
if HAS_TRANSFORMERS:
|
|
52
|
+
|
|
53
|
+
class VLoRACallback(TrainerCallback):
|
|
54
|
+
"""HuggingFace Trainer callback for training-in-subspace.
|
|
55
|
+
|
|
56
|
+
Intercepts the training loop to optimize subspace loadings instead of
|
|
57
|
+
full model parameters. Logs adapter-specific metrics (loadings norm,
|
|
58
|
+
reconstruction error) to the Trainer's log history.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
subspace: Shared subspace (task must already exist).
|
|
62
|
+
task_id: Task whose loadings to train.
|
|
63
|
+
lr: Learning rate for loadings optimizer.
|
|
64
|
+
num_expand: Extra orthogonal directions for the optimizer.
|
|
65
|
+
log_every: Log adapter metrics every N steps.
|
|
66
|
+
save_on_end: Whether to call write_back() when training ends.
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
def __init__(
|
|
70
|
+
self,
|
|
71
|
+
subspace: SharedSubspace,
|
|
72
|
+
task_id: str,
|
|
73
|
+
lr: float = 1e-3,
|
|
74
|
+
num_expand: int = 0,
|
|
75
|
+
log_every: int = 50,
|
|
76
|
+
save_on_end: bool = True,
|
|
77
|
+
):
|
|
78
|
+
_require_transformers()
|
|
79
|
+
self.subspace = subspace
|
|
80
|
+
self.task_id = task_id
|
|
81
|
+
self.lr = lr
|
|
82
|
+
self.num_expand = num_expand
|
|
83
|
+
self.log_every = log_every
|
|
84
|
+
self.save_on_end = save_on_end
|
|
85
|
+
self._trainer: SubspaceTrainer | None = None
|
|
86
|
+
|
|
87
|
+
def on_train_begin(
|
|
88
|
+
self,
|
|
89
|
+
args: TrainingArguments,
|
|
90
|
+
state: TrainerState,
|
|
91
|
+
control: TrainerControl,
|
|
92
|
+
**kwargs: Any,
|
|
93
|
+
):
|
|
94
|
+
self._trainer = SubspaceTrainer(
|
|
95
|
+
self.subspace,
|
|
96
|
+
self.task_id,
|
|
97
|
+
lr=self.lr,
|
|
98
|
+
num_expand=self.num_expand,
|
|
99
|
+
)
|
|
100
|
+
logger.info(
|
|
101
|
+
"VLoRACallback: training '%s' with %d params (lr=%.1e)",
|
|
102
|
+
self.task_id,
|
|
103
|
+
self._trainer.num_trainable_params,
|
|
104
|
+
self.lr,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
def on_step_end(
|
|
108
|
+
self,
|
|
109
|
+
args: TrainingArguments,
|
|
110
|
+
state: TrainerState,
|
|
111
|
+
control: TrainerControl,
|
|
112
|
+
**kwargs: Any,
|
|
113
|
+
):
|
|
114
|
+
if self._trainer is None:
|
|
115
|
+
return
|
|
116
|
+
|
|
117
|
+
step = state.global_step
|
|
118
|
+
if step > 0 and step % self.log_every == 0:
|
|
119
|
+
# Log loadings norm as a proxy for adapter magnitude
|
|
120
|
+
total_norm = 0.0
|
|
121
|
+
for p in self._trainer.params.values():
|
|
122
|
+
total_norm += p.data.norm().item() ** 2
|
|
123
|
+
total_norm = total_norm ** 0.5
|
|
124
|
+
|
|
125
|
+
metrics = {
|
|
126
|
+
"vlora/loadings_norm": total_norm,
|
|
127
|
+
"vlora/trainable_params": self._trainer.num_trainable_params,
|
|
128
|
+
"vlora/step": self._trainer.step_count,
|
|
129
|
+
}
|
|
130
|
+
state.log_history.append(
|
|
131
|
+
{"step": step, **metrics}
|
|
132
|
+
)
|
|
133
|
+
logger.debug(
|
|
134
|
+
"VLoRACallback step %d: loadings_norm=%.4f",
|
|
135
|
+
step,
|
|
136
|
+
total_norm,
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
def on_train_end(
|
|
140
|
+
self,
|
|
141
|
+
args: TrainingArguments,
|
|
142
|
+
state: TrainerState,
|
|
143
|
+
control: TrainerControl,
|
|
144
|
+
**kwargs: Any,
|
|
145
|
+
):
|
|
146
|
+
if self._trainer is not None and self.save_on_end:
|
|
147
|
+
self._trainer.write_back()
|
|
148
|
+
logger.info(
|
|
149
|
+
"VLoRACallback: wrote back loadings for '%s' after %d steps",
|
|
150
|
+
self.task_id,
|
|
151
|
+
self._trainer.step_count,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
@property
|
|
155
|
+
def trainer(self) -> SubspaceTrainer | None:
|
|
156
|
+
"""Access the underlying SubspaceTrainer (available after on_train_begin)."""
|
|
157
|
+
return self._trainer
|
|
158
|
+
|
|
159
|
+
else:
|
|
160
|
+
# Stub class when transformers is not installed
|
|
161
|
+
class VLoRACallback: # type: ignore[no-redef]
|
|
162
|
+
def __init__(self, *args, **kwargs):
|
|
163
|
+
_require_transformers()
|