embed-train 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,37 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import TYPE_CHECKING, Any
3
+
4
+ import yaml
5
+
6
+ from embed_train.exceptions import InvalidRunnerClassError
7
+ from embed_train.utils import load_class
8
+ from retrievalbase.mixins import FromConfigMixin
9
+
10
+ if TYPE_CHECKING:
11
+ from embed_train.settings import RunnerSettings
12
+
13
+
14
+ class Runner[TCRunner: "RunnerSettings"](FromConfigMixin[TCRunner], ABC):
15
+ def __init__(self, config: TCRunner):
16
+ self.config = config
17
+
18
+ @abstractmethod
19
+ def run(self) -> None:
20
+ raise NotImplementedError()
21
+
22
+ @classmethod
23
+ def _validate_klass(self, symbol: Any) -> type["Runner[TCRunner]"]:
24
+ if not isinstance(symbol, type):
25
+ raise InvalidRunnerClassError(f"Loaded object is not a class: {symbol!r}")
26
+ # 2️⃣ Ensure it is a Runner subclass
27
+ if not issubclass(symbol, Runner):
28
+ raise InvalidRunnerClassError(f"Class {symbol.__module__}.{symbol.__name__} is not a subclass of Runner")
29
+ return symbol
30
+
31
+ @classmethod
32
+ def get_runner(cls, config_path: str) -> "Runner[TCRunner]":
33
+ with open(config_path) as f:
34
+ yaml_config = yaml.safe_load(f)
35
+ symbol = load_class(yaml_config["module_path"])
36
+ klass: type[Runner[TCRunner]] = cls._validate_klass(symbol)
37
+ return klass.from_settings()
@@ -0,0 +1,3 @@
1
+ RANDOM_SEED = 42
2
+ CONFIG_PATH = "/config/config.yaml"
3
+ TRUST_REMOTE_CODE = True
@@ -0,0 +1,31 @@
1
+ class EmbedTrainError(Exception):
2
+ """Base exception for all embed_train errors."""
3
+
4
+
5
+ # -------------------------
6
+ # Runtime errors
7
+ # -------------------------
8
+ class EmbedTrainRuntimeError(RuntimeError, EmbedTrainError):
9
+ """Runtime errors during execution (training, loading, etc.)."""
10
+
11
+
12
+ class MissingContextError(EmbedTrainRuntimeError):
13
+ """Raised when required context is missing."""
14
+
15
+
16
+ # -------------------------
17
+ # Type errors
18
+ # -------------------------
19
+ class EmbedTrainTypeError(TypeError, EmbedTrainError):
20
+ """Invalid type provided."""
21
+
22
+
23
+ class InvalidRunnerClassError(EmbedTrainTypeError):
24
+ """Runner class is not valid or does not inherit expected base."""
25
+
26
+
27
+ # -------------------------
28
+ # Value errors
29
+ # -------------------------
30
+ class EmbedTrainValueError(ValueError, EmbedTrainError):
31
+ """Invalid value provided in configuration or runtime."""
@@ -0,0 +1,61 @@
1
+ from abc import ABC, abstractmethod
2
+ from pathlib import Path
3
+ from typing import TYPE_CHECKING, Any, Self, cast
4
+
5
+ from torch import nn
6
+ from transformers import PreTrainedModel
7
+
8
+ from embed_train.utils import load_checkpoint_state_dict
9
+ from retrievalbase.mixins import FromConfigMixin
10
+
11
+ if TYPE_CHECKING:
12
+ from embed_train.settings import ModelSettings
13
+
14
+
15
+ class Model[TCModel: "ModelSettings"](FromConfigMixin[TCModel], ABC):
16
+ def __init__(self, config: TCModel):
17
+ self.config = config
18
+
19
+ @abstractmethod
20
+ def to_hf_model(self) -> PreTrainedModel: # inplace
21
+ raise NotImplementedError()
22
+
23
+ def save(
24
+ self,
25
+ repo_dir: str | Path,
26
+ push: bool = False,
27
+ **push_kwargs: Any,
28
+ ) -> None:
29
+ repo_dir_path = Path(repo_dir) if isinstance(repo_dir, str) else repo_dir
30
+ self.to_hf_model().save_pretrained(repo_dir_path)
31
+ if push:
32
+ self.to_hf_model().push_to_hub(
33
+ repo_id=repo_dir_path.name,
34
+ **push_kwargs,
35
+ )
36
+
37
+ def to(self, device: str) -> Self:
38
+ cast(nn.Module, self.to_hf_model()).to(device)
39
+ return self
40
+
41
+ @classmethod
42
+ def from_checkpoint(
43
+ cls,
44
+ config: TCModel,
45
+ checkpoint: str,
46
+ device: str,
47
+ strict: bool = True,
48
+ ) -> "Model[TCModel]":
49
+ model = cls(config)
50
+ hf_model = model.to_hf_model()
51
+ checkpoint_path = Path(checkpoint)
52
+ if checkpoint_path.is_dir():
53
+ state_dict_path = checkpoint_path / "model.safetensors"
54
+ if not state_dict_path.exists():
55
+ raise FileNotFoundError(f"No model.safetensors found in {checkpoint_path}")
56
+ state_dict = load_checkpoint_state_dict(state_dict_path)
57
+ else:
58
+ state_dict = load_checkpoint_state_dict(checkpoint_path)
59
+ hf_model.load_state_dict(state_dict, strict=strict)
60
+ hf_model.to(device) # ty: ignore[invalid-argument-type]
61
+ return model
@@ -0,0 +1,131 @@
1
+ # finetune_embedder/src/runners/push_to_hf.py
2
+ import inspect
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING, Any
6
+
7
+ from huggingface_hub import HfApi
8
+
9
+ from embed_train import Runner
10
+ from embed_train.models import Model
11
+ from embed_train.utils import load_class
12
+
13
+ _logger = logging.getLogger(__name__)
14
+
15
+ if TYPE_CHECKING:
16
+ from embed_train.settings import PushToHFRunnerSettings
17
+
18
+
19
+ class PushToHFRunner[TCPushToHFRunner: "PushToHFRunnerSettings[Any]"](Runner[TCPushToHFRunner]):
20
+ def __init__(self, config: TCPushToHFRunner) -> None:
21
+ super().__init__(config)
22
+
23
+ def _infer_model_name(self) -> str:
24
+ name: str = self.wrapper_cls.__name__
25
+ if not name.endswith("ModelWrapper"):
26
+ raise ValueError(f"Model wrapper class must end with 'ModelWrapper', got {name}")
27
+ return name.removesuffix("ModelWrapper").lower()
28
+
29
+ def _find_model_source_dir(self) -> Path:
30
+ module_file = Path(inspect.getfile(self.wrapper_cls)).resolve()
31
+ models_dir = module_file.parent
32
+ return models_dir
33
+
34
+ def _collect_model_files(self) -> list[Path]:
35
+ model_name = self._infer_model_name()
36
+ models_dir = self._find_model_source_dir()
37
+
38
+ model_dir = models_dir / model_name
39
+ if not model_dir.exists():
40
+ raise FileNotFoundError(f"Model directory not found: {model_dir}")
41
+
42
+ files: list[Path] = []
43
+
44
+ for pattern in (
45
+ "modeling_*.py",
46
+ "configuration_*.py",
47
+ "vllm_modeling_*.py",
48
+ "vllm_configuration_*.py",
49
+ ):
50
+ files.extend(model_dir.glob(pattern))
51
+
52
+ if not files:
53
+ raise RuntimeError(f"No model source files found in {model_dir}")
54
+
55
+ return sorted(set(files))
56
+
57
+ def _save_model_source_files(self) -> None:
58
+ files = self._collect_model_files()
59
+ for src in files:
60
+ dst = Path(self.config.hf.repo) / src.name
61
+ dst.write_text(src.read_text(encoding="utf-8"), encoding="utf-8")
62
+
63
+ def _load_model(self) -> Model[Any]:
64
+ _logger.info(f"Loading model class from {self.config.model.module_path}")
65
+ model_cls: Model[Any] = load_class(self.config.model.module_path)
66
+ return model_cls.from_checkpoint(
67
+ config=self.config.model,
68
+ checkpoint=self.config.checkpoint_path,
69
+ device=self.config.device,
70
+ )
71
+
72
+ def _load_wrapper(self) -> None:
73
+ _logger.info(f"Loading model wrapper → {self.config.model.module_path}")
74
+ self.wrapper_cls = load_class(self.config.model.module_path)
75
+
76
+ def _load_model_from_checkpoint(self) -> None:
77
+ _logger.info(f"Loading model from checkpoint → {self.config.checkpoint_path}")
78
+ self.model: Model[Any] = self.wrapper_cls.from_checkpoint(
79
+ config=self.config.model,
80
+ checkpoint=self.config.checkpoint_path,
81
+ device=self.config.device,
82
+ )
83
+
84
+ def _save_model_locally(self) -> None:
85
+ _logger.info("Saving model locally → %s", self.config.hf.repo)
86
+ self.model.save(
87
+ repo_dir=self.config.hf.repo,
88
+ push=False, # we control pushing
89
+ create_repo=self.config.create_repo,
90
+ revision=self.config.hf.revision,
91
+ private=self.config.hf.private,
92
+ commit_message=self.config.hf.commit_message,
93
+ )
94
+
95
+ def _ensure_hf_repo_exists(self) -> None:
96
+ if not self.config.create_repo:
97
+ return
98
+ _logger.info(
99
+ f"Ensuring Hugging Face repo exists | repo={self.config.hf.repo} | private={self.config.hf.private}"
100
+ )
101
+ api = HfApi()
102
+ api.create_repo(
103
+ repo_id=self.config.hf.repo,
104
+ repo_type="model",
105
+ private=self.config.hf.private,
106
+ exist_ok=True, # safe for retries / CI
107
+ )
108
+
109
+ def _push_repo_to_hf(self) -> None:
110
+ _logger.info(
111
+ f"Pushing full repository folder to Hugging Face | repo={self.config.hf.repo} | revision={self.config.hf.revision}"
112
+ )
113
+ api = HfApi()
114
+ api.upload_folder(
115
+ folder_path=self.config.hf.repo,
116
+ repo_id=self.config.hf.repo,
117
+ repo_type="model",
118
+ revision=self.config.hf.revision,
119
+ commit_message=self.config.hf.commit_message,
120
+ )
121
+ _logger.info("Repository successfully pushed to Hugging Face")
122
+
123
+ def run(self) -> None:
124
+ self._load_wrapper()
125
+ self._load_model_from_checkpoint()
126
+ self._save_model_locally()
127
+ self._save_model_source_files()
128
+ if self.config.push:
129
+ self._ensure_hf_repo_exists()
130
+ self._push_repo_to_hf()
131
+ _logger.info("PushToHFRunner completed successfully")
embed_train/py.typed ADDED
File without changes
@@ -0,0 +1,202 @@
1
+ from pathlib import Path
2
+ from typing import TYPE_CHECKING, Any, Literal, Self
3
+
4
+ from pydantic import Field, model_validator
5
+ from pydantic_settings import BaseSettings
6
+
7
+ from embed_train.constants import TRUST_REMOTE_CODE
8
+ from embed_train.exceptions import EmbedTrainValueError
9
+ from retrievalbase.dataset.settings import HuggingFaceDatasetAdaptaterSettings
10
+ from retrievalbase.settings import FromConfigMixinSettings
11
+
12
+ if TYPE_CHECKING:
13
+ from retrievalbase.connector.settings import DatasetConnectorSettings
14
+ from retrievalbase.evaluation.settings import ProcessorSettings
15
+
16
+
17
+ class TorchDatasetSettings[TCDatasetConnector: "DatasetConnectorSettings"](FromConfigMixinSettings):
18
+ dataset_connector: TCDatasetConnector
19
+
20
+
21
+ class TokenizerSettings(BaseSettings):
22
+ name: str
23
+ padding: str | bool
24
+ truncation: bool
25
+ max_length: int
26
+ trust_remote_code: bool = TRUST_REMOTE_CODE
27
+
28
+
29
+ class SamplerSettings(BaseSettings):
30
+ module_path: str
31
+ seed: int | None
32
+
33
+
34
+ class DistanceSettings(FromConfigMixinSettings):
35
+ pass
36
+
37
+
38
+ class StepRangeDistanceSettings(DistanceSettings):
39
+ pass
40
+
41
+
42
+ class PositiveSamplerSettings[TCDistance: "DistanceSettings"](SamplerSettings):
43
+ max_step_distance: int
44
+ distance: TCDistance
45
+ k: int
46
+
47
+
48
+ class CollateFnSettings[TCProcessor: "ProcessorSettings"](BaseSettings):
49
+ module_path: str
50
+ tokenizer: TokenizerSettings
51
+ processor: TCProcessor
52
+
53
+
54
+ class InBatchPositiveCollateFnSettings[TCProcessor: "ProcessorSettings"](CollateFnSettings[TCProcessor]):
55
+ pass
56
+
57
+
58
+ class MultiPositiveInBatchCollateFnSettings[TCProcessor: "ProcessorSettings"](CollateFnSettings[TCProcessor]):
59
+ n_pos: int
60
+
61
+
62
+ class ModelSettings(FromConfigMixinSettings):
63
+ pass
64
+
65
+
66
+ class LossSettings(FromConfigMixinSettings):
67
+ pass
68
+
69
+
70
+ class TrainerSettings(FromConfigMixinSettings):
71
+ data_dir: Path
72
+
73
+
74
+ class PyTorchTrainerSettings[
75
+ TCModel: "ModelSettings",
76
+ TCLoss: "LossSettings",
77
+ TCTorchDataset: "TorchDatasetSettings[Any]",
78
+ TCCollateFn: "CollateFnSettings[Any]",
79
+ ](TrainerSettings):
80
+ model: TCModel
81
+ torch_dataset: TCTorchDataset
82
+ collate_fn: TCCollateFn
83
+ train_frac: float
84
+ num_epochs: int
85
+ batch_size: int
86
+ shuffle: bool
87
+ lr: float
88
+ device: str
89
+ save_every: int
90
+ drop_last: bool
91
+ loss: TCLoss
92
+ resume_from: str | None = None
93
+
94
+
95
+ class ContrastiveLossSettings(LossSettings):
96
+ temperature: float
97
+
98
+
99
+ class QueryMultiPositiveDatasetSettings[TCDatasetConnector: "DatasetConnectorSettings"](
100
+ TorchDatasetSettings[TCDatasetConnector]
101
+ ):
102
+ pass
103
+
104
+
105
+ class InBatchNegativeContrastiveLossSettings(ContrastiveLossSettings):
106
+ pass
107
+
108
+
109
+ class MultiPositiveContrastiveLossSettings(ContrastiveLossSettings):
110
+ n_pos: int
111
+
112
+
113
+ class RunnerSettings(FromConfigMixinSettings):
114
+ pass
115
+
116
+
117
+ class HFSettings(BaseSettings):
118
+ repo: str
119
+ revision: str | None
120
+ private: bool
121
+ commit_message: str | None
122
+
123
+
124
+ class PushToHFRunnerSettings[TCModel: "ModelSettings"](RunnerSettings):
125
+ checkpoint_path: str = Field(init=False)
126
+ device: str = Field(init=False)
127
+ push: bool = Field(init=False)
128
+ create_repo: bool = Field(init=False)
129
+ hf: HFSettings = Field(init=False)
130
+ model: TCModel = Field(init=False)
131
+
132
+
133
+ class TrainRunnerSettings[TCTrainer: "TrainerSettings"](RunnerSettings):
134
+ trainer: TCTrainer
135
+
136
+
137
+ class HardNegativesSettings(BaseSettings):
138
+ range_min: int
139
+ range_max: int
140
+ max_score: float
141
+ relative_margin: float
142
+ num_negatives: int
143
+ sampling_strategy: Literal["random", "top"]
144
+ batch_size: int
145
+ use_faiss: bool
146
+
147
+
148
+ class EvalutationSettings(BaseSettings):
149
+ query_column: str
150
+ document_column: str
151
+ precision_recall_at_k: list[int]
152
+ mrr_at_k: list[int]
153
+ ndcg_at_k: list[int]
154
+ batch_size: int
155
+
156
+
157
+ class SentenceTransformerLoss(LossSettings):
158
+ kwargs: dict[str, Any]
159
+
160
+
161
+ class SentenceTransformersTrainerSettings[
162
+ TCDatasetConnector: "DatasetConnectorSettings",
163
+ TCModel: "ModelSettings",
164
+ ](TrainerSettings):
165
+ hf_dataset: HuggingFaceDatasetAdaptaterSettings[TCDatasetConnector]
166
+ model: TCModel
167
+ tokenizer: TokenizerSettings
168
+ loss: SentenceTransformerLoss
169
+ pooling: Literal["cls", "mean_tokens", "max_tokens"]
170
+ batch_size: int
171
+ num_epochs: int
172
+ lr: float
173
+ warmup_ratio: float
174
+ eval_steps: int
175
+ save_steps: int
176
+ logging_steps: int
177
+ fp16: bool
178
+ train_frac: float
179
+ resume_from: str | None
180
+ evaluation: EvalutationSettings
181
+ trust_remote_code: bool = TRUST_REMOTE_CODE
182
+
183
+ # -------------------------
184
+ # CROSS-FIELD VALIDATION
185
+ # -------------------------
186
+ @model_validator(mode="after")
187
+ def _validate_steps_consistency(self) -> Self:
188
+ # 1️⃣ ordering
189
+ if not (self.logging_steps <= self.eval_steps <= self.save_steps):
190
+ raise EmbedTrainValueError(
191
+ "Expected logging_steps ≤ eval_steps ≤ save_steps, got "
192
+ f"{self.logging_steps} ≤ {self.eval_steps} ≤ {self.save_steps}"
193
+ )
194
+
195
+ # 2️⃣ alignment (CRITICAL)
196
+ if self.save_steps % self.eval_steps != 0:
197
+ raise EmbedTrainValueError(
198
+ "save_steps must be a multiple of eval_steps "
199
+ f"(got save_steps={self.save_steps}, eval_steps={self.eval_steps})"
200
+ )
201
+
202
+ return self
@@ -0,0 +1,17 @@
1
+ from typing import TYPE_CHECKING, Any
2
+
3
+ from embed_train import Runner
4
+ from embed_train.train.trainers import Trainer
5
+ from embed_train.utils import load_class
6
+
7
+ if TYPE_CHECKING:
8
+ from embed_train.settings import TrainRunnerSettings
9
+
10
+
11
+ class TrainRunner[TCTrainRunner: "TrainRunnerSettings[Any]"](Runner[TCTrainRunner]):
12
+ def __init__(self, config: TCTrainRunner):
13
+ super().__init__(config)
14
+ self.trainer: Trainer[Any] = load_class(self.config.trainer.module_path).from_config(self.config.trainer)
15
+
16
+ def run(self) -> None:
17
+ self.trainer.train()
@@ -0,0 +1,109 @@
1
+ import logging
2
+ from abc import ABC, abstractmethod
3
+ from typing import TYPE_CHECKING, Any, cast
4
+
5
+ import torch
6
+ from datasets import Dataset as HFDataset # type: ignore[import-untyped]
7
+ from torch.utils.data import Dataset as TorchDatasetBase
8
+ from transformers import AutoTokenizer, PreTrainedTokenizer
9
+
10
+ from embed_train.utils import load_class
11
+ from retrievalbase.dataset import TextDataset
12
+ from retrievalbase.enums import EmbeddingPurposeEnum
13
+ from retrievalbase.evaluation import Processor
14
+ from retrievalbase.mixins import FromConfigMixin
15
+
16
+ _logger = logging.getLogger(__name__)
17
+
18
+ if TYPE_CHECKING:
19
+ from embed_train.settings import CollateFnSettings, TorchDatasetSettings
20
+
21
+
22
+ class CollateFn[TCCollateFn: "CollateFnSettings[Any]", T: dict[str, Any]](ABC):
23
+ def __init__(self, config: TCCollateFn, context: dict[str, Any] | None) -> None:
24
+ self.config = config
25
+ self.context = context
26
+ self.tokenizer: PreTrainedTokenizer = self._load_tokenizer()
27
+ self.processor: Processor[Any] = self._load_processor()
28
+
29
+ @abstractmethod
30
+ def _process_batch(self, batch: list[T]) -> tuple[list[str], list[str]]:
31
+ raise NotImplementedError()
32
+
33
+ def _load_tokenizer(self) -> PreTrainedTokenizer:
34
+ tokenizer = AutoTokenizer.from_pretrained( # nosec CWE-494
35
+ self.config.tokenizer.name,
36
+ trust_remote_code=self.config.tokenizer.trust_remote_code,
37
+ )
38
+ return cast(PreTrainedTokenizer, tokenizer)
39
+
40
+ def _load_processor(self) -> Processor[Any]:
41
+ processor: Processor[Any] = load_class(self.config.processor.module_path).from_config(self.config.processor)
42
+ return processor
43
+
44
+ def __call__(self, batch: list[T]) -> tuple[torch.Tensor, torch.Tensor]:
45
+ queries, candidates = self._process_batch(batch)
46
+ return self._post_process(queries, candidates)
47
+
48
+ def _post_process(self, queries: list[str], candidates: list[str]) -> tuple[torch.Tensor, torch.Tensor]:
49
+ queries = [
50
+ self.processor(
51
+ text=query,
52
+ purpose=cast(EmbeddingPurposeEnum, EmbeddingPurposeEnum.QUERY),
53
+ )
54
+ for query in queries
55
+ ]
56
+ candidates = [
57
+ self.processor(
58
+ text=candidate,
59
+ purpose=cast(EmbeddingPurposeEnum, EmbeddingPurposeEnum.DOCUMENT),
60
+ )
61
+ for candidate in candidates
62
+ ]
63
+ return self._tokenizer_qc(queries, candidates)
64
+
65
+ def _tokenizer_qc(self, queries: list[str], candidates: list[str]) -> tuple[torch.Tensor, torch.Tensor]:
66
+ q_tok = self.tokenizer(
67
+ queries,
68
+ padding=self.config.tokenizer.padding,
69
+ truncation=self.config.tokenizer.truncation,
70
+ max_length=self.config.tokenizer.max_length,
71
+ return_tensors="pt",
72
+ )
73
+ c_tok = self.tokenizer(
74
+ candidates,
75
+ padding=self.config.tokenizer.padding,
76
+ truncation=self.config.tokenizer.truncation,
77
+ max_length=self.config.tokenizer.max_length,
78
+ return_tensors="pt",
79
+ )
80
+ return q_tok, c_tok
81
+
82
+
83
+ class TorchDataset[TCTorchDataset: "TorchDatasetSettings[Any]", T: dict[str, Any]](
84
+ FromConfigMixin[TCTorchDataset],
85
+ TorchDatasetBase[T],
86
+ ABC,
87
+ ):
88
+ def __init__(self, config: TCTorchDataset) -> None:
89
+ self.config = config
90
+ self.dataset = self._load_dataset()
91
+
92
+ @abstractmethod
93
+ def __len__(self) -> int:
94
+ raise NotImplementedError()
95
+
96
+ def _load_dataset(self) -> TextDataset[Any]:
97
+ dataset: TextDataset[Any] = (
98
+ load_class(self.config.dataset_connector.module_path).from_config(self.config.dataset_connector).load_text()
99
+ )
100
+ _logger.info(f"Instantiating dataset | class={dataset.__class__.__name__} |")
101
+ _logger.info(f"Dataset loaded | type={type(dataset).__name__} | size={len(dataset)}")
102
+ return dataset
103
+
104
+ def to_hf_dataset(self) -> HFDataset:
105
+ """
106
+ Convert *this* TorchDataset (via __getitem__) to a HuggingFace Dataset.
107
+ """
108
+ rows: list[dict[str, Any]] = [self[i] for i in range(len(self))]
109
+ return HFDataset.from_list(rows)
@@ -0,0 +1,52 @@
1
+ import random
2
+ from typing import Any
3
+
4
+ from embed_train.settings import (
5
+ InBatchPositiveCollateFnSettings,
6
+ MultiPositiveInBatchCollateFnSettings,
7
+ )
8
+ from embed_train.train.dataset import CollateFn
9
+
10
+
11
+ class InBatchPositiveCollateFn(CollateFn[InBatchPositiveCollateFnSettings, dict[str, str]]):
12
+ def __init__(
13
+ self,
14
+ config: InBatchPositiveCollateFnSettings,
15
+ context: dict[str, Any] | None,
16
+ ) -> None:
17
+ super().__init__(config, context)
18
+
19
+ def _process_batch(
20
+ self,
21
+ batch: list[dict[str, str]],
22
+ ) -> tuple[list[str], list[str]]:
23
+ queries: list[str] = []
24
+ positives: list[str] = []
25
+
26
+ for item in batch:
27
+ queries.append(item["query"])
28
+ positives.append(random.choice(item["positives"])) # nosec B311
29
+ return queries, positives
30
+
31
+
32
+ class MultiPositiveInBatchCollateFn(CollateFn[MultiPositiveInBatchCollateFnSettings, dict[str, str]]):
33
+ def __init__(
34
+ self,
35
+ config: MultiPositiveInBatchCollateFnSettings,
36
+ context: dict[str, Any] | None,
37
+ ) -> None:
38
+ super().__init__(config, context)
39
+
40
+ def _process_batch(
41
+ self,
42
+ batch: list[dict[str, str]],
43
+ ) -> tuple[list[str], list[str]]:
44
+ queries: list[str] = []
45
+ passages: list[str] = []
46
+ for item in batch:
47
+ query = item["query"]
48
+ positives = item["positives"]
49
+ sampled_positives = [random.choice(positives) for _ in range(self.config.n_pos)] # nosec B311
50
+ queries.append(query)
51
+ passages.extend(sampled_positives)
52
+ return queries, passages
@@ -0,0 +1,46 @@
1
+ import random
2
+ from abc import ABC, abstractmethod
3
+ from typing import TYPE_CHECKING, Any
4
+
5
+ from embed_train.settings import StepRangeDistanceSettings
6
+ from retrievalbase.mixins import FromConfigMixin
7
+
8
+ if TYPE_CHECKING:
9
+ from embed_train.settings import DistanceSettings, SamplerSettings
10
+
11
+
12
+ class Sampler[TCSampler: "SamplerSettings", T: dict[str, Any]](ABC):
13
+ RANDOM_SEED = 42
14
+
15
+ def __init__(self, config: TCSampler, context: dict[str, Any] | None) -> None:
16
+ self.config = config
17
+ self.context = context
18
+ self.rng = random.Random(self.config.seed) # nosec B311
19
+
20
+ @abstractmethod
21
+ def sample(self, item: T, **kwargs: Any) -> list[str]:
22
+ pass
23
+
24
+ def _init_components(self) -> None:
25
+ return None
26
+
27
+
28
+ class Distance[TCDistance: "DistanceSettings"](FromConfigMixin[TCDistance], ABC):
29
+ def __init__(self, config: TCDistance) -> None:
30
+ self.config = config
31
+
32
+ @abstractmethod
33
+ def __call__(self, a: tuple[int, int], b: tuple[int, int]) -> int:
34
+ raise NotImplementedError
35
+
36
+
37
+ class StepRangeDistance(Distance[StepRangeDistanceSettings]):
38
+ def __init__(self, config: StepRangeDistanceSettings) -> None:
39
+ super().__init__(config)
40
+
41
+ def __call__(self, a: tuple[int, int], b: tuple[int, int]) -> int:
42
+ if a[1] < b[0]:
43
+ return b[0] - a[1]
44
+ if b[1] < a[0]:
45
+ return a[0] - b[1]
46
+ return 0