embed-train 3.1.0__tar.gz → 3.2.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {embed_train-3.1.0 → embed_train-3.2.0}/CHANGELOG.md +7 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/PKG-INFO +7 -6
- {embed_train-3.1.0 → embed_train-3.2.0}/pyproject.toml +7 -6
- {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/settings.py +1 -1
- {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/train/dataset/hard_negatives.py +5 -12
- {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/train/trainers/hf/__init__.py +5 -5
- {embed_train-3.1.0 → embed_train-3.2.0}/tests/fixtures/components.py +1 -1
- {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_push_to_hf.py +23 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_train/test_collate.py +12 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_train/test_dataset.py +15 -1
- embed_train-3.2.0/tests/unit/test_train/test_hard_negatives.py +117 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_train/test_hf_trainer.py +1 -1
- {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_train/test_loss.py +23 -0
- embed_train-3.2.0/tests/unit/test_train/test_torch_datasets.py +83 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/uv.lock +18 -61
- embed_train-3.1.0/tests/unit/test_train/test_torch_datasets.py +0 -33
- {embed_train-3.1.0 → embed_train-3.2.0}/.gitignore +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/.gitlab-ci.yml +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/.pre-commit-config.yaml +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/.releaserc.json +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/AGENTS.md +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/Makefile +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/README.md +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/codecov.yml +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/commitlint.config.cjs +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/__init__.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/constants.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/exceptions.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/models/__init__.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/push_to_hf/__init__.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/py.typed +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/train/__init__.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/train/dataset/__init__.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/train/dataset/collate.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/train/dataset/sampling/__init__.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/train/dataset/sampling/samplers.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/train/dataset/torch_datasets.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/train/trainers/__init__.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/train/trainers/torch/__init__.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/train/trainers/torch/loss.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/src/embed_train/utils.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/tests/__init__.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/tests/conftest.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/tests/fixtures/__init__.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/tests/fixtures/data.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/tests/integration/__init__.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/tests/integration/test_dataset/__init__.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/tests/integration/test_dataset/test_to_hf_dataset.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/tests/integration/test_train_runner/__init__.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/tests/integration/test_train_runner/test_train_runner_flow.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/__init__.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_abstract_guards.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_embed_train.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_exceptions.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_models.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_settings.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_train/__init__.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_train/test_runner.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_train/test_samplers.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_train/test_sampling.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_train/test_torch_trainer.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_train/test_trainers.py +0 -0
- {embed_train-3.1.0 → embed_train-3.2.0}/tests/unit/test_utils.py +0 -0
|
@@ -1,3 +1,10 @@
|
|
|
1
|
+
# [3.2.0](https://gitlab.com/efysent/agentic-core/embed-train/compare/v3.1.0...v3.2.0) (2026-05-12)
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
### Features
|
|
5
|
+
|
|
6
|
+
* remove sentence transformers deprecations ([581032b](https://gitlab.com/efysent/agentic-core/embed-train/commit/581032b97a0bfc50236cc0fa64a857bc65f20314))
|
|
7
|
+
|
|
1
8
|
# [3.1.0](https://gitlab.com/efysent/agentic-core/embed-train/compare/v3.0.0...v3.1.0) (2026-05-12)
|
|
2
9
|
|
|
3
10
|
|
|
@@ -1,14 +1,15 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: embed-train
|
|
3
|
-
Version: 3.
|
|
3
|
+
Version: 3.2.0
|
|
4
4
|
Author-email: jalal <jalalkhaldi3@gmail.com>
|
|
5
5
|
Requires-Python: <3.13,>=3.11
|
|
6
|
-
Requires-Dist: accelerate
|
|
7
|
-
Requires-Dist: datasets
|
|
6
|
+
Requires-Dist: accelerate==1.13.0
|
|
7
|
+
Requires-Dist: datasets==4.8.4
|
|
8
8
|
Requires-Dist: retrievalbase<3.0.0,>=2.1.0
|
|
9
|
-
Requires-Dist: sentence-transformers
|
|
10
|
-
Requires-Dist: tensorboard
|
|
11
|
-
Requires-Dist: torch
|
|
9
|
+
Requires-Dist: sentence-transformers==5.4.1
|
|
10
|
+
Requires-Dist: tensorboard==2.20.0
|
|
11
|
+
Requires-Dist: torch==2.11.0
|
|
12
|
+
Requires-Dist: transformers==4.57.6
|
|
12
13
|
Description-Content-Type: text/markdown
|
|
13
14
|
|
|
14
15
|
# embed-train
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "embed-train"
|
|
3
|
-
version = "3.
|
|
3
|
+
version = "3.2.0"
|
|
4
4
|
description = ""
|
|
5
5
|
authors = [
|
|
6
6
|
{ name = "jalal", email = "jalalkhaldi3@gmail.com" }
|
|
@@ -9,11 +9,12 @@ readme = "README.md"
|
|
|
9
9
|
requires-python = ">=3.11,<3.13"
|
|
10
10
|
|
|
11
11
|
dependencies = [
|
|
12
|
-
"torch
|
|
13
|
-
"sentence-transformers
|
|
14
|
-
"
|
|
15
|
-
"
|
|
16
|
-
"
|
|
12
|
+
"torch==2.11.0",
|
|
13
|
+
"sentence-transformers==5.4.1",
|
|
14
|
+
"transformers==4.57.6",
|
|
15
|
+
"datasets==4.8.4",
|
|
16
|
+
"tensorboard==2.20.0",
|
|
17
|
+
"accelerate==1.13.0",
|
|
17
18
|
"retrievalbase>=2.1.0,<3.0.0",
|
|
18
19
|
]
|
|
19
20
|
|
|
@@ -81,7 +81,7 @@ class SentenceTransformerHardNegativeMinerSettings(HardNegativeMinerSettings):
|
|
|
81
81
|
model_name_or_path: str
|
|
82
82
|
cross_encoder_model_name_or_path: str | None = None
|
|
83
83
|
tokenizer: TokenizerSettings
|
|
84
|
-
pooling: Literal["cls", "
|
|
84
|
+
pooling: Literal["cls", "max", "mean", "mean_sqrt_len_tokens", "weightedmean", "lasttoken"]
|
|
85
85
|
anchor_column_name: str = "query"
|
|
86
86
|
positive_column_name: str = "positive"
|
|
87
87
|
range_min: int = 0
|
|
@@ -32,13 +32,13 @@ class SentenceTransformerHardNegativeMiner(HardNegativeMiner[SentenceTransformer
|
|
|
32
32
|
model_name_or_path=self.config.model_name_or_path,
|
|
33
33
|
max_seq_length=self.config.tokenizer.max_length,
|
|
34
34
|
tokenizer_name_or_path=self.config.tokenizer.name,
|
|
35
|
-
|
|
35
|
+
model_kwargs={
|
|
36
36
|
"trust_remote_code": trust,
|
|
37
37
|
},
|
|
38
|
-
|
|
38
|
+
config_kwargs={
|
|
39
39
|
"trust_remote_code": trust,
|
|
40
40
|
},
|
|
41
|
-
|
|
41
|
+
processor_kwargs={
|
|
42
42
|
"trust_remote_code": trust,
|
|
43
43
|
"padding": self.config.tokenizer.padding,
|
|
44
44
|
"truncation": self.config.tokenizer.truncation,
|
|
@@ -46,18 +46,11 @@ class SentenceTransformerHardNegativeMiner(HardNegativeMiner[SentenceTransformer
|
|
|
46
46
|
},
|
|
47
47
|
)
|
|
48
48
|
|
|
49
|
-
pooling = Pooling(
|
|
50
|
-
transformer.get_word_embedding_dimension(),
|
|
51
|
-
pooling_mode_mean_tokens=self.config.pooling == "mean_tokens",
|
|
52
|
-
pooling_mode_cls_token=self.config.pooling == "cls",
|
|
53
|
-
pooling_mode_max_tokens=self.config.pooling == "max_tokens",
|
|
54
|
-
)
|
|
49
|
+
pooling = Pooling(transformer.get_embedding_dimension(), pooling_mode=self.config.pooling)
|
|
55
50
|
|
|
56
51
|
model = SentenceTransformer(modules=[transformer, pooling])
|
|
57
52
|
|
|
58
|
-
_logger.info(
|
|
59
|
-
f"SentenceTransformer loaded successfully | embedding_dim={transformer.get_word_embedding_dimension()}"
|
|
60
|
-
)
|
|
53
|
+
_logger.info(f"SentenceTransformer loaded successfully | embedding_dim={transformer.get_embedding_dimension()}")
|
|
61
54
|
|
|
62
55
|
return model
|
|
63
56
|
|
|
@@ -64,7 +64,7 @@ class SentenceTransformersTrainer[TCHFTrainRunner: "SentenceTransformersTrainerS
|
|
|
64
64
|
_logger.info("Starting SentenceTransformers training...")
|
|
65
65
|
trainer.train()
|
|
66
66
|
|
|
67
|
-
def _get_warmup_steps(self, dataset: Dataset) ->
|
|
67
|
+
def _get_warmup_steps(self, dataset: Dataset) -> int:
|
|
68
68
|
train_size = len(dataset)
|
|
69
69
|
steps_per_epoch = train_size // self.config.per_device_train_batch_size
|
|
70
70
|
total_steps = steps_per_epoch * self.config.num_epochs
|
|
@@ -96,13 +96,13 @@ class SentenceTransformersTrainer[TCHFTrainRunner: "SentenceTransformersTrainerS
|
|
|
96
96
|
model_name_or_path=model_path,
|
|
97
97
|
max_seq_length=self.config.tokenizer.max_length,
|
|
98
98
|
tokenizer_name_or_path=self.config.tokenizer.name,
|
|
99
|
-
|
|
99
|
+
model_kwargs={
|
|
100
100
|
"trust_remote_code": trust,
|
|
101
101
|
},
|
|
102
|
-
|
|
102
|
+
config_kwargs={
|
|
103
103
|
"trust_remote_code": trust,
|
|
104
104
|
},
|
|
105
|
-
|
|
105
|
+
processor_kwargs={
|
|
106
106
|
"trust_remote_code": trust,
|
|
107
107
|
"padding": self.config.tokenizer.padding,
|
|
108
108
|
"truncation": self.config.tokenizer.truncation,
|
|
@@ -110,7 +110,7 @@ class SentenceTransformersTrainer[TCHFTrainRunner: "SentenceTransformersTrainerS
|
|
|
110
110
|
},
|
|
111
111
|
)
|
|
112
112
|
pooling = Pooling(
|
|
113
|
-
transformer.
|
|
113
|
+
transformer.get_embedding_dimension(),
|
|
114
114
|
pooling_mode_mean_tokens=self.config.pooling == "mean_tokens",
|
|
115
115
|
pooling_mode_cls_token=self.config.pooling == "cls",
|
|
116
116
|
pooling_mode_max_tokens=self.config.pooling == "max_tokens",
|
|
@@ -264,7 +264,7 @@ def build_hard_negative_miner_settings(**overrides: Any) -> SentenceTransformerH
|
|
|
264
264
|
"module_path": "embed_train.train.dataset.hard_negatives.SentenceTransformerHardNegativeMiner",
|
|
265
265
|
"model_name_or_path": "dummy-miner",
|
|
266
266
|
"tokenizer": build_tokenizer_settings(),
|
|
267
|
-
"pooling": "
|
|
267
|
+
"pooling": "mean",
|
|
268
268
|
"range_min": 0,
|
|
269
269
|
"range_max": 2,
|
|
270
270
|
"relative_margin": 0.1,
|
|
@@ -158,6 +158,29 @@ def test_save_model_locally_passes_expected_arguments(tmp_path: Path) -> None:
|
|
|
158
158
|
]
|
|
159
159
|
|
|
160
160
|
|
|
161
|
+
def test_ensure_hf_repo_exists_creates_repo_when_enabled(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
|
162
|
+
config = build_push_to_hf_runner_settings(tmp_path, create_repo=True)
|
|
163
|
+
runner = PushToHFRunner(config)
|
|
164
|
+
payloads: list[dict[str, object]] = []
|
|
165
|
+
|
|
166
|
+
class DummyApi:
|
|
167
|
+
def create_repo(self, **kwargs: object) -> None:
|
|
168
|
+
payloads.append(kwargs)
|
|
169
|
+
|
|
170
|
+
monkeypatch.setattr("embed_train.push_to_hf.HfApi", DummyApi)
|
|
171
|
+
|
|
172
|
+
runner._ensure_hf_repo_exists()
|
|
173
|
+
|
|
174
|
+
assert payloads == [
|
|
175
|
+
{
|
|
176
|
+
"repo_id": config.hf.repo,
|
|
177
|
+
"repo_type": "model",
|
|
178
|
+
"private": config.hf.private,
|
|
179
|
+
"exist_ok": True,
|
|
180
|
+
}
|
|
181
|
+
]
|
|
182
|
+
|
|
183
|
+
|
|
161
184
|
def test_push_repo_to_hf_calls_upload_folder(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None:
|
|
162
185
|
config = build_push_to_hf_runner_settings(tmp_path)
|
|
163
186
|
runner = PushToHFRunner(config)
|
|
@@ -69,3 +69,15 @@ def test_hard_negative_collate_processes_negative_columns(monkeypatch) -> None:
|
|
|
69
69
|
|
|
70
70
|
assert cast(Any, q_tok)["input_ids"].shape == torch.Size([2, 3])
|
|
71
71
|
assert cast(Any, c_tok)["input_ids"].shape == torch.Size([6, 3])
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def test_hard_negative_collate_processes_single_negative(monkeypatch) -> None:
|
|
75
|
+
monkeypatch.setattr(
|
|
76
|
+
"embed_train.train.dataset.AutoTokenizer.from_pretrained", lambda *args, **kwargs: DummyTokenizer()
|
|
77
|
+
)
|
|
78
|
+
collate = HardNegativeCollateFn(build_hard_negative_collate_settings(), context=None)
|
|
79
|
+
|
|
80
|
+
q_tok, c_tok = collate(cast(Any, [{"query": "q1", "positive": "p1", "negative": "n1"}]))
|
|
81
|
+
|
|
82
|
+
assert cast(Any, q_tok)["input_ids"].shape == torch.Size([1, 3])
|
|
83
|
+
assert cast(Any, c_tok)["input_ids"].shape == torch.Size([2, 3])
|
|
@@ -2,13 +2,16 @@ from __future__ import annotations
|
|
|
2
2
|
|
|
3
3
|
from typing import Any, cast
|
|
4
4
|
|
|
5
|
+
import pytest
|
|
5
6
|
import torch
|
|
7
|
+
from datasets import Dataset
|
|
6
8
|
|
|
7
|
-
from embed_train.train.dataset import CollateFn
|
|
9
|
+
from embed_train.train.dataset import CollateFn, HardNegativeMiner
|
|
8
10
|
from tests.fixtures.components import (
|
|
9
11
|
DummyDatasetConnector,
|
|
10
12
|
DummyTokenizer,
|
|
11
13
|
build_collate_settings,
|
|
14
|
+
build_hard_negative_miner_settings,
|
|
12
15
|
build_multi_positive_torch_dataset_settings,
|
|
13
16
|
)
|
|
14
17
|
|
|
@@ -62,3 +65,14 @@ def test_torch_dataset_loads_runtime_dataset_and_converts_to_hf() -> None:
|
|
|
62
65
|
assert len(dataset) == 1
|
|
63
66
|
assert hf_dataset[0]["query"] == "query-a"
|
|
64
67
|
assert hf_dataset[0]["positives"] == ["doc-1", "doc-2"]
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def test_hard_negative_miner_base_mine_raises_not_implemented() -> None:
|
|
71
|
+
class ConcreteHardNegativeMiner(HardNegativeMiner):
|
|
72
|
+
def mine(self, dataset: Dataset) -> Dataset:
|
|
73
|
+
return super().mine(dataset)
|
|
74
|
+
|
|
75
|
+
miner = ConcreteHardNegativeMiner(build_hard_negative_miner_settings())
|
|
76
|
+
|
|
77
|
+
with pytest.raises(NotImplementedError):
|
|
78
|
+
miner.mine(Dataset.from_list([]))
|
|
@@ -0,0 +1,117 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from datasets import Dataset
|
|
6
|
+
|
|
7
|
+
from embed_train.train.dataset.hard_negatives import SentenceTransformerHardNegativeMiner
|
|
8
|
+
from tests.fixtures.components import build_hard_negative_miner_settings
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class DummyTransformer:
|
|
12
|
+
calls: list[dict[str, Any]] = []
|
|
13
|
+
|
|
14
|
+
def __init__(self, **kwargs: Any) -> None:
|
|
15
|
+
self.calls.append(kwargs)
|
|
16
|
+
|
|
17
|
+
def get_embedding_dimension(self) -> int:
|
|
18
|
+
return 8
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class DummyPooling:
|
|
22
|
+
calls: list[dict[str, Any]] = []
|
|
23
|
+
|
|
24
|
+
def __init__(self, embedding_dimension: int, *, pooling_mode: str) -> None:
|
|
25
|
+
self.calls.append({"embedding_dimension": embedding_dimension, "pooling_mode": pooling_mode})
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class DummySentenceTransformer:
|
|
29
|
+
calls: list[list[object]] = []
|
|
30
|
+
|
|
31
|
+
def __init__(self, *, modules: list[object]) -> None:
|
|
32
|
+
self.calls.append(modules)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class DummyCrossEncoder:
|
|
36
|
+
calls: list[dict[str, Any]] = []
|
|
37
|
+
|
|
38
|
+
def __init__(self, model_name_or_path: str, *, trust_remote_code: bool) -> None:
|
|
39
|
+
self.calls.append(
|
|
40
|
+
{
|
|
41
|
+
"model_name_or_path": model_name_or_path,
|
|
42
|
+
"trust_remote_code": trust_remote_code,
|
|
43
|
+
}
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def test_sentence_transformer_hard_negative_miner_mines_with_config(
|
|
48
|
+
monkeypatch,
|
|
49
|
+
) -> None:
|
|
50
|
+
captured: dict[str, Any] = {}
|
|
51
|
+
cleanup_calls: list[str] = []
|
|
52
|
+
mined = Dataset.from_list([{"query": "q1", "positive": "p1", "negative": "n1"}])
|
|
53
|
+
|
|
54
|
+
def fake_mine_hard_negatives(**kwargs: Any) -> Dataset:
|
|
55
|
+
captured.update(kwargs)
|
|
56
|
+
return mined
|
|
57
|
+
|
|
58
|
+
class DummyCuda:
|
|
59
|
+
@staticmethod
|
|
60
|
+
def is_available() -> bool:
|
|
61
|
+
return True
|
|
62
|
+
|
|
63
|
+
@staticmethod
|
|
64
|
+
def memory_allocated() -> int:
|
|
65
|
+
cleanup_calls.append("allocated")
|
|
66
|
+
return 1024**3
|
|
67
|
+
|
|
68
|
+
@staticmethod
|
|
69
|
+
def memory_reserved() -> int:
|
|
70
|
+
cleanup_calls.append("reserved")
|
|
71
|
+
return 2 * 1024**3
|
|
72
|
+
|
|
73
|
+
@staticmethod
|
|
74
|
+
def empty_cache() -> None:
|
|
75
|
+
cleanup_calls.append("empty_cache")
|
|
76
|
+
|
|
77
|
+
@staticmethod
|
|
78
|
+
def ipc_collect() -> None:
|
|
79
|
+
cleanup_calls.append("ipc_collect")
|
|
80
|
+
|
|
81
|
+
monkeypatch.setattr("embed_train.train.dataset.hard_negatives.Transformer", DummyTransformer)
|
|
82
|
+
monkeypatch.setattr("embed_train.train.dataset.hard_negatives.Pooling", DummyPooling)
|
|
83
|
+
monkeypatch.setattr("embed_train.train.dataset.hard_negatives.SentenceTransformer", DummySentenceTransformer)
|
|
84
|
+
monkeypatch.setattr("embed_train.train.dataset.hard_negatives.CrossEncoder", DummyCrossEncoder)
|
|
85
|
+
monkeypatch.setattr("embed_train.train.dataset.hard_negatives.mine_hard_negatives", fake_mine_hard_negatives)
|
|
86
|
+
monkeypatch.setattr("embed_train.train.dataset.hard_negatives.torch.cuda", DummyCuda)
|
|
87
|
+
|
|
88
|
+
config = build_hard_negative_miner_settings(
|
|
89
|
+
cross_encoder_model_name_or_path="dummy-reranker",
|
|
90
|
+
trust_remote_code=True,
|
|
91
|
+
)
|
|
92
|
+
miner = SentenceTransformerHardNegativeMiner(config)
|
|
93
|
+
|
|
94
|
+
result = miner.mine(Dataset.from_list([{"query": "q1", "positive": "p1"}]))
|
|
95
|
+
|
|
96
|
+
assert result is mined
|
|
97
|
+
assert DummyTransformer.calls[-1]["model_name_or_path"] == "dummy-miner"
|
|
98
|
+
assert DummyTransformer.calls[-1]["processor_kwargs"]["model_max_length"] == config.tokenizer.max_length
|
|
99
|
+
assert DummyPooling.calls[-1] == {"embedding_dimension": 8, "pooling_mode": "mean"}
|
|
100
|
+
assert len(DummySentenceTransformer.calls[-1]) == 2
|
|
101
|
+
assert DummyCrossEncoder.calls[-1] == {
|
|
102
|
+
"model_name_or_path": "dummy-reranker",
|
|
103
|
+
"trust_remote_code": True,
|
|
104
|
+
}
|
|
105
|
+
assert captured["anchor_column_name"] == "query"
|
|
106
|
+
assert captured["positive_column_name"] == "positive"
|
|
107
|
+
assert captured["cross_encoder"] is not None
|
|
108
|
+
assert captured["num_negatives"] == 2
|
|
109
|
+
assert captured["sampling_strategy"] == "top"
|
|
110
|
+
assert "empty_cache" in cleanup_calls
|
|
111
|
+
assert "ipc_collect" in cleanup_calls
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def test_sentence_transformer_hard_negative_miner_skips_cross_encoder() -> None:
|
|
115
|
+
miner = SentenceTransformerHardNegativeMiner(build_hard_negative_miner_settings())
|
|
116
|
+
|
|
117
|
+
assert miner._load_cross_encoder() is None
|
|
@@ -58,3 +58,26 @@ def test_hard_negative_contrastive_loss_rejects_invalid_candidate_layout() -> No
|
|
|
58
58
|
|
|
59
59
|
with pytest.raises(EmbedTrainValueError, match="one positive and at least one negative"):
|
|
60
60
|
loss(q_emb, c_emb)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def test_hard_negative_contrastive_loss_returns_scalar_for_valid_layout() -> None:
|
|
64
|
+
loss = HardNegativeContrastiveLoss(
|
|
65
|
+
HardNegativeContrastiveLossSettings(
|
|
66
|
+
module_path="embed_train.train.trainers.torch.loss.HardNegativeContrastiveLoss",
|
|
67
|
+
temperature=0.5,
|
|
68
|
+
)
|
|
69
|
+
)
|
|
70
|
+
q_emb = torch.tensor([[1.0, 0.0], [0.0, 1.0]])
|
|
71
|
+
c_emb = torch.tensor(
|
|
72
|
+
[
|
|
73
|
+
[1.0, 0.0],
|
|
74
|
+
[0.0, 1.0],
|
|
75
|
+
[0.0, 1.0],
|
|
76
|
+
[1.0, 0.0],
|
|
77
|
+
]
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
result = loss(q_emb, c_emb)
|
|
81
|
+
|
|
82
|
+
assert result.ndim == 0
|
|
83
|
+
assert result.item() >= 0
|
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
from datasets import Dataset
|
|
5
|
+
|
|
6
|
+
from tests.fixtures.components import (
|
|
7
|
+
DummyDatasetConnector,
|
|
8
|
+
build_hard_negative_torch_dataset_settings,
|
|
9
|
+
build_multi_positive_torch_dataset_settings,
|
|
10
|
+
build_torch_dataset_settings,
|
|
11
|
+
)
|
|
12
|
+
from tests.fixtures.data import build_query_rows
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def test_query_multi_positive_dataset_groups_rows() -> None:
|
|
16
|
+
from embed_train.train.dataset.torch_datasets import QueryMultiPositiveDataset
|
|
17
|
+
|
|
18
|
+
DummyDatasetConnector.rows = build_query_rows()
|
|
19
|
+
dataset = QueryMultiPositiveDataset(build_multi_positive_torch_dataset_settings())
|
|
20
|
+
|
|
21
|
+
assert len(dataset) == 2
|
|
22
|
+
grouped = {item["query"]: item["positives"] for item in (dataset[i] for i in range(len(dataset)))}
|
|
23
|
+
assert grouped["query-a"] == ["doc-1", "doc-2"]
|
|
24
|
+
assert grouped["query-b"] == ["doc-3"]
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def test_query_positive_dataset_flattens_rows() -> None:
|
|
28
|
+
from embed_train.train.dataset.torch_datasets import QueryPositiveDataset
|
|
29
|
+
|
|
30
|
+
DummyDatasetConnector.rows = build_query_rows()
|
|
31
|
+
dataset = QueryPositiveDataset(build_torch_dataset_settings())
|
|
32
|
+
|
|
33
|
+
assert len(dataset) == 3
|
|
34
|
+
rows = {tuple(item.values()) for item in (dataset[i] for i in range(len(dataset)))}
|
|
35
|
+
assert ("query-a", "doc-1") in rows
|
|
36
|
+
assert ("query-a", "doc-2") in rows
|
|
37
|
+
assert ("query-b", "doc-3") in rows
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def test_query_positive_dataset_rejects_empty_positive() -> None:
|
|
41
|
+
from embed_train.train.dataset.torch_datasets import QueryPositiveDataset
|
|
42
|
+
|
|
43
|
+
DummyDatasetConnector.rows = [
|
|
44
|
+
{
|
|
45
|
+
"page_content": "",
|
|
46
|
+
"metadata": {"query": "query-a", "step_range": (0, 1), "section": "alpha"},
|
|
47
|
+
}
|
|
48
|
+
]
|
|
49
|
+
dataset = QueryPositiveDataset(build_torch_dataset_settings())
|
|
50
|
+
|
|
51
|
+
with pytest.raises(ValueError, match="No positive passage found for index 0"):
|
|
52
|
+
dataset[0]
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class DummyHardNegativeMiner:
|
|
56
|
+
seen_dataset: Dataset | None = None
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def from_config(cls, _config: object) -> DummyHardNegativeMiner:
|
|
60
|
+
return cls()
|
|
61
|
+
|
|
62
|
+
def mine(self, dataset: Dataset) -> Dataset:
|
|
63
|
+
self.__class__.seen_dataset = dataset
|
|
64
|
+
return Dataset.from_list(
|
|
65
|
+
[
|
|
66
|
+
{"query": "query-a", "positive": "doc-1", "negative": "doc-2"},
|
|
67
|
+
{"query": "query-b", "positive": "doc-3", "negative": "doc-1"},
|
|
68
|
+
]
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def test_hard_negative_dataset_mines_and_exposes_rows(monkeypatch: pytest.MonkeyPatch) -> None:
|
|
73
|
+
from embed_train.train.dataset.torch_datasets import HardNegativeDataset
|
|
74
|
+
|
|
75
|
+
DummyDatasetConnector.rows = build_query_rows()
|
|
76
|
+
monkeypatch.setattr("embed_train.train.dataset.torch_datasets.load_class", lambda _: DummyHardNegativeMiner)
|
|
77
|
+
|
|
78
|
+
dataset = HardNegativeDataset(build_hard_negative_torch_dataset_settings())
|
|
79
|
+
|
|
80
|
+
assert len(dataset) == 2
|
|
81
|
+
assert dataset[0] == {"query": "query-a", "positive": "doc-1", "negative": "doc-2"}
|
|
82
|
+
assert DummyHardNegativeMiner.seen_dataset is not None
|
|
83
|
+
assert DummyHardNegativeMiner.seen_dataset[0] == {"query": "query-a", "positive": "doc-1"}
|
|
@@ -110,15 +110,6 @@ wheels = [
|
|
|
110
110
|
{ url = "https://files.pythonhosted.org/packages/fb/76/641ae371508676492379f16e2fa48f4e2c11741bd63c48be4b12a6b09cba/aiosignal-1.4.0-py3-none-any.whl", hash = "sha256:053243f8b92b990551949e63930a839ff0cf0b0ebbe0597b0f3fb19e1a0fe82e", size = 7490, upload-time = "2025-07-03T22:54:42.156Z" },
|
|
111
111
|
]
|
|
112
112
|
|
|
113
|
-
[[package]]
|
|
114
|
-
name = "annotated-doc"
|
|
115
|
-
version = "0.0.4"
|
|
116
|
-
source = { registry = "https://pypi.org/simple" }
|
|
117
|
-
sdist = { url = "https://files.pythonhosted.org/packages/57/ba/046ceea27344560984e26a590f90bc7f4a75b06701f653222458922b558c/annotated_doc-0.0.4.tar.gz", hash = "sha256:fbcda96e87e9c92ad167c2e53839e57503ecfda18804ea28102353485033faa4", size = 7288, upload-time = "2025-11-10T22:07:42.062Z" }
|
|
118
|
-
wheels = [
|
|
119
|
-
{ url = "https://files.pythonhosted.org/packages/1e/d3/26bf1008eb3d2daa8ef4cacc7f3bfdc11818d111f7e2d0201bc6e3b49d45/annotated_doc-0.0.4-py3-none-any.whl", hash = "sha256:571ac1dc6991c450b25a9c2d84a3705e2ae7a53467b5d111c24fa8baabbed320", size = 5303, upload-time = "2025-11-10T22:07:40.673Z" },
|
|
120
|
-
]
|
|
121
|
-
|
|
122
113
|
[[package]]
|
|
123
114
|
name = "annotated-types"
|
|
124
115
|
version = "0.7.0"
|
|
@@ -284,18 +275,6 @@ wheels = [
|
|
|
284
275
|
{ url = "https://files.pythonhosted.org/packages/db/8f/61959034484a4a7c527811f4721e75d02d653a35afb0b6054474d8185d4c/charset_normalizer-3.4.7-py3-none-any.whl", hash = "sha256:3dce51d0f5e7951f8bb4900c257dad282f49190fdbebecd4ba99bcc41fef404d", size = 61958, upload-time = "2026-04-02T09:28:37.794Z" },
|
|
285
276
|
]
|
|
286
277
|
|
|
287
|
-
[[package]]
|
|
288
|
-
name = "click"
|
|
289
|
-
version = "8.3.2"
|
|
290
|
-
source = { registry = "https://pypi.org/simple" }
|
|
291
|
-
dependencies = [
|
|
292
|
-
{ name = "colorama", marker = "sys_platform == 'win32'" },
|
|
293
|
-
]
|
|
294
|
-
sdist = { url = "https://files.pythonhosted.org/packages/57/75/31212c6bf2503fdf920d87fee5d7a86a2e3bcf444984126f13d8e4016804/click-8.3.2.tar.gz", hash = "sha256:14162b8b3b3550a7d479eafa77dfd3c38d9dc8951f6f69c78913a8f9a7540fd5", size = 302856, upload-time = "2026-04-03T19:14:45.118Z" }
|
|
295
|
-
wheels = [
|
|
296
|
-
{ url = "https://files.pythonhosted.org/packages/e4/20/71885d8b97d4f3dde17b1fdb92dbd4908b00541c5a3379787137285f602e/click-8.3.2-py3-none-any.whl", hash = "sha256:1924d2c27c5653561cd2cae4548d1406039cb79b858b747cfea24924bbc1616d", size = 108379, upload-time = "2026-04-03T19:14:43.505Z" },
|
|
297
|
-
]
|
|
298
|
-
|
|
299
278
|
[[package]]
|
|
300
279
|
name = "colorama"
|
|
301
280
|
version = "0.4.6"
|
|
@@ -459,7 +438,7 @@ wheels = [
|
|
|
459
438
|
|
|
460
439
|
[[package]]
|
|
461
440
|
name = "embed-train"
|
|
462
|
-
version = "3.
|
|
441
|
+
version = "3.2.0"
|
|
463
442
|
source = { editable = "." }
|
|
464
443
|
dependencies = [
|
|
465
444
|
{ name = "accelerate" },
|
|
@@ -468,6 +447,7 @@ dependencies = [
|
|
|
468
447
|
{ name = "sentence-transformers" },
|
|
469
448
|
{ name = "tensorboard" },
|
|
470
449
|
{ name = "torch" },
|
|
450
|
+
{ name = "transformers" },
|
|
471
451
|
]
|
|
472
452
|
|
|
473
453
|
[package.dev-dependencies]
|
|
@@ -484,12 +464,13 @@ dev = [
|
|
|
484
464
|
|
|
485
465
|
[package.metadata]
|
|
486
466
|
requires-dist = [
|
|
487
|
-
{ name = "accelerate", specifier = "
|
|
488
|
-
{ name = "datasets", specifier = "
|
|
467
|
+
{ name = "accelerate", specifier = "==1.13.0" },
|
|
468
|
+
{ name = "datasets", specifier = "==4.8.4" },
|
|
489
469
|
{ name = "retrievalbase", specifier = ">=2.1.0,<3.0.0" },
|
|
490
|
-
{ name = "sentence-transformers", specifier = "
|
|
491
|
-
{ name = "tensorboard", specifier = "
|
|
492
|
-
{ name = "torch", specifier = "
|
|
470
|
+
{ name = "sentence-transformers", specifier = "==5.4.1" },
|
|
471
|
+
{ name = "tensorboard", specifier = "==2.20.0" },
|
|
472
|
+
{ name = "torch", specifier = "==2.11.0" },
|
|
473
|
+
{ name = "transformers", specifier = "==4.57.6" },
|
|
493
474
|
]
|
|
494
475
|
|
|
495
476
|
[package.metadata.requires-dev]
|
|
@@ -702,22 +683,21 @@ http2 = [
|
|
|
702
683
|
|
|
703
684
|
[[package]]
|
|
704
685
|
name = "huggingface-hub"
|
|
705
|
-
version = "
|
|
686
|
+
version = "0.36.2"
|
|
706
687
|
source = { registry = "https://pypi.org/simple" }
|
|
707
688
|
dependencies = [
|
|
708
689
|
{ name = "filelock" },
|
|
709
690
|
{ name = "fsspec" },
|
|
710
|
-
{ name = "hf-xet", marker = "platform_machine == '
|
|
711
|
-
{ name = "httpx" },
|
|
691
|
+
{ name = "hf-xet", marker = "platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'" },
|
|
712
692
|
{ name = "packaging" },
|
|
713
693
|
{ name = "pyyaml" },
|
|
694
|
+
{ name = "requests" },
|
|
714
695
|
{ name = "tqdm" },
|
|
715
|
-
{ name = "typer" },
|
|
716
696
|
{ name = "typing-extensions" },
|
|
717
697
|
]
|
|
718
|
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
|
698
|
+
sdist = { url = "https://files.pythonhosted.org/packages/7c/b7/8cb61d2eece5fb05a83271da168186721c450eb74e3c31f7ef3169fa475b/huggingface_hub-0.36.2.tar.gz", hash = "sha256:1934304d2fb224f8afa3b87007d58501acfda9215b334eed53072dd5e815ff7a", size = 649782, upload-time = "2026-02-06T09:24:13.098Z" }
|
|
719
699
|
wheels = [
|
|
720
|
-
{ url = "https://files.pythonhosted.org/packages/
|
|
700
|
+
{ url = "https://files.pythonhosted.org/packages/a8/af/48ac8483240de756d2438c380746e7130d1c6f75802ef22f3c6d49982787/huggingface_hub-0.36.2-py3-none-any.whl", hash = "sha256:48f0c8eac16145dfce371e9d2d7772854a4f591bcb56c9cf548accf531d54270", size = 566395, upload-time = "2026-02-06T09:24:11.133Z" },
|
|
721
701
|
]
|
|
722
702
|
|
|
723
703
|
[[package]]
|
|
@@ -2085,15 +2065,6 @@ wheels = [
|
|
|
2085
2065
|
{ url = "https://files.pythonhosted.org/packages/e1/e3/c164c88b2e5ce7b24d667b9bd83589cf4f3520d97cad01534cd3c4f55fdb/setuptools-81.0.0-py3-none-any.whl", hash = "sha256:fdd925d5c5d9f62e4b74b30d6dd7828ce236fd6ed998a08d81de62ce5a6310d6", size = 1062021, upload-time = "2026-02-06T21:10:37.175Z" },
|
|
2086
2066
|
]
|
|
2087
2067
|
|
|
2088
|
-
[[package]]
|
|
2089
|
-
name = "shellingham"
|
|
2090
|
-
version = "1.5.4"
|
|
2091
|
-
source = { registry = "https://pypi.org/simple" }
|
|
2092
|
-
sdist = { url = "https://files.pythonhosted.org/packages/58/15/8b3609fd3830ef7b27b655beb4b4e9c62313a4e8da8c676e142cc210d58e/shellingham-1.5.4.tar.gz", hash = "sha256:8dbca0739d487e5bd35ab3ca4b36e11c4078f3a234bfce294b0a0291363404de", size = 10310, upload-time = "2023-10-24T04:13:40.426Z" }
|
|
2093
|
-
wheels = [
|
|
2094
|
-
{ url = "https://files.pythonhosted.org/packages/e0/f9/0595336914c5619e5f28a1fb793285925a8cd4b432c9da0a987836c7f822/shellingham-1.5.4-py2.py3-none-any.whl", hash = "sha256:7ecfff8f2fd72616f7481040475a65b2bf8af90a56c89140852d1120324e8686", size = 9755, upload-time = "2023-10-24T04:13:38.866Z" },
|
|
2095
|
-
]
|
|
2096
|
-
|
|
2097
2068
|
[[package]]
|
|
2098
2069
|
name = "six"
|
|
2099
2070
|
version = "1.17.0"
|
|
@@ -2279,22 +2250,23 @@ wheels = [
|
|
|
2279
2250
|
|
|
2280
2251
|
[[package]]
|
|
2281
2252
|
name = "transformers"
|
|
2282
|
-
version = "
|
|
2253
|
+
version = "4.57.6"
|
|
2283
2254
|
source = { registry = "https://pypi.org/simple" }
|
|
2284
2255
|
dependencies = [
|
|
2256
|
+
{ name = "filelock" },
|
|
2285
2257
|
{ name = "huggingface-hub" },
|
|
2286
2258
|
{ name = "numpy" },
|
|
2287
2259
|
{ name = "packaging" },
|
|
2288
2260
|
{ name = "pyyaml" },
|
|
2289
2261
|
{ name = "regex" },
|
|
2262
|
+
{ name = "requests" },
|
|
2290
2263
|
{ name = "safetensors" },
|
|
2291
2264
|
{ name = "tokenizers" },
|
|
2292
2265
|
{ name = "tqdm" },
|
|
2293
|
-
{ name = "typer" },
|
|
2294
2266
|
]
|
|
2295
|
-
sdist = { url = "https://files.pythonhosted.org/packages/
|
|
2267
|
+
sdist = { url = "https://files.pythonhosted.org/packages/c4/35/67252acc1b929dc88b6602e8c4a982e64f31e733b804c14bc24b47da35e6/transformers-4.57.6.tar.gz", hash = "sha256:55e44126ece9dc0a291521b7e5492b572e6ef2766338a610b9ab5afbb70689d3", size = 10134912, upload-time = "2026-01-16T10:38:39.284Z" }
|
|
2296
2268
|
wheels = [
|
|
2297
|
-
{ url = "https://files.pythonhosted.org/packages/
|
|
2269
|
+
{ url = "https://files.pythonhosted.org/packages/03/b8/e484ef633af3887baeeb4b6ad12743363af7cce68ae51e938e00aaa0529d/transformers-4.57.6-py3-none-any.whl", hash = "sha256:4c9e9de11333ddfe5114bc872c9f370509198acf0b87a832a0ab9458e2bd0550", size = 11993498, upload-time = "2026-01-16T10:38:31.289Z" },
|
|
2298
2270
|
]
|
|
2299
2271
|
|
|
2300
2272
|
[[package]]
|
|
@@ -2332,21 +2304,6 @@ wheels = [
|
|
|
2332
2304
|
{ url = "https://files.pythonhosted.org/packages/88/39/bca669095ccf0a400af941fdf741578d4c2d6719f1b7f10e6dbec10aa862/ty-0.0.31-py3-none-win_arm64.whl", hash = "sha256:e9cb15fad26545c6a608f40f227af3a5513cb376998ca6feddd47ca7d93ffafa", size = 10590392, upload-time = "2026-04-15T15:47:57.968Z" },
|
|
2333
2305
|
]
|
|
2334
2306
|
|
|
2335
|
-
[[package]]
|
|
2336
|
-
name = "typer"
|
|
2337
|
-
version = "0.24.1"
|
|
2338
|
-
source = { registry = "https://pypi.org/simple" }
|
|
2339
|
-
dependencies = [
|
|
2340
|
-
{ name = "annotated-doc" },
|
|
2341
|
-
{ name = "click" },
|
|
2342
|
-
{ name = "rich" },
|
|
2343
|
-
{ name = "shellingham" },
|
|
2344
|
-
]
|
|
2345
|
-
sdist = { url = "https://files.pythonhosted.org/packages/f5/24/cb09efec5cc954f7f9b930bf8279447d24618bb6758d4f6adf2574c41780/typer-0.24.1.tar.gz", hash = "sha256:e39b4732d65fbdcde189ae76cf7cd48aeae72919dea1fdfc16593be016256b45", size = 118613, upload-time = "2026-02-21T16:54:40.609Z" }
|
|
2346
|
-
wheels = [
|
|
2347
|
-
{ url = "https://files.pythonhosted.org/packages/4a/91/48db081e7a63bb37284f9fbcefda7c44c277b18b0e13fbc36ea2335b71e6/typer-0.24.1-py3-none-any.whl", hash = "sha256:112c1f0ce578bfb4cab9ffdabc68f031416ebcc216536611ba21f04e9aa84c9e", size = 56085, upload-time = "2026-02-21T16:54:41.616Z" },
|
|
2348
|
-
]
|
|
2349
|
-
|
|
2350
2307
|
[[package]]
|
|
2351
2308
|
name = "types-pyyaml"
|
|
2352
2309
|
version = "6.0.12.20260408"
|
|
@@ -1,33 +0,0 @@
|
|
|
1
|
-
from __future__ import annotations
|
|
2
|
-
|
|
3
|
-
from tests.fixtures.components import (
|
|
4
|
-
DummyDatasetConnector,
|
|
5
|
-
build_multi_positive_torch_dataset_settings,
|
|
6
|
-
build_torch_dataset_settings,
|
|
7
|
-
)
|
|
8
|
-
from tests.fixtures.data import build_query_rows
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def test_query_multi_positive_dataset_groups_rows() -> None:
|
|
12
|
-
from embed_train.train.dataset.torch_datasets import QueryMultiPositiveDataset
|
|
13
|
-
|
|
14
|
-
DummyDatasetConnector.rows = build_query_rows()
|
|
15
|
-
dataset = QueryMultiPositiveDataset(build_multi_positive_torch_dataset_settings())
|
|
16
|
-
|
|
17
|
-
assert len(dataset) == 2
|
|
18
|
-
grouped = {item["query"]: item["positives"] for item in (dataset[i] for i in range(len(dataset)))}
|
|
19
|
-
assert grouped["query-a"] == ["doc-1", "doc-2"]
|
|
20
|
-
assert grouped["query-b"] == ["doc-3"]
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
def test_query_positive_dataset_flattens_rows() -> None:
|
|
24
|
-
from embed_train.train.dataset.torch_datasets import QueryPositiveDataset
|
|
25
|
-
|
|
26
|
-
DummyDatasetConnector.rows = build_query_rows()
|
|
27
|
-
dataset = QueryPositiveDataset(build_torch_dataset_settings())
|
|
28
|
-
|
|
29
|
-
assert len(dataset) == 3
|
|
30
|
-
rows = {tuple(item.values()) for item in (dataset[i] for i in range(len(dataset)))}
|
|
31
|
-
assert ("query-a", "doc-1") in rows
|
|
32
|
-
assert ("query-a", "doc-2") in rows
|
|
33
|
-
assert ("query-b", "doc-3") in rows
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
{embed_train-3.1.0 → embed_train-3.2.0}/tests/integration/test_dataset/test_to_hf_dataset.py
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|