vespaembed 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vespaembed/__init__.py +1 -1
- vespaembed/cli/__init__.py +17 -0
- vespaembed/cli/commands/__init__.py +7 -0
- vespaembed/cli/commands/evaluate.py +85 -0
- vespaembed/cli/commands/export.py +86 -0
- vespaembed/cli/commands/info.py +52 -0
- vespaembed/cli/commands/serve.py +49 -0
- vespaembed/cli/commands/train.py +267 -0
- vespaembed/cli/vespaembed.py +55 -0
- vespaembed/core/__init__.py +2 -0
- vespaembed/core/config.py +164 -0
- vespaembed/core/registry.py +158 -0
- vespaembed/core/trainer.py +573 -0
- vespaembed/datasets/__init__.py +3 -0
- vespaembed/datasets/formats/__init__.py +5 -0
- vespaembed/datasets/formats/csv.py +15 -0
- vespaembed/datasets/formats/huggingface.py +34 -0
- vespaembed/datasets/formats/jsonl.py +26 -0
- vespaembed/datasets/loader.py +80 -0
- vespaembed/db.py +176 -0
- vespaembed/enums.py +58 -0
- vespaembed/evaluation/__init__.py +3 -0
- vespaembed/evaluation/factory.py +86 -0
- vespaembed/models/__init__.py +4 -0
- vespaembed/models/export.py +89 -0
- vespaembed/models/loader.py +25 -0
- vespaembed/static/css/styles.css +1800 -0
- vespaembed/static/js/app.js +1485 -0
- vespaembed/tasks/__init__.py +23 -0
- vespaembed/tasks/base.py +144 -0
- vespaembed/tasks/pairs.py +91 -0
- vespaembed/tasks/similarity.py +84 -0
- vespaembed/tasks/triplets.py +90 -0
- vespaembed/tasks/tsdae.py +102 -0
- vespaembed/templates/index.html +544 -0
- vespaembed/utils/__init__.py +3 -0
- vespaembed/utils/logging.py +69 -0
- vespaembed/web/__init__.py +1 -0
- vespaembed/web/api/__init__.py +1 -0
- vespaembed/web/app.py +605 -0
- vespaembed/worker.py +313 -0
- vespaembed-0.0.3.dist-info/METADATA +325 -0
- vespaembed-0.0.3.dist-info/RECORD +47 -0
- {vespaembed-0.0.1.dist-info → vespaembed-0.0.3.dist-info}/WHEEL +1 -1
- vespaembed-0.0.1.dist-info/METADATA +0 -20
- vespaembed-0.0.1.dist-info/RECORD +0 -7
- {vespaembed-0.0.1.dist-info → vespaembed-0.0.3.dist-info}/entry_points.txt +0 -0
- {vespaembed-0.0.1.dist-info → vespaembed-0.0.3.dist-info}/licenses/LICENSE +0 -0
- {vespaembed-0.0.1.dist-info → vespaembed-0.0.3.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# Tasks module - imports are performed when this module is loaded
|
|
2
|
+
# This registers all tasks with the Registry
|
|
3
|
+
#
|
|
4
|
+
# Available tasks (by data format):
|
|
5
|
+
# 1. pairs - Text pairs for semantic search (anchor, positive)
|
|
6
|
+
# 2. triplets - Text triplets with hard negatives (anchor, positive, negative)
|
|
7
|
+
# 3. similarity - Text pairs with similarity scores (sentence1, sentence2, score)
|
|
8
|
+
# 4. tsdae - Unlabeled text for unsupervised learning (text)
|
|
9
|
+
#
|
|
10
|
+
# Matryoshka is a training option (--matryoshka flag) that can be enabled for any task except TSDAE.
|
|
11
|
+
from vespaembed.tasks.base import BaseTask
|
|
12
|
+
from vespaembed.tasks.pairs import PairsTask
|
|
13
|
+
from vespaembed.tasks.similarity import SimilarityTask
|
|
14
|
+
from vespaembed.tasks.triplets import TripletsTask
|
|
15
|
+
from vespaembed.tasks.tsdae import TSDAETask
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"BaseTask",
|
|
19
|
+
"PairsTask",
|
|
20
|
+
"SimilarityTask",
|
|
21
|
+
"TripletsTask",
|
|
22
|
+
"TSDAETask",
|
|
23
|
+
]
|
vespaembed/tasks/base.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from datasets import Dataset
|
|
5
|
+
from sentence_transformers import SentenceTransformer
|
|
6
|
+
from sentence_transformers.training_args import BatchSamplers
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class BaseTask(ABC):
|
|
10
|
+
"""Base class for all training tasks."""
|
|
11
|
+
|
|
12
|
+
# Task metadata
|
|
13
|
+
name: str = ""
|
|
14
|
+
description: str = ""
|
|
15
|
+
|
|
16
|
+
# Column configuration
|
|
17
|
+
expected_columns: list[str] = [] # Required columns
|
|
18
|
+
optional_columns: list[str] = [] # Optional columns (included if present)
|
|
19
|
+
column_aliases: dict[str, list[str]] = {}
|
|
20
|
+
|
|
21
|
+
# Loss configuration
|
|
22
|
+
loss_options: list[str] = [] # Available loss variants
|
|
23
|
+
default_loss: str = "" # Default loss variant
|
|
24
|
+
|
|
25
|
+
# Batch sampler (NO_DUPLICATES for in-batch negative losses)
|
|
26
|
+
batch_sampler: BatchSamplers = BatchSamplers.BATCH_SAMPLER
|
|
27
|
+
|
|
28
|
+
def __init__(self, loss_variant: str | None = None):
|
|
29
|
+
"""Initialize task with optional loss variant.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
loss_variant: Which loss variant to use (if task supports multiple)
|
|
33
|
+
"""
|
|
34
|
+
# Label encoding mappings (set by prepare_dataset if task has labels)
|
|
35
|
+
self._label_to_idx: dict[str, int] | None = None
|
|
36
|
+
self._idx_to_label: dict[int, str] | None = None
|
|
37
|
+
|
|
38
|
+
# Loss variant selection
|
|
39
|
+
if loss_variant:
|
|
40
|
+
if self.loss_options and loss_variant not in self.loss_options:
|
|
41
|
+
raise ValueError(f"Unknown loss variant '{loss_variant}'. Options: {self.loss_options}")
|
|
42
|
+
self._loss_variant = loss_variant
|
|
43
|
+
else:
|
|
44
|
+
self._loss_variant = self.default_loss
|
|
45
|
+
|
|
46
|
+
# Track which optional columns were found
|
|
47
|
+
self._found_optional_columns: list[str] = []
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def loss_variant(self) -> str:
|
|
51
|
+
"""Return the selected loss variant."""
|
|
52
|
+
return self._loss_variant
|
|
53
|
+
|
|
54
|
+
@property
|
|
55
|
+
def found_optional_columns(self) -> list[str]:
|
|
56
|
+
"""Return list of optional columns that were found in the dataset."""
|
|
57
|
+
return self._found_optional_columns
|
|
58
|
+
|
|
59
|
+
@property
|
|
60
|
+
def label_to_idx(self) -> dict[str, int] | None:
|
|
61
|
+
"""Return label to index mapping, or None if task has no labels."""
|
|
62
|
+
return self._label_to_idx
|
|
63
|
+
|
|
64
|
+
@property
|
|
65
|
+
def idx_to_label(self) -> dict[int, str] | None:
|
|
66
|
+
"""Return index to label mapping, or None if task has no labels."""
|
|
67
|
+
return self._idx_to_label
|
|
68
|
+
|
|
69
|
+
@property
|
|
70
|
+
def num_labels(self) -> int | None:
|
|
71
|
+
"""Return number of labels, or None if task has no labels."""
|
|
72
|
+
if self._label_to_idx is None:
|
|
73
|
+
return None
|
|
74
|
+
return len(self._label_to_idx)
|
|
75
|
+
|
|
76
|
+
def get_label_config(self) -> dict | None:
|
|
77
|
+
"""Return label configuration in HuggingFace format.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
Dict with id2label, label2id, and num_labels, or None if no labels
|
|
81
|
+
"""
|
|
82
|
+
if self._label_to_idx is None:
|
|
83
|
+
return None
|
|
84
|
+
return {
|
|
85
|
+
"id2label": {str(k): v for k, v in self._idx_to_label.items()},
|
|
86
|
+
"label2id": self._label_to_idx,
|
|
87
|
+
"num_labels": len(self._label_to_idx),
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
@abstractmethod
|
|
91
|
+
def get_loss(self, model: SentenceTransformer, **kwargs) -> Any:
|
|
92
|
+
"""Return configured loss function from sentence-transformers.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
model: The SentenceTransformer model
|
|
96
|
+
**kwargs: Additional loss configuration
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
Loss function instance
|
|
100
|
+
"""
|
|
101
|
+
raise NotImplementedError
|
|
102
|
+
|
|
103
|
+
def prepare_dataset(self, dataset: Dataset) -> Dataset:
|
|
104
|
+
"""Normalize column names and reorder for the loss function.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
dataset: Input dataset
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
Prepared dataset with normalized columns
|
|
111
|
+
"""
|
|
112
|
+
# 1. Rename aliased columns to canonical names
|
|
113
|
+
for canonical, aliases in self.column_aliases.items():
|
|
114
|
+
for alias in aliases:
|
|
115
|
+
if alias in dataset.column_names and canonical not in dataset.column_names:
|
|
116
|
+
dataset = dataset.rename_column(alias, canonical)
|
|
117
|
+
break
|
|
118
|
+
|
|
119
|
+
# 2. Validate required columns exist
|
|
120
|
+
missing = set(self.expected_columns) - set(dataset.column_names)
|
|
121
|
+
if missing:
|
|
122
|
+
available = ", ".join(sorted(dataset.column_names))
|
|
123
|
+
raise ValueError(
|
|
124
|
+
f"Missing required columns for task '{self.name}': {missing}. " f"Available columns: {available}"
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# 3. Find which optional columns are present
|
|
128
|
+
self._found_optional_columns = [col for col in self.optional_columns if col in dataset.column_names]
|
|
129
|
+
|
|
130
|
+
# 4. Select and reorder columns (required + found optional)
|
|
131
|
+
columns_to_select = self.expected_columns + self._found_optional_columns
|
|
132
|
+
return dataset.select_columns(columns_to_select)
|
|
133
|
+
|
|
134
|
+
@abstractmethod
|
|
135
|
+
def get_evaluator(self, eval_dataset: Dataset) -> Any:
|
|
136
|
+
"""Return appropriate sentence-transformers evaluator for this task.
|
|
137
|
+
|
|
138
|
+
Args:
|
|
139
|
+
eval_dataset: Evaluation dataset (already prepared)
|
|
140
|
+
|
|
141
|
+
Returns:
|
|
142
|
+
Evaluator instance
|
|
143
|
+
"""
|
|
144
|
+
raise NotImplementedError
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
from datasets import Dataset
|
|
2
|
+
from sentence_transformers import SentenceTransformer
|
|
3
|
+
from sentence_transformers.evaluation import InformationRetrievalEvaluator
|
|
4
|
+
from sentence_transformers.losses import (
|
|
5
|
+
CachedGISTEmbedLoss,
|
|
6
|
+
CachedMultipleNegativesRankingLoss,
|
|
7
|
+
GISTEmbedLoss,
|
|
8
|
+
MultipleNegativesRankingLoss,
|
|
9
|
+
MultipleNegativesSymmetricRankingLoss,
|
|
10
|
+
)
|
|
11
|
+
from sentence_transformers.training_args import BatchSamplers
|
|
12
|
+
|
|
13
|
+
from vespaembed.core.registry import Registry
|
|
14
|
+
from vespaembed.tasks.base import BaseTask
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@Registry.register_task("pairs")
|
|
18
|
+
class PairsTask(BaseTask):
|
|
19
|
+
"""Text pairs task for semantic search and retrieval.
|
|
20
|
+
|
|
21
|
+
Use this when you have query-document pairs without explicit negatives.
|
|
22
|
+
The loss uses in-batch negatives for contrastive learning.
|
|
23
|
+
|
|
24
|
+
Data format:
|
|
25
|
+
- anchor: Query/question text
|
|
26
|
+
- positive: Relevant document/answer
|
|
27
|
+
|
|
28
|
+
Loss variants:
|
|
29
|
+
- mnr: MultipleNegativesRankingLoss (default) - standard in-batch negatives
|
|
30
|
+
- mnr_symmetric: Bidirectional ranking - use if you need "given answer, find query"
|
|
31
|
+
- gist: GISTEmbedLoss - uses guide model to filter false negatives
|
|
32
|
+
- cached_mnr: Cached version - allows larger effective batch sizes
|
|
33
|
+
- cached_gist: Cached GIST - combines both benefits
|
|
34
|
+
|
|
35
|
+
Tip: Use mine_hard_negatives() to create triplets for better results.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
name = "pairs"
|
|
39
|
+
description = "Text pairs for semantic search (anchor, positive)"
|
|
40
|
+
|
|
41
|
+
expected_columns = ["anchor", "positive"]
|
|
42
|
+
column_aliases = {
|
|
43
|
+
"anchor": ["query", "question", "sent1", "sentence1", "text1"],
|
|
44
|
+
"positive": ["document", "answer", "pos", "sent2", "sentence2", "text2"],
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
# Loss variants
|
|
48
|
+
loss_options = ["mnr", "mnr_symmetric", "gist", "cached_mnr", "cached_gist"]
|
|
49
|
+
default_loss = "mnr"
|
|
50
|
+
|
|
51
|
+
batch_sampler = BatchSamplers.NO_DUPLICATES
|
|
52
|
+
|
|
53
|
+
def get_loss(self, model: SentenceTransformer, **kwargs) -> MultipleNegativesRankingLoss:
|
|
54
|
+
"""Return the selected loss variant."""
|
|
55
|
+
guide_model = kwargs.pop("guide_model", None)
|
|
56
|
+
mini_batch_size = kwargs.pop("mini_batch_size", 32)
|
|
57
|
+
|
|
58
|
+
if self._loss_variant == "mnr":
|
|
59
|
+
return MultipleNegativesRankingLoss(model, **kwargs)
|
|
60
|
+
|
|
61
|
+
elif self._loss_variant == "mnr_symmetric":
|
|
62
|
+
return MultipleNegativesSymmetricRankingLoss(model, **kwargs)
|
|
63
|
+
|
|
64
|
+
elif self._loss_variant == "gist":
|
|
65
|
+
if guide_model is None:
|
|
66
|
+
guide_model = model
|
|
67
|
+
return GISTEmbedLoss(model, guide=guide_model, **kwargs)
|
|
68
|
+
|
|
69
|
+
elif self._loss_variant == "cached_mnr":
|
|
70
|
+
return CachedMultipleNegativesRankingLoss(model, mini_batch_size=mini_batch_size, **kwargs)
|
|
71
|
+
|
|
72
|
+
elif self._loss_variant == "cached_gist":
|
|
73
|
+
if guide_model is None:
|
|
74
|
+
guide_model = model
|
|
75
|
+
return CachedGISTEmbedLoss(model, guide=guide_model, mini_batch_size=mini_batch_size, **kwargs)
|
|
76
|
+
|
|
77
|
+
else:
|
|
78
|
+
return MultipleNegativesRankingLoss(model, **kwargs)
|
|
79
|
+
|
|
80
|
+
def get_evaluator(self, eval_dataset: Dataset):
|
|
81
|
+
"""Return InformationRetrievalEvaluator for pair data."""
|
|
82
|
+
queries = {str(i): text for i, text in enumerate(eval_dataset["anchor"])}
|
|
83
|
+
corpus = {str(i): text for i, text in enumerate(eval_dataset["positive"])}
|
|
84
|
+
relevant_docs = {str(i): {str(i)} for i in range(len(eval_dataset))}
|
|
85
|
+
|
|
86
|
+
return InformationRetrievalEvaluator(
|
|
87
|
+
queries=queries,
|
|
88
|
+
corpus=corpus,
|
|
89
|
+
relevant_docs=relevant_docs,
|
|
90
|
+
name="pairs-eval",
|
|
91
|
+
)
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
from datasets import Dataset
|
|
2
|
+
from sentence_transformers import SentenceTransformer
|
|
3
|
+
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
|
|
4
|
+
from sentence_transformers.losses import AnglELoss, CoSENTLoss, CosineSimilarityLoss
|
|
5
|
+
from sentence_transformers.training_args import BatchSamplers
|
|
6
|
+
|
|
7
|
+
from vespaembed.core.registry import Registry
|
|
8
|
+
from vespaembed.tasks.base import BaseTask
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@Registry.register_task("similarity")
|
|
12
|
+
class SimilarityTask(BaseTask):
|
|
13
|
+
"""Text pairs with similarity scores task.
|
|
14
|
+
|
|
15
|
+
Use this when you have pairs of sentences with continuous similarity
|
|
16
|
+
scores (e.g., 0.0 to 1.0 or 0 to 5). Common for STS (Semantic Textual
|
|
17
|
+
Similarity) benchmarks.
|
|
18
|
+
|
|
19
|
+
Data format:
|
|
20
|
+
- sentence1: First sentence
|
|
21
|
+
- sentence2: Second sentence
|
|
22
|
+
- score: Similarity score (will be normalized to 0-1 range)
|
|
23
|
+
|
|
24
|
+
Loss variants (similar performance, pick one):
|
|
25
|
+
- cosine: CosineSimilarityLoss (default) - simple and effective
|
|
26
|
+
- cosent: CoSENTLoss - ranking-based, from CoSENT paper
|
|
27
|
+
- angle: AnglELoss - angle-optimized, from AnglE paper
|
|
28
|
+
|
|
29
|
+
According to papers: AnglE >= CoSENT >= Cosine, but results are often similar.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
name = "similarity"
|
|
33
|
+
description = "Text pairs with similarity scores (STS-style)"
|
|
34
|
+
|
|
35
|
+
expected_columns = ["sentence1", "sentence2", "score"]
|
|
36
|
+
column_aliases = {
|
|
37
|
+
"sentence1": ["sent1", "text1", "anchor", "query"],
|
|
38
|
+
"sentence2": ["sent2", "text2", "positive", "document"],
|
|
39
|
+
"score": ["similarity", "label", "sim_score"],
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
# Loss variants
|
|
43
|
+
loss_options = ["cosine", "cosent", "angle"]
|
|
44
|
+
default_loss = "cosine"
|
|
45
|
+
|
|
46
|
+
batch_sampler = BatchSamplers.BATCH_SAMPLER
|
|
47
|
+
|
|
48
|
+
def prepare_dataset(self, dataset: Dataset) -> Dataset:
|
|
49
|
+
"""Prepare dataset and normalize scores to 0-1 range if needed."""
|
|
50
|
+
dataset = super().prepare_dataset(dataset)
|
|
51
|
+
|
|
52
|
+
# Check if scores need normalization (e.g., 0-5 scale to 0-1)
|
|
53
|
+
scores = dataset["score"]
|
|
54
|
+
max_score = max(scores)
|
|
55
|
+
|
|
56
|
+
if max_score > 1.0:
|
|
57
|
+
# Normalize to 0-1 range
|
|
58
|
+
dataset = dataset.map(lambda x: {"score": x["score"] / max_score})
|
|
59
|
+
|
|
60
|
+
return dataset
|
|
61
|
+
|
|
62
|
+
def get_loss(self, model: SentenceTransformer, **kwargs):
|
|
63
|
+
"""Return the selected loss variant."""
|
|
64
|
+
if self._loss_variant == "cosine":
|
|
65
|
+
return CosineSimilarityLoss(model, **kwargs)
|
|
66
|
+
|
|
67
|
+
elif self._loss_variant == "cosent":
|
|
68
|
+
return CoSENTLoss(model, **kwargs)
|
|
69
|
+
|
|
70
|
+
elif self._loss_variant == "angle":
|
|
71
|
+
return AnglELoss(model, **kwargs)
|
|
72
|
+
|
|
73
|
+
else:
|
|
74
|
+
# Fallback to default
|
|
75
|
+
return CosineSimilarityLoss(model, **kwargs)
|
|
76
|
+
|
|
77
|
+
def get_evaluator(self, eval_dataset: Dataset) -> EmbeddingSimilarityEvaluator:
|
|
78
|
+
"""Return EmbeddingSimilarityEvaluator."""
|
|
79
|
+
return EmbeddingSimilarityEvaluator(
|
|
80
|
+
sentences1=eval_dataset["sentence1"],
|
|
81
|
+
sentences2=eval_dataset["sentence2"],
|
|
82
|
+
scores=eval_dataset["score"],
|
|
83
|
+
name="similarity-eval",
|
|
84
|
+
)
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
from datasets import Dataset
|
|
2
|
+
from sentence_transformers import SentenceTransformer
|
|
3
|
+
from sentence_transformers.evaluation import TripletEvaluator
|
|
4
|
+
from sentence_transformers.losses import (
|
|
5
|
+
CachedGISTEmbedLoss,
|
|
6
|
+
CachedMultipleNegativesRankingLoss,
|
|
7
|
+
GISTEmbedLoss,
|
|
8
|
+
MultipleNegativesRankingLoss,
|
|
9
|
+
MultipleNegativesSymmetricRankingLoss,
|
|
10
|
+
)
|
|
11
|
+
from sentence_transformers.training_args import BatchSamplers
|
|
12
|
+
|
|
13
|
+
from vespaembed.core.registry import Registry
|
|
14
|
+
from vespaembed.tasks.base import BaseTask
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@Registry.register_task("triplets")
|
|
18
|
+
class TripletsTask(BaseTask):
|
|
19
|
+
"""Text triplets task for semantic search and retrieval.
|
|
20
|
+
|
|
21
|
+
Use this when you have query-document pairs WITH explicit hard negatives.
|
|
22
|
+
The loss uses both explicit negatives and in-batch negatives.
|
|
23
|
+
|
|
24
|
+
Data format:
|
|
25
|
+
- anchor: Query/question text
|
|
26
|
+
- positive: Relevant document/answer
|
|
27
|
+
- negative: Hard negative (irrelevant but similar document)
|
|
28
|
+
|
|
29
|
+
Loss variants:
|
|
30
|
+
- mnr: MultipleNegativesRankingLoss (default) - explicit + in-batch negatives
|
|
31
|
+
- mnr_symmetric: Bidirectional ranking - use if you need "given answer, find query"
|
|
32
|
+
- gist: GISTEmbedLoss - uses guide model to filter false negatives
|
|
33
|
+
- cached_mnr: Cached version - allows larger effective batch sizes
|
|
34
|
+
- cached_gist: Cached GIST - combines both benefits
|
|
35
|
+
|
|
36
|
+
Tip: Hard negatives significantly improve model quality. Use mine_hard_negatives()
|
|
37
|
+
to generate them from pairs data.
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
name = "triplets"
|
|
41
|
+
description = "Text triplets for semantic search (anchor, positive, negative)"
|
|
42
|
+
|
|
43
|
+
expected_columns = ["anchor", "positive", "negative"]
|
|
44
|
+
column_aliases = {
|
|
45
|
+
"anchor": ["query", "question", "sent1", "sentence1", "text1", "premise"],
|
|
46
|
+
"positive": ["document", "answer", "pos", "sent2", "sentence2", "text2", "entailment"],
|
|
47
|
+
"negative": ["neg", "hard_negative", "sent3", "sentence3", "text3", "contradiction"],
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
# Loss variants (same as pairs)
|
|
51
|
+
loss_options = ["mnr", "mnr_symmetric", "gist", "cached_mnr", "cached_gist"]
|
|
52
|
+
default_loss = "mnr"
|
|
53
|
+
|
|
54
|
+
batch_sampler = BatchSamplers.NO_DUPLICATES
|
|
55
|
+
|
|
56
|
+
def get_loss(self, model: SentenceTransformer, **kwargs) -> MultipleNegativesRankingLoss:
|
|
57
|
+
"""Return the selected loss variant."""
|
|
58
|
+
guide_model = kwargs.pop("guide_model", None)
|
|
59
|
+
mini_batch_size = kwargs.pop("mini_batch_size", 32)
|
|
60
|
+
|
|
61
|
+
if self._loss_variant == "mnr":
|
|
62
|
+
return MultipleNegativesRankingLoss(model, **kwargs)
|
|
63
|
+
|
|
64
|
+
elif self._loss_variant == "mnr_symmetric":
|
|
65
|
+
return MultipleNegativesSymmetricRankingLoss(model, **kwargs)
|
|
66
|
+
|
|
67
|
+
elif self._loss_variant == "gist":
|
|
68
|
+
if guide_model is None:
|
|
69
|
+
guide_model = model
|
|
70
|
+
return GISTEmbedLoss(model, guide=guide_model, **kwargs)
|
|
71
|
+
|
|
72
|
+
elif self._loss_variant == "cached_mnr":
|
|
73
|
+
return CachedMultipleNegativesRankingLoss(model, mini_batch_size=mini_batch_size, **kwargs)
|
|
74
|
+
|
|
75
|
+
elif self._loss_variant == "cached_gist":
|
|
76
|
+
if guide_model is None:
|
|
77
|
+
guide_model = model
|
|
78
|
+
return CachedGISTEmbedLoss(model, guide=guide_model, mini_batch_size=mini_batch_size, **kwargs)
|
|
79
|
+
|
|
80
|
+
else:
|
|
81
|
+
return MultipleNegativesRankingLoss(model, **kwargs)
|
|
82
|
+
|
|
83
|
+
def get_evaluator(self, eval_dataset: Dataset):
|
|
84
|
+
"""Return TripletEvaluator for triplet data."""
|
|
85
|
+
return TripletEvaluator(
|
|
86
|
+
anchors=eval_dataset["anchor"],
|
|
87
|
+
positives=eval_dataset["positive"],
|
|
88
|
+
negatives=eval_dataset["negative"],
|
|
89
|
+
name="triplets-eval",
|
|
90
|
+
)
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
import random
|
|
2
|
+
|
|
3
|
+
from datasets import Dataset
|
|
4
|
+
from sentence_transformers import SentenceTransformer
|
|
5
|
+
from sentence_transformers.losses import DenoisingAutoEncoderLoss
|
|
6
|
+
from sentence_transformers.training_args import BatchSamplers
|
|
7
|
+
|
|
8
|
+
from vespaembed.core.registry import Registry
|
|
9
|
+
from vespaembed.tasks.base import BaseTask
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _add_noise(text: str, del_ratio: float = 0.6) -> str:
|
|
13
|
+
"""Add noise to text by randomly deleting words.
|
|
14
|
+
|
|
15
|
+
Args:
|
|
16
|
+
text: Input text string
|
|
17
|
+
del_ratio: Probability of keeping each word (default 0.6 = 60% kept)
|
|
18
|
+
|
|
19
|
+
Returns:
|
|
20
|
+
Noisy version of the text with some words randomly deleted
|
|
21
|
+
"""
|
|
22
|
+
words = text.split()
|
|
23
|
+
if not words:
|
|
24
|
+
return text
|
|
25
|
+
# Keep words with probability del_ratio
|
|
26
|
+
kept_words = [word for word in words if random.random() < del_ratio]
|
|
27
|
+
if len(kept_words) == 0:
|
|
28
|
+
# Keep at least one random word
|
|
29
|
+
return random.choice(words)
|
|
30
|
+
return " ".join(kept_words)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@Registry.register_task("tsdae")
|
|
34
|
+
class TSDAETask(BaseTask):
|
|
35
|
+
"""TSDAE (Transformer-based Sequential Denoising Auto-Encoder) task.
|
|
36
|
+
|
|
37
|
+
Unsupervised training that learns embeddings by reconstructing
|
|
38
|
+
corrupted input sentences. Useful for domain adaptation when
|
|
39
|
+
you only have unlabeled text.
|
|
40
|
+
|
|
41
|
+
Data format:
|
|
42
|
+
- text/sentence: Raw text sentences (no labels needed)
|
|
43
|
+
|
|
44
|
+
The task automatically adds noise by randomly deleting ~40% of words
|
|
45
|
+
from the input text. The model learns to reconstruct the original
|
|
46
|
+
text from the corrupted version.
|
|
47
|
+
|
|
48
|
+
Reference: https://arxiv.org/abs/2104.06979
|
|
49
|
+
"""
|
|
50
|
+
|
|
51
|
+
name = "tsdae"
|
|
52
|
+
description = "TSDAE - unsupervised domain adaptation with denoising auto-encoder"
|
|
53
|
+
|
|
54
|
+
expected_columns = ["text"]
|
|
55
|
+
column_aliases = {
|
|
56
|
+
"text": ["sentence", "sentences", "content", "input"],
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
batch_sampler = BatchSamplers.BATCH_SAMPLER
|
|
60
|
+
|
|
61
|
+
def prepare_dataset(self, dataset: Dataset) -> Dataset:
|
|
62
|
+
"""Prepare dataset for TSDAE training with noise.
|
|
63
|
+
|
|
64
|
+
Adds noise to text by randomly deleting words (40% deletion rate).
|
|
65
|
+
The model learns to reconstruct the original text from the noisy input.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
dataset: Input dataset with 'text' column
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
Dataset with 'anchor' (noisy) and 'positive' (original) columns
|
|
72
|
+
"""
|
|
73
|
+
# First, apply base class normalization (handles column aliases)
|
|
74
|
+
dataset = super().prepare_dataset(dataset)
|
|
75
|
+
|
|
76
|
+
# Create noisy versions by randomly deleting words
|
|
77
|
+
# anchor = noisy text (input), positive = original text (target)
|
|
78
|
+
dataset = dataset.map(
|
|
79
|
+
lambda x: {"anchor": _add_noise(x["text"]), "positive": x["text"]},
|
|
80
|
+
desc="Adding noise to text",
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
# Remove the original text column as we now have anchor/positive
|
|
84
|
+
dataset = dataset.remove_columns(["text"])
|
|
85
|
+
|
|
86
|
+
return dataset
|
|
87
|
+
|
|
88
|
+
def get_loss(self, model: SentenceTransformer, **kwargs) -> DenoisingAutoEncoderLoss:
|
|
89
|
+
"""Return DenoisingAutoEncoderLoss for unsupervised training."""
|
|
90
|
+
return DenoisingAutoEncoderLoss(
|
|
91
|
+
model=model,
|
|
92
|
+
tie_encoder_decoder=True,
|
|
93
|
+
**kwargs,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
def get_evaluator(self, eval_dataset: Dataset):
|
|
97
|
+
"""Return None - TSDAE has no intrinsic evaluator.
|
|
98
|
+
|
|
99
|
+
Evaluation should be done on downstream tasks (e.g., STS, retrieval)
|
|
100
|
+
that match your target use case.
|
|
101
|
+
"""
|
|
102
|
+
return None
|