pylate 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pylate-0.0.1/LICENSE +21 -0
- pylate-0.0.1/PKG-INFO +292 -0
- pylate-0.0.1/README.md +248 -0
- pylate-0.0.1/pylate/__init__.py +10 -0
- pylate-0.0.1/pylate/__version__.py +3 -0
- pylate-0.0.1/pylate/evaluation/__init__.py +11 -0
- pylate-0.0.1/pylate/evaluation/beir.py +218 -0
- pylate-0.0.1/pylate/evaluation/colbert_distillation.py +198 -0
- pylate-0.0.1/pylate/evaluation/colbert_triplet.py +259 -0
- pylate-0.0.1/pylate/indexes/__init__.py +3 -0
- pylate-0.0.1/pylate/indexes/base.py +36 -0
- pylate-0.0.1/pylate/indexes/voyager.py +354 -0
- pylate-0.0.1/pylate/losses/__init__.py +4 -0
- pylate-0.0.1/pylate/losses/contrastive.py +170 -0
- pylate-0.0.1/pylate/losses/distillation.py +125 -0
- pylate-0.0.1/pylate/models/Dense.py +102 -0
- pylate-0.0.1/pylate/models/__init__.py +4 -0
- pylate-0.0.1/pylate/models/colbert.py +1165 -0
- pylate-0.0.1/pylate/rank/__init__.py +3 -0
- pylate-0.0.1/pylate/rank/rank.py +138 -0
- pylate-0.0.1/pylate/retrieve/__init__.py +3 -0
- pylate-0.0.1/pylate/retrieve/colbert.py +144 -0
- pylate-0.0.1/pylate/scores/__init__.py +3 -0
- pylate-0.0.1/pylate/scores/scores.py +172 -0
- pylate-0.0.1/pylate/utils/__init__.py +15 -0
- pylate-0.0.1/pylate/utils/collator.py +104 -0
- pylate-0.0.1/pylate/utils/huggingface_models.py +73 -0
- pylate-0.0.1/pylate/utils/iter_batch.py +40 -0
- pylate-0.0.1/pylate/utils/multi_process.py +110 -0
- pylate-0.0.1/pylate/utils/processing.py +186 -0
- pylate-0.0.1/pylate/utils/tensor.py +69 -0
- pylate-0.0.1/pylate.egg-info/PKG-INFO +292 -0
- pylate-0.0.1/pylate.egg-info/SOURCES.txt +40 -0
- pylate-0.0.1/pylate.egg-info/dependency_links.txt +1 -0
- pylate-0.0.1/pylate.egg-info/requires.txt +34 -0
- pylate-0.0.1/pylate.egg-info/top_level.txt +1 -0
- pylate-0.0.1/setup.cfg +7 -0
- pylate-0.0.1/setup.py +53 -0
- pylate-0.0.1/tests/test_contrastive.py +86 -0
- pylate-0.0.1/tests/test_kd.py +69 -0
- pylate-0.0.1/tests/test_retriever.py +88 -0
pylate-0.0.1/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2024 LightOn
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
pylate-0.0.1/PKG-INFO
ADDED
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: pylate
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Home-page: https://github.com/lightonai/giga-cherche
|
|
5
|
+
Author: LightON
|
|
6
|
+
Classifier: Programming Language :: Python :: 3
|
|
7
|
+
Classifier: Operating System :: OS Independent
|
|
8
|
+
Requires-Python: >=3.8
|
|
9
|
+
Description-Content-Type: text/markdown
|
|
10
|
+
License-File: LICENSE
|
|
11
|
+
Requires-Dist: sentence-transformers>=3.0.1
|
|
12
|
+
Requires-Dist: datasets>=2.20.0
|
|
13
|
+
Requires-Dist: accelerate>=0.31.0
|
|
14
|
+
Requires-Dist: voyager>=2.0.9
|
|
15
|
+
Requires-Dist: sqlitedict>=2.1.0
|
|
16
|
+
Requires-Dist: pandas>=2.2.1
|
|
17
|
+
Provides-Extra: eval
|
|
18
|
+
Requires-Dist: sentence-transformers>=3.0.1; extra == "eval"
|
|
19
|
+
Requires-Dist: datasets>=2.20.0; extra == "eval"
|
|
20
|
+
Requires-Dist: accelerate>=0.31.0; extra == "eval"
|
|
21
|
+
Requires-Dist: voyager>=2.0.9; extra == "eval"
|
|
22
|
+
Requires-Dist: sqlitedict>=2.1.0; extra == "eval"
|
|
23
|
+
Requires-Dist: pandas>=2.2.1; extra == "eval"
|
|
24
|
+
Requires-Dist: ranx>=0.3.16; extra == "eval"
|
|
25
|
+
Requires-Dist: beir>=2.0.0; extra == "eval"
|
|
26
|
+
Provides-Extra: dev
|
|
27
|
+
Requires-Dist: sentence-transformers>=3.0.1; extra == "dev"
|
|
28
|
+
Requires-Dist: datasets>=2.20.0; extra == "dev"
|
|
29
|
+
Requires-Dist: accelerate>=0.31.0; extra == "dev"
|
|
30
|
+
Requires-Dist: voyager>=2.0.9; extra == "dev"
|
|
31
|
+
Requires-Dist: sqlitedict>=2.1.0; extra == "dev"
|
|
32
|
+
Requires-Dist: pandas>=2.2.1; extra == "dev"
|
|
33
|
+
Requires-Dist: ruff>=0.4.9; extra == "dev"
|
|
34
|
+
Requires-Dist: pytest-cov>=5.0.0; extra == "dev"
|
|
35
|
+
Requires-Dist: pytest>=8.2.1; extra == "dev"
|
|
36
|
+
Requires-Dist: pandas>=2.2.1; extra == "dev"
|
|
37
|
+
Requires-Dist: mkdocs-material==9.5.32; extra == "dev"
|
|
38
|
+
Requires-Dist: mkdocs-awesome-pages-plugin==2.9.3; extra == "dev"
|
|
39
|
+
Requires-Dist: mkdocs-jupyter==0.24.8; extra == "dev"
|
|
40
|
+
Requires-Dist: mkdocs_charts_plugin==0.0.10; extra == "dev"
|
|
41
|
+
Requires-Dist: numpydoc==1.8.0; extra == "dev"
|
|
42
|
+
Requires-Dist: ranx>=0.3.16; extra == "dev"
|
|
43
|
+
Requires-Dist: beir>=2.0.0; extra == "dev"
|
|
44
|
+
|
|
45
|
+
<div align="center">
|
|
46
|
+
<h1>PyLate</h1>
|
|
47
|
+
<p>Flexible Training and Retrieval for Late Interaction Models</p>
|
|
48
|
+
</div>
|
|
49
|
+
|
|
50
|
+
<p align="center"><img width=500 src="docs/img/logo.png"/></p>
|
|
51
|
+
|
|
52
|
+
<div align="center">
|
|
53
|
+
<!-- Documentation -->
|
|
54
|
+
<a href="https://github.com/lightonai/pylate"><img src="https://img.shields.io/badge/Documentation-purple.svg?style=flat-square" alt="documentation"></a>
|
|
55
|
+
<!-- License -->
|
|
56
|
+
<a href="https://opensource.org/licenses/MIT"><img src="https://img.shields.io/badge/License-MIT-blue.svg?style=flat-square" alt="license"></a>
|
|
57
|
+
</div>
|
|
58
|
+
|
|
59
|
+
PyLate is a library built on top of Sentence Transformers, designed to simplify and optimize fine-tuning, inference, and retrieval with state-of-the-art ColBERT models. It enables easy fine-tuning on both single and multiple GPUs, providing flexibility for various hardware setups. PyLate also streamlines document retrieval and allows you to load a wide range of models, enabling you to construct ColBERT models from most pre-trained language models.
|
|
60
|
+
|
|
61
|
+
## Installation
|
|
62
|
+
|
|
63
|
+
You can install PyLate using pip:
|
|
64
|
+
|
|
65
|
+
```bash
|
|
66
|
+
pip install pylate
|
|
67
|
+
```
|
|
68
|
+
|
|
69
|
+
For evaluation dependencies, use:
|
|
70
|
+
|
|
71
|
+
```bash
|
|
72
|
+
pip install "pylate[eval]"
|
|
73
|
+
```
|
|
74
|
+
|
|
75
|
+
## Documentation
|
|
76
|
+
|
|
77
|
+
The complete documentation is available [here](https://lightonai.github.io/pylate/), which includes in-depth guides, examples, and API references.
|
|
78
|
+
|
|
79
|
+
## Datasets
|
|
80
|
+
|
|
81
|
+
PyLate supports Hugging Face [Datasets](https://huggingface.co/docs/datasets/en/index), enabling seamless triplet / knowledge distillation based training. Below is an example of creating a custom dataset for training:
|
|
82
|
+
|
|
83
|
+
```python
|
|
84
|
+
from datasets import Dataset
|
|
85
|
+
|
|
86
|
+
dataset = [
|
|
87
|
+
{
|
|
88
|
+
"query": "example query 1",
|
|
89
|
+
"positive": "example positive document 1",
|
|
90
|
+
"negative": "example negative document 1",
|
|
91
|
+
},
|
|
92
|
+
{
|
|
93
|
+
"query": "example query 2",
|
|
94
|
+
"positive": "example positive document 2",
|
|
95
|
+
"negative": "example negative document 2",
|
|
96
|
+
},
|
|
97
|
+
{
|
|
98
|
+
"query": "example query 3",
|
|
99
|
+
"positive": "example positive document 3",
|
|
100
|
+
"negative": "example negative document 3",
|
|
101
|
+
},
|
|
102
|
+
]
|
|
103
|
+
|
|
104
|
+
dataset = Dataset.from_list(mapping=dataset)
|
|
105
|
+
|
|
106
|
+
train_dataset, test_dataset = dataset.train_test_split(test_size=0.3)
|
|
107
|
+
```
|
|
108
|
+
|
|
109
|
+
## Training
|
|
110
|
+
|
|
111
|
+
Here’s a simple example of training a ColBERT model on the MSMARCO dataset using PyLate. This script demonstrates training with triplet loss and evaluating the model on a test set.
|
|
112
|
+
|
|
113
|
+
```python
|
|
114
|
+
from datasets import load_dataset
|
|
115
|
+
from sentence_transformers import (
|
|
116
|
+
SentenceTransformerTrainer,
|
|
117
|
+
SentenceTransformerTrainingArguments,
|
|
118
|
+
)
|
|
119
|
+
from sentence_transformers.training_args import BatchSamplers
|
|
120
|
+
|
|
121
|
+
from pylate import evaluation, losses, models, utils
|
|
122
|
+
|
|
123
|
+
# Define the model
|
|
124
|
+
model = models.ColBERT(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2")
|
|
125
|
+
|
|
126
|
+
# Load dataset
|
|
127
|
+
dataset = load_dataset("sentence-transformers/msmarco-bm25", "triplet", split="train")
|
|
128
|
+
|
|
129
|
+
# Split the dataset to create a test set
|
|
130
|
+
train_dataset, eval_dataset = dataset.train_test_split(test_size=0.01)
|
|
131
|
+
|
|
132
|
+
# Shuffle and select a subset of the dataset for demonstration purposes
|
|
133
|
+
MAX_TRAIN_SIZE, MAX_EVAL_SIZE = 100, 100
|
|
134
|
+
train_dataset = train_dataset.shuffle(seed=21).select(range(MAX_TRAIN_SIZE))
|
|
135
|
+
eval_dataset = eval_dataset.shuffle(seed=21).select(range(MAX_EVAL_SIZE))
|
|
136
|
+
|
|
137
|
+
# Define the loss function
|
|
138
|
+
train_loss = losses.Contrastive(model=model)
|
|
139
|
+
|
|
140
|
+
args = SentenceTransformerTrainingArguments(
|
|
141
|
+
output_dir="colbert-training",
|
|
142
|
+
num_train_epochs=1,
|
|
143
|
+
per_device_train_batch_size=32,
|
|
144
|
+
per_device_eval_batch_size=32,
|
|
145
|
+
fp16=False, # Some GPUs support FP16 which is faster than FP32
|
|
146
|
+
bf16=False, # Some GPUs support BF16 which is a faster FP16
|
|
147
|
+
batch_sampler=BatchSamplers.NO_DUPLICATES,
|
|
148
|
+
# Tracking parameters:
|
|
149
|
+
eval_strategy="steps",
|
|
150
|
+
eval_steps=0.1,
|
|
151
|
+
save_strategy="steps",
|
|
152
|
+
save_steps=5000,
|
|
153
|
+
save_total_limit=2,
|
|
154
|
+
learning_rate=3e-6,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
# Evaluation procedure
|
|
158
|
+
dev_evaluator = evaluation.ColBERTTripletEvaluator(
|
|
159
|
+
anchors=eval_dataset["query"],
|
|
160
|
+
positives=eval_dataset["positive"],
|
|
161
|
+
negatives=eval_dataset["negative"],
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
trainer = SentenceTransformerTrainer(
|
|
165
|
+
model=model,
|
|
166
|
+
args=args,
|
|
167
|
+
train_dataset=train_dataset,
|
|
168
|
+
eval_dataset=eval_dataset,
|
|
169
|
+
loss=train_loss,
|
|
170
|
+
evaluator=dev_evaluator,
|
|
171
|
+
data_collator=utils.ColBERTCollator(tokenize_fn=model.tokenize),
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
trainer.train()
|
|
175
|
+
|
|
176
|
+
model.save_pretrained("custom-colbert-model")
|
|
177
|
+
```
|
|
178
|
+
|
|
179
|
+
After training, the model can be loaded like this:
|
|
180
|
+
|
|
181
|
+
```python
|
|
182
|
+
from pylate import models
|
|
183
|
+
|
|
184
|
+
model = models.ColBERT(model_name_or_path="custom-colbert-model")
|
|
185
|
+
```
|
|
186
|
+
|
|
187
|
+
## Retrieve
|
|
188
|
+
|
|
189
|
+
PyLate allows easy retrieval of top documents for a given query set using the trained ColBERT model and Voyager index.
|
|
190
|
+
|
|
191
|
+
```python
|
|
192
|
+
from pylate import indexes, models, retrieve
|
|
193
|
+
|
|
194
|
+
model = models.ColBERT(
|
|
195
|
+
model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
index = indexes.Voyager(
|
|
199
|
+
index_folder="pylate-index",
|
|
200
|
+
index_name="index",
|
|
201
|
+
override=True,
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
retriever = retrieve.ColBERT(index=index)
|
|
205
|
+
```
|
|
206
|
+
|
|
207
|
+
Once the model and index are set up, we can add documents to the index:
|
|
208
|
+
|
|
209
|
+
```python
|
|
210
|
+
documents_ids = ["1", "2", "3"]
|
|
211
|
+
|
|
212
|
+
documents = [
|
|
213
|
+
"document 1 text", "document 2 text", "document 3 text"
|
|
214
|
+
]
|
|
215
|
+
|
|
216
|
+
# Encode the documents
|
|
217
|
+
documents_embeddings = model.encode(
|
|
218
|
+
documents,
|
|
219
|
+
batch_size=32,
|
|
220
|
+
is_query=False, # Encoding documents
|
|
221
|
+
show_progress_bar=True,
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
# Add the documents ids and embeddings to the Voyager index
|
|
225
|
+
index.add_documents(
|
|
226
|
+
documents_ids=documents_ids,
|
|
227
|
+
documents_embeddings=documents_embeddings,
|
|
228
|
+
)
|
|
229
|
+
```
|
|
230
|
+
|
|
231
|
+
Then we can retrieve the top-k documents for a given query set:
|
|
232
|
+
|
|
233
|
+
```python
|
|
234
|
+
queries_embeddings = model.encode(
|
|
235
|
+
["query for document 3", "query for document 1"],
|
|
236
|
+
batch_size=32,
|
|
237
|
+
is_query=True, # Encoding queries
|
|
238
|
+
show_progress_bar=True,
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
scores = retriever.retrieve(
|
|
242
|
+
queries_embeddings=queries_embeddings,
|
|
243
|
+
k=10,
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
print(scores)
|
|
247
|
+
```
|
|
248
|
+
|
|
249
|
+
Sample Output:
|
|
250
|
+
|
|
251
|
+
```python
|
|
252
|
+
[
|
|
253
|
+
[
|
|
254
|
+
{"id": "3", "score": 11.266985893249512},
|
|
255
|
+
{"id": "1", "score": 10.303335189819336},
|
|
256
|
+
{"id": "2", "score": 9.502392768859863},
|
|
257
|
+
],
|
|
258
|
+
[
|
|
259
|
+
{"id": "1", "score": 10.88800048828125},
|
|
260
|
+
{"id": "3", "score": 9.950843811035156},
|
|
261
|
+
{"id": "2", "score": 9.602447509765625},
|
|
262
|
+
],
|
|
263
|
+
]
|
|
264
|
+
```
|
|
265
|
+
|
|
266
|
+
## Contributing
|
|
267
|
+
|
|
268
|
+
We welcome contributions! To get started:
|
|
269
|
+
|
|
270
|
+
1. Install the development dependencies:
|
|
271
|
+
|
|
272
|
+
```bash
|
|
273
|
+
pip install "pylate[dev]"
|
|
274
|
+
```
|
|
275
|
+
|
|
276
|
+
2. Run tests:
|
|
277
|
+
|
|
278
|
+
```bash
|
|
279
|
+
make test
|
|
280
|
+
```
|
|
281
|
+
|
|
282
|
+
3. Format code with Ruff:
|
|
283
|
+
|
|
284
|
+
```bash
|
|
285
|
+
make ruff
|
|
286
|
+
```
|
|
287
|
+
|
|
288
|
+
4. Build the documentation:
|
|
289
|
+
|
|
290
|
+
```bash
|
|
291
|
+
make livedoc
|
|
292
|
+
```
|
pylate-0.0.1/README.md
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
1
|
+
<div align="center">
|
|
2
|
+
<h1>PyLate</h1>
|
|
3
|
+
<p>Flexible Training and Retrieval for Late Interaction Models</p>
|
|
4
|
+
</div>
|
|
5
|
+
|
|
6
|
+
<p align="center"><img width=500 src="docs/img/logo.png"/></p>
|
|
7
|
+
|
|
8
|
+
<div align="center">
|
|
9
|
+
<!-- Documentation -->
|
|
10
|
+
<a href="https://github.com/lightonai/pylate"><img src="https://img.shields.io/badge/Documentation-purple.svg?style=flat-square" alt="documentation"></a>
|
|
11
|
+
<!-- License -->
|
|
12
|
+
<a href="https://opensource.org/licenses/MIT"><img src="https://img.shields.io/badge/License-MIT-blue.svg?style=flat-square" alt="license"></a>
|
|
13
|
+
</div>
|
|
14
|
+
|
|
15
|
+
PyLate is a library built on top of Sentence Transformers, designed to simplify and optimize fine-tuning, inference, and retrieval with state-of-the-art ColBERT models. It enables easy fine-tuning on both single and multiple GPUs, providing flexibility for various hardware setups. PyLate also streamlines document retrieval and allows you to load a wide range of models, enabling you to construct ColBERT models from most pre-trained language models.
|
|
16
|
+
|
|
17
|
+
## Installation
|
|
18
|
+
|
|
19
|
+
You can install PyLate using pip:
|
|
20
|
+
|
|
21
|
+
```bash
|
|
22
|
+
pip install pylate
|
|
23
|
+
```
|
|
24
|
+
|
|
25
|
+
For evaluation dependencies, use:
|
|
26
|
+
|
|
27
|
+
```bash
|
|
28
|
+
pip install "pylate[eval]"
|
|
29
|
+
```
|
|
30
|
+
|
|
31
|
+
## Documentation
|
|
32
|
+
|
|
33
|
+
The complete documentation is available [here](https://lightonai.github.io/pylate/), which includes in-depth guides, examples, and API references.
|
|
34
|
+
|
|
35
|
+
## Datasets
|
|
36
|
+
|
|
37
|
+
PyLate supports Hugging Face [Datasets](https://huggingface.co/docs/datasets/en/index), enabling seamless triplet / knowledge distillation based training. Below is an example of creating a custom dataset for training:
|
|
38
|
+
|
|
39
|
+
```python
|
|
40
|
+
from datasets import Dataset
|
|
41
|
+
|
|
42
|
+
dataset = [
|
|
43
|
+
{
|
|
44
|
+
"query": "example query 1",
|
|
45
|
+
"positive": "example positive document 1",
|
|
46
|
+
"negative": "example negative document 1",
|
|
47
|
+
},
|
|
48
|
+
{
|
|
49
|
+
"query": "example query 2",
|
|
50
|
+
"positive": "example positive document 2",
|
|
51
|
+
"negative": "example negative document 2",
|
|
52
|
+
},
|
|
53
|
+
{
|
|
54
|
+
"query": "example query 3",
|
|
55
|
+
"positive": "example positive document 3",
|
|
56
|
+
"negative": "example negative document 3",
|
|
57
|
+
},
|
|
58
|
+
]
|
|
59
|
+
|
|
60
|
+
dataset = Dataset.from_list(mapping=dataset)
|
|
61
|
+
|
|
62
|
+
train_dataset, test_dataset = dataset.train_test_split(test_size=0.3)
|
|
63
|
+
```
|
|
64
|
+
|
|
65
|
+
## Training
|
|
66
|
+
|
|
67
|
+
Here’s a simple example of training a ColBERT model on the MSMARCO dataset using PyLate. This script demonstrates training with triplet loss and evaluating the model on a test set.
|
|
68
|
+
|
|
69
|
+
```python
|
|
70
|
+
from datasets import load_dataset
|
|
71
|
+
from sentence_transformers import (
|
|
72
|
+
SentenceTransformerTrainer,
|
|
73
|
+
SentenceTransformerTrainingArguments,
|
|
74
|
+
)
|
|
75
|
+
from sentence_transformers.training_args import BatchSamplers
|
|
76
|
+
|
|
77
|
+
from pylate import evaluation, losses, models, utils
|
|
78
|
+
|
|
79
|
+
# Define the model
|
|
80
|
+
model = models.ColBERT(model_name_or_path="sentence-transformers/all-MiniLM-L6-v2")
|
|
81
|
+
|
|
82
|
+
# Load dataset
|
|
83
|
+
dataset = load_dataset("sentence-transformers/msmarco-bm25", "triplet", split="train")
|
|
84
|
+
|
|
85
|
+
# Split the dataset to create a test set
|
|
86
|
+
train_dataset, eval_dataset = dataset.train_test_split(test_size=0.01)
|
|
87
|
+
|
|
88
|
+
# Shuffle and select a subset of the dataset for demonstration purposes
|
|
89
|
+
MAX_TRAIN_SIZE, MAX_EVAL_SIZE = 100, 100
|
|
90
|
+
train_dataset = train_dataset.shuffle(seed=21).select(range(MAX_TRAIN_SIZE))
|
|
91
|
+
eval_dataset = eval_dataset.shuffle(seed=21).select(range(MAX_EVAL_SIZE))
|
|
92
|
+
|
|
93
|
+
# Define the loss function
|
|
94
|
+
train_loss = losses.Contrastive(model=model)
|
|
95
|
+
|
|
96
|
+
args = SentenceTransformerTrainingArguments(
|
|
97
|
+
output_dir="colbert-training",
|
|
98
|
+
num_train_epochs=1,
|
|
99
|
+
per_device_train_batch_size=32,
|
|
100
|
+
per_device_eval_batch_size=32,
|
|
101
|
+
fp16=False, # Some GPUs support FP16 which is faster than FP32
|
|
102
|
+
bf16=False, # Some GPUs support BF16 which is a faster FP16
|
|
103
|
+
batch_sampler=BatchSamplers.NO_DUPLICATES,
|
|
104
|
+
# Tracking parameters:
|
|
105
|
+
eval_strategy="steps",
|
|
106
|
+
eval_steps=0.1,
|
|
107
|
+
save_strategy="steps",
|
|
108
|
+
save_steps=5000,
|
|
109
|
+
save_total_limit=2,
|
|
110
|
+
learning_rate=3e-6,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
# Evaluation procedure
|
|
114
|
+
dev_evaluator = evaluation.ColBERTTripletEvaluator(
|
|
115
|
+
anchors=eval_dataset["query"],
|
|
116
|
+
positives=eval_dataset["positive"],
|
|
117
|
+
negatives=eval_dataset["negative"],
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
trainer = SentenceTransformerTrainer(
|
|
121
|
+
model=model,
|
|
122
|
+
args=args,
|
|
123
|
+
train_dataset=train_dataset,
|
|
124
|
+
eval_dataset=eval_dataset,
|
|
125
|
+
loss=train_loss,
|
|
126
|
+
evaluator=dev_evaluator,
|
|
127
|
+
data_collator=utils.ColBERTCollator(tokenize_fn=model.tokenize),
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
trainer.train()
|
|
131
|
+
|
|
132
|
+
model.save_pretrained("custom-colbert-model")
|
|
133
|
+
```
|
|
134
|
+
|
|
135
|
+
After training, the model can be loaded like this:
|
|
136
|
+
|
|
137
|
+
```python
|
|
138
|
+
from pylate import models
|
|
139
|
+
|
|
140
|
+
model = models.ColBERT(model_name_or_path="custom-colbert-model")
|
|
141
|
+
```
|
|
142
|
+
|
|
143
|
+
## Retrieve
|
|
144
|
+
|
|
145
|
+
PyLate allows easy retrieval of top documents for a given query set using the trained ColBERT model and Voyager index.
|
|
146
|
+
|
|
147
|
+
```python
|
|
148
|
+
from pylate import indexes, models, retrieve
|
|
149
|
+
|
|
150
|
+
model = models.ColBERT(
|
|
151
|
+
model_name_or_path="sentence-transformers/all-MiniLM-L6-v2",
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
index = indexes.Voyager(
|
|
155
|
+
index_folder="pylate-index",
|
|
156
|
+
index_name="index",
|
|
157
|
+
override=True,
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
retriever = retrieve.ColBERT(index=index)
|
|
161
|
+
```
|
|
162
|
+
|
|
163
|
+
Once the model and index are set up, we can add documents to the index:
|
|
164
|
+
|
|
165
|
+
```python
|
|
166
|
+
documents_ids = ["1", "2", "3"]
|
|
167
|
+
|
|
168
|
+
documents = [
|
|
169
|
+
"document 1 text", "document 2 text", "document 3 text"
|
|
170
|
+
]
|
|
171
|
+
|
|
172
|
+
# Encode the documents
|
|
173
|
+
documents_embeddings = model.encode(
|
|
174
|
+
documents,
|
|
175
|
+
batch_size=32,
|
|
176
|
+
is_query=False, # Encoding documents
|
|
177
|
+
show_progress_bar=True,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# Add the documents ids and embeddings to the Voyager index
|
|
181
|
+
index.add_documents(
|
|
182
|
+
documents_ids=documents_ids,
|
|
183
|
+
documents_embeddings=documents_embeddings,
|
|
184
|
+
)
|
|
185
|
+
```
|
|
186
|
+
|
|
187
|
+
Then we can retrieve the top-k documents for a given query set:
|
|
188
|
+
|
|
189
|
+
```python
|
|
190
|
+
queries_embeddings = model.encode(
|
|
191
|
+
["query for document 3", "query for document 1"],
|
|
192
|
+
batch_size=32,
|
|
193
|
+
is_query=True, # Encoding queries
|
|
194
|
+
show_progress_bar=True,
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
scores = retriever.retrieve(
|
|
198
|
+
queries_embeddings=queries_embeddings,
|
|
199
|
+
k=10,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
print(scores)
|
|
203
|
+
```
|
|
204
|
+
|
|
205
|
+
Sample Output:
|
|
206
|
+
|
|
207
|
+
```python
|
|
208
|
+
[
|
|
209
|
+
[
|
|
210
|
+
{"id": "3", "score": 11.266985893249512},
|
|
211
|
+
{"id": "1", "score": 10.303335189819336},
|
|
212
|
+
{"id": "2", "score": 9.502392768859863},
|
|
213
|
+
],
|
|
214
|
+
[
|
|
215
|
+
{"id": "1", "score": 10.88800048828125},
|
|
216
|
+
{"id": "3", "score": 9.950843811035156},
|
|
217
|
+
{"id": "2", "score": 9.602447509765625},
|
|
218
|
+
],
|
|
219
|
+
]
|
|
220
|
+
```
|
|
221
|
+
|
|
222
|
+
## Contributing
|
|
223
|
+
|
|
224
|
+
We welcome contributions! To get started:
|
|
225
|
+
|
|
226
|
+
1. Install the development dependencies:
|
|
227
|
+
|
|
228
|
+
```bash
|
|
229
|
+
pip install "pylate[dev]"
|
|
230
|
+
```
|
|
231
|
+
|
|
232
|
+
2. Run tests:
|
|
233
|
+
|
|
234
|
+
```bash
|
|
235
|
+
make test
|
|
236
|
+
```
|
|
237
|
+
|
|
238
|
+
3. Format code with Ruff:
|
|
239
|
+
|
|
240
|
+
```bash
|
|
241
|
+
make ruff
|
|
242
|
+
```
|
|
243
|
+
|
|
244
|
+
4. Build the documentation:
|
|
245
|
+
|
|
246
|
+
```bash
|
|
247
|
+
make livedoc
|
|
248
|
+
```
|
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
from .beir import evaluate, get_beir_triples, load_beir
|
|
2
|
+
from .colbert_distillation import ColBERTDistillationEvaluator
|
|
3
|
+
from .colbert_triplet import ColBERTTripletEvaluator
|
|
4
|
+
|
|
5
|
+
__all__ = [
|
|
6
|
+
"ColBERTTripletEvaluator",
|
|
7
|
+
"ColBERTDistillationEvaluator",
|
|
8
|
+
"get_beir_triples",
|
|
9
|
+
"load_beir",
|
|
10
|
+
"evaluate",
|
|
11
|
+
]
|