vespaembed 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. vespaembed/__init__.py +1 -1
  2. vespaembed/cli/__init__.py +17 -0
  3. vespaembed/cli/commands/__init__.py +7 -0
  4. vespaembed/cli/commands/evaluate.py +85 -0
  5. vespaembed/cli/commands/export.py +86 -0
  6. vespaembed/cli/commands/info.py +52 -0
  7. vespaembed/cli/commands/serve.py +49 -0
  8. vespaembed/cli/commands/train.py +267 -0
  9. vespaembed/cli/vespaembed.py +55 -0
  10. vespaembed/core/__init__.py +2 -0
  11. vespaembed/core/config.py +164 -0
  12. vespaembed/core/registry.py +158 -0
  13. vespaembed/core/trainer.py +573 -0
  14. vespaembed/datasets/__init__.py +3 -0
  15. vespaembed/datasets/formats/__init__.py +5 -0
  16. vespaembed/datasets/formats/csv.py +15 -0
  17. vespaembed/datasets/formats/huggingface.py +34 -0
  18. vespaembed/datasets/formats/jsonl.py +26 -0
  19. vespaembed/datasets/loader.py +80 -0
  20. vespaembed/db.py +176 -0
  21. vespaembed/enums.py +58 -0
  22. vespaembed/evaluation/__init__.py +3 -0
  23. vespaembed/evaluation/factory.py +86 -0
  24. vespaembed/models/__init__.py +4 -0
  25. vespaembed/models/export.py +89 -0
  26. vespaembed/models/loader.py +25 -0
  27. vespaembed/static/css/styles.css +1800 -0
  28. vespaembed/static/js/app.js +1485 -0
  29. vespaembed/tasks/__init__.py +23 -0
  30. vespaembed/tasks/base.py +144 -0
  31. vespaembed/tasks/pairs.py +91 -0
  32. vespaembed/tasks/similarity.py +84 -0
  33. vespaembed/tasks/triplets.py +90 -0
  34. vespaembed/tasks/tsdae.py +102 -0
  35. vespaembed/templates/index.html +544 -0
  36. vespaembed/utils/__init__.py +3 -0
  37. vespaembed/utils/logging.py +69 -0
  38. vespaembed/web/__init__.py +1 -0
  39. vespaembed/web/api/__init__.py +1 -0
  40. vespaembed/web/app.py +605 -0
  41. vespaembed/worker.py +313 -0
  42. vespaembed-0.0.2.dist-info/METADATA +325 -0
  43. vespaembed-0.0.2.dist-info/RECORD +47 -0
  44. {vespaembed-0.0.1.dist-info → vespaembed-0.0.2.dist-info}/WHEEL +1 -1
  45. vespaembed-0.0.1.dist-info/METADATA +0 -20
  46. vespaembed-0.0.1.dist-info/RECORD +0 -7
  47. {vespaembed-0.0.1.dist-info → vespaembed-0.0.2.dist-info}/entry_points.txt +0 -0
  48. {vespaembed-0.0.1.dist-info → vespaembed-0.0.2.dist-info}/licenses/LICENSE +0 -0
  49. {vespaembed-0.0.1.dist-info → vespaembed-0.0.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,80 @@
1
+ from pathlib import Path
2
+ from typing import Optional
3
+
4
+ from datasets import Dataset
5
+
6
+ from vespaembed.datasets.formats.csv import load_csv
7
+ from vespaembed.datasets.formats.huggingface import load_hf_dataset
8
+ from vespaembed.datasets.formats.jsonl import load_jsonl
9
+
10
+
11
+ def load_dataset(
12
+ path: str,
13
+ subset: Optional[str] = None,
14
+ split: Optional[str] = None,
15
+ ) -> Dataset:
16
+ """Load a dataset from various sources.
17
+
18
+ Supports:
19
+ - CSV files (.csv)
20
+ - JSONL files (.jsonl)
21
+ - HuggingFace datasets (org/dataset-name)
22
+
23
+ Args:
24
+ path: Path to file or HuggingFace dataset name
25
+ subset: HuggingFace dataset subset (optional)
26
+ split: HuggingFace dataset split (optional, defaults to "train")
27
+
28
+ Returns:
29
+ HuggingFace Dataset object
30
+
31
+ Raises:
32
+ ValueError: If file format is not supported
33
+ FileNotFoundError: If file does not exist
34
+ """
35
+ path_obj = Path(path)
36
+
37
+ # Check if it's a local file
38
+ if path_obj.exists():
39
+ suffix = path_obj.suffix.lower()
40
+
41
+ if suffix == ".csv":
42
+ return load_csv(path)
43
+ elif suffix in (".jsonl", ".json"):
44
+ return load_jsonl(path)
45
+ else:
46
+ raise ValueError(f"Unsupported file format: {suffix}. " "Supported formats: .csv, .jsonl")
47
+
48
+ # Check if it looks like a HuggingFace dataset
49
+ if "/" in path or not path_obj.suffix:
50
+ return load_hf_dataset(path, subset=subset, split=split or "train")
51
+
52
+ # File doesn't exist and doesn't look like HF dataset
53
+ raise FileNotFoundError(
54
+ f"File not found: {path}. " "Provide a valid file path or HuggingFace dataset name (e.g., 'org/dataset-name')."
55
+ )
56
+
57
+
58
+ def preview_dataset(dataset: Dataset, num_samples: int = 5) -> list[dict]:
59
+ """Preview a dataset by returning the first N samples.
60
+
61
+ Args:
62
+ dataset: Dataset to preview
63
+ num_samples: Number of samples to return
64
+
65
+ Returns:
66
+ List of sample dictionaries
67
+ """
68
+ return [dataset[i] for i in range(min(num_samples, len(dataset)))]
69
+
70
+
71
+ def get_columns(dataset: Dataset) -> list[str]:
72
+ """Get column names from a dataset.
73
+
74
+ Args:
75
+ dataset: Dataset to inspect
76
+
77
+ Returns:
78
+ List of column names
79
+ """
80
+ return dataset.column_names
vespaembed/db.py ADDED
@@ -0,0 +1,176 @@
1
+ import json
2
+ import shutil
3
+ import sqlite3
4
+ from datetime import datetime
5
+ from pathlib import Path
6
+ from typing import Optional
7
+
8
+ from vespaembed.enums import RunStatus
9
+
10
+ # Default database location
11
+ DEFAULT_DB_DIR = Path.home() / ".vespaembed"
12
+ DEFAULT_DB_PATH = DEFAULT_DB_DIR / "vespaembed.db"
13
+
14
+
15
+ def get_db_path() -> Path:
16
+ """Get the database path, creating directory if needed."""
17
+ DEFAULT_DB_DIR.mkdir(parents=True, exist_ok=True)
18
+ return DEFAULT_DB_PATH
19
+
20
+
21
+ def get_connection() -> sqlite3.Connection:
22
+ """Get a database connection."""
23
+ conn = sqlite3.connect(get_db_path(), check_same_thread=False)
24
+ conn.row_factory = sqlite3.Row
25
+ return conn
26
+
27
+
28
+ def init_db():
29
+ """Initialize the database schema."""
30
+ conn = get_connection()
31
+ cursor = conn.cursor()
32
+
33
+ cursor.execute("""
34
+ CREATE TABLE IF NOT EXISTS runs (
35
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
36
+ status TEXT NOT NULL DEFAULT 'pending',
37
+ pid INTEGER,
38
+ config TEXT,
39
+ project_name TEXT,
40
+ output_dir TEXT,
41
+ created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
42
+ updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
43
+ error_message TEXT
44
+ )
45
+ """)
46
+
47
+ conn.commit()
48
+ conn.close()
49
+
50
+
51
+ def create_run(config: dict, project_name: str, output_dir: str) -> int:
52
+ """Create a new training run."""
53
+ conn = get_connection()
54
+ cursor = conn.cursor()
55
+
56
+ cursor.execute(
57
+ """
58
+ INSERT INTO runs (status, config, project_name, output_dir)
59
+ VALUES (?, ?, ?, ?)
60
+ """,
61
+ (RunStatus.PENDING.value, json.dumps(config), project_name, output_dir),
62
+ )
63
+
64
+ run_id = cursor.lastrowid
65
+ conn.commit()
66
+ conn.close()
67
+
68
+ return run_id
69
+
70
+
71
+ def update_run_status(
72
+ run_id: int,
73
+ status: RunStatus,
74
+ pid: Optional[int] = None,
75
+ error_message: Optional[str] = None,
76
+ ):
77
+ """Update a run's status."""
78
+ conn = get_connection()
79
+ cursor = conn.cursor()
80
+
81
+ if pid is not None:
82
+ cursor.execute(
83
+ """
84
+ UPDATE runs SET status = ?, pid = ?, updated_at = ?
85
+ WHERE id = ?
86
+ """,
87
+ (status.value, pid, datetime.now(), run_id),
88
+ )
89
+ elif error_message is not None:
90
+ cursor.execute(
91
+ """
92
+ UPDATE runs SET status = ?, error_message = ?, updated_at = ?
93
+ WHERE id = ?
94
+ """,
95
+ (status.value, error_message, datetime.now(), run_id),
96
+ )
97
+ else:
98
+ cursor.execute(
99
+ """
100
+ UPDATE runs SET status = ?, updated_at = ?
101
+ WHERE id = ?
102
+ """,
103
+ (status.value, datetime.now(), run_id),
104
+ )
105
+
106
+ conn.commit()
107
+ conn.close()
108
+
109
+
110
+ def get_run(run_id: int) -> Optional[dict]:
111
+ """Get a run by ID."""
112
+ conn = get_connection()
113
+ cursor = conn.cursor()
114
+
115
+ cursor.execute("SELECT * FROM runs WHERE id = ?", (run_id,))
116
+ row = cursor.fetchone()
117
+ conn.close()
118
+
119
+ if row:
120
+ return dict(row)
121
+ return None
122
+
123
+
124
+ def get_all_runs() -> list[dict]:
125
+ """Get all runs."""
126
+ conn = get_connection()
127
+ cursor = conn.cursor()
128
+
129
+ cursor.execute("SELECT * FROM runs ORDER BY created_at DESC")
130
+ rows = cursor.fetchall()
131
+ conn.close()
132
+
133
+ return [dict(row) for row in rows]
134
+
135
+
136
+ def get_active_run() -> Optional[dict]:
137
+ """Get the currently active (running or pending) run."""
138
+ conn = get_connection()
139
+ cursor = conn.cursor()
140
+
141
+ cursor.execute(
142
+ """
143
+ SELECT * FROM runs
144
+ WHERE status IN (?, ?)
145
+ ORDER BY created_at DESC
146
+ LIMIT 1
147
+ """,
148
+ (RunStatus.PENDING.value, RunStatus.RUNNING.value),
149
+ )
150
+
151
+ row = cursor.fetchone()
152
+ conn.close()
153
+
154
+ if row:
155
+ return dict(row)
156
+ return None
157
+
158
+
159
+ def delete_run(run_id: int, delete_files: bool = True):
160
+ """Delete a run and optionally its output files."""
161
+ run = get_run(run_id)
162
+
163
+ if run and delete_files and run.get("output_dir"):
164
+ output_path = Path(run["output_dir"])
165
+ if output_path.exists():
166
+ shutil.rmtree(output_path, ignore_errors=True)
167
+
168
+ conn = get_connection()
169
+ cursor = conn.cursor()
170
+ cursor.execute("DELETE FROM runs WHERE id = ?", (run_id,))
171
+ conn.commit()
172
+ conn.close()
173
+
174
+
175
+ # Initialize database on import
176
+ init_db()
vespaembed/enums.py ADDED
@@ -0,0 +1,58 @@
1
+ from enum import Enum
2
+
3
+
4
+ class TaskType(str, Enum):
5
+ """Supported training tasks.
6
+
7
+ Tasks are organized by data format:
8
+ - pairs: Text pairs for semantic search (anchor, positive)
9
+ - triplets: Text triplets with hard negatives (anchor, positive, negative)
10
+ - similarity: Text pairs with similarity scores
11
+ - tsdae: Unlabeled text for unsupervised learning
12
+
13
+ Note: Matryoshka is a training option (--matryoshka flag), not a separate task.
14
+ """
15
+
16
+ PAIRS = "pairs"
17
+ TRIPLETS = "triplets"
18
+ SIMILARITY = "similarity"
19
+ TSDAE = "tsdae"
20
+
21
+
22
+ class LossVariant(str, Enum):
23
+ """Available loss function variants.
24
+
25
+ For pairs task:
26
+ - mnr: MultipleNegativesRankingLoss (default, recommended)
27
+ - mnr_symmetric: Bidirectional ranking
28
+ - gist: GISTEmbedLoss with guide model
29
+ - cached_mnr: Cached version for larger batches
30
+ - cached_gist: Cached GIST
31
+
32
+ For similarity task:
33
+ - cosine: CosineSimilarityLoss (default)
34
+ - cosent: CoSENTLoss
35
+ - angle: AnglELoss
36
+ """
37
+
38
+ # Pairs task variants
39
+ MNR = "mnr"
40
+ MNR_SYMMETRIC = "mnr_symmetric"
41
+ GIST = "gist"
42
+ CACHED_MNR = "cached_mnr"
43
+ CACHED_GIST = "cached_gist"
44
+
45
+ # Similarity task variants
46
+ COSINE = "cosine"
47
+ COSENT = "cosent"
48
+ ANGLE = "angle"
49
+
50
+
51
+ class RunStatus(str, Enum):
52
+ """Training run status."""
53
+
54
+ PENDING = "pending"
55
+ RUNNING = "running"
56
+ COMPLETED = "completed"
57
+ STOPPED = "stopped"
58
+ ERROR = "error"
@@ -0,0 +1,3 @@
1
+ from vespaembed.evaluation.factory import create_evaluator
2
+
3
+ __all__ = ["create_evaluator"]
@@ -0,0 +1,86 @@
1
+ from typing import Any, Optional
2
+
3
+ from datasets import Dataset
4
+ from sentence_transformers.evaluation import (
5
+ BinaryClassificationEvaluator,
6
+ EmbeddingSimilarityEvaluator,
7
+ InformationRetrievalEvaluator,
8
+ TripletEvaluator,
9
+ )
10
+
11
+
12
+ def create_evaluator(
13
+ task: str,
14
+ eval_dataset: Dataset,
15
+ name: str = "eval",
16
+ ) -> Optional[Any]:
17
+ """Create an appropriate evaluator based on task type.
18
+
19
+ Args:
20
+ task: Task name (mnr, triplet, contrastive, sts, nli, tsdae, matryoshka)
21
+ eval_dataset: Prepared evaluation dataset
22
+ name: Evaluator name
23
+
24
+ Returns:
25
+ Evaluator instance or None
26
+ """
27
+ if task == "mnr":
28
+ return _create_ir_evaluator(eval_dataset, name)
29
+ elif task == "triplet":
30
+ return _create_triplet_evaluator(eval_dataset, name)
31
+ elif task == "contrastive":
32
+ return _create_binary_evaluator(eval_dataset, name)
33
+ elif task == "sts":
34
+ return _create_similarity_evaluator(eval_dataset, name)
35
+ elif task == "nli":
36
+ return _create_similarity_evaluator(eval_dataset, name)
37
+ elif task == "tsdae":
38
+ return None # TSDAE has no intrinsic evaluator
39
+ elif task == "matryoshka":
40
+ return _create_ir_evaluator(eval_dataset, name)
41
+ else:
42
+ return None
43
+
44
+
45
+ def _create_ir_evaluator(dataset: Dataset, name: str) -> InformationRetrievalEvaluator:
46
+ """Create Information Retrieval evaluator."""
47
+ queries = {str(i): text for i, text in enumerate(dataset["anchor"])}
48
+ corpus = {str(i): text for i, text in enumerate(dataset["positive"])}
49
+ relevant_docs = {str(i): {str(i)} for i in range(len(dataset))}
50
+
51
+ return InformationRetrievalEvaluator(
52
+ queries=queries,
53
+ corpus=corpus,
54
+ relevant_docs=relevant_docs,
55
+ name=name,
56
+ )
57
+
58
+
59
+ def _create_triplet_evaluator(dataset: Dataset, name: str) -> TripletEvaluator:
60
+ """Create Triplet evaluator."""
61
+ return TripletEvaluator(
62
+ anchors=dataset["anchor"],
63
+ positives=dataset["positive"],
64
+ negatives=dataset["negative"],
65
+ name=name,
66
+ )
67
+
68
+
69
+ def _create_binary_evaluator(dataset: Dataset, name: str) -> BinaryClassificationEvaluator:
70
+ """Create Binary Classification evaluator."""
71
+ return BinaryClassificationEvaluator(
72
+ sentences1=dataset["sentence1"],
73
+ sentences2=dataset["sentence2"],
74
+ labels=dataset["label"],
75
+ name=name,
76
+ )
77
+
78
+
79
+ def _create_similarity_evaluator(dataset: Dataset, name: str) -> EmbeddingSimilarityEvaluator:
80
+ """Create Embedding Similarity evaluator."""
81
+ return EmbeddingSimilarityEvaluator(
82
+ sentences1=dataset["sentence1"],
83
+ sentences2=dataset["sentence2"],
84
+ scores=dataset["score"],
85
+ name=name,
86
+ )
@@ -0,0 +1,4 @@
1
+ from vespaembed.models.export import export_model
2
+ from vespaembed.models.loader import load_model
3
+
4
+ __all__ = ["export_model", "load_model"]
@@ -0,0 +1,89 @@
1
+ from pathlib import Path
2
+ from typing import Optional
3
+
4
+ from sentence_transformers import SentenceTransformer
5
+
6
+
7
+ def export_model(
8
+ model: SentenceTransformer,
9
+ output_path: str,
10
+ format: str = "onnx",
11
+ ) -> str:
12
+ """Export a model to a different format.
13
+
14
+ Args:
15
+ model: SentenceTransformer model
16
+ output_path: Output directory or file path
17
+ format: Export format ("onnx")
18
+
19
+ Returns:
20
+ Path to exported model
21
+
22
+ Raises:
23
+ ValueError: If format is not supported
24
+ """
25
+ output_path = Path(output_path)
26
+
27
+ if format.lower() == "onnx":
28
+ return _export_onnx(model, output_path)
29
+ else:
30
+ raise ValueError(f"Unsupported export format: {format}. Supported: onnx")
31
+
32
+
33
+ def _export_onnx(model: SentenceTransformer, output_path: Path) -> str:
34
+ """Export model to ONNX format.
35
+
36
+ Args:
37
+ model: SentenceTransformer model
38
+ output_path: Output directory
39
+
40
+ Returns:
41
+ Path to ONNX model
42
+ """
43
+ try:
44
+ import onnx # noqa: F401
45
+ except ImportError:
46
+ raise ImportError("ONNX not installed. Install with: pip install vespaembed[onnx]")
47
+
48
+ output_path.mkdir(parents=True, exist_ok=True)
49
+ onnx_path = output_path / "model.onnx"
50
+
51
+ # Use sentence-transformers built-in ONNX export if available
52
+ # Otherwise fall back to manual export
53
+ try:
54
+ model.save(str(output_path), model_name_or_path="model.onnx", create_model_card=False)
55
+ except Exception:
56
+ # Manual export via transformers
57
+ from transformers import AutoTokenizer
58
+
59
+ tokenizer = AutoTokenizer.from_pretrained(model[0].auto_model.config._name_or_path)
60
+
61
+ # Export the transformer part
62
+ model[0].auto_model.save_pretrained(output_path)
63
+ tokenizer.save_pretrained(output_path)
64
+
65
+ return str(onnx_path)
66
+
67
+
68
+ def push_to_hub(
69
+ model: SentenceTransformer,
70
+ repo_id: str,
71
+ commit_message: Optional[str] = None,
72
+ private: bool = False,
73
+ ) -> str:
74
+ """Push model to HuggingFace Hub.
75
+
76
+ Args:
77
+ model: SentenceTransformer model
78
+ repo_id: Repository ID (e.g., "username/model-name")
79
+ commit_message: Commit message
80
+ private: Whether to create a private repository
81
+
82
+ Returns:
83
+ URL of the model on HuggingFace Hub
84
+ """
85
+ return model.push_to_hub(
86
+ repo_id=repo_id,
87
+ commit_message=commit_message or "Upload model via vespaembed",
88
+ private=private,
89
+ )
@@ -0,0 +1,25 @@
1
+ from sentence_transformers import SentenceTransformer
2
+
3
+
4
+ def load_model(model_name_or_path: str, use_unsloth: bool = False) -> SentenceTransformer:
5
+ """Load a sentence transformer model.
6
+
7
+ Args:
8
+ model_name_or_path: Model name from HuggingFace Hub or local path
9
+ use_unsloth: Whether to use Unsloth for faster inference
10
+
11
+ Returns:
12
+ SentenceTransformer model
13
+ """
14
+ if use_unsloth:
15
+ try:
16
+ from unsloth import FastSentenceTransformer
17
+
18
+ return FastSentenceTransformer.from_pretrained(
19
+ model_name_or_path,
20
+ for_inference=True,
21
+ )
22
+ except ImportError:
23
+ raise ImportError("Unsloth not installed. Install with: pip install vespaembed[unsloth]")
24
+
25
+ return SentenceTransformer(model_name_or_path)