datapizza-ai-embedders-cohere 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,3 @@
1
+ from .cohere import CohereEmbedder
2
+
3
+ __all__ = ["CohereEmbedder"]
@@ -0,0 +1,74 @@
1
+ from datapizza.core.embedder import BaseEmbedder
2
+
3
+
4
+ class CohereEmbedder(BaseEmbedder):
5
+ def __init__(
6
+ self,
7
+ *,
8
+ api_key: str,
9
+ model_name: str | None = None,
10
+ base_url: str | None = None,
11
+ input_type: str = "search_document",
12
+ ):
13
+ self.api_key = api_key
14
+ self.base_url = base_url
15
+ self.model_name = model_name
16
+
17
+ self.input_type = input_type
18
+
19
+ self.client = None
20
+ self.a_client = None
21
+
22
+ def _set_client(self):
23
+ import cohere
24
+
25
+ if not self.client:
26
+ self.client = cohere.ClientV2(base_url=self.base_url, api_key=self.api_key)
27
+
28
+ def _set_a_client(self):
29
+ import cohere
30
+
31
+ if not self.a_client:
32
+ self.a_client = cohere.AsyncClientV2(
33
+ base_url=self.base_url,
34
+ api_key=self.api_key,
35
+ )
36
+
37
+ def embed(
38
+ self, text: str | list[str], model_name: str | None = None
39
+ ) -> list[float] | list[list[float]]:
40
+ model = model_name or self.model_name
41
+ if not model:
42
+ raise ValueError("Model name is required.")
43
+
44
+ texts = [text] if isinstance(text, str) else text
45
+
46
+ client = self._get_client()
47
+
48
+ response = client.embed(
49
+ texts=texts,
50
+ model=model,
51
+ input_type=self.input_type,
52
+ embedding_types=["float"],
53
+ )
54
+ embeddings = response.embeddings.float
55
+ return embeddings[0] if isinstance(text, str) else embeddings
56
+
57
+ async def a_embed(
58
+ self, text: str | list[str], model_name: str | None = None
59
+ ) -> list[float]:
60
+ model = model_name or self.model_name
61
+ if not model:
62
+ raise ValueError("Model name is required.")
63
+
64
+ texts = [text] if isinstance(text, str) else text
65
+
66
+ client = self._get_a_client()
67
+ response = await client.embed(
68
+ texts=texts,
69
+ model=model,
70
+ input_type=self.input_type,
71
+ embedding_types=["float"],
72
+ )
73
+ embeddings = response.embeddings.float
74
+ return embeddings[0] if isinstance(text, str) else embeddings
@@ -0,0 +1,12 @@
1
+ Metadata-Version: 2.4
2
+ Name: datapizza-ai-embedders-cohere
3
+ Version: 0.0.2
4
+ Summary: Cohere embedder for the datapizza-ai framework
5
+ Author-email: Datapizza <datapizza@datapizza.tech>
6
+ License: MIT
7
+ Classifier: License :: OSI Approved :: MIT License
8
+ Classifier: Operating System :: OS Independent
9
+ Classifier: Programming Language :: Python :: 3
10
+ Requires-Python: <4,>=3.10.0
11
+ Requires-Dist: cohere<6.0.0,>=5.14.0
12
+ Requires-Dist: datapizza-ai-core>=0.0.0
@@ -0,0 +1,5 @@
1
+ datapizza/embedders/cohere/__init__.py,sha256=OO44WbItRJdFA3setpBHi0eb9ngEfEElh7oGWYZzQ9I,65
2
+ datapizza/embedders/cohere/cohere.py,sha256=94z2nX4Le7bHrSKiCRteLujA8meIEMkuq037XCsv_Vs,2133
3
+ datapizza_ai_embedders_cohere-0.0.2.dist-info/METADATA,sha256=hH2DYq82hoVti5OSHT3CL_Bm54JSQvH4qKNlaRGjhNY,445
4
+ datapizza_ai_embedders_cohere-0.0.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
5
+ datapizza_ai_embedders_cohere-0.0.2.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.27.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any