embed-train 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- embed_train/__init__.py +37 -0
- embed_train/constants.py +3 -0
- embed_train/exceptions.py +31 -0
- embed_train/models/__init__.py +61 -0
- embed_train/push_to_hf/__init__.py +131 -0
- embed_train/py.typed +0 -0
- embed_train/settings.py +202 -0
- embed_train/train/__init__.py +17 -0
- embed_train/train/dataset/__init__.py +109 -0
- embed_train/train/dataset/collate.py +52 -0
- embed_train/train/dataset/sampling/__init__.py +46 -0
- embed_train/train/dataset/sampling/samplers.py +36 -0
- embed_train/train/dataset/torch_datasets.py +71 -0
- embed_train/train/trainers/__init__.py +22 -0
- embed_train/train/trainers/hf/__init__.py +158 -0
- embed_train/train/trainers/torch/__init__.py +226 -0
- embed_train/train/trainers/torch/loss.py +99 -0
- embed_train/utils.py +80 -0
- embed_train-1.0.0.dist-info/METADATA +283 -0
- embed_train-1.0.0.dist-info/RECORD +21 -0
- embed_train-1.0.0.dist-info/WHEEL +4 -0
embed_train/__init__.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import TYPE_CHECKING, Any
|
|
3
|
+
|
|
4
|
+
import yaml
|
|
5
|
+
|
|
6
|
+
from embed_train.exceptions import InvalidRunnerClassError
|
|
7
|
+
from embed_train.utils import load_class
|
|
8
|
+
from retrievalbase.mixins import FromConfigMixin
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from embed_train.settings import RunnerSettings
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Runner[TCRunner: "RunnerSettings"](FromConfigMixin[TCRunner], ABC):
|
|
15
|
+
def __init__(self, config: TCRunner):
|
|
16
|
+
self.config = config
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def run(self) -> None:
|
|
20
|
+
raise NotImplementedError()
|
|
21
|
+
|
|
22
|
+
@classmethod
|
|
23
|
+
def _validate_klass(self, symbol: Any) -> type["Runner[TCRunner]"]:
|
|
24
|
+
if not isinstance(symbol, type):
|
|
25
|
+
raise InvalidRunnerClassError(f"Loaded object is not a class: {symbol!r}")
|
|
26
|
+
# 2️⃣ Ensure it is a Runner subclass
|
|
27
|
+
if not issubclass(symbol, Runner):
|
|
28
|
+
raise InvalidRunnerClassError(f"Class {symbol.__module__}.{symbol.__name__} is not a subclass of Runner")
|
|
29
|
+
return symbol
|
|
30
|
+
|
|
31
|
+
@classmethod
|
|
32
|
+
def get_runner(cls, config_path: str) -> "Runner[TCRunner]":
|
|
33
|
+
with open(config_path) as f:
|
|
34
|
+
yaml_config = yaml.safe_load(f)
|
|
35
|
+
symbol = load_class(yaml_config["module_path"])
|
|
36
|
+
klass: type[Runner[TCRunner]] = cls._validate_klass(symbol)
|
|
37
|
+
return klass.from_settings()
|
embed_train/constants.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
class EmbedTrainError(Exception):
|
|
2
|
+
"""Base exception for all embed_train errors."""
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
# -------------------------
|
|
6
|
+
# Runtime errors
|
|
7
|
+
# -------------------------
|
|
8
|
+
class EmbedTrainRuntimeError(RuntimeError, EmbedTrainError):
|
|
9
|
+
"""Runtime errors during execution (training, loading, etc.)."""
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class MissingContextError(EmbedTrainRuntimeError):
|
|
13
|
+
"""Raised when required context is missing."""
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# -------------------------
|
|
17
|
+
# Type errors
|
|
18
|
+
# -------------------------
|
|
19
|
+
class EmbedTrainTypeError(TypeError, EmbedTrainError):
|
|
20
|
+
"""Invalid type provided."""
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class InvalidRunnerClassError(EmbedTrainTypeError):
|
|
24
|
+
"""Runner class is not valid or does not inherit expected base."""
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
# -------------------------
|
|
28
|
+
# Value errors
|
|
29
|
+
# -------------------------
|
|
30
|
+
class EmbedTrainValueError(ValueError, EmbedTrainError):
|
|
31
|
+
"""Invalid value provided in configuration or runtime."""
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Self, cast
|
|
4
|
+
|
|
5
|
+
from torch import nn
|
|
6
|
+
from transformers import PreTrainedModel
|
|
7
|
+
|
|
8
|
+
from embed_train.utils import load_checkpoint_state_dict
|
|
9
|
+
from retrievalbase.mixins import FromConfigMixin
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from embed_train.settings import ModelSettings
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Model[TCModel: "ModelSettings"](FromConfigMixin[TCModel], ABC):
|
|
16
|
+
def __init__(self, config: TCModel):
|
|
17
|
+
self.config = config
|
|
18
|
+
|
|
19
|
+
@abstractmethod
|
|
20
|
+
def to_hf_model(self) -> PreTrainedModel: # inplace
|
|
21
|
+
raise NotImplementedError()
|
|
22
|
+
|
|
23
|
+
def save(
|
|
24
|
+
self,
|
|
25
|
+
repo_dir: str | Path,
|
|
26
|
+
push: bool = False,
|
|
27
|
+
**push_kwargs: Any,
|
|
28
|
+
) -> None:
|
|
29
|
+
repo_dir_path = Path(repo_dir) if isinstance(repo_dir, str) else repo_dir
|
|
30
|
+
self.to_hf_model().save_pretrained(repo_dir_path)
|
|
31
|
+
if push:
|
|
32
|
+
self.to_hf_model().push_to_hub(
|
|
33
|
+
repo_id=repo_dir_path.name,
|
|
34
|
+
**push_kwargs,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
def to(self, device: str) -> Self:
|
|
38
|
+
cast(nn.Module, self.to_hf_model()).to(device)
|
|
39
|
+
return self
|
|
40
|
+
|
|
41
|
+
@classmethod
|
|
42
|
+
def from_checkpoint(
|
|
43
|
+
cls,
|
|
44
|
+
config: TCModel,
|
|
45
|
+
checkpoint: str,
|
|
46
|
+
device: str,
|
|
47
|
+
strict: bool = True,
|
|
48
|
+
) -> "Model[TCModel]":
|
|
49
|
+
model = cls(config)
|
|
50
|
+
hf_model = model.to_hf_model()
|
|
51
|
+
checkpoint_path = Path(checkpoint)
|
|
52
|
+
if checkpoint_path.is_dir():
|
|
53
|
+
state_dict_path = checkpoint_path / "model.safetensors"
|
|
54
|
+
if not state_dict_path.exists():
|
|
55
|
+
raise FileNotFoundError(f"No model.safetensors found in {checkpoint_path}")
|
|
56
|
+
state_dict = load_checkpoint_state_dict(state_dict_path)
|
|
57
|
+
else:
|
|
58
|
+
state_dict = load_checkpoint_state_dict(checkpoint_path)
|
|
59
|
+
hf_model.load_state_dict(state_dict, strict=strict)
|
|
60
|
+
hf_model.to(device) # ty: ignore[invalid-argument-type]
|
|
61
|
+
return model
|
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
# finetune_embedder/src/runners/push_to_hf.py
|
|
2
|
+
import inspect
|
|
3
|
+
import logging
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
|
|
7
|
+
from huggingface_hub import HfApi
|
|
8
|
+
|
|
9
|
+
from embed_train import Runner
|
|
10
|
+
from embed_train.models import Model
|
|
11
|
+
from embed_train.utils import load_class
|
|
12
|
+
|
|
13
|
+
_logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
if TYPE_CHECKING:
|
|
16
|
+
from embed_train.settings import PushToHFRunnerSettings
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class PushToHFRunner[TCPushToHFRunner: "PushToHFRunnerSettings[Any]"](Runner[TCPushToHFRunner]):
|
|
20
|
+
def __init__(self, config: TCPushToHFRunner) -> None:
|
|
21
|
+
super().__init__(config)
|
|
22
|
+
|
|
23
|
+
def _infer_model_name(self) -> str:
|
|
24
|
+
name: str = self.wrapper_cls.__name__
|
|
25
|
+
if not name.endswith("ModelWrapper"):
|
|
26
|
+
raise ValueError(f"Model wrapper class must end with 'ModelWrapper', got {name}")
|
|
27
|
+
return name.removesuffix("ModelWrapper").lower()
|
|
28
|
+
|
|
29
|
+
def _find_model_source_dir(self) -> Path:
|
|
30
|
+
module_file = Path(inspect.getfile(self.wrapper_cls)).resolve()
|
|
31
|
+
models_dir = module_file.parent
|
|
32
|
+
return models_dir
|
|
33
|
+
|
|
34
|
+
def _collect_model_files(self) -> list[Path]:
|
|
35
|
+
model_name = self._infer_model_name()
|
|
36
|
+
models_dir = self._find_model_source_dir()
|
|
37
|
+
|
|
38
|
+
model_dir = models_dir / model_name
|
|
39
|
+
if not model_dir.exists():
|
|
40
|
+
raise FileNotFoundError(f"Model directory not found: {model_dir}")
|
|
41
|
+
|
|
42
|
+
files: list[Path] = []
|
|
43
|
+
|
|
44
|
+
for pattern in (
|
|
45
|
+
"modeling_*.py",
|
|
46
|
+
"configuration_*.py",
|
|
47
|
+
"vllm_modeling_*.py",
|
|
48
|
+
"vllm_configuration_*.py",
|
|
49
|
+
):
|
|
50
|
+
files.extend(model_dir.glob(pattern))
|
|
51
|
+
|
|
52
|
+
if not files:
|
|
53
|
+
raise RuntimeError(f"No model source files found in {model_dir}")
|
|
54
|
+
|
|
55
|
+
return sorted(set(files))
|
|
56
|
+
|
|
57
|
+
def _save_model_source_files(self) -> None:
|
|
58
|
+
files = self._collect_model_files()
|
|
59
|
+
for src in files:
|
|
60
|
+
dst = Path(self.config.hf.repo) / src.name
|
|
61
|
+
dst.write_text(src.read_text(encoding="utf-8"), encoding="utf-8")
|
|
62
|
+
|
|
63
|
+
def _load_model(self) -> Model[Any]:
|
|
64
|
+
_logger.info(f"Loading model class from {self.config.model.module_path}")
|
|
65
|
+
model_cls: Model[Any] = load_class(self.config.model.module_path)
|
|
66
|
+
return model_cls.from_checkpoint(
|
|
67
|
+
config=self.config.model,
|
|
68
|
+
checkpoint=self.config.checkpoint_path,
|
|
69
|
+
device=self.config.device,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
def _load_wrapper(self) -> None:
|
|
73
|
+
_logger.info(f"Loading model wrapper → {self.config.model.module_path}")
|
|
74
|
+
self.wrapper_cls = load_class(self.config.model.module_path)
|
|
75
|
+
|
|
76
|
+
def _load_model_from_checkpoint(self) -> None:
|
|
77
|
+
_logger.info(f"Loading model from checkpoint → {self.config.checkpoint_path}")
|
|
78
|
+
self.model: Model[Any] = self.wrapper_cls.from_checkpoint(
|
|
79
|
+
config=self.config.model,
|
|
80
|
+
checkpoint=self.config.checkpoint_path,
|
|
81
|
+
device=self.config.device,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
def _save_model_locally(self) -> None:
|
|
85
|
+
_logger.info("Saving model locally → %s", self.config.hf.repo)
|
|
86
|
+
self.model.save(
|
|
87
|
+
repo_dir=self.config.hf.repo,
|
|
88
|
+
push=False, # we control pushing
|
|
89
|
+
create_repo=self.config.create_repo,
|
|
90
|
+
revision=self.config.hf.revision,
|
|
91
|
+
private=self.config.hf.private,
|
|
92
|
+
commit_message=self.config.hf.commit_message,
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
def _ensure_hf_repo_exists(self) -> None:
|
|
96
|
+
if not self.config.create_repo:
|
|
97
|
+
return
|
|
98
|
+
_logger.info(
|
|
99
|
+
f"Ensuring Hugging Face repo exists | repo={self.config.hf.repo} | private={self.config.hf.private}"
|
|
100
|
+
)
|
|
101
|
+
api = HfApi()
|
|
102
|
+
api.create_repo(
|
|
103
|
+
repo_id=self.config.hf.repo,
|
|
104
|
+
repo_type="model",
|
|
105
|
+
private=self.config.hf.private,
|
|
106
|
+
exist_ok=True, # safe for retries / CI
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
def _push_repo_to_hf(self) -> None:
|
|
110
|
+
_logger.info(
|
|
111
|
+
f"Pushing full repository folder to Hugging Face | repo={self.config.hf.repo} | revision={self.config.hf.revision}"
|
|
112
|
+
)
|
|
113
|
+
api = HfApi()
|
|
114
|
+
api.upload_folder(
|
|
115
|
+
folder_path=self.config.hf.repo,
|
|
116
|
+
repo_id=self.config.hf.repo,
|
|
117
|
+
repo_type="model",
|
|
118
|
+
revision=self.config.hf.revision,
|
|
119
|
+
commit_message=self.config.hf.commit_message,
|
|
120
|
+
)
|
|
121
|
+
_logger.info("Repository successfully pushed to Hugging Face")
|
|
122
|
+
|
|
123
|
+
def run(self) -> None:
|
|
124
|
+
self._load_wrapper()
|
|
125
|
+
self._load_model_from_checkpoint()
|
|
126
|
+
self._save_model_locally()
|
|
127
|
+
self._save_model_source_files()
|
|
128
|
+
if self.config.push:
|
|
129
|
+
self._ensure_hf_repo_exists()
|
|
130
|
+
self._push_repo_to_hf()
|
|
131
|
+
_logger.info("PushToHFRunner completed successfully")
|
embed_train/py.typed
ADDED
|
File without changes
|
embed_train/settings.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import TYPE_CHECKING, Any, Literal, Self
|
|
3
|
+
|
|
4
|
+
from pydantic import Field, model_validator
|
|
5
|
+
from pydantic_settings import BaseSettings
|
|
6
|
+
|
|
7
|
+
from embed_train.constants import TRUST_REMOTE_CODE
|
|
8
|
+
from embed_train.exceptions import EmbedTrainValueError
|
|
9
|
+
from retrievalbase.dataset.settings import HuggingFaceDatasetAdaptaterSettings
|
|
10
|
+
from retrievalbase.settings import FromConfigMixinSettings
|
|
11
|
+
|
|
12
|
+
if TYPE_CHECKING:
|
|
13
|
+
from retrievalbase.connector.settings import DatasetConnectorSettings
|
|
14
|
+
from retrievalbase.evaluation.settings import ProcessorSettings
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TorchDatasetSettings[TCDatasetConnector: "DatasetConnectorSettings"](FromConfigMixinSettings):
|
|
18
|
+
dataset_connector: TCDatasetConnector
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class TokenizerSettings(BaseSettings):
|
|
22
|
+
name: str
|
|
23
|
+
padding: str | bool
|
|
24
|
+
truncation: bool
|
|
25
|
+
max_length: int
|
|
26
|
+
trust_remote_code: bool = TRUST_REMOTE_CODE
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class SamplerSettings(BaseSettings):
|
|
30
|
+
module_path: str
|
|
31
|
+
seed: int | None
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class DistanceSettings(FromConfigMixinSettings):
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class StepRangeDistanceSettings(DistanceSettings):
|
|
39
|
+
pass
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class PositiveSamplerSettings[TCDistance: "DistanceSettings"](SamplerSettings):
|
|
43
|
+
max_step_distance: int
|
|
44
|
+
distance: TCDistance
|
|
45
|
+
k: int
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class CollateFnSettings[TCProcessor: "ProcessorSettings"](BaseSettings):
|
|
49
|
+
module_path: str
|
|
50
|
+
tokenizer: TokenizerSettings
|
|
51
|
+
processor: TCProcessor
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class InBatchPositiveCollateFnSettings[TCProcessor: "ProcessorSettings"](CollateFnSettings[TCProcessor]):
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class MultiPositiveInBatchCollateFnSettings[TCProcessor: "ProcessorSettings"](CollateFnSettings[TCProcessor]):
|
|
59
|
+
n_pos: int
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class ModelSettings(FromConfigMixinSettings):
|
|
63
|
+
pass
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
class LossSettings(FromConfigMixinSettings):
|
|
67
|
+
pass
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class TrainerSettings(FromConfigMixinSettings):
|
|
71
|
+
data_dir: Path
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class PyTorchTrainerSettings[
|
|
75
|
+
TCModel: "ModelSettings",
|
|
76
|
+
TCLoss: "LossSettings",
|
|
77
|
+
TCTorchDataset: "TorchDatasetSettings[Any]",
|
|
78
|
+
TCCollateFn: "CollateFnSettings[Any]",
|
|
79
|
+
](TrainerSettings):
|
|
80
|
+
model: TCModel
|
|
81
|
+
torch_dataset: TCTorchDataset
|
|
82
|
+
collate_fn: TCCollateFn
|
|
83
|
+
train_frac: float
|
|
84
|
+
num_epochs: int
|
|
85
|
+
batch_size: int
|
|
86
|
+
shuffle: bool
|
|
87
|
+
lr: float
|
|
88
|
+
device: str
|
|
89
|
+
save_every: int
|
|
90
|
+
drop_last: bool
|
|
91
|
+
loss: TCLoss
|
|
92
|
+
resume_from: str | None = None
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class ContrastiveLossSettings(LossSettings):
|
|
96
|
+
temperature: float
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class QueryMultiPositiveDatasetSettings[TCDatasetConnector: "DatasetConnectorSettings"](
|
|
100
|
+
TorchDatasetSettings[TCDatasetConnector]
|
|
101
|
+
):
|
|
102
|
+
pass
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class InBatchNegativeContrastiveLossSettings(ContrastiveLossSettings):
|
|
106
|
+
pass
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
class MultiPositiveContrastiveLossSettings(ContrastiveLossSettings):
|
|
110
|
+
n_pos: int
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class RunnerSettings(FromConfigMixinSettings):
|
|
114
|
+
pass
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class HFSettings(BaseSettings):
|
|
118
|
+
repo: str
|
|
119
|
+
revision: str | None
|
|
120
|
+
private: bool
|
|
121
|
+
commit_message: str | None
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
class PushToHFRunnerSettings[TCModel: "ModelSettings"](RunnerSettings):
|
|
125
|
+
checkpoint_path: str = Field(init=False)
|
|
126
|
+
device: str = Field(init=False)
|
|
127
|
+
push: bool = Field(init=False)
|
|
128
|
+
create_repo: bool = Field(init=False)
|
|
129
|
+
hf: HFSettings = Field(init=False)
|
|
130
|
+
model: TCModel = Field(init=False)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
class TrainRunnerSettings[TCTrainer: "TrainerSettings"](RunnerSettings):
|
|
134
|
+
trainer: TCTrainer
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class HardNegativesSettings(BaseSettings):
|
|
138
|
+
range_min: int
|
|
139
|
+
range_max: int
|
|
140
|
+
max_score: float
|
|
141
|
+
relative_margin: float
|
|
142
|
+
num_negatives: int
|
|
143
|
+
sampling_strategy: Literal["random", "top"]
|
|
144
|
+
batch_size: int
|
|
145
|
+
use_faiss: bool
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
class EvalutationSettings(BaseSettings):
|
|
149
|
+
query_column: str
|
|
150
|
+
document_column: str
|
|
151
|
+
precision_recall_at_k: list[int]
|
|
152
|
+
mrr_at_k: list[int]
|
|
153
|
+
ndcg_at_k: list[int]
|
|
154
|
+
batch_size: int
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
class SentenceTransformerLoss(LossSettings):
|
|
158
|
+
kwargs: dict[str, Any]
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class SentenceTransformersTrainerSettings[
|
|
162
|
+
TCDatasetConnector: "DatasetConnectorSettings",
|
|
163
|
+
TCModel: "ModelSettings",
|
|
164
|
+
](TrainerSettings):
|
|
165
|
+
hf_dataset: HuggingFaceDatasetAdaptaterSettings[TCDatasetConnector]
|
|
166
|
+
model: TCModel
|
|
167
|
+
tokenizer: TokenizerSettings
|
|
168
|
+
loss: SentenceTransformerLoss
|
|
169
|
+
pooling: Literal["cls", "mean_tokens", "max_tokens"]
|
|
170
|
+
batch_size: int
|
|
171
|
+
num_epochs: int
|
|
172
|
+
lr: float
|
|
173
|
+
warmup_ratio: float
|
|
174
|
+
eval_steps: int
|
|
175
|
+
save_steps: int
|
|
176
|
+
logging_steps: int
|
|
177
|
+
fp16: bool
|
|
178
|
+
train_frac: float
|
|
179
|
+
resume_from: str | None
|
|
180
|
+
evaluation: EvalutationSettings
|
|
181
|
+
trust_remote_code: bool = TRUST_REMOTE_CODE
|
|
182
|
+
|
|
183
|
+
# -------------------------
|
|
184
|
+
# CROSS-FIELD VALIDATION
|
|
185
|
+
# -------------------------
|
|
186
|
+
@model_validator(mode="after")
|
|
187
|
+
def _validate_steps_consistency(self) -> Self:
|
|
188
|
+
# 1️⃣ ordering
|
|
189
|
+
if not (self.logging_steps <= self.eval_steps <= self.save_steps):
|
|
190
|
+
raise EmbedTrainValueError(
|
|
191
|
+
"Expected logging_steps ≤ eval_steps ≤ save_steps, got "
|
|
192
|
+
f"{self.logging_steps} ≤ {self.eval_steps} ≤ {self.save_steps}"
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
# 2️⃣ alignment (CRITICAL)
|
|
196
|
+
if self.save_steps % self.eval_steps != 0:
|
|
197
|
+
raise EmbedTrainValueError(
|
|
198
|
+
"save_steps must be a multiple of eval_steps "
|
|
199
|
+
f"(got save_steps={self.save_steps}, eval_steps={self.eval_steps})"
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
return self
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from typing import TYPE_CHECKING, Any
|
|
2
|
+
|
|
3
|
+
from embed_train import Runner
|
|
4
|
+
from embed_train.train.trainers import Trainer
|
|
5
|
+
from embed_train.utils import load_class
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from embed_train.settings import TrainRunnerSettings
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TrainRunner[TCTrainRunner: "TrainRunnerSettings[Any]"](Runner[TCTrainRunner]):
|
|
12
|
+
def __init__(self, config: TCTrainRunner):
|
|
13
|
+
super().__init__(config)
|
|
14
|
+
self.trainer: Trainer[Any] = load_class(self.config.trainer.module_path).from_config(self.config.trainer)
|
|
15
|
+
|
|
16
|
+
def run(self) -> None:
|
|
17
|
+
self.trainer.train()
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import TYPE_CHECKING, Any, cast
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from datasets import Dataset as HFDataset # type: ignore[import-untyped]
|
|
7
|
+
from torch.utils.data import Dataset as TorchDatasetBase
|
|
8
|
+
from transformers import AutoTokenizer, PreTrainedTokenizer
|
|
9
|
+
|
|
10
|
+
from embed_train.utils import load_class
|
|
11
|
+
from retrievalbase.dataset import TextDataset
|
|
12
|
+
from retrievalbase.enums import EmbeddingPurposeEnum
|
|
13
|
+
from retrievalbase.evaluation import Processor
|
|
14
|
+
from retrievalbase.mixins import FromConfigMixin
|
|
15
|
+
|
|
16
|
+
_logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from embed_train.settings import CollateFnSettings, TorchDatasetSettings
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class CollateFn[TCCollateFn: "CollateFnSettings[Any]", T: dict[str, Any]](ABC):
|
|
23
|
+
def __init__(self, config: TCCollateFn, context: dict[str, Any] | None) -> None:
|
|
24
|
+
self.config = config
|
|
25
|
+
self.context = context
|
|
26
|
+
self.tokenizer: PreTrainedTokenizer = self._load_tokenizer()
|
|
27
|
+
self.processor: Processor[Any] = self._load_processor()
|
|
28
|
+
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def _process_batch(self, batch: list[T]) -> tuple[list[str], list[str]]:
|
|
31
|
+
raise NotImplementedError()
|
|
32
|
+
|
|
33
|
+
def _load_tokenizer(self) -> PreTrainedTokenizer:
|
|
34
|
+
tokenizer = AutoTokenizer.from_pretrained( # nosec CWE-494
|
|
35
|
+
self.config.tokenizer.name,
|
|
36
|
+
trust_remote_code=self.config.tokenizer.trust_remote_code,
|
|
37
|
+
)
|
|
38
|
+
return cast(PreTrainedTokenizer, tokenizer)
|
|
39
|
+
|
|
40
|
+
def _load_processor(self) -> Processor[Any]:
|
|
41
|
+
processor: Processor[Any] = load_class(self.config.processor.module_path).from_config(self.config.processor)
|
|
42
|
+
return processor
|
|
43
|
+
|
|
44
|
+
def __call__(self, batch: list[T]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
45
|
+
queries, candidates = self._process_batch(batch)
|
|
46
|
+
return self._post_process(queries, candidates)
|
|
47
|
+
|
|
48
|
+
def _post_process(self, queries: list[str], candidates: list[str]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
49
|
+
queries = [
|
|
50
|
+
self.processor(
|
|
51
|
+
text=query,
|
|
52
|
+
purpose=cast(EmbeddingPurposeEnum, EmbeddingPurposeEnum.QUERY),
|
|
53
|
+
)
|
|
54
|
+
for query in queries
|
|
55
|
+
]
|
|
56
|
+
candidates = [
|
|
57
|
+
self.processor(
|
|
58
|
+
text=candidate,
|
|
59
|
+
purpose=cast(EmbeddingPurposeEnum, EmbeddingPurposeEnum.DOCUMENT),
|
|
60
|
+
)
|
|
61
|
+
for candidate in candidates
|
|
62
|
+
]
|
|
63
|
+
return self._tokenizer_qc(queries, candidates)
|
|
64
|
+
|
|
65
|
+
def _tokenizer_qc(self, queries: list[str], candidates: list[str]) -> tuple[torch.Tensor, torch.Tensor]:
|
|
66
|
+
q_tok = self.tokenizer(
|
|
67
|
+
queries,
|
|
68
|
+
padding=self.config.tokenizer.padding,
|
|
69
|
+
truncation=self.config.tokenizer.truncation,
|
|
70
|
+
max_length=self.config.tokenizer.max_length,
|
|
71
|
+
return_tensors="pt",
|
|
72
|
+
)
|
|
73
|
+
c_tok = self.tokenizer(
|
|
74
|
+
candidates,
|
|
75
|
+
padding=self.config.tokenizer.padding,
|
|
76
|
+
truncation=self.config.tokenizer.truncation,
|
|
77
|
+
max_length=self.config.tokenizer.max_length,
|
|
78
|
+
return_tensors="pt",
|
|
79
|
+
)
|
|
80
|
+
return q_tok, c_tok
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class TorchDataset[TCTorchDataset: "TorchDatasetSettings[Any]", T: dict[str, Any]](
|
|
84
|
+
FromConfigMixin[TCTorchDataset],
|
|
85
|
+
TorchDatasetBase[T],
|
|
86
|
+
ABC,
|
|
87
|
+
):
|
|
88
|
+
def __init__(self, config: TCTorchDataset) -> None:
|
|
89
|
+
self.config = config
|
|
90
|
+
self.dataset = self._load_dataset()
|
|
91
|
+
|
|
92
|
+
@abstractmethod
|
|
93
|
+
def __len__(self) -> int:
|
|
94
|
+
raise NotImplementedError()
|
|
95
|
+
|
|
96
|
+
def _load_dataset(self) -> TextDataset[Any]:
|
|
97
|
+
dataset: TextDataset[Any] = (
|
|
98
|
+
load_class(self.config.dataset_connector.module_path).from_config(self.config.dataset_connector).load_text()
|
|
99
|
+
)
|
|
100
|
+
_logger.info(f"Instantiating dataset | class={dataset.__class__.__name__} |")
|
|
101
|
+
_logger.info(f"Dataset loaded | type={type(dataset).__name__} | size={len(dataset)}")
|
|
102
|
+
return dataset
|
|
103
|
+
|
|
104
|
+
def to_hf_dataset(self) -> HFDataset:
|
|
105
|
+
"""
|
|
106
|
+
Convert *this* TorchDataset (via __getitem__) to a HuggingFace Dataset.
|
|
107
|
+
"""
|
|
108
|
+
rows: list[dict[str, Any]] = [self[i] for i in range(len(self))]
|
|
109
|
+
return HFDataset.from_list(rows)
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
from embed_train.settings import (
|
|
5
|
+
InBatchPositiveCollateFnSettings,
|
|
6
|
+
MultiPositiveInBatchCollateFnSettings,
|
|
7
|
+
)
|
|
8
|
+
from embed_train.train.dataset import CollateFn
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class InBatchPositiveCollateFn(CollateFn[InBatchPositiveCollateFnSettings, dict[str, str]]):
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
config: InBatchPositiveCollateFnSettings,
|
|
15
|
+
context: dict[str, Any] | None,
|
|
16
|
+
) -> None:
|
|
17
|
+
super().__init__(config, context)
|
|
18
|
+
|
|
19
|
+
def _process_batch(
|
|
20
|
+
self,
|
|
21
|
+
batch: list[dict[str, str]],
|
|
22
|
+
) -> tuple[list[str], list[str]]:
|
|
23
|
+
queries: list[str] = []
|
|
24
|
+
positives: list[str] = []
|
|
25
|
+
|
|
26
|
+
for item in batch:
|
|
27
|
+
queries.append(item["query"])
|
|
28
|
+
positives.append(random.choice(item["positives"])) # nosec B311
|
|
29
|
+
return queries, positives
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class MultiPositiveInBatchCollateFn(CollateFn[MultiPositiveInBatchCollateFnSettings, dict[str, str]]):
|
|
33
|
+
def __init__(
|
|
34
|
+
self,
|
|
35
|
+
config: MultiPositiveInBatchCollateFnSettings,
|
|
36
|
+
context: dict[str, Any] | None,
|
|
37
|
+
) -> None:
|
|
38
|
+
super().__init__(config, context)
|
|
39
|
+
|
|
40
|
+
def _process_batch(
|
|
41
|
+
self,
|
|
42
|
+
batch: list[dict[str, str]],
|
|
43
|
+
) -> tuple[list[str], list[str]]:
|
|
44
|
+
queries: list[str] = []
|
|
45
|
+
passages: list[str] = []
|
|
46
|
+
for item in batch:
|
|
47
|
+
query = item["query"]
|
|
48
|
+
positives = item["positives"]
|
|
49
|
+
sampled_positives = [random.choice(positives) for _ in range(self.config.n_pos)] # nosec B311
|
|
50
|
+
queries.append(query)
|
|
51
|
+
passages.extend(sampled_positives)
|
|
52
|
+
return queries, passages
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import TYPE_CHECKING, Any
|
|
4
|
+
|
|
5
|
+
from embed_train.settings import StepRangeDistanceSettings
|
|
6
|
+
from retrievalbase.mixins import FromConfigMixin
|
|
7
|
+
|
|
8
|
+
if TYPE_CHECKING:
|
|
9
|
+
from embed_train.settings import DistanceSettings, SamplerSettings
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Sampler[TCSampler: "SamplerSettings", T: dict[str, Any]](ABC):
|
|
13
|
+
RANDOM_SEED = 42
|
|
14
|
+
|
|
15
|
+
def __init__(self, config: TCSampler, context: dict[str, Any] | None) -> None:
|
|
16
|
+
self.config = config
|
|
17
|
+
self.context = context
|
|
18
|
+
self.rng = random.Random(self.config.seed) # nosec B311
|
|
19
|
+
|
|
20
|
+
@abstractmethod
|
|
21
|
+
def sample(self, item: T, **kwargs: Any) -> list[str]:
|
|
22
|
+
pass
|
|
23
|
+
|
|
24
|
+
def _init_components(self) -> None:
|
|
25
|
+
return None
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class Distance[TCDistance: "DistanceSettings"](FromConfigMixin[TCDistance], ABC):
|
|
29
|
+
def __init__(self, config: TCDistance) -> None:
|
|
30
|
+
self.config = config
|
|
31
|
+
|
|
32
|
+
@abstractmethod
|
|
33
|
+
def __call__(self, a: tuple[int, int], b: tuple[int, int]) -> int:
|
|
34
|
+
raise NotImplementedError
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class StepRangeDistance(Distance[StepRangeDistanceSettings]):
|
|
38
|
+
def __init__(self, config: StepRangeDistanceSettings) -> None:
|
|
39
|
+
super().__init__(config)
|
|
40
|
+
|
|
41
|
+
def __call__(self, a: tuple[int, int], b: tuple[int, int]) -> int:
|
|
42
|
+
if a[1] < b[0]:
|
|
43
|
+
return b[0] - a[1]
|
|
44
|
+
if b[1] < a[0]:
|
|
45
|
+
return a[0] - b[1]
|
|
46
|
+
return 0
|