unstructured-ingest 0.5.1__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of unstructured-ingest might be problematic. Click here for more details.
- test/integration/connectors/test_google_drive.py +141 -0
- test/unit/v2/embedders/test_bedrock.py +1 -1
- test/unit/v2/embedders/test_huggingface.py +1 -1
- unstructured_ingest/__version__.py +1 -1
- unstructured_ingest/embed/azure_openai.py +6 -0
- unstructured_ingest/embed/bedrock.py +29 -12
- unstructured_ingest/embed/huggingface.py +14 -5
- unstructured_ingest/embed/interfaces.py +63 -44
- unstructured_ingest/embed/mixedbreadai.py +28 -105
- unstructured_ingest/embed/octoai.py +19 -44
- unstructured_ingest/embed/openai.py +17 -48
- unstructured_ingest/embed/togetherai.py +16 -49
- unstructured_ingest/embed/vertexai.py +15 -39
- unstructured_ingest/embed/voyageai.py +16 -42
- unstructured_ingest/v2/errors.py +7 -0
- unstructured_ingest/v2/processes/connectors/google_drive.py +132 -3
- unstructured_ingest/v2/processes/connectors/neo4j.py +129 -43
- unstructured_ingest/v2/processes/connectors/sql/snowflake.py +53 -3
- unstructured_ingest/v2/processes/embedder.py +9 -7
- {unstructured_ingest-0.5.1.dist-info → unstructured_ingest-0.5.3.dist-info}/METADATA +99 -87
- {unstructured_ingest-0.5.1.dist-info → unstructured_ingest-0.5.3.dist-info}/RECORD +25 -25
- {unstructured_ingest-0.5.1.dist-info → unstructured_ingest-0.5.3.dist-info}/WHEEL +1 -1
- {unstructured_ingest-0.5.1.dist-info → unstructured_ingest-0.5.3.dist-info}/LICENSE.md +0 -0
- {unstructured_ingest-0.5.1.dist-info → unstructured_ingest-0.5.3.dist-info}/entry_points.txt +0 -0
- {unstructured_ingest-0.5.1.dist-info → unstructured_ingest-0.5.3.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
import os
|
|
2
|
+
import uuid
|
|
2
3
|
|
|
3
4
|
import pytest
|
|
5
|
+
from googleapiclient.errors import HttpError
|
|
4
6
|
|
|
5
7
|
from test.integration.connectors.utils.constants import (
|
|
6
8
|
SOURCE_TAG,
|
|
@@ -13,6 +15,9 @@ from test.integration.connectors.utils.validation.source import (
|
|
|
13
15
|
update_fixtures,
|
|
14
16
|
)
|
|
15
17
|
from test.integration.utils import requires_env
|
|
18
|
+
from unstructured_ingest.error import (
|
|
19
|
+
SourceConnectionError,
|
|
20
|
+
)
|
|
16
21
|
from unstructured_ingest.v2.interfaces import Downloader, Indexer
|
|
17
22
|
from unstructured_ingest.v2.processes.connectors.google_drive import (
|
|
18
23
|
CONNECTOR_TYPE,
|
|
@@ -25,6 +30,49 @@ from unstructured_ingest.v2.processes.connectors.google_drive import (
|
|
|
25
30
|
)
|
|
26
31
|
|
|
27
32
|
|
|
33
|
+
@pytest.fixture
|
|
34
|
+
def google_drive_connection_config():
|
|
35
|
+
"""
|
|
36
|
+
Build a valid GoogleDriveConnectionConfig using the environment variables.
|
|
37
|
+
Expects:
|
|
38
|
+
- GOOGLE_DRIVE_ID
|
|
39
|
+
- GOOGLE_DRIVE_SERVICE_KEY
|
|
40
|
+
"""
|
|
41
|
+
drive_id = os.getenv("GOOGLE_DRIVE_ID")
|
|
42
|
+
service_key = os.getenv("GOOGLE_DRIVE_SERVICE_KEY")
|
|
43
|
+
if not drive_id or not service_key:
|
|
44
|
+
pytest.skip("Google Drive credentials not provided in environment variables.")
|
|
45
|
+
|
|
46
|
+
access_config = GoogleDriveAccessConfig(service_account_key=service_key)
|
|
47
|
+
return GoogleDriveConnectionConfig(drive_id=drive_id, access_config=access_config)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@pytest.fixture
|
|
51
|
+
def google_drive_empty_folder(google_drive_connection_config):
|
|
52
|
+
"""
|
|
53
|
+
Creates an empty folder on Google Drive for testing the "empty folder" case.
|
|
54
|
+
The folder is deleted after the test.
|
|
55
|
+
"""
|
|
56
|
+
from google.oauth2 import service_account
|
|
57
|
+
from googleapiclient.discovery import build
|
|
58
|
+
|
|
59
|
+
access_config = google_drive_connection_config.access_config.get_secret_value()
|
|
60
|
+
creds = service_account.Credentials.from_service_account_info(access_config.service_account_key)
|
|
61
|
+
service = build("drive", "v3", credentials=creds)
|
|
62
|
+
|
|
63
|
+
# Create an empty folder.
|
|
64
|
+
file_metadata = {
|
|
65
|
+
"name": f"utic-empty-folder-{uuid.uuid4()}",
|
|
66
|
+
"mimeType": "application/vnd.google-apps.folder",
|
|
67
|
+
}
|
|
68
|
+
folder = service.files().create(body=file_metadata, fields="id, name").execute()
|
|
69
|
+
folder_id = folder.get("id")
|
|
70
|
+
try:
|
|
71
|
+
yield folder_id
|
|
72
|
+
finally:
|
|
73
|
+
service.files().delete(fileId=folder_id).execute()
|
|
74
|
+
|
|
75
|
+
|
|
28
76
|
@requires_env("GOOGLE_DRIVE_SERVICE_KEY")
|
|
29
77
|
@pytest.mark.tags(SOURCE_TAG, CONNECTOR_TYPE)
|
|
30
78
|
def test_google_drive_source(temp_dir):
|
|
@@ -114,3 +162,96 @@ def source_connector_validation(
|
|
|
114
162
|
save_downloads=configs.validate_downloaded_files,
|
|
115
163
|
save_filedata=configs.validate_file_data,
|
|
116
164
|
)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
# Precheck fails when the drive ID has an appended parameter (simulate copy-paste error)
|
|
168
|
+
@pytest.mark.tags("google-drive", "precheck")
|
|
169
|
+
@requires_env("GOOGLE_DRIVE_ID", "GOOGLE_DRIVE_SERVICE_KEY")
|
|
170
|
+
def test_google_drive_precheck_invalid_parameter(google_drive_connection_config):
|
|
171
|
+
# Append a query parameter as often happens when copying from a URL.
|
|
172
|
+
invalid_drive_id = google_drive_connection_config.drive_id + "?usp=sharing"
|
|
173
|
+
connection_config = GoogleDriveConnectionConfig(
|
|
174
|
+
drive_id=invalid_drive_id,
|
|
175
|
+
access_config=google_drive_connection_config.access_config,
|
|
176
|
+
)
|
|
177
|
+
index_config = GoogleDriveIndexerConfig(recursive=True)
|
|
178
|
+
indexer = GoogleDriveIndexer(connection_config=connection_config, index_config=index_config)
|
|
179
|
+
with pytest.raises(SourceConnectionError) as excinfo:
|
|
180
|
+
indexer.precheck()
|
|
181
|
+
assert "invalid" in str(excinfo.value).lower() or "not found" in str(excinfo.value).lower()
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
# Precheck fails due to lack of permission (simulate via monkeypatching).
|
|
185
|
+
@pytest.mark.tags("google-drive", "precheck")
|
|
186
|
+
@requires_env("GOOGLE_DRIVE_ID", "GOOGLE_DRIVE_SERVICE_KEY")
|
|
187
|
+
def test_google_drive_precheck_no_permission(google_drive_connection_config, monkeypatch):
|
|
188
|
+
index_config = GoogleDriveIndexerConfig(recursive=True)
|
|
189
|
+
indexer = GoogleDriveIndexer(
|
|
190
|
+
connection_config=google_drive_connection_config,
|
|
191
|
+
index_config=index_config,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
# Monkeypatch get_root_info to always raise an HTTP 403 error.
|
|
195
|
+
def fake_get_root_info(files_client, object_id):
|
|
196
|
+
raise HttpError(
|
|
197
|
+
resp=type("Response", (), {"status": 403, "reason": "Forbidden"})(),
|
|
198
|
+
content=b"Forbidden",
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
monkeypatch.setattr(indexer, "get_root_info", fake_get_root_info)
|
|
202
|
+
with pytest.raises(SourceConnectionError) as excinfo:
|
|
203
|
+
indexer.precheck()
|
|
204
|
+
assert "forbidden" in str(excinfo.value).lower() or "permission" in str(excinfo.value).lower()
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
# Precheck fails when the folder is empty.
|
|
208
|
+
# @pytest.mark.tags("google-drive", "precheck")
|
|
209
|
+
# @requires_env("GOOGLE_DRIVE_ID", "GOOGLE_DRIVE_SERVICE_KEY")
|
|
210
|
+
# def test_google_drive_precheck_empty_folder(
|
|
211
|
+
# google_drive_connection_config, google_drive_empty_folder
|
|
212
|
+
# ):
|
|
213
|
+
# # Use the empty folder's ID as the target.
|
|
214
|
+
# connection_config = GoogleDriveConnectionConfig(
|
|
215
|
+
# drive_id=google_drive_empty_folder,
|
|
216
|
+
# access_config=google_drive_connection_config.access_config,
|
|
217
|
+
# )
|
|
218
|
+
|
|
219
|
+
# index_config = GoogleDriveIndexerConfig(recursive=True)
|
|
220
|
+
# indexer = GoogleDriveIndexer(connection_config=connection_config, index_config=index_config)
|
|
221
|
+
# with pytest.raises(SourceConnectionError) as excinfo:
|
|
222
|
+
# indexer.precheck()
|
|
223
|
+
# assert "empty folder" in str(excinfo.value).lower()
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
@pytest.mark.tags("google-drive", "count", "integration")
|
|
227
|
+
@requires_env("GOOGLE_DRIVE_ID", "GOOGLE_DRIVE_SERVICE_KEY")
|
|
228
|
+
def test_google_drive_count_files(google_drive_connection_config):
|
|
229
|
+
"""
|
|
230
|
+
This test verifies that the count_files_recursively method returns the expected count of files.
|
|
231
|
+
According to the test credentials, there are 3 files in the root directory and 1 nested file,
|
|
232
|
+
so the total count should be 4.
|
|
233
|
+
"""
|
|
234
|
+
# I assumed that we're applying the same extension filter as with other tests
|
|
235
|
+
# However there's 6 files in total in the test dir
|
|
236
|
+
extensions_filter = ["pdf", "docx"]
|
|
237
|
+
with google_drive_connection_config.get_client() as client:
|
|
238
|
+
count = GoogleDriveIndexer.count_files_recursively(
|
|
239
|
+
client, google_drive_connection_config.drive_id, extensions_filter
|
|
240
|
+
)
|
|
241
|
+
assert count == 4, f"Expected file count of 4, but got {count}"
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
# Precheck fails with a completely invalid drive ID.
|
|
245
|
+
@pytest.mark.tags("google-drive", "precheck")
|
|
246
|
+
@requires_env("GOOGLE_DRIVE_ID", "GOOGLE_DRIVE_SERVICE_KEY")
|
|
247
|
+
def test_google_drive_precheck_invalid_drive_id(google_drive_connection_config):
|
|
248
|
+
invalid_drive_id = "invalid_drive_id"
|
|
249
|
+
connection_config = GoogleDriveConnectionConfig(
|
|
250
|
+
drive_id=invalid_drive_id,
|
|
251
|
+
access_config=google_drive_connection_config.access_config,
|
|
252
|
+
)
|
|
253
|
+
index_config = GoogleDriveIndexerConfig(recursive=True)
|
|
254
|
+
indexer = GoogleDriveIndexer(connection_config=connection_config, index_config=index_config)
|
|
255
|
+
with pytest.raises(SourceConnectionError) as excinfo:
|
|
256
|
+
indexer.precheck()
|
|
257
|
+
assert "invalid" in str(excinfo.value).lower() or "not found" in str(excinfo.value).lower()
|
|
@@ -16,7 +16,7 @@ fake = faker.Faker()
|
|
|
16
16
|
def generate_embedder_config_params() -> dict:
|
|
17
17
|
params = {}
|
|
18
18
|
if random.random() < 0.5:
|
|
19
|
-
params["
|
|
19
|
+
params["embedder_model_name"] = fake.word() if random.random() < 0.5 else None
|
|
20
20
|
params["embedder_model_kwargs"] = (
|
|
21
21
|
generate_random_dictionary(key_type=str, value_type=Any)
|
|
22
22
|
if random.random() < 0.5
|
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.5.
|
|
1
|
+
__version__ = "0.5.3" # pragma: no cover
|
|
@@ -44,7 +44,13 @@ class AzureOpenAIEmbeddingConfig(OpenAIEmbeddingConfig):
|
|
|
44
44
|
class AzureOpenAIEmbeddingEncoder(OpenAIEmbeddingEncoder):
|
|
45
45
|
config: AzureOpenAIEmbeddingConfig
|
|
46
46
|
|
|
47
|
+
def get_client(self) -> "AzureOpenAI":
|
|
48
|
+
return self.config.get_client()
|
|
49
|
+
|
|
47
50
|
|
|
48
51
|
@dataclass
|
|
49
52
|
class AsyncAzureOpenAIEmbeddingEncoder(AsyncOpenAIEmbeddingEncoder):
|
|
50
53
|
config: AzureOpenAIEmbeddingConfig
|
|
54
|
+
|
|
55
|
+
def get_client(self) -> "AsyncAzureOpenAI":
|
|
56
|
+
return self.config.get_async_client()
|
|
@@ -8,13 +8,20 @@ from typing import TYPE_CHECKING, AsyncIterable
|
|
|
8
8
|
from pydantic import Field, SecretStr
|
|
9
9
|
|
|
10
10
|
from unstructured_ingest.embed.interfaces import (
|
|
11
|
+
EMBEDDINGS_KEY,
|
|
11
12
|
AsyncBaseEmbeddingEncoder,
|
|
12
13
|
BaseEmbeddingEncoder,
|
|
13
14
|
EmbeddingConfig,
|
|
14
15
|
)
|
|
15
16
|
from unstructured_ingest.logger import logger
|
|
16
17
|
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
17
|
-
from unstructured_ingest.v2.errors import
|
|
18
|
+
from unstructured_ingest.v2.errors import (
|
|
19
|
+
ProviderError,
|
|
20
|
+
RateLimitError,
|
|
21
|
+
UserAuthError,
|
|
22
|
+
UserError,
|
|
23
|
+
is_internal_error,
|
|
24
|
+
)
|
|
18
25
|
|
|
19
26
|
if TYPE_CHECKING:
|
|
20
27
|
from botocore.client import BaseClient
|
|
@@ -50,9 +57,11 @@ class BedrockEmbeddingConfig(EmbeddingConfig):
|
|
|
50
57
|
aws_access_key_id: SecretStr
|
|
51
58
|
aws_secret_access_key: SecretStr
|
|
52
59
|
region_name: str = "us-west-2"
|
|
53
|
-
|
|
60
|
+
embedder_model_name: str = Field(default="amazon.titan-embed-text-v1", alias="model_name")
|
|
54
61
|
|
|
55
62
|
def wrap_error(self, e: Exception) -> Exception:
|
|
63
|
+
if is_internal_error(e=e):
|
|
64
|
+
return e
|
|
56
65
|
from botocore.exceptions import ClientError
|
|
57
66
|
|
|
58
67
|
if isinstance(e, ClientError):
|
|
@@ -121,7 +130,7 @@ class BedrockEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
121
130
|
|
|
122
131
|
def embed_query(self, query: str) -> list[float]:
|
|
123
132
|
"""Call out to Bedrock embedding endpoint."""
|
|
124
|
-
provider = self.config.
|
|
133
|
+
provider = self.config.embedder_model_name.split(".")[0]
|
|
125
134
|
body = conform_query(query=query, provider=provider)
|
|
126
135
|
|
|
127
136
|
bedrock_client = self.config.get_client()
|
|
@@ -129,7 +138,7 @@ class BedrockEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
129
138
|
try:
|
|
130
139
|
response = bedrock_client.invoke_model(
|
|
131
140
|
body=json.dumps(body),
|
|
132
|
-
modelId=self.config.
|
|
141
|
+
modelId=self.config.embedder_model_name,
|
|
133
142
|
accept="application/json",
|
|
134
143
|
contentType="application/json",
|
|
135
144
|
)
|
|
@@ -145,9 +154,14 @@ class BedrockEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
145
154
|
return response_body.get("embedding")
|
|
146
155
|
|
|
147
156
|
def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
157
|
+
elements = elements.copy()
|
|
158
|
+
elements_with_text = [e for e in elements if e.get("text")]
|
|
159
|
+
if not elements_with_text:
|
|
160
|
+
return elements
|
|
161
|
+
embeddings = [self.embed_query(query=e["text"]) for e in elements_with_text]
|
|
162
|
+
for element, embedding in zip(elements_with_text, embeddings):
|
|
163
|
+
element[EMBEDDINGS_KEY] = embedding
|
|
164
|
+
return elements
|
|
151
165
|
|
|
152
166
|
|
|
153
167
|
@dataclass
|
|
@@ -159,7 +173,7 @@ class AsyncBedrockEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
|
|
|
159
173
|
|
|
160
174
|
async def embed_query(self, query: str) -> list[float]:
|
|
161
175
|
"""Call out to Bedrock embedding endpoint."""
|
|
162
|
-
provider = self.config.
|
|
176
|
+
provider = self.config.embedder_model_name.split(".")[0]
|
|
163
177
|
body = conform_query(query=query, provider=provider)
|
|
164
178
|
try:
|
|
165
179
|
async with self.config.get_async_client() as bedrock_client:
|
|
@@ -167,7 +181,7 @@ class AsyncBedrockEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
|
|
|
167
181
|
try:
|
|
168
182
|
response = await bedrock_client.invoke_model(
|
|
169
183
|
body=json.dumps(body),
|
|
170
|
-
modelId=self.config.
|
|
184
|
+
modelId=self.config.embedder_model_name,
|
|
171
185
|
accept="application/json",
|
|
172
186
|
contentType="application/json",
|
|
173
187
|
)
|
|
@@ -186,8 +200,11 @@ class AsyncBedrockEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
|
|
|
186
200
|
raise ValueError(f"Error raised by inference endpoint: {e}")
|
|
187
201
|
|
|
188
202
|
async def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
203
|
+
elements = elements.copy()
|
|
204
|
+
elements_with_text = [e for e in elements if e.get("text")]
|
|
189
205
|
embeddings = await asyncio.gather(
|
|
190
|
-
*[self.embed_query(query=e.get("text", "")) for e in
|
|
206
|
+
*[self.embed_query(query=e.get("text", "")) for e in elements_with_text]
|
|
191
207
|
)
|
|
192
|
-
|
|
193
|
-
|
|
208
|
+
for element, embedding in zip(elements_with_text, embeddings):
|
|
209
|
+
element[EMBEDDINGS_KEY] = embedding
|
|
210
|
+
return elements
|
|
@@ -3,7 +3,11 @@ from typing import TYPE_CHECKING, Optional
|
|
|
3
3
|
|
|
4
4
|
from pydantic import Field
|
|
5
5
|
|
|
6
|
-
from unstructured_ingest.embed.interfaces import
|
|
6
|
+
from unstructured_ingest.embed.interfaces import (
|
|
7
|
+
EMBEDDINGS_KEY,
|
|
8
|
+
BaseEmbeddingEncoder,
|
|
9
|
+
EmbeddingConfig,
|
|
10
|
+
)
|
|
7
11
|
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
8
12
|
|
|
9
13
|
if TYPE_CHECKING:
|
|
@@ -43,7 +47,7 @@ class HuggingFaceEmbeddingConfig(EmbeddingConfig):
|
|
|
43
47
|
class HuggingFaceEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
44
48
|
config: HuggingFaceEmbeddingConfig
|
|
45
49
|
|
|
46
|
-
def
|
|
50
|
+
def _embed_query(self, query: str) -> list[float]:
|
|
47
51
|
return self._embed_documents(texts=[query])[0]
|
|
48
52
|
|
|
49
53
|
def _embed_documents(self, texts: list[str]) -> list[list[float]]:
|
|
@@ -52,6 +56,11 @@ class HuggingFaceEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
52
56
|
return embeddings.tolist()
|
|
53
57
|
|
|
54
58
|
def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
59
|
+
elements = elements.copy()
|
|
60
|
+
elements_with_text = [e for e in elements if e.get("text")]
|
|
61
|
+
if not elements_with_text:
|
|
62
|
+
return elements
|
|
63
|
+
embeddings = self._embed_documents([e["text"] for e in elements_with_text])
|
|
64
|
+
for element, embedding in zip(elements_with_text, embeddings):
|
|
65
|
+
element[EMBEDDINGS_KEY] = embedding
|
|
66
|
+
return elements
|
|
@@ -1,11 +1,14 @@
|
|
|
1
|
-
import
|
|
2
|
-
from abc import ABC, abstractmethod
|
|
1
|
+
from abc import ABC
|
|
3
2
|
from dataclasses import dataclass
|
|
4
|
-
from typing import Optional
|
|
3
|
+
from typing import Any, Optional
|
|
5
4
|
|
|
6
5
|
import numpy as np
|
|
7
6
|
from pydantic import BaseModel, Field
|
|
8
7
|
|
|
8
|
+
from unstructured_ingest.utils.data_prep import batch_generator
|
|
9
|
+
|
|
10
|
+
EMBEDDINGS_KEY = "embeddings"
|
|
11
|
+
|
|
9
12
|
|
|
10
13
|
class EmbeddingConfig(BaseModel):
|
|
11
14
|
batch_size: Optional[int] = Field(
|
|
@@ -26,27 +29,6 @@ class BaseEncoder(ABC):
|
|
|
26
29
|
if possible"""
|
|
27
30
|
return e
|
|
28
31
|
|
|
29
|
-
@staticmethod
|
|
30
|
-
def _add_embeddings_to_elements(
|
|
31
|
-
elements: list[dict], embeddings: list[list[float]]
|
|
32
|
-
) -> list[dict]:
|
|
33
|
-
"""
|
|
34
|
-
Add embeddings to elements.
|
|
35
|
-
|
|
36
|
-
Args:
|
|
37
|
-
elements (list[Element]): List of elements.
|
|
38
|
-
embeddings (list[list[float]]): List of embeddings.
|
|
39
|
-
|
|
40
|
-
Returns:
|
|
41
|
-
list[Element]: Elements with embeddings added.
|
|
42
|
-
"""
|
|
43
|
-
assert len(elements) == len(embeddings)
|
|
44
|
-
elements_w_embedding = []
|
|
45
|
-
for i, element in enumerate(elements):
|
|
46
|
-
element["embeddings"] = embeddings[i]
|
|
47
|
-
elements_w_embedding.append(element)
|
|
48
|
-
return elements
|
|
49
|
-
|
|
50
32
|
|
|
51
33
|
@dataclass
|
|
52
34
|
class BaseEmbeddingEncoder(BaseEncoder, ABC):
|
|
@@ -69,21 +51,37 @@ class BaseEmbeddingEncoder(BaseEncoder, ABC):
|
|
|
69
51
|
exemplary_embedding = self.get_exemplary_embedding()
|
|
70
52
|
return np.isclose(np.linalg.norm(exemplary_embedding), 1.0)
|
|
71
53
|
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
pass
|
|
54
|
+
def get_client(self):
|
|
55
|
+
raise NotImplementedError
|
|
75
56
|
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
pass
|
|
57
|
+
def embed_batch(self, client: Any, batch: list[str]) -> list[list[float]]:
|
|
58
|
+
raise NotImplementedError
|
|
79
59
|
|
|
80
|
-
def
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
60
|
+
def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
61
|
+
client = self.get_client()
|
|
62
|
+
elements = elements.copy()
|
|
63
|
+
elements_with_text = [e for e in elements if e.get("text")]
|
|
64
|
+
texts = [e["text"] for e in elements_with_text]
|
|
65
|
+
embeddings = []
|
|
66
|
+
try:
|
|
67
|
+
for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
|
|
68
|
+
embeddings = self.embed_batch(client=client, batch=batch)
|
|
69
|
+
embeddings.extend(embeddings)
|
|
70
|
+
except Exception as e:
|
|
71
|
+
raise self.wrap_error(e=e)
|
|
72
|
+
for element, embedding in zip(elements_with_text, embeddings):
|
|
73
|
+
element[EMBEDDINGS_KEY] = embedding
|
|
74
|
+
return elements
|
|
85
75
|
|
|
86
|
-
|
|
76
|
+
def _embed_query(self, query: str) -> list[float]:
|
|
77
|
+
client = self.get_client()
|
|
78
|
+
return self.embed_batch(client=client, batch=[query])[0]
|
|
79
|
+
|
|
80
|
+
def embed_query(self, query: str) -> list[float]:
|
|
81
|
+
try:
|
|
82
|
+
return self._embed_query(query=query)
|
|
83
|
+
except Exception as e:
|
|
84
|
+
raise self.wrap_error(e=e)
|
|
87
85
|
|
|
88
86
|
|
|
89
87
|
@dataclass
|
|
@@ -107,14 +105,35 @@ class AsyncBaseEmbeddingEncoder(BaseEncoder, ABC):
|
|
|
107
105
|
exemplary_embedding = await self.get_exemplary_embedding()
|
|
108
106
|
return np.isclose(np.linalg.norm(exemplary_embedding), 1.0)
|
|
109
107
|
|
|
110
|
-
|
|
108
|
+
def get_client(self):
|
|
109
|
+
raise NotImplementedError
|
|
110
|
+
|
|
111
|
+
async def embed_batch(self, client: Any, batch: list[str]) -> list[list[float]]:
|
|
112
|
+
raise NotImplementedError
|
|
113
|
+
|
|
111
114
|
async def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
112
|
-
|
|
115
|
+
client = self.get_client()
|
|
116
|
+
elements = elements.copy()
|
|
117
|
+
elements_with_text = [e for e in elements if e.get("text")]
|
|
118
|
+
texts = [e["text"] for e in elements_with_text]
|
|
119
|
+
embeddings = []
|
|
120
|
+
try:
|
|
121
|
+
for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
|
|
122
|
+
embeddings = await self.embed_batch(client=client, batch=batch)
|
|
123
|
+
embeddings.extend(embeddings)
|
|
124
|
+
except Exception as e:
|
|
125
|
+
raise self.wrap_error(e=e)
|
|
126
|
+
for element, embedding in zip(elements_with_text, embeddings):
|
|
127
|
+
element[EMBEDDINGS_KEY] = embedding
|
|
128
|
+
return elements
|
|
113
129
|
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
130
|
+
async def _embed_query(self, query: str) -> list[float]:
|
|
131
|
+
client = self.get_client()
|
|
132
|
+
embeddings = await self.embed_batch(client=client, batch=[query])
|
|
133
|
+
return embeddings[0]
|
|
117
134
|
|
|
118
|
-
async def
|
|
119
|
-
|
|
120
|
-
|
|
135
|
+
async def embed_query(self, query: str) -> list[float]:
|
|
136
|
+
try:
|
|
137
|
+
return await self._embed_query(query=query)
|
|
138
|
+
except Exception as e:
|
|
139
|
+
raise self.wrap_error(e=e)
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import asyncio
|
|
2
1
|
import os
|
|
3
2
|
from dataclasses import dataclass
|
|
4
3
|
from typing import TYPE_CHECKING
|
|
@@ -10,7 +9,6 @@ from unstructured_ingest.embed.interfaces import (
|
|
|
10
9
|
BaseEmbeddingEncoder,
|
|
11
10
|
EmbeddingConfig,
|
|
12
11
|
)
|
|
13
|
-
from unstructured_ingest.utils.data_prep import batch_generator
|
|
14
12
|
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
15
13
|
|
|
16
14
|
USER_AGENT = "@mixedbread-ai/unstructured"
|
|
@@ -84,7 +82,7 @@ class MixedbreadAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
84
82
|
|
|
85
83
|
def get_exemplary_embedding(self) -> list[float]:
|
|
86
84
|
"""Get an exemplary embedding to determine dimensions and unit vector status."""
|
|
87
|
-
return self.
|
|
85
|
+
return self.embed_query(query="Q")
|
|
88
86
|
|
|
89
87
|
@requires_dependencies(
|
|
90
88
|
["mixedbread_ai"],
|
|
@@ -99,55 +97,19 @@ class MixedbreadAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
99
97
|
additional_headers={"User-Agent": USER_AGENT},
|
|
100
98
|
)
|
|
101
99
|
|
|
102
|
-
def
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
|
|
116
|
-
response = client.embeddings(
|
|
117
|
-
model=self.config.embedder_model_name,
|
|
118
|
-
normalized=True,
|
|
119
|
-
encoding_format=ENCODING_FORMAT,
|
|
120
|
-
truncation_strategy=TRUNCATION_STRATEGY,
|
|
121
|
-
request_options=self.get_request_options(),
|
|
122
|
-
input=batch,
|
|
123
|
-
)
|
|
124
|
-
responses.append(response)
|
|
125
|
-
return [item.embedding for response in responses for item in response.data]
|
|
126
|
-
|
|
127
|
-
def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
128
|
-
"""
|
|
129
|
-
Embed a list of document elements.
|
|
130
|
-
|
|
131
|
-
Args:
|
|
132
|
-
elements (list[Element]): List of document elements.
|
|
133
|
-
|
|
134
|
-
Returns:
|
|
135
|
-
list[Element]: Elements with embeddings.
|
|
136
|
-
"""
|
|
137
|
-
embeddings = self._embed([e.get("text", "") for e in elements])
|
|
138
|
-
return self._add_embeddings_to_elements(elements, embeddings)
|
|
139
|
-
|
|
140
|
-
def embed_query(self, query: str) -> list[float]:
|
|
141
|
-
"""
|
|
142
|
-
Embed a query string.
|
|
143
|
-
|
|
144
|
-
Args:
|
|
145
|
-
query (str): Query string to embed.
|
|
146
|
-
|
|
147
|
-
Returns:
|
|
148
|
-
list[float]: Embedding of the query.
|
|
149
|
-
"""
|
|
150
|
-
return self._embed([query])[0]
|
|
100
|
+
def get_client(self) -> "MixedbreadAI":
|
|
101
|
+
return self.config.get_client()
|
|
102
|
+
|
|
103
|
+
def embed_batch(self, client: "MixedbreadAI", batch: list[str]) -> list[list[float]]:
|
|
104
|
+
response = client.embeddings(
|
|
105
|
+
model=self.config.embedder_model_name,
|
|
106
|
+
normalized=True,
|
|
107
|
+
encoding_format=ENCODING_FORMAT,
|
|
108
|
+
truncation_strategy=TRUNCATION_STRATEGY,
|
|
109
|
+
request_options=self.get_request_options(),
|
|
110
|
+
input=batch,
|
|
111
|
+
)
|
|
112
|
+
return [datum.embedding for datum in response.data]
|
|
151
113
|
|
|
152
114
|
|
|
153
115
|
@dataclass
|
|
@@ -157,8 +119,7 @@ class AsyncMixedbreadAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
|
|
|
157
119
|
|
|
158
120
|
async def get_exemplary_embedding(self) -> list[float]:
|
|
159
121
|
"""Get an exemplary embedding to determine dimensions and unit vector status."""
|
|
160
|
-
|
|
161
|
-
return embedding[0]
|
|
122
|
+
return await self.embed_query(query="Q")
|
|
162
123
|
|
|
163
124
|
@requires_dependencies(
|
|
164
125
|
["mixedbread_ai"],
|
|
@@ -173,54 +134,16 @@ class AsyncMixedbreadAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
|
|
|
173
134
|
additional_headers={"User-Agent": USER_AGENT},
|
|
174
135
|
)
|
|
175
136
|
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
for
|
|
189
|
-
tasks.append(
|
|
190
|
-
client.embeddings(
|
|
191
|
-
model=self.config.embedder_model_name,
|
|
192
|
-
normalized=True,
|
|
193
|
-
encoding_format=ENCODING_FORMAT,
|
|
194
|
-
truncation_strategy=TRUNCATION_STRATEGY,
|
|
195
|
-
request_options=self.get_request_options(),
|
|
196
|
-
input=batch,
|
|
197
|
-
)
|
|
198
|
-
)
|
|
199
|
-
responses = await asyncio.gather(*tasks)
|
|
200
|
-
return [item.embedding for response in responses for item in response.data]
|
|
201
|
-
|
|
202
|
-
async def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
203
|
-
"""
|
|
204
|
-
Embed a list of document elements.
|
|
205
|
-
|
|
206
|
-
Args:
|
|
207
|
-
elements (list[Element]): List of document elements.
|
|
208
|
-
|
|
209
|
-
Returns:
|
|
210
|
-
list[Element]: Elements with embeddings.
|
|
211
|
-
"""
|
|
212
|
-
embeddings = await self._embed([e.get("text", "") for e in elements])
|
|
213
|
-
return self._add_embeddings_to_elements(elements, embeddings)
|
|
214
|
-
|
|
215
|
-
async def embed_query(self, query: str) -> list[float]:
|
|
216
|
-
"""
|
|
217
|
-
Embed a query string.
|
|
218
|
-
|
|
219
|
-
Args:
|
|
220
|
-
query (str): Query string to embed.
|
|
221
|
-
|
|
222
|
-
Returns:
|
|
223
|
-
list[float]: Embedding of the query.
|
|
224
|
-
"""
|
|
225
|
-
embedding = await self._embed([query])
|
|
226
|
-
return embedding[0]
|
|
137
|
+
def get_client(self) -> "AsyncMixedbreadAI":
|
|
138
|
+
return self.config.get_async_client()
|
|
139
|
+
|
|
140
|
+
async def embed_batch(self, client: "AsyncMixedbreadAI", batch: list[str]) -> list[list[float]]:
|
|
141
|
+
response = await client.embeddings(
|
|
142
|
+
model=self.config.embedder_model_name,
|
|
143
|
+
normalized=True,
|
|
144
|
+
encoding_format=ENCODING_FORMAT,
|
|
145
|
+
truncation_strategy=TRUNCATION_STRATEGY,
|
|
146
|
+
request_options=self.get_request_options(),
|
|
147
|
+
input=batch,
|
|
148
|
+
)
|
|
149
|
+
return [datum.embedding for datum in response.data]
|