vespaembed 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. vespaembed/__init__.py +1 -1
  2. vespaembed/cli/__init__.py +17 -0
  3. vespaembed/cli/commands/__init__.py +7 -0
  4. vespaembed/cli/commands/evaluate.py +85 -0
  5. vespaembed/cli/commands/export.py +86 -0
  6. vespaembed/cli/commands/info.py +52 -0
  7. vespaembed/cli/commands/serve.py +49 -0
  8. vespaembed/cli/commands/train.py +267 -0
  9. vespaembed/cli/vespaembed.py +55 -0
  10. vespaembed/core/__init__.py +2 -0
  11. vespaembed/core/config.py +164 -0
  12. vespaembed/core/registry.py +158 -0
  13. vespaembed/core/trainer.py +573 -0
  14. vespaembed/datasets/__init__.py +3 -0
  15. vespaembed/datasets/formats/__init__.py +5 -0
  16. vespaembed/datasets/formats/csv.py +15 -0
  17. vespaembed/datasets/formats/huggingface.py +34 -0
  18. vespaembed/datasets/formats/jsonl.py +26 -0
  19. vespaembed/datasets/loader.py +80 -0
  20. vespaembed/db.py +176 -0
  21. vespaembed/enums.py +58 -0
  22. vespaembed/evaluation/__init__.py +3 -0
  23. vespaembed/evaluation/factory.py +86 -0
  24. vespaembed/models/__init__.py +4 -0
  25. vespaembed/models/export.py +89 -0
  26. vespaembed/models/loader.py +25 -0
  27. vespaembed/static/css/styles.css +1800 -0
  28. vespaembed/static/js/app.js +1485 -0
  29. vespaembed/tasks/__init__.py +23 -0
  30. vespaembed/tasks/base.py +144 -0
  31. vespaembed/tasks/pairs.py +91 -0
  32. vespaembed/tasks/similarity.py +84 -0
  33. vespaembed/tasks/triplets.py +90 -0
  34. vespaembed/tasks/tsdae.py +102 -0
  35. vespaembed/templates/index.html +544 -0
  36. vespaembed/utils/__init__.py +3 -0
  37. vespaembed/utils/logging.py +69 -0
  38. vespaembed/web/__init__.py +1 -0
  39. vespaembed/web/api/__init__.py +1 -0
  40. vespaembed/web/app.py +605 -0
  41. vespaembed/worker.py +313 -0
  42. vespaembed-0.0.3.dist-info/METADATA +325 -0
  43. vespaembed-0.0.3.dist-info/RECORD +47 -0
  44. {vespaembed-0.0.1.dist-info → vespaembed-0.0.3.dist-info}/WHEEL +1 -1
  45. vespaembed-0.0.1.dist-info/METADATA +0 -20
  46. vespaembed-0.0.1.dist-info/RECORD +0 -7
  47. {vespaembed-0.0.1.dist-info → vespaembed-0.0.3.dist-info}/entry_points.txt +0 -0
  48. {vespaembed-0.0.1.dist-info → vespaembed-0.0.3.dist-info}/licenses/LICENSE +0 -0
  49. {vespaembed-0.0.1.dist-info → vespaembed-0.0.3.dist-info}/top_level.txt +0 -0
vespaembed/worker.py ADDED
@@ -0,0 +1,313 @@
1
+ """Background worker for training runs."""
2
+
3
+ import argparse
4
+ import json
5
+ import signal
6
+ import sys
7
+ import traceback
8
+ from pathlib import Path
9
+
10
+ # Import tasks to register them
11
+ import vespaembed.tasks # noqa: F401
12
+ from vespaembed.core.config import (
13
+ DataConfig,
14
+ LoraConfig,
15
+ OutputConfig,
16
+ TrainingConfig,
17
+ TrainingHyperparameters,
18
+ UnslothConfig,
19
+ )
20
+ from vespaembed.core.trainer import VespaEmbedTrainer
21
+ from vespaembed.db import update_run_status
22
+ from vespaembed.enums import RunStatus
23
+ from vespaembed.utils.logging import logger
24
+
25
+
26
+ class TrainingWorker:
27
+ """Worker that executes training in a subprocess."""
28
+
29
+ def __init__(self, run_id: int, config: dict):
30
+ self.run_id = run_id
31
+ self.config = config
32
+ self.trainer = None
33
+ self.stopped = False
34
+
35
+ # Set up signal handlers
36
+ signal.signal(signal.SIGTERM, self._handle_signal)
37
+ signal.signal(signal.SIGINT, self._handle_signal)
38
+
39
+ def _handle_signal(self, signum, frame):
40
+ """Handle termination signals."""
41
+ logger.info(f"Received signal {signum}, stopping training...")
42
+ self.stopped = True
43
+ if self.trainer:
44
+ # Trainer will check for stop flag
45
+ pass
46
+ sys.exit(0)
47
+
48
+ def _send_update(self, update_type: str, data: dict):
49
+ """Send update to the web server via file."""
50
+ update_dir = Path.home() / ".vespaembed" / "updates"
51
+ update_dir.mkdir(parents=True, exist_ok=True)
52
+
53
+ update_file = update_dir / f"run_{self.run_id}.jsonl"
54
+ update = {"type": update_type, "run_id": self.run_id, **data}
55
+
56
+ with open(update_file, "a") as f:
57
+ f.write(json.dumps(update) + "\n")
58
+
59
+ def _log_config(self):
60
+ """Log training configuration to UI."""
61
+ c = self.config
62
+ lines = [
63
+ "=" * 50,
64
+ "Training Configuration",
65
+ "=" * 50,
66
+ f" Base Model: {c.get('base_model', '--')}",
67
+ f" Task: {c.get('task', '--')}",
68
+ ]
69
+
70
+ if c.get("loss_variant"):
71
+ lines.append(f" Loss Variant: {c['loss_variant']}")
72
+
73
+ # Data source
74
+ if c.get("train_filename"):
75
+ lines.append(f" Training Data: {c['train_filename'].split('/')[-1]}")
76
+ elif c.get("hf_dataset"):
77
+ lines.append(f" Training Data: {c['hf_dataset']}")
78
+
79
+ # Training params
80
+ lines.extend(
81
+ [
82
+ f" Epochs: {c.get('epochs', 3)}",
83
+ f" Batch Size: {c.get('batch_size', 32)}",
84
+ f" Learning Rate: {c.get('learning_rate', 2e-5)}",
85
+ f" Optimizer: {c.get('optimizer', 'adamw_torch')}",
86
+ f" Scheduler: {c.get('scheduler', 'linear')}",
87
+ f" Warmup Ratio: {c.get('warmup_ratio', 0.1)}",
88
+ f" Weight Decay: {c.get('weight_decay', 0.01)}",
89
+ ]
90
+ )
91
+
92
+ # Precision
93
+ if c.get("bf16"):
94
+ lines.append(" Precision: BF16")
95
+ elif c.get("fp16"):
96
+ lines.append(" Precision: FP16")
97
+ else:
98
+ lines.append(" Precision: FP32")
99
+
100
+ # Optional features
101
+ if c.get("gradient_checkpointing"):
102
+ lines.append(" Grad Checkpoint: Enabled")
103
+ if c.get("gradient_accumulation_steps", 1) > 1:
104
+ lines.append(f" Grad Accum: {c['gradient_accumulation_steps']}")
105
+
106
+ # LoRA
107
+ if c.get("lora_enabled"):
108
+ lines.append(f" LoRA: r={c.get('lora_r', 64)}, alpha={c.get('lora_alpha', 128)}")
109
+
110
+ # Unsloth
111
+ if c.get("unsloth_enabled"):
112
+ lines.append(f" Unsloth: Enabled ({c.get('unsloth_save_method', 'merged_16bit')})")
113
+
114
+ # Matryoshka
115
+ if c.get("matryoshka_dims"):
116
+ lines.append(f" Matryoshka: {c['matryoshka_dims']}")
117
+
118
+ lines.append("=" * 50)
119
+
120
+ # Send each line as a log message
121
+ for line in lines:
122
+ self._send_update("log", {"message": line})
123
+
124
+ def run(self):
125
+ """Execute the training run."""
126
+ try:
127
+ logger.info(f"Starting training run {self.run_id}")
128
+ self._send_update("status", {"status": "running"})
129
+
130
+ # Log configuration to UI
131
+ self._log_config()
132
+
133
+ # Determine data source - file upload or HuggingFace dataset
134
+ train_source = self.config.get("train_filename") or self.config.get("hf_dataset")
135
+ eval_source = self.config.get("eval_filename")
136
+ eval_split = None
137
+
138
+ # For HuggingFace datasets, eval can come from a different split
139
+ hf_eval_split = self.config.get("hf_eval_split")
140
+ if self.config.get("hf_dataset") and hf_eval_split:
141
+ # Use the same dataset but different split for eval
142
+ eval_source = self.config.get("hf_dataset")
143
+ eval_split = hf_eval_split
144
+
145
+ # Build data config
146
+ data_config = DataConfig(
147
+ train=train_source,
148
+ eval=eval_source,
149
+ subset=self.config.get("hf_subset"),
150
+ split=self.config.get("hf_train_split", "train"),
151
+ eval_split=eval_split,
152
+ )
153
+
154
+ # Parse matryoshka_dims if present (comes as comma-separated string from UI)
155
+ matryoshka_dims = None
156
+ if self.config.get("matryoshka_dims"):
157
+ dims_str = self.config["matryoshka_dims"]
158
+ if isinstance(dims_str, str):
159
+ matryoshka_dims = [int(d.strip()) for d in dims_str.split(",") if d.strip()]
160
+ elif isinstance(dims_str, list):
161
+ matryoshka_dims = dims_str
162
+
163
+ # Parse lora_target_modules (comes as comma-separated string from UI)
164
+ modules_str = self.config.get("lora_target_modules", "query, key, value, dense")
165
+ if isinstance(modules_str, str):
166
+ lora_target_modules = [m.strip() for m in modules_str.split(",") if m.strip()]
167
+ else:
168
+ lora_target_modules = modules_str # Already a list
169
+
170
+ # Build LoRA config
171
+ lora_config = LoraConfig(
172
+ enabled=self.config.get("lora_enabled", False),
173
+ r=self.config.get("lora_r", 64),
174
+ alpha=self.config.get("lora_alpha", 128),
175
+ dropout=self.config.get("lora_dropout", 0.1),
176
+ target_modules=lora_target_modules,
177
+ )
178
+
179
+ # Build Unsloth config
180
+ unsloth_config = UnslothConfig(
181
+ enabled=self.config.get("unsloth_enabled", False),
182
+ save_method=self.config.get("unsloth_save_method", "merged_16bit"),
183
+ )
184
+
185
+ # Build training config with nested structure
186
+ training_config = TrainingConfig(
187
+ task=self.config["task"],
188
+ base_model=self.config["base_model"],
189
+ data=data_config,
190
+ training=TrainingHyperparameters(
191
+ epochs=self.config.get("epochs", 3),
192
+ batch_size=self.config.get("batch_size", 32),
193
+ learning_rate=self.config.get("learning_rate", 2e-5),
194
+ warmup_ratio=self.config.get("warmup_ratio", 0.1),
195
+ weight_decay=self.config.get("weight_decay", 0.01),
196
+ fp16=self.config.get("fp16", False),
197
+ bf16=self.config.get("bf16", False),
198
+ eval_steps=self.config.get("eval_steps", 500),
199
+ save_steps=self.config.get("save_steps", 500),
200
+ logging_steps=self.config.get("logging_steps", 100),
201
+ gradient_accumulation_steps=self.config.get("gradient_accumulation_steps", 1),
202
+ optimizer=self.config.get("optimizer", "adamw_torch"),
203
+ scheduler=self.config.get("scheduler", "linear"),
204
+ ),
205
+ output=OutputConfig(
206
+ dir=self.config["output_dir"], # Required - set by web app
207
+ push_to_hub=self.config.get("push_to_hub", False),
208
+ hf_username=self.config.get("hf_username"),
209
+ ),
210
+ lora=lora_config,
211
+ unsloth=unsloth_config,
212
+ max_seq_length=self.config.get("max_seq_length"), # None = auto-detect
213
+ gradient_checkpointing=self.config.get("gradient_checkpointing", False),
214
+ matryoshka_dims=matryoshka_dims,
215
+ loss_variant=self.config.get("loss_variant"),
216
+ )
217
+
218
+ # Create trainer with progress callback
219
+ self.trainer = VespaEmbedTrainer(
220
+ config=training_config,
221
+ progress_callback=self._progress_callback,
222
+ )
223
+
224
+ # Run training
225
+ self.trainer.train()
226
+
227
+ # Update status on completion
228
+ update_run_status(self.run_id, RunStatus.COMPLETED)
229
+ self._send_update("complete", {"output_dir": training_config.output.dir})
230
+ logger.info(f"Training run {self.run_id} completed successfully")
231
+
232
+ except Exception as e:
233
+ logger.error(f"Training run {self.run_id} failed: {e}")
234
+ logger.error(traceback.format_exc())
235
+ update_run_status(self.run_id, RunStatus.ERROR, error_message=str(e))
236
+ self._send_update("error", {"message": str(e)})
237
+ sys.exit(1)
238
+
239
+ def _progress_callback(self, progress: dict):
240
+ """Callback for training progress updates."""
241
+ progress_type = progress.get("type", "progress")
242
+
243
+ if progress_type == "train_start":
244
+ self._send_update("progress", progress)
245
+ self._send_update(
246
+ "log",
247
+ {"message": f"Training started: {progress['total_steps']} steps, {progress['total_epochs']} epochs"},
248
+ )
249
+ elif progress_type == "train_end":
250
+ self._send_update("progress", progress)
251
+ elapsed = self._format_time(progress.get("elapsed_seconds", 0))
252
+ self._send_update("log", {"message": f"Training completed in {elapsed}"})
253
+ else:
254
+ self._send_update("progress", progress)
255
+ self._send_update("log", {"message": self._format_progress(progress)})
256
+
257
+ def _format_time(self, seconds: float) -> str:
258
+ """Format seconds into human readable time."""
259
+ if seconds < 60:
260
+ return f"{seconds:.0f}s"
261
+ elif seconds < 3600:
262
+ mins = int(seconds // 60)
263
+ secs = int(seconds % 60)
264
+ return f"{mins}m {secs}s"
265
+ else:
266
+ hours = int(seconds // 3600)
267
+ mins = int((seconds % 3600) // 60)
268
+ return f"{hours}h {mins}m"
269
+
270
+ def _format_progress(self, progress: dict) -> str:
271
+ """Format progress as a log message like tqdm."""
272
+ step = progress.get("step", 0)
273
+ total = progress.get("total_steps", 0)
274
+ pct = progress.get("progress_pct", 0)
275
+ epoch = progress.get("epoch", 0)
276
+ total_epochs = progress.get("total_epochs", 0)
277
+ loss = progress.get("loss")
278
+ lr = progress.get("learning_rate")
279
+ eta = progress.get("eta_seconds", 0)
280
+
281
+ # Format: [Step 100/1000 (10%)] Epoch 1/3 | Loss: 0.5234 | LR: 2.00e-05 | ETA: 5m 30s
282
+ parts = [f"[Step {step}/{total} ({pct:.0f}%)]"]
283
+
284
+ if total_epochs:
285
+ parts.append(f"Epoch {epoch:.1f}/{total_epochs}")
286
+
287
+ if loss is not None:
288
+ parts.append(f"Loss: {loss:.4f}")
289
+
290
+ if lr is not None:
291
+ parts.append(f"LR: {lr:.2e}")
292
+
293
+ if eta > 0:
294
+ parts.append(f"ETA: {self._format_time(eta)}")
295
+
296
+ return " | ".join(parts)
297
+
298
+
299
+ def main():
300
+ """Main entry point for the worker."""
301
+ parser = argparse.ArgumentParser(description="VespaEmbed Training Worker")
302
+ parser.add_argument("--run-id", type=int, required=True, help="Run ID")
303
+ parser.add_argument("--config", type=str, required=True, help="Training config JSON")
304
+
305
+ args = parser.parse_args()
306
+
307
+ config = json.loads(args.config)
308
+ worker = TrainingWorker(run_id=args.run_id, config=config)
309
+ worker.run()
310
+
311
+
312
+ if __name__ == "__main__":
313
+ main()
@@ -0,0 +1,325 @@
1
+ Metadata-Version: 2.4
2
+ Name: vespaembed
3
+ Version: 0.0.3
4
+ Summary: vespaembed: no-code training for embedding models
5
+ Author: Abhishek Thakur
6
+ License: Apache 2.0
7
+ Project-URL: Homepage, https://github.com/vespaai-playground/vespaembed
8
+ Requires-Python: >=3.11
9
+ Description-Content-Type: text/markdown
10
+ License-File: LICENSE
11
+ Requires-Dist: sentence-transformers>=3.0.0
12
+ Requires-Dist: transformers>=4.40.0
13
+ Requires-Dist: accelerate>=0.26.0
14
+ Requires-Dist: torch>=2.0.0
15
+ Requires-Dist: datasets>=2.18.0
16
+ Requires-Dist: pandas>=2.0.0
17
+ Requires-Dist: pydantic>=2.0.0
18
+ Requires-Dist: rich>=13.0.0
19
+ Requires-Dist: fastapi>=0.111.0
20
+ Requires-Dist: uvicorn>=0.30.0
21
+ Requires-Dist: python-multipart>=0.0.9
22
+ Requires-Dist: websockets>=12.0
23
+ Requires-Dist: jinja2>=3.1.0
24
+ Requires-Dist: pyyaml>=6.0.0
25
+ Requires-Dist: tensorboard>=2.15.0
26
+ Requires-Dist: peft>=0.18.1
27
+ Requires-Dist: unsloth>=2026.1.4
28
+ Provides-Extra: unsloth
29
+ Requires-Dist: unsloth; extra == "unsloth"
30
+ Provides-Extra: onnx
31
+ Requires-Dist: onnx>=1.14.0; extra == "onnx"
32
+ Requires-Dist: onnxruntime>=1.23.2; extra == "onnx"
33
+ Provides-Extra: dev
34
+ Requires-Dist: black==26.1.0; extra == "dev"
35
+ Requires-Dist: isort==7.0.0; extra == "dev"
36
+ Requires-Dist: flake8==7.3.0; extra == "dev"
37
+ Requires-Dist: pytest>=9.0.2; extra == "dev"
38
+ Requires-Dist: pytest-cov>=7.0.0; extra == "dev"
39
+ Dynamic: license-file
40
+
41
+ # VespaEmbed
42
+
43
+ No-code training for embedding models. Train custom embedding models with a web UI or CLI.
44
+
45
+ ## Features
46
+
47
+ - **Web UI** - Visual interface for configuring and monitoring training
48
+ - **CLI** - Command-line interface for scripting and automation
49
+ - **Multiple Tasks** - Support for pairs, triplets, similarity scoring, and unsupervised learning
50
+ - **Loss Variants** - Choose from multiple loss functions per task
51
+ - **Matryoshka Embeddings** - Train multi-dimensional embeddings for flexible retrieval
52
+ - **LoRA Support** - Parameter-efficient fine-tuning with LoRA adapters
53
+ - **Unsloth Integration** - Faster training with Unsloth optimizations
54
+ - **HuggingFace Integration** - Load datasets, models from HuggingFace Hub, push models to Hub
55
+
56
+ ## Installation
57
+
58
+ > **Note:** VespaEmbed is in experimental phase. Install from source.
59
+
60
+ ```bash
61
+ git clone https://github.com/vespaai-playground/vespaembed.git
62
+ cd vespaembed
63
+ uv sync
64
+ ```
65
+
66
+ ### Optional Dependencies
67
+
68
+ ```bash
69
+ # For Unsloth acceleration (requires NVIDIA/AMD GPU)
70
+ uv sync --extra unsloth
71
+
72
+ # For ONNX export
73
+ uv sync --extra onnx
74
+
75
+ # For development
76
+ uv sync --extra dev
77
+ ```
78
+
79
+ ## Quick Start
80
+
81
+ ### Web UI
82
+
83
+ Launch the web interface:
84
+
85
+ ```bash
86
+ vespaembed
87
+ ```
88
+
89
+ Open http://localhost:8000 in your browser. The UI lets you:
90
+ - Upload training data (CSV or JSONL)
91
+ - Select task type and base model
92
+ - Configure hyperparameters
93
+ - Monitor training progress
94
+ - Download trained models
95
+
96
+ ### CLI
97
+
98
+ Train a model from the command line:
99
+
100
+ ```bash
101
+ vespaembed train \
102
+ --data examples/data/pairs.csv \
103
+ --task pairs \
104
+ --base-model sentence-transformers/all-MiniLM-L6-v2 \
105
+ --epochs 3
106
+ ```
107
+
108
+ Or use a YAML config file:
109
+
110
+ ```bash
111
+ vespaembed train --config config.yaml
112
+ ```
113
+
114
+ ## Tasks
115
+
116
+ VespaEmbed supports 4 training tasks based on your data format:
117
+
118
+ ### Pairs
119
+
120
+ Text pairs for semantic search. Use when you have query-document pairs without explicit negatives.
121
+
122
+ **Data format:**
123
+ ```csv
124
+ anchor,positive
125
+ What is machine learning?,Machine learning is a subset of AI...
126
+ How does photosynthesis work?,Photosynthesis converts sunlight...
127
+ ```
128
+
129
+ **Loss variants:** `mnr` (default), `mnr_symmetric`, `gist`, `cached_mnr`, `cached_gist`
130
+
131
+ ### Triplets
132
+
133
+ Text triplets with hard negatives. Use when you have explicit negative examples.
134
+
135
+ **Data format:**
136
+ ```csv
137
+ anchor,positive,negative
138
+ What is Python?,Python is a programming language...,A python is a large snake...
139
+ ```
140
+
141
+ **Loss variants:** `mnr` (default), `mnr_symmetric`, `gist`, `cached_mnr`, `cached_gist`
142
+
143
+ ### Similarity
144
+
145
+ Text pairs with similarity scores (STS-style). Use when you have continuous similarity labels.
146
+
147
+ **Data format:**
148
+ ```csv
149
+ sentence1,sentence2,score
150
+ A man is playing guitar,A person plays music,0.85
151
+ The cat is sleeping,A dog is running,0.12
152
+ ```
153
+
154
+ **Loss variants:** `cosine` (default), `cosent`, `angle`
155
+
156
+ ### TSDAE
157
+
158
+ Unsupervised learning with denoising auto-encoder. Use when you only have unlabeled text for domain adaptation.
159
+
160
+ **Data format:**
161
+ ```csv
162
+ text
163
+ Machine learning is transforming how we analyze data.
164
+ Natural language processing enables computers to understand human language.
165
+ ```
166
+
167
+ ## Configuration
168
+
169
+ ### CLI Arguments
170
+
171
+ ```bash
172
+ vespaembed train \
173
+ --data <path> # Training data (CSV, JSONL, or HF dataset)
174
+ --task <task> # Task type: pairs, triplets, similarity, tsdae
175
+ --base-model <model> # Base model name or path
176
+ --project <name> # Project name (optional)
177
+ --eval-data <path> # Evaluation data (optional)
178
+ --epochs <n> # Number of epochs (default: 3)
179
+ --batch-size <n> # Batch size (default: 32)
180
+ --learning-rate <lr> # Learning rate (default: 2e-5)
181
+ --optimizer <opt> # Optimizer (default: adamw_torch)
182
+ --scheduler <sched> # LR scheduler (default: linear)
183
+ --matryoshka # Enable Matryoshka embeddings
184
+ --matryoshka-dims <dims> # Dimensions (default: 768,512,256,128,64)
185
+ --unsloth # Use Unsloth for faster training
186
+ --subset <name> # HuggingFace dataset subset
187
+ --split <name> # HuggingFace dataset split
188
+ ```
189
+
190
+ ### Optimizers
191
+
192
+ | Option | Description |
193
+ |--------|-------------|
194
+ | `adamw_torch` | AdamW (default) |
195
+ | `adamw_torch_fused` | Fused AdamW (faster on CUDA) |
196
+ | `adamw_8bit` | 8-bit AdamW (memory efficient) |
197
+ | `adafactor` | Adafactor (memory efficient, no momentum) |
198
+ | `sgd` | SGD with momentum |
199
+
200
+ ### Schedulers
201
+
202
+ | Option | Description |
203
+ |--------|-------------|
204
+ | `linear` | Linear decay (default) |
205
+ | `cosine` | Cosine annealing |
206
+ | `cosine_with_restarts` | Cosine with warm restarts |
207
+ | `constant` | Constant learning rate |
208
+ | `constant_with_warmup` | Constant after warmup |
209
+ | `polynomial` | Polynomial decay |
210
+
211
+ ### YAML Configuration
212
+
213
+ ```yaml
214
+ task: pairs
215
+ base_model: sentence-transformers/all-MiniLM-L6-v2
216
+
217
+ data:
218
+ train: train.csv
219
+ eval: eval.csv # optional
220
+
221
+ training:
222
+ epochs: 3
223
+ batch_size: 32
224
+ learning_rate: 2e-5
225
+ warmup_ratio: 0.1
226
+ weight_decay: 0.01
227
+ fp16: true
228
+ eval_steps: 500
229
+ save_steps: 500
230
+ logging_steps: 100
231
+ optimizer: adamw_torch # adamw_torch, adamw_8bit, adafactor, sgd
232
+ scheduler: linear # linear, cosine, constant, polynomial
233
+
234
+ output:
235
+ dir: ./output
236
+ push_to_hub: false
237
+ hf_username: null
238
+
239
+ # Optional: LoRA configuration
240
+ lora:
241
+ enabled: false
242
+ r: 64
243
+ alpha: 128
244
+ dropout: 0.1
245
+ target_modules: [query, key, value, dense]
246
+
247
+ # Optional: Matryoshka dimensions
248
+ matryoshka_dims: [768, 512, 256, 128, 64]
249
+
250
+ # Optional: Loss variant (uses task default if not specified)
251
+ loss_variant: mnr
252
+ ```
253
+
254
+ ### HuggingFace Datasets
255
+
256
+ Load datasets directly from HuggingFace Hub:
257
+
258
+ ```bash
259
+ vespaembed train \
260
+ --data sentence-transformers/all-nli \
261
+ --subset triplet \
262
+ --split train \
263
+ --task triplets \
264
+ --base-model sentence-transformers/all-MiniLM-L6-v2
265
+ ```
266
+
267
+ ## CLI Commands
268
+
269
+ | Command | Description |
270
+ |---------|-------------|
271
+ | `vespaembed` | Launch web UI (default) |
272
+ | `vespaembed serve` | Launch web UI |
273
+ | `vespaembed train` | Train a model |
274
+ | `vespaembed evaluate` | Evaluate a model |
275
+ | `vespaembed export` | Export model to ONNX |
276
+ | `vespaembed info` | Show task information |
277
+
278
+ ## Output
279
+
280
+ Trained models are saved to `~/.vespaembed/projects/<project-name>/`:
281
+
282
+ ```
283
+ ~/.vespaembed/projects/my-project/
284
+ ├── final/ # Final trained model
285
+ ├── checkpoint-500/ # Training checkpoints
286
+ ├── checkpoint-1000/
287
+ └── logs/ # TensorBoard logs
288
+ ```
289
+
290
+ ## Column Aliases
291
+
292
+ VespaEmbed automatically recognizes common column name variations:
293
+
294
+ | Task | Expected | Also Accepts |
295
+ |------|----------|--------------|
296
+ | pairs | `anchor` | `query`, `question`, `sent1`, `sentence1`, `text1` |
297
+ | pairs | `positive` | `document`, `answer`, `pos`, `sent2`, `sentence2`, `text2` |
298
+ | triplets | `negative` | `neg`, `hard_negative`, `sent3`, `sentence3`, `text3` |
299
+ | similarity | `sentence1` | `sent1`, `text1`, `anchor`, `query` |
300
+ | similarity | `sentence2` | `sent2`, `text2`, `positive`, `document` |
301
+ | similarity | `score` | `similarity`, `label`, `sim_score` |
302
+ | tsdae | `text` | `sentence`, `sentences`, `content`, `input` |
303
+
304
+ ## Development
305
+
306
+ ```bash
307
+ # Install dev dependencies
308
+ uv sync --extra dev
309
+
310
+ # Run tests
311
+ uv run pytest tests/
312
+
313
+ # Run tests with coverage
314
+ uv run pytest tests/ --cov=vespaembed
315
+
316
+ # Format code
317
+ make format
318
+
319
+ # Lint
320
+ make lint
321
+ ```
322
+
323
+ ## License
324
+
325
+ Apache 2.0
@@ -0,0 +1,47 @@
1
+ vespaembed/__init__.py,sha256=4GZKi13lDTD25YBkGakhZyEQZWTER_OWQMNPoH_UM2c,22
2
+ vespaembed/db.py,sha256=9m46Ljc8e2s8NnhRJRkf3pj-U72PGM3o94pEec8IFoo,4272
3
+ vespaembed/enums.py,sha256=1ovPAJW-5f4TrI8dnHMJYitau30s5TeMk0s2Ptkppaw,1453
4
+ vespaembed/worker.py,sha256=2vI292BWRKKK-hBtTQj3fxj5sKmuuXODubYk6yCH5hU,12230
5
+ vespaembed/cli/__init__.py,sha256=rWKimiyk7YndJfRCEAlZdsDv5l96Cixwl_UX6Nq2vf0,433
6
+ vespaembed/cli/vespaembed.py,sha256=LcGh9k5fQPjHoDOzf-1dUhcACDxQLXIXpVGmjCylbdI,1616
7
+ vespaembed/cli/commands/__init__.py,sha256=4lIRTQRXe_QQbtFJDNgeRJ3BZf4bD7FmhEScVWeI6S8,376
8
+ vespaembed/cli/commands/evaluate.py,sha256=bx74DrSdCqzP6KWXBx_6N1VcFE-vOt9FkFYKkvkTpU8,2595
9
+ vespaembed/cli/commands/export.py,sha256=SB8cIC1FqHXLKdrqvL2ca_mzJDGe-W9ECdYV9T5z--k,2614
10
+ vespaembed/cli/commands/info.py,sha256=3n0hz7TmUdKX7ZvgmBnOFtpNtMu3VA_3zujGidGY1G8,1615
11
+ vespaembed/cli/commands/serve.py,sha256=Qw_MxtsrIwfk45vs9l3cMvCn9_H8MKxUXYOg-Nqf940,1403
12
+ vespaembed/cli/commands/train.py,sha256=IEYMWr5hrZGAshVljHAJ10Z7Iw8QFiGb54fe39PIRKI,9015
13
+ vespaembed/core/__init__.py,sha256=NcgBvvomhTlLim8WG84cFnnGNDFKSJolmUrBsbvafvM,145
14
+ vespaembed/core/config.py,sha256=eERuTr4j-Gpp0TovrRmE7GGECuIWhbUTGzbJ1Kuk8ik,6041
15
+ vespaembed/core/registry.py,sha256=CrSIQxx6a9QQF7po-YhseMfcHs5IBDd_3LwHpxUACWE,5179
16
+ vespaembed/core/trainer.py,sha256=9PPMEpH-hOqqO7ZaMqrYlVzHKffAbUIvo3makBXB-0U,22802
17
+ vespaembed/datasets/__init__.py,sha256=F1QDj036DWAjRQO38u4dK04tzzRtPwGLk_AMbtfjIp8,80
18
+ vespaembed/datasets/loader.py,sha256=aRasiQxqIpaH7sfbZrp_7h-xc8zHFfgojCQFWVHvsWE,2252
19
+ vespaembed/datasets/formats/__init__.py,sha256=UDhvPk9FHq8RR_V-uphciNDYH3Ut--DvT7j8DmbQSTQ,235
20
+ vespaembed/datasets/formats/csv.py,sha256=jcYgaLtNbuE5CFtg9jHha4VekKay67y6LKTXCWPasGA,290
21
+ vespaembed/datasets/formats/huggingface.py,sha256=v6zY4wTNLOe-Qx-TxJBPwF5QM6WMkcXPDlmd2Dar_9k,865
22
+ vespaembed/datasets/formats/jsonl.py,sha256=8VsfzZHWU_16z7npBETSyDVqDvzosprW_U9wNdt3exE,512
23
+ vespaembed/evaluation/__init__.py,sha256=hE8xZ0xG-qPTP1_At5dcZG1p1wwUATFtJZlxhUindmc,91
24
+ vespaembed/evaluation/factory.py,sha256=_6b_b5Sj279owCWSLp94KPRKcZeUq-tnF9IyxNRBFWM,2715
25
+ vespaembed/models/__init__.py,sha256=EVHS2Xuvr12JLIm9BA8SXvCsJcSPRcIKGBmuNHFYNrA,140
26
+ vespaembed/models/export.py,sha256=fdX21e6UEZpHhhKU9gwpUBgYJWRmX9_S6EbWP6-QWwA,2436
27
+ vespaembed/models/loader.py,sha256=7VWWTpPXeVp9eJq4EsmNdA1zBVYf9BQQRRF1v9ObnJ8,813
28
+ vespaembed/static/css/styles.css,sha256=yOXm9UW41s4nZpG2CBQA5tw6bHzEBFeUMofb6LdqNWk,37350
29
+ vespaembed/static/js/app.js,sha256=kri7JVLSNGWludIYj_ZKQLrJ6fzziQlUv5dd0L2NxBI,52969
30
+ vespaembed/tasks/__init__.py,sha256=cMSdguS9t2mElHfiVpZJ5BcnXtyBQmHnT2I30INZ-gs,883
31
+ vespaembed/tasks/base.py,sha256=ER2H91HYsKE25tJc6J3TkYkgwVhsZkICRDb5MSXLaek,5105
32
+ vespaembed/tasks/pairs.py,sha256=VRmY7bu3fNQTPtopq-GvfsHH4g2yVdHL-Nr2KdU4-28,3520
33
+ vespaembed/tasks/similarity.py,sha256=xk5hGoALNJnelM0n5lRys1gF30KiVdlUplgcxBI0lz8,3037
34
+ vespaembed/tasks/triplets.py,sha256=ZTfF6yFQS_368LJ-rV4qVQGqrxxGzS0eiGO1f4tqICU,3581
35
+ vespaembed/tasks/tsdae.py,sha256=VcJbyABee9yfv2leNe_yVMDa48XfronAFcEDDIVrZKo,3434
36
+ vespaembed/templates/index.html,sha256=z4vIejLAgSOL0ISTJpmyMqSegCew-X3Aj4wv1xV4knY,30185
37
+ vespaembed/utils/__init__.py,sha256=ZvKPCgX-ISxoj8LEwkMZ0BgVK3MZcE3TExkNBOkRhfk,74
38
+ vespaembed/utils/logging.py,sha256=xxbMAohRsmfB5TZX-43iX488_2p-KBmttdBRuCyA4dE,1736
39
+ vespaembed/web/__init__.py,sha256=rFAxCq3VJ7UDoLoYVZPshpAoetYV9JPlKpJdHREC_PM,37
40
+ vespaembed/web/app.py,sha256=0GdPAWs5y2bhJv6XCmZMdf6-LMnZrLA5RHxtylbxu_4,19380
41
+ vespaembed/web/api/__init__.py,sha256=GLiedM-kLCaQTZCa_UmkqIrjEGjwvGdDdfKpZhUFgN0,16
42
+ vespaembed-0.0.3.dist-info/licenses/LICENSE,sha256=QwcOLU5TJoTeUhuIXzhdCEEDDvorGiC6-3YTOl4TecE,11356
43
+ vespaembed-0.0.3.dist-info/METADATA,sha256=uJW0Jy-2FKMgZHfjjG-tR7fsVtyvvTeaFc1avthbK7c,8655
44
+ vespaembed-0.0.3.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
45
+ vespaembed-0.0.3.dist-info/entry_points.txt,sha256=_yO09x95rXw-0RhJkMmXAt5bjBOysVy10jbe4KJuMIg,62
46
+ vespaembed-0.0.3.dist-info/top_level.txt,sha256=GhqAsJ29dVFnUrvwec8lRqDCC1BqlWntZZFibsgCEdY,11
47
+ vespaembed-0.0.3.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.10.1)
2
+ Generator: setuptools (80.10.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5