vespaembed 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vespaembed/__init__.py +1 -1
- vespaembed/cli/__init__.py +17 -0
- vespaembed/cli/commands/__init__.py +7 -0
- vespaembed/cli/commands/evaluate.py +85 -0
- vespaembed/cli/commands/export.py +86 -0
- vespaembed/cli/commands/info.py +52 -0
- vespaembed/cli/commands/serve.py +49 -0
- vespaembed/cli/commands/train.py +267 -0
- vespaembed/cli/vespaembed.py +55 -0
- vespaembed/core/__init__.py +2 -0
- vespaembed/core/config.py +164 -0
- vespaembed/core/registry.py +158 -0
- vespaembed/core/trainer.py +573 -0
- vespaembed/datasets/__init__.py +3 -0
- vespaembed/datasets/formats/__init__.py +5 -0
- vespaembed/datasets/formats/csv.py +15 -0
- vespaembed/datasets/formats/huggingface.py +34 -0
- vespaembed/datasets/formats/jsonl.py +26 -0
- vespaembed/datasets/loader.py +80 -0
- vespaembed/db.py +176 -0
- vespaembed/enums.py +58 -0
- vespaembed/evaluation/__init__.py +3 -0
- vespaembed/evaluation/factory.py +86 -0
- vespaembed/models/__init__.py +4 -0
- vespaembed/models/export.py +89 -0
- vespaembed/models/loader.py +25 -0
- vespaembed/static/css/styles.css +1800 -0
- vespaembed/static/js/app.js +1485 -0
- vespaembed/tasks/__init__.py +23 -0
- vespaembed/tasks/base.py +144 -0
- vespaembed/tasks/pairs.py +91 -0
- vespaembed/tasks/similarity.py +84 -0
- vespaembed/tasks/triplets.py +90 -0
- vespaembed/tasks/tsdae.py +102 -0
- vespaembed/templates/index.html +544 -0
- vespaembed/utils/__init__.py +3 -0
- vespaembed/utils/logging.py +69 -0
- vespaembed/web/__init__.py +1 -0
- vespaembed/web/api/__init__.py +1 -0
- vespaembed/web/app.py +605 -0
- vespaembed/worker.py +313 -0
- vespaembed-0.0.2.dist-info/METADATA +325 -0
- vespaembed-0.0.2.dist-info/RECORD +47 -0
- {vespaembed-0.0.1.dist-info → vespaembed-0.0.2.dist-info}/WHEEL +1 -1
- vespaembed-0.0.1.dist-info/METADATA +0 -20
- vespaembed-0.0.1.dist-info/RECORD +0 -7
- {vespaembed-0.0.1.dist-info → vespaembed-0.0.2.dist-info}/entry_points.txt +0 -0
- {vespaembed-0.0.1.dist-info → vespaembed-0.0.2.dist-info}/licenses/LICENSE +0 -0
- {vespaembed-0.0.1.dist-info → vespaembed-0.0.2.dist-info}/top_level.txt +0 -0
vespaembed/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.0.
|
|
1
|
+
__version__ = "0.0.2"
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from argparse import ArgumentParser
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class BaseCommand(ABC):
|
|
6
|
+
"""Base class for all CLI commands."""
|
|
7
|
+
|
|
8
|
+
@staticmethod
|
|
9
|
+
@abstractmethod
|
|
10
|
+
def register_subcommand(parser: ArgumentParser):
|
|
11
|
+
"""Register the subcommand with argparse."""
|
|
12
|
+
raise NotImplementedError
|
|
13
|
+
|
|
14
|
+
@abstractmethod
|
|
15
|
+
def execute(self):
|
|
16
|
+
"""Execute the command."""
|
|
17
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
from vespaembed.cli.commands.evaluate import EvaluateCommand
|
|
2
|
+
from vespaembed.cli.commands.export import ExportCommand
|
|
3
|
+
from vespaembed.cli.commands.info import InfoCommand
|
|
4
|
+
from vespaembed.cli.commands.serve import ServeCommand
|
|
5
|
+
from vespaembed.cli.commands.train import TrainCommand
|
|
6
|
+
|
|
7
|
+
__all__ = ["EvaluateCommand", "ExportCommand", "InfoCommand", "ServeCommand", "TrainCommand"]
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
from argparse import ArgumentParser, Namespace
|
|
2
|
+
|
|
3
|
+
from vespaembed.cli import BaseCommand
|
|
4
|
+
from vespaembed.core.registry import Registry
|
|
5
|
+
from vespaembed.datasets.loader import load_dataset
|
|
6
|
+
from vespaembed.models.loader import load_model
|
|
7
|
+
from vespaembed.utils.logging import logger
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def evaluate_command_factory(args: Namespace) -> "EvaluateCommand":
|
|
11
|
+
"""Factory function for EvaluateCommand."""
|
|
12
|
+
return EvaluateCommand(
|
|
13
|
+
model_path=args.model,
|
|
14
|
+
data_path=args.data,
|
|
15
|
+
task=args.task,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class EvaluateCommand(BaseCommand):
|
|
20
|
+
"""Evaluate a trained embedding model."""
|
|
21
|
+
|
|
22
|
+
@staticmethod
|
|
23
|
+
def register_subcommand(parser: ArgumentParser):
|
|
24
|
+
"""Register the evaluate subcommand."""
|
|
25
|
+
eval_parser = parser.add_parser(
|
|
26
|
+
"evaluate",
|
|
27
|
+
help="Evaluate a trained embedding model",
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
eval_parser.add_argument(
|
|
31
|
+
"--model",
|
|
32
|
+
type=str,
|
|
33
|
+
required=True,
|
|
34
|
+
help="Path to trained model",
|
|
35
|
+
)
|
|
36
|
+
eval_parser.add_argument(
|
|
37
|
+
"--data",
|
|
38
|
+
type=str,
|
|
39
|
+
required=True,
|
|
40
|
+
help="Path to evaluation data",
|
|
41
|
+
)
|
|
42
|
+
eval_parser.add_argument(
|
|
43
|
+
"--task",
|
|
44
|
+
type=str,
|
|
45
|
+
required=True,
|
|
46
|
+
choices=["mnr", "triplet", "contrastive", "sts", "nli", "tsdae", "matryoshka"],
|
|
47
|
+
help="Task type (determines evaluator)",
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
eval_parser.set_defaults(func=evaluate_command_factory)
|
|
51
|
+
|
|
52
|
+
def __init__(self, model_path: str, data_path: str, task: str):
|
|
53
|
+
self.model_path = model_path
|
|
54
|
+
self.data_path = data_path
|
|
55
|
+
self.task = task
|
|
56
|
+
|
|
57
|
+
def execute(self):
|
|
58
|
+
"""Execute the evaluate command."""
|
|
59
|
+
# Load model
|
|
60
|
+
logger.info(f"Loading model: {self.model_path}")
|
|
61
|
+
model = load_model(self.model_path)
|
|
62
|
+
|
|
63
|
+
# Load and prepare data
|
|
64
|
+
logger.info(f"Loading evaluation data: {self.data_path}")
|
|
65
|
+
task_cls = Registry.get_task(self.task)
|
|
66
|
+
task = task_cls()
|
|
67
|
+
|
|
68
|
+
eval_data = load_dataset(self.data_path)
|
|
69
|
+
eval_data = task.prepare_dataset(eval_data)
|
|
70
|
+
|
|
71
|
+
# Create evaluator
|
|
72
|
+
evaluator = task.get_evaluator(eval_data)
|
|
73
|
+
|
|
74
|
+
if evaluator is None:
|
|
75
|
+
logger.warning(f"No evaluator available for task: {self.task}")
|
|
76
|
+
return
|
|
77
|
+
|
|
78
|
+
# Run evaluation
|
|
79
|
+
logger.info("Running evaluation...")
|
|
80
|
+
results = evaluator(model)
|
|
81
|
+
|
|
82
|
+
# Print results
|
|
83
|
+
logger.success("Evaluation Results:")
|
|
84
|
+
for key, value in results.items():
|
|
85
|
+
logger.print(f" {key}: {value:.4f}")
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
from argparse import ArgumentParser, Namespace
|
|
2
|
+
|
|
3
|
+
from vespaembed.cli import BaseCommand
|
|
4
|
+
from vespaembed.models.export import export_model, push_to_hub
|
|
5
|
+
from vespaembed.models.loader import load_model
|
|
6
|
+
from vespaembed.utils.logging import logger
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def export_command_factory(args: Namespace) -> "ExportCommand":
|
|
10
|
+
"""Factory function for ExportCommand."""
|
|
11
|
+
return ExportCommand(
|
|
12
|
+
model_path=args.model,
|
|
13
|
+
output_path=args.output,
|
|
14
|
+
format=args.format,
|
|
15
|
+
hub_id=args.hub_id,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class ExportCommand(BaseCommand):
|
|
20
|
+
"""Export a trained model to different formats."""
|
|
21
|
+
|
|
22
|
+
@staticmethod
|
|
23
|
+
def register_subcommand(parser: ArgumentParser):
|
|
24
|
+
"""Register the export subcommand."""
|
|
25
|
+
export_parser = parser.add_parser(
|
|
26
|
+
"export",
|
|
27
|
+
help="Export a trained model to different formats",
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
export_parser.add_argument(
|
|
31
|
+
"--model",
|
|
32
|
+
type=str,
|
|
33
|
+
required=True,
|
|
34
|
+
help="Path to trained model",
|
|
35
|
+
)
|
|
36
|
+
export_parser.add_argument(
|
|
37
|
+
"--output",
|
|
38
|
+
type=str,
|
|
39
|
+
default=None,
|
|
40
|
+
help="Output path for exported model",
|
|
41
|
+
)
|
|
42
|
+
export_parser.add_argument(
|
|
43
|
+
"--format",
|
|
44
|
+
type=str,
|
|
45
|
+
default="onnx",
|
|
46
|
+
choices=["onnx"],
|
|
47
|
+
help="Export format (default: onnx)",
|
|
48
|
+
)
|
|
49
|
+
export_parser.add_argument(
|
|
50
|
+
"--hub-id",
|
|
51
|
+
type=str,
|
|
52
|
+
default=None,
|
|
53
|
+
help="Push to HuggingFace Hub with this repo ID",
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
export_parser.set_defaults(func=export_command_factory)
|
|
57
|
+
|
|
58
|
+
def __init__(
|
|
59
|
+
self,
|
|
60
|
+
model_path: str,
|
|
61
|
+
output_path: str = None,
|
|
62
|
+
format: str = "onnx",
|
|
63
|
+
hub_id: str = None,
|
|
64
|
+
):
|
|
65
|
+
self.model_path = model_path
|
|
66
|
+
self.output_path = output_path or f"{model_path}_exported"
|
|
67
|
+
self.format = format
|
|
68
|
+
self.hub_id = hub_id
|
|
69
|
+
|
|
70
|
+
def execute(self):
|
|
71
|
+
"""Execute the export command."""
|
|
72
|
+
# Load model
|
|
73
|
+
logger.info(f"Loading model: {self.model_path}")
|
|
74
|
+
model = load_model(self.model_path)
|
|
75
|
+
|
|
76
|
+
# Export
|
|
77
|
+
if self.format:
|
|
78
|
+
logger.info(f"Exporting to {self.format}: {self.output_path}")
|
|
79
|
+
export_path = export_model(model, self.output_path, self.format)
|
|
80
|
+
logger.success(f"Model exported to: {export_path}")
|
|
81
|
+
|
|
82
|
+
# Push to Hub
|
|
83
|
+
if self.hub_id:
|
|
84
|
+
logger.info(f"Pushing to HuggingFace Hub: {self.hub_id}")
|
|
85
|
+
url = push_to_hub(model, self.hub_id)
|
|
86
|
+
logger.success(f"Model pushed to: {url}")
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
from argparse import ArgumentParser, Namespace
|
|
2
|
+
|
|
3
|
+
# Import tasks to register them
|
|
4
|
+
import vespaembed.tasks # noqa: F401
|
|
5
|
+
from vespaembed.cli import BaseCommand
|
|
6
|
+
from vespaembed.core.registry import Registry
|
|
7
|
+
from vespaembed.utils.logging import logger
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def info_command_factory(args: Namespace) -> "InfoCommand":
|
|
11
|
+
"""Factory function for InfoCommand."""
|
|
12
|
+
return InfoCommand(show_tasks=args.tasks)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class InfoCommand(BaseCommand):
|
|
16
|
+
"""Show information about available tasks."""
|
|
17
|
+
|
|
18
|
+
@staticmethod
|
|
19
|
+
def register_subcommand(parser: ArgumentParser):
|
|
20
|
+
"""Register the info subcommand."""
|
|
21
|
+
info_parser = parser.add_parser(
|
|
22
|
+
"info",
|
|
23
|
+
help="Show information about available tasks",
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
info_parser.add_argument(
|
|
27
|
+
"--tasks",
|
|
28
|
+
action="store_true",
|
|
29
|
+
help="List available tasks",
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
info_parser.set_defaults(func=info_command_factory)
|
|
33
|
+
|
|
34
|
+
def __init__(self, show_tasks: bool = True):
|
|
35
|
+
self.show_tasks = show_tasks
|
|
36
|
+
|
|
37
|
+
def execute(self):
|
|
38
|
+
"""Execute the info command."""
|
|
39
|
+
if self.show_tasks or True: # Default to showing tasks
|
|
40
|
+
self._show_tasks()
|
|
41
|
+
|
|
42
|
+
def _show_tasks(self):
|
|
43
|
+
"""Display available tasks."""
|
|
44
|
+
logger.print("\n[bold]Available Tasks:[/bold]\n")
|
|
45
|
+
|
|
46
|
+
tasks_info = Registry.get_task_info()
|
|
47
|
+
|
|
48
|
+
for task in tasks_info:
|
|
49
|
+
logger.print(f" [cyan]{task['name']}[/cyan]")
|
|
50
|
+
logger.print(f" {task['description']}")
|
|
51
|
+
logger.print(f" Expected columns: {', '.join(task['expected_columns'])}")
|
|
52
|
+
logger.print("")
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
from argparse import ArgumentParser, Namespace
|
|
2
|
+
|
|
3
|
+
from vespaembed.cli import BaseCommand
|
|
4
|
+
from vespaembed.utils.logging import logger
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def serve_command_factory(args: Namespace) -> "ServeCommand":
|
|
8
|
+
"""Factory function for ServeCommand."""
|
|
9
|
+
return ServeCommand(host=args.host, port=args.port)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ServeCommand(BaseCommand):
|
|
13
|
+
"""Start the web UI server."""
|
|
14
|
+
|
|
15
|
+
@staticmethod
|
|
16
|
+
def register_subcommand(parser: ArgumentParser):
|
|
17
|
+
"""Register the serve subcommand."""
|
|
18
|
+
serve_parser = parser.add_parser(
|
|
19
|
+
"serve",
|
|
20
|
+
help="Start the web UI server",
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
serve_parser.add_argument(
|
|
24
|
+
"--host",
|
|
25
|
+
type=str,
|
|
26
|
+
default="127.0.0.1",
|
|
27
|
+
help="Host to bind to (default: 127.0.0.1)",
|
|
28
|
+
)
|
|
29
|
+
serve_parser.add_argument(
|
|
30
|
+
"--port",
|
|
31
|
+
type=int,
|
|
32
|
+
default=8000,
|
|
33
|
+
help="Port to bind to (default: 8000)",
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
serve_parser.set_defaults(func=serve_command_factory)
|
|
37
|
+
|
|
38
|
+
def __init__(self, host: str = "127.0.0.1", port: int = 8000):
|
|
39
|
+
self.host = host
|
|
40
|
+
self.port = port
|
|
41
|
+
|
|
42
|
+
def execute(self):
|
|
43
|
+
"""Execute the serve command."""
|
|
44
|
+
import uvicorn
|
|
45
|
+
|
|
46
|
+
from vespaembed.web.app import app
|
|
47
|
+
|
|
48
|
+
logger.info(f"Starting VespaEmbed web UI at http://{self.host}:{self.port}")
|
|
49
|
+
uvicorn.run(app, host=self.host, port=self.port)
|
|
@@ -0,0 +1,267 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from argparse import ArgumentParser, Namespace
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
# Import tasks to register them
|
|
7
|
+
import vespaembed.tasks # noqa: F401
|
|
8
|
+
from vespaembed.cli import BaseCommand
|
|
9
|
+
from vespaembed.core.config import TrainingConfig, load_config_from_yaml
|
|
10
|
+
from vespaembed.core.trainer import VespaEmbedTrainer
|
|
11
|
+
from vespaembed.utils.logging import logger
|
|
12
|
+
|
|
13
|
+
# Projects directory
|
|
14
|
+
PROJECTS_DIR = Path.home() / ".vespaembed" / "projects"
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def train_command_factory(args: Namespace) -> "TrainCommand":
|
|
18
|
+
"""Factory function for TrainCommand."""
|
|
19
|
+
return TrainCommand(
|
|
20
|
+
config_path=args.config,
|
|
21
|
+
data=args.data,
|
|
22
|
+
task=args.task,
|
|
23
|
+
base_model=args.base_model,
|
|
24
|
+
project=args.project,
|
|
25
|
+
eval_data=args.eval_data,
|
|
26
|
+
epochs=args.epochs,
|
|
27
|
+
batch_size=args.batch_size,
|
|
28
|
+
learning_rate=args.learning_rate,
|
|
29
|
+
optimizer=args.optimizer,
|
|
30
|
+
scheduler=args.scheduler,
|
|
31
|
+
unsloth=args.unsloth,
|
|
32
|
+
matryoshka=args.matryoshka,
|
|
33
|
+
matryoshka_dims=args.matryoshka_dims,
|
|
34
|
+
subset=args.subset,
|
|
35
|
+
split=args.split,
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class TrainCommand(BaseCommand):
|
|
40
|
+
"""Train or fine-tune an embedding model."""
|
|
41
|
+
|
|
42
|
+
@staticmethod
|
|
43
|
+
def register_subcommand(parser: ArgumentParser):
|
|
44
|
+
"""Register the train subcommand."""
|
|
45
|
+
train_parser = parser.add_parser(
|
|
46
|
+
"train",
|
|
47
|
+
help="Train or fine-tune an embedding model",
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
# Config file (alternative to CLI args)
|
|
51
|
+
train_parser.add_argument(
|
|
52
|
+
"--config",
|
|
53
|
+
type=str,
|
|
54
|
+
default=None,
|
|
55
|
+
help="Path to YAML configuration file",
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# Required arguments (unless using config file)
|
|
59
|
+
train_parser.add_argument(
|
|
60
|
+
"--data",
|
|
61
|
+
type=str,
|
|
62
|
+
default=None,
|
|
63
|
+
help="Path to training data (CSV, JSONL, or HF dataset)",
|
|
64
|
+
)
|
|
65
|
+
train_parser.add_argument(
|
|
66
|
+
"--task",
|
|
67
|
+
type=str,
|
|
68
|
+
default=None,
|
|
69
|
+
choices=["pairs", "triplets", "similarity", "tsdae"],
|
|
70
|
+
help="Training task type",
|
|
71
|
+
)
|
|
72
|
+
train_parser.add_argument(
|
|
73
|
+
"--base-model",
|
|
74
|
+
type=str,
|
|
75
|
+
default=None,
|
|
76
|
+
help="Base model name or path",
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
# Optional arguments
|
|
80
|
+
train_parser.add_argument(
|
|
81
|
+
"--project",
|
|
82
|
+
type=str,
|
|
83
|
+
default=None,
|
|
84
|
+
help="Project name. Output saved to ~/.vespaembed/projects/<name>/",
|
|
85
|
+
)
|
|
86
|
+
train_parser.add_argument(
|
|
87
|
+
"--eval-data",
|
|
88
|
+
type=str,
|
|
89
|
+
default=None,
|
|
90
|
+
help="Path to evaluation data",
|
|
91
|
+
)
|
|
92
|
+
train_parser.add_argument(
|
|
93
|
+
"--epochs",
|
|
94
|
+
type=int,
|
|
95
|
+
default=3,
|
|
96
|
+
help="Number of training epochs (default: 3)",
|
|
97
|
+
)
|
|
98
|
+
train_parser.add_argument(
|
|
99
|
+
"--batch-size",
|
|
100
|
+
type=int,
|
|
101
|
+
default=32,
|
|
102
|
+
help="Batch size (default: 32)",
|
|
103
|
+
)
|
|
104
|
+
train_parser.add_argument(
|
|
105
|
+
"--learning-rate",
|
|
106
|
+
type=float,
|
|
107
|
+
default=2e-5,
|
|
108
|
+
help="Learning rate (default: 2e-5)",
|
|
109
|
+
)
|
|
110
|
+
train_parser.add_argument(
|
|
111
|
+
"--optimizer",
|
|
112
|
+
type=str,
|
|
113
|
+
default="adamw_torch",
|
|
114
|
+
choices=["adamw_torch", "adamw_torch_fused", "adamw_8bit", "adafactor", "sgd"],
|
|
115
|
+
help="Optimizer type (default: adamw_torch)",
|
|
116
|
+
)
|
|
117
|
+
train_parser.add_argument(
|
|
118
|
+
"--scheduler",
|
|
119
|
+
type=str,
|
|
120
|
+
default="linear",
|
|
121
|
+
choices=["linear", "cosine", "cosine_with_restarts", "constant", "constant_with_warmup", "polynomial"],
|
|
122
|
+
help="Learning rate scheduler (default: linear)",
|
|
123
|
+
)
|
|
124
|
+
train_parser.add_argument(
|
|
125
|
+
"--unsloth",
|
|
126
|
+
action="store_true",
|
|
127
|
+
help="Use Unsloth for faster training",
|
|
128
|
+
)
|
|
129
|
+
train_parser.add_argument(
|
|
130
|
+
"--matryoshka",
|
|
131
|
+
action="store_true",
|
|
132
|
+
help="Enable Matryoshka embeddings (multi-dimensional)",
|
|
133
|
+
)
|
|
134
|
+
train_parser.add_argument(
|
|
135
|
+
"--matryoshka-dims",
|
|
136
|
+
type=str,
|
|
137
|
+
default="768,512,256,128,64",
|
|
138
|
+
help="Matryoshka dimensions, comma-separated (default: 768,512,256,128,64)",
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# HuggingFace dataset options
|
|
142
|
+
train_parser.add_argument(
|
|
143
|
+
"--subset",
|
|
144
|
+
type=str,
|
|
145
|
+
default=None,
|
|
146
|
+
help="HuggingFace dataset subset",
|
|
147
|
+
)
|
|
148
|
+
train_parser.add_argument(
|
|
149
|
+
"--split",
|
|
150
|
+
type=str,
|
|
151
|
+
default=None,
|
|
152
|
+
help="HuggingFace dataset split",
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
train_parser.set_defaults(func=train_command_factory)
|
|
156
|
+
|
|
157
|
+
def __init__(
|
|
158
|
+
self,
|
|
159
|
+
config_path: Optional[str] = None,
|
|
160
|
+
data: Optional[str] = None,
|
|
161
|
+
task: Optional[str] = None,
|
|
162
|
+
base_model: Optional[str] = None,
|
|
163
|
+
project: Optional[str] = None,
|
|
164
|
+
eval_data: Optional[str] = None,
|
|
165
|
+
epochs: int = 3,
|
|
166
|
+
batch_size: int = 32,
|
|
167
|
+
learning_rate: float = 2e-5,
|
|
168
|
+
optimizer: str = "adamw_torch",
|
|
169
|
+
scheduler: str = "linear",
|
|
170
|
+
unsloth: bool = False,
|
|
171
|
+
matryoshka: bool = False,
|
|
172
|
+
matryoshka_dims: str = "768,512,256,128,64",
|
|
173
|
+
subset: Optional[str] = None,
|
|
174
|
+
split: Optional[str] = None,
|
|
175
|
+
):
|
|
176
|
+
self.config_path = config_path
|
|
177
|
+
self.data = data
|
|
178
|
+
self.task = task
|
|
179
|
+
self.base_model = base_model
|
|
180
|
+
self.project = project
|
|
181
|
+
self.eval_data = eval_data
|
|
182
|
+
self.epochs = epochs
|
|
183
|
+
self.batch_size = batch_size
|
|
184
|
+
self.learning_rate = learning_rate
|
|
185
|
+
self.optimizer = optimizer
|
|
186
|
+
self.scheduler = scheduler
|
|
187
|
+
self.unsloth = unsloth
|
|
188
|
+
self.matryoshka = matryoshka
|
|
189
|
+
self.matryoshka_dims = matryoshka_dims
|
|
190
|
+
self.subset = subset
|
|
191
|
+
self.split = split
|
|
192
|
+
|
|
193
|
+
def _generate_project_name(self) -> str:
|
|
194
|
+
"""Generate a random project name."""
|
|
195
|
+
import random
|
|
196
|
+
import string
|
|
197
|
+
|
|
198
|
+
return "".join(random.choices(string.ascii_lowercase + string.digits, k=8))
|
|
199
|
+
|
|
200
|
+
def _resolve_output_dir(self, project_name: str) -> Path:
|
|
201
|
+
"""Resolve output directory from project name."""
|
|
202
|
+
PROJECTS_DIR.mkdir(parents=True, exist_ok=True)
|
|
203
|
+
output_dir = PROJECTS_DIR / project_name
|
|
204
|
+
if output_dir.exists():
|
|
205
|
+
# Append timestamp to make unique
|
|
206
|
+
timestamp = int(time.time())
|
|
207
|
+
output_dir = PROJECTS_DIR / f"{project_name}-{timestamp}"
|
|
208
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
209
|
+
return output_dir
|
|
210
|
+
|
|
211
|
+
def execute(self):
|
|
212
|
+
"""Execute the train command."""
|
|
213
|
+
# Load config from file or build from CLI args
|
|
214
|
+
if self.config_path:
|
|
215
|
+
logger.info(f"Loading config from: {self.config_path}")
|
|
216
|
+
config = load_config_from_yaml(self.config_path)
|
|
217
|
+
else:
|
|
218
|
+
# Validate required arguments
|
|
219
|
+
if not self.data:
|
|
220
|
+
raise ValueError("--data is required (or use --config)")
|
|
221
|
+
if not self.task:
|
|
222
|
+
raise ValueError("--task is required (or use --config)")
|
|
223
|
+
if not self.base_model:
|
|
224
|
+
raise ValueError("--base-model is required (or use --config)")
|
|
225
|
+
|
|
226
|
+
# Generate project name if not provided
|
|
227
|
+
project_name = self.project or self._generate_project_name()
|
|
228
|
+
output_dir = self._resolve_output_dir(project_name)
|
|
229
|
+
|
|
230
|
+
logger.info(f"Project: {project_name}")
|
|
231
|
+
logger.info(f"Output: {output_dir}")
|
|
232
|
+
|
|
233
|
+
# Parse matryoshka dimensions if enabled
|
|
234
|
+
matryoshka_dims = None
|
|
235
|
+
if self.matryoshka:
|
|
236
|
+
if self.task == "tsdae":
|
|
237
|
+
raise ValueError("Matryoshka is not supported with TSDAE task")
|
|
238
|
+
matryoshka_dims = [int(d.strip()) for d in self.matryoshka_dims.split(",") if d.strip()]
|
|
239
|
+
logger.info(f"Matryoshka enabled with dimensions: {matryoshka_dims}")
|
|
240
|
+
|
|
241
|
+
# Build config from CLI args
|
|
242
|
+
config = TrainingConfig(
|
|
243
|
+
task=self.task,
|
|
244
|
+
base_model=self.base_model,
|
|
245
|
+
data={
|
|
246
|
+
"train": self.data,
|
|
247
|
+
"eval": self.eval_data,
|
|
248
|
+
"subset": self.subset,
|
|
249
|
+
"split": self.split,
|
|
250
|
+
},
|
|
251
|
+
training={
|
|
252
|
+
"epochs": self.epochs,
|
|
253
|
+
"batch_size": self.batch_size,
|
|
254
|
+
"learning_rate": self.learning_rate,
|
|
255
|
+
"optimizer": self.optimizer,
|
|
256
|
+
"scheduler": self.scheduler,
|
|
257
|
+
},
|
|
258
|
+
output={
|
|
259
|
+
"dir": str(output_dir),
|
|
260
|
+
},
|
|
261
|
+
unsloth=self.unsloth,
|
|
262
|
+
matryoshka_dims=matryoshka_dims,
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
# Create and run trainer
|
|
266
|
+
trainer = VespaEmbedTrainer(config)
|
|
267
|
+
trainer.train()
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
|
|
3
|
+
from vespaembed.cli.commands.evaluate import EvaluateCommand
|
|
4
|
+
from vespaembed.cli.commands.export import ExportCommand
|
|
5
|
+
from vespaembed.cli.commands.info import InfoCommand
|
|
6
|
+
from vespaembed.cli.commands.serve import ServeCommand
|
|
7
|
+
from vespaembed.cli.commands.train import TrainCommand
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def main():
|
|
11
|
+
"""Main entry point for vespaembed CLI."""
|
|
12
|
+
parser = argparse.ArgumentParser(
|
|
13
|
+
prog="vespaembed",
|
|
14
|
+
description="VespaEmbed - No-code training for embedding models",
|
|
15
|
+
usage="vespaembed [<command>] [<args>]",
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
# Global arguments
|
|
19
|
+
parser.add_argument(
|
|
20
|
+
"--host",
|
|
21
|
+
type=str,
|
|
22
|
+
default="127.0.0.1",
|
|
23
|
+
help="Host for web UI (default: 127.0.0.1)",
|
|
24
|
+
)
|
|
25
|
+
parser.add_argument(
|
|
26
|
+
"--port",
|
|
27
|
+
type=int,
|
|
28
|
+
default=8000,
|
|
29
|
+
help="Port for web UI (default: 8000)",
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
# Subcommands
|
|
33
|
+
commands_parser = parser.add_subparsers(dest="command")
|
|
34
|
+
|
|
35
|
+
# Register all commands
|
|
36
|
+
TrainCommand.register_subcommand(commands_parser)
|
|
37
|
+
EvaluateCommand.register_subcommand(commands_parser)
|
|
38
|
+
ExportCommand.register_subcommand(commands_parser)
|
|
39
|
+
ServeCommand.register_subcommand(commands_parser)
|
|
40
|
+
InfoCommand.register_subcommand(commands_parser)
|
|
41
|
+
|
|
42
|
+
args = parser.parse_args()
|
|
43
|
+
|
|
44
|
+
# If no command specified, launch web UI
|
|
45
|
+
if not hasattr(args, "func") or args.func is None:
|
|
46
|
+
command = ServeCommand(host=args.host, port=args.port)
|
|
47
|
+
command.execute()
|
|
48
|
+
else:
|
|
49
|
+
# Execute the specified command
|
|
50
|
+
command = args.func(args)
|
|
51
|
+
command.execute()
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
if __name__ == "__main__":
|
|
55
|
+
main()
|