vespaembed 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. vespaembed/__init__.py +1 -1
  2. vespaembed/cli/__init__.py +17 -0
  3. vespaembed/cli/commands/__init__.py +7 -0
  4. vespaembed/cli/commands/evaluate.py +85 -0
  5. vespaembed/cli/commands/export.py +86 -0
  6. vespaembed/cli/commands/info.py +52 -0
  7. vespaembed/cli/commands/serve.py +49 -0
  8. vespaembed/cli/commands/train.py +267 -0
  9. vespaembed/cli/vespaembed.py +55 -0
  10. vespaembed/core/__init__.py +2 -0
  11. vespaembed/core/config.py +164 -0
  12. vespaembed/core/registry.py +158 -0
  13. vespaembed/core/trainer.py +573 -0
  14. vespaembed/datasets/__init__.py +3 -0
  15. vespaembed/datasets/formats/__init__.py +5 -0
  16. vespaembed/datasets/formats/csv.py +15 -0
  17. vespaembed/datasets/formats/huggingface.py +34 -0
  18. vespaembed/datasets/formats/jsonl.py +26 -0
  19. vespaembed/datasets/loader.py +80 -0
  20. vespaembed/db.py +176 -0
  21. vespaembed/enums.py +58 -0
  22. vespaembed/evaluation/__init__.py +3 -0
  23. vespaembed/evaluation/factory.py +86 -0
  24. vespaembed/models/__init__.py +4 -0
  25. vespaembed/models/export.py +89 -0
  26. vespaembed/models/loader.py +25 -0
  27. vespaembed/static/css/styles.css +1800 -0
  28. vespaembed/static/js/app.js +1485 -0
  29. vespaembed/tasks/__init__.py +23 -0
  30. vespaembed/tasks/base.py +144 -0
  31. vespaembed/tasks/pairs.py +91 -0
  32. vespaembed/tasks/similarity.py +84 -0
  33. vespaembed/tasks/triplets.py +90 -0
  34. vespaembed/tasks/tsdae.py +102 -0
  35. vespaembed/templates/index.html +544 -0
  36. vespaembed/utils/__init__.py +3 -0
  37. vespaembed/utils/logging.py +69 -0
  38. vespaembed/web/__init__.py +1 -0
  39. vespaembed/web/api/__init__.py +1 -0
  40. vespaembed/web/app.py +605 -0
  41. vespaembed/worker.py +313 -0
  42. vespaembed-0.0.2.dist-info/METADATA +325 -0
  43. vespaembed-0.0.2.dist-info/RECORD +47 -0
  44. {vespaembed-0.0.1.dist-info → vespaembed-0.0.2.dist-info}/WHEEL +1 -1
  45. vespaembed-0.0.1.dist-info/METADATA +0 -20
  46. vespaembed-0.0.1.dist-info/RECORD +0 -7
  47. {vespaembed-0.0.1.dist-info → vespaembed-0.0.2.dist-info}/entry_points.txt +0 -0
  48. {vespaembed-0.0.1.dist-info → vespaembed-0.0.2.dist-info}/licenses/LICENSE +0 -0
  49. {vespaembed-0.0.1.dist-info → vespaembed-0.0.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,23 @@
1
+ # Tasks module - imports are performed when this module is loaded
2
+ # This registers all tasks with the Registry
3
+ #
4
+ # Available tasks (by data format):
5
+ # 1. pairs - Text pairs for semantic search (anchor, positive)
6
+ # 2. triplets - Text triplets with hard negatives (anchor, positive, negative)
7
+ # 3. similarity - Text pairs with similarity scores (sentence1, sentence2, score)
8
+ # 4. tsdae - Unlabeled text for unsupervised learning (text)
9
+ #
10
+ # Matryoshka is a training option (--matryoshka flag) that can be enabled for any task except TSDAE.
11
+ from vespaembed.tasks.base import BaseTask
12
+ from vespaembed.tasks.pairs import PairsTask
13
+ from vespaembed.tasks.similarity import SimilarityTask
14
+ from vespaembed.tasks.triplets import TripletsTask
15
+ from vespaembed.tasks.tsdae import TSDAETask
16
+
17
+ __all__ = [
18
+ "BaseTask",
19
+ "PairsTask",
20
+ "SimilarityTask",
21
+ "TripletsTask",
22
+ "TSDAETask",
23
+ ]
@@ -0,0 +1,144 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any
3
+
4
+ from datasets import Dataset
5
+ from sentence_transformers import SentenceTransformer
6
+ from sentence_transformers.training_args import BatchSamplers
7
+
8
+
9
+ class BaseTask(ABC):
10
+ """Base class for all training tasks."""
11
+
12
+ # Task metadata
13
+ name: str = ""
14
+ description: str = ""
15
+
16
+ # Column configuration
17
+ expected_columns: list[str] = [] # Required columns
18
+ optional_columns: list[str] = [] # Optional columns (included if present)
19
+ column_aliases: dict[str, list[str]] = {}
20
+
21
+ # Loss configuration
22
+ loss_options: list[str] = [] # Available loss variants
23
+ default_loss: str = "" # Default loss variant
24
+
25
+ # Batch sampler (NO_DUPLICATES for in-batch negative losses)
26
+ batch_sampler: BatchSamplers = BatchSamplers.BATCH_SAMPLER
27
+
28
+ def __init__(self, loss_variant: str | None = None):
29
+ """Initialize task with optional loss variant.
30
+
31
+ Args:
32
+ loss_variant: Which loss variant to use (if task supports multiple)
33
+ """
34
+ # Label encoding mappings (set by prepare_dataset if task has labels)
35
+ self._label_to_idx: dict[str, int] | None = None
36
+ self._idx_to_label: dict[int, str] | None = None
37
+
38
+ # Loss variant selection
39
+ if loss_variant:
40
+ if self.loss_options and loss_variant not in self.loss_options:
41
+ raise ValueError(f"Unknown loss variant '{loss_variant}'. Options: {self.loss_options}")
42
+ self._loss_variant = loss_variant
43
+ else:
44
+ self._loss_variant = self.default_loss
45
+
46
+ # Track which optional columns were found
47
+ self._found_optional_columns: list[str] = []
48
+
49
+ @property
50
+ def loss_variant(self) -> str:
51
+ """Return the selected loss variant."""
52
+ return self._loss_variant
53
+
54
+ @property
55
+ def found_optional_columns(self) -> list[str]:
56
+ """Return list of optional columns that were found in the dataset."""
57
+ return self._found_optional_columns
58
+
59
+ @property
60
+ def label_to_idx(self) -> dict[str, int] | None:
61
+ """Return label to index mapping, or None if task has no labels."""
62
+ return self._label_to_idx
63
+
64
+ @property
65
+ def idx_to_label(self) -> dict[int, str] | None:
66
+ """Return index to label mapping, or None if task has no labels."""
67
+ return self._idx_to_label
68
+
69
+ @property
70
+ def num_labels(self) -> int | None:
71
+ """Return number of labels, or None if task has no labels."""
72
+ if self._label_to_idx is None:
73
+ return None
74
+ return len(self._label_to_idx)
75
+
76
+ def get_label_config(self) -> dict | None:
77
+ """Return label configuration in HuggingFace format.
78
+
79
+ Returns:
80
+ Dict with id2label, label2id, and num_labels, or None if no labels
81
+ """
82
+ if self._label_to_idx is None:
83
+ return None
84
+ return {
85
+ "id2label": {str(k): v for k, v in self._idx_to_label.items()},
86
+ "label2id": self._label_to_idx,
87
+ "num_labels": len(self._label_to_idx),
88
+ }
89
+
90
+ @abstractmethod
91
+ def get_loss(self, model: SentenceTransformer, **kwargs) -> Any:
92
+ """Return configured loss function from sentence-transformers.
93
+
94
+ Args:
95
+ model: The SentenceTransformer model
96
+ **kwargs: Additional loss configuration
97
+
98
+ Returns:
99
+ Loss function instance
100
+ """
101
+ raise NotImplementedError
102
+
103
+ def prepare_dataset(self, dataset: Dataset) -> Dataset:
104
+ """Normalize column names and reorder for the loss function.
105
+
106
+ Args:
107
+ dataset: Input dataset
108
+
109
+ Returns:
110
+ Prepared dataset with normalized columns
111
+ """
112
+ # 1. Rename aliased columns to canonical names
113
+ for canonical, aliases in self.column_aliases.items():
114
+ for alias in aliases:
115
+ if alias in dataset.column_names and canonical not in dataset.column_names:
116
+ dataset = dataset.rename_column(alias, canonical)
117
+ break
118
+
119
+ # 2. Validate required columns exist
120
+ missing = set(self.expected_columns) - set(dataset.column_names)
121
+ if missing:
122
+ available = ", ".join(sorted(dataset.column_names))
123
+ raise ValueError(
124
+ f"Missing required columns for task '{self.name}': {missing}. " f"Available columns: {available}"
125
+ )
126
+
127
+ # 3. Find which optional columns are present
128
+ self._found_optional_columns = [col for col in self.optional_columns if col in dataset.column_names]
129
+
130
+ # 4. Select and reorder columns (required + found optional)
131
+ columns_to_select = self.expected_columns + self._found_optional_columns
132
+ return dataset.select_columns(columns_to_select)
133
+
134
+ @abstractmethod
135
+ def get_evaluator(self, eval_dataset: Dataset) -> Any:
136
+ """Return appropriate sentence-transformers evaluator for this task.
137
+
138
+ Args:
139
+ eval_dataset: Evaluation dataset (already prepared)
140
+
141
+ Returns:
142
+ Evaluator instance
143
+ """
144
+ raise NotImplementedError
@@ -0,0 +1,91 @@
1
+ from datasets import Dataset
2
+ from sentence_transformers import SentenceTransformer
3
+ from sentence_transformers.evaluation import InformationRetrievalEvaluator
4
+ from sentence_transformers.losses import (
5
+ CachedGISTEmbedLoss,
6
+ CachedMultipleNegativesRankingLoss,
7
+ GISTEmbedLoss,
8
+ MultipleNegativesRankingLoss,
9
+ MultipleNegativesSymmetricRankingLoss,
10
+ )
11
+ from sentence_transformers.training_args import BatchSamplers
12
+
13
+ from vespaembed.core.registry import Registry
14
+ from vespaembed.tasks.base import BaseTask
15
+
16
+
17
+ @Registry.register_task("pairs")
18
+ class PairsTask(BaseTask):
19
+ """Text pairs task for semantic search and retrieval.
20
+
21
+ Use this when you have query-document pairs without explicit negatives.
22
+ The loss uses in-batch negatives for contrastive learning.
23
+
24
+ Data format:
25
+ - anchor: Query/question text
26
+ - positive: Relevant document/answer
27
+
28
+ Loss variants:
29
+ - mnr: MultipleNegativesRankingLoss (default) - standard in-batch negatives
30
+ - mnr_symmetric: Bidirectional ranking - use if you need "given answer, find query"
31
+ - gist: GISTEmbedLoss - uses guide model to filter false negatives
32
+ - cached_mnr: Cached version - allows larger effective batch sizes
33
+ - cached_gist: Cached GIST - combines both benefits
34
+
35
+ Tip: Use mine_hard_negatives() to create triplets for better results.
36
+ """
37
+
38
+ name = "pairs"
39
+ description = "Text pairs for semantic search (anchor, positive)"
40
+
41
+ expected_columns = ["anchor", "positive"]
42
+ column_aliases = {
43
+ "anchor": ["query", "question", "sent1", "sentence1", "text1"],
44
+ "positive": ["document", "answer", "pos", "sent2", "sentence2", "text2"],
45
+ }
46
+
47
+ # Loss variants
48
+ loss_options = ["mnr", "mnr_symmetric", "gist", "cached_mnr", "cached_gist"]
49
+ default_loss = "mnr"
50
+
51
+ batch_sampler = BatchSamplers.NO_DUPLICATES
52
+
53
+ def get_loss(self, model: SentenceTransformer, **kwargs) -> MultipleNegativesRankingLoss:
54
+ """Return the selected loss variant."""
55
+ guide_model = kwargs.pop("guide_model", None)
56
+ mini_batch_size = kwargs.pop("mini_batch_size", 32)
57
+
58
+ if self._loss_variant == "mnr":
59
+ return MultipleNegativesRankingLoss(model, **kwargs)
60
+
61
+ elif self._loss_variant == "mnr_symmetric":
62
+ return MultipleNegativesSymmetricRankingLoss(model, **kwargs)
63
+
64
+ elif self._loss_variant == "gist":
65
+ if guide_model is None:
66
+ guide_model = model
67
+ return GISTEmbedLoss(model, guide=guide_model, **kwargs)
68
+
69
+ elif self._loss_variant == "cached_mnr":
70
+ return CachedMultipleNegativesRankingLoss(model, mini_batch_size=mini_batch_size, **kwargs)
71
+
72
+ elif self._loss_variant == "cached_gist":
73
+ if guide_model is None:
74
+ guide_model = model
75
+ return CachedGISTEmbedLoss(model, guide=guide_model, mini_batch_size=mini_batch_size, **kwargs)
76
+
77
+ else:
78
+ return MultipleNegativesRankingLoss(model, **kwargs)
79
+
80
+ def get_evaluator(self, eval_dataset: Dataset):
81
+ """Return InformationRetrievalEvaluator for pair data."""
82
+ queries = {str(i): text for i, text in enumerate(eval_dataset["anchor"])}
83
+ corpus = {str(i): text for i, text in enumerate(eval_dataset["positive"])}
84
+ relevant_docs = {str(i): {str(i)} for i in range(len(eval_dataset))}
85
+
86
+ return InformationRetrievalEvaluator(
87
+ queries=queries,
88
+ corpus=corpus,
89
+ relevant_docs=relevant_docs,
90
+ name="pairs-eval",
91
+ )
@@ -0,0 +1,84 @@
1
+ from datasets import Dataset
2
+ from sentence_transformers import SentenceTransformer
3
+ from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
4
+ from sentence_transformers.losses import AnglELoss, CoSENTLoss, CosineSimilarityLoss
5
+ from sentence_transformers.training_args import BatchSamplers
6
+
7
+ from vespaembed.core.registry import Registry
8
+ from vespaembed.tasks.base import BaseTask
9
+
10
+
11
+ @Registry.register_task("similarity")
12
+ class SimilarityTask(BaseTask):
13
+ """Text pairs with similarity scores task.
14
+
15
+ Use this when you have pairs of sentences with continuous similarity
16
+ scores (e.g., 0.0 to 1.0 or 0 to 5). Common for STS (Semantic Textual
17
+ Similarity) benchmarks.
18
+
19
+ Data format:
20
+ - sentence1: First sentence
21
+ - sentence2: Second sentence
22
+ - score: Similarity score (will be normalized to 0-1 range)
23
+
24
+ Loss variants (similar performance, pick one):
25
+ - cosine: CosineSimilarityLoss (default) - simple and effective
26
+ - cosent: CoSENTLoss - ranking-based, from CoSENT paper
27
+ - angle: AnglELoss - angle-optimized, from AnglE paper
28
+
29
+ According to papers: AnglE >= CoSENT >= Cosine, but results are often similar.
30
+ """
31
+
32
+ name = "similarity"
33
+ description = "Text pairs with similarity scores (STS-style)"
34
+
35
+ expected_columns = ["sentence1", "sentence2", "score"]
36
+ column_aliases = {
37
+ "sentence1": ["sent1", "text1", "anchor", "query"],
38
+ "sentence2": ["sent2", "text2", "positive", "document"],
39
+ "score": ["similarity", "label", "sim_score"],
40
+ }
41
+
42
+ # Loss variants
43
+ loss_options = ["cosine", "cosent", "angle"]
44
+ default_loss = "cosine"
45
+
46
+ batch_sampler = BatchSamplers.BATCH_SAMPLER
47
+
48
+ def prepare_dataset(self, dataset: Dataset) -> Dataset:
49
+ """Prepare dataset and normalize scores to 0-1 range if needed."""
50
+ dataset = super().prepare_dataset(dataset)
51
+
52
+ # Check if scores need normalization (e.g., 0-5 scale to 0-1)
53
+ scores = dataset["score"]
54
+ max_score = max(scores)
55
+
56
+ if max_score > 1.0:
57
+ # Normalize to 0-1 range
58
+ dataset = dataset.map(lambda x: {"score": x["score"] / max_score})
59
+
60
+ return dataset
61
+
62
+ def get_loss(self, model: SentenceTransformer, **kwargs):
63
+ """Return the selected loss variant."""
64
+ if self._loss_variant == "cosine":
65
+ return CosineSimilarityLoss(model, **kwargs)
66
+
67
+ elif self._loss_variant == "cosent":
68
+ return CoSENTLoss(model, **kwargs)
69
+
70
+ elif self._loss_variant == "angle":
71
+ return AnglELoss(model, **kwargs)
72
+
73
+ else:
74
+ # Fallback to default
75
+ return CosineSimilarityLoss(model, **kwargs)
76
+
77
+ def get_evaluator(self, eval_dataset: Dataset) -> EmbeddingSimilarityEvaluator:
78
+ """Return EmbeddingSimilarityEvaluator."""
79
+ return EmbeddingSimilarityEvaluator(
80
+ sentences1=eval_dataset["sentence1"],
81
+ sentences2=eval_dataset["sentence2"],
82
+ scores=eval_dataset["score"],
83
+ name="similarity-eval",
84
+ )
@@ -0,0 +1,90 @@
1
+ from datasets import Dataset
2
+ from sentence_transformers import SentenceTransformer
3
+ from sentence_transformers.evaluation import TripletEvaluator
4
+ from sentence_transformers.losses import (
5
+ CachedGISTEmbedLoss,
6
+ CachedMultipleNegativesRankingLoss,
7
+ GISTEmbedLoss,
8
+ MultipleNegativesRankingLoss,
9
+ MultipleNegativesSymmetricRankingLoss,
10
+ )
11
+ from sentence_transformers.training_args import BatchSamplers
12
+
13
+ from vespaembed.core.registry import Registry
14
+ from vespaembed.tasks.base import BaseTask
15
+
16
+
17
+ @Registry.register_task("triplets")
18
+ class TripletsTask(BaseTask):
19
+ """Text triplets task for semantic search and retrieval.
20
+
21
+ Use this when you have query-document pairs WITH explicit hard negatives.
22
+ The loss uses both explicit negatives and in-batch negatives.
23
+
24
+ Data format:
25
+ - anchor: Query/question text
26
+ - positive: Relevant document/answer
27
+ - negative: Hard negative (irrelevant but similar document)
28
+
29
+ Loss variants:
30
+ - mnr: MultipleNegativesRankingLoss (default) - explicit + in-batch negatives
31
+ - mnr_symmetric: Bidirectional ranking - use if you need "given answer, find query"
32
+ - gist: GISTEmbedLoss - uses guide model to filter false negatives
33
+ - cached_mnr: Cached version - allows larger effective batch sizes
34
+ - cached_gist: Cached GIST - combines both benefits
35
+
36
+ Tip: Hard negatives significantly improve model quality. Use mine_hard_negatives()
37
+ to generate them from pairs data.
38
+ """
39
+
40
+ name = "triplets"
41
+ description = "Text triplets for semantic search (anchor, positive, negative)"
42
+
43
+ expected_columns = ["anchor", "positive", "negative"]
44
+ column_aliases = {
45
+ "anchor": ["query", "question", "sent1", "sentence1", "text1", "premise"],
46
+ "positive": ["document", "answer", "pos", "sent2", "sentence2", "text2", "entailment"],
47
+ "negative": ["neg", "hard_negative", "sent3", "sentence3", "text3", "contradiction"],
48
+ }
49
+
50
+ # Loss variants (same as pairs)
51
+ loss_options = ["mnr", "mnr_symmetric", "gist", "cached_mnr", "cached_gist"]
52
+ default_loss = "mnr"
53
+
54
+ batch_sampler = BatchSamplers.NO_DUPLICATES
55
+
56
+ def get_loss(self, model: SentenceTransformer, **kwargs) -> MultipleNegativesRankingLoss:
57
+ """Return the selected loss variant."""
58
+ guide_model = kwargs.pop("guide_model", None)
59
+ mini_batch_size = kwargs.pop("mini_batch_size", 32)
60
+
61
+ if self._loss_variant == "mnr":
62
+ return MultipleNegativesRankingLoss(model, **kwargs)
63
+
64
+ elif self._loss_variant == "mnr_symmetric":
65
+ return MultipleNegativesSymmetricRankingLoss(model, **kwargs)
66
+
67
+ elif self._loss_variant == "gist":
68
+ if guide_model is None:
69
+ guide_model = model
70
+ return GISTEmbedLoss(model, guide=guide_model, **kwargs)
71
+
72
+ elif self._loss_variant == "cached_mnr":
73
+ return CachedMultipleNegativesRankingLoss(model, mini_batch_size=mini_batch_size, **kwargs)
74
+
75
+ elif self._loss_variant == "cached_gist":
76
+ if guide_model is None:
77
+ guide_model = model
78
+ return CachedGISTEmbedLoss(model, guide=guide_model, mini_batch_size=mini_batch_size, **kwargs)
79
+
80
+ else:
81
+ return MultipleNegativesRankingLoss(model, **kwargs)
82
+
83
+ def get_evaluator(self, eval_dataset: Dataset):
84
+ """Return TripletEvaluator for triplet data."""
85
+ return TripletEvaluator(
86
+ anchors=eval_dataset["anchor"],
87
+ positives=eval_dataset["positive"],
88
+ negatives=eval_dataset["negative"],
89
+ name="triplets-eval",
90
+ )
@@ -0,0 +1,102 @@
1
+ import random
2
+
3
+ from datasets import Dataset
4
+ from sentence_transformers import SentenceTransformer
5
+ from sentence_transformers.losses import DenoisingAutoEncoderLoss
6
+ from sentence_transformers.training_args import BatchSamplers
7
+
8
+ from vespaembed.core.registry import Registry
9
+ from vespaembed.tasks.base import BaseTask
10
+
11
+
12
+ def _add_noise(text: str, del_ratio: float = 0.6) -> str:
13
+ """Add noise to text by randomly deleting words.
14
+
15
+ Args:
16
+ text: Input text string
17
+ del_ratio: Probability of keeping each word (default 0.6 = 60% kept)
18
+
19
+ Returns:
20
+ Noisy version of the text with some words randomly deleted
21
+ """
22
+ words = text.split()
23
+ if not words:
24
+ return text
25
+ # Keep words with probability del_ratio
26
+ kept_words = [word for word in words if random.random() < del_ratio]
27
+ if len(kept_words) == 0:
28
+ # Keep at least one random word
29
+ return random.choice(words)
30
+ return " ".join(kept_words)
31
+
32
+
33
+ @Registry.register_task("tsdae")
34
+ class TSDAETask(BaseTask):
35
+ """TSDAE (Transformer-based Sequential Denoising Auto-Encoder) task.
36
+
37
+ Unsupervised training that learns embeddings by reconstructing
38
+ corrupted input sentences. Useful for domain adaptation when
39
+ you only have unlabeled text.
40
+
41
+ Data format:
42
+ - text/sentence: Raw text sentences (no labels needed)
43
+
44
+ The task automatically adds noise by randomly deleting ~40% of words
45
+ from the input text. The model learns to reconstruct the original
46
+ text from the corrupted version.
47
+
48
+ Reference: https://arxiv.org/abs/2104.06979
49
+ """
50
+
51
+ name = "tsdae"
52
+ description = "TSDAE - unsupervised domain adaptation with denoising auto-encoder"
53
+
54
+ expected_columns = ["text"]
55
+ column_aliases = {
56
+ "text": ["sentence", "sentences", "content", "input"],
57
+ }
58
+
59
+ batch_sampler = BatchSamplers.BATCH_SAMPLER
60
+
61
+ def prepare_dataset(self, dataset: Dataset) -> Dataset:
62
+ """Prepare dataset for TSDAE training with noise.
63
+
64
+ Adds noise to text by randomly deleting words (40% deletion rate).
65
+ The model learns to reconstruct the original text from the noisy input.
66
+
67
+ Args:
68
+ dataset: Input dataset with 'text' column
69
+
70
+ Returns:
71
+ Dataset with 'anchor' (noisy) and 'positive' (original) columns
72
+ """
73
+ # First, apply base class normalization (handles column aliases)
74
+ dataset = super().prepare_dataset(dataset)
75
+
76
+ # Create noisy versions by randomly deleting words
77
+ # anchor = noisy text (input), positive = original text (target)
78
+ dataset = dataset.map(
79
+ lambda x: {"anchor": _add_noise(x["text"]), "positive": x["text"]},
80
+ desc="Adding noise to text",
81
+ )
82
+
83
+ # Remove the original text column as we now have anchor/positive
84
+ dataset = dataset.remove_columns(["text"])
85
+
86
+ return dataset
87
+
88
+ def get_loss(self, model: SentenceTransformer, **kwargs) -> DenoisingAutoEncoderLoss:
89
+ """Return DenoisingAutoEncoderLoss for unsupervised training."""
90
+ return DenoisingAutoEncoderLoss(
91
+ model=model,
92
+ tie_encoder_decoder=True,
93
+ **kwargs,
94
+ )
95
+
96
+ def get_evaluator(self, eval_dataset: Dataset):
97
+ """Return None - TSDAE has no intrinsic evaluator.
98
+
99
+ Evaluation should be done on downstream tasks (e.g., STS, retrieval)
100
+ that match your target use case.
101
+ """
102
+ return None