unstructured-ingest 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of unstructured-ingest might be problematic. Click here for more details.
- test/integration/connectors/test_sharepoint.py +161 -10
- test/unit/v2/embedders/test_bedrock.py +1 -1
- test/unit/v2/embedders/test_huggingface.py +1 -1
- unstructured_ingest/__version__.py +1 -1
- unstructured_ingest/embed/azure_openai.py +6 -0
- unstructured_ingest/embed/bedrock.py +16 -6
- unstructured_ingest/embed/huggingface.py +3 -1
- unstructured_ingest/embed/interfaces.py +61 -23
- unstructured_ingest/embed/mixedbreadai.py +28 -114
- unstructured_ingest/embed/octoai.py +19 -51
- unstructured_ingest/embed/openai.py +17 -55
- unstructured_ingest/embed/togetherai.py +16 -58
- unstructured_ingest/embed/vertexai.py +15 -46
- unstructured_ingest/embed/voyageai.py +17 -52
- unstructured_ingest/v2/errors.py +7 -0
- unstructured_ingest/v2/processes/connectors/neo4j.py +129 -43
- unstructured_ingest/v2/processes/connectors/sharepoint.py +9 -4
- unstructured_ingest/v2/processes/embedder.py +9 -7
- {unstructured_ingest-0.5.2.dist-info → unstructured_ingest-0.5.4.dist-info}/METADATA +101 -89
- {unstructured_ingest-0.5.2.dist-info → unstructured_ingest-0.5.4.dist-info}/RECORD +24 -24
- {unstructured_ingest-0.5.2.dist-info → unstructured_ingest-0.5.4.dist-info}/WHEEL +1 -1
- {unstructured_ingest-0.5.2.dist-info → unstructured_ingest-0.5.4.dist-info}/LICENSE.md +0 -0
- {unstructured_ingest-0.5.2.dist-info → unstructured_ingest-0.5.4.dist-info}/entry_points.txt +0 -0
- {unstructured_ingest-0.5.2.dist-info → unstructured_ingest-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -19,24 +19,31 @@ from unstructured_ingest.v2.processes.connectors.sharepoint import (
|
|
|
19
19
|
)
|
|
20
20
|
|
|
21
21
|
|
|
22
|
+
def sharepoint_config():
|
|
23
|
+
class SharepointTestConfig:
|
|
24
|
+
def __init__(self):
|
|
25
|
+
self.client_id = os.environ["SHAREPOINT_CLIENT_ID"]
|
|
26
|
+
self.client_cred = os.environ["SHAREPOINT_CRED"]
|
|
27
|
+
self.user_pname = os.environ["MS_USER_PNAME"]
|
|
28
|
+
self.tenant = os.environ["MS_TENANT_ID"]
|
|
29
|
+
|
|
30
|
+
return SharepointTestConfig()
|
|
31
|
+
|
|
32
|
+
|
|
22
33
|
@pytest.mark.asyncio
|
|
23
34
|
@pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG, BLOB_STORAGE_TAG)
|
|
24
35
|
@requires_env("SHAREPOINT_CLIENT_ID", "SHAREPOINT_CRED", "MS_TENANT_ID", "MS_USER_PNAME")
|
|
25
36
|
async def test_sharepoint_source(temp_dir):
|
|
26
|
-
# Retrieve environment variables
|
|
27
37
|
site = "https://unstructuredio.sharepoint.com/sites/utic-platform-test-source"
|
|
28
|
-
|
|
29
|
-
client_cred = os.environ["SHAREPOINT_CRED"]
|
|
30
|
-
user_pname = os.environ["MS_USER_PNAME"]
|
|
31
|
-
tenant = os.environ["MS_TENANT_ID"]
|
|
38
|
+
config = sharepoint_config()
|
|
32
39
|
|
|
33
40
|
# Create connection and indexer configurations
|
|
34
|
-
access_config = SharepointAccessConfig(client_cred=client_cred)
|
|
41
|
+
access_config = SharepointAccessConfig(client_cred=config.client_cred)
|
|
35
42
|
connection_config = SharepointConnectionConfig(
|
|
36
|
-
client_id=client_id,
|
|
43
|
+
client_id=config.client_id,
|
|
37
44
|
site=site,
|
|
38
|
-
tenant=tenant,
|
|
39
|
-
user_pname=user_pname,
|
|
45
|
+
tenant=config.tenant,
|
|
46
|
+
user_pname=config.user_pname,
|
|
40
47
|
access_config=access_config,
|
|
41
48
|
)
|
|
42
49
|
index_config = SharepointIndexerConfig(recursive=True)
|
|
@@ -58,7 +65,151 @@ async def test_sharepoint_source(temp_dir):
|
|
|
58
65
|
indexer=indexer,
|
|
59
66
|
downloader=downloader,
|
|
60
67
|
configs=SourceValidationConfigs(
|
|
61
|
-
test_id="
|
|
68
|
+
test_id="sharepoint1",
|
|
69
|
+
expected_num_files=4,
|
|
70
|
+
validate_downloaded_files=True,
|
|
71
|
+
exclude_fields_extend=[
|
|
72
|
+
"metadata.date_created",
|
|
73
|
+
"metadata.date_modified",
|
|
74
|
+
"additional_metadata.LastModified",
|
|
75
|
+
"additional_metadata.@microsoft.graph.downloadUrl",
|
|
76
|
+
],
|
|
77
|
+
),
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
@pytest.mark.asyncio
|
|
82
|
+
@pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG, BLOB_STORAGE_TAG)
|
|
83
|
+
@requires_env("SHAREPOINT_CLIENT_ID", "SHAREPOINT_CRED", "MS_TENANT_ID", "MS_USER_PNAME")
|
|
84
|
+
async def test_sharepoint_source_with_path(temp_dir):
|
|
85
|
+
site = "https://unstructuredio.sharepoint.com/sites/utic-platform-test-source"
|
|
86
|
+
config = sharepoint_config()
|
|
87
|
+
|
|
88
|
+
# Create connection and indexer configurations
|
|
89
|
+
access_config = SharepointAccessConfig(client_cred=config.client_cred)
|
|
90
|
+
connection_config = SharepointConnectionConfig(
|
|
91
|
+
client_id=config.client_id,
|
|
92
|
+
site=site,
|
|
93
|
+
tenant=config.tenant,
|
|
94
|
+
user_pname=config.user_pname,
|
|
95
|
+
access_config=access_config,
|
|
96
|
+
)
|
|
97
|
+
index_config = SharepointIndexerConfig(recursive=True, path="Folder1")
|
|
98
|
+
|
|
99
|
+
download_config = SharepointDownloaderConfig(download_dir=temp_dir)
|
|
100
|
+
|
|
101
|
+
# Instantiate indexer and downloader
|
|
102
|
+
indexer = SharepointIndexer(
|
|
103
|
+
connection_config=connection_config,
|
|
104
|
+
index_config=index_config,
|
|
105
|
+
)
|
|
106
|
+
downloader = SharepointDownloader(
|
|
107
|
+
connection_config=connection_config,
|
|
108
|
+
download_config=download_config,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
# Run the source connector validation
|
|
112
|
+
await source_connector_validation(
|
|
113
|
+
indexer=indexer,
|
|
114
|
+
downloader=downloader,
|
|
115
|
+
configs=SourceValidationConfigs(
|
|
116
|
+
test_id="sharepoint2",
|
|
117
|
+
expected_num_files=2,
|
|
118
|
+
validate_downloaded_files=True,
|
|
119
|
+
exclude_fields_extend=[
|
|
120
|
+
"metadata.date_created",
|
|
121
|
+
"metadata.date_modified",
|
|
122
|
+
"additional_metadata.LastModified",
|
|
123
|
+
"additional_metadata.@microsoft.graph.downloadUrl",
|
|
124
|
+
],
|
|
125
|
+
),
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
|
|
129
|
+
@pytest.mark.asyncio
|
|
130
|
+
@pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG, BLOB_STORAGE_TAG)
|
|
131
|
+
@requires_env("SHAREPOINT_CLIENT_ID", "SHAREPOINT_CRED", "MS_TENANT_ID", "MS_USER_PNAME")
|
|
132
|
+
async def test_sharepoint_root_with_path(temp_dir):
|
|
133
|
+
site = "https://unstructuredio.sharepoint.com/"
|
|
134
|
+
config = sharepoint_config()
|
|
135
|
+
|
|
136
|
+
# Create connection and indexer configurations
|
|
137
|
+
access_config = SharepointAccessConfig(client_cred=config.client_cred)
|
|
138
|
+
connection_config = SharepointConnectionConfig(
|
|
139
|
+
client_id=config.client_id,
|
|
140
|
+
site=site,
|
|
141
|
+
tenant=config.tenant,
|
|
142
|
+
user_pname=config.user_pname,
|
|
143
|
+
access_config=access_config,
|
|
144
|
+
)
|
|
145
|
+
index_config = SharepointIndexerConfig(recursive=True, path="e2e-test-folder")
|
|
146
|
+
|
|
147
|
+
download_config = SharepointDownloaderConfig(download_dir=temp_dir)
|
|
148
|
+
|
|
149
|
+
# Instantiate indexer and downloader
|
|
150
|
+
indexer = SharepointIndexer(
|
|
151
|
+
connection_config=connection_config,
|
|
152
|
+
index_config=index_config,
|
|
153
|
+
)
|
|
154
|
+
downloader = SharepointDownloader(
|
|
155
|
+
connection_config=connection_config,
|
|
156
|
+
download_config=download_config,
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
# Run the source connector validation
|
|
160
|
+
await source_connector_validation(
|
|
161
|
+
indexer=indexer,
|
|
162
|
+
downloader=downloader,
|
|
163
|
+
configs=SourceValidationConfigs(
|
|
164
|
+
test_id="sharepoint3",
|
|
165
|
+
expected_num_files=1,
|
|
166
|
+
validate_downloaded_files=True,
|
|
167
|
+
exclude_fields_extend=[
|
|
168
|
+
"metadata.date_created",
|
|
169
|
+
"metadata.date_modified",
|
|
170
|
+
"additional_metadata.LastModified",
|
|
171
|
+
"additional_metadata.@microsoft.graph.downloadUrl",
|
|
172
|
+
],
|
|
173
|
+
),
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
@pytest.mark.asyncio
|
|
178
|
+
@pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG, BLOB_STORAGE_TAG)
|
|
179
|
+
@requires_env("SHAREPOINT_CLIENT_ID", "SHAREPOINT_CRED", "MS_TENANT_ID", "MS_USER_PNAME")
|
|
180
|
+
async def test_sharepoint_shared_documents(temp_dir):
|
|
181
|
+
site = "https://unstructuredio.sharepoint.com/sites/utic-platform-test-source"
|
|
182
|
+
config = sharepoint_config()
|
|
183
|
+
|
|
184
|
+
# Create connection and indexer configurations
|
|
185
|
+
access_config = SharepointAccessConfig(client_cred=config.client_cred)
|
|
186
|
+
connection_config = SharepointConnectionConfig(
|
|
187
|
+
client_id=config.client_id,
|
|
188
|
+
site=site,
|
|
189
|
+
tenant=config.tenant,
|
|
190
|
+
user_pname=config.user_pname,
|
|
191
|
+
access_config=access_config,
|
|
192
|
+
)
|
|
193
|
+
index_config = SharepointIndexerConfig(recursive=True, path="Shared Documents")
|
|
194
|
+
|
|
195
|
+
download_config = SharepointDownloaderConfig(download_dir=temp_dir)
|
|
196
|
+
|
|
197
|
+
# Instantiate indexer and downloader
|
|
198
|
+
indexer = SharepointIndexer(
|
|
199
|
+
connection_config=connection_config,
|
|
200
|
+
index_config=index_config,
|
|
201
|
+
)
|
|
202
|
+
downloader = SharepointDownloader(
|
|
203
|
+
connection_config=connection_config,
|
|
204
|
+
download_config=download_config,
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
# Run the source connector validation
|
|
208
|
+
await source_connector_validation(
|
|
209
|
+
indexer=indexer,
|
|
210
|
+
downloader=downloader,
|
|
211
|
+
configs=SourceValidationConfigs(
|
|
212
|
+
test_id="sharepoint4",
|
|
62
213
|
expected_num_files=4,
|
|
63
214
|
validate_downloaded_files=True,
|
|
64
215
|
exclude_fields_extend=[
|
|
@@ -16,7 +16,7 @@ fake = faker.Faker()
|
|
|
16
16
|
def generate_embedder_config_params() -> dict:
|
|
17
17
|
params = {}
|
|
18
18
|
if random.random() < 0.5:
|
|
19
|
-
params["
|
|
19
|
+
params["embedder_model_name"] = fake.word() if random.random() < 0.5 else None
|
|
20
20
|
params["embedder_model_kwargs"] = (
|
|
21
21
|
generate_random_dictionary(key_type=str, value_type=Any)
|
|
22
22
|
if random.random() < 0.5
|
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.5.
|
|
1
|
+
__version__ = "0.5.4" # pragma: no cover
|
|
@@ -44,7 +44,13 @@ class AzureOpenAIEmbeddingConfig(OpenAIEmbeddingConfig):
|
|
|
44
44
|
class AzureOpenAIEmbeddingEncoder(OpenAIEmbeddingEncoder):
|
|
45
45
|
config: AzureOpenAIEmbeddingConfig
|
|
46
46
|
|
|
47
|
+
def get_client(self) -> "AzureOpenAI":
|
|
48
|
+
return self.config.get_client()
|
|
49
|
+
|
|
47
50
|
|
|
48
51
|
@dataclass
|
|
49
52
|
class AsyncAzureOpenAIEmbeddingEncoder(AsyncOpenAIEmbeddingEncoder):
|
|
50
53
|
config: AzureOpenAIEmbeddingConfig
|
|
54
|
+
|
|
55
|
+
def get_client(self) -> "AsyncAzureOpenAI":
|
|
56
|
+
return self.config.get_async_client()
|
|
@@ -15,7 +15,13 @@ from unstructured_ingest.embed.interfaces import (
|
|
|
15
15
|
)
|
|
16
16
|
from unstructured_ingest.logger import logger
|
|
17
17
|
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
18
|
-
from unstructured_ingest.v2.errors import
|
|
18
|
+
from unstructured_ingest.v2.errors import (
|
|
19
|
+
ProviderError,
|
|
20
|
+
RateLimitError,
|
|
21
|
+
UserAuthError,
|
|
22
|
+
UserError,
|
|
23
|
+
is_internal_error,
|
|
24
|
+
)
|
|
19
25
|
|
|
20
26
|
if TYPE_CHECKING:
|
|
21
27
|
from botocore.client import BaseClient
|
|
@@ -51,9 +57,11 @@ class BedrockEmbeddingConfig(EmbeddingConfig):
|
|
|
51
57
|
aws_access_key_id: SecretStr
|
|
52
58
|
aws_secret_access_key: SecretStr
|
|
53
59
|
region_name: str = "us-west-2"
|
|
54
|
-
|
|
60
|
+
embedder_model_name: str = Field(default="amazon.titan-embed-text-v1", alias="model_name")
|
|
55
61
|
|
|
56
62
|
def wrap_error(self, e: Exception) -> Exception:
|
|
63
|
+
if is_internal_error(e=e):
|
|
64
|
+
return e
|
|
57
65
|
from botocore.exceptions import ClientError
|
|
58
66
|
|
|
59
67
|
if isinstance(e, ClientError):
|
|
@@ -122,7 +130,7 @@ class BedrockEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
122
130
|
|
|
123
131
|
def embed_query(self, query: str) -> list[float]:
|
|
124
132
|
"""Call out to Bedrock embedding endpoint."""
|
|
125
|
-
provider = self.config.
|
|
133
|
+
provider = self.config.embedder_model_name.split(".")[0]
|
|
126
134
|
body = conform_query(query=query, provider=provider)
|
|
127
135
|
|
|
128
136
|
bedrock_client = self.config.get_client()
|
|
@@ -130,7 +138,7 @@ class BedrockEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
130
138
|
try:
|
|
131
139
|
response = bedrock_client.invoke_model(
|
|
132
140
|
body=json.dumps(body),
|
|
133
|
-
modelId=self.config.
|
|
141
|
+
modelId=self.config.embedder_model_name,
|
|
134
142
|
accept="application/json",
|
|
135
143
|
contentType="application/json",
|
|
136
144
|
)
|
|
@@ -148,6 +156,8 @@ class BedrockEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
148
156
|
def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
149
157
|
elements = elements.copy()
|
|
150
158
|
elements_with_text = [e for e in elements if e.get("text")]
|
|
159
|
+
if not elements_with_text:
|
|
160
|
+
return elements
|
|
151
161
|
embeddings = [self.embed_query(query=e["text"]) for e in elements_with_text]
|
|
152
162
|
for element, embedding in zip(elements_with_text, embeddings):
|
|
153
163
|
element[EMBEDDINGS_KEY] = embedding
|
|
@@ -163,7 +173,7 @@ class AsyncBedrockEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
|
|
|
163
173
|
|
|
164
174
|
async def embed_query(self, query: str) -> list[float]:
|
|
165
175
|
"""Call out to Bedrock embedding endpoint."""
|
|
166
|
-
provider = self.config.
|
|
176
|
+
provider = self.config.embedder_model_name.split(".")[0]
|
|
167
177
|
body = conform_query(query=query, provider=provider)
|
|
168
178
|
try:
|
|
169
179
|
async with self.config.get_async_client() as bedrock_client:
|
|
@@ -171,7 +181,7 @@ class AsyncBedrockEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
|
|
|
171
181
|
try:
|
|
172
182
|
response = await bedrock_client.invoke_model(
|
|
173
183
|
body=json.dumps(body),
|
|
174
|
-
modelId=self.config.
|
|
184
|
+
modelId=self.config.embedder_model_name,
|
|
175
185
|
accept="application/json",
|
|
176
186
|
contentType="application/json",
|
|
177
187
|
)
|
|
@@ -47,7 +47,7 @@ class HuggingFaceEmbeddingConfig(EmbeddingConfig):
|
|
|
47
47
|
class HuggingFaceEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
48
48
|
config: HuggingFaceEmbeddingConfig
|
|
49
49
|
|
|
50
|
-
def
|
|
50
|
+
def _embed_query(self, query: str) -> list[float]:
|
|
51
51
|
return self._embed_documents(texts=[query])[0]
|
|
52
52
|
|
|
53
53
|
def _embed_documents(self, texts: list[str]) -> list[list[float]]:
|
|
@@ -58,6 +58,8 @@ class HuggingFaceEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
58
58
|
def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
59
59
|
elements = elements.copy()
|
|
60
60
|
elements_with_text = [e for e in elements if e.get("text")]
|
|
61
|
+
if not elements_with_text:
|
|
62
|
+
return elements
|
|
61
63
|
embeddings = self._embed_documents([e["text"] for e in elements_with_text])
|
|
62
64
|
for element, embedding in zip(elements_with_text, embeddings):
|
|
63
65
|
element[EMBEDDINGS_KEY] = embedding
|
|
@@ -1,11 +1,12 @@
|
|
|
1
|
-
import
|
|
2
|
-
from abc import ABC, abstractmethod
|
|
1
|
+
from abc import ABC
|
|
3
2
|
from dataclasses import dataclass
|
|
4
|
-
from typing import Optional
|
|
3
|
+
from typing import Any, Optional
|
|
5
4
|
|
|
6
5
|
import numpy as np
|
|
7
6
|
from pydantic import BaseModel, Field
|
|
8
7
|
|
|
8
|
+
from unstructured_ingest.utils.data_prep import batch_generator
|
|
9
|
+
|
|
9
10
|
EMBEDDINGS_KEY = "embeddings"
|
|
10
11
|
|
|
11
12
|
|
|
@@ -50,21 +51,37 @@ class BaseEmbeddingEncoder(BaseEncoder, ABC):
|
|
|
50
51
|
exemplary_embedding = self.get_exemplary_embedding()
|
|
51
52
|
return np.isclose(np.linalg.norm(exemplary_embedding), 1.0)
|
|
52
53
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
pass
|
|
54
|
+
def get_client(self):
|
|
55
|
+
raise NotImplementedError
|
|
56
56
|
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
pass
|
|
57
|
+
def embed_batch(self, client: Any, batch: list[str]) -> list[list[float]]:
|
|
58
|
+
raise NotImplementedError
|
|
60
59
|
|
|
61
|
-
def
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
60
|
+
def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
61
|
+
client = self.get_client()
|
|
62
|
+
elements = elements.copy()
|
|
63
|
+
elements_with_text = [e for e in elements if e.get("text")]
|
|
64
|
+
texts = [e["text"] for e in elements_with_text]
|
|
65
|
+
embeddings = []
|
|
66
|
+
try:
|
|
67
|
+
for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
|
|
68
|
+
embeddings = self.embed_batch(client=client, batch=batch)
|
|
69
|
+
embeddings.extend(embeddings)
|
|
70
|
+
except Exception as e:
|
|
71
|
+
raise self.wrap_error(e=e)
|
|
72
|
+
for element, embedding in zip(elements_with_text, embeddings):
|
|
73
|
+
element[EMBEDDINGS_KEY] = embedding
|
|
74
|
+
return elements
|
|
75
|
+
|
|
76
|
+
def _embed_query(self, query: str) -> list[float]:
|
|
77
|
+
client = self.get_client()
|
|
78
|
+
return self.embed_batch(client=client, batch=[query])[0]
|
|
66
79
|
|
|
67
|
-
|
|
80
|
+
def embed_query(self, query: str) -> list[float]:
|
|
81
|
+
try:
|
|
82
|
+
return self._embed_query(query=query)
|
|
83
|
+
except Exception as e:
|
|
84
|
+
raise self.wrap_error(e=e)
|
|
68
85
|
|
|
69
86
|
|
|
70
87
|
@dataclass
|
|
@@ -88,14 +105,35 @@ class AsyncBaseEmbeddingEncoder(BaseEncoder, ABC):
|
|
|
88
105
|
exemplary_embedding = await self.get_exemplary_embedding()
|
|
89
106
|
return np.isclose(np.linalg.norm(exemplary_embedding), 1.0)
|
|
90
107
|
|
|
91
|
-
|
|
108
|
+
def get_client(self):
|
|
109
|
+
raise NotImplementedError
|
|
110
|
+
|
|
111
|
+
async def embed_batch(self, client: Any, batch: list[str]) -> list[list[float]]:
|
|
112
|
+
raise NotImplementedError
|
|
113
|
+
|
|
92
114
|
async def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
93
|
-
|
|
115
|
+
client = self.get_client()
|
|
116
|
+
elements = elements.copy()
|
|
117
|
+
elements_with_text = [e for e in elements if e.get("text")]
|
|
118
|
+
texts = [e["text"] for e in elements_with_text]
|
|
119
|
+
embeddings = []
|
|
120
|
+
try:
|
|
121
|
+
for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
|
|
122
|
+
embeddings = await self.embed_batch(client=client, batch=batch)
|
|
123
|
+
embeddings.extend(embeddings)
|
|
124
|
+
except Exception as e:
|
|
125
|
+
raise self.wrap_error(e=e)
|
|
126
|
+
for element, embedding in zip(elements_with_text, embeddings):
|
|
127
|
+
element[EMBEDDINGS_KEY] = embedding
|
|
128
|
+
return elements
|
|
129
|
+
|
|
130
|
+
async def _embed_query(self, query: str) -> list[float]:
|
|
131
|
+
client = self.get_client()
|
|
132
|
+
embeddings = await self.embed_batch(client=client, batch=[query])
|
|
133
|
+
return embeddings[0]
|
|
94
134
|
|
|
95
|
-
@abstractmethod
|
|
96
135
|
async def embed_query(self, query: str) -> list[float]:
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
return results
|
|
136
|
+
try:
|
|
137
|
+
return await self._embed_query(query=query)
|
|
138
|
+
except Exception as e:
|
|
139
|
+
raise self.wrap_error(e=e)
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import asyncio
|
|
2
1
|
import os
|
|
3
2
|
from dataclasses import dataclass
|
|
4
3
|
from typing import TYPE_CHECKING
|
|
@@ -6,12 +5,10 @@ from typing import TYPE_CHECKING
|
|
|
6
5
|
from pydantic import Field, SecretStr
|
|
7
6
|
|
|
8
7
|
from unstructured_ingest.embed.interfaces import (
|
|
9
|
-
EMBEDDINGS_KEY,
|
|
10
8
|
AsyncBaseEmbeddingEncoder,
|
|
11
9
|
BaseEmbeddingEncoder,
|
|
12
10
|
EmbeddingConfig,
|
|
13
11
|
)
|
|
14
|
-
from unstructured_ingest.utils.data_prep import batch_generator
|
|
15
12
|
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
16
13
|
|
|
17
14
|
USER_AGENT = "@mixedbread-ai/unstructured"
|
|
@@ -85,7 +82,7 @@ class MixedbreadAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
85
82
|
|
|
86
83
|
def get_exemplary_embedding(self) -> list[float]:
|
|
87
84
|
"""Get an exemplary embedding to determine dimensions and unit vector status."""
|
|
88
|
-
return self.
|
|
85
|
+
return self.embed_query(query="Q")
|
|
89
86
|
|
|
90
87
|
@requires_dependencies(
|
|
91
88
|
["mixedbread_ai"],
|
|
@@ -100,59 +97,19 @@ class MixedbreadAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
100
97
|
additional_headers={"User-Agent": USER_AGENT},
|
|
101
98
|
)
|
|
102
99
|
|
|
103
|
-
def
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
|
|
117
|
-
response = client.embeddings(
|
|
118
|
-
model=self.config.embedder_model_name,
|
|
119
|
-
normalized=True,
|
|
120
|
-
encoding_format=ENCODING_FORMAT,
|
|
121
|
-
truncation_strategy=TRUNCATION_STRATEGY,
|
|
122
|
-
request_options=self.get_request_options(),
|
|
123
|
-
input=batch,
|
|
124
|
-
)
|
|
125
|
-
responses.append(response)
|
|
126
|
-
return [item.embedding for response in responses for item in response.data]
|
|
127
|
-
|
|
128
|
-
def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
129
|
-
"""
|
|
130
|
-
Embed a list of document elements.
|
|
131
|
-
|
|
132
|
-
Args:
|
|
133
|
-
elements (list[Element]): List of document elements.
|
|
134
|
-
|
|
135
|
-
Returns:
|
|
136
|
-
list[Element]: Elements with embeddings.
|
|
137
|
-
"""
|
|
138
|
-
elements = elements.copy()
|
|
139
|
-
elements_with_text = [e for e in elements if e.get("text")]
|
|
140
|
-
embeddings = self._embed([e["text"] for e in elements_with_text])
|
|
141
|
-
for element, embedding in zip(elements_with_text, embeddings):
|
|
142
|
-
element[EMBEDDINGS_KEY] = embedding
|
|
143
|
-
return elements
|
|
144
|
-
|
|
145
|
-
def embed_query(self, query: str) -> list[float]:
|
|
146
|
-
"""
|
|
147
|
-
Embed a query string.
|
|
148
|
-
|
|
149
|
-
Args:
|
|
150
|
-
query (str): Query string to embed.
|
|
151
|
-
|
|
152
|
-
Returns:
|
|
153
|
-
list[float]: Embedding of the query.
|
|
154
|
-
"""
|
|
155
|
-
return self._embed([query])[0]
|
|
100
|
+
def get_client(self) -> "MixedbreadAI":
|
|
101
|
+
return self.config.get_client()
|
|
102
|
+
|
|
103
|
+
def embed_batch(self, client: "MixedbreadAI", batch: list[str]) -> list[list[float]]:
|
|
104
|
+
response = client.embeddings(
|
|
105
|
+
model=self.config.embedder_model_name,
|
|
106
|
+
normalized=True,
|
|
107
|
+
encoding_format=ENCODING_FORMAT,
|
|
108
|
+
truncation_strategy=TRUNCATION_STRATEGY,
|
|
109
|
+
request_options=self.get_request_options(),
|
|
110
|
+
input=batch,
|
|
111
|
+
)
|
|
112
|
+
return [datum.embedding for datum in response.data]
|
|
156
113
|
|
|
157
114
|
|
|
158
115
|
@dataclass
|
|
@@ -162,8 +119,7 @@ class AsyncMixedbreadAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
|
|
|
162
119
|
|
|
163
120
|
async def get_exemplary_embedding(self) -> list[float]:
|
|
164
121
|
"""Get an exemplary embedding to determine dimensions and unit vector status."""
|
|
165
|
-
|
|
166
|
-
return embedding[0]
|
|
122
|
+
return await self.embed_query(query="Q")
|
|
167
123
|
|
|
168
124
|
@requires_dependencies(
|
|
169
125
|
["mixedbread_ai"],
|
|
@@ -178,58 +134,16 @@ class AsyncMixedbreadAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
|
|
|
178
134
|
additional_headers={"User-Agent": USER_AGENT},
|
|
179
135
|
)
|
|
180
136
|
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
for
|
|
194
|
-
tasks.append(
|
|
195
|
-
client.embeddings(
|
|
196
|
-
model=self.config.embedder_model_name,
|
|
197
|
-
normalized=True,
|
|
198
|
-
encoding_format=ENCODING_FORMAT,
|
|
199
|
-
truncation_strategy=TRUNCATION_STRATEGY,
|
|
200
|
-
request_options=self.get_request_options(),
|
|
201
|
-
input=batch,
|
|
202
|
-
)
|
|
203
|
-
)
|
|
204
|
-
responses = await asyncio.gather(*tasks)
|
|
205
|
-
return [item.embedding for response in responses for item in response.data]
|
|
206
|
-
|
|
207
|
-
async def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
208
|
-
"""
|
|
209
|
-
Embed a list of document elements.
|
|
210
|
-
|
|
211
|
-
Args:
|
|
212
|
-
elements (list[Element]): List of document elements.
|
|
213
|
-
|
|
214
|
-
Returns:
|
|
215
|
-
list[Element]: Elements with embeddings.
|
|
216
|
-
"""
|
|
217
|
-
elements = elements.copy()
|
|
218
|
-
elements_with_text = [e for e in elements if e.get("text")]
|
|
219
|
-
embeddings = await self._embed([e["text"] for e in elements_with_text])
|
|
220
|
-
for element, embedding in zip(elements_with_text, embeddings):
|
|
221
|
-
element[EMBEDDINGS_KEY] = embedding
|
|
222
|
-
return elements
|
|
223
|
-
|
|
224
|
-
async def embed_query(self, query: str) -> list[float]:
|
|
225
|
-
"""
|
|
226
|
-
Embed a query string.
|
|
227
|
-
|
|
228
|
-
Args:
|
|
229
|
-
query (str): Query string to embed.
|
|
230
|
-
|
|
231
|
-
Returns:
|
|
232
|
-
list[float]: Embedding of the query.
|
|
233
|
-
"""
|
|
234
|
-
embedding = await self._embed([query])
|
|
235
|
-
return embedding[0]
|
|
137
|
+
def get_client(self) -> "AsyncMixedbreadAI":
|
|
138
|
+
return self.config.get_async_client()
|
|
139
|
+
|
|
140
|
+
async def embed_batch(self, client: "AsyncMixedbreadAI", batch: list[str]) -> list[list[float]]:
|
|
141
|
+
response = await client.embeddings(
|
|
142
|
+
model=self.config.embedder_model_name,
|
|
143
|
+
normalized=True,
|
|
144
|
+
encoding_format=ENCODING_FORMAT,
|
|
145
|
+
truncation_strategy=TRUNCATION_STRATEGY,
|
|
146
|
+
request_options=self.get_request_options(),
|
|
147
|
+
input=batch,
|
|
148
|
+
)
|
|
149
|
+
return [datum.embedding for datum in response.data]
|