vespaembed 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vespaembed/__init__.py +1 -1
- vespaembed/cli/__init__.py +17 -0
- vespaembed/cli/commands/__init__.py +7 -0
- vespaembed/cli/commands/evaluate.py +85 -0
- vespaembed/cli/commands/export.py +86 -0
- vespaembed/cli/commands/info.py +52 -0
- vespaembed/cli/commands/serve.py +49 -0
- vespaembed/cli/commands/train.py +267 -0
- vespaembed/cli/vespaembed.py +55 -0
- vespaembed/core/__init__.py +2 -0
- vespaembed/core/config.py +164 -0
- vespaembed/core/registry.py +158 -0
- vespaembed/core/trainer.py +573 -0
- vespaembed/datasets/__init__.py +3 -0
- vespaembed/datasets/formats/__init__.py +5 -0
- vespaembed/datasets/formats/csv.py +15 -0
- vespaembed/datasets/formats/huggingface.py +34 -0
- vespaembed/datasets/formats/jsonl.py +26 -0
- vespaembed/datasets/loader.py +80 -0
- vespaembed/db.py +176 -0
- vespaembed/enums.py +58 -0
- vespaembed/evaluation/__init__.py +3 -0
- vespaembed/evaluation/factory.py +86 -0
- vespaembed/models/__init__.py +4 -0
- vespaembed/models/export.py +89 -0
- vespaembed/models/loader.py +25 -0
- vespaembed/static/css/styles.css +1800 -0
- vespaembed/static/js/app.js +1485 -0
- vespaembed/tasks/__init__.py +23 -0
- vespaembed/tasks/base.py +144 -0
- vespaembed/tasks/pairs.py +91 -0
- vespaembed/tasks/similarity.py +84 -0
- vespaembed/tasks/triplets.py +90 -0
- vespaembed/tasks/tsdae.py +102 -0
- vespaembed/templates/index.html +544 -0
- vespaembed/utils/__init__.py +3 -0
- vespaembed/utils/logging.py +69 -0
- vespaembed/web/__init__.py +1 -0
- vespaembed/web/api/__init__.py +1 -0
- vespaembed/web/app.py +605 -0
- vespaembed/worker.py +313 -0
- vespaembed-0.0.2.dist-info/METADATA +325 -0
- vespaembed-0.0.2.dist-info/RECORD +47 -0
- {vespaembed-0.0.1.dist-info → vespaembed-0.0.2.dist-info}/WHEEL +1 -1
- vespaembed-0.0.1.dist-info/METADATA +0 -20
- vespaembed-0.0.1.dist-info/RECORD +0 -7
- {vespaembed-0.0.1.dist-info → vespaembed-0.0.2.dist-info}/entry_points.txt +0 -0
- {vespaembed-0.0.1.dist-info → vespaembed-0.0.2.dist-info}/licenses/LICENSE +0 -0
- {vespaembed-0.0.1.dist-info → vespaembed-0.0.2.dist-info}/top_level.txt +0 -0
vespaembed/worker.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
1
|
+
"""Background worker for training runs."""
|
|
2
|
+
|
|
3
|
+
import argparse
|
|
4
|
+
import json
|
|
5
|
+
import signal
|
|
6
|
+
import sys
|
|
7
|
+
import traceback
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
# Import tasks to register them
|
|
11
|
+
import vespaembed.tasks # noqa: F401
|
|
12
|
+
from vespaembed.core.config import (
|
|
13
|
+
DataConfig,
|
|
14
|
+
LoraConfig,
|
|
15
|
+
OutputConfig,
|
|
16
|
+
TrainingConfig,
|
|
17
|
+
TrainingHyperparameters,
|
|
18
|
+
UnslothConfig,
|
|
19
|
+
)
|
|
20
|
+
from vespaembed.core.trainer import VespaEmbedTrainer
|
|
21
|
+
from vespaembed.db import update_run_status
|
|
22
|
+
from vespaembed.enums import RunStatus
|
|
23
|
+
from vespaembed.utils.logging import logger
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class TrainingWorker:
|
|
27
|
+
"""Worker that executes training in a subprocess."""
|
|
28
|
+
|
|
29
|
+
def __init__(self, run_id: int, config: dict):
|
|
30
|
+
self.run_id = run_id
|
|
31
|
+
self.config = config
|
|
32
|
+
self.trainer = None
|
|
33
|
+
self.stopped = False
|
|
34
|
+
|
|
35
|
+
# Set up signal handlers
|
|
36
|
+
signal.signal(signal.SIGTERM, self._handle_signal)
|
|
37
|
+
signal.signal(signal.SIGINT, self._handle_signal)
|
|
38
|
+
|
|
39
|
+
def _handle_signal(self, signum, frame):
|
|
40
|
+
"""Handle termination signals."""
|
|
41
|
+
logger.info(f"Received signal {signum}, stopping training...")
|
|
42
|
+
self.stopped = True
|
|
43
|
+
if self.trainer:
|
|
44
|
+
# Trainer will check for stop flag
|
|
45
|
+
pass
|
|
46
|
+
sys.exit(0)
|
|
47
|
+
|
|
48
|
+
def _send_update(self, update_type: str, data: dict):
|
|
49
|
+
"""Send update to the web server via file."""
|
|
50
|
+
update_dir = Path.home() / ".vespaembed" / "updates"
|
|
51
|
+
update_dir.mkdir(parents=True, exist_ok=True)
|
|
52
|
+
|
|
53
|
+
update_file = update_dir / f"run_{self.run_id}.jsonl"
|
|
54
|
+
update = {"type": update_type, "run_id": self.run_id, **data}
|
|
55
|
+
|
|
56
|
+
with open(update_file, "a") as f:
|
|
57
|
+
f.write(json.dumps(update) + "\n")
|
|
58
|
+
|
|
59
|
+
def _log_config(self):
|
|
60
|
+
"""Log training configuration to UI."""
|
|
61
|
+
c = self.config
|
|
62
|
+
lines = [
|
|
63
|
+
"=" * 50,
|
|
64
|
+
"Training Configuration",
|
|
65
|
+
"=" * 50,
|
|
66
|
+
f" Base Model: {c.get('base_model', '--')}",
|
|
67
|
+
f" Task: {c.get('task', '--')}",
|
|
68
|
+
]
|
|
69
|
+
|
|
70
|
+
if c.get("loss_variant"):
|
|
71
|
+
lines.append(f" Loss Variant: {c['loss_variant']}")
|
|
72
|
+
|
|
73
|
+
# Data source
|
|
74
|
+
if c.get("train_filename"):
|
|
75
|
+
lines.append(f" Training Data: {c['train_filename'].split('/')[-1]}")
|
|
76
|
+
elif c.get("hf_dataset"):
|
|
77
|
+
lines.append(f" Training Data: {c['hf_dataset']}")
|
|
78
|
+
|
|
79
|
+
# Training params
|
|
80
|
+
lines.extend(
|
|
81
|
+
[
|
|
82
|
+
f" Epochs: {c.get('epochs', 3)}",
|
|
83
|
+
f" Batch Size: {c.get('batch_size', 32)}",
|
|
84
|
+
f" Learning Rate: {c.get('learning_rate', 2e-5)}",
|
|
85
|
+
f" Optimizer: {c.get('optimizer', 'adamw_torch')}",
|
|
86
|
+
f" Scheduler: {c.get('scheduler', 'linear')}",
|
|
87
|
+
f" Warmup Ratio: {c.get('warmup_ratio', 0.1)}",
|
|
88
|
+
f" Weight Decay: {c.get('weight_decay', 0.01)}",
|
|
89
|
+
]
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
# Precision
|
|
93
|
+
if c.get("bf16"):
|
|
94
|
+
lines.append(" Precision: BF16")
|
|
95
|
+
elif c.get("fp16"):
|
|
96
|
+
lines.append(" Precision: FP16")
|
|
97
|
+
else:
|
|
98
|
+
lines.append(" Precision: FP32")
|
|
99
|
+
|
|
100
|
+
# Optional features
|
|
101
|
+
if c.get("gradient_checkpointing"):
|
|
102
|
+
lines.append(" Grad Checkpoint: Enabled")
|
|
103
|
+
if c.get("gradient_accumulation_steps", 1) > 1:
|
|
104
|
+
lines.append(f" Grad Accum: {c['gradient_accumulation_steps']}")
|
|
105
|
+
|
|
106
|
+
# LoRA
|
|
107
|
+
if c.get("lora_enabled"):
|
|
108
|
+
lines.append(f" LoRA: r={c.get('lora_r', 64)}, alpha={c.get('lora_alpha', 128)}")
|
|
109
|
+
|
|
110
|
+
# Unsloth
|
|
111
|
+
if c.get("unsloth_enabled"):
|
|
112
|
+
lines.append(f" Unsloth: Enabled ({c.get('unsloth_save_method', 'merged_16bit')})")
|
|
113
|
+
|
|
114
|
+
# Matryoshka
|
|
115
|
+
if c.get("matryoshka_dims"):
|
|
116
|
+
lines.append(f" Matryoshka: {c['matryoshka_dims']}")
|
|
117
|
+
|
|
118
|
+
lines.append("=" * 50)
|
|
119
|
+
|
|
120
|
+
# Send each line as a log message
|
|
121
|
+
for line in lines:
|
|
122
|
+
self._send_update("log", {"message": line})
|
|
123
|
+
|
|
124
|
+
def run(self):
|
|
125
|
+
"""Execute the training run."""
|
|
126
|
+
try:
|
|
127
|
+
logger.info(f"Starting training run {self.run_id}")
|
|
128
|
+
self._send_update("status", {"status": "running"})
|
|
129
|
+
|
|
130
|
+
# Log configuration to UI
|
|
131
|
+
self._log_config()
|
|
132
|
+
|
|
133
|
+
# Determine data source - file upload or HuggingFace dataset
|
|
134
|
+
train_source = self.config.get("train_filename") or self.config.get("hf_dataset")
|
|
135
|
+
eval_source = self.config.get("eval_filename")
|
|
136
|
+
eval_split = None
|
|
137
|
+
|
|
138
|
+
# For HuggingFace datasets, eval can come from a different split
|
|
139
|
+
hf_eval_split = self.config.get("hf_eval_split")
|
|
140
|
+
if self.config.get("hf_dataset") and hf_eval_split:
|
|
141
|
+
# Use the same dataset but different split for eval
|
|
142
|
+
eval_source = self.config.get("hf_dataset")
|
|
143
|
+
eval_split = hf_eval_split
|
|
144
|
+
|
|
145
|
+
# Build data config
|
|
146
|
+
data_config = DataConfig(
|
|
147
|
+
train=train_source,
|
|
148
|
+
eval=eval_source,
|
|
149
|
+
subset=self.config.get("hf_subset"),
|
|
150
|
+
split=self.config.get("hf_train_split", "train"),
|
|
151
|
+
eval_split=eval_split,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# Parse matryoshka_dims if present (comes as comma-separated string from UI)
|
|
155
|
+
matryoshka_dims = None
|
|
156
|
+
if self.config.get("matryoshka_dims"):
|
|
157
|
+
dims_str = self.config["matryoshka_dims"]
|
|
158
|
+
if isinstance(dims_str, str):
|
|
159
|
+
matryoshka_dims = [int(d.strip()) for d in dims_str.split(",") if d.strip()]
|
|
160
|
+
elif isinstance(dims_str, list):
|
|
161
|
+
matryoshka_dims = dims_str
|
|
162
|
+
|
|
163
|
+
# Parse lora_target_modules (comes as comma-separated string from UI)
|
|
164
|
+
modules_str = self.config.get("lora_target_modules", "query, key, value, dense")
|
|
165
|
+
if isinstance(modules_str, str):
|
|
166
|
+
lora_target_modules = [m.strip() for m in modules_str.split(",") if m.strip()]
|
|
167
|
+
else:
|
|
168
|
+
lora_target_modules = modules_str # Already a list
|
|
169
|
+
|
|
170
|
+
# Build LoRA config
|
|
171
|
+
lora_config = LoraConfig(
|
|
172
|
+
enabled=self.config.get("lora_enabled", False),
|
|
173
|
+
r=self.config.get("lora_r", 64),
|
|
174
|
+
alpha=self.config.get("lora_alpha", 128),
|
|
175
|
+
dropout=self.config.get("lora_dropout", 0.1),
|
|
176
|
+
target_modules=lora_target_modules,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
# Build Unsloth config
|
|
180
|
+
unsloth_config = UnslothConfig(
|
|
181
|
+
enabled=self.config.get("unsloth_enabled", False),
|
|
182
|
+
save_method=self.config.get("unsloth_save_method", "merged_16bit"),
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
# Build training config with nested structure
|
|
186
|
+
training_config = TrainingConfig(
|
|
187
|
+
task=self.config["task"],
|
|
188
|
+
base_model=self.config["base_model"],
|
|
189
|
+
data=data_config,
|
|
190
|
+
training=TrainingHyperparameters(
|
|
191
|
+
epochs=self.config.get("epochs", 3),
|
|
192
|
+
batch_size=self.config.get("batch_size", 32),
|
|
193
|
+
learning_rate=self.config.get("learning_rate", 2e-5),
|
|
194
|
+
warmup_ratio=self.config.get("warmup_ratio", 0.1),
|
|
195
|
+
weight_decay=self.config.get("weight_decay", 0.01),
|
|
196
|
+
fp16=self.config.get("fp16", False),
|
|
197
|
+
bf16=self.config.get("bf16", False),
|
|
198
|
+
eval_steps=self.config.get("eval_steps", 500),
|
|
199
|
+
save_steps=self.config.get("save_steps", 500),
|
|
200
|
+
logging_steps=self.config.get("logging_steps", 100),
|
|
201
|
+
gradient_accumulation_steps=self.config.get("gradient_accumulation_steps", 1),
|
|
202
|
+
optimizer=self.config.get("optimizer", "adamw_torch"),
|
|
203
|
+
scheduler=self.config.get("scheduler", "linear"),
|
|
204
|
+
),
|
|
205
|
+
output=OutputConfig(
|
|
206
|
+
dir=self.config["output_dir"], # Required - set by web app
|
|
207
|
+
push_to_hub=self.config.get("push_to_hub", False),
|
|
208
|
+
hf_username=self.config.get("hf_username"),
|
|
209
|
+
),
|
|
210
|
+
lora=lora_config,
|
|
211
|
+
unsloth=unsloth_config,
|
|
212
|
+
max_seq_length=self.config.get("max_seq_length"), # None = auto-detect
|
|
213
|
+
gradient_checkpointing=self.config.get("gradient_checkpointing", False),
|
|
214
|
+
matryoshka_dims=matryoshka_dims,
|
|
215
|
+
loss_variant=self.config.get("loss_variant"),
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
# Create trainer with progress callback
|
|
219
|
+
self.trainer = VespaEmbedTrainer(
|
|
220
|
+
config=training_config,
|
|
221
|
+
progress_callback=self._progress_callback,
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
# Run training
|
|
225
|
+
self.trainer.train()
|
|
226
|
+
|
|
227
|
+
# Update status on completion
|
|
228
|
+
update_run_status(self.run_id, RunStatus.COMPLETED)
|
|
229
|
+
self._send_update("complete", {"output_dir": training_config.output.dir})
|
|
230
|
+
logger.info(f"Training run {self.run_id} completed successfully")
|
|
231
|
+
|
|
232
|
+
except Exception as e:
|
|
233
|
+
logger.error(f"Training run {self.run_id} failed: {e}")
|
|
234
|
+
logger.error(traceback.format_exc())
|
|
235
|
+
update_run_status(self.run_id, RunStatus.ERROR, error_message=str(e))
|
|
236
|
+
self._send_update("error", {"message": str(e)})
|
|
237
|
+
sys.exit(1)
|
|
238
|
+
|
|
239
|
+
def _progress_callback(self, progress: dict):
|
|
240
|
+
"""Callback for training progress updates."""
|
|
241
|
+
progress_type = progress.get("type", "progress")
|
|
242
|
+
|
|
243
|
+
if progress_type == "train_start":
|
|
244
|
+
self._send_update("progress", progress)
|
|
245
|
+
self._send_update(
|
|
246
|
+
"log",
|
|
247
|
+
{"message": f"Training started: {progress['total_steps']} steps, {progress['total_epochs']} epochs"},
|
|
248
|
+
)
|
|
249
|
+
elif progress_type == "train_end":
|
|
250
|
+
self._send_update("progress", progress)
|
|
251
|
+
elapsed = self._format_time(progress.get("elapsed_seconds", 0))
|
|
252
|
+
self._send_update("log", {"message": f"Training completed in {elapsed}"})
|
|
253
|
+
else:
|
|
254
|
+
self._send_update("progress", progress)
|
|
255
|
+
self._send_update("log", {"message": self._format_progress(progress)})
|
|
256
|
+
|
|
257
|
+
def _format_time(self, seconds: float) -> str:
|
|
258
|
+
"""Format seconds into human readable time."""
|
|
259
|
+
if seconds < 60:
|
|
260
|
+
return f"{seconds:.0f}s"
|
|
261
|
+
elif seconds < 3600:
|
|
262
|
+
mins = int(seconds // 60)
|
|
263
|
+
secs = int(seconds % 60)
|
|
264
|
+
return f"{mins}m {secs}s"
|
|
265
|
+
else:
|
|
266
|
+
hours = int(seconds // 3600)
|
|
267
|
+
mins = int((seconds % 3600) // 60)
|
|
268
|
+
return f"{hours}h {mins}m"
|
|
269
|
+
|
|
270
|
+
def _format_progress(self, progress: dict) -> str:
|
|
271
|
+
"""Format progress as a log message like tqdm."""
|
|
272
|
+
step = progress.get("step", 0)
|
|
273
|
+
total = progress.get("total_steps", 0)
|
|
274
|
+
pct = progress.get("progress_pct", 0)
|
|
275
|
+
epoch = progress.get("epoch", 0)
|
|
276
|
+
total_epochs = progress.get("total_epochs", 0)
|
|
277
|
+
loss = progress.get("loss")
|
|
278
|
+
lr = progress.get("learning_rate")
|
|
279
|
+
eta = progress.get("eta_seconds", 0)
|
|
280
|
+
|
|
281
|
+
# Format: [Step 100/1000 (10%)] Epoch 1/3 | Loss: 0.5234 | LR: 2.00e-05 | ETA: 5m 30s
|
|
282
|
+
parts = [f"[Step {step}/{total} ({pct:.0f}%)]"]
|
|
283
|
+
|
|
284
|
+
if total_epochs:
|
|
285
|
+
parts.append(f"Epoch {epoch:.1f}/{total_epochs}")
|
|
286
|
+
|
|
287
|
+
if loss is not None:
|
|
288
|
+
parts.append(f"Loss: {loss:.4f}")
|
|
289
|
+
|
|
290
|
+
if lr is not None:
|
|
291
|
+
parts.append(f"LR: {lr:.2e}")
|
|
292
|
+
|
|
293
|
+
if eta > 0:
|
|
294
|
+
parts.append(f"ETA: {self._format_time(eta)}")
|
|
295
|
+
|
|
296
|
+
return " | ".join(parts)
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def main():
|
|
300
|
+
"""Main entry point for the worker."""
|
|
301
|
+
parser = argparse.ArgumentParser(description="VespaEmbed Training Worker")
|
|
302
|
+
parser.add_argument("--run-id", type=int, required=True, help="Run ID")
|
|
303
|
+
parser.add_argument("--config", type=str, required=True, help="Training config JSON")
|
|
304
|
+
|
|
305
|
+
args = parser.parse_args()
|
|
306
|
+
|
|
307
|
+
config = json.loads(args.config)
|
|
308
|
+
worker = TrainingWorker(run_id=args.run_id, config=config)
|
|
309
|
+
worker.run()
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
if __name__ == "__main__":
|
|
313
|
+
main()
|
|
@@ -0,0 +1,325 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: vespaembed
|
|
3
|
+
Version: 0.0.2
|
|
4
|
+
Summary: vespaembed: no-code training for embedding models
|
|
5
|
+
Author: Abhishek Thakur
|
|
6
|
+
License: Apache 2.0
|
|
7
|
+
Project-URL: Homepage, https://github.com/vespaai-playground/vespaembed
|
|
8
|
+
Requires-Python: >=3.11
|
|
9
|
+
Description-Content-Type: text/markdown
|
|
10
|
+
License-File: LICENSE
|
|
11
|
+
Requires-Dist: sentence-transformers>=3.0.0
|
|
12
|
+
Requires-Dist: transformers>=4.40.0
|
|
13
|
+
Requires-Dist: accelerate>=0.26.0
|
|
14
|
+
Requires-Dist: torch>=2.0.0
|
|
15
|
+
Requires-Dist: datasets>=2.18.0
|
|
16
|
+
Requires-Dist: pandas>=2.0.0
|
|
17
|
+
Requires-Dist: pydantic>=2.0.0
|
|
18
|
+
Requires-Dist: rich>=13.0.0
|
|
19
|
+
Requires-Dist: fastapi>=0.111.0
|
|
20
|
+
Requires-Dist: uvicorn>=0.30.0
|
|
21
|
+
Requires-Dist: python-multipart>=0.0.9
|
|
22
|
+
Requires-Dist: websockets>=12.0
|
|
23
|
+
Requires-Dist: jinja2>=3.1.0
|
|
24
|
+
Requires-Dist: pyyaml>=6.0.0
|
|
25
|
+
Requires-Dist: tensorboard>=2.15.0
|
|
26
|
+
Requires-Dist: peft>=0.18.1
|
|
27
|
+
Requires-Dist: unsloth>=2026.1.4
|
|
28
|
+
Provides-Extra: unsloth
|
|
29
|
+
Requires-Dist: unsloth; extra == "unsloth"
|
|
30
|
+
Provides-Extra: onnx
|
|
31
|
+
Requires-Dist: onnx>=1.14.0; extra == "onnx"
|
|
32
|
+
Requires-Dist: onnxruntime>=1.23.2; extra == "onnx"
|
|
33
|
+
Provides-Extra: dev
|
|
34
|
+
Requires-Dist: black==26.1.0; extra == "dev"
|
|
35
|
+
Requires-Dist: isort==7.0.0; extra == "dev"
|
|
36
|
+
Requires-Dist: flake8==7.3.0; extra == "dev"
|
|
37
|
+
Requires-Dist: pytest>=9.0.2; extra == "dev"
|
|
38
|
+
Requires-Dist: pytest-cov>=7.0.0; extra == "dev"
|
|
39
|
+
Dynamic: license-file
|
|
40
|
+
|
|
41
|
+
# VespaEmbed
|
|
42
|
+
|
|
43
|
+
No-code training for embedding models. Train custom embedding models with a web UI or CLI.
|
|
44
|
+
|
|
45
|
+
## Features
|
|
46
|
+
|
|
47
|
+
- **Web UI** - Visual interface for configuring and monitoring training
|
|
48
|
+
- **CLI** - Command-line interface for scripting and automation
|
|
49
|
+
- **Multiple Tasks** - Support for pairs, triplets, similarity scoring, and unsupervised learning
|
|
50
|
+
- **Loss Variants** - Choose from multiple loss functions per task
|
|
51
|
+
- **Matryoshka Embeddings** - Train multi-dimensional embeddings for flexible retrieval
|
|
52
|
+
- **LoRA Support** - Parameter-efficient fine-tuning with LoRA adapters
|
|
53
|
+
- **Unsloth Integration** - Faster training with Unsloth optimizations
|
|
54
|
+
- **HuggingFace Integration** - Load datasets, models from HuggingFace Hub, push models to Hub
|
|
55
|
+
|
|
56
|
+
## Installation
|
|
57
|
+
|
|
58
|
+
> **Note:** VespaEmbed is in experimental phase. Install from source.
|
|
59
|
+
|
|
60
|
+
```bash
|
|
61
|
+
git clone https://github.com/vespaai-playground/vespaembed.git
|
|
62
|
+
cd vespaembed
|
|
63
|
+
uv sync
|
|
64
|
+
```
|
|
65
|
+
|
|
66
|
+
### Optional Dependencies
|
|
67
|
+
|
|
68
|
+
```bash
|
|
69
|
+
# For Unsloth acceleration (requires NVIDIA/AMD GPU)
|
|
70
|
+
uv sync --extra unsloth
|
|
71
|
+
|
|
72
|
+
# For ONNX export
|
|
73
|
+
uv sync --extra onnx
|
|
74
|
+
|
|
75
|
+
# For development
|
|
76
|
+
uv sync --extra dev
|
|
77
|
+
```
|
|
78
|
+
|
|
79
|
+
## Quick Start
|
|
80
|
+
|
|
81
|
+
### Web UI
|
|
82
|
+
|
|
83
|
+
Launch the web interface:
|
|
84
|
+
|
|
85
|
+
```bash
|
|
86
|
+
vespaembed
|
|
87
|
+
```
|
|
88
|
+
|
|
89
|
+
Open http://localhost:8000 in your browser. The UI lets you:
|
|
90
|
+
- Upload training data (CSV or JSONL)
|
|
91
|
+
- Select task type and base model
|
|
92
|
+
- Configure hyperparameters
|
|
93
|
+
- Monitor training progress
|
|
94
|
+
- Download trained models
|
|
95
|
+
|
|
96
|
+
### CLI
|
|
97
|
+
|
|
98
|
+
Train a model from the command line:
|
|
99
|
+
|
|
100
|
+
```bash
|
|
101
|
+
vespaembed train \
|
|
102
|
+
--data examples/data/pairs.csv \
|
|
103
|
+
--task pairs \
|
|
104
|
+
--base-model sentence-transformers/all-MiniLM-L6-v2 \
|
|
105
|
+
--epochs 3
|
|
106
|
+
```
|
|
107
|
+
|
|
108
|
+
Or use a YAML config file:
|
|
109
|
+
|
|
110
|
+
```bash
|
|
111
|
+
vespaembed train --config config.yaml
|
|
112
|
+
```
|
|
113
|
+
|
|
114
|
+
## Tasks
|
|
115
|
+
|
|
116
|
+
VespaEmbed supports 4 training tasks based on your data format:
|
|
117
|
+
|
|
118
|
+
### Pairs
|
|
119
|
+
|
|
120
|
+
Text pairs for semantic search. Use when you have query-document pairs without explicit negatives.
|
|
121
|
+
|
|
122
|
+
**Data format:**
|
|
123
|
+
```csv
|
|
124
|
+
anchor,positive
|
|
125
|
+
What is machine learning?,Machine learning is a subset of AI...
|
|
126
|
+
How does photosynthesis work?,Photosynthesis converts sunlight...
|
|
127
|
+
```
|
|
128
|
+
|
|
129
|
+
**Loss variants:** `mnr` (default), `mnr_symmetric`, `gist`, `cached_mnr`, `cached_gist`
|
|
130
|
+
|
|
131
|
+
### Triplets
|
|
132
|
+
|
|
133
|
+
Text triplets with hard negatives. Use when you have explicit negative examples.
|
|
134
|
+
|
|
135
|
+
**Data format:**
|
|
136
|
+
```csv
|
|
137
|
+
anchor,positive,negative
|
|
138
|
+
What is Python?,Python is a programming language...,A python is a large snake...
|
|
139
|
+
```
|
|
140
|
+
|
|
141
|
+
**Loss variants:** `mnr` (default), `mnr_symmetric`, `gist`, `cached_mnr`, `cached_gist`
|
|
142
|
+
|
|
143
|
+
### Similarity
|
|
144
|
+
|
|
145
|
+
Text pairs with similarity scores (STS-style). Use when you have continuous similarity labels.
|
|
146
|
+
|
|
147
|
+
**Data format:**
|
|
148
|
+
```csv
|
|
149
|
+
sentence1,sentence2,score
|
|
150
|
+
A man is playing guitar,A person plays music,0.85
|
|
151
|
+
The cat is sleeping,A dog is running,0.12
|
|
152
|
+
```
|
|
153
|
+
|
|
154
|
+
**Loss variants:** `cosine` (default), `cosent`, `angle`
|
|
155
|
+
|
|
156
|
+
### TSDAE
|
|
157
|
+
|
|
158
|
+
Unsupervised learning with denoising auto-encoder. Use when you only have unlabeled text for domain adaptation.
|
|
159
|
+
|
|
160
|
+
**Data format:**
|
|
161
|
+
```csv
|
|
162
|
+
text
|
|
163
|
+
Machine learning is transforming how we analyze data.
|
|
164
|
+
Natural language processing enables computers to understand human language.
|
|
165
|
+
```
|
|
166
|
+
|
|
167
|
+
## Configuration
|
|
168
|
+
|
|
169
|
+
### CLI Arguments
|
|
170
|
+
|
|
171
|
+
```bash
|
|
172
|
+
vespaembed train \
|
|
173
|
+
--data <path> # Training data (CSV, JSONL, or HF dataset)
|
|
174
|
+
--task <task> # Task type: pairs, triplets, similarity, tsdae
|
|
175
|
+
--base-model <model> # Base model name or path
|
|
176
|
+
--project <name> # Project name (optional)
|
|
177
|
+
--eval-data <path> # Evaluation data (optional)
|
|
178
|
+
--epochs <n> # Number of epochs (default: 3)
|
|
179
|
+
--batch-size <n> # Batch size (default: 32)
|
|
180
|
+
--learning-rate <lr> # Learning rate (default: 2e-5)
|
|
181
|
+
--optimizer <opt> # Optimizer (default: adamw_torch)
|
|
182
|
+
--scheduler <sched> # LR scheduler (default: linear)
|
|
183
|
+
--matryoshka # Enable Matryoshka embeddings
|
|
184
|
+
--matryoshka-dims <dims> # Dimensions (default: 768,512,256,128,64)
|
|
185
|
+
--unsloth # Use Unsloth for faster training
|
|
186
|
+
--subset <name> # HuggingFace dataset subset
|
|
187
|
+
--split <name> # HuggingFace dataset split
|
|
188
|
+
```
|
|
189
|
+
|
|
190
|
+
### Optimizers
|
|
191
|
+
|
|
192
|
+
| Option | Description |
|
|
193
|
+
|--------|-------------|
|
|
194
|
+
| `adamw_torch` | AdamW (default) |
|
|
195
|
+
| `adamw_torch_fused` | Fused AdamW (faster on CUDA) |
|
|
196
|
+
| `adamw_8bit` | 8-bit AdamW (memory efficient) |
|
|
197
|
+
| `adafactor` | Adafactor (memory efficient, no momentum) |
|
|
198
|
+
| `sgd` | SGD with momentum |
|
|
199
|
+
|
|
200
|
+
### Schedulers
|
|
201
|
+
|
|
202
|
+
| Option | Description |
|
|
203
|
+
|--------|-------------|
|
|
204
|
+
| `linear` | Linear decay (default) |
|
|
205
|
+
| `cosine` | Cosine annealing |
|
|
206
|
+
| `cosine_with_restarts` | Cosine with warm restarts |
|
|
207
|
+
| `constant` | Constant learning rate |
|
|
208
|
+
| `constant_with_warmup` | Constant after warmup |
|
|
209
|
+
| `polynomial` | Polynomial decay |
|
|
210
|
+
|
|
211
|
+
### YAML Configuration
|
|
212
|
+
|
|
213
|
+
```yaml
|
|
214
|
+
task: pairs
|
|
215
|
+
base_model: sentence-transformers/all-MiniLM-L6-v2
|
|
216
|
+
|
|
217
|
+
data:
|
|
218
|
+
train: train.csv
|
|
219
|
+
eval: eval.csv # optional
|
|
220
|
+
|
|
221
|
+
training:
|
|
222
|
+
epochs: 3
|
|
223
|
+
batch_size: 32
|
|
224
|
+
learning_rate: 2e-5
|
|
225
|
+
warmup_ratio: 0.1
|
|
226
|
+
weight_decay: 0.01
|
|
227
|
+
fp16: true
|
|
228
|
+
eval_steps: 500
|
|
229
|
+
save_steps: 500
|
|
230
|
+
logging_steps: 100
|
|
231
|
+
optimizer: adamw_torch # adamw_torch, adamw_8bit, adafactor, sgd
|
|
232
|
+
scheduler: linear # linear, cosine, constant, polynomial
|
|
233
|
+
|
|
234
|
+
output:
|
|
235
|
+
dir: ./output
|
|
236
|
+
push_to_hub: false
|
|
237
|
+
hf_username: null
|
|
238
|
+
|
|
239
|
+
# Optional: LoRA configuration
|
|
240
|
+
lora:
|
|
241
|
+
enabled: false
|
|
242
|
+
r: 64
|
|
243
|
+
alpha: 128
|
|
244
|
+
dropout: 0.1
|
|
245
|
+
target_modules: [query, key, value, dense]
|
|
246
|
+
|
|
247
|
+
# Optional: Matryoshka dimensions
|
|
248
|
+
matryoshka_dims: [768, 512, 256, 128, 64]
|
|
249
|
+
|
|
250
|
+
# Optional: Loss variant (uses task default if not specified)
|
|
251
|
+
loss_variant: mnr
|
|
252
|
+
```
|
|
253
|
+
|
|
254
|
+
### HuggingFace Datasets
|
|
255
|
+
|
|
256
|
+
Load datasets directly from HuggingFace Hub:
|
|
257
|
+
|
|
258
|
+
```bash
|
|
259
|
+
vespaembed train \
|
|
260
|
+
--data sentence-transformers/all-nli \
|
|
261
|
+
--subset triplet \
|
|
262
|
+
--split train \
|
|
263
|
+
--task triplets \
|
|
264
|
+
--base-model sentence-transformers/all-MiniLM-L6-v2
|
|
265
|
+
```
|
|
266
|
+
|
|
267
|
+
## CLI Commands
|
|
268
|
+
|
|
269
|
+
| Command | Description |
|
|
270
|
+
|---------|-------------|
|
|
271
|
+
| `vespaembed` | Launch web UI (default) |
|
|
272
|
+
| `vespaembed serve` | Launch web UI |
|
|
273
|
+
| `vespaembed train` | Train a model |
|
|
274
|
+
| `vespaembed evaluate` | Evaluate a model |
|
|
275
|
+
| `vespaembed export` | Export model to ONNX |
|
|
276
|
+
| `vespaembed info` | Show task information |
|
|
277
|
+
|
|
278
|
+
## Output
|
|
279
|
+
|
|
280
|
+
Trained models are saved to `~/.vespaembed/projects/<project-name>/`:
|
|
281
|
+
|
|
282
|
+
```
|
|
283
|
+
~/.vespaembed/projects/my-project/
|
|
284
|
+
├── final/ # Final trained model
|
|
285
|
+
├── checkpoint-500/ # Training checkpoints
|
|
286
|
+
├── checkpoint-1000/
|
|
287
|
+
└── logs/ # TensorBoard logs
|
|
288
|
+
```
|
|
289
|
+
|
|
290
|
+
## Column Aliases
|
|
291
|
+
|
|
292
|
+
VespaEmbed automatically recognizes common column name variations:
|
|
293
|
+
|
|
294
|
+
| Task | Expected | Also Accepts |
|
|
295
|
+
|------|----------|--------------|
|
|
296
|
+
| pairs | `anchor` | `query`, `question`, `sent1`, `sentence1`, `text1` |
|
|
297
|
+
| pairs | `positive` | `document`, `answer`, `pos`, `sent2`, `sentence2`, `text2` |
|
|
298
|
+
| triplets | `negative` | `neg`, `hard_negative`, `sent3`, `sentence3`, `text3` |
|
|
299
|
+
| similarity | `sentence1` | `sent1`, `text1`, `anchor`, `query` |
|
|
300
|
+
| similarity | `sentence2` | `sent2`, `text2`, `positive`, `document` |
|
|
301
|
+
| similarity | `score` | `similarity`, `label`, `sim_score` |
|
|
302
|
+
| tsdae | `text` | `sentence`, `sentences`, `content`, `input` |
|
|
303
|
+
|
|
304
|
+
## Development
|
|
305
|
+
|
|
306
|
+
```bash
|
|
307
|
+
# Install dev dependencies
|
|
308
|
+
uv sync --extra dev
|
|
309
|
+
|
|
310
|
+
# Run tests
|
|
311
|
+
uv run pytest tests/
|
|
312
|
+
|
|
313
|
+
# Run tests with coverage
|
|
314
|
+
uv run pytest tests/ --cov=vespaembed
|
|
315
|
+
|
|
316
|
+
# Format code
|
|
317
|
+
make format
|
|
318
|
+
|
|
319
|
+
# Lint
|
|
320
|
+
make lint
|
|
321
|
+
```
|
|
322
|
+
|
|
323
|
+
## License
|
|
324
|
+
|
|
325
|
+
Apache 2.0
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
vespaembed/__init__.py,sha256=QvlVh4JTl3JL7jQAja76yKtT-IvF4631ASjWY1wS6AQ,22
|
|
2
|
+
vespaembed/db.py,sha256=9m46Ljc8e2s8NnhRJRkf3pj-U72PGM3o94pEec8IFoo,4272
|
|
3
|
+
vespaembed/enums.py,sha256=1ovPAJW-5f4TrI8dnHMJYitau30s5TeMk0s2Ptkppaw,1453
|
|
4
|
+
vespaembed/worker.py,sha256=2vI292BWRKKK-hBtTQj3fxj5sKmuuXODubYk6yCH5hU,12230
|
|
5
|
+
vespaembed/cli/__init__.py,sha256=rWKimiyk7YndJfRCEAlZdsDv5l96Cixwl_UX6Nq2vf0,433
|
|
6
|
+
vespaembed/cli/vespaembed.py,sha256=LcGh9k5fQPjHoDOzf-1dUhcACDxQLXIXpVGmjCylbdI,1616
|
|
7
|
+
vespaembed/cli/commands/__init__.py,sha256=4lIRTQRXe_QQbtFJDNgeRJ3BZf4bD7FmhEScVWeI6S8,376
|
|
8
|
+
vespaembed/cli/commands/evaluate.py,sha256=bx74DrSdCqzP6KWXBx_6N1VcFE-vOt9FkFYKkvkTpU8,2595
|
|
9
|
+
vespaembed/cli/commands/export.py,sha256=SB8cIC1FqHXLKdrqvL2ca_mzJDGe-W9ECdYV9T5z--k,2614
|
|
10
|
+
vespaembed/cli/commands/info.py,sha256=3n0hz7TmUdKX7ZvgmBnOFtpNtMu3VA_3zujGidGY1G8,1615
|
|
11
|
+
vespaembed/cli/commands/serve.py,sha256=Qw_MxtsrIwfk45vs9l3cMvCn9_H8MKxUXYOg-Nqf940,1403
|
|
12
|
+
vespaembed/cli/commands/train.py,sha256=IEYMWr5hrZGAshVljHAJ10Z7Iw8QFiGb54fe39PIRKI,9015
|
|
13
|
+
vespaembed/core/__init__.py,sha256=NcgBvvomhTlLim8WG84cFnnGNDFKSJolmUrBsbvafvM,145
|
|
14
|
+
vespaembed/core/config.py,sha256=eERuTr4j-Gpp0TovrRmE7GGECuIWhbUTGzbJ1Kuk8ik,6041
|
|
15
|
+
vespaembed/core/registry.py,sha256=CrSIQxx6a9QQF7po-YhseMfcHs5IBDd_3LwHpxUACWE,5179
|
|
16
|
+
vespaembed/core/trainer.py,sha256=9PPMEpH-hOqqO7ZaMqrYlVzHKffAbUIvo3makBXB-0U,22802
|
|
17
|
+
vespaembed/datasets/__init__.py,sha256=F1QDj036DWAjRQO38u4dK04tzzRtPwGLk_AMbtfjIp8,80
|
|
18
|
+
vespaembed/datasets/loader.py,sha256=aRasiQxqIpaH7sfbZrp_7h-xc8zHFfgojCQFWVHvsWE,2252
|
|
19
|
+
vespaembed/datasets/formats/__init__.py,sha256=UDhvPk9FHq8RR_V-uphciNDYH3Ut--DvT7j8DmbQSTQ,235
|
|
20
|
+
vespaembed/datasets/formats/csv.py,sha256=jcYgaLtNbuE5CFtg9jHha4VekKay67y6LKTXCWPasGA,290
|
|
21
|
+
vespaembed/datasets/formats/huggingface.py,sha256=v6zY4wTNLOe-Qx-TxJBPwF5QM6WMkcXPDlmd2Dar_9k,865
|
|
22
|
+
vespaembed/datasets/formats/jsonl.py,sha256=8VsfzZHWU_16z7npBETSyDVqDvzosprW_U9wNdt3exE,512
|
|
23
|
+
vespaembed/evaluation/__init__.py,sha256=hE8xZ0xG-qPTP1_At5dcZG1p1wwUATFtJZlxhUindmc,91
|
|
24
|
+
vespaembed/evaluation/factory.py,sha256=_6b_b5Sj279owCWSLp94KPRKcZeUq-tnF9IyxNRBFWM,2715
|
|
25
|
+
vespaembed/models/__init__.py,sha256=EVHS2Xuvr12JLIm9BA8SXvCsJcSPRcIKGBmuNHFYNrA,140
|
|
26
|
+
vespaembed/models/export.py,sha256=fdX21e6UEZpHhhKU9gwpUBgYJWRmX9_S6EbWP6-QWwA,2436
|
|
27
|
+
vespaembed/models/loader.py,sha256=7VWWTpPXeVp9eJq4EsmNdA1zBVYf9BQQRRF1v9ObnJ8,813
|
|
28
|
+
vespaembed/static/css/styles.css,sha256=yOXm9UW41s4nZpG2CBQA5tw6bHzEBFeUMofb6LdqNWk,37350
|
|
29
|
+
vespaembed/static/js/app.js,sha256=kri7JVLSNGWludIYj_ZKQLrJ6fzziQlUv5dd0L2NxBI,52969
|
|
30
|
+
vespaembed/tasks/__init__.py,sha256=cMSdguS9t2mElHfiVpZJ5BcnXtyBQmHnT2I30INZ-gs,883
|
|
31
|
+
vespaembed/tasks/base.py,sha256=ER2H91HYsKE25tJc6J3TkYkgwVhsZkICRDb5MSXLaek,5105
|
|
32
|
+
vespaembed/tasks/pairs.py,sha256=VRmY7bu3fNQTPtopq-GvfsHH4g2yVdHL-Nr2KdU4-28,3520
|
|
33
|
+
vespaembed/tasks/similarity.py,sha256=xk5hGoALNJnelM0n5lRys1gF30KiVdlUplgcxBI0lz8,3037
|
|
34
|
+
vespaembed/tasks/triplets.py,sha256=ZTfF6yFQS_368LJ-rV4qVQGqrxxGzS0eiGO1f4tqICU,3581
|
|
35
|
+
vespaembed/tasks/tsdae.py,sha256=VcJbyABee9yfv2leNe_yVMDa48XfronAFcEDDIVrZKo,3434
|
|
36
|
+
vespaembed/templates/index.html,sha256=z4vIejLAgSOL0ISTJpmyMqSegCew-X3Aj4wv1xV4knY,30185
|
|
37
|
+
vespaembed/utils/__init__.py,sha256=ZvKPCgX-ISxoj8LEwkMZ0BgVK3MZcE3TExkNBOkRhfk,74
|
|
38
|
+
vespaembed/utils/logging.py,sha256=xxbMAohRsmfB5TZX-43iX488_2p-KBmttdBRuCyA4dE,1736
|
|
39
|
+
vespaembed/web/__init__.py,sha256=rFAxCq3VJ7UDoLoYVZPshpAoetYV9JPlKpJdHREC_PM,37
|
|
40
|
+
vespaembed/web/app.py,sha256=0GdPAWs5y2bhJv6XCmZMdf6-LMnZrLA5RHxtylbxu_4,19380
|
|
41
|
+
vespaembed/web/api/__init__.py,sha256=GLiedM-kLCaQTZCa_UmkqIrjEGjwvGdDdfKpZhUFgN0,16
|
|
42
|
+
vespaembed-0.0.2.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
|
|
43
|
+
vespaembed-0.0.2.dist-info/METADATA,sha256=COTsjigAH1BnJZiuUgOVPi698eiIJ6g7oJ4Ud6SBeGw,8655
|
|
44
|
+
vespaembed-0.0.2.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
45
|
+
vespaembed-0.0.2.dist-info/entry_points.txt,sha256=_yO09x95rXw-0RhJkMmXAt5bjBOysVy10jbe4KJuMIg,62
|
|
46
|
+
vespaembed-0.0.2.dist-info/top_level.txt,sha256=GhqAsJ29dVFnUrvwec8lRqDCC1BqlWntZZFibsgCEdY,11
|
|
47
|
+
vespaembed-0.0.2.dist-info/RECORD,,
|