fastembed-bio 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. fastembed/__init__.py +24 -0
  2. fastembed/bio/__init__.py +3 -0
  3. fastembed/bio/protein_embedding.py +456 -0
  4. fastembed/common/__init__.py +3 -0
  5. fastembed/common/model_description.py +52 -0
  6. fastembed/common/model_management.py +471 -0
  7. fastembed/common/onnx_model.py +188 -0
  8. fastembed/common/preprocessor_utils.py +84 -0
  9. fastembed/common/types.py +27 -0
  10. fastembed/common/utils.py +69 -0
  11. fastembed/embedding.py +24 -0
  12. fastembed/image/__init__.py +3 -0
  13. fastembed/image/image_embedding.py +135 -0
  14. fastembed/image/image_embedding_base.py +55 -0
  15. fastembed/image/onnx_embedding.py +217 -0
  16. fastembed/image/onnx_image_model.py +156 -0
  17. fastembed/image/transform/functional.py +221 -0
  18. fastembed/image/transform/operators.py +499 -0
  19. fastembed/late_interaction/__init__.py +5 -0
  20. fastembed/late_interaction/colbert.py +301 -0
  21. fastembed/late_interaction/jina_colbert.py +58 -0
  22. fastembed/late_interaction/late_interaction_embedding_base.py +80 -0
  23. fastembed/late_interaction/late_interaction_text_embedding.py +180 -0
  24. fastembed/late_interaction/token_embeddings.py +83 -0
  25. fastembed/late_interaction_multimodal/__init__.py +5 -0
  26. fastembed/late_interaction_multimodal/colmodernvbert.py +532 -0
  27. fastembed/late_interaction_multimodal/colpali.py +327 -0
  28. fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding.py +189 -0
  29. fastembed/late_interaction_multimodal/late_interaction_multimodal_embedding_base.py +86 -0
  30. fastembed/late_interaction_multimodal/onnx_multimodal_model.py +291 -0
  31. fastembed/parallel_processor.py +253 -0
  32. fastembed/postprocess/__init__.py +3 -0
  33. fastembed/postprocess/muvera.py +362 -0
  34. fastembed/py.typed +1 -0
  35. fastembed/rerank/cross_encoder/__init__.py +3 -0
  36. fastembed/rerank/cross_encoder/custom_text_cross_encoder.py +47 -0
  37. fastembed/rerank/cross_encoder/onnx_text_cross_encoder.py +239 -0
  38. fastembed/rerank/cross_encoder/onnx_text_model.py +204 -0
  39. fastembed/rerank/cross_encoder/text_cross_encoder.py +178 -0
  40. fastembed/rerank/cross_encoder/text_cross_encoder_base.py +63 -0
  41. fastembed/sparse/__init__.py +4 -0
  42. fastembed/sparse/bm25.py +359 -0
  43. fastembed/sparse/bm42.py +369 -0
  44. fastembed/sparse/minicoil.py +372 -0
  45. fastembed/sparse/sparse_embedding_base.py +90 -0
  46. fastembed/sparse/sparse_text_embedding.py +143 -0
  47. fastembed/sparse/splade_pp.py +196 -0
  48. fastembed/sparse/utils/minicoil_encoder.py +146 -0
  49. fastembed/sparse/utils/sparse_vectors_converter.py +244 -0
  50. fastembed/sparse/utils/tokenizer.py +120 -0
  51. fastembed/sparse/utils/vocab_resolver.py +202 -0
  52. fastembed/text/__init__.py +3 -0
  53. fastembed/text/clip_embedding.py +56 -0
  54. fastembed/text/custom_text_embedding.py +97 -0
  55. fastembed/text/multitask_embedding.py +109 -0
  56. fastembed/text/onnx_embedding.py +353 -0
  57. fastembed/text/onnx_text_model.py +180 -0
  58. fastembed/text/pooled_embedding.py +136 -0
  59. fastembed/text/pooled_normalized_embedding.py +164 -0
  60. fastembed/text/text_embedding.py +228 -0
  61. fastembed/text/text_embedding_base.py +75 -0
  62. fastembed_bio-0.1.0.dist-info/METADATA +339 -0
  63. fastembed_bio-0.1.0.dist-info/RECORD +66 -0
  64. fastembed_bio-0.1.0.dist-info/WHEEL +4 -0
  65. fastembed_bio-0.1.0.dist-info/licenses/LICENSE +201 -0
  66. fastembed_bio-0.1.0.dist-info/licenses/NOTICE +22 -0
fastembed/__init__.py ADDED
@@ -0,0 +1,24 @@
1
+ import importlib.metadata
2
+
3
+ from fastembed.bio import ProteinEmbedding
4
+ from fastembed.image import ImageEmbedding
5
+ from fastembed.late_interaction import LateInteractionTextEmbedding
6
+ from fastembed.late_interaction_multimodal import LateInteractionMultimodalEmbedding
7
+ from fastembed.sparse import SparseEmbedding, SparseTextEmbedding
8
+ from fastembed.text import TextEmbedding
9
+
10
+ try:
11
+ version = importlib.metadata.version("fastembed")
12
+ except importlib.metadata.PackageNotFoundError as _:
13
+ version = importlib.metadata.version("fastembed-gpu")
14
+
15
+ __version__ = version
16
+ __all__ = [
17
+ "TextEmbedding",
18
+ "SparseTextEmbedding",
19
+ "SparseEmbedding",
20
+ "ImageEmbedding",
21
+ "LateInteractionTextEmbedding",
22
+ "LateInteractionMultimodalEmbedding",
23
+ "ProteinEmbedding",
24
+ ]
@@ -0,0 +1,3 @@
1
+ from fastembed.bio.protein_embedding import ProteinEmbedding
2
+
3
+ __all__ = ["ProteinEmbedding"]
@@ -0,0 +1,456 @@
1
+ import json
2
+ from dataclasses import asdict
3
+ from pathlib import Path
4
+ from typing import Any, Iterable, Sequence, Type
5
+
6
+ import numpy as np
7
+ from tokenizers import Tokenizer, pre_tokenizers, processors
8
+ from tokenizers.models import WordLevel
9
+
10
+ from fastembed.common.model_description import DenseModelDescription, ModelSource
11
+ from fastembed.common.model_management import ModelManagement
12
+ from fastembed.common.onnx_model import OnnxModel, OnnxOutputContext, EmbeddingWorker
13
+ from fastembed.common.types import NumpyArray, OnnxProvider, Device
14
+ from fastembed.common.utils import define_cache_dir, iter_batch, normalize
15
+
16
+
17
+ supported_protein_models: list[DenseModelDescription] = [
18
+ DenseModelDescription(
19
+ model="facebook/esm2_t12_35M_UR50D",
20
+ dim=480,
21
+ description="Protein embeddings, ESM-2 35M parameters, 480 dimensions, 1024 max sequence length",
22
+ license="mit",
23
+ size_in_GB=0.13,
24
+ sources=ModelSource(hf="nleroy917/esm2_t12_35M_UR50D-onnx"),
25
+ model_file="model.onnx",
26
+ additional_files=["tokenizer.json", "tokenizer_config.json", "special_tokens_map.json"],
27
+ ),
28
+ ]
29
+
30
+
31
+ def load_protein_tokenizer(model_dir: Path, max_length: int = 1024) -> Tokenizer:
32
+ """Load a protein tokenizer from model directory using HuggingFace tokenizers.
33
+
34
+ Attempts to load in order:
35
+ 1. tokenizer.json (standard HuggingFace fast tokenizer format)
36
+ 2. Build from vocab.txt (fallback for models without tokenizer.json)
37
+
38
+ Args:
39
+ model_dir: Path to model directory containing tokenizer files
40
+ max_length: Maximum sequence length (default, can be overridden by config)
41
+
42
+ Returns:
43
+ Configured Tokenizer instance
44
+ """
45
+ tokenizer_json_path = model_dir / "tokenizer.json"
46
+ tokenizer_config_path = model_dir / "tokenizer_config.json"
47
+ vocab_path = model_dir / "vocab.txt"
48
+
49
+ # Try to load tokenizer.json directly (preferred)
50
+ if tokenizer_json_path.exists():
51
+ tokenizer = Tokenizer.from_file(str(tokenizer_json_path))
52
+ # Read max_length from config if available
53
+ if tokenizer_config_path.exists():
54
+ with open(tokenizer_config_path) as f:
55
+ config = json.load(f)
56
+ config_max_length = config.get("model_max_length", max_length)
57
+ # Cap at reasonable value (transformers defaults can be huge)
58
+ if config_max_length <= max_length:
59
+ max_length = config_max_length
60
+ tokenizer.enable_truncation(max_length=max_length)
61
+ return tokenizer
62
+
63
+ # Fall back to building from vocab.txt
64
+ if not vocab_path.exists():
65
+ raise ValueError(
66
+ f"Could not find tokenizer.json or vocab.txt in {model_dir}"
67
+ )
68
+
69
+ # Read max_length from config if available
70
+ if tokenizer_config_path.exists():
71
+ with open(tokenizer_config_path) as f:
72
+ config = json.load(f)
73
+ max_length = config.get("model_max_length", max_length)
74
+
75
+ vocab: dict[str, int] = {}
76
+ with open(vocab_path) as f:
77
+ for idx, line in enumerate(f):
78
+ token = line.strip()
79
+ vocab[token] = idx
80
+
81
+ unk_token = "<unk>"
82
+ cls_token = "<cls>"
83
+ eos_token = "<eos>"
84
+ pad_token = "<pad>"
85
+
86
+ tokenizer = Tokenizer(WordLevel(vocab=vocab, unk_token=unk_token))
87
+
88
+ tokenizer.pre_tokenizer = pre_tokenizers.Split(
89
+ pattern="", behavior="isolated", invert=False
90
+ )
91
+
92
+ cls_token_id = vocab.get(cls_token, 0)
93
+ eos_token_id = vocab.get(eos_token, 2)
94
+
95
+ tokenizer.post_processor = processors.TemplateProcessing(
96
+ single=f"{cls_token}:0 $A:0 {eos_token}:0",
97
+ special_tokens=[
98
+ (cls_token, cls_token_id),
99
+ (eos_token, eos_token_id),
100
+ ],
101
+ )
102
+
103
+ pad_token_id = vocab.get(pad_token, 1)
104
+ tokenizer.enable_padding(pad_id=pad_token_id, pad_token=pad_token)
105
+ tokenizer.enable_truncation(max_length=max_length)
106
+
107
+ return tokenizer
108
+
109
+
110
+ class ProteinEmbeddingBase(ModelManagement[DenseModelDescription]):
111
+ def __init__(
112
+ self,
113
+ model_name: str,
114
+ cache_dir: str | None = None,
115
+ threads: int | None = None,
116
+ **kwargs: Any,
117
+ ):
118
+ self.model_name = model_name
119
+ self.cache_dir = cache_dir
120
+ self.threads = threads
121
+ self._local_files_only = kwargs.pop("local_files_only", False)
122
+ self._embedding_size: int | None = None
123
+
124
+ def embed(
125
+ self,
126
+ sequences: str | Iterable[str],
127
+ batch_size: int = 32,
128
+ parallel: int | None = None,
129
+ **kwargs: Any,
130
+ ) -> Iterable[NumpyArray]:
131
+ """
132
+ Embed protein sequences.
133
+
134
+ Args:
135
+ sequences: Single protein sequence or iterable of sequences
136
+ batch_size: Batch size for encoding
137
+ parallel: Number of parallel workers (None for single-threaded)
138
+
139
+ Yields:
140
+ Embeddings as numpy arrays
141
+ """
142
+ raise NotImplementedError()
143
+
144
+ @classmethod
145
+ def get_embedding_size(cls, model_name: str) -> int:
146
+ """
147
+ Returns embedding size of the passed model.
148
+
149
+ Args:
150
+ model_name: Name of the model
151
+ """
152
+ descriptions = cls._list_supported_models()
153
+ for description in descriptions:
154
+ if description.model.lower() == model_name.lower():
155
+ if description.dim is not None:
156
+ return description.dim
157
+ raise ValueError(f"Model {model_name} not found")
158
+
159
+ @property
160
+ def embedding_size(self) -> int:
161
+ """
162
+ Returns embedding size for the current model.
163
+ """
164
+ if self._embedding_size is None:
165
+ self._embedding_size = self.get_embedding_size(self.model_name)
166
+ return self._embedding_size
167
+
168
+
169
+ class OnnxProteinModel(OnnxModel[NumpyArray]):
170
+ """
171
+ ONNX model handler for protein embeddings.
172
+ """
173
+
174
+ ONNX_OUTPUT_NAMES: list[str] | None = None
175
+
176
+ def __init__(self) -> None:
177
+ super().__init__()
178
+ self.tokenizer: Tokenizer | None = None
179
+
180
+ def _load_onnx_model(
181
+ self,
182
+ model_dir: Path,
183
+ model_file: str,
184
+ threads: int | None,
185
+ providers: Sequence[OnnxProvider] | None = None,
186
+ cuda: bool | Device = Device.AUTO,
187
+ device_id: int | None = None,
188
+ extra_session_options: dict[str, Any] | None = None,
189
+ ) -> None:
190
+ super()._load_onnx_model(
191
+ model_dir=model_dir,
192
+ model_file=model_file,
193
+ threads=threads,
194
+ providers=providers,
195
+ cuda=cuda,
196
+ device_id=device_id,
197
+ extra_session_options=extra_session_options,
198
+ )
199
+ self.tokenizer = load_protein_tokenizer(model_dir)
200
+
201
+ def onnx_embed(self, sequences: list[str], **kwargs: Any) -> OnnxOutputContext:
202
+ """
203
+ Run ONNX inference on protein sequences.
204
+
205
+ Args:
206
+ sequences: List of protein sequences
207
+ Returns:
208
+ OnnxOutputContext containing model output and inputs
209
+ """
210
+ assert self.tokenizer is not None
211
+
212
+ sequences = [seq.upper() for seq in sequences]
213
+ encoded = self.tokenizer.encode_batch(sequences)
214
+ input_ids = np.array([e.ids for e in encoded], dtype=np.int64)
215
+ attention_mask = np.array([e.attention_mask for e in encoded], dtype=np.int64)
216
+
217
+ input_names = {node.name for node in self.model.get_inputs()} # type: ignore[union-attr]
218
+ onnx_input: dict[str, NumpyArray] = {
219
+ "input_ids": input_ids,
220
+ }
221
+ if "attention_mask" in input_names:
222
+ onnx_input["attention_mask"] = attention_mask
223
+
224
+ model_output = self.model.run(self.ONNX_OUTPUT_NAMES, onnx_input) # type: ignore[union-attr]
225
+
226
+ return OnnxOutputContext(
227
+ model_output=model_output[0],
228
+ attention_mask=attention_mask,
229
+ input_ids=input_ids,
230
+ )
231
+
232
+ def _post_process_onnx_output(
233
+ self, output: OnnxOutputContext, **kwargs: Any
234
+ ) -> Iterable[NumpyArray]:
235
+ """Convert ONNX output to embeddings with mean pooling."""
236
+ embeddings = output.model_output
237
+ attention_mask = output.attention_mask
238
+
239
+ if attention_mask is None:
240
+ raise ValueError("attention_mask is required for mean pooling")
241
+
242
+ mask_expanded = np.expand_dims(attention_mask, axis=-1)
243
+ sum_embeddings = np.sum(embeddings * mask_expanded, axis=1)
244
+ sum_mask = np.sum(mask_expanded, axis=1)
245
+ sum_mask = np.clip(sum_mask, a_min=1e-9, a_max=None)
246
+ mean_embeddings = sum_embeddings / sum_mask
247
+
248
+ return normalize(mean_embeddings)
249
+
250
+
251
+ class OnnxProteinEmbedding(ProteinEmbeddingBase, OnnxProteinModel):
252
+ """
253
+ ONNX-based protein embedding implementation.
254
+ """
255
+
256
+ @classmethod
257
+ def _list_supported_models(cls) -> list[DenseModelDescription]:
258
+ return supported_protein_models
259
+
260
+ def __init__(
261
+ self,
262
+ model_name: str = "facebook/esm2_t12_35M_UR50D",
263
+ cache_dir: str | None = None,
264
+ threads: int | None = None,
265
+ providers: Sequence[OnnxProvider] | None = None,
266
+ cuda: bool | Device = Device.AUTO,
267
+ device_ids: list[int] | None = None,
268
+ lazy_load: bool = False,
269
+ device_id: int | None = None,
270
+ specific_model_path: str | None = None,
271
+ **kwargs: Any,
272
+ ):
273
+ super().__init__(model_name, cache_dir, threads, **kwargs)
274
+ self.providers = providers
275
+ self.lazy_load = lazy_load
276
+ self._extra_session_options = self._select_exposed_session_options(kwargs)
277
+ self.device_ids = device_ids
278
+ self.cuda = cuda
279
+
280
+ self.device_id: int | None = None
281
+ if device_id is not None:
282
+ self.device_id = device_id
283
+ elif self.device_ids is not None:
284
+ self.device_id = self.device_ids[0]
285
+
286
+ self.model_description = self._get_model_description(model_name)
287
+ self.cache_dir = str(define_cache_dir(cache_dir))
288
+ self._specific_model_path = specific_model_path
289
+ self._model_dir = self.download_model(
290
+ self.model_description,
291
+ self.cache_dir,
292
+ local_files_only=self._local_files_only,
293
+ specific_model_path=self._specific_model_path,
294
+ )
295
+
296
+ if not self.lazy_load:
297
+ self.load_onnx_model()
298
+
299
+ def load_onnx_model(self) -> None:
300
+ self._load_onnx_model(
301
+ model_dir=self._model_dir,
302
+ model_file=self.model_description.model_file,
303
+ threads=self.threads,
304
+ providers=self.providers,
305
+ cuda=self.cuda,
306
+ device_id=self.device_id,
307
+ extra_session_options=self._extra_session_options,
308
+ )
309
+
310
+ def embed(
311
+ self,
312
+ sequences: str | Iterable[str],
313
+ batch_size: int = 32,
314
+ parallel: int | None = None,
315
+ **kwargs: Any,
316
+ ) -> Iterable[NumpyArray]:
317
+ """
318
+ Embed protein sequences.
319
+
320
+ Args:
321
+ sequences: Single protein sequence or iterable of sequences (amino acid strings)
322
+ batch_size: Batch size for encoding
323
+ parallel: Number of parallel workers (not yet supported)
324
+
325
+ Yields:
326
+ Embeddings as numpy arrays, one per sequence
327
+ """
328
+ if isinstance(sequences, str):
329
+ sequences = [sequences]
330
+
331
+ if not hasattr(self, "model") or self.model is None:
332
+ self.load_onnx_model()
333
+
334
+ for batch in iter_batch(sequences, batch_size):
335
+ yield from self._post_process_onnx_output(self.onnx_embed(batch, **kwargs), **kwargs)
336
+
337
+ @classmethod
338
+ def _get_worker_class(cls) -> Type["ProteinEmbeddingWorker"]:
339
+ return ProteinEmbeddingWorker
340
+
341
+
342
+ class ProteinEmbeddingWorker(EmbeddingWorker[NumpyArray]):
343
+ def init_embedding(
344
+ self,
345
+ model_name: str,
346
+ cache_dir: str,
347
+ **kwargs: Any,
348
+ ) -> OnnxProteinEmbedding:
349
+ return OnnxProteinEmbedding(
350
+ model_name=model_name,
351
+ cache_dir=cache_dir,
352
+ threads=1,
353
+ **kwargs,
354
+ )
355
+
356
+ def process(
357
+ self, items: Iterable[tuple[int, Any]]
358
+ ) -> Iterable[tuple[int, OnnxOutputContext]]:
359
+ for idx, batch in items:
360
+ onnx_output = self.model.onnx_embed(batch)
361
+ yield idx, onnx_output
362
+
363
+
364
+ class ProteinEmbedding(ProteinEmbeddingBase):
365
+ """
366
+ Protein sequence embedding using ESM-2 and similar models.
367
+
368
+ Example:
369
+ >>> from fastembed.bio import ProteinEmbedding
370
+ >>> model = ProteinEmbedding("facebook/esm2_t12_35M_UR50D")
371
+ >>> embeddings = list(model.embed(["MKTVRQERLKS", "GKGDPKKPRGKM"]))
372
+ >>> print(embeddings[0].shape)
373
+ (480,)
374
+ """
375
+
376
+ EMBEDDINGS_REGISTRY: list[Type[ProteinEmbeddingBase]] = [OnnxProteinEmbedding]
377
+
378
+ @classmethod
379
+ def list_supported_models(cls) -> list[dict[str, Any]]:
380
+ """Lists the supported models.
381
+
382
+ Returns:
383
+ list[dict[str, Any]]: A list of dictionaries containing the model information.
384
+ """
385
+ return [asdict(model) for model in cls._list_supported_models()]
386
+
387
+ @classmethod
388
+ def _list_supported_models(cls) -> list[DenseModelDescription]:
389
+ result: list[DenseModelDescription] = []
390
+ for embedding in cls.EMBEDDINGS_REGISTRY:
391
+ result.extend(embedding._list_supported_models())
392
+ return result
393
+
394
+ def __init__(
395
+ self,
396
+ model_name: str = "facebook/esm2_t12_35M_UR50D",
397
+ cache_dir: str | None = None,
398
+ threads: int | None = None,
399
+ providers: Sequence[OnnxProvider] | None = None,
400
+ cuda: bool | Device = Device.AUTO,
401
+ device_ids: list[int] | None = None,
402
+ lazy_load: bool = False,
403
+ **kwargs: Any,
404
+ ):
405
+ """
406
+ Initialize ProteinEmbedding.
407
+
408
+ Args:
409
+ model_name: Name of the model to use
410
+ cache_dir: Path to cache directory
411
+ threads: Number of threads for ONNX runtime
412
+ providers: ONNX execution providers
413
+ cuda: Whether to use CUDA
414
+ device_ids: List of device IDs for multi-GPU
415
+ lazy_load: Whether to load model lazily
416
+ """
417
+ super().__init__(model_name, cache_dir, threads, **kwargs)
418
+
419
+ for EMBEDDING_MODEL_TYPE in self.EMBEDDINGS_REGISTRY:
420
+ supported_models = EMBEDDING_MODEL_TYPE._list_supported_models()
421
+ if any(model_name.lower() == model.model.lower() for model in supported_models):
422
+ self.model = EMBEDDING_MODEL_TYPE(
423
+ model_name=model_name,
424
+ cache_dir=cache_dir,
425
+ threads=threads,
426
+ providers=providers,
427
+ cuda=cuda,
428
+ device_ids=device_ids,
429
+ lazy_load=lazy_load,
430
+ **kwargs,
431
+ )
432
+ return
433
+
434
+ raise ValueError(
435
+ f"Model {model_name} is not supported in ProteinEmbedding. "
436
+ "Please check the supported models using `ProteinEmbedding.list_supported_models()`"
437
+ )
438
+
439
+ def embed(
440
+ self,
441
+ sequences: str | Iterable[str],
442
+ batch_size: int = 32,
443
+ parallel: int | None = None,
444
+ **kwargs: Any,
445
+ ) -> Iterable[NumpyArray]:
446
+ """Embed protein sequences.
447
+
448
+ Args:
449
+ sequences: Single protein sequence or iterable of sequences (amino acid strings)
450
+ batch_size: Batch size for encoding
451
+ parallel: Number of parallel workers
452
+
453
+ Yields:
454
+ Embeddings as numpy arrays, one per sequence
455
+ """
456
+ yield from self.model.embed(sequences, batch_size, parallel, **kwargs)
@@ -0,0 +1,3 @@
1
+ from fastembed.common.types import ImageInput, OnnxProvider, PathInput
2
+
3
+ __all__ = ["OnnxProvider", "ImageInput", "PathInput"]
@@ -0,0 +1,52 @@
1
+ from dataclasses import dataclass, field
2
+ from enum import Enum
3
+ from typing import Any
4
+
5
+
6
+ @dataclass(frozen=True)
7
+ class ModelSource:
8
+ hf: str | None = None
9
+ url: str | None = None
10
+ _deprecated_tar_struct: bool = False
11
+
12
+ @property
13
+ def deprecated_tar_struct(self) -> bool:
14
+ return self._deprecated_tar_struct
15
+
16
+ def __post_init__(self) -> None:
17
+ if self.hf is None and self.url is None:
18
+ raise ValueError(
19
+ f"At least one source should be set, current sources: hf={self.hf}, url={self.url}"
20
+ )
21
+
22
+
23
+ @dataclass(frozen=True)
24
+ class BaseModelDescription:
25
+ model: str
26
+ sources: ModelSource
27
+ model_file: str
28
+ description: str
29
+ license: str
30
+ size_in_GB: float
31
+ additional_files: list[str] = field(default_factory=list)
32
+
33
+
34
+ @dataclass(frozen=True)
35
+ class DenseModelDescription(BaseModelDescription):
36
+ dim: int | None = None
37
+ tasks: dict[str, Any] | None = field(default_factory=dict)
38
+
39
+ def __post_init__(self) -> None:
40
+ assert self.dim is not None, "dim is required for dense model description"
41
+
42
+
43
+ @dataclass(frozen=True)
44
+ class SparseModelDescription(BaseModelDescription):
45
+ requires_idf: bool | None = None
46
+ vocab_size: int | None = None
47
+
48
+
49
+ class PoolingType(str, Enum):
50
+ CLS = "CLS"
51
+ MEAN = "MEAN"
52
+ DISABLED = "DISABLED"