embed-train 3.0.0__tar.gz → 3.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. {embed_train-3.0.0 → embed_train-3.2.0}/AGENTS.md +9 -4
  2. {embed_train-3.0.0 → embed_train-3.2.0}/CHANGELOG.md +14 -0
  3. embed_train-3.0.0/README.md → embed_train-3.2.0/PKG-INFO +40 -5
  4. embed_train-3.0.0/PKG-INFO → embed_train-3.2.0/README.md +26 -18
  5. {embed_train-3.0.0 → embed_train-3.2.0}/pyproject.toml +8 -7
  6. {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/settings.py +50 -11
  7. {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/train/dataset/__init__.py +10 -1
  8. {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/train/dataset/collate.py +34 -0
  9. embed_train-3.2.0/src/embed_train/train/dataset/hard_negatives.py +136 -0
  10. {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/train/dataset/torch_datasets.py +41 -1
  11. {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/train/trainers/hf/__init__.py +5 -5
  12. {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/train/trainers/torch/__init__.py +1 -1
  13. {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/train/trainers/torch/loss.py +24 -0
  14. {embed_train-3.0.0 → embed_train-3.2.0}/tests/fixtures/components.py +54 -0
  15. {embed_train-3.0.0 → embed_train-3.2.0}/tests/unit/test_push_to_hf.py +23 -0
  16. {embed_train-3.0.0 → embed_train-3.2.0}/tests/unit/test_settings.py +4 -14
  17. embed_train-3.2.0/tests/unit/test_train/test_collate.py +83 -0
  18. {embed_train-3.0.0 → embed_train-3.2.0}/tests/unit/test_train/test_dataset.py +15 -1
  19. embed_train-3.2.0/tests/unit/test_train/test_hard_negatives.py +117 -0
  20. {embed_train-3.0.0 → embed_train-3.2.0}/tests/unit/test_train/test_hf_trainer.py +1 -1
  21. embed_train-3.2.0/tests/unit/test_train/test_loss.py +83 -0
  22. embed_train-3.2.0/tests/unit/test_train/test_torch_datasets.py +83 -0
  23. {embed_train-3.0.0 → embed_train-3.2.0}/uv.lock +22 -65
  24. embed_train-3.0.0/tests/unit/test_train/test_collate.py +0 -42
  25. embed_train-3.0.0/tests/unit/test_train/test_loss.py +0 -39
  26. embed_train-3.0.0/tests/unit/test_train/test_torch_datasets.py +0 -33
  27. {embed_train-3.0.0 → embed_train-3.2.0}/.gitignore +0 -0
  28. {embed_train-3.0.0 → embed_train-3.2.0}/.gitlab-ci.yml +0 -0
  29. {embed_train-3.0.0 → embed_train-3.2.0}/.pre-commit-config.yaml +0 -0
  30. {embed_train-3.0.0 → embed_train-3.2.0}/.releaserc.json +0 -0
  31. {embed_train-3.0.0 → embed_train-3.2.0}/Makefile +0 -0
  32. {embed_train-3.0.0 → embed_train-3.2.0}/codecov.yml +0 -0
  33. {embed_train-3.0.0 → embed_train-3.2.0}/commitlint.config.cjs +0 -0
  34. {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/__init__.py +0 -0
  35. {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/constants.py +0 -0
  36. {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/exceptions.py +0 -0
  37. {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/models/__init__.py +0 -0
  38. {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/push_to_hf/__init__.py +0 -0
  39. {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/py.typed +0 -0
  40. {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/train/__init__.py +0 -0
  41. {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/train/dataset/sampling/__init__.py +0 -0
  42. {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/train/dataset/sampling/samplers.py +0 -0
  43. {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/train/trainers/__init__.py +0 -0
  44. {embed_train-3.0.0 → embed_train-3.2.0}/src/embed_train/utils.py +0 -0
  45. {embed_train-3.0.0 → embed_train-3.2.0}/tests/__init__.py +0 -0
  46. {embed_train-3.0.0 → embed_train-3.2.0}/tests/conftest.py +0 -0
  47. {embed_train-3.0.0 → embed_train-3.2.0}/tests/fixtures/__init__.py +0 -0
  48. {embed_train-3.0.0 → embed_train-3.2.0}/tests/fixtures/data.py +0 -0
  49. {embed_train-3.0.0 → embed_train-3.2.0}/tests/integration/__init__.py +0 -0
  50. {embed_train-3.0.0 → embed_train-3.2.0}/tests/integration/test_dataset/__init__.py +0 -0
  51. {embed_train-3.0.0 → embed_train-3.2.0}/tests/integration/test_dataset/test_to_hf_dataset.py +0 -0
  52. {embed_train-3.0.0 → embed_train-3.2.0}/tests/integration/test_train_runner/__init__.py +0 -0
  53. {embed_train-3.0.0 → embed_train-3.2.0}/tests/integration/test_train_runner/test_train_runner_flow.py +0 -0
  54. {embed_train-3.0.0 → embed_train-3.2.0}/tests/unit/__init__.py +0 -0
  55. {embed_train-3.0.0 → embed_train-3.2.0}/tests/unit/test_abstract_guards.py +0 -0
  56. {embed_train-3.0.0 → embed_train-3.2.0}/tests/unit/test_embed_train.py +0 -0
  57. {embed_train-3.0.0 → embed_train-3.2.0}/tests/unit/test_exceptions.py +0 -0
  58. {embed_train-3.0.0 → embed_train-3.2.0}/tests/unit/test_models.py +0 -0
  59. {embed_train-3.0.0 → embed_train-3.2.0}/tests/unit/test_train/__init__.py +0 -0
  60. {embed_train-3.0.0 → embed_train-3.2.0}/tests/unit/test_train/test_runner.py +0 -0
  61. {embed_train-3.0.0 → embed_train-3.2.0}/tests/unit/test_train/test_samplers.py +0 -0
  62. {embed_train-3.0.0 → embed_train-3.2.0}/tests/unit/test_train/test_sampling.py +0 -0
  63. {embed_train-3.0.0 → embed_train-3.2.0}/tests/unit/test_train/test_torch_trainer.py +0 -0
  64. {embed_train-3.0.0 → embed_train-3.2.0}/tests/unit/test_train/test_trainers.py +0 -0
  65. {embed_train-3.0.0 → embed_train-3.2.0}/tests/unit/test_utils.py +0 -0
@@ -36,13 +36,15 @@ Key modules:
36
36
  - `src/embed_train/train/trainers/hf/__init__.py`
37
37
  SentenceTransformers-based training path with `InformationRetrievalEvaluator`.
38
38
  - `src/embed_train/train/dataset/__init__.py`
39
- Base `TorchDataset` and `CollateFn` abstractions.
39
+ Base `TorchDataset`, `HardNegativeMiner`, and `CollateFn` abstractions.
40
40
  - `src/embed_train/train/dataset/collate.py`
41
- Built-in collate functions for in-batch positive training.
41
+ Built-in collate functions for in-batch positive and hard-negative training.
42
+ - `src/embed_train/train/dataset/hard_negatives.py`
43
+ SentenceTransformers-backed hard-negative mining implementation.
42
44
  - `src/embed_train/train/dataset/torch_datasets.py`
43
- Built-in grouped and flattened query/positive dataset views.
45
+ Built-in grouped, flattened query/positive, and hard-negative dataset views.
44
46
  - `src/embed_train/train/trainers/torch/loss.py`
45
- Built-in contrastive losses.
47
+ Built-in contrastive losses, including hard-negative candidate ranking.
46
48
  - `src/embed_train/push_to_hf/__init__.py`
47
49
  `PushToHFRunner` for checkpoint restore, local repo export, and HF upload.
48
50
  - `src/embed_train/utils.py`
@@ -102,6 +104,9 @@ For most changes, inspect the matching implementation and tests together:
102
104
  - Preserve the row contracts expected by built-in collate functions:
103
105
  - grouped format: `{"query": str, "positives": list[str]}`
104
106
  - flattened format: `{"query": str, "positive": str}`
107
+ - hard-negative format: `{"query": str, "positive": str, "negative": str}` or `negative_<n>` columns
108
+ - Keep hard-negative mining isolated behind `HardNegativeMiner`/`HardNegativeDataset` unless a task explicitly changes the abstraction boundary.
109
+ - Preserve the hard-negative candidate order expected by `HardNegativeContrastiveLoss`: one positive first, followed by one or more negatives for the same query.
105
110
  - When changing row formats, update the collate functions and tests together.
106
111
  - Keep tokenizer behavior explicit through settings rather than hidden defaults.
107
112
 
@@ -1,3 +1,17 @@
1
+ # [3.2.0](https://gitlab.com/efysent/agentic-core/embed-train/compare/v3.1.0...v3.2.0) (2026-05-12)
2
+
3
+
4
+ ### Features
5
+
6
+ * remove sentence transformers deprecations ([581032b](https://gitlab.com/efysent/agentic-core/embed-train/commit/581032b97a0bfc50236cc0fa64a857bc65f20314))
7
+
8
+ # [3.1.0](https://gitlab.com/efysent/agentic-core/embed-train/compare/v3.0.0...v3.1.0) (2026-05-12)
9
+
10
+
11
+ ### Features
12
+
13
+ * add hard negative training ([4d79e25](https://gitlab.com/efysent/agentic-core/embed-train/commit/4d79e25877f7483dd7c6f13a97b1dd846c998c9a))
14
+
1
15
  # [3.0.0](https://gitlab.com/efysent/agentic-core/embed-train/compare/v2.0.0...v3.0.0) (2026-05-09)
2
16
 
3
17
 
@@ -1,3 +1,17 @@
1
+ Metadata-Version: 2.4
2
+ Name: embed-train
3
+ Version: 3.2.0
4
+ Author-email: jalal <jalalkhaldi3@gmail.com>
5
+ Requires-Python: <3.13,>=3.11
6
+ Requires-Dist: accelerate==1.13.0
7
+ Requires-Dist: datasets==4.8.4
8
+ Requires-Dist: retrievalbase<3.0.0,>=2.1.0
9
+ Requires-Dist: sentence-transformers==5.4.1
10
+ Requires-Dist: tensorboard==2.20.0
11
+ Requires-Dist: torch==2.11.0
12
+ Requires-Dist: transformers==4.57.6
13
+ Description-Content-Type: text/markdown
14
+
1
15
  # embed-train
2
16
 
3
17
  `embed-train` is a config-driven library for training, evaluating, checkpointing, and publishing embedding models.
@@ -61,9 +75,10 @@ src/embed_train/
61
75
  ├── train/
62
76
  │ ├── __init__.py # TrainRunner
63
77
  │ ├── dataset/
64
- │ │ ├── __init__.py # Base TorchDataset and CollateFn abstractions
65
- │ │ ├── collate.py # Built-in in-batch positive collate functions
66
- │ │ ├── torch_datasets.py # Built-in query/positive dataset views
78
+ │ │ ├── __init__.py # Base TorchDataset, HardNegativeMiner, and CollateFn abstractions
79
+ │ │ ├── collate.py # Built-in in-batch and hard-negative collate functions
80
+ │ │ ├── hard_negatives.py # SentenceTransformers hard-negative miner
81
+ │ │ ├── torch_datasets.py # Built-in query/positive and hard-negative dataset views
67
82
  │ │ └── sampling/
68
83
  │ │ └── samplers.py # Built-in positive sampler(s)
69
84
  │ └── trainers/
@@ -107,7 +122,8 @@ The base class already provides:
107
122
 
108
123
  The PyTorch flow separates concerns clearly:
109
124
 
110
- - `TorchDataset`: how rows are loaded and exposed
125
+ - `TorchDataset`: how rows are loaded, optionally mined, and exposed
126
+ - `HardNegativeMiner`: how query/positive rows can be expanded with mined negatives
111
127
  - `CollateFn`: how rows become query/document text pairs and tokenized tensors
112
128
  - `Processor`: text normalization or preprocessing, provided by `retrievalbase`
113
129
 
@@ -124,8 +140,12 @@ The current repository includes these reusable implementations:
124
140
 
125
141
  - `embed_train.train.dataset.torch_datasets.QueryMultiPositiveDataset`
126
142
  - `embed_train.train.dataset.torch_datasets.QueryPositiveDataset`
143
+ - `embed_train.train.dataset.torch_datasets.HardNegativeDataset`
144
+ - `embed_train.train.dataset.hard_negatives.SentenceTransformerHardNegativeMiner`
145
+ - `embed_train.train.dataset.collate.HardNegativeCollateFn`
127
146
  - `embed_train.train.dataset.collate.InBatchNegativeCollateFn`
128
147
  - `embed_train.train.dataset.collate.MultiPositiveInBatchCollateFn`
148
+ - `embed_train.train.trainers.torch.loss.HardNegativeContrastiveLoss`
129
149
  - `embed_train.train.trainers.torch.loss.InBatchNegativeContrastiveLoss`
130
150
  - `embed_train.train.trainers.torch.loss.MultiPositiveContrastiveLoss`
131
151
  - `embed_train.train.trainers.hf.SentenceTransformersTrainer`
@@ -147,6 +167,20 @@ What the built-in trainer does:
147
167
  - logs to TensorBoard
148
168
  - saves checkpoints to `data_dir/checkpoints/...`
149
169
 
170
+ For hard-negative training, configure the trainer with:
171
+
172
+ - `torch_dataset`: `embed_train.train.dataset.torch_datasets.HardNegativeDataset`
173
+ - `torch_dataset.hard_negative_miner`: `embed_train.train.dataset.hard_negatives.SentenceTransformerHardNegativeMiner`
174
+ - `collate_fn`: `embed_train.train.dataset.collate.HardNegativeCollateFn`
175
+ - `loss`: `embed_train.train.trainers.torch.loss.HardNegativeContrastiveLoss`
176
+
177
+ `HardNegativeDataset` converts rows with `metadata.query` and `page_content` into a Hugging Face
178
+ `Dataset`, mines negatives with `sentence_transformers.util.mine_hard_negatives`, and then exposes
179
+ the mined rows through the normal `TorchDataset` interface. `HardNegativeCollateFn` expects each row
180
+ to contain `query`, `positive`, and either `negative` or numbered `negative_<n>` fields. It emits one
181
+ positive followed by that row's negatives, which is the candidate layout required by
182
+ `HardNegativeContrastiveLoss`.
183
+
150
184
  ### SentenceTransformers Training
151
185
 
152
186
  Use this when your data is naturally represented as a Hugging Face `Dataset` and you want a standard SentenceTransformers training path with IR evaluation.
@@ -204,10 +238,11 @@ This is the right place to define pooling, projection heads, shared or separate
204
238
 
205
239
  Subclass `TorchDataset` when you need a different row shape or data-loading strategy.
206
240
 
207
- The built-in datasets show two common patterns:
241
+ The built-in datasets show these common patterns:
208
242
 
209
243
  - grouped query -> many positives
210
244
  - flattened query -> single positive
245
+ - flattened query -> single positive plus mined hard negatives
211
246
 
212
247
  ### Add a Custom Collate Function
213
248
 
@@ -1,16 +1,3 @@
1
- Metadata-Version: 2.4
2
- Name: embed-train
3
- Version: 3.0.0
4
- Author-email: jalal <jalalkhaldi3@gmail.com>
5
- Requires-Python: <3.13,>=3.11
6
- Requires-Dist: accelerate<2.0.0,>=1.13.0
7
- Requires-Dist: datasets<5.0.0,>=4.5.0
8
- Requires-Dist: retrievalbase<2.0.0,>=1.0.0
9
- Requires-Dist: sentence-transformers<6.0.0,>=5.1.2
10
- Requires-Dist: tensorboard<3.0.0,>=2.20.0
11
- Requires-Dist: torch<3.0.0,>=2.9.0
12
- Description-Content-Type: text/markdown
13
-
14
1
  # embed-train
15
2
 
16
3
  `embed-train` is a config-driven library for training, evaluating, checkpointing, and publishing embedding models.
@@ -74,9 +61,10 @@ src/embed_train/
74
61
  ├── train/
75
62
  │ ├── __init__.py # TrainRunner
76
63
  │ ├── dataset/
77
- │ │ ├── __init__.py # Base TorchDataset and CollateFn abstractions
78
- │ │ ├── collate.py # Built-in in-batch positive collate functions
79
- │ │ ├── torch_datasets.py # Built-in query/positive dataset views
64
+ │ │ ├── __init__.py # Base TorchDataset, HardNegativeMiner, and CollateFn abstractions
65
+ │ │ ├── collate.py # Built-in in-batch and hard-negative collate functions
66
+ │ │ ├── hard_negatives.py # SentenceTransformers hard-negative miner
67
+ │ │ ├── torch_datasets.py # Built-in query/positive and hard-negative dataset views
80
68
  │ │ └── sampling/
81
69
  │ │ └── samplers.py # Built-in positive sampler(s)
82
70
  │ └── trainers/
@@ -120,7 +108,8 @@ The base class already provides:
120
108
 
121
109
  The PyTorch flow separates concerns clearly:
122
110
 
123
- - `TorchDataset`: how rows are loaded and exposed
111
+ - `TorchDataset`: how rows are loaded, optionally mined, and exposed
112
+ - `HardNegativeMiner`: how query/positive rows can be expanded with mined negatives
124
113
  - `CollateFn`: how rows become query/document text pairs and tokenized tensors
125
114
  - `Processor`: text normalization or preprocessing, provided by `retrievalbase`
126
115
 
@@ -137,8 +126,12 @@ The current repository includes these reusable implementations:
137
126
 
138
127
  - `embed_train.train.dataset.torch_datasets.QueryMultiPositiveDataset`
139
128
  - `embed_train.train.dataset.torch_datasets.QueryPositiveDataset`
129
+ - `embed_train.train.dataset.torch_datasets.HardNegativeDataset`
130
+ - `embed_train.train.dataset.hard_negatives.SentenceTransformerHardNegativeMiner`
131
+ - `embed_train.train.dataset.collate.HardNegativeCollateFn`
140
132
  - `embed_train.train.dataset.collate.InBatchNegativeCollateFn`
141
133
  - `embed_train.train.dataset.collate.MultiPositiveInBatchCollateFn`
134
+ - `embed_train.train.trainers.torch.loss.HardNegativeContrastiveLoss`
142
135
  - `embed_train.train.trainers.torch.loss.InBatchNegativeContrastiveLoss`
143
136
  - `embed_train.train.trainers.torch.loss.MultiPositiveContrastiveLoss`
144
137
  - `embed_train.train.trainers.hf.SentenceTransformersTrainer`
@@ -160,6 +153,20 @@ What the built-in trainer does:
160
153
  - logs to TensorBoard
161
154
  - saves checkpoints to `data_dir/checkpoints/...`
162
155
 
156
+ For hard-negative training, configure the trainer with:
157
+
158
+ - `torch_dataset`: `embed_train.train.dataset.torch_datasets.HardNegativeDataset`
159
+ - `torch_dataset.hard_negative_miner`: `embed_train.train.dataset.hard_negatives.SentenceTransformerHardNegativeMiner`
160
+ - `collate_fn`: `embed_train.train.dataset.collate.HardNegativeCollateFn`
161
+ - `loss`: `embed_train.train.trainers.torch.loss.HardNegativeContrastiveLoss`
162
+
163
+ `HardNegativeDataset` converts rows with `metadata.query` and `page_content` into a Hugging Face
164
+ `Dataset`, mines negatives with `sentence_transformers.util.mine_hard_negatives`, and then exposes
165
+ the mined rows through the normal `TorchDataset` interface. `HardNegativeCollateFn` expects each row
166
+ to contain `query`, `positive`, and either `negative` or numbered `negative_<n>` fields. It emits one
167
+ positive followed by that row's negatives, which is the candidate layout required by
168
+ `HardNegativeContrastiveLoss`.
169
+
163
170
  ### SentenceTransformers Training
164
171
 
165
172
  Use this when your data is naturally represented as a Hugging Face `Dataset` and you want a standard SentenceTransformers training path with IR evaluation.
@@ -217,10 +224,11 @@ This is the right place to define pooling, projection heads, shared or separate
217
224
 
218
225
  Subclass `TorchDataset` when you need a different row shape or data-loading strategy.
219
226
 
220
- The built-in datasets show two common patterns:
227
+ The built-in datasets show these common patterns:
221
228
 
222
229
  - grouped query -> many positives
223
230
  - flattened query -> single positive
231
+ - flattened query -> single positive plus mined hard negatives
224
232
 
225
233
  ### Add a Custom Collate Function
226
234
 
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "embed-train"
3
- version = "3.0.0"
3
+ version = "3.2.0"
4
4
  description = ""
5
5
  authors = [
6
6
  { name = "jalal", email = "jalalkhaldi3@gmail.com" }
@@ -9,12 +9,13 @@ readme = "README.md"
9
9
  requires-python = ">=3.11,<3.13"
10
10
 
11
11
  dependencies = [
12
- "torch>=2.9.0,<3.0.0",
13
- "sentence-transformers>=5.1.2,<6.0.0",
14
- "datasets>=4.5.0,<5.0.0",
15
- "tensorboard>=2.20.0,<3.0.0",
16
- "accelerate>=1.13.0,<2.0.0",
17
- "retrievalbase>=1.0.0,<2.0.0",
12
+ "torch==2.11.0",
13
+ "sentence-transformers==5.4.1",
14
+ "transformers==4.57.6",
15
+ "datasets==4.8.4",
16
+ "tensorboard==2.20.0",
17
+ "accelerate==1.13.0",
18
+ "retrievalbase>=2.1.0,<3.0.0",
18
19
  ]
19
20
 
20
21
  [build-system]
@@ -57,6 +57,10 @@ class MultiPositiveInBatchCollateFnSettings[TCProcessor: "ProcessorSettings"](Co
57
57
  n_pos: int
58
58
 
59
59
 
60
+ class HardNegativeCollateFnSettings[TCProcessor: "ProcessorSettings"](CollateFnSettings[TCProcessor]):
61
+ pass
62
+
63
+
60
64
  class ModelSettings(FromConfigMixinSettings):
61
65
  pass
62
66
 
@@ -69,6 +73,41 @@ class TrainerSettings(FromConfigMixinSettings):
69
73
  data_dir: Path
70
74
 
71
75
 
76
+ class HardNegativeMinerSettings(FromConfigMixinSettings):
77
+ pass
78
+
79
+
80
+ class SentenceTransformerHardNegativeMinerSettings(HardNegativeMinerSettings):
81
+ model_name_or_path: str
82
+ cross_encoder_model_name_or_path: str | None = None
83
+ tokenizer: TokenizerSettings
84
+ pooling: Literal["cls", "max", "mean", "mean_sqrt_len_tokens", "weightedmean", "lasttoken"]
85
+ anchor_column_name: str = "query"
86
+ positive_column_name: str = "positive"
87
+ range_min: int = 0
88
+ range_max: int | None = None
89
+ max_score: float | None = None
90
+ min_score: float | None = None
91
+ absolute_margin: float | None = None
92
+ relative_margin: float | None = None
93
+ num_negatives: int = 3
94
+ sampling_strategy: Literal["random", "top"] = "top"
95
+ query_prompt_name: str | None = None
96
+ query_prompt: str | None = None
97
+ corpus_prompt_name: str | None = None
98
+ corpus_prompt: str | None = None
99
+ include_positives: bool = False
100
+ output_format: Literal["triplet", "n-tuple", "labeled-pair", "labeled-list"] = "n-tuple"
101
+ output_scores: bool = False
102
+ batch_size: int = 32
103
+ faiss_batch_size: int = 16384
104
+ use_faiss: bool = False
105
+ use_multi_process: bool | list[str] = False
106
+ verbose: bool = True
107
+ cache_folder: str | None = None
108
+ trust_remote_code: bool = TRUST_REMOTE_CODE
109
+
110
+
72
111
  class PyTorchTrainerSettings[
73
112
  TCModel: "ModelSettings",
74
113
  TCLoss: "LossSettings",
@@ -107,6 +146,13 @@ class QueryPositiveDatasetSettings[TCDatasetConnector: "DatasetConnectorSettings
107
146
  pass
108
147
 
109
148
 
149
+ class HardNegativeDatasetSettings[
150
+ TCDatasetConnector: "DatasetConnectorSettings",
151
+ TCHardNegativeMiner: "HardNegativeMinerSettings",
152
+ ](TorchDatasetSettings[TCDatasetConnector]):
153
+ hard_negative_miner: TCHardNegativeMiner
154
+
155
+
110
156
  class InBatchNegativeContrastiveLossSettings(ContrastiveLossSettings):
111
157
  pass
112
158
 
@@ -115,6 +161,10 @@ class MultiPositiveContrastiveLossSettings(ContrastiveLossSettings):
115
161
  n_pos: int
116
162
 
117
163
 
164
+ class HardNegativeContrastiveLossSettings(ContrastiveLossSettings):
165
+ pass
166
+
167
+
118
168
  class RunnerSettings(FromConfigMixinSettings):
119
169
  pass
120
170
 
@@ -139,17 +189,6 @@ class TrainRunnerSettings[TCTrainer: "TrainerSettings"](RunnerSettings):
139
189
  trainer: TCTrainer
140
190
 
141
191
 
142
- class HardNegativesSettings(BaseSettings):
143
- range_min: int
144
- range_max: int
145
- max_score: float
146
- relative_margin: float
147
- num_negatives: int
148
- sampling_strategy: Literal["random", "top"]
149
- batch_size: int
150
- use_faiss: bool
151
-
152
-
153
192
  class EvalutationSettings(BaseSettings):
154
193
  query_column: str
155
194
  document_column: str
@@ -16,7 +16,7 @@ from retrievalbase.mixins import FromConfigMixin
16
16
  _logger = logging.getLogger(__name__)
17
17
 
18
18
  if TYPE_CHECKING:
19
- from embed_train.settings import CollateFnSettings, TorchDatasetSettings
19
+ from embed_train.settings import CollateFnSettings, HardNegativeMinerSettings, TorchDatasetSettings
20
20
 
21
21
 
22
22
  class CollateFn[TCCollateFn: "CollateFnSettings[Any]", T: dict[str, Any]](ABC):
@@ -107,3 +107,12 @@ class TorchDataset[TCTorchDataset: "TorchDatasetSettings[Any]", T: dict[str, Any
107
107
  """
108
108
  rows: list[dict[str, Any]] = [self[i] for i in range(len(self))]
109
109
  return HFDataset.from_list(rows)
110
+
111
+
112
+ class HardNegativeMiner[TCHardNegativeMiner: "HardNegativeMinerSettings"](FromConfigMixin[TCHardNegativeMiner], ABC):
113
+ def __init__(self, config: TCHardNegativeMiner) -> None:
114
+ self.config = config
115
+
116
+ @abstractmethod
117
+ def mine(self, dataset: HFDataset) -> HFDataset:
118
+ raise NotImplementedError
@@ -2,6 +2,7 @@ import random
2
2
  from typing import Any
3
3
 
4
4
  from embed_train.settings import (
5
+ HardNegativeCollateFnSettings,
5
6
  InBatchNegativeCollateFnSettings,
6
7
  MultiPositiveInBatchCollateFnSettings,
7
8
  )
@@ -50,3 +51,36 @@ class MultiPositiveInBatchCollateFn(CollateFn[MultiPositiveInBatchCollateFnSetti
50
51
  queries.append(query)
51
52
  passages.extend(sampled_positives)
52
53
  return queries, passages
54
+
55
+
56
+ class HardNegativeCollateFn(CollateFn[HardNegativeCollateFnSettings, dict[str, Any]]):
57
+ def __init__(
58
+ self,
59
+ config: HardNegativeCollateFnSettings,
60
+ context: dict[str, Any] | None,
61
+ ) -> None:
62
+ super().__init__(config, context)
63
+
64
+ def _process_batch(
65
+ self,
66
+ batch: list[dict[str, Any]],
67
+ ) -> tuple[list[str], list[str]]:
68
+ queries: list[str] = []
69
+ passages: list[str] = []
70
+
71
+ for item in batch:
72
+ queries.append(item["query"])
73
+ passages.append(item["positive"])
74
+ passages.extend(self._negative_passages(item))
75
+
76
+ return queries, passages
77
+
78
+ def _negative_passages(self, item: dict[str, Any]) -> list[str]:
79
+ if "negative" in item:
80
+ return [item["negative"]]
81
+
82
+ negative_keys = sorted(
83
+ (key for key in item if key.startswith("negative_")),
84
+ key=lambda key: int(key.rsplit("_", 1)[1]),
85
+ )
86
+ return [item[key] for key in negative_keys]
@@ -0,0 +1,136 @@
1
+ import gc
2
+ import logging
3
+
4
+ import torch
5
+ from datasets import Dataset as HFDataset # type: ignore[import-untyped]
6
+ from sentence_transformers import SentenceTransformer
7
+ from sentence_transformers.cross_encoder import CrossEncoder
8
+ from sentence_transformers.sentence_transformer.modules import Pooling, Transformer
9
+ from sentence_transformers.util import mine_hard_negatives
10
+
11
+ from embed_train.settings import SentenceTransformerHardNegativeMinerSettings
12
+ from embed_train.train.dataset import HardNegativeMiner
13
+
14
+ _logger = logging.getLogger(__name__)
15
+
16
+
17
+ class SentenceTransformerHardNegativeMiner(HardNegativeMiner[SentenceTransformerHardNegativeMinerSettings]):
18
+ def __init__(self, config: "SentenceTransformerHardNegativeMinerSettings") -> None:
19
+ super().__init__(config)
20
+
21
+ def _load_sentence_transformer(self) -> SentenceTransformer:
22
+ trust = self.config.trust_remote_code or self.config.tokenizer.trust_remote_code
23
+
24
+ _logger.info(
25
+ f"Loading SentenceTransformer | "
26
+ f"model={self.config.model_name_or_path} | "
27
+ f"pooling={self.config.pooling} | "
28
+ f"max_length={self.config.tokenizer.max_length}"
29
+ )
30
+
31
+ transformer = Transformer(
32
+ model_name_or_path=self.config.model_name_or_path,
33
+ max_seq_length=self.config.tokenizer.max_length,
34
+ tokenizer_name_or_path=self.config.tokenizer.name,
35
+ model_kwargs={
36
+ "trust_remote_code": trust,
37
+ },
38
+ config_kwargs={
39
+ "trust_remote_code": trust,
40
+ },
41
+ processor_kwargs={
42
+ "trust_remote_code": trust,
43
+ "padding": self.config.tokenizer.padding,
44
+ "truncation": self.config.tokenizer.truncation,
45
+ "model_max_length": self.config.tokenizer.max_length,
46
+ },
47
+ )
48
+
49
+ pooling = Pooling(transformer.get_embedding_dimension(), pooling_mode=self.config.pooling)
50
+
51
+ model = SentenceTransformer(modules=[transformer, pooling])
52
+
53
+ _logger.info(f"SentenceTransformer loaded successfully | embedding_dim={transformer.get_embedding_dimension()}")
54
+
55
+ return model
56
+
57
+ def _load_cross_encoder(self) -> "CrossEncoder | None":
58
+ if not self.config.cross_encoder_model_name_or_path:
59
+ _logger.info("No CrossEncoder configured")
60
+ return None
61
+
62
+ _logger.info(f"Loading CrossEncoder | model={self.config.cross_encoder_model_name_or_path}")
63
+
64
+ model = CrossEncoder(
65
+ self.config.cross_encoder_model_name_or_path,
66
+ trust_remote_code=self.config.trust_remote_code,
67
+ )
68
+
69
+ _logger.info("CrossEncoder loaded successfully")
70
+
71
+ return model
72
+
73
+ def mine(self, dataset: HFDataset) -> HFDataset:
74
+ _logger.info(
75
+ f"Starting hard negative mining | "
76
+ f"dataset_size={len(dataset)} | "
77
+ f"num_negatives={self.config.num_negatives} | "
78
+ f"sampling_strategy={self.config.sampling_strategy} | "
79
+ f"use_faiss={self.config.use_faiss}"
80
+ )
81
+
82
+ model = self._load_sentence_transformer()
83
+ cross_encoder = self._load_cross_encoder()
84
+
85
+ try:
86
+ mined = mine_hard_negatives(
87
+ dataset=dataset,
88
+ model=model,
89
+ anchor_column_name=self.config.anchor_column_name,
90
+ positive_column_name=self.config.positive_column_name,
91
+ cross_encoder=cross_encoder,
92
+ range_min=self.config.range_min,
93
+ range_max=self.config.range_max,
94
+ max_score=self.config.max_score,
95
+ min_score=self.config.min_score,
96
+ absolute_margin=self.config.absolute_margin,
97
+ relative_margin=self.config.relative_margin,
98
+ num_negatives=self.config.num_negatives,
99
+ sampling_strategy=self.config.sampling_strategy,
100
+ query_prompt_name=self.config.query_prompt_name,
101
+ query_prompt=self.config.query_prompt,
102
+ corpus_prompt_name=self.config.corpus_prompt_name,
103
+ corpus_prompt=self.config.corpus_prompt,
104
+ include_positives=self.config.include_positives,
105
+ output_format=self.config.output_format,
106
+ output_scores=self.config.output_scores,
107
+ batch_size=self.config.batch_size,
108
+ faiss_batch_size=self.config.faiss_batch_size,
109
+ use_faiss=self.config.use_faiss,
110
+ use_multi_process=self.config.use_multi_process,
111
+ verbose=self.config.verbose,
112
+ cache_folder=self.config.cache_folder,
113
+ )
114
+
115
+ _logger.info(f"Hard negative mining completed successfully | mined_dataset_size={len(mined)}")
116
+ finally:
117
+ _logger.info("Cleaning GPU memory after hard negative mining")
118
+ del model
119
+ if cross_encoder is not None:
120
+ del cross_encoder
121
+ gc.collect()
122
+ if torch.cuda.is_available():
123
+ allocated_before = torch.cuda.memory_allocated() / 1024**3
124
+ reserved_before = torch.cuda.memory_reserved() / 1024**3
125
+ torch.cuda.empty_cache()
126
+ torch.cuda.ipc_collect()
127
+ allocated_after = torch.cuda.memory_allocated() / 1024**3
128
+ reserved_after = torch.cuda.memory_reserved() / 1024**3
129
+ _logger.info(
130
+ f"CUDA memory cleanup completed | "
131
+ f"allocated_before={allocated_before:.2f}GB | "
132
+ f"reserved_before={reserved_before:.2f}GB | "
133
+ f"allocated_after={allocated_after:.2f}GB | "
134
+ f"reserved_after={reserved_after:.2f}GB"
135
+ )
136
+ return mined
@@ -1,9 +1,16 @@
1
1
  from typing import Any
2
2
 
3
3
  import polars as pl
4
+ from datasets import Dataset as HFDataset # type: ignore[import-untyped]
4
5
 
5
- from embed_train.settings import QueryMultiPositiveDatasetSettings, QueryPositiveDatasetSettings
6
+ from embed_train.settings import (
7
+ HardNegativeDatasetSettings,
8
+ QueryMultiPositiveDatasetSettings,
9
+ QueryPositiveDatasetSettings,
10
+ )
6
11
  from embed_train.train.dataset import TorchDataset
12
+ from embed_train.train.dataset.hard_negatives import HardNegativeMiner
13
+ from embed_train.utils import load_class
7
14
 
8
15
 
9
16
  class QueryMultiPositiveDataset(TorchDataset[QueryMultiPositiveDatasetSettings[Any], dict[str, Any]]):
@@ -71,3 +78,36 @@ class QueryPositiveDataset(
71
78
  "query": self._queries[index],
72
79
  "positive": positive,
73
80
  }
81
+
82
+
83
+ class HardNegativeDataset(
84
+ TorchDataset[
85
+ HardNegativeDatasetSettings[Any, Any],
86
+ dict[str, Any],
87
+ ]
88
+ ):
89
+ def __init__(
90
+ self,
91
+ config: HardNegativeDatasetSettings[Any, Any],
92
+ ) -> None:
93
+ super().__init__(config)
94
+ input_dataset = self._to_query_positive_hf_dataset()
95
+ miner = self._load_hard_negative_miner()
96
+ self._dataset = miner.mine(input_dataset)
97
+
98
+ def _to_query_positive_hf_dataset(self) -> HFDataset:
99
+ rows = self.dataset.polars.select(
100
+ pl.col("metadata").struct.field("query").alias("query"),
101
+ pl.col("page_content").alias("positive"),
102
+ )
103
+ return HFDataset.from_list(rows.to_dicts())
104
+
105
+ def _load_hard_negative_miner(self) -> HardNegativeMiner[Any]:
106
+ miner_cls = load_class(self.config.hard_negative_miner.module_path)
107
+ return miner_cls.from_config(self.config.hard_negative_miner)
108
+
109
+ def __len__(self) -> int:
110
+ return len(self._dataset)
111
+
112
+ def __getitem__(self, index: int) -> dict[str, Any]:
113
+ return self._dataset[index]
@@ -64,7 +64,7 @@ class SentenceTransformersTrainer[TCHFTrainRunner: "SentenceTransformersTrainerS
64
64
  _logger.info("Starting SentenceTransformers training...")
65
65
  trainer.train()
66
66
 
67
- def _get_warmup_steps(self, dataset: Dataset) -> float:
67
+ def _get_warmup_steps(self, dataset: Dataset) -> int:
68
68
  train_size = len(dataset)
69
69
  steps_per_epoch = train_size // self.config.per_device_train_batch_size
70
70
  total_steps = steps_per_epoch * self.config.num_epochs
@@ -96,13 +96,13 @@ class SentenceTransformersTrainer[TCHFTrainRunner: "SentenceTransformersTrainerS
96
96
  model_name_or_path=model_path,
97
97
  max_seq_length=self.config.tokenizer.max_length,
98
98
  tokenizer_name_or_path=self.config.tokenizer.name,
99
- model_args={
99
+ model_kwargs={
100
100
  "trust_remote_code": trust,
101
101
  },
102
- config_args={
102
+ config_kwargs={
103
103
  "trust_remote_code": trust,
104
104
  },
105
- tokenizer_args={
105
+ processor_kwargs={
106
106
  "trust_remote_code": trust,
107
107
  "padding": self.config.tokenizer.padding,
108
108
  "truncation": self.config.tokenizer.truncation,
@@ -110,7 +110,7 @@ class SentenceTransformersTrainer[TCHFTrainRunner: "SentenceTransformersTrainerS
110
110
  },
111
111
  )
112
112
  pooling = Pooling(
113
- transformer.get_word_embedding_dimension(),
113
+ transformer.get_embedding_dimension(),
114
114
  pooling_mode_mean_tokens=self.config.pooling == "mean_tokens",
115
115
  pooling_mode_cls_token=self.config.pooling == "cls",
116
116
  pooling_mode_max_tokens=self.config.pooling == "max_tokens",
@@ -189,7 +189,7 @@ class PyTorchTrainer[TCPyTorchTrainer: "PyTorchTrainerSettings[Any, Any, Any, An
189
189
  self.model.load_state_dict(ckpt["model_state"])
190
190
  self.optimizer.load_state_dict(ckpt["optimizer_state"])
191
191
  _logger.info(f"Resuming training from epoch {ckpt['epoch']}")
192
- return ckpt["epoch"] # type:ignore[no-any-return]
192
+ return ckpt["epoch"]
193
193
  return 0
194
194
 
195
195
  def run_epoch(