vespaembed 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vespaembed/__init__.py +1 -1
- vespaembed/cli/__init__.py +17 -0
- vespaembed/cli/commands/__init__.py +7 -0
- vespaembed/cli/commands/evaluate.py +85 -0
- vespaembed/cli/commands/export.py +86 -0
- vespaembed/cli/commands/info.py +52 -0
- vespaembed/cli/commands/serve.py +49 -0
- vespaembed/cli/commands/train.py +267 -0
- vespaembed/cli/vespaembed.py +55 -0
- vespaembed/core/__init__.py +2 -0
- vespaembed/core/config.py +164 -0
- vespaembed/core/registry.py +158 -0
- vespaembed/core/trainer.py +573 -0
- vespaembed/datasets/__init__.py +3 -0
- vespaembed/datasets/formats/__init__.py +5 -0
- vespaembed/datasets/formats/csv.py +15 -0
- vespaembed/datasets/formats/huggingface.py +34 -0
- vespaembed/datasets/formats/jsonl.py +26 -0
- vespaembed/datasets/loader.py +80 -0
- vespaembed/db.py +176 -0
- vespaembed/enums.py +58 -0
- vespaembed/evaluation/__init__.py +3 -0
- vespaembed/evaluation/factory.py +86 -0
- vespaembed/models/__init__.py +4 -0
- vespaembed/models/export.py +89 -0
- vespaembed/models/loader.py +25 -0
- vespaembed/static/css/styles.css +1800 -0
- vespaembed/static/js/app.js +1485 -0
- vespaembed/tasks/__init__.py +23 -0
- vespaembed/tasks/base.py +144 -0
- vespaembed/tasks/pairs.py +91 -0
- vespaembed/tasks/similarity.py +84 -0
- vespaembed/tasks/triplets.py +90 -0
- vespaembed/tasks/tsdae.py +102 -0
- vespaembed/templates/index.html +544 -0
- vespaembed/utils/__init__.py +3 -0
- vespaembed/utils/logging.py +69 -0
- vespaembed/web/__init__.py +1 -0
- vespaembed/web/api/__init__.py +1 -0
- vespaembed/web/app.py +605 -0
- vespaembed/worker.py +313 -0
- vespaembed-0.0.3.dist-info/METADATA +325 -0
- vespaembed-0.0.3.dist-info/RECORD +47 -0
- {vespaembed-0.0.1.dist-info → vespaembed-0.0.3.dist-info}/WHEEL +1 -1
- vespaembed-0.0.1.dist-info/METADATA +0 -20
- vespaembed-0.0.1.dist-info/RECORD +0 -7
- {vespaembed-0.0.1.dist-info → vespaembed-0.0.3.dist-info}/entry_points.txt +0 -0
- {vespaembed-0.0.1.dist-info → vespaembed-0.0.3.dist-info}/licenses/LICENSE +0 -0
- {vespaembed-0.0.1.dist-info → vespaembed-0.0.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from datasets import Dataset
|
|
5
|
+
|
|
6
|
+
from vespaembed.datasets.formats.csv import load_csv
|
|
7
|
+
from vespaembed.datasets.formats.huggingface import load_hf_dataset
|
|
8
|
+
from vespaembed.datasets.formats.jsonl import load_jsonl
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def load_dataset(
|
|
12
|
+
path: str,
|
|
13
|
+
subset: Optional[str] = None,
|
|
14
|
+
split: Optional[str] = None,
|
|
15
|
+
) -> Dataset:
|
|
16
|
+
"""Load a dataset from various sources.
|
|
17
|
+
|
|
18
|
+
Supports:
|
|
19
|
+
- CSV files (.csv)
|
|
20
|
+
- JSONL files (.jsonl)
|
|
21
|
+
- HuggingFace datasets (org/dataset-name)
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
path: Path to file or HuggingFace dataset name
|
|
25
|
+
subset: HuggingFace dataset subset (optional)
|
|
26
|
+
split: HuggingFace dataset split (optional, defaults to "train")
|
|
27
|
+
|
|
28
|
+
Returns:
|
|
29
|
+
HuggingFace Dataset object
|
|
30
|
+
|
|
31
|
+
Raises:
|
|
32
|
+
ValueError: If file format is not supported
|
|
33
|
+
FileNotFoundError: If file does not exist
|
|
34
|
+
"""
|
|
35
|
+
path_obj = Path(path)
|
|
36
|
+
|
|
37
|
+
# Check if it's a local file
|
|
38
|
+
if path_obj.exists():
|
|
39
|
+
suffix = path_obj.suffix.lower()
|
|
40
|
+
|
|
41
|
+
if suffix == ".csv":
|
|
42
|
+
return load_csv(path)
|
|
43
|
+
elif suffix in (".jsonl", ".json"):
|
|
44
|
+
return load_jsonl(path)
|
|
45
|
+
else:
|
|
46
|
+
raise ValueError(f"Unsupported file format: {suffix}. " "Supported formats: .csv, .jsonl")
|
|
47
|
+
|
|
48
|
+
# Check if it looks like a HuggingFace dataset
|
|
49
|
+
if "/" in path or not path_obj.suffix:
|
|
50
|
+
return load_hf_dataset(path, subset=subset, split=split or "train")
|
|
51
|
+
|
|
52
|
+
# File doesn't exist and doesn't look like HF dataset
|
|
53
|
+
raise FileNotFoundError(
|
|
54
|
+
f"File not found: {path}. " "Provide a valid file path or HuggingFace dataset name (e.g., 'org/dataset-name')."
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def preview_dataset(dataset: Dataset, num_samples: int = 5) -> list[dict]:
|
|
59
|
+
"""Preview a dataset by returning the first N samples.
|
|
60
|
+
|
|
61
|
+
Args:
|
|
62
|
+
dataset: Dataset to preview
|
|
63
|
+
num_samples: Number of samples to return
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
List of sample dictionaries
|
|
67
|
+
"""
|
|
68
|
+
return [dataset[i] for i in range(min(num_samples, len(dataset)))]
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def get_columns(dataset: Dataset) -> list[str]:
|
|
72
|
+
"""Get column names from a dataset.
|
|
73
|
+
|
|
74
|
+
Args:
|
|
75
|
+
dataset: Dataset to inspect
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
List of column names
|
|
79
|
+
"""
|
|
80
|
+
return dataset.column_names
|
vespaembed/db.py
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import shutil
|
|
3
|
+
import sqlite3
|
|
4
|
+
from datetime import datetime
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
from vespaembed.enums import RunStatus
|
|
9
|
+
|
|
10
|
+
# Default database location
|
|
11
|
+
DEFAULT_DB_DIR = Path.home() / ".vespaembed"
|
|
12
|
+
DEFAULT_DB_PATH = DEFAULT_DB_DIR / "vespaembed.db"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_db_path() -> Path:
|
|
16
|
+
"""Get the database path, creating directory if needed."""
|
|
17
|
+
DEFAULT_DB_DIR.mkdir(parents=True, exist_ok=True)
|
|
18
|
+
return DEFAULT_DB_PATH
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def get_connection() -> sqlite3.Connection:
|
|
22
|
+
"""Get a database connection."""
|
|
23
|
+
conn = sqlite3.connect(get_db_path(), check_same_thread=False)
|
|
24
|
+
conn.row_factory = sqlite3.Row
|
|
25
|
+
return conn
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def init_db():
|
|
29
|
+
"""Initialize the database schema."""
|
|
30
|
+
conn = get_connection()
|
|
31
|
+
cursor = conn.cursor()
|
|
32
|
+
|
|
33
|
+
cursor.execute("""
|
|
34
|
+
CREATE TABLE IF NOT EXISTS runs (
|
|
35
|
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
36
|
+
status TEXT NOT NULL DEFAULT 'pending',
|
|
37
|
+
pid INTEGER,
|
|
38
|
+
config TEXT,
|
|
39
|
+
project_name TEXT,
|
|
40
|
+
output_dir TEXT,
|
|
41
|
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
42
|
+
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
|
43
|
+
error_message TEXT
|
|
44
|
+
)
|
|
45
|
+
""")
|
|
46
|
+
|
|
47
|
+
conn.commit()
|
|
48
|
+
conn.close()
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def create_run(config: dict, project_name: str, output_dir: str) -> int:
|
|
52
|
+
"""Create a new training run."""
|
|
53
|
+
conn = get_connection()
|
|
54
|
+
cursor = conn.cursor()
|
|
55
|
+
|
|
56
|
+
cursor.execute(
|
|
57
|
+
"""
|
|
58
|
+
INSERT INTO runs (status, config, project_name, output_dir)
|
|
59
|
+
VALUES (?, ?, ?, ?)
|
|
60
|
+
""",
|
|
61
|
+
(RunStatus.PENDING.value, json.dumps(config), project_name, output_dir),
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
run_id = cursor.lastrowid
|
|
65
|
+
conn.commit()
|
|
66
|
+
conn.close()
|
|
67
|
+
|
|
68
|
+
return run_id
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def update_run_status(
|
|
72
|
+
run_id: int,
|
|
73
|
+
status: RunStatus,
|
|
74
|
+
pid: Optional[int] = None,
|
|
75
|
+
error_message: Optional[str] = None,
|
|
76
|
+
):
|
|
77
|
+
"""Update a run's status."""
|
|
78
|
+
conn = get_connection()
|
|
79
|
+
cursor = conn.cursor()
|
|
80
|
+
|
|
81
|
+
if pid is not None:
|
|
82
|
+
cursor.execute(
|
|
83
|
+
"""
|
|
84
|
+
UPDATE runs SET status = ?, pid = ?, updated_at = ?
|
|
85
|
+
WHERE id = ?
|
|
86
|
+
""",
|
|
87
|
+
(status.value, pid, datetime.now(), run_id),
|
|
88
|
+
)
|
|
89
|
+
elif error_message is not None:
|
|
90
|
+
cursor.execute(
|
|
91
|
+
"""
|
|
92
|
+
UPDATE runs SET status = ?, error_message = ?, updated_at = ?
|
|
93
|
+
WHERE id = ?
|
|
94
|
+
""",
|
|
95
|
+
(status.value, error_message, datetime.now(), run_id),
|
|
96
|
+
)
|
|
97
|
+
else:
|
|
98
|
+
cursor.execute(
|
|
99
|
+
"""
|
|
100
|
+
UPDATE runs SET status = ?, updated_at = ?
|
|
101
|
+
WHERE id = ?
|
|
102
|
+
""",
|
|
103
|
+
(status.value, datetime.now(), run_id),
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
conn.commit()
|
|
107
|
+
conn.close()
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def get_run(run_id: int) -> Optional[dict]:
|
|
111
|
+
"""Get a run by ID."""
|
|
112
|
+
conn = get_connection()
|
|
113
|
+
cursor = conn.cursor()
|
|
114
|
+
|
|
115
|
+
cursor.execute("SELECT * FROM runs WHERE id = ?", (run_id,))
|
|
116
|
+
row = cursor.fetchone()
|
|
117
|
+
conn.close()
|
|
118
|
+
|
|
119
|
+
if row:
|
|
120
|
+
return dict(row)
|
|
121
|
+
return None
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def get_all_runs() -> list[dict]:
|
|
125
|
+
"""Get all runs."""
|
|
126
|
+
conn = get_connection()
|
|
127
|
+
cursor = conn.cursor()
|
|
128
|
+
|
|
129
|
+
cursor.execute("SELECT * FROM runs ORDER BY created_at DESC")
|
|
130
|
+
rows = cursor.fetchall()
|
|
131
|
+
conn.close()
|
|
132
|
+
|
|
133
|
+
return [dict(row) for row in rows]
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def get_active_run() -> Optional[dict]:
|
|
137
|
+
"""Get the currently active (running or pending) run."""
|
|
138
|
+
conn = get_connection()
|
|
139
|
+
cursor = conn.cursor()
|
|
140
|
+
|
|
141
|
+
cursor.execute(
|
|
142
|
+
"""
|
|
143
|
+
SELECT * FROM runs
|
|
144
|
+
WHERE status IN (?, ?)
|
|
145
|
+
ORDER BY created_at DESC
|
|
146
|
+
LIMIT 1
|
|
147
|
+
""",
|
|
148
|
+
(RunStatus.PENDING.value, RunStatus.RUNNING.value),
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
row = cursor.fetchone()
|
|
152
|
+
conn.close()
|
|
153
|
+
|
|
154
|
+
if row:
|
|
155
|
+
return dict(row)
|
|
156
|
+
return None
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def delete_run(run_id: int, delete_files: bool = True):
|
|
160
|
+
"""Delete a run and optionally its output files."""
|
|
161
|
+
run = get_run(run_id)
|
|
162
|
+
|
|
163
|
+
if run and delete_files and run.get("output_dir"):
|
|
164
|
+
output_path = Path(run["output_dir"])
|
|
165
|
+
if output_path.exists():
|
|
166
|
+
shutil.rmtree(output_path, ignore_errors=True)
|
|
167
|
+
|
|
168
|
+
conn = get_connection()
|
|
169
|
+
cursor = conn.cursor()
|
|
170
|
+
cursor.execute("DELETE FROM runs WHERE id = ?", (run_id,))
|
|
171
|
+
conn.commit()
|
|
172
|
+
conn.close()
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
# Initialize database on import
|
|
176
|
+
init_db()
|
vespaembed/enums.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class TaskType(str, Enum):
|
|
5
|
+
"""Supported training tasks.
|
|
6
|
+
|
|
7
|
+
Tasks are organized by data format:
|
|
8
|
+
- pairs: Text pairs for semantic search (anchor, positive)
|
|
9
|
+
- triplets: Text triplets with hard negatives (anchor, positive, negative)
|
|
10
|
+
- similarity: Text pairs with similarity scores
|
|
11
|
+
- tsdae: Unlabeled text for unsupervised learning
|
|
12
|
+
|
|
13
|
+
Note: Matryoshka is a training option (--matryoshka flag), not a separate task.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
PAIRS = "pairs"
|
|
17
|
+
TRIPLETS = "triplets"
|
|
18
|
+
SIMILARITY = "similarity"
|
|
19
|
+
TSDAE = "tsdae"
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class LossVariant(str, Enum):
|
|
23
|
+
"""Available loss function variants.
|
|
24
|
+
|
|
25
|
+
For pairs task:
|
|
26
|
+
- mnr: MultipleNegativesRankingLoss (default, recommended)
|
|
27
|
+
- mnr_symmetric: Bidirectional ranking
|
|
28
|
+
- gist: GISTEmbedLoss with guide model
|
|
29
|
+
- cached_mnr: Cached version for larger batches
|
|
30
|
+
- cached_gist: Cached GIST
|
|
31
|
+
|
|
32
|
+
For similarity task:
|
|
33
|
+
- cosine: CosineSimilarityLoss (default)
|
|
34
|
+
- cosent: CoSENTLoss
|
|
35
|
+
- angle: AnglELoss
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
# Pairs task variants
|
|
39
|
+
MNR = "mnr"
|
|
40
|
+
MNR_SYMMETRIC = "mnr_symmetric"
|
|
41
|
+
GIST = "gist"
|
|
42
|
+
CACHED_MNR = "cached_mnr"
|
|
43
|
+
CACHED_GIST = "cached_gist"
|
|
44
|
+
|
|
45
|
+
# Similarity task variants
|
|
46
|
+
COSINE = "cosine"
|
|
47
|
+
COSENT = "cosent"
|
|
48
|
+
ANGLE = "angle"
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class RunStatus(str, Enum):
|
|
52
|
+
"""Training run status."""
|
|
53
|
+
|
|
54
|
+
PENDING = "pending"
|
|
55
|
+
RUNNING = "running"
|
|
56
|
+
COMPLETED = "completed"
|
|
57
|
+
STOPPED = "stopped"
|
|
58
|
+
ERROR = "error"
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
from typing import Any, Optional
|
|
2
|
+
|
|
3
|
+
from datasets import Dataset
|
|
4
|
+
from sentence_transformers.evaluation import (
|
|
5
|
+
BinaryClassificationEvaluator,
|
|
6
|
+
EmbeddingSimilarityEvaluator,
|
|
7
|
+
InformationRetrievalEvaluator,
|
|
8
|
+
TripletEvaluator,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def create_evaluator(
|
|
13
|
+
task: str,
|
|
14
|
+
eval_dataset: Dataset,
|
|
15
|
+
name: str = "eval",
|
|
16
|
+
) -> Optional[Any]:
|
|
17
|
+
"""Create an appropriate evaluator based on task type.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
task: Task name (mnr, triplet, contrastive, sts, nli, tsdae, matryoshka)
|
|
21
|
+
eval_dataset: Prepared evaluation dataset
|
|
22
|
+
name: Evaluator name
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
Evaluator instance or None
|
|
26
|
+
"""
|
|
27
|
+
if task == "mnr":
|
|
28
|
+
return _create_ir_evaluator(eval_dataset, name)
|
|
29
|
+
elif task == "triplet":
|
|
30
|
+
return _create_triplet_evaluator(eval_dataset, name)
|
|
31
|
+
elif task == "contrastive":
|
|
32
|
+
return _create_binary_evaluator(eval_dataset, name)
|
|
33
|
+
elif task == "sts":
|
|
34
|
+
return _create_similarity_evaluator(eval_dataset, name)
|
|
35
|
+
elif task == "nli":
|
|
36
|
+
return _create_similarity_evaluator(eval_dataset, name)
|
|
37
|
+
elif task == "tsdae":
|
|
38
|
+
return None # TSDAE has no intrinsic evaluator
|
|
39
|
+
elif task == "matryoshka":
|
|
40
|
+
return _create_ir_evaluator(eval_dataset, name)
|
|
41
|
+
else:
|
|
42
|
+
return None
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _create_ir_evaluator(dataset: Dataset, name: str) -> InformationRetrievalEvaluator:
|
|
46
|
+
"""Create Information Retrieval evaluator."""
|
|
47
|
+
queries = {str(i): text for i, text in enumerate(dataset["anchor"])}
|
|
48
|
+
corpus = {str(i): text for i, text in enumerate(dataset["positive"])}
|
|
49
|
+
relevant_docs = {str(i): {str(i)} for i in range(len(dataset))}
|
|
50
|
+
|
|
51
|
+
return InformationRetrievalEvaluator(
|
|
52
|
+
queries=queries,
|
|
53
|
+
corpus=corpus,
|
|
54
|
+
relevant_docs=relevant_docs,
|
|
55
|
+
name=name,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _create_triplet_evaluator(dataset: Dataset, name: str) -> TripletEvaluator:
|
|
60
|
+
"""Create Triplet evaluator."""
|
|
61
|
+
return TripletEvaluator(
|
|
62
|
+
anchors=dataset["anchor"],
|
|
63
|
+
positives=dataset["positive"],
|
|
64
|
+
negatives=dataset["negative"],
|
|
65
|
+
name=name,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def _create_binary_evaluator(dataset: Dataset, name: str) -> BinaryClassificationEvaluator:
|
|
70
|
+
"""Create Binary Classification evaluator."""
|
|
71
|
+
return BinaryClassificationEvaluator(
|
|
72
|
+
sentences1=dataset["sentence1"],
|
|
73
|
+
sentences2=dataset["sentence2"],
|
|
74
|
+
labels=dataset["label"],
|
|
75
|
+
name=name,
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def _create_similarity_evaluator(dataset: Dataset, name: str) -> EmbeddingSimilarityEvaluator:
|
|
80
|
+
"""Create Embedding Similarity evaluator."""
|
|
81
|
+
return EmbeddingSimilarityEvaluator(
|
|
82
|
+
sentences1=dataset["sentence1"],
|
|
83
|
+
sentences2=dataset["sentence2"],
|
|
84
|
+
scores=dataset["score"],
|
|
85
|
+
name=name,
|
|
86
|
+
)
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from sentence_transformers import SentenceTransformer
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def export_model(
|
|
8
|
+
model: SentenceTransformer,
|
|
9
|
+
output_path: str,
|
|
10
|
+
format: str = "onnx",
|
|
11
|
+
) -> str:
|
|
12
|
+
"""Export a model to a different format.
|
|
13
|
+
|
|
14
|
+
Args:
|
|
15
|
+
model: SentenceTransformer model
|
|
16
|
+
output_path: Output directory or file path
|
|
17
|
+
format: Export format ("onnx")
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
Path to exported model
|
|
21
|
+
|
|
22
|
+
Raises:
|
|
23
|
+
ValueError: If format is not supported
|
|
24
|
+
"""
|
|
25
|
+
output_path = Path(output_path)
|
|
26
|
+
|
|
27
|
+
if format.lower() == "onnx":
|
|
28
|
+
return _export_onnx(model, output_path)
|
|
29
|
+
else:
|
|
30
|
+
raise ValueError(f"Unsupported export format: {format}. Supported: onnx")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _export_onnx(model: SentenceTransformer, output_path: Path) -> str:
|
|
34
|
+
"""Export model to ONNX format.
|
|
35
|
+
|
|
36
|
+
Args:
|
|
37
|
+
model: SentenceTransformer model
|
|
38
|
+
output_path: Output directory
|
|
39
|
+
|
|
40
|
+
Returns:
|
|
41
|
+
Path to ONNX model
|
|
42
|
+
"""
|
|
43
|
+
try:
|
|
44
|
+
import onnx # noqa: F401
|
|
45
|
+
except ImportError:
|
|
46
|
+
raise ImportError("ONNX not installed. Install with: pip install vespaembed[onnx]")
|
|
47
|
+
|
|
48
|
+
output_path.mkdir(parents=True, exist_ok=True)
|
|
49
|
+
onnx_path = output_path / "model.onnx"
|
|
50
|
+
|
|
51
|
+
# Use sentence-transformers built-in ONNX export if available
|
|
52
|
+
# Otherwise fall back to manual export
|
|
53
|
+
try:
|
|
54
|
+
model.save(str(output_path), model_name_or_path="model.onnx", create_model_card=False)
|
|
55
|
+
except Exception:
|
|
56
|
+
# Manual export via transformers
|
|
57
|
+
from transformers import AutoTokenizer
|
|
58
|
+
|
|
59
|
+
tokenizer = AutoTokenizer.from_pretrained(model[0].auto_model.config._name_or_path)
|
|
60
|
+
|
|
61
|
+
# Export the transformer part
|
|
62
|
+
model[0].auto_model.save_pretrained(output_path)
|
|
63
|
+
tokenizer.save_pretrained(output_path)
|
|
64
|
+
|
|
65
|
+
return str(onnx_path)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def push_to_hub(
|
|
69
|
+
model: SentenceTransformer,
|
|
70
|
+
repo_id: str,
|
|
71
|
+
commit_message: Optional[str] = None,
|
|
72
|
+
private: bool = False,
|
|
73
|
+
) -> str:
|
|
74
|
+
"""Push model to HuggingFace Hub.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
model: SentenceTransformer model
|
|
78
|
+
repo_id: Repository ID (e.g., "username/model-name")
|
|
79
|
+
commit_message: Commit message
|
|
80
|
+
private: Whether to create a private repository
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
URL of the model on HuggingFace Hub
|
|
84
|
+
"""
|
|
85
|
+
return model.push_to_hub(
|
|
86
|
+
repo_id=repo_id,
|
|
87
|
+
commit_message=commit_message or "Upload model via vespaembed",
|
|
88
|
+
private=private,
|
|
89
|
+
)
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from sentence_transformers import SentenceTransformer
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def load_model(model_name_or_path: str, use_unsloth: bool = False) -> SentenceTransformer:
|
|
5
|
+
"""Load a sentence transformer model.
|
|
6
|
+
|
|
7
|
+
Args:
|
|
8
|
+
model_name_or_path: Model name from HuggingFace Hub or local path
|
|
9
|
+
use_unsloth: Whether to use Unsloth for faster inference
|
|
10
|
+
|
|
11
|
+
Returns:
|
|
12
|
+
SentenceTransformer model
|
|
13
|
+
"""
|
|
14
|
+
if use_unsloth:
|
|
15
|
+
try:
|
|
16
|
+
from unsloth import FastSentenceTransformer
|
|
17
|
+
|
|
18
|
+
return FastSentenceTransformer.from_pretrained(
|
|
19
|
+
model_name_or_path,
|
|
20
|
+
for_inference=True,
|
|
21
|
+
)
|
|
22
|
+
except ImportError:
|
|
23
|
+
raise ImportError("Unsloth not installed. Install with: pip install vespaembed[unsloth]")
|
|
24
|
+
|
|
25
|
+
return SentenceTransformer(model_name_or_path)
|