embed-train 3.0.0__tar.gz → 3.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {embed_train-3.0.0 → embed_train-3.2.0}/AGENTS.md +9 -4
- {embed_train-3.0.0 → embed_train-3.2.0}/CHANGELOG.md +14 -0
- embed_train-3.0.0/README.md → embed_train-3.2.0/PKG-INFO +40 -5
- embed_train-3.0.0/PKG-INFO → embed_train-3.2.0/README.md +26 -18
- {embed_train-3.0.0 → embed_train-3.2.0}/pyproject.toml +8 -7
- {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/settings.py +50 -11
- {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/train/dataset/__init__.py +10 -1
- {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/train/dataset/collate.py +34 -0
- embed_train-3.2.0/src/embed_train/train/dataset/hard_negatives.py +136 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/train/dataset/torch_datasets.py +41 -1
- {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/train/trainers/hf/__init__.py +5 -5
- {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/train/trainers/torch/__init__.py +1 -1
- {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/train/trainers/torch/loss.py +24 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/tests/fixtures/components.py +54 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/tests/unit/test_push_to_hf.py +23 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/tests/unit/test_settings.py +4 -14
- embed_train-3.2.0/tests/unit/test_train/test_collate.py +83 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/tests/unit/test_train/test_dataset.py +15 -1
- embed_train-3.2.0/tests/unit/test_train/test_hard_negatives.py +117 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/tests/unit/test_train/test_hf_trainer.py +1 -1
- embed_train-3.2.0/tests/unit/test_train/test_loss.py +83 -0
- embed_train-3.2.0/tests/unit/test_train/test_torch_datasets.py +83 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/uv.lock +22 -65
- embed_train-3.0.0/tests/unit/test_train/test_collate.py +0 -42
- embed_train-3.0.0/tests/unit/test_train/test_loss.py +0 -39
- embed_train-3.0.0/tests/unit/test_train/test_torch_datasets.py +0 -33
- {embed_train-3.0.0 → embed_train-3.2.0}/.gitignore +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/.gitlab-ci.yml +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/.pre-commit-config.yaml +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/.releaserc.json +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/Makefile +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/codecov.yml +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/commitlint.config.cjs +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/__init__.py +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/constants.py +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/exceptions.py +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/models/__init__.py +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/push_to_hf/__init__.py +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/py.typed +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/train/__init__.py +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/train/dataset/sampling/__init__.py +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/train/dataset/sampling/samplers.py +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/train/trainers/__init__.py +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/utils.py +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/tests/__init__.py +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/tests/conftest.py +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/tests/fixtures/__init__.py +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/tests/fixtures/data.py +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/tests/integration/__init__.py +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/tests/integration/test_dataset/__init__.py +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/tests/integration/test_dataset/test_to_hf_dataset.py +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/tests/integration/test_train_runner/__init__.py +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/tests/integration/test_train_runner/test_train_runner_flow.py +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/tests/unit/__init__.py +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/tests/unit/test_abstract_guards.py +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/tests/unit/test_embed_train.py +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/tests/unit/test_exceptions.py +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/tests/unit/test_models.py +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/tests/unit/test_train/__init__.py +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/tests/unit/test_train/test_runner.py +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/tests/unit/test_train/test_samplers.py +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/tests/unit/test_train/test_sampling.py +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/tests/unit/test_train/test_torch_trainer.py +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/tests/unit/test_train/test_trainers.py +0 -0
- {embed_train-3.0.0 → embed_train-3.2.0}/tests/unit/test_utils.py +0 -0
|
@@ -36,13 +36,15 @@ Key modules:
|
|
|
36
36
|
- `src/embed_train/train/trainers/hf/__init__.py`
|
|
37
37
|
SentenceTransformers-based training path with `InformationRetrievalEvaluator`.
|
|
38
38
|
- `src/embed_train/train/dataset/__init__.py`
|
|
39
|
-
Base `TorchDataset` and `CollateFn` abstractions.
|
|
39
|
+
Base `TorchDataset`, `HardNegativeMiner`, and `CollateFn` abstractions.
|
|
40
40
|
- `src/embed_train/train/dataset/collate.py`
|
|
41
|
-
Built-in collate functions for in-batch positive training.
|
|
41
|
+
Built-in collate functions for in-batch positive and hard-negative training.
|
|
42
|
+
- `src/embed_train/train/dataset/hard_negatives.py`
|
|
43
|
+
SentenceTransformers-backed hard-negative mining implementation.
|
|
42
44
|
- `src/embed_train/train/dataset/torch_datasets.py`
|
|
43
|
-
Built-in grouped
|
|
45
|
+
Built-in grouped, flattened query/positive, and hard-negative dataset views.
|
|
44
46
|
- `src/embed_train/train/trainers/torch/loss.py`
|
|
45
|
-
Built-in contrastive losses.
|
|
47
|
+
Built-in contrastive losses, including hard-negative candidate ranking.
|
|
46
48
|
- `src/embed_train/push_to_hf/__init__.py`
|
|
47
49
|
`PushToHFRunner` for checkpoint restore, local repo export, and HF upload.
|
|
48
50
|
- `src/embed_train/utils.py`
|
|
@@ -102,6 +104,9 @@ For most changes, inspect the matching implementation and tests together:
|
|
|
102
104
|
- Preserve the row contracts expected by built-in collate functions:
|
|
103
105
|
- grouped format: `{"query": str, "positives": list[str]}`
|
|
104
106
|
- flattened format: `{"query": str, "positive": str}`
|
|
107
|
+
- hard-negative format: `{"query": str, "positive": str, "negative": str}` or `negative_<n>` columns
|
|
108
|
+
- Keep hard-negative mining isolated behind `HardNegativeMiner`/`HardNegativeDataset` unless a task explicitly changes the abstraction boundary.
|
|
109
|
+
- Preserve the hard-negative candidate order expected by `HardNegativeContrastiveLoss`: one positive first, followed by one or more negatives for the same query.
|
|
105
110
|
- When changing row formats, update the collate functions and tests together.
|
|
106
111
|
- Keep tokenizer behavior explicit through settings rather than hidden defaults.
|
|
107
112
|
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
# [3.2.0](https://gitlab.com/efysent/agentic-core/embed-train/compare/v3.1.0...v3.2.0) (2026-05-12)
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
### Features
|
|
5
|
+
|
|
6
|
+
* remove sentence transformers deprecations ([581032b](https://gitlab.com/efysent/agentic-core/embed-train/commit/581032b97a0bfc50236cc0fa64a857bc65f20314))
|
|
7
|
+
|
|
8
|
+
# [3.1.0](https://gitlab.com/efysent/agentic-core/embed-train/compare/v3.0.0...v3.1.0) (2026-05-12)
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
### Features
|
|
12
|
+
|
|
13
|
+
* add hard negative training ([4d79e25](https://gitlab.com/efysent/agentic-core/embed-train/commit/4d79e25877f7483dd7c6f13a97b1dd846c998c9a))
|
|
14
|
+
|
|
1
15
|
# [3.0.0](https://gitlab.com/efysent/agentic-core/embed-train/compare/v2.0.0...v3.0.0) (2026-05-09)
|
|
2
16
|
|
|
3
17
|
|
|
@@ -1,3 +1,17 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: embed-train
|
|
3
|
+
Version: 3.2.0
|
|
4
|
+
Author-email: jalal <jalalkhaldi3@gmail.com>
|
|
5
|
+
Requires-Python: <3.13,>=3.11
|
|
6
|
+
Requires-Dist: accelerate==1.13.0
|
|
7
|
+
Requires-Dist: datasets==4.8.4
|
|
8
|
+
Requires-Dist: retrievalbase<3.0.0,>=2.1.0
|
|
9
|
+
Requires-Dist: sentence-transformers==5.4.1
|
|
10
|
+
Requires-Dist: tensorboard==2.20.0
|
|
11
|
+
Requires-Dist: torch==2.11.0
|
|
12
|
+
Requires-Dist: transformers==4.57.6
|
|
13
|
+
Description-Content-Type: text/markdown
|
|
14
|
+
|
|
1
15
|
# embed-train
|
|
2
16
|
|
|
3
17
|
`embed-train` is a config-driven library for training, evaluating, checkpointing, and publishing embedding models.
|
|
@@ -61,9 +75,10 @@ src/embed_train/
|
|
|
61
75
|
├── train/
|
|
62
76
|
│ ├── __init__.py # TrainRunner
|
|
63
77
|
│ ├── dataset/
|
|
64
|
-
│ │ ├── __init__.py # Base TorchDataset and CollateFn abstractions
|
|
65
|
-
│ │ ├── collate.py # Built-in in-batch
|
|
66
|
-
│ │ ├──
|
|
78
|
+
│ │ ├── __init__.py # Base TorchDataset, HardNegativeMiner, and CollateFn abstractions
|
|
79
|
+
│ │ ├── collate.py # Built-in in-batch and hard-negative collate functions
|
|
80
|
+
│ │ ├── hard_negatives.py # SentenceTransformers hard-negative miner
|
|
81
|
+
│ │ ├── torch_datasets.py # Built-in query/positive and hard-negative dataset views
|
|
67
82
|
│ │ └── sampling/
|
|
68
83
|
│ │ └── samplers.py # Built-in positive sampler(s)
|
|
69
84
|
│ └── trainers/
|
|
@@ -107,7 +122,8 @@ The base class already provides:
|
|
|
107
122
|
|
|
108
123
|
The PyTorch flow separates concerns clearly:
|
|
109
124
|
|
|
110
|
-
- `TorchDataset`: how rows are loaded and exposed
|
|
125
|
+
- `TorchDataset`: how rows are loaded, optionally mined, and exposed
|
|
126
|
+
- `HardNegativeMiner`: how query/positive rows can be expanded with mined negatives
|
|
111
127
|
- `CollateFn`: how rows become query/document text pairs and tokenized tensors
|
|
112
128
|
- `Processor`: text normalization or preprocessing, provided by `retrievalbase`
|
|
113
129
|
|
|
@@ -124,8 +140,12 @@ The current repository includes these reusable implementations:
|
|
|
124
140
|
|
|
125
141
|
- `embed_train.train.dataset.torch_datasets.QueryMultiPositiveDataset`
|
|
126
142
|
- `embed_train.train.dataset.torch_datasets.QueryPositiveDataset`
|
|
143
|
+
- `embed_train.train.dataset.torch_datasets.HardNegativeDataset`
|
|
144
|
+
- `embed_train.train.dataset.hard_negatives.SentenceTransformerHardNegativeMiner`
|
|
145
|
+
- `embed_train.train.dataset.collate.HardNegativeCollateFn`
|
|
127
146
|
- `embed_train.train.dataset.collate.InBatchNegativeCollateFn`
|
|
128
147
|
- `embed_train.train.dataset.collate.MultiPositiveInBatchCollateFn`
|
|
148
|
+
- `embed_train.train.trainers.torch.loss.HardNegativeContrastiveLoss`
|
|
129
149
|
- `embed_train.train.trainers.torch.loss.InBatchNegativeContrastiveLoss`
|
|
130
150
|
- `embed_train.train.trainers.torch.loss.MultiPositiveContrastiveLoss`
|
|
131
151
|
- `embed_train.train.trainers.hf.SentenceTransformersTrainer`
|
|
@@ -147,6 +167,20 @@ What the built-in trainer does:
|
|
|
147
167
|
- logs to TensorBoard
|
|
148
168
|
- saves checkpoints to `data_dir/checkpoints/...`
|
|
149
169
|
|
|
170
|
+
For hard-negative training, configure the trainer with:
|
|
171
|
+
|
|
172
|
+
- `torch_dataset`: `embed_train.train.dataset.torch_datasets.HardNegativeDataset`
|
|
173
|
+
- `torch_dataset.hard_negative_miner`: `embed_train.train.dataset.hard_negatives.SentenceTransformerHardNegativeMiner`
|
|
174
|
+
- `collate_fn`: `embed_train.train.dataset.collate.HardNegativeCollateFn`
|
|
175
|
+
- `loss`: `embed_train.train.trainers.torch.loss.HardNegativeContrastiveLoss`
|
|
176
|
+
|
|
177
|
+
`HardNegativeDataset` converts rows with `metadata.query` and `page_content` into a Hugging Face
|
|
178
|
+
`Dataset`, mines negatives with `sentence_transformers.util.mine_hard_negatives`, and then exposes
|
|
179
|
+
the mined rows through the normal `TorchDataset` interface. `HardNegativeCollateFn` expects each row
|
|
180
|
+
to contain `query`, `positive`, and either `negative` or numbered `negative_<n>` fields. It emits one
|
|
181
|
+
positive followed by that row's negatives, which is the candidate layout required by
|
|
182
|
+
`HardNegativeContrastiveLoss`.
|
|
183
|
+
|
|
150
184
|
### SentenceTransformers Training
|
|
151
185
|
|
|
152
186
|
Use this when your data is naturally represented as a Hugging Face `Dataset` and you want a standard SentenceTransformers training path with IR evaluation.
|
|
@@ -204,10 +238,11 @@ This is the right place to define pooling, projection heads, shared or separate
|
|
|
204
238
|
|
|
205
239
|
Subclass `TorchDataset` when you need a different row shape or data-loading strategy.
|
|
206
240
|
|
|
207
|
-
The built-in datasets show
|
|
241
|
+
The built-in datasets show these common patterns:
|
|
208
242
|
|
|
209
243
|
- grouped query -> many positives
|
|
210
244
|
- flattened query -> single positive
|
|
245
|
+
- flattened query -> single positive plus mined hard negatives
|
|
211
246
|
|
|
212
247
|
### Add a Custom Collate Function
|
|
213
248
|
|
|
@@ -1,16 +1,3 @@
|
|
|
1
|
-
Metadata-Version: 2.4
|
|
2
|
-
Name: embed-train
|
|
3
|
-
Version: 3.0.0
|
|
4
|
-
Author-email: jalal <jalalkhaldi3@gmail.com>
|
|
5
|
-
Requires-Python: <3.13,>=3.11
|
|
6
|
-
Requires-Dist: accelerate<2.0.0,>=1.13.0
|
|
7
|
-
Requires-Dist: datasets<5.0.0,>=4.5.0
|
|
8
|
-
Requires-Dist: retrievalbase<2.0.0,>=1.0.0
|
|
9
|
-
Requires-Dist: sentence-transformers<6.0.0,>=5.1.2
|
|
10
|
-
Requires-Dist: tensorboard<3.0.0,>=2.20.0
|
|
11
|
-
Requires-Dist: torch<3.0.0,>=2.9.0
|
|
12
|
-
Description-Content-Type: text/markdown
|
|
13
|
-
|
|
14
1
|
# embed-train
|
|
15
2
|
|
|
16
3
|
`embed-train` is a config-driven library for training, evaluating, checkpointing, and publishing embedding models.
|
|
@@ -74,9 +61,10 @@ src/embed_train/
|
|
|
74
61
|
├── train/
|
|
75
62
|
│ ├── __init__.py # TrainRunner
|
|
76
63
|
│ ├── dataset/
|
|
77
|
-
│ │ ├── __init__.py # Base TorchDataset and CollateFn abstractions
|
|
78
|
-
│ │ ├── collate.py # Built-in in-batch
|
|
79
|
-
│ │ ├──
|
|
64
|
+
│ │ ├── __init__.py # Base TorchDataset, HardNegativeMiner, and CollateFn abstractions
|
|
65
|
+
│ │ ├── collate.py # Built-in in-batch and hard-negative collate functions
|
|
66
|
+
│ │ ├── hard_negatives.py # SentenceTransformers hard-negative miner
|
|
67
|
+
│ │ ├── torch_datasets.py # Built-in query/positive and hard-negative dataset views
|
|
80
68
|
│ │ └── sampling/
|
|
81
69
|
│ │ └── samplers.py # Built-in positive sampler(s)
|
|
82
70
|
│ └── trainers/
|
|
@@ -120,7 +108,8 @@ The base class already provides:
|
|
|
120
108
|
|
|
121
109
|
The PyTorch flow separates concerns clearly:
|
|
122
110
|
|
|
123
|
-
- `TorchDataset`: how rows are loaded and exposed
|
|
111
|
+
- `TorchDataset`: how rows are loaded, optionally mined, and exposed
|
|
112
|
+
- `HardNegativeMiner`: how query/positive rows can be expanded with mined negatives
|
|
124
113
|
- `CollateFn`: how rows become query/document text pairs and tokenized tensors
|
|
125
114
|
- `Processor`: text normalization or preprocessing, provided by `retrievalbase`
|
|
126
115
|
|
|
@@ -137,8 +126,12 @@ The current repository includes these reusable implementations:
|
|
|
137
126
|
|
|
138
127
|
- `embed_train.train.dataset.torch_datasets.QueryMultiPositiveDataset`
|
|
139
128
|
- `embed_train.train.dataset.torch_datasets.QueryPositiveDataset`
|
|
129
|
+
- `embed_train.train.dataset.torch_datasets.HardNegativeDataset`
|
|
130
|
+
- `embed_train.train.dataset.hard_negatives.SentenceTransformerHardNegativeMiner`
|
|
131
|
+
- `embed_train.train.dataset.collate.HardNegativeCollateFn`
|
|
140
132
|
- `embed_train.train.dataset.collate.InBatchNegativeCollateFn`
|
|
141
133
|
- `embed_train.train.dataset.collate.MultiPositiveInBatchCollateFn`
|
|
134
|
+
- `embed_train.train.trainers.torch.loss.HardNegativeContrastiveLoss`
|
|
142
135
|
- `embed_train.train.trainers.torch.loss.InBatchNegativeContrastiveLoss`
|
|
143
136
|
- `embed_train.train.trainers.torch.loss.MultiPositiveContrastiveLoss`
|
|
144
137
|
- `embed_train.train.trainers.hf.SentenceTransformersTrainer`
|
|
@@ -160,6 +153,20 @@ What the built-in trainer does:
|
|
|
160
153
|
- logs to TensorBoard
|
|
161
154
|
- saves checkpoints to `data_dir/checkpoints/...`
|
|
162
155
|
|
|
156
|
+
For hard-negative training, configure the trainer with:
|
|
157
|
+
|
|
158
|
+
- `torch_dataset`: `embed_train.train.dataset.torch_datasets.HardNegativeDataset`
|
|
159
|
+
- `torch_dataset.hard_negative_miner`: `embed_train.train.dataset.hard_negatives.SentenceTransformerHardNegativeMiner`
|
|
160
|
+
- `collate_fn`: `embed_train.train.dataset.collate.HardNegativeCollateFn`
|
|
161
|
+
- `loss`: `embed_train.train.trainers.torch.loss.HardNegativeContrastiveLoss`
|
|
162
|
+
|
|
163
|
+
`HardNegativeDataset` converts rows with `metadata.query` and `page_content` into a Hugging Face
|
|
164
|
+
`Dataset`, mines negatives with `sentence_transformers.util.mine_hard_negatives`, and then exposes
|
|
165
|
+
the mined rows through the normal `TorchDataset` interface. `HardNegativeCollateFn` expects each row
|
|
166
|
+
to contain `query`, `positive`, and either `negative` or numbered `negative_<n>` fields. It emits one
|
|
167
|
+
positive followed by that row's negatives, which is the candidate layout required by
|
|
168
|
+
`HardNegativeContrastiveLoss`.
|
|
169
|
+
|
|
163
170
|
### SentenceTransformers Training
|
|
164
171
|
|
|
165
172
|
Use this when your data is naturally represented as a Hugging Face `Dataset` and you want a standard SentenceTransformers training path with IR evaluation.
|
|
@@ -217,10 +224,11 @@ This is the right place to define pooling, projection heads, shared or separate
|
|
|
217
224
|
|
|
218
225
|
Subclass `TorchDataset` when you need a different row shape or data-loading strategy.
|
|
219
226
|
|
|
220
|
-
The built-in datasets show
|
|
227
|
+
The built-in datasets show these common patterns:
|
|
221
228
|
|
|
222
229
|
- grouped query -> many positives
|
|
223
230
|
- flattened query -> single positive
|
|
231
|
+
- flattened query -> single positive plus mined hard negatives
|
|
224
232
|
|
|
225
233
|
### Add a Custom Collate Function
|
|
226
234
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "embed-train"
|
|
3
|
-
version = "3.
|
|
3
|
+
version = "3.2.0"
|
|
4
4
|
description = ""
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "jalal", email = "jalalkhaldi3@gmail.com" }
|
|
@@ -9,12 +9,13 @@ readme = "README.md"
|
|
|
9
9
|
requires-python = ">=3.11,<3.13"
|
|
10
10
|
|
|
11
11
|
dependencies = [
|
|
12
|
-
"torch
|
|
13
|
-
"sentence-transformers
|
|
14
|
-
"
|
|
15
|
-
"
|
|
16
|
-
"
|
|
17
|
-
"
|
|
12
|
+
"torch==2.11.0",
|
|
13
|
+
"sentence-transformers==5.4.1",
|
|
14
|
+
"transformers==4.57.6",
|
|
15
|
+
"datasets==4.8.4",
|
|
16
|
+
"tensorboard==2.20.0",
|
|
17
|
+
"accelerate==1.13.0",
|
|
18
|
+
"retrievalbase>=2.1.0,<3.0.0",
|
|
18
19
|
]
|
|
19
20
|
|
|
20
21
|
[build-system]
|
|
@@ -57,6 +57,10 @@ class MultiPositiveInBatchCollateFnSettings[TCProcessor: "ProcessorSettings"](Co
|
|
|
57
57
|
n_pos: int
|
|
58
58
|
|
|
59
59
|
|
|
60
|
+
class HardNegativeCollateFnSettings[TCProcessor: "ProcessorSettings"](CollateFnSettings[TCProcessor]):
|
|
61
|
+
pass
|
|
62
|
+
|
|
63
|
+
|
|
60
64
|
class ModelSettings(FromConfigMixinSettings):
|
|
61
65
|
pass
|
|
62
66
|
|
|
@@ -69,6 +73,41 @@ class TrainerSettings(FromConfigMixinSettings):
|
|
|
69
73
|
data_dir: Path
|
|
70
74
|
|
|
71
75
|
|
|
76
|
+
class HardNegativeMinerSettings(FromConfigMixinSettings):
|
|
77
|
+
pass
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
class SentenceTransformerHardNegativeMinerSettings(HardNegativeMinerSettings):
|
|
81
|
+
model_name_or_path: str
|
|
82
|
+
cross_encoder_model_name_or_path: str | None = None
|
|
83
|
+
tokenizer: TokenizerSettings
|
|
84
|
+
pooling: Literal["cls", "max", "mean", "mean_sqrt_len_tokens", "weightedmean", "lasttoken"]
|
|
85
|
+
anchor_column_name: str = "query"
|
|
86
|
+
positive_column_name: str = "positive"
|
|
87
|
+
range_min: int = 0
|
|
88
|
+
range_max: int | None = None
|
|
89
|
+
max_score: float | None = None
|
|
90
|
+
min_score: float | None = None
|
|
91
|
+
absolute_margin: float | None = None
|
|
92
|
+
relative_margin: float | None = None
|
|
93
|
+
num_negatives: int = 3
|
|
94
|
+
sampling_strategy: Literal["random", "top"] = "top"
|
|
95
|
+
query_prompt_name: str | None = None
|
|
96
|
+
query_prompt: str | None = None
|
|
97
|
+
corpus_prompt_name: str | None = None
|
|
98
|
+
corpus_prompt: str | None = None
|
|
99
|
+
include_positives: bool = False
|
|
100
|
+
output_format: Literal["triplet", "n-tuple", "labeled-pair", "labeled-list"] = "n-tuple"
|
|
101
|
+
output_scores: bool = False
|
|
102
|
+
batch_size: int = 32
|
|
103
|
+
faiss_batch_size: int = 16384
|
|
104
|
+
use_faiss: bool = False
|
|
105
|
+
use_multi_process: bool | list[str] = False
|
|
106
|
+
verbose: bool = True
|
|
107
|
+
cache_folder: str | None = None
|
|
108
|
+
trust_remote_code: bool = TRUST_REMOTE_CODE
|
|
109
|
+
|
|
110
|
+
|
|
72
111
|
class PyTorchTrainerSettings[
|
|
73
112
|
TCModel: "ModelSettings",
|
|
74
113
|
TCLoss: "LossSettings",
|
|
@@ -107,6 +146,13 @@ class QueryPositiveDatasetSettings[TCDatasetConnector: "DatasetConnectorSettings
|
|
|
107
146
|
pass
|
|
108
147
|
|
|
109
148
|
|
|
149
|
+
class HardNegativeDatasetSettings[
|
|
150
|
+
TCDatasetConnector: "DatasetConnectorSettings",
|
|
151
|
+
TCHardNegativeMiner: "HardNegativeMinerSettings",
|
|
152
|
+
](TorchDatasetSettings[TCDatasetConnector]):
|
|
153
|
+
hard_negative_miner: TCHardNegativeMiner
|
|
154
|
+
|
|
155
|
+
|
|
110
156
|
class InBatchNegativeContrastiveLossSettings(ContrastiveLossSettings):
|
|
111
157
|
pass
|
|
112
158
|
|
|
@@ -115,6 +161,10 @@ class MultiPositiveContrastiveLossSettings(ContrastiveLossSettings):
|
|
|
115
161
|
n_pos: int
|
|
116
162
|
|
|
117
163
|
|
|
164
|
+
class HardNegativeContrastiveLossSettings(ContrastiveLossSettings):
|
|
165
|
+
pass
|
|
166
|
+
|
|
167
|
+
|
|
118
168
|
class RunnerSettings(FromConfigMixinSettings):
|
|
119
169
|
pass
|
|
120
170
|
|
|
@@ -139,17 +189,6 @@ class TrainRunnerSettings[TCTrainer: "TrainerSettings"](RunnerSettings):
|
|
|
139
189
|
trainer: TCTrainer
|
|
140
190
|
|
|
141
191
|
|
|
142
|
-
class HardNegativesSettings(BaseSettings):
|
|
143
|
-
range_min: int
|
|
144
|
-
range_max: int
|
|
145
|
-
max_score: float
|
|
146
|
-
relative_margin: float
|
|
147
|
-
num_negatives: int
|
|
148
|
-
sampling_strategy: Literal["random", "top"]
|
|
149
|
-
batch_size: int
|
|
150
|
-
use_faiss: bool
|
|
151
|
-
|
|
152
|
-
|
|
153
192
|
class EvalutationSettings(BaseSettings):
|
|
154
193
|
query_column: str
|
|
155
194
|
document_column: str
|
|
@@ -16,7 +16,7 @@ from retrievalbase.mixins import FromConfigMixin
|
|
|
16
16
|
_logger = logging.getLogger(__name__)
|
|
17
17
|
|
|
18
18
|
if TYPE_CHECKING:
|
|
19
|
-
from embed_train.settings import CollateFnSettings, TorchDatasetSettings
|
|
19
|
+
from embed_train.settings import CollateFnSettings, HardNegativeMinerSettings, TorchDatasetSettings
|
|
20
20
|
|
|
21
21
|
|
|
22
22
|
class CollateFn[TCCollateFn: "CollateFnSettings[Any]", T: dict[str, Any]](ABC):
|
|
@@ -107,3 +107,12 @@ class TorchDataset[TCTorchDataset: "TorchDatasetSettings[Any]", T: dict[str, Any
|
|
|
107
107
|
"""
|
|
108
108
|
rows: list[dict[str, Any]] = [self[i] for i in range(len(self))]
|
|
109
109
|
return HFDataset.from_list(rows)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
class HardNegativeMiner[TCHardNegativeMiner: "HardNegativeMinerSettings"](FromConfigMixin[TCHardNegativeMiner], ABC):
|
|
113
|
+
def __init__(self, config: TCHardNegativeMiner) -> None:
|
|
114
|
+
self.config = config
|
|
115
|
+
|
|
116
|
+
@abstractmethod
|
|
117
|
+
def mine(self, dataset: HFDataset) -> HFDataset:
|
|
118
|
+
raise NotImplementedError
|
|
@@ -2,6 +2,7 @@ import random
|
|
|
2
2
|
from typing import Any
|
|
3
3
|
|
|
4
4
|
from embed_train.settings import (
|
|
5
|
+
HardNegativeCollateFnSettings,
|
|
5
6
|
InBatchNegativeCollateFnSettings,
|
|
6
7
|
MultiPositiveInBatchCollateFnSettings,
|
|
7
8
|
)
|
|
@@ -50,3 +51,36 @@ class MultiPositiveInBatchCollateFn(CollateFn[MultiPositiveInBatchCollateFnSetti
|
|
|
50
51
|
queries.append(query)
|
|
51
52
|
passages.extend(sampled_positives)
|
|
52
53
|
return queries, passages
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class HardNegativeCollateFn(CollateFn[HardNegativeCollateFnSettings, dict[str, Any]]):
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
config: HardNegativeCollateFnSettings,
|
|
60
|
+
context: dict[str, Any] | None,
|
|
61
|
+
) -> None:
|
|
62
|
+
super().__init__(config, context)
|
|
63
|
+
|
|
64
|
+
def _process_batch(
|
|
65
|
+
self,
|
|
66
|
+
batch: list[dict[str, Any]],
|
|
67
|
+
) -> tuple[list[str], list[str]]:
|
|
68
|
+
queries: list[str] = []
|
|
69
|
+
passages: list[str] = []
|
|
70
|
+
|
|
71
|
+
for item in batch:
|
|
72
|
+
queries.append(item["query"])
|
|
73
|
+
passages.append(item["positive"])
|
|
74
|
+
passages.extend(self._negative_passages(item))
|
|
75
|
+
|
|
76
|
+
return queries, passages
|
|
77
|
+
|
|
78
|
+
def _negative_passages(self, item: dict[str, Any]) -> list[str]:
|
|
79
|
+
if "negative" in item:
|
|
80
|
+
return [item["negative"]]
|
|
81
|
+
|
|
82
|
+
negative_keys = sorted(
|
|
83
|
+
(key for key in item if key.startswith("negative_")),
|
|
84
|
+
key=lambda key: int(key.rsplit("_", 1)[1]),
|
|
85
|
+
)
|
|
86
|
+
return [item[key] for key in negative_keys]
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
import gc
|
|
2
|
+
import logging
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
from datasets import Dataset as HFDataset # type: ignore[import-untyped]
|
|
6
|
+
from sentence_transformers import SentenceTransformer
|
|
7
|
+
from sentence_transformers.cross_encoder import CrossEncoder
|
|
8
|
+
from sentence_transformers.sentence_transformer.modules import Pooling, Transformer
|
|
9
|
+
from sentence_transformers.util import mine_hard_negatives
|
|
10
|
+
|
|
11
|
+
from embed_train.settings import SentenceTransformerHardNegativeMinerSettings
|
|
12
|
+
from embed_train.train.dataset import HardNegativeMiner
|
|
13
|
+
|
|
14
|
+
_logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SentenceTransformerHardNegativeMiner(HardNegativeMiner[SentenceTransformerHardNegativeMinerSettings]):
|
|
18
|
+
def __init__(self, config: "SentenceTransformerHardNegativeMinerSettings") -> None:
|
|
19
|
+
super().__init__(config)
|
|
20
|
+
|
|
21
|
+
def _load_sentence_transformer(self) -> SentenceTransformer:
|
|
22
|
+
trust = self.config.trust_remote_code or self.config.tokenizer.trust_remote_code
|
|
23
|
+
|
|
24
|
+
_logger.info(
|
|
25
|
+
f"Loading SentenceTransformer | "
|
|
26
|
+
f"model={self.config.model_name_or_path} | "
|
|
27
|
+
f"pooling={self.config.pooling} | "
|
|
28
|
+
f"max_length={self.config.tokenizer.max_length}"
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
transformer = Transformer(
|
|
32
|
+
model_name_or_path=self.config.model_name_or_path,
|
|
33
|
+
max_seq_length=self.config.tokenizer.max_length,
|
|
34
|
+
tokenizer_name_or_path=self.config.tokenizer.name,
|
|
35
|
+
model_kwargs={
|
|
36
|
+
"trust_remote_code": trust,
|
|
37
|
+
},
|
|
38
|
+
config_kwargs={
|
|
39
|
+
"trust_remote_code": trust,
|
|
40
|
+
},
|
|
41
|
+
processor_kwargs={
|
|
42
|
+
"trust_remote_code": trust,
|
|
43
|
+
"padding": self.config.tokenizer.padding,
|
|
44
|
+
"truncation": self.config.tokenizer.truncation,
|
|
45
|
+
"model_max_length": self.config.tokenizer.max_length,
|
|
46
|
+
},
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
pooling = Pooling(transformer.get_embedding_dimension(), pooling_mode=self.config.pooling)
|
|
50
|
+
|
|
51
|
+
model = SentenceTransformer(modules=[transformer, pooling])
|
|
52
|
+
|
|
53
|
+
_logger.info(f"SentenceTransformer loaded successfully | embedding_dim={transformer.get_embedding_dimension()}")
|
|
54
|
+
|
|
55
|
+
return model
|
|
56
|
+
|
|
57
|
+
def _load_cross_encoder(self) -> "CrossEncoder | None":
|
|
58
|
+
if not self.config.cross_encoder_model_name_or_path:
|
|
59
|
+
_logger.info("No CrossEncoder configured")
|
|
60
|
+
return None
|
|
61
|
+
|
|
62
|
+
_logger.info(f"Loading CrossEncoder | model={self.config.cross_encoder_model_name_or_path}")
|
|
63
|
+
|
|
64
|
+
model = CrossEncoder(
|
|
65
|
+
self.config.cross_encoder_model_name_or_path,
|
|
66
|
+
trust_remote_code=self.config.trust_remote_code,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
_logger.info("CrossEncoder loaded successfully")
|
|
70
|
+
|
|
71
|
+
return model
|
|
72
|
+
|
|
73
|
+
def mine(self, dataset: HFDataset) -> HFDataset:
|
|
74
|
+
_logger.info(
|
|
75
|
+
f"Starting hard negative mining | "
|
|
76
|
+
f"dataset_size={len(dataset)} | "
|
|
77
|
+
f"num_negatives={self.config.num_negatives} | "
|
|
78
|
+
f"sampling_strategy={self.config.sampling_strategy} | "
|
|
79
|
+
f"use_faiss={self.config.use_faiss}"
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
model = self._load_sentence_transformer()
|
|
83
|
+
cross_encoder = self._load_cross_encoder()
|
|
84
|
+
|
|
85
|
+
try:
|
|
86
|
+
mined = mine_hard_negatives(
|
|
87
|
+
dataset=dataset,
|
|
88
|
+
model=model,
|
|
89
|
+
anchor_column_name=self.config.anchor_column_name,
|
|
90
|
+
positive_column_name=self.config.positive_column_name,
|
|
91
|
+
cross_encoder=cross_encoder,
|
|
92
|
+
range_min=self.config.range_min,
|
|
93
|
+
range_max=self.config.range_max,
|
|
94
|
+
max_score=self.config.max_score,
|
|
95
|
+
min_score=self.config.min_score,
|
|
96
|
+
absolute_margin=self.config.absolute_margin,
|
|
97
|
+
relative_margin=self.config.relative_margin,
|
|
98
|
+
num_negatives=self.config.num_negatives,
|
|
99
|
+
sampling_strategy=self.config.sampling_strategy,
|
|
100
|
+
query_prompt_name=self.config.query_prompt_name,
|
|
101
|
+
query_prompt=self.config.query_prompt,
|
|
102
|
+
corpus_prompt_name=self.config.corpus_prompt_name,
|
|
103
|
+
corpus_prompt=self.config.corpus_prompt,
|
|
104
|
+
include_positives=self.config.include_positives,
|
|
105
|
+
output_format=self.config.output_format,
|
|
106
|
+
output_scores=self.config.output_scores,
|
|
107
|
+
batch_size=self.config.batch_size,
|
|
108
|
+
faiss_batch_size=self.config.faiss_batch_size,
|
|
109
|
+
use_faiss=self.config.use_faiss,
|
|
110
|
+
use_multi_process=self.config.use_multi_process,
|
|
111
|
+
verbose=self.config.verbose,
|
|
112
|
+
cache_folder=self.config.cache_folder,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
_logger.info(f"Hard negative mining completed successfully | mined_dataset_size={len(mined)}")
|
|
116
|
+
finally:
|
|
117
|
+
_logger.info("Cleaning GPU memory after hard negative mining")
|
|
118
|
+
del model
|
|
119
|
+
if cross_encoder is not None:
|
|
120
|
+
del cross_encoder
|
|
121
|
+
gc.collect()
|
|
122
|
+
if torch.cuda.is_available():
|
|
123
|
+
allocated_before = torch.cuda.memory_allocated() / 1024**3
|
|
124
|
+
reserved_before = torch.cuda.memory_reserved() / 1024**3
|
|
125
|
+
torch.cuda.empty_cache()
|
|
126
|
+
torch.cuda.ipc_collect()
|
|
127
|
+
allocated_after = torch.cuda.memory_allocated() / 1024**3
|
|
128
|
+
reserved_after = torch.cuda.memory_reserved() / 1024**3
|
|
129
|
+
_logger.info(
|
|
130
|
+
f"CUDA memory cleanup completed | "
|
|
131
|
+
f"allocated_before={allocated_before:.2f}GB | "
|
|
132
|
+
f"reserved_before={reserved_before:.2f}GB | "
|
|
133
|
+
f"allocated_after={allocated_after:.2f}GB | "
|
|
134
|
+
f"reserved_after={reserved_after:.2f}GB"
|
|
135
|
+
)
|
|
136
|
+
return mined
|
|
@@ -1,9 +1,16 @@
|
|
|
1
1
|
from typing import Any
|
|
2
2
|
|
|
3
3
|
import polars as pl
|
|
4
|
+
from datasets import Dataset as HFDataset # type: ignore[import-untyped]
|
|
4
5
|
|
|
5
|
-
from embed_train.settings import
|
|
6
|
+
from embed_train.settings import (
|
|
7
|
+
HardNegativeDatasetSettings,
|
|
8
|
+
QueryMultiPositiveDatasetSettings,
|
|
9
|
+
QueryPositiveDatasetSettings,
|
|
10
|
+
)
|
|
6
11
|
from embed_train.train.dataset import TorchDataset
|
|
12
|
+
from embed_train.train.dataset.hard_negatives import HardNegativeMiner
|
|
13
|
+
from embed_train.utils import load_class
|
|
7
14
|
|
|
8
15
|
|
|
9
16
|
class QueryMultiPositiveDataset(TorchDataset[QueryMultiPositiveDatasetSettings[Any], dict[str, Any]]):
|
|
@@ -71,3 +78,36 @@ class QueryPositiveDataset(
|
|
|
71
78
|
"query": self._queries[index],
|
|
72
79
|
"positive": positive,
|
|
73
80
|
}
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class HardNegativeDataset(
|
|
84
|
+
TorchDataset[
|
|
85
|
+
HardNegativeDatasetSettings[Any, Any],
|
|
86
|
+
dict[str, Any],
|
|
87
|
+
]
|
|
88
|
+
):
|
|
89
|
+
def __init__(
|
|
90
|
+
self,
|
|
91
|
+
config: HardNegativeDatasetSettings[Any, Any],
|
|
92
|
+
) -> None:
|
|
93
|
+
super().__init__(config)
|
|
94
|
+
input_dataset = self._to_query_positive_hf_dataset()
|
|
95
|
+
miner = self._load_hard_negative_miner()
|
|
96
|
+
self._dataset = miner.mine(input_dataset)
|
|
97
|
+
|
|
98
|
+
def _to_query_positive_hf_dataset(self) -> HFDataset:
|
|
99
|
+
rows = self.dataset.polars.select(
|
|
100
|
+
pl.col("metadata").struct.field("query").alias("query"),
|
|
101
|
+
pl.col("page_content").alias("positive"),
|
|
102
|
+
)
|
|
103
|
+
return HFDataset.from_list(rows.to_dicts())
|
|
104
|
+
|
|
105
|
+
def _load_hard_negative_miner(self) -> HardNegativeMiner[Any]:
|
|
106
|
+
miner_cls = load_class(self.config.hard_negative_miner.module_path)
|
|
107
|
+
return miner_cls.from_config(self.config.hard_negative_miner)
|
|
108
|
+
|
|
109
|
+
def __len__(self) -> int:
|
|
110
|
+
return len(self._dataset)
|
|
111
|
+
|
|
112
|
+
def __getitem__(self, index: int) -> dict[str, Any]:
|
|
113
|
+
return self._dataset[index]
|
|
@@ -64,7 +64,7 @@ class SentenceTransformersTrainer[TCHFTrainRunner: "SentenceTransformersTrainerS
|
|
|
64
64
|
_logger.info("Starting SentenceTransformers training...")
|
|
65
65
|
trainer.train()
|
|
66
66
|
|
|
67
|
-
def _get_warmup_steps(self, dataset: Dataset) ->
|
|
67
|
+
def _get_warmup_steps(self, dataset: Dataset) -> int:
|
|
68
68
|
train_size = len(dataset)
|
|
69
69
|
steps_per_epoch = train_size // self.config.per_device_train_batch_size
|
|
70
70
|
total_steps = steps_per_epoch * self.config.num_epochs
|
|
@@ -96,13 +96,13 @@ class SentenceTransformersTrainer[TCHFTrainRunner: "SentenceTransformersTrainerS
|
|
|
96
96
|
model_name_or_path=model_path,
|
|
97
97
|
max_seq_length=self.config.tokenizer.max_length,
|
|
98
98
|
tokenizer_name_or_path=self.config.tokenizer.name,
|
|
99
|
-
|
|
99
|
+
model_kwargs={
|
|
100
100
|
"trust_remote_code": trust,
|
|
101
101
|
},
|
|
102
|
-
|
|
102
|
+
config_kwargs={
|
|
103
103
|
"trust_remote_code": trust,
|
|
104
104
|
},
|
|
105
|
-
|
|
105
|
+
processor_kwargs={
|
|
106
106
|
"trust_remote_code": trust,
|
|
107
107
|
"padding": self.config.tokenizer.padding,
|
|
108
108
|
"truncation": self.config.tokenizer.truncation,
|
|
@@ -110,7 +110,7 @@ class SentenceTransformersTrainer[TCHFTrainRunner: "SentenceTransformersTrainerS
|
|
|
110
110
|
},
|
|
111
111
|
)
|
|
112
112
|
pooling = Pooling(
|
|
113
|
-
transformer.
|
|
113
|
+
transformer.get_embedding_dimension(),
|
|
114
114
|
pooling_mode_mean_tokens=self.config.pooling == "mean_tokens",
|
|
115
115
|
pooling_mode_cls_token=self.config.pooling == "cls",
|
|
116
116
|
pooling_mode_max_tokens=self.config.pooling == "max_tokens",
|
|
@@ -189,7 +189,7 @@ class PyTorchTrainer[TCPyTorchTrainer: "PyTorchTrainerSettings[Any, Any, Any, An
|
|
|
189
189
|
self.model.load_state_dict(ckpt["model_state"])
|
|
190
190
|
self.optimizer.load_state_dict(ckpt["optimizer_state"])
|
|
191
191
|
_logger.info(f"Resuming training from epoch {ckpt['epoch']}")
|
|
192
|
-
return ckpt["epoch"]
|
|
192
|
+
return ckpt["epoch"]
|
|
193
193
|
return 0
|
|
194
194
|
|
|
195
195
|
def run_epoch(
|