embed-train 3.1.0__tar.gz → 3.2.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. {embed_train-3.1.0 → embed_train-3.2.0}/CHANGELOG.md +7 -0
  2. {embed_train-3.1.0 → embed_train-3.2.0}/PKG-INFO +7 -6
  3. {embed_train-3.1.0 → embed_train-3.2.0}/pyproject.toml +7 -6
  4. {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/settings.py +1 -1
  5. {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/train/dataset/hard_negatives.py +5 -12
  6. {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/train/trainers/hf/__init__.py +5 -5
  7. {embed_train-3.1.0 → embed_train-3.2.0}/tests/fixtures/components.py +1 -1
  8. {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_push_to_hf.py +23 -0
  9. {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_train/test_collate.py +12 -0
  10. {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_train/test_dataset.py +15 -1
  11. embed_train-3.2.0/tests/unit/test_train/test_hard_negatives.py +117 -0
  12. {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_train/test_hf_trainer.py +1 -1
  13. {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_train/test_loss.py +23 -0
  14. embed_train-3.2.0/tests/unit/test_train/test_torch_datasets.py +83 -0
  15. {embed_train-3.1.0 → embed_train-3.2.0}/uv.lock +18 -61
  16. embed_train-3.1.0/tests/unit/test_train/test_torch_datasets.py +0 -33
  17. {embed_train-3.1.0 → embed_train-3.2.0}/.gitignore +0 -0
  18. {embed_train-3.1.0 → embed_train-3.2.0}/.gitlab-ci.yml +0 -0
  19. {embed_train-3.1.0 → embed_train-3.2.0}/.pre-commit-config.yaml +0 -0
  20. {embed_train-3.1.0 → embed_train-3.2.0}/.releaserc.json +0 -0
  21. {embed_train-3.1.0 → embed_train-3.2.0}/AGENTS.md +0 -0
  22. {embed_train-3.1.0 → embed_train-3.2.0}/Makefile +0 -0
  23. {embed_train-3.1.0 → embed_train-3.2.0}/README.md +0 -0
  24. {embed_train-3.1.0 → embed_train-3.2.0}/codecov.yml +0 -0
  25. {embed_train-3.1.0 → embed_train-3.2.0}/commitlint.config.cjs +0 -0
  26. {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/__init__.py +0 -0
  27. {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/constants.py +0 -0
  28. {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/exceptions.py +0 -0
  29. {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/models/__init__.py +0 -0
  30. {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/push_to_hf/__init__.py +0 -0
  31. {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/py.typed +0 -0
  32. {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/train/__init__.py +0 -0
  33. {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/train/dataset/__init__.py +0 -0
  34. {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/train/dataset/collate.py +0 -0
  35. {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/train/dataset/sampling/__init__.py +0 -0
  36. {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/train/dataset/sampling/samplers.py +0 -0
  37. {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/train/dataset/torch_datasets.py +0 -0
  38. {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/train/trainers/__init__.py +0 -0
  39. {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/train/trainers/torch/__init__.py +0 -0
  40. {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/train/trainers/torch/loss.py +0 -0
  41. {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/utils.py +0 -0
  42. {embed_train-3.1.0 → embed_train-3.2.0}/tests/__init__.py +0 -0
  43. {embed_train-3.1.0 → embed_train-3.2.0}/tests/conftest.py +0 -0
  44. {embed_train-3.1.0 → embed_train-3.2.0}/tests/fixtures/__init__.py +0 -0
  45. {embed_train-3.1.0 → embed_train-3.2.0}/tests/fixtures/data.py +0 -0
  46. {embed_train-3.1.0 → embed_train-3.2.0}/tests/integration/__init__.py +0 -0
  47. {embed_train-3.1.0 → embed_train-3.2.0}/tests/integration/test_dataset/__init__.py +0 -0
  48. {embed_train-3.1.0 → embed_train-3.2.0}/tests/integration/test_dataset/test_to_hf_dataset.py +0 -0
  49. {embed_train-3.1.0 → embed_train-3.2.0}/tests/integration/test_train_runner/__init__.py +0 -0
  50. {embed_train-3.1.0 → embed_train-3.2.0}/tests/integration/test_train_runner/test_train_runner_flow.py +0 -0
  51. {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/__init__.py +0 -0
  52. {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_abstract_guards.py +0 -0
  53. {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_embed_train.py +0 -0
  54. {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_exceptions.py +0 -0
  55. {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_models.py +0 -0
  56. {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_settings.py +0 -0
  57. {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_train/__init__.py +0 -0
  58. {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_train/test_runner.py +0 -0
  59. {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_train/test_samplers.py +0 -0
  60. {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_train/test_sampling.py +0 -0
  61. {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_train/test_torch_trainer.py +0 -0
  62. {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_train/test_trainers.py +0 -0
  63. {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_utils.py +0 -0
@@ -1,3 +1,10 @@
1
+ # [3.2.0](https://gitlab.com/efysent/agentic-core/embed-train/compare/v3.1.0...v3.2.0) (2026-05-12)
2
+
3
+
4
+ ### Features
5
+
6
+ * remove sentence transformers deprecations ([581032b](https://gitlab.com/efysent/agentic-core/embed-train/commit/581032b97a0bfc50236cc0fa64a857bc65f20314))
7
+
1
8
  # [3.1.0](https://gitlab.com/efysent/agentic-core/embed-train/compare/v3.0.0...v3.1.0) (2026-05-12)
2
9
 
3
10
 
@@ -1,14 +1,15 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: embed-train
3
- Version: 3.1.0
3
+ Version: 3.2.0
4
4
  Author-email: jalal <jalalkhaldi3@gmail.com>
5
5
  Requires-Python: <3.13,>=3.11
6
- Requires-Dist: accelerate<2.0.0,>=1.13.0
7
- Requires-Dist: datasets<5.0.0,>=4.5.0
6
+ Requires-Dist: accelerate==1.13.0
7
+ Requires-Dist: datasets==4.8.4
8
8
  Requires-Dist: retrievalbase<3.0.0,>=2.1.0
9
- Requires-Dist: sentence-transformers<6.0.0,>=5.1.2
10
- Requires-Dist: tensorboard<3.0.0,>=2.20.0
11
- Requires-Dist: torch<3.0.0,>=2.9.0
9
+ Requires-Dist: sentence-transformers==5.4.1
10
+ Requires-Dist: tensorboard==2.20.0
11
+ Requires-Dist: torch==2.11.0
12
+ Requires-Dist: transformers==4.57.6
12
13
  Description-Content-Type: text/markdown
13
14
 
14
15
  # embed-train
@@ -1,6 +1,6 @@
1
1
  [project]
2
2
  name = "embed-train"
3
- version = "3.1.0"
3
+ version = "3.2.0"
4
4
  description = ""
5
5
  authors = [
6
6
  { name = "jalal", email = "jalalkhaldi3@gmail.com" }
@@ -9,11 +9,12 @@ readme = "README.md"
9
9
  requires-python = ">=3.11,<3.13"
10
10
 
11
11
  dependencies = [
12
- "torch>=2.9.0,<3.0.0",
13
- "sentence-transformers>=5.1.2,<6.0.0",
14
- "datasets>=4.5.0,<5.0.0",
15
- "tensorboard>=2.20.0,<3.0.0",
16
- "accelerate>=1.13.0,<2.0.0",
12
+ "torch==2.11.0",
13
+ "sentence-transformers==5.4.1",
14
+ "transformers==4.57.6",
15
+ "datasets==4.8.4",
16
+ "tensorboard==2.20.0",
17
+ "accelerate==1.13.0",
17
18
  "retrievalbase>=2.1.0,<3.0.0",
18
19
  ]
19
20
 
@@ -81,7 +81,7 @@ class SentenceTransformerHardNegativeMinerSettings(HardNegativeMinerSettings):
81
81
  model_name_or_path: str
82
82
  cross_encoder_model_name_or_path: str | None = None
83
83
  tokenizer: TokenizerSettings
84
- pooling: Literal["cls", "mean_tokens", "max_tokens"]
84
+ pooling: Literal["cls", "max", "mean", "mean_sqrt_len_tokens", "weightedmean", "lasttoken"]
85
85
  anchor_column_name: str = "query"
86
86
  positive_column_name: str = "positive"
87
87
  range_min: int = 0
@@ -32,13 +32,13 @@ class SentenceTransformerHardNegativeMiner(HardNegativeMiner[SentenceTransformer
32
32
  model_name_or_path=self.config.model_name_or_path,
33
33
  max_seq_length=self.config.tokenizer.max_length,
34
34
  tokenizer_name_or_path=self.config.tokenizer.name,
35
- model_args={
35
+ model_kwargs={
36
36
  "trust_remote_code": trust,
37
37
  },
38
- config_args={
38
+ config_kwargs={
39
39
  "trust_remote_code": trust,
40
40
  },
41
- tokenizer_args={
41
+ processor_kwargs={
42
42
  "trust_remote_code": trust,
43
43
  "padding": self.config.tokenizer.padding,
44
44
  "truncation": self.config.tokenizer.truncation,
@@ -46,18 +46,11 @@ class SentenceTransformerHardNegativeMiner(HardNegativeMiner[SentenceTransformer
46
46
  },
47
47
  )
48
48
 
49
- pooling = Pooling(
50
- transformer.get_word_embedding_dimension(),
51
- pooling_mode_mean_tokens=self.config.pooling == "mean_tokens",
52
- pooling_mode_cls_token=self.config.pooling == "cls",
53
- pooling_mode_max_tokens=self.config.pooling == "max_tokens",
54
- )
49
+ pooling = Pooling(transformer.get_embedding_dimension(), pooling_mode=self.config.pooling)
55
50
 
56
51
  model = SentenceTransformer(modules=[transformer, pooling])
57
52
 
58
- _logger.info(
59
- f"SentenceTransformer loaded successfully | embedding_dim={transformer.get_word_embedding_dimension()}"
60
- )
53
+ _logger.info(f"SentenceTransformer loaded successfully | embedding_dim={transformer.get_embedding_dimension()}")
61
54
 
62
55
  return model
63
56
 
@@ -64,7 +64,7 @@ class SentenceTransformersTrainer[TCHFTrainRunner: "SentenceTransformersTrainerS
64
64
  _logger.info("Starting SentenceTransformers training...")
65
65
  trainer.train()
66
66
 
67
- def _get_warmup_steps(self, dataset: Dataset) -> float:
67
+ def _get_warmup_steps(self, dataset: Dataset) -> int:
68
68
  train_size = len(dataset)
69
69
  steps_per_epoch = train_size // self.config.per_device_train_batch_size
70
70
  total_steps = steps_per_epoch * self.config.num_epochs
@@ -96,13 +96,13 @@ class SentenceTransformersTrainer[TCHFTrainRunner: "SentenceTransformersTrainerS
96
96
  model_name_or_path=model_path,
97
97
  max_seq_length=self.config.tokenizer.max_length,
98
98
  tokenizer_name_or_path=self.config.tokenizer.name,
99
- model_args={
99
+ model_kwargs={
100
100
  "trust_remote_code": trust,
101
101
  },
102
- config_args={
102
+ config_kwargs={
103
103
  "trust_remote_code": trust,
104
104
  },
105
- tokenizer_args={
105
+ processor_kwargs={
106
106
  "trust_remote_code": trust,
107
107
  "padding": self.config.tokenizer.padding,
108
108
  "truncation": self.config.tokenizer.truncation,
@@ -110,7 +110,7 @@ class SentenceTransformersTrainer[TCHFTrainRunner: "SentenceTransformersTrainerS
110
110
  },
111
111
  )
112
112
  pooling = Pooling(
113
- transformer.get_word_embedding_dimension(),
113
+ transformer.get_embedding_dimension(),
114
114
  pooling_mode_mean_tokens=self.config.pooling == "mean_tokens",
115
115
  pooling_mode_cls_token=self.config.pooling == "cls",
116
116
  pooling_mode_max_tokens=self.config.pooling == "max_tokens",
@@ -264,7 +264,7 @@ def build_hard_negative_miner_settings(**overrides: Any) -> SentenceTransformerH
264
264
  "module_path": "embed_train.train.dataset.hard_negatives.SentenceTransformerHardNegativeMiner",
265
265
  "model_name_or_path": "dummy-miner",
266
266
  "tokenizer": build_tokenizer_settings(),
267
- "pooling": "mean_tokens",
267
+ "pooling": "mean",
268
268
  "range_min": 0,
269
269
  "range_max": 2,
270
270
  "relative_margin": 0.1,
@@ -158,6 +158,29 @@ def test_save_model_locally_passes_expected_arguments(tmp_path: Path) -> None:
158
158
  ]
159
159
 
160
160
 
161
+ def test_ensure_hf_repo_exists_creates_repo_when_enabled(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
162
+ config = build_push_to_hf_runner_settings(tmp_path, create_repo=True)
163
+ runner = PushToHFRunner(config)
164
+ payloads: list[dict[str, object]] = []
165
+
166
+ class DummyApi:
167
+ def create_repo(self, **kwargs: object) -> None:
168
+ payloads.append(kwargs)
169
+
170
+ monkeypatch.setattr("embed_train.push_to_hf.HfApi", DummyApi)
171
+
172
+ runner._ensure_hf_repo_exists()
173
+
174
+ assert payloads == [
175
+ {
176
+ "repo_id": config.hf.repo,
177
+ "repo_type": "model",
178
+ "private": config.hf.private,
179
+ "exist_ok": True,
180
+ }
181
+ ]
182
+
183
+
161
184
  def test_push_repo_to_hf_calls_upload_folder(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
162
185
  config = build_push_to_hf_runner_settings(tmp_path)
163
186
  runner = PushToHFRunner(config)
@@ -69,3 +69,15 @@ def test_hard_negative_collate_processes_negative_columns(monkeypatch) -> None:
69
69
 
70
70
  assert cast(Any, q_tok)["input_ids"].shape == torch.Size([2, 3])
71
71
  assert cast(Any, c_tok)["input_ids"].shape == torch.Size([6, 3])
72
+
73
+
74
+ def test_hard_negative_collate_processes_single_negative(monkeypatch) -> None:
75
+ monkeypatch.setattr(
76
+ "embed_train.train.dataset.AutoTokenizer.from_pretrained", lambda *args, **kwargs: DummyTokenizer()
77
+ )
78
+ collate = HardNegativeCollateFn(build_hard_negative_collate_settings(), context=None)
79
+
80
+ q_tok, c_tok = collate(cast(Any, [{"query": "q1", "positive": "p1", "negative": "n1"}]))
81
+
82
+ assert cast(Any, q_tok)["input_ids"].shape == torch.Size([1, 3])
83
+ assert cast(Any, c_tok)["input_ids"].shape == torch.Size([2, 3])
@@ -2,13 +2,16 @@ from __future__ import annotations
2
2
 
3
3
  from typing import Any, cast
4
4
 
5
+ import pytest
5
6
  import torch
7
+ from datasets import Dataset
6
8
 
7
- from embed_train.train.dataset import CollateFn
9
+ from embed_train.train.dataset import CollateFn, HardNegativeMiner
8
10
  from tests.fixtures.components import (
9
11
  DummyDatasetConnector,
10
12
  DummyTokenizer,
11
13
  build_collate_settings,
14
+ build_hard_negative_miner_settings,
12
15
  build_multi_positive_torch_dataset_settings,
13
16
  )
14
17
 
@@ -62,3 +65,14 @@ def test_torch_dataset_loads_runtime_dataset_and_converts_to_hf() -> None:
62
65
  assert len(dataset) == 1
63
66
  assert hf_dataset[0]["query"] == "query-a"
64
67
  assert hf_dataset[0]["positives"] == ["doc-1", "doc-2"]
68
+
69
+
70
+ def test_hard_negative_miner_base_mine_raises_not_implemented() -> None:
71
+ class ConcreteHardNegativeMiner(HardNegativeMiner):
72
+ def mine(self, dataset: Dataset) -> Dataset:
73
+ return super().mine(dataset)
74
+
75
+ miner = ConcreteHardNegativeMiner(build_hard_negative_miner_settings())
76
+
77
+ with pytest.raises(NotImplementedError):
78
+ miner.mine(Dataset.from_list([]))
@@ -0,0 +1,117 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ from datasets import Dataset
6
+
7
+ from embed_train.train.dataset.hard_negatives import SentenceTransformerHardNegativeMiner
8
+ from tests.fixtures.components import build_hard_negative_miner_settings
9
+
10
+
11
+ class DummyTransformer:
12
+ calls: list[dict[str, Any]] = []
13
+
14
+ def __init__(self, **kwargs: Any) -> None:
15
+ self.calls.append(kwargs)
16
+
17
+ def get_embedding_dimension(self) -> int:
18
+ return 8
19
+
20
+
21
+ class DummyPooling:
22
+ calls: list[dict[str, Any]] = []
23
+
24
+ def __init__(self, embedding_dimension: int, *, pooling_mode: str) -> None:
25
+ self.calls.append({"embedding_dimension": embedding_dimension, "pooling_mode": pooling_mode})
26
+
27
+
28
+ class DummySentenceTransformer:
29
+ calls: list[list[object]] = []
30
+
31
+ def __init__(self, *, modules: list[object]) -> None:
32
+ self.calls.append(modules)
33
+
34
+
35
+ class DummyCrossEncoder:
36
+ calls: list[dict[str, Any]] = []
37
+
38
+ def __init__(self, model_name_or_path: str, *, trust_remote_code: bool) -> None:
39
+ self.calls.append(
40
+ {
41
+ "model_name_or_path": model_name_or_path,
42
+ "trust_remote_code": trust_remote_code,
43
+ }
44
+ )
45
+
46
+
47
+ def test_sentence_transformer_hard_negative_miner_mines_with_config(
48
+ monkeypatch,
49
+ ) -> None:
50
+ captured: dict[str, Any] = {}
51
+ cleanup_calls: list[str] = []
52
+ mined = Dataset.from_list([{"query": "q1", "positive": "p1", "negative": "n1"}])
53
+
54
+ def fake_mine_hard_negatives(**kwargs: Any) -> Dataset:
55
+ captured.update(kwargs)
56
+ return mined
57
+
58
+ class DummyCuda:
59
+ @staticmethod
60
+ def is_available() -> bool:
61
+ return True
62
+
63
+ @staticmethod
64
+ def memory_allocated() -> int:
65
+ cleanup_calls.append("allocated")
66
+ return 1024**3
67
+
68
+ @staticmethod
69
+ def memory_reserved() -> int:
70
+ cleanup_calls.append("reserved")
71
+ return 2 * 1024**3
72
+
73
+ @staticmethod
74
+ def empty_cache() -> None:
75
+ cleanup_calls.append("empty_cache")
76
+
77
+ @staticmethod
78
+ def ipc_collect() -> None:
79
+ cleanup_calls.append("ipc_collect")
80
+
81
+ monkeypatch.setattr("embed_train.train.dataset.hard_negatives.Transformer", DummyTransformer)
82
+ monkeypatch.setattr("embed_train.train.dataset.hard_negatives.Pooling", DummyPooling)
83
+ monkeypatch.setattr("embed_train.train.dataset.hard_negatives.SentenceTransformer", DummySentenceTransformer)
84
+ monkeypatch.setattr("embed_train.train.dataset.hard_negatives.CrossEncoder", DummyCrossEncoder)
85
+ monkeypatch.setattr("embed_train.train.dataset.hard_negatives.mine_hard_negatives", fake_mine_hard_negatives)
86
+ monkeypatch.setattr("embed_train.train.dataset.hard_negatives.torch.cuda", DummyCuda)
87
+
88
+ config = build_hard_negative_miner_settings(
89
+ cross_encoder_model_name_or_path="dummy-reranker",
90
+ trust_remote_code=True,
91
+ )
92
+ miner = SentenceTransformerHardNegativeMiner(config)
93
+
94
+ result = miner.mine(Dataset.from_list([{"query": "q1", "positive": "p1"}]))
95
+
96
+ assert result is mined
97
+ assert DummyTransformer.calls[-1]["model_name_or_path"] == "dummy-miner"
98
+ assert DummyTransformer.calls[-1]["processor_kwargs"]["model_max_length"] == config.tokenizer.max_length
99
+ assert DummyPooling.calls[-1] == {"embedding_dimension": 8, "pooling_mode": "mean"}
100
+ assert len(DummySentenceTransformer.calls[-1]) == 2
101
+ assert DummyCrossEncoder.calls[-1] == {
102
+ "model_name_or_path": "dummy-reranker",
103
+ "trust_remote_code": True,
104
+ }
105
+ assert captured["anchor_column_name"] == "query"
106
+ assert captured["positive_column_name"] == "positive"
107
+ assert captured["cross_encoder"] is not None
108
+ assert captured["num_negatives"] == 2
109
+ assert captured["sampling_strategy"] == "top"
110
+ assert "empty_cache" in cleanup_calls
111
+ assert "ipc_collect" in cleanup_calls
112
+
113
+
114
+ def test_sentence_transformer_hard_negative_miner_skips_cross_encoder() -> None:
115
+ miner = SentenceTransformerHardNegativeMiner(build_hard_negative_miner_settings())
116
+
117
+ assert miner._load_cross_encoder() is None
@@ -19,7 +19,7 @@ class DummyTransformer:
19
19
  def __init__(self, **kwargs):
20
20
  self.kwargs = kwargs
21
21
 
22
- def get_word_embedding_dimension(self) -> int:
22
+ def get_embedding_dimension(self) -> int:
23
23
  return 16
24
24
 
25
25
 
@@ -58,3 +58,26 @@ def test_hard_negative_contrastive_loss_rejects_invalid_candidate_layout() -> No
58
58
 
59
59
  with pytest.raises(EmbedTrainValueError, match="one positive and at least one negative"):
60
60
  loss(q_emb, c_emb)
61
+
62
+
63
+ def test_hard_negative_contrastive_loss_returns_scalar_for_valid_layout() -> None:
64
+ loss = HardNegativeContrastiveLoss(
65
+ HardNegativeContrastiveLossSettings(
66
+ module_path="embed_train.train.trainers.torch.loss.HardNegativeContrastiveLoss",
67
+ temperature=0.5,
68
+ )
69
+ )
70
+ q_emb = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
71
+ c_emb = torch.tensor(
72
+ [
73
+ [1.0, 0.0],
74
+ [0.0, 1.0],
75
+ [0.0, 1.0],
76
+ [1.0, 0.0],
77
+ ]
78
+ )
79
+
80
+ result = loss(q_emb, c_emb)
81
+
82
+ assert result.ndim == 0
83
+ assert result.item() >= 0
@@ -0,0 +1,83 @@
1
+ from __future__ import annotations
2
+
3
+ import pytest
4
+ from datasets import Dataset
5
+
6
+ from tests.fixtures.components import (
7
+ DummyDatasetConnector,
8
+ build_hard_negative_torch_dataset_settings,
9
+ build_multi_positive_torch_dataset_settings,
10
+ build_torch_dataset_settings,
11
+ )
12
+ from tests.fixtures.data import build_query_rows
13
+
14
+
15
+ def test_query_multi_positive_dataset_groups_rows() -> None:
16
+ from embed_train.train.dataset.torch_datasets import QueryMultiPositiveDataset
17
+
18
+ DummyDatasetConnector.rows = build_query_rows()
19
+ dataset = QueryMultiPositiveDataset(build_multi_positive_torch_dataset_settings())
20
+
21
+ assert len(dataset) == 2
22
+ grouped = {item["query"]: item["positives"] for item in (dataset[i] for i in range(len(dataset)))}
23
+ assert grouped["query-a"] == ["doc-1", "doc-2"]
24
+ assert grouped["query-b"] == ["doc-3"]
25
+
26
+
27
+ def test_query_positive_dataset_flattens_rows() -> None:
28
+ from embed_train.train.dataset.torch_datasets import QueryPositiveDataset
29
+
30
+ DummyDatasetConnector.rows = build_query_rows()
31
+ dataset = QueryPositiveDataset(build_torch_dataset_settings())
32
+
33
+ assert len(dataset) == 3
34
+ rows = {tuple(item.values()) for item in (dataset[i] for i in range(len(dataset)))}
35
+ assert ("query-a", "doc-1") in rows
36
+ assert ("query-a", "doc-2") in rows
37
+ assert ("query-b", "doc-3") in rows
38
+
39
+
40
+ def test_query_positive_dataset_rejects_empty_positive() -> None:
41
+ from embed_train.train.dataset.torch_datasets import QueryPositiveDataset
42
+
43
+ DummyDatasetConnector.rows = [
44
+ {
45
+ "page_content": "",
46
+ "metadata": {"query": "query-a", "step_range": (0, 1), "section": "alpha"},
47
+ }
48
+ ]
49
+ dataset = QueryPositiveDataset(build_torch_dataset_settings())
50
+
51
+ with pytest.raises(ValueError, match="No positive passage found for index 0"):
52
+ dataset[0]
53
+
54
+
55
+ class DummyHardNegativeMiner:
56
+ seen_dataset: Dataset | None = None
57
+
58
+ @classmethod
59
+ def from_config(cls, _config: object) -> DummyHardNegativeMiner:
60
+ return cls()
61
+
62
+ def mine(self, dataset: Dataset) -> Dataset:
63
+ self.__class__.seen_dataset = dataset
64
+ return Dataset.from_list(
65
+ [
66
+ {"query": "query-a", "positive": "doc-1", "negative": "doc-2"},
67
+ {"query": "query-b", "positive": "doc-3", "negative": "doc-1"},
68
+ ]
69
+ )
70
+
71
+
72
+ def test_hard_negative_dataset_mines_and_exposes_rows(monkeypatch: pytest.MonkeyPatch) -> None:
73
+ from embed_train.train.dataset.torch_datasets import HardNegativeDataset
74
+
75
+ DummyDatasetConnector.rows = build_query_rows()
76
+ monkeypatch.setattr("embed_train.train.dataset.torch_datasets.load_class", lambda _: DummyHardNegativeMiner)
77
+
78
+ dataset = HardNegativeDataset(build_hard_negative_torch_dataset_settings())
79
+
80
+ assert len(dataset) == 2
81
+ assert dataset[0] == {"query": "query-a", "positive": "doc-1", "negative": "doc-2"}
82
+ assert DummyHardNegativeMiner.seen_dataset is not None
83
+ assert DummyHardNegativeMiner.seen_dataset[0] == {"query": "query-a", "positive": "doc-1"}
@@ -110,15 +110,6 @@ wheels = [
110
110
  { url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" },
111
111
  ]
112
112
 
113
- [[package]]
114
- name = "annotated-doc"
115
- version = "0.0.4"
116
- source = { registry = "https://pypi.org/simple" }
117
- sdist = { url = "https://files.pythonhosted.org/packages/57/ba/046ceea27344560984e26a590f90bc7f4a75b06701f653222458922b558c/annotated_doc-0.0.4.tar.gz", hash = "sha256:fbcda96e87e9c92ad167c2e53839e57503ecfda18804ea28102353485033faa4", size = 7288, upload-time = "2025-11-10T22:07:42.062Z" }
118
- wheels = [
119
- { url = "https://files.pythonhosted.org/packages/1e/d3/26bf1008eb3d2daa8ef4cacc7f3bfdc11818d111f7e2d0201bc6e3b49d45/annotated_doc-0.0.4-py3-none-any.whl", hash = "sha256:571ac1dc6991c450b25a9c2d84a3705e2ae7a53467b5d111c24fa8baabbed320", size = 5303, upload-time = "2025-11-10T22:07:40.673Z" },
120
- ]
121
-
122
113
  [[package]]
123
114
  name = "annotated-types"
124
115
  version = "0.7.0"
@@ -284,18 +275,6 @@ wheels = [
284
275
  { url = "https://files.pythonhosted.org/packages/db/8f/61959034484a4a7c527811f4721e75d02d653a35afb0b6054474d8185d4c/charset_normalizer-3.4.7-py3-none-any.whl", hash = "sha256:3dce51d0f5e7951f8bb4900c257dad282f49190fdbebecd4ba99bcc41fef404d", size = 61958, upload-time = "2026-04-02T09:28:37.794Z" },
285
276
  ]
286
277
 
287
- [[package]]
288
- name = "click"
289
- version = "8.3.2"
290
- source = { registry = "https://pypi.org/simple" }
291
- dependencies = [
292
- { name = "colorama", marker = "sys_platform == 'win32'" },
293
- ]
294
- sdist = { url = "https://files.pythonhosted.org/packages/57/75/31212c6bf2503fdf920d87fee5d7a86a2e3bcf444984126f13d8e4016804/click-8.3.2.tar.gz", hash = "sha256:14162b8b3b3550a7d479eafa77dfd3c38d9dc8951f6f69c78913a8f9a7540fd5", size = 302856, upload-time = "2026-04-03T19:14:45.118Z" }
295
- wheels = [
296
- { url = "https://files.pythonhosted.org/packages/e4/20/71885d8b97d4f3dde17b1fdb92dbd4908b00541c5a3379787137285f602e/click-8.3.2-py3-none-any.whl", hash = "sha256:1924d2c27c5653561cd2cae4548d1406039cb79b858b747cfea24924bbc1616d", size = 108379, upload-time = "2026-04-03T19:14:43.505Z" },
297
- ]
298
-
299
278
  [[package]]
300
279
  name = "colorama"
301
280
  version = "0.4.6"
@@ -459,7 +438,7 @@ wheels = [
459
438
 
460
439
  [[package]]
461
440
  name = "embed-train"
462
- version = "3.1.0"
441
+ version = "3.2.0"
463
442
  source = { editable = "." }
464
443
  dependencies = [
465
444
  { name = "accelerate" },
@@ -468,6 +447,7 @@ dependencies = [
468
447
  { name = "sentence-transformers" },
469
448
  { name = "tensorboard" },
470
449
  { name = "torch" },
450
+ { name = "transformers" },
471
451
  ]
472
452
 
473
453
  [package.dev-dependencies]
@@ -484,12 +464,13 @@ dev = [
484
464
 
485
465
  [package.metadata]
486
466
  requires-dist = [
487
- { name = "accelerate", specifier = ">=1.13.0,<2.0.0" },
488
- { name = "datasets", specifier = ">=4.5.0,<5.0.0" },
467
+ { name = "accelerate", specifier = "==1.13.0" },
468
+ { name = "datasets", specifier = "==4.8.4" },
489
469
  { name = "retrievalbase", specifier = ">=2.1.0,<3.0.0" },
490
- { name = "sentence-transformers", specifier = ">=5.1.2,<6.0.0" },
491
- { name = "tensorboard", specifier = ">=2.20.0,<3.0.0" },
492
- { name = "torch", specifier = ">=2.9.0,<3.0.0" },
470
+ { name = "sentence-transformers", specifier = "==5.4.1" },
471
+ { name = "tensorboard", specifier = "==2.20.0" },
472
+ { name = "torch", specifier = "==2.11.0" },
473
+ { name = "transformers", specifier = "==4.57.6" },
493
474
  ]
494
475
 
495
476
  [package.metadata.requires-dev]
@@ -702,22 +683,21 @@ http2 = [
702
683
 
703
684
  [[package]]
704
685
  name = "huggingface-hub"
705
- version = "1.11.0"
686
+ version = "0.36.2"
706
687
  source = { registry = "https://pypi.org/simple" }
707
688
  dependencies = [
708
689
  { name = "filelock" },
709
690
  { name = "fsspec" },
710
- { name = "hf-xet", marker = "platform_machine == 'AMD64' or platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'" },
711
- { name = "httpx" },
691
+ { name = "hf-xet", marker = "platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'" },
712
692
  { name = "packaging" },
713
693
  { name = "pyyaml" },
694
+ { name = "requests" },
714
695
  { name = "tqdm" },
715
- { name = "typer" },
716
696
  { name = "typing-extensions" },
717
697
  ]
718
- sdist = { url = "https://files.pythonhosted.org/packages/dc/89/e7aa12d8a6b9259bed10671abb25ae6fa437c0f88a86ecbf59617bae7759/huggingface_hub-1.11.0.tar.gz", hash = "sha256:15fb3713c7f9cdff7b808a94fd91664f661ab142796bb48c9cd9493e8d166278", size = 761749, upload-time = "2026-04-16T13:07:39.73Z" }
698
+ sdist = { url = "https://files.pythonhosted.org/packages/7c/b7/8cb61d2eece5fb05a83271da168186721c450eb74e3c31f7ef3169fa475b/huggingface_hub-0.36.2.tar.gz", hash = "sha256:1934304d2fb224f8afa3b87007d58501acfda9215b334eed53072dd5e815ff7a", size = 649782, upload-time = "2026-02-06T09:24:13.098Z" }
719
699
  wheels = [
720
- { url = "https://files.pythonhosted.org/packages/37/02/4f3f8997d1ea7fe0146b343e5e14bd065fa87af790d07e5576d31b31cc18/huggingface_hub-1.11.0-py3-none-any.whl", hash = "sha256:42a6de0afbfeb5e022222d36398f029679db4eb4778801aafda32257ae9131ab", size = 645499, upload-time = "2026-04-16T13:07:37.716Z" },
700
+ { url = "https://files.pythonhosted.org/packages/a8/af/48ac8483240de756d2438c380746e7130d1c6f75802ef22f3c6d49982787/huggingface_hub-0.36.2-py3-none-any.whl", hash = "sha256:48f0c8eac16145dfce371e9d2d7772854a4f591bcb56c9cf548accf531d54270", size = 566395, upload-time = "2026-02-06T09:24:11.133Z" },
721
701
  ]
722
702
 
723
703
  [[package]]
@@ -2085,15 +2065,6 @@ wheels = [
2085
2065
  { url = "https://files.pythonhosted.org/packages/e1/e3/c164c88b2e5ce7b24d667b9bd83589cf4f3520d97cad01534cd3c4f55fdb/setuptools-81.0.0-py3-none-any.whl", hash = "sha256:fdd925d5c5d9f62e4b74b30d6dd7828ce236fd6ed998a08d81de62ce5a6310d6", size = 1062021, upload-time = "2026-02-06T21:10:37.175Z" },
2086
2066
  ]
2087
2067
 
2088
- [[package]]
2089
- name = "shellingham"
2090
- version = "1.5.4"
2091
- source = { registry = "https://pypi.org/simple" }
2092
- sdist = { url = "https://files.pythonhosted.org/packages/58/15/8b3609fd3830ef7b27b655beb4b4e9c62313a4e8da8c676e142cc210d58e/shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de", size = 10310, upload-time = "2023-10-24T04:13:40.426Z" }
2093
- wheels = [
2094
- { url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755, upload-time = "2023-10-24T04:13:38.866Z" },
2095
- ]
2096
-
2097
2068
  [[package]]
2098
2069
  name = "six"
2099
2070
  version = "1.17.0"
@@ -2279,22 +2250,23 @@ wheels = [
2279
2250
 
2280
2251
  [[package]]
2281
2252
  name = "transformers"
2282
- version = "5.5.4"
2253
+ version = "4.57.6"
2283
2254
  source = { registry = "https://pypi.org/simple" }
2284
2255
  dependencies = [
2256
+ { name = "filelock" },
2285
2257
  { name = "huggingface-hub" },
2286
2258
  { name = "numpy" },
2287
2259
  { name = "packaging" },
2288
2260
  { name = "pyyaml" },
2289
2261
  { name = "regex" },
2262
+ { name = "requests" },
2290
2263
  { name = "safetensors" },
2291
2264
  { name = "tokenizers" },
2292
2265
  { name = "tqdm" },
2293
- { name = "typer" },
2294
2266
  ]
2295
- sdist = { url = "https://files.pythonhosted.org/packages/a5/1e/1e244ab2ab50a863e6b52cc55761910567fa532b69a6740f6e99c5fdbd98/transformers-5.5.4.tar.gz", hash = "sha256:2e67cadba81fc7608cc07c4dd54f524820bc3d95b1cabd0ef3db7733c4f8b82e", size = 8227649, upload-time = "2026-04-13T16:55:55.181Z" }
2267
+ sdist = { url = "https://files.pythonhosted.org/packages/c4/35/67252acc1b929dc88b6602e8c4a982e64f31e733b804c14bc24b47da35e6/transformers-4.57.6.tar.gz", hash = "sha256:55e44126ece9dc0a291521b7e5492b572e6ef2766338a610b9ab5afbb70689d3", size = 10134912, upload-time = "2026-01-16T10:38:39.284Z" }
2296
2268
  wheels = [
2297
- { url = "https://files.pythonhosted.org/packages/29/fb/162a66789c65e5afa3b051309240c26bf37fbc8fea285b4546ae747995a2/transformers-5.5.4-py3-none-any.whl", hash = "sha256:0bd6281b82966fe5a7a16f553ea517a9db1dee6284d7cb224dfd88fc0dd1c167", size = 10236696, upload-time = "2026-04-13T16:55:51.497Z" },
2269
+ { url = "https://files.pythonhosted.org/packages/03/b8/e484ef633af3887baeeb4b6ad12743363af7cce68ae51e938e00aaa0529d/transformers-4.57.6-py3-none-any.whl", hash = "sha256:4c9e9de11333ddfe5114bc872c9f370509198acf0b87a832a0ab9458e2bd0550", size = 11993498, upload-time = "2026-01-16T10:38:31.289Z" },
2298
2270
  ]
2299
2271
 
2300
2272
  [[package]]
@@ -2332,21 +2304,6 @@ wheels = [
2332
2304
  { url = "https://files.pythonhosted.org/packages/88/39/bca669095ccf0a400af941fdf741578d4c2d6719f1b7f10e6dbec10aa862/ty-0.0.31-py3-none-win_arm64.whl", hash = "sha256:e9cb15fad26545c6a608f40f227af3a5513cb376998ca6feddd47ca7d93ffafa", size = 10590392, upload-time = "2026-04-15T15:47:57.968Z" },
2333
2305
  ]
2334
2306
 
2335
- [[package]]
2336
- name = "typer"
2337
- version = "0.24.1"
2338
- source = { registry = "https://pypi.org/simple" }
2339
- dependencies = [
2340
- { name = "annotated-doc" },
2341
- { name = "click" },
2342
- { name = "rich" },
2343
- { name = "shellingham" },
2344
- ]
2345
- sdist = { url = "https://files.pythonhosted.org/packages/f5/24/cb09efec5cc954f7f9b930bf8279447d24618bb6758d4f6adf2574c41780/typer-0.24.1.tar.gz", hash = "sha256:e39b4732d65fbdcde189ae76cf7cd48aeae72919dea1fdfc16593be016256b45", size = 118613, upload-time = "2026-02-21T16:54:40.609Z" }
2346
- wheels = [
2347
- { url = "https://files.pythonhosted.org/packages/4a/91/48db081e7a63bb37284f9fbcefda7c44c277b18b0e13fbc36ea2335b71e6/typer-0.24.1-py3-none-any.whl", hash = "sha256:112c1f0ce578bfb4cab9ffdabc68f031416ebcc216536611ba21f04e9aa84c9e", size = 56085, upload-time = "2026-02-21T16:54:41.616Z" },
2348
- ]
2349
-
2350
2307
  [[package]]
2351
2308
  name = "types-pyyaml"
2352
2309
  version = "6.0.12.20260408"
@@ -1,33 +0,0 @@
1
- from __future__ import annotations
2
-
3
- from tests.fixtures.components import (
4
- DummyDatasetConnector,
5
- build_multi_positive_torch_dataset_settings,
6
- build_torch_dataset_settings,
7
- )
8
- from tests.fixtures.data import build_query_rows
9
-
10
-
11
- def test_query_multi_positive_dataset_groups_rows() -> None:
12
- from embed_train.train.dataset.torch_datasets import QueryMultiPositiveDataset
13
-
14
- DummyDatasetConnector.rows = build_query_rows()
15
- dataset = QueryMultiPositiveDataset(build_multi_positive_torch_dataset_settings())
16
-
17
- assert len(dataset) == 2
18
- grouped = {item["query"]: item["positives"] for item in (dataset[i] for i in range(len(dataset)))}
19
- assert grouped["query-a"] == ["doc-1", "doc-2"]
20
- assert grouped["query-b"] == ["doc-3"]
21
-
22
-
23
- def test_query_positive_dataset_flattens_rows() -> None:
24
- from embed_train.train.dataset.torch_datasets import QueryPositiveDataset
25
-
26
- DummyDatasetConnector.rows = build_query_rows()
27
- dataset = QueryPositiveDataset(build_torch_dataset_settings())
28
-
29
- assert len(dataset) == 3
30
- rows = {tuple(item.values()) for item in (dataset[i] for i in range(len(dataset)))}
31
- assert ("query-a", "doc-1") in rows
32
- assert ("query-a", "doc-2") in rows
33
- assert ("query-b", "doc-3") in rows
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes