datapizza-ai-embedders-fastembedder 0.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datapizza/embedders/fastembedder/__init__.py +3 -0
- datapizza/embedders/fastembedder/fastembedder.py +44 -0
- datapizza_ai_embedders_fastembedder-0.0.2.dist-info/METADATA +12 -0
- datapizza_ai_embedders_fastembedder-0.0.2.dist-info/RECORD +5 -0
- datapizza_ai_embedders_fastembedder-0.0.2.dist-info/WHEEL +4 -0
@@ -0,0 +1,44 @@
|
|
1
|
+
import logging
|
2
|
+
|
3
|
+
import fastembed
|
4
|
+
from datapizza.core.embedder import BaseEmbedder
|
5
|
+
from datapizza.type import SparseEmbedding
|
6
|
+
|
7
|
+
log = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
|
10
|
+
class FastEmbedder(BaseEmbedder):
|
11
|
+
def __init__(
|
12
|
+
self,
|
13
|
+
model_name: str,
|
14
|
+
embedding_name: str | None = None,
|
15
|
+
cache_dir: str | None = None,
|
16
|
+
):
|
17
|
+
self.model_name = model_name
|
18
|
+
if embedding_name:
|
19
|
+
self.embedding_name = embedding_name
|
20
|
+
else:
|
21
|
+
self.embedding_name = model_name
|
22
|
+
|
23
|
+
self.cache_dir = cache_dir
|
24
|
+
self.embedder = fastembed.SparseTextEmbedding(
|
25
|
+
model_name=model_name, cache_dir=cache_dir
|
26
|
+
)
|
27
|
+
|
28
|
+
def embed(self, text: str | list[str]) :
|
29
|
+
if isinstance(text, list):
|
30
|
+
embeddings = [next(iter(self.embedder.embed(t))) for t in text]
|
31
|
+
return [SparseEmbedding(name=self.embedding_name, values=embedding.values.tolist(), indices=embedding.indices.tolist()) for embedding in embeddings]
|
32
|
+
else:
|
33
|
+
embedding = next(iter(self.embedder.embed(text)))
|
34
|
+
return SparseEmbedding(name=self.embedding_name, values=embedding.values.tolist(), indices=embedding.indices.tolist())
|
35
|
+
|
36
|
+
|
37
|
+
def a_embed(self, text: str | list[str]):
|
38
|
+
if isinstance(text, list):
|
39
|
+
embeddings = [next(iter(self.embedder.embed(t))) for t in text]
|
40
|
+
return [SparseEmbedding(name=self.embedding_name, values=embedding.values.tolist(), indices=embedding.indices.tolist()) for embedding in embeddings]
|
41
|
+
else:
|
42
|
+
embedding = next(iter(self.embedder.embed(text)))
|
43
|
+
return SparseEmbedding(name=self.embedding_name, values=embedding.values.tolist(), indices=embedding.indices.tolist())
|
44
|
+
|
@@ -0,0 +1,12 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: datapizza-ai-embedders-fastembedder
|
3
|
+
Version: 0.0.2
|
4
|
+
Summary: FastEmbed embedder for the datapizza-ai framework
|
5
|
+
Author-email: Datapizza <datapizza@datapizza.tech>
|
6
|
+
License: MIT
|
7
|
+
Classifier: License :: OSI Approved :: MIT License
|
8
|
+
Classifier: Operating System :: OS Independent
|
9
|
+
Classifier: Programming Language :: Python :: 3
|
10
|
+
Requires-Python: <4,>=3.10.0
|
11
|
+
Requires-Dist: datapizza-ai-core>=0.0.0
|
12
|
+
Requires-Dist: fastembed>=0.6.1
|
@@ -0,0 +1,5 @@
|
|
1
|
+
datapizza/embedders/fastembedder/__init__.py,sha256=D5Oac7f0KwSqgjE-OlKgzaMxAg-uqwDP003mkjlxhcw,67
|
2
|
+
datapizza/embedders/fastembedder/fastembedder.py,sha256=fU3U340slc7WLuXxRvyyIkJJe15jLHCQLoBbWT8jU08,1722
|
3
|
+
datapizza_ai_embedders_fastembedder-0.0.2.dist-info/METADATA,sha256=EXpdNnJ6KAO70LpoDYsP1iua8Wn7ZDBLge__olDtrZA,449
|
4
|
+
datapizza_ai_embedders_fastembedder-0.0.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
5
|
+
datapizza_ai_embedders_fastembedder-0.0.2.dist-info/RECORD,,
|