vespaembed 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vespaembed/__init__.py +1 -1
- vespaembed/cli/__init__.py +17 -0
- vespaembed/cli/commands/__init__.py +7 -0
- vespaembed/cli/commands/evaluate.py +85 -0
- vespaembed/cli/commands/export.py +86 -0
- vespaembed/cli/commands/info.py +52 -0
- vespaembed/cli/commands/serve.py +49 -0
- vespaembed/cli/commands/train.py +267 -0
- vespaembed/cli/vespaembed.py +55 -0
- vespaembed/core/__init__.py +2 -0
- vespaembed/core/config.py +164 -0
- vespaembed/core/registry.py +158 -0
- vespaembed/core/trainer.py +573 -0
- vespaembed/datasets/__init__.py +3 -0
- vespaembed/datasets/formats/__init__.py +5 -0
- vespaembed/datasets/formats/csv.py +15 -0
- vespaembed/datasets/formats/huggingface.py +34 -0
- vespaembed/datasets/formats/jsonl.py +26 -0
- vespaembed/datasets/loader.py +80 -0
- vespaembed/db.py +176 -0
- vespaembed/enums.py +58 -0
- vespaembed/evaluation/__init__.py +3 -0
- vespaembed/evaluation/factory.py +86 -0
- vespaembed/models/__init__.py +4 -0
- vespaembed/models/export.py +89 -0
- vespaembed/models/loader.py +25 -0
- vespaembed/static/css/styles.css +1800 -0
- vespaembed/static/js/app.js +1485 -0
- vespaembed/tasks/__init__.py +23 -0
- vespaembed/tasks/base.py +144 -0
- vespaembed/tasks/pairs.py +91 -0
- vespaembed/tasks/similarity.py +84 -0
- vespaembed/tasks/triplets.py +90 -0
- vespaembed/tasks/tsdae.py +102 -0
- vespaembed/templates/index.html +544 -0
- vespaembed/utils/__init__.py +3 -0
- vespaembed/utils/logging.py +69 -0
- vespaembed/web/__init__.py +1 -0
- vespaembed/web/api/__init__.py +1 -0
- vespaembed/web/app.py +605 -0
- vespaembed/worker.py +313 -0
- vespaembed-0.0.3.dist-info/METADATA +325 -0
- vespaembed-0.0.3.dist-info/RECORD +47 -0
- {vespaembed-0.0.1.dist-info → vespaembed-0.0.3.dist-info}/WHEEL +1 -1
- vespaembed-0.0.1.dist-info/METADATA +0 -20
- vespaembed-0.0.1.dist-info/RECORD +0 -7
- {vespaembed-0.0.1.dist-info → vespaembed-0.0.3.dist-info}/entry_points.txt +0 -0
- {vespaembed-0.0.1.dist-info → vespaembed-0.0.3.dist-info}/licenses/LICENSE +0 -0
- {vespaembed-0.0.1.dist-info → vespaembed-0.0.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,573 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Callable, Optional
|
|
5
|
+
|
|
6
|
+
from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer
|
|
7
|
+
from sentence_transformers.training_args import SentenceTransformerTrainingArguments
|
|
8
|
+
from transformers import TrainerCallback
|
|
9
|
+
|
|
10
|
+
from vespaembed.core.config import TrainingConfig
|
|
11
|
+
from vespaembed.core.registry import Registry
|
|
12
|
+
from vespaembed.datasets.loader import load_dataset
|
|
13
|
+
from vespaembed.utils.logging import logger
|
|
14
|
+
|
|
15
|
+
# Get HuggingFace token from environment (fallback for loading private models)
|
|
16
|
+
HF_TOKEN_ENV = os.environ.get("HF_TOKEN")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ProgressCallback(TrainerCallback):
|
|
20
|
+
"""Callback to report training progress like tqdm."""
|
|
21
|
+
|
|
22
|
+
def __init__(self, callback: Callable[[dict], None]):
|
|
23
|
+
self.callback = callback
|
|
24
|
+
self.start_time = None
|
|
25
|
+
self.total_steps = 0
|
|
26
|
+
self.total_epochs = 0
|
|
27
|
+
|
|
28
|
+
def on_train_begin(self, args, state, control, **kwargs):
|
|
29
|
+
"""Called at the beginning of training."""
|
|
30
|
+
import time
|
|
31
|
+
|
|
32
|
+
self.start_time = time.time()
|
|
33
|
+
self.total_steps = state.max_steps
|
|
34
|
+
self.total_epochs = args.num_train_epochs
|
|
35
|
+
|
|
36
|
+
if self.callback:
|
|
37
|
+
self.callback(
|
|
38
|
+
{
|
|
39
|
+
"type": "train_start",
|
|
40
|
+
"total_steps": self.total_steps,
|
|
41
|
+
"total_epochs": self.total_epochs,
|
|
42
|
+
}
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
def on_log(self, args, state, control, logs=None, **kwargs):
|
|
46
|
+
"""Called when the trainer logs metrics."""
|
|
47
|
+
import time
|
|
48
|
+
|
|
49
|
+
if logs and self.callback:
|
|
50
|
+
current_step = state.global_step
|
|
51
|
+
elapsed = time.time() - self.start_time if self.start_time else 0
|
|
52
|
+
|
|
53
|
+
# Calculate progress
|
|
54
|
+
progress_pct = (current_step / self.total_steps * 100) if self.total_steps > 0 else 0
|
|
55
|
+
steps_per_sec = current_step / elapsed if elapsed > 0 else 0
|
|
56
|
+
remaining_steps = self.total_steps - current_step
|
|
57
|
+
eta_seconds = remaining_steps / steps_per_sec if steps_per_sec > 0 else 0
|
|
58
|
+
|
|
59
|
+
progress = {
|
|
60
|
+
"type": "progress",
|
|
61
|
+
"epoch": state.epoch,
|
|
62
|
+
"total_epochs": self.total_epochs,
|
|
63
|
+
"step": current_step,
|
|
64
|
+
"total_steps": self.total_steps,
|
|
65
|
+
"progress_pct": progress_pct,
|
|
66
|
+
"loss": logs.get("loss"),
|
|
67
|
+
"learning_rate": logs.get("learning_rate"),
|
|
68
|
+
"steps_per_sec": steps_per_sec,
|
|
69
|
+
"elapsed_seconds": elapsed,
|
|
70
|
+
"eta_seconds": eta_seconds,
|
|
71
|
+
}
|
|
72
|
+
self.callback(progress)
|
|
73
|
+
|
|
74
|
+
def on_train_end(self, args, state, control, **kwargs):
|
|
75
|
+
"""Called at the end of training."""
|
|
76
|
+
import time
|
|
77
|
+
|
|
78
|
+
if self.callback:
|
|
79
|
+
elapsed = time.time() - self.start_time if self.start_time else 0
|
|
80
|
+
self.callback(
|
|
81
|
+
{
|
|
82
|
+
"type": "train_end",
|
|
83
|
+
"total_steps": state.global_step,
|
|
84
|
+
"elapsed_seconds": elapsed,
|
|
85
|
+
}
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class VespaEmbedTrainer:
|
|
90
|
+
"""High-level trainer that wraps SentenceTransformerTrainer."""
|
|
91
|
+
|
|
92
|
+
def __init__(
|
|
93
|
+
self,
|
|
94
|
+
config: TrainingConfig,
|
|
95
|
+
progress_callback: Optional[Callable[[dict], None]] = None,
|
|
96
|
+
):
|
|
97
|
+
"""Initialize the trainer.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
config: Training configuration
|
|
101
|
+
progress_callback: Optional callback for progress updates
|
|
102
|
+
"""
|
|
103
|
+
self.config = config
|
|
104
|
+
self.model = None
|
|
105
|
+
self.task = None
|
|
106
|
+
self.progress_callback = progress_callback
|
|
107
|
+
|
|
108
|
+
def _load_model(self) -> SentenceTransformer:
|
|
109
|
+
"""Load the model with optional LoRA and Unsloth support.
|
|
110
|
+
|
|
111
|
+
Supports four modes:
|
|
112
|
+
1. Standard: SentenceTransformer only
|
|
113
|
+
2. Standard + LoRA: SentenceTransformer with PEFT adapter
|
|
114
|
+
3. Unsloth: FastSentenceTransformer (faster training)
|
|
115
|
+
4. Unsloth + LoRA: FastSentenceTransformer with LoRA via get_peft_model
|
|
116
|
+
"""
|
|
117
|
+
use_unsloth = self.config.unsloth.enabled
|
|
118
|
+
|
|
119
|
+
if use_unsloth:
|
|
120
|
+
return self._load_unsloth_model()
|
|
121
|
+
else:
|
|
122
|
+
return self._load_standard_model()
|
|
123
|
+
|
|
124
|
+
def _load_standard_model(self) -> SentenceTransformer:
|
|
125
|
+
"""Load model with standard SentenceTransformer, optionally with LoRA."""
|
|
126
|
+
logger.info(f"Loading model: {self.config.base_model}")
|
|
127
|
+
model = SentenceTransformer(self.config.base_model, token=HF_TOKEN_ENV)
|
|
128
|
+
|
|
129
|
+
# Set max_seq_length if specified
|
|
130
|
+
if self.config.max_seq_length:
|
|
131
|
+
model.max_seq_length = self.config.max_seq_length
|
|
132
|
+
logger.info(f"Set max_seq_length: {self.config.max_seq_length}")
|
|
133
|
+
else:
|
|
134
|
+
logger.info(f"Using model's default max_seq_length: {model.max_seq_length}")
|
|
135
|
+
|
|
136
|
+
# Note: gradient_checkpointing is handled by SentenceTransformerTrainingArguments
|
|
137
|
+
# The Trainer will automatically enable it on the model during training
|
|
138
|
+
|
|
139
|
+
# Add LoRA adapter if enabled
|
|
140
|
+
if self.config.lora.enabled:
|
|
141
|
+
try:
|
|
142
|
+
from peft import LoraConfig, TaskType
|
|
143
|
+
|
|
144
|
+
logger.info(
|
|
145
|
+
f"Adding LoRA adapter: r={self.config.lora.r}, "
|
|
146
|
+
f"alpha={self.config.lora.alpha}, dropout={self.config.lora.dropout}, "
|
|
147
|
+
f"target_modules={self.config.lora.target_modules}"
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
peft_config = LoraConfig(
|
|
151
|
+
task_type=TaskType.FEATURE_EXTRACTION,
|
|
152
|
+
r=self.config.lora.r,
|
|
153
|
+
lora_alpha=self.config.lora.alpha,
|
|
154
|
+
lora_dropout=self.config.lora.dropout,
|
|
155
|
+
target_modules=self.config.lora.target_modules,
|
|
156
|
+
)
|
|
157
|
+
model.add_adapter(peft_config)
|
|
158
|
+
|
|
159
|
+
except ImportError:
|
|
160
|
+
raise ImportError("PEFT not installed. Install with: pip install peft")
|
|
161
|
+
|
|
162
|
+
return model
|
|
163
|
+
|
|
164
|
+
def _load_unsloth_model(self) -> SentenceTransformer:
|
|
165
|
+
"""Load model with Unsloth for faster training."""
|
|
166
|
+
import torch
|
|
167
|
+
|
|
168
|
+
try:
|
|
169
|
+
from unsloth import FastSentenceTransformer
|
|
170
|
+
except ImportError:
|
|
171
|
+
raise ImportError("Unsloth not installed. Install with: pip install unsloth")
|
|
172
|
+
|
|
173
|
+
# Auto-detect max_seq_length from model if not specified
|
|
174
|
+
max_seq_length = self.config.max_seq_length
|
|
175
|
+
if max_seq_length is None:
|
|
176
|
+
# Load model temporarily to get max_seq_length
|
|
177
|
+
temp_model = SentenceTransformer(self.config.base_model, token=HF_TOKEN_ENV)
|
|
178
|
+
max_seq_length = temp_model.max_seq_length
|
|
179
|
+
del temp_model
|
|
180
|
+
logger.info(f"Auto-detected max_seq_length: {max_seq_length}")
|
|
181
|
+
else:
|
|
182
|
+
logger.info(f"Using specified max_seq_length: {max_seq_length}")
|
|
183
|
+
|
|
184
|
+
# Determine dtype based on precision settings
|
|
185
|
+
# BF16 is preferred for Unsloth when available (better for training)
|
|
186
|
+
if self.config.training.bf16:
|
|
187
|
+
dtype = torch.bfloat16
|
|
188
|
+
logger.info("Using BF16 precision for Unsloth")
|
|
189
|
+
elif self.config.training.fp16:
|
|
190
|
+
dtype = torch.float16
|
|
191
|
+
logger.info("Using FP16 precision for Unsloth")
|
|
192
|
+
else:
|
|
193
|
+
dtype = None # Let Unsloth auto-detect
|
|
194
|
+
logger.info("Using auto-detected precision for Unsloth")
|
|
195
|
+
|
|
196
|
+
# Full finetuning when Unsloth is enabled but LoRA is not
|
|
197
|
+
full_finetuning = not self.config.lora.enabled
|
|
198
|
+
|
|
199
|
+
# Use Unsloth's optimized gradient checkpointing ("unsloth") when enabled
|
|
200
|
+
# This is faster and uses less memory than standard PyTorch GC
|
|
201
|
+
# Must be passed to from_pretrained for full finetuning, or get_peft_model for LoRA
|
|
202
|
+
use_gc = "unsloth" if self.config.gradient_checkpointing else False
|
|
203
|
+
|
|
204
|
+
logger.info(f"Loading model with Unsloth: {self.config.base_model}")
|
|
205
|
+
if full_finetuning:
|
|
206
|
+
logger.info(f"Full finetuning mode (gradient_checkpointing={use_gc})")
|
|
207
|
+
model = FastSentenceTransformer.from_pretrained(
|
|
208
|
+
model_name=self.config.base_model,
|
|
209
|
+
max_seq_length=max_seq_length,
|
|
210
|
+
full_finetuning=full_finetuning,
|
|
211
|
+
dtype=dtype,
|
|
212
|
+
use_gradient_checkpointing=use_gc if full_finetuning else False,
|
|
213
|
+
token=HF_TOKEN_ENV,
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
# Add LoRA if enabled
|
|
217
|
+
if self.config.lora.enabled:
|
|
218
|
+
logger.info(
|
|
219
|
+
f"Applying Unsloth LoRA: r={self.config.lora.r}, "
|
|
220
|
+
f"alpha={self.config.lora.alpha}, target_modules={self.config.lora.target_modules}"
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
model = FastSentenceTransformer.get_peft_model(
|
|
224
|
+
model,
|
|
225
|
+
r=self.config.lora.r,
|
|
226
|
+
lora_alpha=self.config.lora.alpha,
|
|
227
|
+
lora_dropout=self.config.lora.dropout,
|
|
228
|
+
target_modules=self.config.lora.target_modules,
|
|
229
|
+
bias="none",
|
|
230
|
+
use_gradient_checkpointing=use_gc,
|
|
231
|
+
random_state=3407,
|
|
232
|
+
use_rslora=False,
|
|
233
|
+
loftq_config=None,
|
|
234
|
+
task_type="FEATURE_EXTRACTION",
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
return model
|
|
238
|
+
|
|
239
|
+
def _log_training_config(self):
|
|
240
|
+
"""Log training configuration parameters."""
|
|
241
|
+
logger.info("=" * 60)
|
|
242
|
+
logger.info("Training Configuration")
|
|
243
|
+
logger.info("=" * 60)
|
|
244
|
+
|
|
245
|
+
# Model & Task
|
|
246
|
+
logger.info(f" Base Model: {self.config.base_model}")
|
|
247
|
+
logger.info(f" Task: {self.config.task}")
|
|
248
|
+
if self.config.loss_variant:
|
|
249
|
+
logger.info(f" Loss Variant: {self.config.loss_variant}")
|
|
250
|
+
|
|
251
|
+
# Data
|
|
252
|
+
logger.info(f" Training Data: {self.config.data.train}")
|
|
253
|
+
if self.config.data.eval:
|
|
254
|
+
logger.info(f" Eval Data: {self.config.data.eval}")
|
|
255
|
+
|
|
256
|
+
# Training hyperparameters
|
|
257
|
+
t = self.config.training
|
|
258
|
+
logger.info(f" Epochs: {t.epochs}")
|
|
259
|
+
logger.info(f" Batch Size: {t.batch_size}")
|
|
260
|
+
logger.info(f" Learning Rate: {t.learning_rate}")
|
|
261
|
+
logger.info(f" Optimizer: {t.optimizer}")
|
|
262
|
+
logger.info(f" Scheduler: {t.scheduler}")
|
|
263
|
+
logger.info(f" Warmup Ratio: {t.warmup_ratio}")
|
|
264
|
+
logger.info(f" Weight Decay: {t.weight_decay}")
|
|
265
|
+
|
|
266
|
+
# Precision
|
|
267
|
+
if t.bf16:
|
|
268
|
+
logger.info(" Precision: BF16")
|
|
269
|
+
elif t.fp16:
|
|
270
|
+
logger.info(" Precision: FP16")
|
|
271
|
+
else:
|
|
272
|
+
logger.info(" Precision: FP32")
|
|
273
|
+
|
|
274
|
+
# Optional features
|
|
275
|
+
if self.config.max_seq_length:
|
|
276
|
+
logger.info(f" Max Seq Length: {self.config.max_seq_length}")
|
|
277
|
+
if self.config.gradient_checkpointing:
|
|
278
|
+
logger.info(" Grad Checkpoint: Enabled")
|
|
279
|
+
if t.gradient_accumulation_steps > 1:
|
|
280
|
+
logger.info(f" Grad Accum: {t.gradient_accumulation_steps}")
|
|
281
|
+
|
|
282
|
+
# LoRA
|
|
283
|
+
if self.config.lora.enabled:
|
|
284
|
+
logger.info(f" LoRA: r={self.config.lora.r}, alpha={self.config.lora.alpha}")
|
|
285
|
+
|
|
286
|
+
# Unsloth
|
|
287
|
+
if self.config.unsloth.enabled:
|
|
288
|
+
logger.info(f" Unsloth: Enabled (save: {self.config.unsloth.save_method})")
|
|
289
|
+
|
|
290
|
+
# Matryoshka
|
|
291
|
+
if self.config.matryoshka_dims:
|
|
292
|
+
logger.info(f" Matryoshka: {self.config.matryoshka_dims}")
|
|
293
|
+
|
|
294
|
+
# Output
|
|
295
|
+
logger.info(f" Output Dir: {self.config.output.dir}")
|
|
296
|
+
if self.config.output.push_to_hub:
|
|
297
|
+
logger.info(f" Push to Hub: {self.config.output.hf_username}")
|
|
298
|
+
|
|
299
|
+
logger.info("=" * 60)
|
|
300
|
+
|
|
301
|
+
def train(self) -> SentenceTransformer:
|
|
302
|
+
"""Run the training process.
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
Trained SentenceTransformer model
|
|
306
|
+
"""
|
|
307
|
+
# Log configuration
|
|
308
|
+
self._log_training_config()
|
|
309
|
+
|
|
310
|
+
# 1. Load model
|
|
311
|
+
self.model = self._load_model()
|
|
312
|
+
|
|
313
|
+
# 2. Get task (pass task-specific params if applicable)
|
|
314
|
+
task_cls = Registry.get_task(self.config.task)
|
|
315
|
+
# Handle loss_variant as either enum or string
|
|
316
|
+
loss_variant = self.config.loss_variant
|
|
317
|
+
if loss_variant is not None:
|
|
318
|
+
loss_variant = loss_variant.value if hasattr(loss_variant, "value") else loss_variant
|
|
319
|
+
|
|
320
|
+
if loss_variant:
|
|
321
|
+
self.task = task_cls(loss_variant=loss_variant)
|
|
322
|
+
else:
|
|
323
|
+
self.task = task_cls()
|
|
324
|
+
|
|
325
|
+
loss_info = f" (loss: {self.task.loss_variant})" if self.task.loss_variant else ""
|
|
326
|
+
logger.info(f"Using task: {self.task.name} - {self.task.description}{loss_info}")
|
|
327
|
+
|
|
328
|
+
# 3. Load and prepare training data
|
|
329
|
+
logger.info(f"Loading training data: {self.config.data.train}")
|
|
330
|
+
train_data = load_dataset(
|
|
331
|
+
self.config.data.train,
|
|
332
|
+
subset=self.config.data.subset,
|
|
333
|
+
split=self.config.data.split or "train",
|
|
334
|
+
)
|
|
335
|
+
train_data = self.task.prepare_dataset(train_data)
|
|
336
|
+
logger.info(f"Training samples: {len(train_data)}")
|
|
337
|
+
|
|
338
|
+
# 4. Load and prepare evaluation data (optional)
|
|
339
|
+
eval_data = None
|
|
340
|
+
evaluator = None
|
|
341
|
+
if self.config.data.eval:
|
|
342
|
+
logger.info(f"Loading evaluation data: {self.config.data.eval}")
|
|
343
|
+
# Use eval_split if specified (for HF datasets), otherwise use default split detection
|
|
344
|
+
eval_split = self.config.data.eval_split
|
|
345
|
+
eval_data = load_dataset(
|
|
346
|
+
self.config.data.eval,
|
|
347
|
+
subset=self.config.data.subset, # Use same subset as training
|
|
348
|
+
split=eval_split,
|
|
349
|
+
)
|
|
350
|
+
eval_data = self.task.prepare_dataset(eval_data)
|
|
351
|
+
evaluator = self.task.get_evaluator(eval_data)
|
|
352
|
+
logger.info(f"Evaluation samples: {len(eval_data)}")
|
|
353
|
+
|
|
354
|
+
# 5. Create loss function
|
|
355
|
+
loss = self.task.get_loss(self.model)
|
|
356
|
+
|
|
357
|
+
# Wrap with MatryoshkaLoss if dimensions specified (not supported for TSDAE)
|
|
358
|
+
if self.config.matryoshka_dims:
|
|
359
|
+
if self.config.task == "tsdae":
|
|
360
|
+
raise ValueError("Matryoshka is not supported with TSDAE (uses decoder architecture)")
|
|
361
|
+
from sentence_transformers.losses import MatryoshkaLoss
|
|
362
|
+
|
|
363
|
+
logger.info(f"Wrapping with MatryoshkaLoss: {self.config.matryoshka_dims}")
|
|
364
|
+
loss = MatryoshkaLoss(self.model, loss, matryoshka_dims=self.config.matryoshka_dims)
|
|
365
|
+
|
|
366
|
+
# 6. Create output directory
|
|
367
|
+
output_dir = Path(self.config.output.dir)
|
|
368
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
369
|
+
|
|
370
|
+
# 7. Training arguments
|
|
371
|
+
logger.info(f"Optimizer: {self.config.training.optimizer}, Scheduler: {self.config.training.scheduler}")
|
|
372
|
+
args = SentenceTransformerTrainingArguments(
|
|
373
|
+
output_dir=str(output_dir),
|
|
374
|
+
num_train_epochs=self.config.training.epochs,
|
|
375
|
+
per_device_train_batch_size=self.config.training.batch_size,
|
|
376
|
+
per_device_eval_batch_size=self.config.training.batch_size,
|
|
377
|
+
learning_rate=self.config.training.learning_rate,
|
|
378
|
+
warmup_ratio=self.config.training.warmup_ratio,
|
|
379
|
+
weight_decay=self.config.training.weight_decay,
|
|
380
|
+
fp16=self.config.training.fp16,
|
|
381
|
+
bf16=self.config.training.bf16,
|
|
382
|
+
optim=self.config.training.optimizer,
|
|
383
|
+
lr_scheduler_type=self.config.training.scheduler,
|
|
384
|
+
gradient_checkpointing=self.config.gradient_checkpointing,
|
|
385
|
+
batch_sampler=self.task.batch_sampler,
|
|
386
|
+
eval_strategy="steps" if evaluator else "no",
|
|
387
|
+
eval_steps=self.config.training.eval_steps if evaluator else None,
|
|
388
|
+
save_strategy="steps",
|
|
389
|
+
save_steps=self.config.training.save_steps,
|
|
390
|
+
save_total_limit=self.config.output.save_total_limit,
|
|
391
|
+
logging_steps=self.config.training.logging_steps,
|
|
392
|
+
gradient_accumulation_steps=self.config.training.gradient_accumulation_steps,
|
|
393
|
+
load_best_model_at_end=True if evaluator else False,
|
|
394
|
+
report_to="tensorboard",
|
|
395
|
+
logging_dir=str(output_dir / "logs"),
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
# 8. Create trainer
|
|
399
|
+
callbacks = []
|
|
400
|
+
if self.progress_callback:
|
|
401
|
+
callbacks.append(ProgressCallback(self.progress_callback))
|
|
402
|
+
|
|
403
|
+
trainer = SentenceTransformerTrainer(
|
|
404
|
+
model=self.model,
|
|
405
|
+
args=args,
|
|
406
|
+
train_dataset=train_data,
|
|
407
|
+
eval_dataset=eval_data,
|
|
408
|
+
loss=loss,
|
|
409
|
+
evaluator=evaluator,
|
|
410
|
+
callbacks=callbacks if callbacks else None,
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
# 9. Train
|
|
414
|
+
logger.info("Starting training...")
|
|
415
|
+
trainer.train()
|
|
416
|
+
|
|
417
|
+
# 10. Save final model
|
|
418
|
+
final_path = output_dir / "final"
|
|
419
|
+
logger.info(f"Saving model to: {final_path}")
|
|
420
|
+
self._save_model(final_path)
|
|
421
|
+
|
|
422
|
+
# 11. Add label mappings to config.json if task has labels (HuggingFace convention)
|
|
423
|
+
label_config = self.task.get_label_config()
|
|
424
|
+
if label_config:
|
|
425
|
+
config_path = final_path / "config.json"
|
|
426
|
+
if config_path.exists():
|
|
427
|
+
with open(config_path) as f:
|
|
428
|
+
config = json.load(f)
|
|
429
|
+
config.update(label_config)
|
|
430
|
+
with open(config_path, "w") as f:
|
|
431
|
+
json.dump(config, f, indent=2)
|
|
432
|
+
logger.info(f"Added label mappings to config.json ({label_config['num_labels']} labels)")
|
|
433
|
+
|
|
434
|
+
# 12. Push to hub if configured (always private)
|
|
435
|
+
if self.config.output.push_to_hub and self.config.output.hf_username:
|
|
436
|
+
if not HF_TOKEN_ENV:
|
|
437
|
+
logger.warning("HF_TOKEN environment variable not set, skipping push to hub")
|
|
438
|
+
else:
|
|
439
|
+
# Construct repo name from username and project directory name
|
|
440
|
+
project_name = Path(self.config.output.dir).name
|
|
441
|
+
repo_id = f"{self.config.output.hf_username}/{project_name}"
|
|
442
|
+
logger.info(f"Pushing to HuggingFace Hub (private): {repo_id}")
|
|
443
|
+
self._push_to_hub(repo_id)
|
|
444
|
+
|
|
445
|
+
logger.success("Training completed!")
|
|
446
|
+
return self.model
|
|
447
|
+
|
|
448
|
+
def _save_model(self, path: Path) -> None:
|
|
449
|
+
"""Save the model based on configuration.
|
|
450
|
+
|
|
451
|
+
Handles different save methods for standard, LoRA, and Unsloth models.
|
|
452
|
+
"""
|
|
453
|
+
path_str = str(path)
|
|
454
|
+
|
|
455
|
+
# Add vespaembed tag to model card metadata
|
|
456
|
+
if hasattr(self.model, "model_card_data") and self.model.model_card_data is not None:
|
|
457
|
+
self.model.model_card_data.add_tags("vespaembed")
|
|
458
|
+
|
|
459
|
+
if self.config.unsloth.enabled:
|
|
460
|
+
save_method = self.config.unsloth.save_method
|
|
461
|
+
|
|
462
|
+
if save_method == "lora":
|
|
463
|
+
# Save only LoRA adapters
|
|
464
|
+
logger.info("Saving LoRA adapters only")
|
|
465
|
+
self.model.save_pretrained(path_str)
|
|
466
|
+
elif save_method == "merged_16bit":
|
|
467
|
+
# Merge and save as FP16
|
|
468
|
+
logger.info("Saving merged model (FP16)")
|
|
469
|
+
self.model.save_pretrained_merged(
|
|
470
|
+
path_str,
|
|
471
|
+
tokenizer=self.model.tokenizer,
|
|
472
|
+
save_method="merged_16bit",
|
|
473
|
+
)
|
|
474
|
+
elif save_method == "merged_4bit":
|
|
475
|
+
# Merge and save as 4-bit
|
|
476
|
+
logger.info("Saving merged model (4-bit)")
|
|
477
|
+
self.model.save_pretrained_merged(
|
|
478
|
+
path_str,
|
|
479
|
+
tokenizer=self.model.tokenizer,
|
|
480
|
+
save_method="merged_4bit",
|
|
481
|
+
)
|
|
482
|
+
else:
|
|
483
|
+
# Standard or LoRA (PEFT) save
|
|
484
|
+
self.model.save_pretrained(path_str)
|
|
485
|
+
|
|
486
|
+
# Add vespaembed mention to README.md
|
|
487
|
+
self._add_vespaembed_to_readme(path)
|
|
488
|
+
|
|
489
|
+
def _add_vespaembed_to_readme(self, path: Path) -> None:
|
|
490
|
+
"""Add vespaembed mention to the README.md file.
|
|
491
|
+
|
|
492
|
+
This method is idempotent - it will not add duplicate mentions if called multiple times.
|
|
493
|
+
"""
|
|
494
|
+
readme_path = path / "README.md"
|
|
495
|
+
if not readme_path.exists():
|
|
496
|
+
return
|
|
497
|
+
|
|
498
|
+
content = readme_path.read_text(encoding="utf-8")
|
|
499
|
+
|
|
500
|
+
# Check if vespaembed mention already exists (idempotency)
|
|
501
|
+
if "github.com/vespaai-playground/vespaembed" in content:
|
|
502
|
+
return
|
|
503
|
+
|
|
504
|
+
# Insert vespaembed mention after the first heading
|
|
505
|
+
vespaembed_mention = (
|
|
506
|
+
"\n> This model was trained using " "[vespaembed](https://github.com/vespaai-playground/vespaembed).\n"
|
|
507
|
+
)
|
|
508
|
+
|
|
509
|
+
# Find the first heading and insert after it
|
|
510
|
+
lines = content.split("\n")
|
|
511
|
+
new_lines = []
|
|
512
|
+
inserted = False
|
|
513
|
+
|
|
514
|
+
for line in lines:
|
|
515
|
+
new_lines.append(line)
|
|
516
|
+
# Insert after the first markdown heading (starting with #)
|
|
517
|
+
if not inserted and line.startswith("# ") and not line.startswith("# For reference"):
|
|
518
|
+
new_lines.append(vespaembed_mention)
|
|
519
|
+
inserted = True
|
|
520
|
+
|
|
521
|
+
if inserted:
|
|
522
|
+
readme_path.write_text("\n".join(new_lines), encoding="utf-8")
|
|
523
|
+
|
|
524
|
+
def _push_to_hub(self, repo_id: str) -> None:
|
|
525
|
+
"""Push the model to HuggingFace Hub.
|
|
526
|
+
|
|
527
|
+
Handles different push methods for standard, LoRA, and Unsloth models.
|
|
528
|
+
"""
|
|
529
|
+
if self.config.unsloth.enabled:
|
|
530
|
+
save_method = self.config.unsloth.save_method
|
|
531
|
+
|
|
532
|
+
if save_method == "lora":
|
|
533
|
+
# Push only LoRA adapters
|
|
534
|
+
self.model.push_to_hub(repo_id, token=HF_TOKEN_ENV, private=True)
|
|
535
|
+
else:
|
|
536
|
+
# Push merged model
|
|
537
|
+
self.model.push_to_hub_merged(
|
|
538
|
+
repo_id,
|
|
539
|
+
tokenizer=self.model.tokenizer,
|
|
540
|
+
save_method=save_method,
|
|
541
|
+
token=HF_TOKEN_ENV,
|
|
542
|
+
private=True,
|
|
543
|
+
)
|
|
544
|
+
else:
|
|
545
|
+
# Standard or LoRA (PEFT) save
|
|
546
|
+
self.model.push_to_hub(repo_id, token=HF_TOKEN_ENV, private=True)
|
|
547
|
+
|
|
548
|
+
# Upload the modified README.md with vespaembed mention
|
|
549
|
+
self._upload_readme_to_hub(repo_id)
|
|
550
|
+
|
|
551
|
+
def _upload_readme_to_hub(self, repo_id: str) -> None:
|
|
552
|
+
"""Upload the modified README.md to HuggingFace Hub.
|
|
553
|
+
|
|
554
|
+
This uploads the local README.md (which contains the vespaembed mention)
|
|
555
|
+
to overwrite the auto-generated one on the hub.
|
|
556
|
+
"""
|
|
557
|
+
from huggingface_hub import HfApi
|
|
558
|
+
|
|
559
|
+
# Get the local README.md path from the final output directory
|
|
560
|
+
output_dir = Path(self.config.output.dir)
|
|
561
|
+
readme_path = output_dir / "final" / "README.md"
|
|
562
|
+
|
|
563
|
+
if not readme_path.exists():
|
|
564
|
+
logger.warning(f"README.md not found at {readme_path}, skipping upload")
|
|
565
|
+
return
|
|
566
|
+
|
|
567
|
+
api = HfApi()
|
|
568
|
+
api.upload_file(
|
|
569
|
+
path_or_fileobj=str(readme_path),
|
|
570
|
+
path_in_repo="README.md",
|
|
571
|
+
repo_id=repo_id,
|
|
572
|
+
token=HF_TOKEN_ENV,
|
|
573
|
+
)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import pandas as pd
|
|
2
|
+
from datasets import Dataset
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def load_csv(path: str) -> Dataset:
|
|
6
|
+
"""Load a CSV file as a HuggingFace Dataset.
|
|
7
|
+
|
|
8
|
+
Args:
|
|
9
|
+
path: Path to CSV file
|
|
10
|
+
|
|
11
|
+
Returns:
|
|
12
|
+
HuggingFace Dataset
|
|
13
|
+
"""
|
|
14
|
+
df = pd.read_csv(path)
|
|
15
|
+
return Dataset.from_pandas(df)
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from datasets import Dataset
|
|
5
|
+
from datasets import load_dataset as hf_load_dataset
|
|
6
|
+
|
|
7
|
+
# Get HuggingFace token from environment
|
|
8
|
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def load_hf_dataset(
|
|
12
|
+
name: str,
|
|
13
|
+
subset: Optional[str] = None,
|
|
14
|
+
split: str = "train",
|
|
15
|
+
) -> Dataset:
|
|
16
|
+
"""Load a dataset from the HuggingFace Hub.
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
name: Dataset name (e.g., "sentence-transformers/all-nli")
|
|
20
|
+
subset: Dataset subset/configuration (optional)
|
|
21
|
+
split: Dataset split (default: "train")
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
HuggingFace Dataset
|
|
25
|
+
|
|
26
|
+
Note:
|
|
27
|
+
Uses HF_TOKEN environment variable for authentication if set.
|
|
28
|
+
"""
|
|
29
|
+
if subset:
|
|
30
|
+
dataset = hf_load_dataset(name, subset, split=split, token=HF_TOKEN)
|
|
31
|
+
else:
|
|
32
|
+
dataset = hf_load_dataset(name, split=split, token=HF_TOKEN)
|
|
33
|
+
|
|
34
|
+
return dataset
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
import json
|
|
2
|
+
|
|
3
|
+
from datasets import Dataset
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def load_jsonl(path: str) -> Dataset:
|
|
7
|
+
"""Load a JSONL file as a HuggingFace Dataset.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
path: Path to JSONL file
|
|
11
|
+
|
|
12
|
+
Returns:
|
|
13
|
+
HuggingFace Dataset
|
|
14
|
+
"""
|
|
15
|
+
records = []
|
|
16
|
+
|
|
17
|
+
with open(path) as f:
|
|
18
|
+
for line in f:
|
|
19
|
+
line = line.strip()
|
|
20
|
+
if line:
|
|
21
|
+
records.append(json.loads(line))
|
|
22
|
+
|
|
23
|
+
if not records:
|
|
24
|
+
raise ValueError(f"No records found in {path}")
|
|
25
|
+
|
|
26
|
+
return Dataset.from_list(records)
|