vespaembed 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. vespaembed/__init__.py +1 -1
  2. vespaembed/cli/__init__.py +17 -0
  3. vespaembed/cli/commands/__init__.py +7 -0
  4. vespaembed/cli/commands/evaluate.py +85 -0
  5. vespaembed/cli/commands/export.py +86 -0
  6. vespaembed/cli/commands/info.py +52 -0
  7. vespaembed/cli/commands/serve.py +49 -0
  8. vespaembed/cli/commands/train.py +267 -0
  9. vespaembed/cli/vespaembed.py +55 -0
  10. vespaembed/core/__init__.py +2 -0
  11. vespaembed/core/config.py +164 -0
  12. vespaembed/core/registry.py +158 -0
  13. vespaembed/core/trainer.py +573 -0
  14. vespaembed/datasets/__init__.py +3 -0
  15. vespaembed/datasets/formats/__init__.py +5 -0
  16. vespaembed/datasets/formats/csv.py +15 -0
  17. vespaembed/datasets/formats/huggingface.py +34 -0
  18. vespaembed/datasets/formats/jsonl.py +26 -0
  19. vespaembed/datasets/loader.py +80 -0
  20. vespaembed/db.py +176 -0
  21. vespaembed/enums.py +58 -0
  22. vespaembed/evaluation/__init__.py +3 -0
  23. vespaembed/evaluation/factory.py +86 -0
  24. vespaembed/models/__init__.py +4 -0
  25. vespaembed/models/export.py +89 -0
  26. vespaembed/models/loader.py +25 -0
  27. vespaembed/static/css/styles.css +1800 -0
  28. vespaembed/static/js/app.js +1485 -0
  29. vespaembed/tasks/__init__.py +23 -0
  30. vespaembed/tasks/base.py +144 -0
  31. vespaembed/tasks/pairs.py +91 -0
  32. vespaembed/tasks/similarity.py +84 -0
  33. vespaembed/tasks/triplets.py +90 -0
  34. vespaembed/tasks/tsdae.py +102 -0
  35. vespaembed/templates/index.html +544 -0
  36. vespaembed/utils/__init__.py +3 -0
  37. vespaembed/utils/logging.py +69 -0
  38. vespaembed/web/__init__.py +1 -0
  39. vespaembed/web/api/__init__.py +1 -0
  40. vespaembed/web/app.py +605 -0
  41. vespaembed/worker.py +313 -0
  42. vespaembed-0.0.3.dist-info/METADATA +325 -0
  43. vespaembed-0.0.3.dist-info/RECORD +47 -0
  44. {vespaembed-0.0.1.dist-info → vespaembed-0.0.3.dist-info}/WHEEL +1 -1
  45. vespaembed-0.0.1.dist-info/METADATA +0 -20
  46. vespaembed-0.0.1.dist-info/RECORD +0 -7
  47. {vespaembed-0.0.1.dist-info → vespaembed-0.0.3.dist-info}/entry_points.txt +0 -0
  48. {vespaembed-0.0.1.dist-info → vespaembed-0.0.3.dist-info}/licenses/LICENSE +0 -0
  49. {vespaembed-0.0.1.dist-info → vespaembed-0.0.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,573 @@
1
+ import json
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Callable, Optional
5
+
6
+ from sentence_transformers import SentenceTransformer, SentenceTransformerTrainer
7
+ from sentence_transformers.training_args import SentenceTransformerTrainingArguments
8
+ from transformers import TrainerCallback
9
+
10
+ from vespaembed.core.config import TrainingConfig
11
+ from vespaembed.core.registry import Registry
12
+ from vespaembed.datasets.loader import load_dataset
13
+ from vespaembed.utils.logging import logger
14
+
15
+ # Get HuggingFace token from environment (fallback for loading private models)
16
+ HF_TOKEN_ENV = os.environ.get("HF_TOKEN")
17
+
18
+
19
+ class ProgressCallback(TrainerCallback):
20
+ """Callback to report training progress like tqdm."""
21
+
22
+ def __init__(self, callback: Callable[[dict], None]):
23
+ self.callback = callback
24
+ self.start_time = None
25
+ self.total_steps = 0
26
+ self.total_epochs = 0
27
+
28
+ def on_train_begin(self, args, state, control, **kwargs):
29
+ """Called at the beginning of training."""
30
+ import time
31
+
32
+ self.start_time = time.time()
33
+ self.total_steps = state.max_steps
34
+ self.total_epochs = args.num_train_epochs
35
+
36
+ if self.callback:
37
+ self.callback(
38
+ {
39
+ "type": "train_start",
40
+ "total_steps": self.total_steps,
41
+ "total_epochs": self.total_epochs,
42
+ }
43
+ )
44
+
45
+ def on_log(self, args, state, control, logs=None, **kwargs):
46
+ """Called when the trainer logs metrics."""
47
+ import time
48
+
49
+ if logs and self.callback:
50
+ current_step = state.global_step
51
+ elapsed = time.time() - self.start_time if self.start_time else 0
52
+
53
+ # Calculate progress
54
+ progress_pct = (current_step / self.total_steps * 100) if self.total_steps > 0 else 0
55
+ steps_per_sec = current_step / elapsed if elapsed > 0 else 0
56
+ remaining_steps = self.total_steps - current_step
57
+ eta_seconds = remaining_steps / steps_per_sec if steps_per_sec > 0 else 0
58
+
59
+ progress = {
60
+ "type": "progress",
61
+ "epoch": state.epoch,
62
+ "total_epochs": self.total_epochs,
63
+ "step": current_step,
64
+ "total_steps": self.total_steps,
65
+ "progress_pct": progress_pct,
66
+ "loss": logs.get("loss"),
67
+ "learning_rate": logs.get("learning_rate"),
68
+ "steps_per_sec": steps_per_sec,
69
+ "elapsed_seconds": elapsed,
70
+ "eta_seconds": eta_seconds,
71
+ }
72
+ self.callback(progress)
73
+
74
+ def on_train_end(self, args, state, control, **kwargs):
75
+ """Called at the end of training."""
76
+ import time
77
+
78
+ if self.callback:
79
+ elapsed = time.time() - self.start_time if self.start_time else 0
80
+ self.callback(
81
+ {
82
+ "type": "train_end",
83
+ "total_steps": state.global_step,
84
+ "elapsed_seconds": elapsed,
85
+ }
86
+ )
87
+
88
+
89
+ class VespaEmbedTrainer:
90
+ """High-level trainer that wraps SentenceTransformerTrainer."""
91
+
92
+ def __init__(
93
+ self,
94
+ config: TrainingConfig,
95
+ progress_callback: Optional[Callable[[dict], None]] = None,
96
+ ):
97
+ """Initialize the trainer.
98
+
99
+ Args:
100
+ config: Training configuration
101
+ progress_callback: Optional callback for progress updates
102
+ """
103
+ self.config = config
104
+ self.model = None
105
+ self.task = None
106
+ self.progress_callback = progress_callback
107
+
108
+ def _load_model(self) -> SentenceTransformer:
109
+ """Load the model with optional LoRA and Unsloth support.
110
+
111
+ Supports four modes:
112
+ 1. Standard: SentenceTransformer only
113
+ 2. Standard + LoRA: SentenceTransformer with PEFT adapter
114
+ 3. Unsloth: FastSentenceTransformer (faster training)
115
+ 4. Unsloth + LoRA: FastSentenceTransformer with LoRA via get_peft_model
116
+ """
117
+ use_unsloth = self.config.unsloth.enabled
118
+
119
+ if use_unsloth:
120
+ return self._load_unsloth_model()
121
+ else:
122
+ return self._load_standard_model()
123
+
124
+ def _load_standard_model(self) -> SentenceTransformer:
125
+ """Load model with standard SentenceTransformer, optionally with LoRA."""
126
+ logger.info(f"Loading model: {self.config.base_model}")
127
+ model = SentenceTransformer(self.config.base_model, token=HF_TOKEN_ENV)
128
+
129
+ # Set max_seq_length if specified
130
+ if self.config.max_seq_length:
131
+ model.max_seq_length = self.config.max_seq_length
132
+ logger.info(f"Set max_seq_length: {self.config.max_seq_length}")
133
+ else:
134
+ logger.info(f"Using model's default max_seq_length: {model.max_seq_length}")
135
+
136
+ # Note: gradient_checkpointing is handled by SentenceTransformerTrainingArguments
137
+ # The Trainer will automatically enable it on the model during training
138
+
139
+ # Add LoRA adapter if enabled
140
+ if self.config.lora.enabled:
141
+ try:
142
+ from peft import LoraConfig, TaskType
143
+
144
+ logger.info(
145
+ f"Adding LoRA adapter: r={self.config.lora.r}, "
146
+ f"alpha={self.config.lora.alpha}, dropout={self.config.lora.dropout}, "
147
+ f"target_modules={self.config.lora.target_modules}"
148
+ )
149
+
150
+ peft_config = LoraConfig(
151
+ task_type=TaskType.FEATURE_EXTRACTION,
152
+ r=self.config.lora.r,
153
+ lora_alpha=self.config.lora.alpha,
154
+ lora_dropout=self.config.lora.dropout,
155
+ target_modules=self.config.lora.target_modules,
156
+ )
157
+ model.add_adapter(peft_config)
158
+
159
+ except ImportError:
160
+ raise ImportError("PEFT not installed. Install with: pip install peft")
161
+
162
+ return model
163
+
164
+ def _load_unsloth_model(self) -> SentenceTransformer:
165
+ """Load model with Unsloth for faster training."""
166
+ import torch
167
+
168
+ try:
169
+ from unsloth import FastSentenceTransformer
170
+ except ImportError:
171
+ raise ImportError("Unsloth not installed. Install with: pip install unsloth")
172
+
173
+ # Auto-detect max_seq_length from model if not specified
174
+ max_seq_length = self.config.max_seq_length
175
+ if max_seq_length is None:
176
+ # Load model temporarily to get max_seq_length
177
+ temp_model = SentenceTransformer(self.config.base_model, token=HF_TOKEN_ENV)
178
+ max_seq_length = temp_model.max_seq_length
179
+ del temp_model
180
+ logger.info(f"Auto-detected max_seq_length: {max_seq_length}")
181
+ else:
182
+ logger.info(f"Using specified max_seq_length: {max_seq_length}")
183
+
184
+ # Determine dtype based on precision settings
185
+ # BF16 is preferred for Unsloth when available (better for training)
186
+ if self.config.training.bf16:
187
+ dtype = torch.bfloat16
188
+ logger.info("Using BF16 precision for Unsloth")
189
+ elif self.config.training.fp16:
190
+ dtype = torch.float16
191
+ logger.info("Using FP16 precision for Unsloth")
192
+ else:
193
+ dtype = None # Let Unsloth auto-detect
194
+ logger.info("Using auto-detected precision for Unsloth")
195
+
196
+ # Full finetuning when Unsloth is enabled but LoRA is not
197
+ full_finetuning = not self.config.lora.enabled
198
+
199
+ # Use Unsloth's optimized gradient checkpointing ("unsloth") when enabled
200
+ # This is faster and uses less memory than standard PyTorch GC
201
+ # Must be passed to from_pretrained for full finetuning, or get_peft_model for LoRA
202
+ use_gc = "unsloth" if self.config.gradient_checkpointing else False
203
+
204
+ logger.info(f"Loading model with Unsloth: {self.config.base_model}")
205
+ if full_finetuning:
206
+ logger.info(f"Full finetuning mode (gradient_checkpointing={use_gc})")
207
+ model = FastSentenceTransformer.from_pretrained(
208
+ model_name=self.config.base_model,
209
+ max_seq_length=max_seq_length,
210
+ full_finetuning=full_finetuning,
211
+ dtype=dtype,
212
+ use_gradient_checkpointing=use_gc if full_finetuning else False,
213
+ token=HF_TOKEN_ENV,
214
+ )
215
+
216
+ # Add LoRA if enabled
217
+ if self.config.lora.enabled:
218
+ logger.info(
219
+ f"Applying Unsloth LoRA: r={self.config.lora.r}, "
220
+ f"alpha={self.config.lora.alpha}, target_modules={self.config.lora.target_modules}"
221
+ )
222
+
223
+ model = FastSentenceTransformer.get_peft_model(
224
+ model,
225
+ r=self.config.lora.r,
226
+ lora_alpha=self.config.lora.alpha,
227
+ lora_dropout=self.config.lora.dropout,
228
+ target_modules=self.config.lora.target_modules,
229
+ bias="none",
230
+ use_gradient_checkpointing=use_gc,
231
+ random_state=3407,
232
+ use_rslora=False,
233
+ loftq_config=None,
234
+ task_type="FEATURE_EXTRACTION",
235
+ )
236
+
237
+ return model
238
+
239
+ def _log_training_config(self):
240
+ """Log training configuration parameters."""
241
+ logger.info("=" * 60)
242
+ logger.info("Training Configuration")
243
+ logger.info("=" * 60)
244
+
245
+ # Model & Task
246
+ logger.info(f" Base Model: {self.config.base_model}")
247
+ logger.info(f" Task: {self.config.task}")
248
+ if self.config.loss_variant:
249
+ logger.info(f" Loss Variant: {self.config.loss_variant}")
250
+
251
+ # Data
252
+ logger.info(f" Training Data: {self.config.data.train}")
253
+ if self.config.data.eval:
254
+ logger.info(f" Eval Data: {self.config.data.eval}")
255
+
256
+ # Training hyperparameters
257
+ t = self.config.training
258
+ logger.info(f" Epochs: {t.epochs}")
259
+ logger.info(f" Batch Size: {t.batch_size}")
260
+ logger.info(f" Learning Rate: {t.learning_rate}")
261
+ logger.info(f" Optimizer: {t.optimizer}")
262
+ logger.info(f" Scheduler: {t.scheduler}")
263
+ logger.info(f" Warmup Ratio: {t.warmup_ratio}")
264
+ logger.info(f" Weight Decay: {t.weight_decay}")
265
+
266
+ # Precision
267
+ if t.bf16:
268
+ logger.info(" Precision: BF16")
269
+ elif t.fp16:
270
+ logger.info(" Precision: FP16")
271
+ else:
272
+ logger.info(" Precision: FP32")
273
+
274
+ # Optional features
275
+ if self.config.max_seq_length:
276
+ logger.info(f" Max Seq Length: {self.config.max_seq_length}")
277
+ if self.config.gradient_checkpointing:
278
+ logger.info(" Grad Checkpoint: Enabled")
279
+ if t.gradient_accumulation_steps > 1:
280
+ logger.info(f" Grad Accum: {t.gradient_accumulation_steps}")
281
+
282
+ # LoRA
283
+ if self.config.lora.enabled:
284
+ logger.info(f" LoRA: r={self.config.lora.r}, alpha={self.config.lora.alpha}")
285
+
286
+ # Unsloth
287
+ if self.config.unsloth.enabled:
288
+ logger.info(f" Unsloth: Enabled (save: {self.config.unsloth.save_method})")
289
+
290
+ # Matryoshka
291
+ if self.config.matryoshka_dims:
292
+ logger.info(f" Matryoshka: {self.config.matryoshka_dims}")
293
+
294
+ # Output
295
+ logger.info(f" Output Dir: {self.config.output.dir}")
296
+ if self.config.output.push_to_hub:
297
+ logger.info(f" Push to Hub: {self.config.output.hf_username}")
298
+
299
+ logger.info("=" * 60)
300
+
301
+ def train(self) -> SentenceTransformer:
302
+ """Run the training process.
303
+
304
+ Returns:
305
+ Trained SentenceTransformer model
306
+ """
307
+ # Log configuration
308
+ self._log_training_config()
309
+
310
+ # 1. Load model
311
+ self.model = self._load_model()
312
+
313
+ # 2. Get task (pass task-specific params if applicable)
314
+ task_cls = Registry.get_task(self.config.task)
315
+ # Handle loss_variant as either enum or string
316
+ loss_variant = self.config.loss_variant
317
+ if loss_variant is not None:
318
+ loss_variant = loss_variant.value if hasattr(loss_variant, "value") else loss_variant
319
+
320
+ if loss_variant:
321
+ self.task = task_cls(loss_variant=loss_variant)
322
+ else:
323
+ self.task = task_cls()
324
+
325
+ loss_info = f" (loss: {self.task.loss_variant})" if self.task.loss_variant else ""
326
+ logger.info(f"Using task: {self.task.name} - {self.task.description}{loss_info}")
327
+
328
+ # 3. Load and prepare training data
329
+ logger.info(f"Loading training data: {self.config.data.train}")
330
+ train_data = load_dataset(
331
+ self.config.data.train,
332
+ subset=self.config.data.subset,
333
+ split=self.config.data.split or "train",
334
+ )
335
+ train_data = self.task.prepare_dataset(train_data)
336
+ logger.info(f"Training samples: {len(train_data)}")
337
+
338
+ # 4. Load and prepare evaluation data (optional)
339
+ eval_data = None
340
+ evaluator = None
341
+ if self.config.data.eval:
342
+ logger.info(f"Loading evaluation data: {self.config.data.eval}")
343
+ # Use eval_split if specified (for HF datasets), otherwise use default split detection
344
+ eval_split = self.config.data.eval_split
345
+ eval_data = load_dataset(
346
+ self.config.data.eval,
347
+ subset=self.config.data.subset, # Use same subset as training
348
+ split=eval_split,
349
+ )
350
+ eval_data = self.task.prepare_dataset(eval_data)
351
+ evaluator = self.task.get_evaluator(eval_data)
352
+ logger.info(f"Evaluation samples: {len(eval_data)}")
353
+
354
+ # 5. Create loss function
355
+ loss = self.task.get_loss(self.model)
356
+
357
+ # Wrap with MatryoshkaLoss if dimensions specified (not supported for TSDAE)
358
+ if self.config.matryoshka_dims:
359
+ if self.config.task == "tsdae":
360
+ raise ValueError("Matryoshka is not supported with TSDAE (uses decoder architecture)")
361
+ from sentence_transformers.losses import MatryoshkaLoss
362
+
363
+ logger.info(f"Wrapping with MatryoshkaLoss: {self.config.matryoshka_dims}")
364
+ loss = MatryoshkaLoss(self.model, loss, matryoshka_dims=self.config.matryoshka_dims)
365
+
366
+ # 6. Create output directory
367
+ output_dir = Path(self.config.output.dir)
368
+ output_dir.mkdir(parents=True, exist_ok=True)
369
+
370
+ # 7. Training arguments
371
+ logger.info(f"Optimizer: {self.config.training.optimizer}, Scheduler: {self.config.training.scheduler}")
372
+ args = SentenceTransformerTrainingArguments(
373
+ output_dir=str(output_dir),
374
+ num_train_epochs=self.config.training.epochs,
375
+ per_device_train_batch_size=self.config.training.batch_size,
376
+ per_device_eval_batch_size=self.config.training.batch_size,
377
+ learning_rate=self.config.training.learning_rate,
378
+ warmup_ratio=self.config.training.warmup_ratio,
379
+ weight_decay=self.config.training.weight_decay,
380
+ fp16=self.config.training.fp16,
381
+ bf16=self.config.training.bf16,
382
+ optim=self.config.training.optimizer,
383
+ lr_scheduler_type=self.config.training.scheduler,
384
+ gradient_checkpointing=self.config.gradient_checkpointing,
385
+ batch_sampler=self.task.batch_sampler,
386
+ eval_strategy="steps" if evaluator else "no",
387
+ eval_steps=self.config.training.eval_steps if evaluator else None,
388
+ save_strategy="steps",
389
+ save_steps=self.config.training.save_steps,
390
+ save_total_limit=self.config.output.save_total_limit,
391
+ logging_steps=self.config.training.logging_steps,
392
+ gradient_accumulation_steps=self.config.training.gradient_accumulation_steps,
393
+ load_best_model_at_end=True if evaluator else False,
394
+ report_to="tensorboard",
395
+ logging_dir=str(output_dir / "logs"),
396
+ )
397
+
398
+ # 8. Create trainer
399
+ callbacks = []
400
+ if self.progress_callback:
401
+ callbacks.append(ProgressCallback(self.progress_callback))
402
+
403
+ trainer = SentenceTransformerTrainer(
404
+ model=self.model,
405
+ args=args,
406
+ train_dataset=train_data,
407
+ eval_dataset=eval_data,
408
+ loss=loss,
409
+ evaluator=evaluator,
410
+ callbacks=callbacks if callbacks else None,
411
+ )
412
+
413
+ # 9. Train
414
+ logger.info("Starting training...")
415
+ trainer.train()
416
+
417
+ # 10. Save final model
418
+ final_path = output_dir / "final"
419
+ logger.info(f"Saving model to: {final_path}")
420
+ self._save_model(final_path)
421
+
422
+ # 11. Add label mappings to config.json if task has labels (HuggingFace convention)
423
+ label_config = self.task.get_label_config()
424
+ if label_config:
425
+ config_path = final_path / "config.json"
426
+ if config_path.exists():
427
+ with open(config_path) as f:
428
+ config = json.load(f)
429
+ config.update(label_config)
430
+ with open(config_path, "w") as f:
431
+ json.dump(config, f, indent=2)
432
+ logger.info(f"Added label mappings to config.json ({label_config['num_labels']} labels)")
433
+
434
+ # 12. Push to hub if configured (always private)
435
+ if self.config.output.push_to_hub and self.config.output.hf_username:
436
+ if not HF_TOKEN_ENV:
437
+ logger.warning("HF_TOKEN environment variable not set, skipping push to hub")
438
+ else:
439
+ # Construct repo name from username and project directory name
440
+ project_name = Path(self.config.output.dir).name
441
+ repo_id = f"{self.config.output.hf_username}/{project_name}"
442
+ logger.info(f"Pushing to HuggingFace Hub (private): {repo_id}")
443
+ self._push_to_hub(repo_id)
444
+
445
+ logger.success("Training completed!")
446
+ return self.model
447
+
448
+ def _save_model(self, path: Path) -> None:
449
+ """Save the model based on configuration.
450
+
451
+ Handles different save methods for standard, LoRA, and Unsloth models.
452
+ """
453
+ path_str = str(path)
454
+
455
+ # Add vespaembed tag to model card metadata
456
+ if hasattr(self.model, "model_card_data") and self.model.model_card_data is not None:
457
+ self.model.model_card_data.add_tags("vespaembed")
458
+
459
+ if self.config.unsloth.enabled:
460
+ save_method = self.config.unsloth.save_method
461
+
462
+ if save_method == "lora":
463
+ # Save only LoRA adapters
464
+ logger.info("Saving LoRA adapters only")
465
+ self.model.save_pretrained(path_str)
466
+ elif save_method == "merged_16bit":
467
+ # Merge and save as FP16
468
+ logger.info("Saving merged model (FP16)")
469
+ self.model.save_pretrained_merged(
470
+ path_str,
471
+ tokenizer=self.model.tokenizer,
472
+ save_method="merged_16bit",
473
+ )
474
+ elif save_method == "merged_4bit":
475
+ # Merge and save as 4-bit
476
+ logger.info("Saving merged model (4-bit)")
477
+ self.model.save_pretrained_merged(
478
+ path_str,
479
+ tokenizer=self.model.tokenizer,
480
+ save_method="merged_4bit",
481
+ )
482
+ else:
483
+ # Standard or LoRA (PEFT) save
484
+ self.model.save_pretrained(path_str)
485
+
486
+ # Add vespaembed mention to README.md
487
+ self._add_vespaembed_to_readme(path)
488
+
489
+ def _add_vespaembed_to_readme(self, path: Path) -> None:
490
+ """Add vespaembed mention to the README.md file.
491
+
492
+ This method is idempotent - it will not add duplicate mentions if called multiple times.
493
+ """
494
+ readme_path = path / "README.md"
495
+ if not readme_path.exists():
496
+ return
497
+
498
+ content = readme_path.read_text(encoding="utf-8")
499
+
500
+ # Check if vespaembed mention already exists (idempotency)
501
+ if "github.com/vespaai-playground/vespaembed" in content:
502
+ return
503
+
504
+ # Insert vespaembed mention after the first heading
505
+ vespaembed_mention = (
506
+ "\n> This model was trained using " "[vespaembed](https://github.com/vespaai-playground/vespaembed).\n"
507
+ )
508
+
509
+ # Find the first heading and insert after it
510
+ lines = content.split("\n")
511
+ new_lines = []
512
+ inserted = False
513
+
514
+ for line in lines:
515
+ new_lines.append(line)
516
+ # Insert after the first markdown heading (starting with #)
517
+ if not inserted and line.startswith("# ") and not line.startswith("# For reference"):
518
+ new_lines.append(vespaembed_mention)
519
+ inserted = True
520
+
521
+ if inserted:
522
+ readme_path.write_text("\n".join(new_lines), encoding="utf-8")
523
+
524
+ def _push_to_hub(self, repo_id: str) -> None:
525
+ """Push the model to HuggingFace Hub.
526
+
527
+ Handles different push methods for standard, LoRA, and Unsloth models.
528
+ """
529
+ if self.config.unsloth.enabled:
530
+ save_method = self.config.unsloth.save_method
531
+
532
+ if save_method == "lora":
533
+ # Push only LoRA adapters
534
+ self.model.push_to_hub(repo_id, token=HF_TOKEN_ENV, private=True)
535
+ else:
536
+ # Push merged model
537
+ self.model.push_to_hub_merged(
538
+ repo_id,
539
+ tokenizer=self.model.tokenizer,
540
+ save_method=save_method,
541
+ token=HF_TOKEN_ENV,
542
+ private=True,
543
+ )
544
+ else:
545
+ # Standard or LoRA (PEFT) save
546
+ self.model.push_to_hub(repo_id, token=HF_TOKEN_ENV, private=True)
547
+
548
+ # Upload the modified README.md with vespaembed mention
549
+ self._upload_readme_to_hub(repo_id)
550
+
551
+ def _upload_readme_to_hub(self, repo_id: str) -> None:
552
+ """Upload the modified README.md to HuggingFace Hub.
553
+
554
+ This uploads the local README.md (which contains the vespaembed mention)
555
+ to overwrite the auto-generated one on the hub.
556
+ """
557
+ from huggingface_hub import HfApi
558
+
559
+ # Get the local README.md path from the final output directory
560
+ output_dir = Path(self.config.output.dir)
561
+ readme_path = output_dir / "final" / "README.md"
562
+
563
+ if not readme_path.exists():
564
+ logger.warning(f"README.md not found at {readme_path}, skipping upload")
565
+ return
566
+
567
+ api = HfApi()
568
+ api.upload_file(
569
+ path_or_fileobj=str(readme_path),
570
+ path_in_repo="README.md",
571
+ repo_id=repo_id,
572
+ token=HF_TOKEN_ENV,
573
+ )
@@ -0,0 +1,3 @@
1
+ from vespaembed.datasets.loader import load_dataset
2
+
3
+ __all__ = ["load_dataset"]
@@ -0,0 +1,5 @@
1
+ from vespaembed.datasets.formats.csv import load_csv
2
+ from vespaembed.datasets.formats.huggingface import load_hf_dataset
3
+ from vespaembed.datasets.formats.jsonl import load_jsonl
4
+
5
+ __all__ = ["load_csv", "load_hf_dataset", "load_jsonl"]
@@ -0,0 +1,15 @@
1
+ import pandas as pd
2
+ from datasets import Dataset
3
+
4
+
5
+ def load_csv(path: str) -> Dataset:
6
+ """Load a CSV file as a HuggingFace Dataset.
7
+
8
+ Args:
9
+ path: Path to CSV file
10
+
11
+ Returns:
12
+ HuggingFace Dataset
13
+ """
14
+ df = pd.read_csv(path)
15
+ return Dataset.from_pandas(df)
@@ -0,0 +1,34 @@
1
+ import os
2
+ from typing import Optional
3
+
4
+ from datasets import Dataset
5
+ from datasets import load_dataset as hf_load_dataset
6
+
7
+ # Get HuggingFace token from environment
8
+ HF_TOKEN = os.environ.get("HF_TOKEN")
9
+
10
+
11
+ def load_hf_dataset(
12
+ name: str,
13
+ subset: Optional[str] = None,
14
+ split: str = "train",
15
+ ) -> Dataset:
16
+ """Load a dataset from the HuggingFace Hub.
17
+
18
+ Args:
19
+ name: Dataset name (e.g., "sentence-transformers/all-nli")
20
+ subset: Dataset subset/configuration (optional)
21
+ split: Dataset split (default: "train")
22
+
23
+ Returns:
24
+ HuggingFace Dataset
25
+
26
+ Note:
27
+ Uses HF_TOKEN environment variable for authentication if set.
28
+ """
29
+ if subset:
30
+ dataset = hf_load_dataset(name, subset, split=split, token=HF_TOKEN)
31
+ else:
32
+ dataset = hf_load_dataset(name, split=split, token=HF_TOKEN)
33
+
34
+ return dataset
@@ -0,0 +1,26 @@
1
+ import json
2
+
3
+ from datasets import Dataset
4
+
5
+
6
+ def load_jsonl(path: str) -> Dataset:
7
+ """Load a JSONL file as a HuggingFace Dataset.
8
+
9
+ Args:
10
+ path: Path to JSONL file
11
+
12
+ Returns:
13
+ HuggingFace Dataset
14
+ """
15
+ records = []
16
+
17
+ with open(path) as f:
18
+ for line in f:
19
+ line = line.strip()
20
+ if line:
21
+ records.append(json.loads(line))
22
+
23
+ if not records:
24
+ raise ValueError(f"No records found in {path}")
25
+
26
+ return Dataset.from_list(records)