unstructured-ingest 0.5.2__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of unstructured-ingest might be problematic. Click here for more details.
- test/unit/v2/embedders/test_bedrock.py +1 -1
- test/unit/v2/embedders/test_huggingface.py +1 -1
- unstructured_ingest/__version__.py +1 -1
- unstructured_ingest/embed/azure_openai.py +6 -0
- unstructured_ingest/embed/bedrock.py +16 -6
- unstructured_ingest/embed/huggingface.py +3 -1
- unstructured_ingest/embed/interfaces.py +61 -23
- unstructured_ingest/embed/mixedbreadai.py +28 -114
- unstructured_ingest/embed/octoai.py +19 -51
- unstructured_ingest/embed/openai.py +17 -55
- unstructured_ingest/embed/togetherai.py +16 -58
- unstructured_ingest/embed/vertexai.py +15 -46
- unstructured_ingest/embed/voyageai.py +17 -52
- unstructured_ingest/v2/errors.py +7 -0
- unstructured_ingest/v2/processes/connectors/neo4j.py +129 -43
- unstructured_ingest/v2/processes/embedder.py +9 -7
- {unstructured_ingest-0.5.2.dist-info → unstructured_ingest-0.5.3.dist-info}/METADATA +96 -84
- {unstructured_ingest-0.5.2.dist-info → unstructured_ingest-0.5.3.dist-info}/RECORD +22 -22
- {unstructured_ingest-0.5.2.dist-info → unstructured_ingest-0.5.3.dist-info}/WHEEL +1 -1
- {unstructured_ingest-0.5.2.dist-info → unstructured_ingest-0.5.3.dist-info}/LICENSE.md +0 -0
- {unstructured_ingest-0.5.2.dist-info → unstructured_ingest-0.5.3.dist-info}/entry_points.txt +0 -0
- {unstructured_ingest-0.5.2.dist-info → unstructured_ingest-0.5.3.dist-info}/top_level.txt +0 -0
|
@@ -16,7 +16,7 @@ fake = faker.Faker()
|
|
|
16
16
|
def generate_embedder_config_params() -> dict:
|
|
17
17
|
params = {}
|
|
18
18
|
if random.random() < 0.5:
|
|
19
|
-
params["
|
|
19
|
+
params["embedder_model_name"] = fake.word() if random.random() < 0.5 else None
|
|
20
20
|
params["embedder_model_kwargs"] = (
|
|
21
21
|
generate_random_dictionary(key_type=str, value_type=Any)
|
|
22
22
|
if random.random() < 0.5
|
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.5.
|
|
1
|
+
__version__ = "0.5.3" # pragma: no cover
|
|
@@ -44,7 +44,13 @@ class AzureOpenAIEmbeddingConfig(OpenAIEmbeddingConfig):
|
|
|
44
44
|
class AzureOpenAIEmbeddingEncoder(OpenAIEmbeddingEncoder):
|
|
45
45
|
config: AzureOpenAIEmbeddingConfig
|
|
46
46
|
|
|
47
|
+
def get_client(self) -> "AzureOpenAI":
|
|
48
|
+
return self.config.get_client()
|
|
49
|
+
|
|
47
50
|
|
|
48
51
|
@dataclass
|
|
49
52
|
class AsyncAzureOpenAIEmbeddingEncoder(AsyncOpenAIEmbeddingEncoder):
|
|
50
53
|
config: AzureOpenAIEmbeddingConfig
|
|
54
|
+
|
|
55
|
+
def get_client(self) -> "AsyncAzureOpenAI":
|
|
56
|
+
return self.config.get_async_client()
|
|
@@ -15,7 +15,13 @@ from unstructured_ingest.embed.interfaces import (
|
|
|
15
15
|
)
|
|
16
16
|
from unstructured_ingest.logger import logger
|
|
17
17
|
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
18
|
-
from unstructured_ingest.v2.errors import
|
|
18
|
+
from unstructured_ingest.v2.errors import (
|
|
19
|
+
ProviderError,
|
|
20
|
+
RateLimitError,
|
|
21
|
+
UserAuthError,
|
|
22
|
+
UserError,
|
|
23
|
+
is_internal_error,
|
|
24
|
+
)
|
|
19
25
|
|
|
20
26
|
if TYPE_CHECKING:
|
|
21
27
|
from botocore.client import BaseClient
|
|
@@ -51,9 +57,11 @@ class BedrockEmbeddingConfig(EmbeddingConfig):
|
|
|
51
57
|
aws_access_key_id: SecretStr
|
|
52
58
|
aws_secret_access_key: SecretStr
|
|
53
59
|
region_name: str = "us-west-2"
|
|
54
|
-
|
|
60
|
+
embedder_model_name: str = Field(default="amazon.titan-embed-text-v1", alias="model_name")
|
|
55
61
|
|
|
56
62
|
def wrap_error(self, e: Exception) -> Exception:
|
|
63
|
+
if is_internal_error(e=e):
|
|
64
|
+
return e
|
|
57
65
|
from botocore.exceptions import ClientError
|
|
58
66
|
|
|
59
67
|
if isinstance(e, ClientError):
|
|
@@ -122,7 +130,7 @@ class BedrockEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
122
130
|
|
|
123
131
|
def embed_query(self, query: str) -> list[float]:
|
|
124
132
|
"""Call out to Bedrock embedding endpoint."""
|
|
125
|
-
provider = self.config.
|
|
133
|
+
provider = self.config.embedder_model_name.split(".")[0]
|
|
126
134
|
body = conform_query(query=query, provider=provider)
|
|
127
135
|
|
|
128
136
|
bedrock_client = self.config.get_client()
|
|
@@ -130,7 +138,7 @@ class BedrockEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
130
138
|
try:
|
|
131
139
|
response = bedrock_client.invoke_model(
|
|
132
140
|
body=json.dumps(body),
|
|
133
|
-
modelId=self.config.
|
|
141
|
+
modelId=self.config.embedder_model_name,
|
|
134
142
|
accept="application/json",
|
|
135
143
|
contentType="application/json",
|
|
136
144
|
)
|
|
@@ -148,6 +156,8 @@ class BedrockEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
148
156
|
def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
149
157
|
elements = elements.copy()
|
|
150
158
|
elements_with_text = [e for e in elements if e.get("text")]
|
|
159
|
+
if not elements_with_text:
|
|
160
|
+
return elements
|
|
151
161
|
embeddings = [self.embed_query(query=e["text"]) for e in elements_with_text]
|
|
152
162
|
for element, embedding in zip(elements_with_text, embeddings):
|
|
153
163
|
element[EMBEDDINGS_KEY] = embedding
|
|
@@ -163,7 +173,7 @@ class AsyncBedrockEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
|
|
|
163
173
|
|
|
164
174
|
async def embed_query(self, query: str) -> list[float]:
|
|
165
175
|
"""Call out to Bedrock embedding endpoint."""
|
|
166
|
-
provider = self.config.
|
|
176
|
+
provider = self.config.embedder_model_name.split(".")[0]
|
|
167
177
|
body = conform_query(query=query, provider=provider)
|
|
168
178
|
try:
|
|
169
179
|
async with self.config.get_async_client() as bedrock_client:
|
|
@@ -171,7 +181,7 @@ class AsyncBedrockEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
|
|
|
171
181
|
try:
|
|
172
182
|
response = await bedrock_client.invoke_model(
|
|
173
183
|
body=json.dumps(body),
|
|
174
|
-
modelId=self.config.
|
|
184
|
+
modelId=self.config.embedder_model_name,
|
|
175
185
|
accept="application/json",
|
|
176
186
|
contentType="application/json",
|
|
177
187
|
)
|
|
@@ -47,7 +47,7 @@ class HuggingFaceEmbeddingConfig(EmbeddingConfig):
|
|
|
47
47
|
class HuggingFaceEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
48
48
|
config: HuggingFaceEmbeddingConfig
|
|
49
49
|
|
|
50
|
-
def
|
|
50
|
+
def _embed_query(self, query: str) -> list[float]:
|
|
51
51
|
return self._embed_documents(texts=[query])[0]
|
|
52
52
|
|
|
53
53
|
def _embed_documents(self, texts: list[str]) -> list[list[float]]:
|
|
@@ -58,6 +58,8 @@ class HuggingFaceEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
58
58
|
def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
59
59
|
elements = elements.copy()
|
|
60
60
|
elements_with_text = [e for e in elements if e.get("text")]
|
|
61
|
+
if not elements_with_text:
|
|
62
|
+
return elements
|
|
61
63
|
embeddings = self._embed_documents([e["text"] for e in elements_with_text])
|
|
62
64
|
for element, embedding in zip(elements_with_text, embeddings):
|
|
63
65
|
element[EMBEDDINGS_KEY] = embedding
|
|
@@ -1,11 +1,12 @@
|
|
|
1
|
-
import
|
|
2
|
-
from abc import ABC, abstractmethod
|
|
1
|
+
from abc import ABC
|
|
3
2
|
from dataclasses import dataclass
|
|
4
|
-
from typing import Optional
|
|
3
|
+
from typing import Any, Optional
|
|
5
4
|
|
|
6
5
|
import numpy as np
|
|
7
6
|
from pydantic import BaseModel, Field
|
|
8
7
|
|
|
8
|
+
from unstructured_ingest.utils.data_prep import batch_generator
|
|
9
|
+
|
|
9
10
|
EMBEDDINGS_KEY = "embeddings"
|
|
10
11
|
|
|
11
12
|
|
|
@@ -50,21 +51,37 @@ class BaseEmbeddingEncoder(BaseEncoder, ABC):
|
|
|
50
51
|
exemplary_embedding = self.get_exemplary_embedding()
|
|
51
52
|
return np.isclose(np.linalg.norm(exemplary_embedding), 1.0)
|
|
52
53
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
pass
|
|
54
|
+
def get_client(self):
|
|
55
|
+
raise NotImplementedError
|
|
56
56
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
pass
|
|
57
|
+
def embed_batch(self, client: Any, batch: list[str]) -> list[list[float]]:
|
|
58
|
+
raise NotImplementedError
|
|
60
59
|
|
|
61
|
-
def
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
60
|
+
def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
61
|
+
client = self.get_client()
|
|
62
|
+
elements = elements.copy()
|
|
63
|
+
elements_with_text = [e for e in elements if e.get("text")]
|
|
64
|
+
texts = [e["text"] for e in elements_with_text]
|
|
65
|
+
embeddings = []
|
|
66
|
+
try:
|
|
67
|
+
for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
|
|
68
|
+
embeddings = self.embed_batch(client=client, batch=batch)
|
|
69
|
+
embeddings.extend(embeddings)
|
|
70
|
+
except Exception as e:
|
|
71
|
+
raise self.wrap_error(e=e)
|
|
72
|
+
for element, embedding in zip(elements_with_text, embeddings):
|
|
73
|
+
element[EMBEDDINGS_KEY] = embedding
|
|
74
|
+
return elements
|
|
75
|
+
|
|
76
|
+
def _embed_query(self, query: str) -> list[float]:
|
|
77
|
+
client = self.get_client()
|
|
78
|
+
return self.embed_batch(client=client, batch=[query])[0]
|
|
66
79
|
|
|
67
|
-
|
|
80
|
+
def embed_query(self, query: str) -> list[float]:
|
|
81
|
+
try:
|
|
82
|
+
return self._embed_query(query=query)
|
|
83
|
+
except Exception as e:
|
|
84
|
+
raise self.wrap_error(e=e)
|
|
68
85
|
|
|
69
86
|
|
|
70
87
|
@dataclass
|
|
@@ -88,14 +105,35 @@ class AsyncBaseEmbeddingEncoder(BaseEncoder, ABC):
|
|
|
88
105
|
exemplary_embedding = await self.get_exemplary_embedding()
|
|
89
106
|
return np.isclose(np.linalg.norm(exemplary_embedding), 1.0)
|
|
90
107
|
|
|
91
|
-
|
|
108
|
+
def get_client(self):
|
|
109
|
+
raise NotImplementedError
|
|
110
|
+
|
|
111
|
+
async def embed_batch(self, client: Any, batch: list[str]) -> list[list[float]]:
|
|
112
|
+
raise NotImplementedError
|
|
113
|
+
|
|
92
114
|
async def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
93
|
-
|
|
115
|
+
client = self.get_client()
|
|
116
|
+
elements = elements.copy()
|
|
117
|
+
elements_with_text = [e for e in elements if e.get("text")]
|
|
118
|
+
texts = [e["text"] for e in elements_with_text]
|
|
119
|
+
embeddings = []
|
|
120
|
+
try:
|
|
121
|
+
for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
|
|
122
|
+
embeddings = await self.embed_batch(client=client, batch=batch)
|
|
123
|
+
embeddings.extend(embeddings)
|
|
124
|
+
except Exception as e:
|
|
125
|
+
raise self.wrap_error(e=e)
|
|
126
|
+
for element, embedding in zip(elements_with_text, embeddings):
|
|
127
|
+
element[EMBEDDINGS_KEY] = embedding
|
|
128
|
+
return elements
|
|
129
|
+
|
|
130
|
+
async def _embed_query(self, query: str) -> list[float]:
|
|
131
|
+
client = self.get_client()
|
|
132
|
+
embeddings = await self.embed_batch(client=client, batch=[query])
|
|
133
|
+
return embeddings[0]
|
|
94
134
|
|
|
95
|
-
@abstractmethod
|
|
96
135
|
async def embed_query(self, query: str) -> list[float]:
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
return results
|
|
136
|
+
try:
|
|
137
|
+
return await self._embed_query(query=query)
|
|
138
|
+
except Exception as e:
|
|
139
|
+
raise self.wrap_error(e=e)
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import asyncio
|
|
2
1
|
import os
|
|
3
2
|
from dataclasses import dataclass
|
|
4
3
|
from typing import TYPE_CHECKING
|
|
@@ -6,12 +5,10 @@ from typing import TYPE_CHECKING
|
|
|
6
5
|
from pydantic import Field, SecretStr
|
|
7
6
|
|
|
8
7
|
from unstructured_ingest.embed.interfaces import (
|
|
9
|
-
EMBEDDINGS_KEY,
|
|
10
8
|
AsyncBaseEmbeddingEncoder,
|
|
11
9
|
BaseEmbeddingEncoder,
|
|
12
10
|
EmbeddingConfig,
|
|
13
11
|
)
|
|
14
|
-
from unstructured_ingest.utils.data_prep import batch_generator
|
|
15
12
|
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
16
13
|
|
|
17
14
|
USER_AGENT = "@mixedbread-ai/unstructured"
|
|
@@ -85,7 +82,7 @@ class MixedbreadAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
85
82
|
|
|
86
83
|
def get_exemplary_embedding(self) -> list[float]:
|
|
87
84
|
"""Get an exemplary embedding to determine dimensions and unit vector status."""
|
|
88
|
-
return self.
|
|
85
|
+
return self.embed_query(query="Q")
|
|
89
86
|
|
|
90
87
|
@requires_dependencies(
|
|
91
88
|
["mixedbread_ai"],
|
|
@@ -100,59 +97,19 @@ class MixedbreadAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
100
97
|
additional_headers={"User-Agent": USER_AGENT},
|
|
101
98
|
)
|
|
102
99
|
|
|
103
|
-
def
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
|
|
117
|
-
response = client.embeddings(
|
|
118
|
-
model=self.config.embedder_model_name,
|
|
119
|
-
normalized=True,
|
|
120
|
-
encoding_format=ENCODING_FORMAT,
|
|
121
|
-
truncation_strategy=TRUNCATION_STRATEGY,
|
|
122
|
-
request_options=self.get_request_options(),
|
|
123
|
-
input=batch,
|
|
124
|
-
)
|
|
125
|
-
responses.append(response)
|
|
126
|
-
return [item.embedding for response in responses for item in response.data]
|
|
127
|
-
|
|
128
|
-
def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
129
|
-
"""
|
|
130
|
-
Embed a list of document elements.
|
|
131
|
-
|
|
132
|
-
Args:
|
|
133
|
-
elements (list[Element]): List of document elements.
|
|
134
|
-
|
|
135
|
-
Returns:
|
|
136
|
-
list[Element]: Elements with embeddings.
|
|
137
|
-
"""
|
|
138
|
-
elements = elements.copy()
|
|
139
|
-
elements_with_text = [e for e in elements if e.get("text")]
|
|
140
|
-
embeddings = self._embed([e["text"] for e in elements_with_text])
|
|
141
|
-
for element, embedding in zip(elements_with_text, embeddings):
|
|
142
|
-
element[EMBEDDINGS_KEY] = embedding
|
|
143
|
-
return elements
|
|
144
|
-
|
|
145
|
-
def embed_query(self, query: str) -> list[float]:
|
|
146
|
-
"""
|
|
147
|
-
Embed a query string.
|
|
148
|
-
|
|
149
|
-
Args:
|
|
150
|
-
query (str): Query string to embed.
|
|
151
|
-
|
|
152
|
-
Returns:
|
|
153
|
-
list[float]: Embedding of the query.
|
|
154
|
-
"""
|
|
155
|
-
return self._embed([query])[0]
|
|
100
|
+
def get_client(self) -> "MixedbreadAI":
|
|
101
|
+
return self.config.get_client()
|
|
102
|
+
|
|
103
|
+
def embed_batch(self, client: "MixedbreadAI", batch: list[str]) -> list[list[float]]:
|
|
104
|
+
response = client.embeddings(
|
|
105
|
+
model=self.config.embedder_model_name,
|
|
106
|
+
normalized=True,
|
|
107
|
+
encoding_format=ENCODING_FORMAT,
|
|
108
|
+
truncation_strategy=TRUNCATION_STRATEGY,
|
|
109
|
+
request_options=self.get_request_options(),
|
|
110
|
+
input=batch,
|
|
111
|
+
)
|
|
112
|
+
return [datum.embedding for datum in response.data]
|
|
156
113
|
|
|
157
114
|
|
|
158
115
|
@dataclass
|
|
@@ -162,8 +119,7 @@ class AsyncMixedbreadAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
|
|
|
162
119
|
|
|
163
120
|
async def get_exemplary_embedding(self) -> list[float]:
|
|
164
121
|
"""Get an exemplary embedding to determine dimensions and unit vector status."""
|
|
165
|
-
|
|
166
|
-
return embedding[0]
|
|
122
|
+
return await self.embed_query(query="Q")
|
|
167
123
|
|
|
168
124
|
@requires_dependencies(
|
|
169
125
|
["mixedbread_ai"],
|
|
@@ -178,58 +134,16 @@ class AsyncMixedbreadAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
|
|
|
178
134
|
additional_headers={"User-Agent": USER_AGENT},
|
|
179
135
|
)
|
|
180
136
|
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
for
|
|
194
|
-
tasks.append(
|
|
195
|
-
client.embeddings(
|
|
196
|
-
model=self.config.embedder_model_name,
|
|
197
|
-
normalized=True,
|
|
198
|
-
encoding_format=ENCODING_FORMAT,
|
|
199
|
-
truncation_strategy=TRUNCATION_STRATEGY,
|
|
200
|
-
request_options=self.get_request_options(),
|
|
201
|
-
input=batch,
|
|
202
|
-
)
|
|
203
|
-
)
|
|
204
|
-
responses = await asyncio.gather(*tasks)
|
|
205
|
-
return [item.embedding for response in responses for item in response.data]
|
|
206
|
-
|
|
207
|
-
async def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
208
|
-
"""
|
|
209
|
-
Embed a list of document elements.
|
|
210
|
-
|
|
211
|
-
Args:
|
|
212
|
-
elements (list[Element]): List of document elements.
|
|
213
|
-
|
|
214
|
-
Returns:
|
|
215
|
-
list[Element]: Elements with embeddings.
|
|
216
|
-
"""
|
|
217
|
-
elements = elements.copy()
|
|
218
|
-
elements_with_text = [e for e in elements if e.get("text")]
|
|
219
|
-
embeddings = await self._embed([e["text"] for e in elements_with_text])
|
|
220
|
-
for element, embedding in zip(elements_with_text, embeddings):
|
|
221
|
-
element[EMBEDDINGS_KEY] = embedding
|
|
222
|
-
return elements
|
|
223
|
-
|
|
224
|
-
async def embed_query(self, query: str) -> list[float]:
|
|
225
|
-
"""
|
|
226
|
-
Embed a query string.
|
|
227
|
-
|
|
228
|
-
Args:
|
|
229
|
-
query (str): Query string to embed.
|
|
230
|
-
|
|
231
|
-
Returns:
|
|
232
|
-
list[float]: Embedding of the query.
|
|
233
|
-
"""
|
|
234
|
-
embedding = await self._embed([query])
|
|
235
|
-
return embedding[0]
|
|
137
|
+
def get_client(self) -> "AsyncMixedbreadAI":
|
|
138
|
+
return self.config.get_async_client()
|
|
139
|
+
|
|
140
|
+
async def embed_batch(self, client: "AsyncMixedbreadAI", batch: list[str]) -> list[list[float]]:
|
|
141
|
+
response = await client.embeddings(
|
|
142
|
+
model=self.config.embedder_model_name,
|
|
143
|
+
normalized=True,
|
|
144
|
+
encoding_format=ENCODING_FORMAT,
|
|
145
|
+
truncation_strategy=TRUNCATION_STRATEGY,
|
|
146
|
+
request_options=self.get_request_options(),
|
|
147
|
+
input=batch,
|
|
148
|
+
)
|
|
149
|
+
return [datum.embedding for datum in response.data]
|
|
@@ -4,13 +4,11 @@ from typing import TYPE_CHECKING
|
|
|
4
4
|
from pydantic import Field, SecretStr
|
|
5
5
|
|
|
6
6
|
from unstructured_ingest.embed.interfaces import (
|
|
7
|
-
EMBEDDINGS_KEY,
|
|
8
7
|
AsyncBaseEmbeddingEncoder,
|
|
9
8
|
BaseEmbeddingEncoder,
|
|
10
9
|
EmbeddingConfig,
|
|
11
10
|
)
|
|
12
11
|
from unstructured_ingest.logger import logger
|
|
13
|
-
from unstructured_ingest.utils.data_prep import batch_generator
|
|
14
12
|
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
15
13
|
from unstructured_ingest.v2.errors import (
|
|
16
14
|
ProviderError,
|
|
@@ -18,6 +16,7 @@ from unstructured_ingest.v2.errors import (
|
|
|
18
16
|
RateLimitError,
|
|
19
17
|
UserAuthError,
|
|
20
18
|
UserError,
|
|
19
|
+
is_internal_error,
|
|
21
20
|
)
|
|
22
21
|
|
|
23
22
|
if TYPE_CHECKING:
|
|
@@ -30,6 +29,8 @@ class OctoAiEmbeddingConfig(EmbeddingConfig):
|
|
|
30
29
|
base_url: str = Field(default="https://text.octoai.run/v1")
|
|
31
30
|
|
|
32
31
|
def wrap_error(self, e: Exception) -> Exception:
|
|
32
|
+
if is_internal_error(e=e):
|
|
33
|
+
return e
|
|
33
34
|
# https://platform.openai.com/docs/guides/error-codes/api-errors
|
|
34
35
|
from openai import APIStatusError
|
|
35
36
|
|
|
@@ -81,31 +82,17 @@ class OctoAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
81
82
|
def wrap_error(self, e: Exception) -> Exception:
|
|
82
83
|
return self.config.wrap_error(e=e)
|
|
83
84
|
|
|
84
|
-
def
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
response = client.embeddings.create(input=query, model=self.config.embedder_model_name)
|
|
88
|
-
except Exception as e:
|
|
89
|
-
raise self.wrap_error(e=e)
|
|
85
|
+
def _embed_query(self, query: str):
|
|
86
|
+
client = self.get_client()
|
|
87
|
+
response = client.embeddings.create(input=query, model=self.config.embedder_model_name)
|
|
90
88
|
return response.data[0].embedding
|
|
91
89
|
|
|
92
|
-
def
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
embeddings =
|
|
97
|
-
|
|
98
|
-
try:
|
|
99
|
-
for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
|
|
100
|
-
response = client.embeddings.create(
|
|
101
|
-
input=batch, model=self.config.embedder_model_name
|
|
102
|
-
)
|
|
103
|
-
embeddings.extend([data.embedding for data in response.data])
|
|
104
|
-
except Exception as e:
|
|
105
|
-
raise self.wrap_error(e=e)
|
|
106
|
-
for element, embedding in zip(elements_with_text, embeddings):
|
|
107
|
-
element[EMBEDDINGS_KEY] = embedding
|
|
108
|
-
return elements
|
|
90
|
+
def get_client(self) -> "OpenAI":
|
|
91
|
+
return self.config.get_client()
|
|
92
|
+
|
|
93
|
+
def embed_batch(self, client: "OpenAI", batch: list[str]) -> list[list[float]]:
|
|
94
|
+
response = client.embeddings.create(input=batch, model=self.config.embedder_model_name)
|
|
95
|
+
return [data.embedding for data in response.data]
|
|
109
96
|
|
|
110
97
|
|
|
111
98
|
@dataclass
|
|
@@ -115,30 +102,11 @@ class AsyncOctoAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
|
|
|
115
102
|
def wrap_error(self, e: Exception) -> Exception:
|
|
116
103
|
return self.config.wrap_error(e=e)
|
|
117
104
|
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
try:
|
|
121
|
-
response = await client.embeddings.create(
|
|
122
|
-
input=query, model=self.config.embedder_model_name
|
|
123
|
-
)
|
|
124
|
-
except Exception as e:
|
|
125
|
-
raise self.wrap_error(e=e)
|
|
126
|
-
return response.data[0].embedding
|
|
105
|
+
def get_client(self) -> "AsyncOpenAI":
|
|
106
|
+
return self.config.get_async_client()
|
|
127
107
|
|
|
128
|
-
async def
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
embeddings = []
|
|
134
|
-
try:
|
|
135
|
-
for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
|
|
136
|
-
response = await client.embeddings.create(
|
|
137
|
-
input=batch, model=self.config.embedder_model_name
|
|
138
|
-
)
|
|
139
|
-
embeddings.extend([data.embedding for data in response.data])
|
|
140
|
-
except Exception as e:
|
|
141
|
-
raise self.wrap_error(e=e)
|
|
142
|
-
for element, embedding in zip(elements_with_text, embeddings):
|
|
143
|
-
element[EMBEDDINGS_KEY] = embedding
|
|
144
|
-
return elements
|
|
108
|
+
async def embed_batch(self, client: "AsyncOpenAI", batch: list[str]) -> list[list[float]]:
|
|
109
|
+
response = await client.embeddings.create(
|
|
110
|
+
input=batch, model=self.config.embedder_model_name
|
|
111
|
+
)
|
|
112
|
+
return [data.embedding for data in response.data]
|
|
@@ -4,13 +4,11 @@ from typing import TYPE_CHECKING
|
|
|
4
4
|
from pydantic import Field, SecretStr
|
|
5
5
|
|
|
6
6
|
from unstructured_ingest.embed.interfaces import (
|
|
7
|
-
EMBEDDINGS_KEY,
|
|
8
7
|
AsyncBaseEmbeddingEncoder,
|
|
9
8
|
BaseEmbeddingEncoder,
|
|
10
9
|
EmbeddingConfig,
|
|
11
10
|
)
|
|
12
11
|
from unstructured_ingest.logger import logger
|
|
13
|
-
from unstructured_ingest.utils.data_prep import batch_generator
|
|
14
12
|
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
15
13
|
from unstructured_ingest.v2.errors import (
|
|
16
14
|
ProviderError,
|
|
@@ -18,6 +16,7 @@ from unstructured_ingest.v2.errors import (
|
|
|
18
16
|
RateLimitError,
|
|
19
17
|
UserAuthError,
|
|
20
18
|
UserError,
|
|
19
|
+
is_internal_error,
|
|
21
20
|
)
|
|
22
21
|
|
|
23
22
|
if TYPE_CHECKING:
|
|
@@ -29,6 +28,8 @@ class OpenAIEmbeddingConfig(EmbeddingConfig):
|
|
|
29
28
|
embedder_model_name: str = Field(default="text-embedding-ada-002", alias="model_name")
|
|
30
29
|
|
|
31
30
|
def wrap_error(self, e: Exception) -> Exception:
|
|
31
|
+
if is_internal_error(e=e):
|
|
32
|
+
return e
|
|
32
33
|
# https://platform.openai.com/docs/guides/error-codes/api-errors
|
|
33
34
|
from openai import APIStatusError
|
|
34
35
|
|
|
@@ -72,32 +73,12 @@ class OpenAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
72
73
|
def wrap_error(self, e: Exception) -> Exception:
|
|
73
74
|
return self.config.wrap_error(e=e)
|
|
74
75
|
|
|
75
|
-
def
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
raise self.wrap_error(e=e)
|
|
82
|
-
return response.data[0].embedding
|
|
83
|
-
|
|
84
|
-
def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
85
|
-
client = self.config.get_client()
|
|
86
|
-
elements = elements.copy()
|
|
87
|
-
elements_with_text = [e for e in elements if e.get("text")]
|
|
88
|
-
texts = [e["text"] for e in elements_with_text]
|
|
89
|
-
embeddings = []
|
|
90
|
-
try:
|
|
91
|
-
for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
|
|
92
|
-
response = client.embeddings.create(
|
|
93
|
-
input=batch, model=self.config.embedder_model_name
|
|
94
|
-
)
|
|
95
|
-
embeddings.extend([data.embedding for data in response.data])
|
|
96
|
-
except Exception as e:
|
|
97
|
-
raise self.wrap_error(e=e)
|
|
98
|
-
for element, embedding in zip(elements_with_text, embeddings):
|
|
99
|
-
element[EMBEDDINGS_KEY] = embedding
|
|
100
|
-
return elements
|
|
76
|
+
def get_client(self) -> "OpenAI":
|
|
77
|
+
return self.config.get_client()
|
|
78
|
+
|
|
79
|
+
def embed_batch(self, client: "OpenAI", batch: list[str]) -> list[list[float]]:
|
|
80
|
+
response = client.embeddings.create(input=batch, model=self.config.embedder_model_name)
|
|
81
|
+
return [data.embedding for data in response.data]
|
|
101
82
|
|
|
102
83
|
|
|
103
84
|
@dataclass
|
|
@@ -107,30 +88,11 @@ class AsyncOpenAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
|
|
|
107
88
|
def wrap_error(self, e: Exception) -> Exception:
|
|
108
89
|
return self.config.wrap_error(e=e)
|
|
109
90
|
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
return response.data[0].embedding
|
|
119
|
-
|
|
120
|
-
async def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
121
|
-
client = self.config.get_async_client()
|
|
122
|
-
elements = elements.copy()
|
|
123
|
-
elements_with_text = [e for e in elements if e.get("text")]
|
|
124
|
-
texts = [e["text"] for e in elements_with_text]
|
|
125
|
-
embeddings = []
|
|
126
|
-
try:
|
|
127
|
-
for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
|
|
128
|
-
response = await client.embeddings.create(
|
|
129
|
-
input=batch, model=self.config.embedder_model_name
|
|
130
|
-
)
|
|
131
|
-
embeddings.extend([data.embedding for data in response.data])
|
|
132
|
-
except Exception as e:
|
|
133
|
-
raise self.wrap_error(e=e)
|
|
134
|
-
for element, embedding in zip(elements_with_text, embeddings):
|
|
135
|
-
element[EMBEDDINGS_KEY] = embedding
|
|
136
|
-
return elements
|
|
91
|
+
def get_client(self) -> "AsyncOpenAI":
|
|
92
|
+
return self.config.get_async_client()
|
|
93
|
+
|
|
94
|
+
async def embed_batch(self, client: "AsyncOpenAI", batch: list[str]) -> list[list[float]]:
|
|
95
|
+
response = await client.embeddings.create(
|
|
96
|
+
input=batch, model=self.config.embedder_model_name
|
|
97
|
+
)
|
|
98
|
+
return [data.embedding for data in response.data]
|