fastembed-bio 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fastembed/__init__.py +24 -0
- fastembed/bio/__init__.py +3 -0
- fastembed/bio/protein_embedding.py +456 -0
- fastembed/common/__init__.py +3 -0
- fastembed/common/model_description.py +52 -0
- fastembed/common/model_management.py +471 -0
- fastembed/common/onnx_model.py +188 -0
- fastembed/common/preprocessor_utils.py +84 -0
- fastembed/common/types.py +27 -0
- fastembed/common/utils.py +69 -0
- fastembed/embedding.py +24 -0
- fastembed/image/__init__.py +3 -0
- fastembed/image/image_embedding.py +135 -0
- fastembed/image/image_embedding_base.py +55 -0
- fastembed/image/onnx_embedding.py +217 -0
- fastembed/image/onnx_image_model.py +156 -0
- fastembed/image/transform/functional.py +221 -0
- fastembed/image/transform/operators.py +499 -0
- fastembed/late_interaction/__init__.py +5 -0
- fastembed/late_interaction/colbert.py +301 -0
- fastembed/late_interaction/jina_colbert.py +58 -0
- fastembed/late_interaction/late_interaction_embedding_base.py +80 -0
- fastembed/late_interaction/late_interaction_text_embedding.py +180 -0
- fastembed/late_interaction/token_embeddings.py +83 -0
- fastembed/late_interaction_multimodal/__init__.py +5 -0
- fastembed/late_interaction_multimodal/colmodernvbert.py +532 -0
- fastembed/late_interaction_multimodal/colpali.py +327 -0
- fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +189 -0
- fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +86 -0
- fastembed/late_interaction_multimodal/onnx_multimodal_model.py +291 -0
- fastembed/parallel_processor.py +253 -0
- fastembed/postprocess/__init__.py +3 -0
- fastembed/postprocess/muvera.py +362 -0
- fastembed/py.typed +1 -0
- fastembed/rerank/cross_encoder/__init__.py +3 -0
- fastembed/rerank/cross_encoder/custom_text_cross_encoder.py +47 -0
- fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +239 -0
- fastembed/rerank/cross_encoder/onnx_text_model.py +204 -0
- fastembed/rerank/cross_encoder/text_cross_encoder.py +178 -0
- fastembed/rerank/cross_encoder/text_cross_encoder_base.py +63 -0
- fastembed/sparse/__init__.py +4 -0
- fastembed/sparse/bm25.py +359 -0
- fastembed/sparse/bm42.py +369 -0
- fastembed/sparse/minicoil.py +372 -0
- fastembed/sparse/sparse_embedding_base.py +90 -0
- fastembed/sparse/sparse_text_embedding.py +143 -0
- fastembed/sparse/splade_pp.py +196 -0
- fastembed/sparse/utils/minicoil_encoder.py +146 -0
- fastembed/sparse/utils/sparse_vectors_converter.py +244 -0
- fastembed/sparse/utils/tokenizer.py +120 -0
- fastembed/sparse/utils/vocab_resolver.py +202 -0
- fastembed/text/__init__.py +3 -0
- fastembed/text/clip_embedding.py +56 -0
- fastembed/text/custom_text_embedding.py +97 -0
- fastembed/text/multitask_embedding.py +109 -0
- fastembed/text/onnx_embedding.py +353 -0
- fastembed/text/onnx_text_model.py +180 -0
- fastembed/text/pooled_embedding.py +136 -0
- fastembed/text/pooled_normalized_embedding.py +164 -0
- fastembed/text/text_embedding.py +228 -0
- fastembed/text/text_embedding_base.py +75 -0
- fastembed_bio-0.1.0.dist-info/METADATA +339 -0
- fastembed_bio-0.1.0.dist-info/RECORD +66 -0
- fastembed_bio-0.1.0.dist-info/WHEEL +4 -0
- fastembed_bio-0.1.0.dist-info/licenses/LICENSE +201 -0
- fastembed_bio-0.1.0.dist-info/licenses/NOTICE +22 -0
|
@@ -0,0 +1,291 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import os
|
|
3
|
+
from multiprocessing import get_all_start_methods
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any, Iterable, Sequence, Type
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from PIL import Image
|
|
9
|
+
from tokenizers import Encoding, Tokenizer
|
|
10
|
+
|
|
11
|
+
from fastembed.common import OnnxProvider, ImageInput
|
|
12
|
+
from fastembed.common.onnx_model import EmbeddingWorker, OnnxModel, OnnxOutputContext, T
|
|
13
|
+
from fastembed.common.preprocessor_utils import load_tokenizer, load_preprocessor
|
|
14
|
+
from fastembed.common.types import NumpyArray, Device
|
|
15
|
+
from fastembed.common.utils import iter_batch
|
|
16
|
+
from fastembed.image.transform.operators import Compose
|
|
17
|
+
from fastembed.parallel_processor import ParallelWorkerPool
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class OnnxMultimodalModel(OnnxModel[T]):
|
|
21
|
+
ONNX_OUTPUT_NAMES: list[str] | None = None
|
|
22
|
+
|
|
23
|
+
def __init__(self) -> None:
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.tokenizer: Tokenizer | None = None
|
|
26
|
+
self.processor: Compose | None = None
|
|
27
|
+
self.special_token_to_id: dict[str, int] = {}
|
|
28
|
+
|
|
29
|
+
def _preprocess_onnx_text_input(
|
|
30
|
+
self, onnx_input: dict[str, NumpyArray], **kwargs: Any
|
|
31
|
+
) -> dict[str, NumpyArray]:
|
|
32
|
+
"""
|
|
33
|
+
Preprocess the onnx input.
|
|
34
|
+
"""
|
|
35
|
+
return onnx_input
|
|
36
|
+
|
|
37
|
+
def _preprocess_onnx_image_input(
|
|
38
|
+
self, onnx_input: dict[str, NumpyArray], **kwargs: Any
|
|
39
|
+
) -> dict[str, NumpyArray]:
|
|
40
|
+
"""
|
|
41
|
+
Preprocess the onnx input.
|
|
42
|
+
"""
|
|
43
|
+
return onnx_input
|
|
44
|
+
|
|
45
|
+
@classmethod
|
|
46
|
+
def _get_text_worker_class(cls) -> Type["TextEmbeddingWorker[T]"]:
|
|
47
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
48
|
+
|
|
49
|
+
@classmethod
|
|
50
|
+
def _get_image_worker_class(cls) -> Type["ImageEmbeddingWorker[T]"]:
|
|
51
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
52
|
+
|
|
53
|
+
def _post_process_onnx_image_output(self, output: OnnxOutputContext) -> Iterable[T]:
|
|
54
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
55
|
+
|
|
56
|
+
def _post_process_onnx_text_output(self, output: OnnxOutputContext) -> Iterable[T]:
|
|
57
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
58
|
+
|
|
59
|
+
def _load_onnx_model(
|
|
60
|
+
self,
|
|
61
|
+
model_dir: Path,
|
|
62
|
+
model_file: str,
|
|
63
|
+
threads: int | None,
|
|
64
|
+
providers: Sequence[OnnxProvider] | None = None,
|
|
65
|
+
cuda: bool | Device = Device.AUTO,
|
|
66
|
+
device_id: int | None = None,
|
|
67
|
+
extra_session_options: dict[str, Any] | None = None,
|
|
68
|
+
) -> None:
|
|
69
|
+
super()._load_onnx_model(
|
|
70
|
+
model_dir=model_dir,
|
|
71
|
+
model_file=model_file,
|
|
72
|
+
threads=threads,
|
|
73
|
+
providers=providers,
|
|
74
|
+
cuda=cuda,
|
|
75
|
+
device_id=device_id,
|
|
76
|
+
extra_session_options=extra_session_options,
|
|
77
|
+
)
|
|
78
|
+
self.tokenizer, self.special_token_to_id = load_tokenizer(model_dir=model_dir)
|
|
79
|
+
assert self.tokenizer is not None
|
|
80
|
+
self.processor = load_preprocessor(model_dir=model_dir)
|
|
81
|
+
|
|
82
|
+
def load_onnx_model(self) -> None:
|
|
83
|
+
raise NotImplementedError("Subclasses must implement this method")
|
|
84
|
+
|
|
85
|
+
def tokenize(self, documents: list[str], **kwargs: Any) -> list[Encoding]:
|
|
86
|
+
return self.tokenizer.encode_batch(documents) # type: ignore[union-attr]
|
|
87
|
+
|
|
88
|
+
def onnx_embed_text(
|
|
89
|
+
self,
|
|
90
|
+
documents: list[str],
|
|
91
|
+
**kwargs: Any,
|
|
92
|
+
) -> OnnxOutputContext:
|
|
93
|
+
encoded = self.tokenize(documents, **kwargs)
|
|
94
|
+
input_ids = np.array([e.ids for e in encoded])
|
|
95
|
+
attention_mask = np.array([e.attention_mask for e in encoded]) # type: ignore[union-attr]
|
|
96
|
+
input_names = {node.name for node in self.model.get_inputs()} # type: ignore[union-attr]
|
|
97
|
+
onnx_input: dict[str, NumpyArray] = {
|
|
98
|
+
"input_ids": np.array(input_ids, dtype=np.int64),
|
|
99
|
+
}
|
|
100
|
+
if "attention_mask" in input_names:
|
|
101
|
+
onnx_input["attention_mask"] = np.array(attention_mask, dtype=np.int64)
|
|
102
|
+
if "token_type_ids" in input_names:
|
|
103
|
+
onnx_input["token_type_ids"] = np.array(
|
|
104
|
+
[np.zeros(len(e), dtype=np.int64) for e in input_ids], dtype=np.int64
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
onnx_input = self._preprocess_onnx_text_input(onnx_input, **kwargs)
|
|
108
|
+
model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore[union-attr]
|
|
109
|
+
return OnnxOutputContext(
|
|
110
|
+
model_output=model_output[0],
|
|
111
|
+
attention_mask=onnx_input.get("attention_mask", attention_mask),
|
|
112
|
+
input_ids=onnx_input.get("input_ids", input_ids),
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
def _embed_documents(
|
|
116
|
+
self,
|
|
117
|
+
model_name: str,
|
|
118
|
+
cache_dir: str,
|
|
119
|
+
documents: str | Iterable[str],
|
|
120
|
+
batch_size: int = 256,
|
|
121
|
+
parallel: int | None = None,
|
|
122
|
+
providers: Sequence[OnnxProvider] | None = None,
|
|
123
|
+
cuda: bool | Device = Device.AUTO,
|
|
124
|
+
device_ids: list[int] | None = None,
|
|
125
|
+
local_files_only: bool = False,
|
|
126
|
+
specific_model_path: str | None = None,
|
|
127
|
+
extra_session_options: dict[str, Any] | None = None,
|
|
128
|
+
**kwargs: Any,
|
|
129
|
+
) -> Iterable[T]:
|
|
130
|
+
is_small = False
|
|
131
|
+
|
|
132
|
+
if isinstance(documents, str):
|
|
133
|
+
documents = [documents]
|
|
134
|
+
is_small = True
|
|
135
|
+
|
|
136
|
+
if isinstance(documents, list):
|
|
137
|
+
if len(documents) < batch_size:
|
|
138
|
+
is_small = True
|
|
139
|
+
|
|
140
|
+
if parallel is None or is_small:
|
|
141
|
+
if not hasattr(self, "model") or self.model is None:
|
|
142
|
+
self.load_onnx_model()
|
|
143
|
+
for batch in iter_batch(documents, batch_size):
|
|
144
|
+
yield from self._post_process_onnx_text_output(self.onnx_embed_text(batch))
|
|
145
|
+
else:
|
|
146
|
+
if parallel == 0:
|
|
147
|
+
parallel = os.cpu_count()
|
|
148
|
+
|
|
149
|
+
start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
|
|
150
|
+
params = {
|
|
151
|
+
"model_name": model_name,
|
|
152
|
+
"cache_dir": cache_dir,
|
|
153
|
+
"providers": providers,
|
|
154
|
+
"local_files_only": local_files_only,
|
|
155
|
+
"specific_model_path": specific_model_path,
|
|
156
|
+
**kwargs,
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
if extra_session_options is not None:
|
|
160
|
+
params.update(extra_session_options)
|
|
161
|
+
|
|
162
|
+
pool = ParallelWorkerPool(
|
|
163
|
+
num_workers=parallel or 1,
|
|
164
|
+
worker=self._get_text_worker_class(),
|
|
165
|
+
cuda=cuda,
|
|
166
|
+
device_ids=device_ids,
|
|
167
|
+
start_method=start_method,
|
|
168
|
+
)
|
|
169
|
+
for batch in pool.ordered_map(iter_batch(documents, batch_size), **params):
|
|
170
|
+
yield from self._post_process_onnx_text_output(batch) # type: ignore
|
|
171
|
+
|
|
172
|
+
def onnx_embed_image(self, images: list[ImageInput], **kwargs: Any) -> OnnxOutputContext:
|
|
173
|
+
with contextlib.ExitStack() as stack:
|
|
174
|
+
image_files = [
|
|
175
|
+
stack.enter_context(Image.open(image))
|
|
176
|
+
if not isinstance(image, Image.Image)
|
|
177
|
+
else image
|
|
178
|
+
for image in images
|
|
179
|
+
]
|
|
180
|
+
assert self.processor is not None, "Processor is not initialized"
|
|
181
|
+
encoded = np.array(self.processor(image_files))
|
|
182
|
+
onnx_input = {"pixel_values": encoded}
|
|
183
|
+
onnx_input = self._preprocess_onnx_image_input(onnx_input, **kwargs)
|
|
184
|
+
model_output = self.model.run(None, onnx_input) # type: ignore[union-attr]
|
|
185
|
+
embeddings = model_output[0].reshape(len(images), -1)
|
|
186
|
+
return OnnxOutputContext(model_output=embeddings)
|
|
187
|
+
|
|
188
|
+
def _embed_images(
|
|
189
|
+
self,
|
|
190
|
+
model_name: str,
|
|
191
|
+
cache_dir: str,
|
|
192
|
+
images: Iterable[ImageInput] | ImageInput,
|
|
193
|
+
batch_size: int = 256,
|
|
194
|
+
parallel: int | None = None,
|
|
195
|
+
providers: Sequence[OnnxProvider] | None = None,
|
|
196
|
+
cuda: bool | Device = Device.AUTO,
|
|
197
|
+
device_ids: list[int] | None = None,
|
|
198
|
+
local_files_only: bool = False,
|
|
199
|
+
specific_model_path: str | None = None,
|
|
200
|
+
extra_session_options: dict[str, Any] | None = None,
|
|
201
|
+
**kwargs: Any,
|
|
202
|
+
) -> Iterable[T]:
|
|
203
|
+
is_small = False
|
|
204
|
+
|
|
205
|
+
if isinstance(images, (str, Path, Image.Image)):
|
|
206
|
+
images = [images]
|
|
207
|
+
is_small = True
|
|
208
|
+
|
|
209
|
+
if isinstance(images, list) and len(images) < batch_size:
|
|
210
|
+
is_small = True
|
|
211
|
+
|
|
212
|
+
if parallel is None or is_small:
|
|
213
|
+
if not hasattr(self, "model") or self.model is None:
|
|
214
|
+
self.load_onnx_model()
|
|
215
|
+
|
|
216
|
+
for batch in iter_batch(images, batch_size):
|
|
217
|
+
yield from self._post_process_onnx_image_output(self.onnx_embed_image(batch))
|
|
218
|
+
else:
|
|
219
|
+
if parallel == 0:
|
|
220
|
+
parallel = os.cpu_count()
|
|
221
|
+
|
|
222
|
+
start_method = "forkserver" if "forkserver" in get_all_start_methods() else "spawn"
|
|
223
|
+
params = {
|
|
224
|
+
"model_name": model_name,
|
|
225
|
+
"cache_dir": cache_dir,
|
|
226
|
+
"providers": providers,
|
|
227
|
+
"local_files_only": local_files_only,
|
|
228
|
+
"specific_model_path": specific_model_path,
|
|
229
|
+
**kwargs,
|
|
230
|
+
}
|
|
231
|
+
|
|
232
|
+
if extra_session_options is not None:
|
|
233
|
+
params.update(extra_session_options)
|
|
234
|
+
|
|
235
|
+
pool = ParallelWorkerPool(
|
|
236
|
+
num_workers=parallel or 1,
|
|
237
|
+
worker=self._get_image_worker_class(),
|
|
238
|
+
cuda=cuda,
|
|
239
|
+
device_ids=device_ids,
|
|
240
|
+
start_method=start_method,
|
|
241
|
+
)
|
|
242
|
+
for batch in pool.ordered_map(iter_batch(images, batch_size), **params):
|
|
243
|
+
yield from self._post_process_onnx_image_output(batch) # type: ignore
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
class TextEmbeddingWorker(EmbeddingWorker[T]):
|
|
247
|
+
def __init__(
|
|
248
|
+
self,
|
|
249
|
+
model_name: str,
|
|
250
|
+
cache_dir: str,
|
|
251
|
+
**kwargs: Any,
|
|
252
|
+
):
|
|
253
|
+
self.model: OnnxMultimodalModel
|
|
254
|
+
super().__init__(model_name, cache_dir, **kwargs)
|
|
255
|
+
|
|
256
|
+
def init_embedding(
|
|
257
|
+
self,
|
|
258
|
+
model_name: str,
|
|
259
|
+
cache_dir: str,
|
|
260
|
+
**kwargs: Any,
|
|
261
|
+
) -> OnnxMultimodalModel:
|
|
262
|
+
raise NotImplementedError()
|
|
263
|
+
|
|
264
|
+
def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]:
|
|
265
|
+
for idx, batch in items:
|
|
266
|
+
onnx_output = self.model.onnx_embed_text(batch)
|
|
267
|
+
yield idx, onnx_output
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
class ImageEmbeddingWorker(EmbeddingWorker[T]):
|
|
271
|
+
def __init__(
|
|
272
|
+
self,
|
|
273
|
+
model_name: str,
|
|
274
|
+
cache_dir: str,
|
|
275
|
+
**kwargs: Any,
|
|
276
|
+
):
|
|
277
|
+
self.model: OnnxMultimodalModel
|
|
278
|
+
super().__init__(model_name, cache_dir, **kwargs)
|
|
279
|
+
|
|
280
|
+
def init_embedding(
|
|
281
|
+
self,
|
|
282
|
+
model_name: str,
|
|
283
|
+
cache_dir: str,
|
|
284
|
+
**kwargs: Any,
|
|
285
|
+
) -> OnnxMultimodalModel:
|
|
286
|
+
raise NotImplementedError()
|
|
287
|
+
|
|
288
|
+
def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]:
|
|
289
|
+
for idx, batch in items:
|
|
290
|
+
embeddings = self.model.onnx_embed_image(batch)
|
|
291
|
+
yield idx, embeddings
|
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
from collections import defaultdict
|
|
4
|
+
from copy import deepcopy
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from multiprocessing import Queue, get_context
|
|
7
|
+
from multiprocessing.context import BaseContext
|
|
8
|
+
from multiprocessing.process import BaseProcess
|
|
9
|
+
from multiprocessing.sharedctypes import Synchronized as BaseValue
|
|
10
|
+
from queue import Empty
|
|
11
|
+
from typing import Any, Iterable, Type
|
|
12
|
+
|
|
13
|
+
from fastembed.common.types import Device
|
|
14
|
+
|
|
15
|
+
# Single item should be processed in less than:
|
|
16
|
+
processing_timeout = 10 * 60 # seconds
|
|
17
|
+
|
|
18
|
+
max_internal_batch_size = 200
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class QueueSignals(str, Enum):
|
|
22
|
+
stop = "stop"
|
|
23
|
+
confirm = "confirm"
|
|
24
|
+
error = "error"
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class Worker:
|
|
28
|
+
@classmethod
|
|
29
|
+
def start(cls, *args: Any, **kwargs: Any) -> "Worker":
|
|
30
|
+
raise NotImplementedError()
|
|
31
|
+
|
|
32
|
+
def process(self, items: Iterable[tuple[int, Any]]) -> Iterable[tuple[int, Any]]:
|
|
33
|
+
raise NotImplementedError()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _worker(
|
|
37
|
+
worker_class: Type[Worker],
|
|
38
|
+
input_queue: Queue,
|
|
39
|
+
output_queue: Queue,
|
|
40
|
+
num_active_workers: BaseValue,
|
|
41
|
+
worker_id: int,
|
|
42
|
+
kwargs: dict[str, Any] | None = None,
|
|
43
|
+
) -> None:
|
|
44
|
+
"""
|
|
45
|
+
A worker that pulls data pints off the input queue, and places the execution result on the output queue.
|
|
46
|
+
When there are no data pints left on the input queue, it decrements
|
|
47
|
+
num_active_workers to signal completion.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
if kwargs is None:
|
|
51
|
+
kwargs = {}
|
|
52
|
+
|
|
53
|
+
logging.info(
|
|
54
|
+
f"Reader worker: {worker_id} PID: {os.getpid()} Device: {kwargs.get('device_id', 'CPU')}"
|
|
55
|
+
)
|
|
56
|
+
try:
|
|
57
|
+
worker = worker_class.start(**kwargs)
|
|
58
|
+
|
|
59
|
+
# Keep going until you get an item that's None.
|
|
60
|
+
def input_queue_iterable() -> Iterable[Any]:
|
|
61
|
+
while True:
|
|
62
|
+
item = input_queue.get()
|
|
63
|
+
if item == QueueSignals.stop:
|
|
64
|
+
break
|
|
65
|
+
yield item
|
|
66
|
+
|
|
67
|
+
for processed_item in worker.process(input_queue_iterable()):
|
|
68
|
+
output_queue.put(processed_item)
|
|
69
|
+
except Exception as e: # pylint: disable=broad-except
|
|
70
|
+
logging.exception(e)
|
|
71
|
+
output_queue.put(QueueSignals.error)
|
|
72
|
+
finally:
|
|
73
|
+
# It's important that we close and join the queue here before
|
|
74
|
+
# decrementing num_active_workers. Otherwise our parent may join us
|
|
75
|
+
# before the queue's feeder thread has passed all buffered items to
|
|
76
|
+
# the underlying pipe resulting in a deadlock.
|
|
77
|
+
#
|
|
78
|
+
# See:
|
|
79
|
+
# https://docs.python.org/3.6/library/multiprocessing.html?highlight=process#pipes-and-queues
|
|
80
|
+
# https://docs.python.org/3.6/library/multiprocessing.html?highlight=process#programming-guidelines
|
|
81
|
+
input_queue.close()
|
|
82
|
+
output_queue.close()
|
|
83
|
+
input_queue.join_thread()
|
|
84
|
+
output_queue.join_thread()
|
|
85
|
+
|
|
86
|
+
with num_active_workers.get_lock():
|
|
87
|
+
num_active_workers.value -= 1
|
|
88
|
+
|
|
89
|
+
logging.info(f"Reader worker {worker_id} finished")
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
class ParallelWorkerPool:
|
|
93
|
+
def __init__(
|
|
94
|
+
self,
|
|
95
|
+
num_workers: int,
|
|
96
|
+
worker: Type[Worker],
|
|
97
|
+
start_method: str | None = None,
|
|
98
|
+
device_ids: list[int] | None = None,
|
|
99
|
+
cuda: bool | Device = Device.AUTO,
|
|
100
|
+
):
|
|
101
|
+
self.worker_class = worker
|
|
102
|
+
self.num_workers = num_workers
|
|
103
|
+
self.input_queue: Queue | None = None
|
|
104
|
+
self.output_queue: Queue | None = None
|
|
105
|
+
self.ctx: BaseContext = get_context(start_method)
|
|
106
|
+
self.processes: list[BaseProcess] = []
|
|
107
|
+
self.queue_size = self.num_workers * max_internal_batch_size
|
|
108
|
+
self.emergency_shutdown = False
|
|
109
|
+
self.device_ids = device_ids
|
|
110
|
+
self.cuda = cuda
|
|
111
|
+
self.num_active_workers: BaseValue | None = None
|
|
112
|
+
|
|
113
|
+
def start(self, **kwargs: Any) -> None:
|
|
114
|
+
self.input_queue = self.ctx.Queue(self.queue_size)
|
|
115
|
+
self.output_queue = self.ctx.Queue(self.queue_size)
|
|
116
|
+
|
|
117
|
+
ctx_value = self.ctx.Value("i", self.num_workers)
|
|
118
|
+
assert isinstance(ctx_value, BaseValue)
|
|
119
|
+
self.num_active_workers = ctx_value
|
|
120
|
+
|
|
121
|
+
for worker_id in range(0, self.num_workers):
|
|
122
|
+
worker_kwargs = deepcopy(kwargs)
|
|
123
|
+
if self.device_ids:
|
|
124
|
+
device_id = self.device_ids[worker_id % len(self.device_ids)]
|
|
125
|
+
worker_kwargs["device_id"] = device_id
|
|
126
|
+
worker_kwargs["cuda"] = self.cuda
|
|
127
|
+
|
|
128
|
+
assert hasattr(self.ctx, "Process")
|
|
129
|
+
process = self.ctx.Process(
|
|
130
|
+
target=_worker,
|
|
131
|
+
args=(
|
|
132
|
+
self.worker_class,
|
|
133
|
+
self.input_queue,
|
|
134
|
+
self.output_queue,
|
|
135
|
+
self.num_active_workers,
|
|
136
|
+
worker_id,
|
|
137
|
+
worker_kwargs,
|
|
138
|
+
),
|
|
139
|
+
)
|
|
140
|
+
process.start()
|
|
141
|
+
self.processes.append(process)
|
|
142
|
+
|
|
143
|
+
def ordered_map(self, stream: Iterable[Any], *args: Any, **kwargs: Any) -> Iterable[Any]:
|
|
144
|
+
buffer: defaultdict[int, Any] = defaultdict(Any) # type: ignore
|
|
145
|
+
next_expected = 0
|
|
146
|
+
|
|
147
|
+
for idx, item in self.semi_ordered_map(stream, *args, **kwargs):
|
|
148
|
+
buffer[idx] = item
|
|
149
|
+
while next_expected in buffer:
|
|
150
|
+
yield buffer.pop(next_expected)
|
|
151
|
+
next_expected += 1
|
|
152
|
+
|
|
153
|
+
def semi_ordered_map(
|
|
154
|
+
self, stream: Iterable[Any], *args: Any, **kwargs: Any
|
|
155
|
+
) -> Iterable[tuple[int, Any]]:
|
|
156
|
+
try:
|
|
157
|
+
self.start(**kwargs)
|
|
158
|
+
|
|
159
|
+
assert self.input_queue is not None, "Input queue was not initialized"
|
|
160
|
+
assert self.output_queue is not None, "Output queue was not initialized"
|
|
161
|
+
|
|
162
|
+
pushed = 0
|
|
163
|
+
read = 0
|
|
164
|
+
for idx, item in enumerate(stream):
|
|
165
|
+
self.check_worker_health()
|
|
166
|
+
if pushed - read < self.queue_size:
|
|
167
|
+
try:
|
|
168
|
+
out_item = self.output_queue.get_nowait()
|
|
169
|
+
except Empty:
|
|
170
|
+
out_item = None
|
|
171
|
+
else:
|
|
172
|
+
try:
|
|
173
|
+
out_item = self.output_queue.get(timeout=processing_timeout)
|
|
174
|
+
except Empty as e:
|
|
175
|
+
self.join_or_terminate()
|
|
176
|
+
raise e
|
|
177
|
+
|
|
178
|
+
if out_item is not None:
|
|
179
|
+
if out_item == QueueSignals.error:
|
|
180
|
+
self.join_or_terminate()
|
|
181
|
+
raise RuntimeError("Thread unexpectedly terminated")
|
|
182
|
+
yield out_item
|
|
183
|
+
read += 1
|
|
184
|
+
|
|
185
|
+
self.input_queue.put((idx, item))
|
|
186
|
+
pushed += 1
|
|
187
|
+
|
|
188
|
+
for _ in range(self.num_workers):
|
|
189
|
+
self.input_queue.put(QueueSignals.stop)
|
|
190
|
+
|
|
191
|
+
while read < pushed:
|
|
192
|
+
self.check_worker_health()
|
|
193
|
+
out_item = self.output_queue.get(timeout=processing_timeout)
|
|
194
|
+
if out_item == QueueSignals.error:
|
|
195
|
+
self.join_or_terminate()
|
|
196
|
+
raise RuntimeError("Thread unexpectedly terminated")
|
|
197
|
+
yield out_item
|
|
198
|
+
read += 1
|
|
199
|
+
finally:
|
|
200
|
+
assert self.input_queue is not None, "Input queue is None"
|
|
201
|
+
assert self.output_queue is not None, "Output queue is None"
|
|
202
|
+
self.join()
|
|
203
|
+
self.input_queue.close()
|
|
204
|
+
self.output_queue.close()
|
|
205
|
+
if self.emergency_shutdown:
|
|
206
|
+
self.input_queue.cancel_join_thread()
|
|
207
|
+
self.output_queue.cancel_join_thread()
|
|
208
|
+
else:
|
|
209
|
+
self.input_queue.join_thread()
|
|
210
|
+
self.output_queue.join_thread()
|
|
211
|
+
|
|
212
|
+
def check_worker_health(self) -> None:
|
|
213
|
+
"""
|
|
214
|
+
Checks if any worker process has terminated unexpectedly
|
|
215
|
+
"""
|
|
216
|
+
for process in self.processes:
|
|
217
|
+
if not process.is_alive() and process.exitcode != 0:
|
|
218
|
+
self.emergency_shutdown = True
|
|
219
|
+
self.join_or_terminate()
|
|
220
|
+
raise RuntimeError(
|
|
221
|
+
f"Worker PID: {process.pid} terminated unexpectedly with code {process.exitcode}"
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
def join_or_terminate(self, timeout: int = 1) -> None:
|
|
225
|
+
"""
|
|
226
|
+
Emergency shutdown
|
|
227
|
+
@param timeout:
|
|
228
|
+
@return:
|
|
229
|
+
"""
|
|
230
|
+
for process in self.processes:
|
|
231
|
+
process.join(timeout=timeout)
|
|
232
|
+
if process.is_alive():
|
|
233
|
+
process.terminate()
|
|
234
|
+
self.processes.clear()
|
|
235
|
+
|
|
236
|
+
def join(self) -> None:
|
|
237
|
+
for process in self.processes:
|
|
238
|
+
process.join()
|
|
239
|
+
self.processes.clear()
|
|
240
|
+
|
|
241
|
+
def __del__(self) -> None:
|
|
242
|
+
"""
|
|
243
|
+
Terminate processes if the user hasn't joined. This is necessary as
|
|
244
|
+
leaving stray processes running can corrupt shared state. In brief,
|
|
245
|
+
we've observed shared memory counters being reused (when the memory was
|
|
246
|
+
free from the perspective of the parent process) while the stray
|
|
247
|
+
workers still held a reference to them.
|
|
248
|
+
For a discussion of using destructors in Python in this manner, see
|
|
249
|
+
https://eli.thegreenplace.net/2009/06/12/safely-using-destructors-in-python/.
|
|
250
|
+
"""
|
|
251
|
+
for process in self.processes:
|
|
252
|
+
if process.is_alive():
|
|
253
|
+
process.terminate()
|