unstructured-ingest 0.5.1__py3-none-any.whl → 0.5.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unstructured-ingest might be problematic. Click here for more details.

Files changed (25) hide show
  1. test/integration/connectors/test_google_drive.py +141 -0
  2. test/unit/v2/embedders/test_bedrock.py +1 -1
  3. test/unit/v2/embedders/test_huggingface.py +1 -1
  4. unstructured_ingest/__version__.py +1 -1
  5. unstructured_ingest/embed/azure_openai.py +6 -0
  6. unstructured_ingest/embed/bedrock.py +29 -12
  7. unstructured_ingest/embed/huggingface.py +14 -5
  8. unstructured_ingest/embed/interfaces.py +63 -44
  9. unstructured_ingest/embed/mixedbreadai.py +28 -105
  10. unstructured_ingest/embed/octoai.py +19 -44
  11. unstructured_ingest/embed/openai.py +17 -48
  12. unstructured_ingest/embed/togetherai.py +16 -49
  13. unstructured_ingest/embed/vertexai.py +15 -39
  14. unstructured_ingest/embed/voyageai.py +16 -42
  15. unstructured_ingest/v2/errors.py +7 -0
  16. unstructured_ingest/v2/processes/connectors/google_drive.py +132 -3
  17. unstructured_ingest/v2/processes/connectors/neo4j.py +129 -43
  18. unstructured_ingest/v2/processes/connectors/sql/snowflake.py +53 -3
  19. unstructured_ingest/v2/processes/embedder.py +9 -7
  20. {unstructured_ingest-0.5.1.dist-info → unstructured_ingest-0.5.3.dist-info}/METADATA +99 -87
  21. {unstructured_ingest-0.5.1.dist-info → unstructured_ingest-0.5.3.dist-info}/RECORD +25 -25
  22. {unstructured_ingest-0.5.1.dist-info → unstructured_ingest-0.5.3.dist-info}/WHEEL +1 -1
  23. {unstructured_ingest-0.5.1.dist-info → unstructured_ingest-0.5.3.dist-info}/LICENSE.md +0 -0
  24. {unstructured_ingest-0.5.1.dist-info → unstructured_ingest-0.5.3.dist-info}/entry_points.txt +0 -0
  25. {unstructured_ingest-0.5.1.dist-info → unstructured_ingest-0.5.3.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,8 @@
1
1
  import os
2
+ import uuid
2
3
 
3
4
  import pytest
5
+ from googleapiclient.errors import HttpError
4
6
 
5
7
  from test.integration.connectors.utils.constants import (
6
8
  SOURCE_TAG,
@@ -13,6 +15,9 @@ from test.integration.connectors.utils.validation.source import (
13
15
  update_fixtures,
14
16
  )
15
17
  from test.integration.utils import requires_env
18
+ from unstructured_ingest.error import (
19
+ SourceConnectionError,
20
+ )
16
21
  from unstructured_ingest.v2.interfaces import Downloader, Indexer
17
22
  from unstructured_ingest.v2.processes.connectors.google_drive import (
18
23
  CONNECTOR_TYPE,
@@ -25,6 +30,49 @@ from unstructured_ingest.v2.processes.connectors.google_drive import (
25
30
  )
26
31
 
27
32
 
33
+ @pytest.fixture
34
+ def google_drive_connection_config():
35
+ """
36
+ Build a valid GoogleDriveConnectionConfig using the environment variables.
37
+ Expects:
38
+ - GOOGLE_DRIVE_ID
39
+ - GOOGLE_DRIVE_SERVICE_KEY
40
+ """
41
+ drive_id = os.getenv("GOOGLE_DRIVE_ID")
42
+ service_key = os.getenv("GOOGLE_DRIVE_SERVICE_KEY")
43
+ if not drive_id or not service_key:
44
+ pytest.skip("Google Drive credentials not provided in environment variables.")
45
+
46
+ access_config = GoogleDriveAccessConfig(service_account_key=service_key)
47
+ return GoogleDriveConnectionConfig(drive_id=drive_id, access_config=access_config)
48
+
49
+
50
+ @pytest.fixture
51
+ def google_drive_empty_folder(google_drive_connection_config):
52
+ """
53
+ Creates an empty folder on Google Drive for testing the "empty folder" case.
54
+ The folder is deleted after the test.
55
+ """
56
+ from google.oauth2 import service_account
57
+ from googleapiclient.discovery import build
58
+
59
+ access_config = google_drive_connection_config.access_config.get_secret_value()
60
+ creds = service_account.Credentials.from_service_account_info(access_config.service_account_key)
61
+ service = build("drive", "v3", credentials=creds)
62
+
63
+ # Create an empty folder.
64
+ file_metadata = {
65
+ "name": f"utic-empty-folder-{uuid.uuid4()}",
66
+ "mimeType": "application/vnd.google-apps.folder",
67
+ }
68
+ folder = service.files().create(body=file_metadata, fields="id, name").execute()
69
+ folder_id = folder.get("id")
70
+ try:
71
+ yield folder_id
72
+ finally:
73
+ service.files().delete(fileId=folder_id).execute()
74
+
75
+
28
76
  @requires_env("GOOGLE_DRIVE_SERVICE_KEY")
29
77
  @pytest.mark.tags(SOURCE_TAG, CONNECTOR_TYPE)
30
78
  def test_google_drive_source(temp_dir):
@@ -114,3 +162,96 @@ def source_connector_validation(
114
162
  save_downloads=configs.validate_downloaded_files,
115
163
  save_filedata=configs.validate_file_data,
116
164
  )
165
+
166
+
167
+ # Precheck fails when the drive ID has an appended parameter (simulate copy-paste error)
168
+ @pytest.mark.tags("google-drive", "precheck")
169
+ @requires_env("GOOGLE_DRIVE_ID", "GOOGLE_DRIVE_SERVICE_KEY")
170
+ def test_google_drive_precheck_invalid_parameter(google_drive_connection_config):
171
+ # Append a query parameter as often happens when copying from a URL.
172
+ invalid_drive_id = google_drive_connection_config.drive_id + "?usp=sharing"
173
+ connection_config = GoogleDriveConnectionConfig(
174
+ drive_id=invalid_drive_id,
175
+ access_config=google_drive_connection_config.access_config,
176
+ )
177
+ index_config = GoogleDriveIndexerConfig(recursive=True)
178
+ indexer = GoogleDriveIndexer(connection_config=connection_config, index_config=index_config)
179
+ with pytest.raises(SourceConnectionError) as excinfo:
180
+ indexer.precheck()
181
+ assert "invalid" in str(excinfo.value).lower() or "not found" in str(excinfo.value).lower()
182
+
183
+
184
+ # Precheck fails due to lack of permission (simulate via monkeypatching).
185
+ @pytest.mark.tags("google-drive", "precheck")
186
+ @requires_env("GOOGLE_DRIVE_ID", "GOOGLE_DRIVE_SERVICE_KEY")
187
+ def test_google_drive_precheck_no_permission(google_drive_connection_config, monkeypatch):
188
+ index_config = GoogleDriveIndexerConfig(recursive=True)
189
+ indexer = GoogleDriveIndexer(
190
+ connection_config=google_drive_connection_config,
191
+ index_config=index_config,
192
+ )
193
+
194
+ # Monkeypatch get_root_info to always raise an HTTP 403 error.
195
+ def fake_get_root_info(files_client, object_id):
196
+ raise HttpError(
197
+ resp=type("Response", (), {"status": 403, "reason": "Forbidden"})(),
198
+ content=b"Forbidden",
199
+ )
200
+
201
+ monkeypatch.setattr(indexer, "get_root_info", fake_get_root_info)
202
+ with pytest.raises(SourceConnectionError) as excinfo:
203
+ indexer.precheck()
204
+ assert "forbidden" in str(excinfo.value).lower() or "permission" in str(excinfo.value).lower()
205
+
206
+
207
+ # Precheck fails when the folder is empty.
208
+ # @pytest.mark.tags("google-drive", "precheck")
209
+ # @requires_env("GOOGLE_DRIVE_ID", "GOOGLE_DRIVE_SERVICE_KEY")
210
+ # def test_google_drive_precheck_empty_folder(
211
+ # google_drive_connection_config, google_drive_empty_folder
212
+ # ):
213
+ # # Use the empty folder's ID as the target.
214
+ # connection_config = GoogleDriveConnectionConfig(
215
+ # drive_id=google_drive_empty_folder,
216
+ # access_config=google_drive_connection_config.access_config,
217
+ # )
218
+
219
+ # index_config = GoogleDriveIndexerConfig(recursive=True)
220
+ # indexer = GoogleDriveIndexer(connection_config=connection_config, index_config=index_config)
221
+ # with pytest.raises(SourceConnectionError) as excinfo:
222
+ # indexer.precheck()
223
+ # assert "empty folder" in str(excinfo.value).lower()
224
+
225
+
226
+ @pytest.mark.tags("google-drive", "count", "integration")
227
+ @requires_env("GOOGLE_DRIVE_ID", "GOOGLE_DRIVE_SERVICE_KEY")
228
+ def test_google_drive_count_files(google_drive_connection_config):
229
+ """
230
+ This test verifies that the count_files_recursively method returns the expected count of files.
231
+ According to the test credentials, there are 3 files in the root directory and 1 nested file,
232
+ so the total count should be 4.
233
+ """
234
+ # I assumed that we're applying the same extension filter as with other tests
235
+ # However there's 6 files in total in the test dir
236
+ extensions_filter = ["pdf", "docx"]
237
+ with google_drive_connection_config.get_client() as client:
238
+ count = GoogleDriveIndexer.count_files_recursively(
239
+ client, google_drive_connection_config.drive_id, extensions_filter
240
+ )
241
+ assert count == 4, f"Expected file count of 4, but got {count}"
242
+
243
+
244
+ # Precheck fails with a completely invalid drive ID.
245
+ @pytest.mark.tags("google-drive", "precheck")
246
+ @requires_env("GOOGLE_DRIVE_ID", "GOOGLE_DRIVE_SERVICE_KEY")
247
+ def test_google_drive_precheck_invalid_drive_id(google_drive_connection_config):
248
+ invalid_drive_id = "invalid_drive_id"
249
+ connection_config = GoogleDriveConnectionConfig(
250
+ drive_id=invalid_drive_id,
251
+ access_config=google_drive_connection_config.access_config,
252
+ )
253
+ index_config = GoogleDriveIndexerConfig(recursive=True)
254
+ indexer = GoogleDriveIndexer(connection_config=connection_config, index_config=index_config)
255
+ with pytest.raises(SourceConnectionError) as excinfo:
256
+ indexer.precheck()
257
+ assert "invalid" in str(excinfo.value).lower() or "not found" in str(excinfo.value).lower()
@@ -15,7 +15,7 @@ def generate_embedder_config_params() -> dict:
15
15
  "region_name": fake.city(),
16
16
  }
17
17
  if random.random() < 0.5:
18
- params["embed_model_name"] = fake.word()
18
+ params["embedder_model_name"] = fake.word()
19
19
  return params
20
20
 
21
21
 
@@ -16,7 +16,7 @@ fake = faker.Faker()
16
16
  def generate_embedder_config_params() -> dict:
17
17
  params = {}
18
18
  if random.random() < 0.5:
19
- params["embed_model_name"] = fake.word() if random.random() < 0.5 else None
19
+ params["embedder_model_name"] = fake.word() if random.random() < 0.5 else None
20
20
  params["embedder_model_kwargs"] = (
21
21
  generate_random_dictionary(key_type=str, value_type=Any)
22
22
  if random.random() < 0.5
@@ -1 +1 @@
1
- __version__ = "0.5.1" # pragma: no cover
1
+ __version__ = "0.5.3" # pragma: no cover
@@ -44,7 +44,13 @@ class AzureOpenAIEmbeddingConfig(OpenAIEmbeddingConfig):
44
44
  class AzureOpenAIEmbeddingEncoder(OpenAIEmbeddingEncoder):
45
45
  config: AzureOpenAIEmbeddingConfig
46
46
 
47
+ def get_client(self) -> "AzureOpenAI":
48
+ return self.config.get_client()
49
+
47
50
 
48
51
  @dataclass
49
52
  class AsyncAzureOpenAIEmbeddingEncoder(AsyncOpenAIEmbeddingEncoder):
50
53
  config: AzureOpenAIEmbeddingConfig
54
+
55
+ def get_client(self) -> "AsyncAzureOpenAI":
56
+ return self.config.get_async_client()
@@ -8,13 +8,20 @@ from typing import TYPE_CHECKING, AsyncIterable
8
8
  from pydantic import Field, SecretStr
9
9
 
10
10
  from unstructured_ingest.embed.interfaces import (
11
+ EMBEDDINGS_KEY,
11
12
  AsyncBaseEmbeddingEncoder,
12
13
  BaseEmbeddingEncoder,
13
14
  EmbeddingConfig,
14
15
  )
15
16
  from unstructured_ingest.logger import logger
16
17
  from unstructured_ingest.utils.dep_check import requires_dependencies
17
- from unstructured_ingest.v2.errors import ProviderError, RateLimitError, UserAuthError, UserError
18
+ from unstructured_ingest.v2.errors import (
19
+ ProviderError,
20
+ RateLimitError,
21
+ UserAuthError,
22
+ UserError,
23
+ is_internal_error,
24
+ )
18
25
 
19
26
  if TYPE_CHECKING:
20
27
  from botocore.client import BaseClient
@@ -50,9 +57,11 @@ class BedrockEmbeddingConfig(EmbeddingConfig):
50
57
  aws_access_key_id: SecretStr
51
58
  aws_secret_access_key: SecretStr
52
59
  region_name: str = "us-west-2"
53
- embed_model_name: str = Field(default="amazon.titan-embed-text-v1", alias="model_name")
60
+ embedder_model_name: str = Field(default="amazon.titan-embed-text-v1", alias="model_name")
54
61
 
55
62
  def wrap_error(self, e: Exception) -> Exception:
63
+ if is_internal_error(e=e):
64
+ return e
56
65
  from botocore.exceptions import ClientError
57
66
 
58
67
  if isinstance(e, ClientError):
@@ -121,7 +130,7 @@ class BedrockEmbeddingEncoder(BaseEmbeddingEncoder):
121
130
 
122
131
  def embed_query(self, query: str) -> list[float]:
123
132
  """Call out to Bedrock embedding endpoint."""
124
- provider = self.config.embed_model_name.split(".")[0]
133
+ provider = self.config.embedder_model_name.split(".")[0]
125
134
  body = conform_query(query=query, provider=provider)
126
135
 
127
136
  bedrock_client = self.config.get_client()
@@ -129,7 +138,7 @@ class BedrockEmbeddingEncoder(BaseEmbeddingEncoder):
129
138
  try:
130
139
  response = bedrock_client.invoke_model(
131
140
  body=json.dumps(body),
132
- modelId=self.config.embed_model_name,
141
+ modelId=self.config.embedder_model_name,
133
142
  accept="application/json",
134
143
  contentType="application/json",
135
144
  )
@@ -145,9 +154,14 @@ class BedrockEmbeddingEncoder(BaseEmbeddingEncoder):
145
154
  return response_body.get("embedding")
146
155
 
147
156
  def embed_documents(self, elements: list[dict]) -> list[dict]:
148
- embeddings = [self.embed_query(query=e.get("text", "")) for e in elements]
149
- elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
150
- return elements_with_embeddings
157
+ elements = elements.copy()
158
+ elements_with_text = [e for e in elements if e.get("text")]
159
+ if not elements_with_text:
160
+ return elements
161
+ embeddings = [self.embed_query(query=e["text"]) for e in elements_with_text]
162
+ for element, embedding in zip(elements_with_text, embeddings):
163
+ element[EMBEDDINGS_KEY] = embedding
164
+ return elements
151
165
 
152
166
 
153
167
  @dataclass
@@ -159,7 +173,7 @@ class AsyncBedrockEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
159
173
 
160
174
  async def embed_query(self, query: str) -> list[float]:
161
175
  """Call out to Bedrock embedding endpoint."""
162
- provider = self.config.embed_model_name.split(".")[0]
176
+ provider = self.config.embedder_model_name.split(".")[0]
163
177
  body = conform_query(query=query, provider=provider)
164
178
  try:
165
179
  async with self.config.get_async_client() as bedrock_client:
@@ -167,7 +181,7 @@ class AsyncBedrockEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
167
181
  try:
168
182
  response = await bedrock_client.invoke_model(
169
183
  body=json.dumps(body),
170
- modelId=self.config.embed_model_name,
184
+ modelId=self.config.embedder_model_name,
171
185
  accept="application/json",
172
186
  contentType="application/json",
173
187
  )
@@ -186,8 +200,11 @@ class AsyncBedrockEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
186
200
  raise ValueError(f"Error raised by inference endpoint: {e}")
187
201
 
188
202
  async def embed_documents(self, elements: list[dict]) -> list[dict]:
203
+ elements = elements.copy()
204
+ elements_with_text = [e for e in elements if e.get("text")]
189
205
  embeddings = await asyncio.gather(
190
- *[self.embed_query(query=e.get("text", "")) for e in elements]
206
+ *[self.embed_query(query=e.get("text", "")) for e in elements_with_text]
191
207
  )
192
- elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
193
- return elements_with_embeddings
208
+ for element, embedding in zip(elements_with_text, embeddings):
209
+ element[EMBEDDINGS_KEY] = embedding
210
+ return elements
@@ -3,7 +3,11 @@ from typing import TYPE_CHECKING, Optional
3
3
 
4
4
  from pydantic import Field
5
5
 
6
- from unstructured_ingest.embed.interfaces import BaseEmbeddingEncoder, EmbeddingConfig
6
+ from unstructured_ingest.embed.interfaces import (
7
+ EMBEDDINGS_KEY,
8
+ BaseEmbeddingEncoder,
9
+ EmbeddingConfig,
10
+ )
7
11
  from unstructured_ingest.utils.dep_check import requires_dependencies
8
12
 
9
13
  if TYPE_CHECKING:
@@ -43,7 +47,7 @@ class HuggingFaceEmbeddingConfig(EmbeddingConfig):
43
47
  class HuggingFaceEmbeddingEncoder(BaseEmbeddingEncoder):
44
48
  config: HuggingFaceEmbeddingConfig
45
49
 
46
- def embed_query(self, query: str) -> list[float]:
50
+ def _embed_query(self, query: str) -> list[float]:
47
51
  return self._embed_documents(texts=[query])[0]
48
52
 
49
53
  def _embed_documents(self, texts: list[str]) -> list[list[float]]:
@@ -52,6 +56,11 @@ class HuggingFaceEmbeddingEncoder(BaseEmbeddingEncoder):
52
56
  return embeddings.tolist()
53
57
 
54
58
  def embed_documents(self, elements: list[dict]) -> list[dict]:
55
- embeddings = self._embed_documents([e.get("text", "") for e in elements])
56
- elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
57
- return elements_with_embeddings
59
+ elements = elements.copy()
60
+ elements_with_text = [e for e in elements if e.get("text")]
61
+ if not elements_with_text:
62
+ return elements
63
+ embeddings = self._embed_documents([e["text"] for e in elements_with_text])
64
+ for element, embedding in zip(elements_with_text, embeddings):
65
+ element[EMBEDDINGS_KEY] = embedding
66
+ return elements
@@ -1,11 +1,14 @@
1
- import asyncio
2
- from abc import ABC, abstractmethod
1
+ from abc import ABC
3
2
  from dataclasses import dataclass
4
- from typing import Optional
3
+ from typing import Any, Optional
5
4
 
6
5
  import numpy as np
7
6
  from pydantic import BaseModel, Field
8
7
 
8
+ from unstructured_ingest.utils.data_prep import batch_generator
9
+
10
+ EMBEDDINGS_KEY = "embeddings"
11
+
9
12
 
10
13
  class EmbeddingConfig(BaseModel):
11
14
  batch_size: Optional[int] = Field(
@@ -26,27 +29,6 @@ class BaseEncoder(ABC):
26
29
  if possible"""
27
30
  return e
28
31
 
29
- @staticmethod
30
- def _add_embeddings_to_elements(
31
- elements: list[dict], embeddings: list[list[float]]
32
- ) -> list[dict]:
33
- """
34
- Add embeddings to elements.
35
-
36
- Args:
37
- elements (list[Element]): List of elements.
38
- embeddings (list[list[float]]): List of embeddings.
39
-
40
- Returns:
41
- list[Element]: Elements with embeddings added.
42
- """
43
- assert len(elements) == len(embeddings)
44
- elements_w_embedding = []
45
- for i, element in enumerate(elements):
46
- element["embeddings"] = embeddings[i]
47
- elements_w_embedding.append(element)
48
- return elements
49
-
50
32
 
51
33
  @dataclass
52
34
  class BaseEmbeddingEncoder(BaseEncoder, ABC):
@@ -69,21 +51,37 @@ class BaseEmbeddingEncoder(BaseEncoder, ABC):
69
51
  exemplary_embedding = self.get_exemplary_embedding()
70
52
  return np.isclose(np.linalg.norm(exemplary_embedding), 1.0)
71
53
 
72
- @abstractmethod
73
- def embed_documents(self, elements: list[dict]) -> list[dict]:
74
- pass
54
+ def get_client(self):
55
+ raise NotImplementedError
75
56
 
76
- @abstractmethod
77
- def embed_query(self, query: str) -> list[float]:
78
- pass
57
+ def embed_batch(self, client: Any, batch: list[str]) -> list[list[float]]:
58
+ raise NotImplementedError
79
59
 
80
- def _embed_documents(self, elements: list[str]) -> list[list[float]]:
81
- results = []
82
- for text in elements:
83
- response = self.embed_query(query=text)
84
- results.append(response)
60
+ def embed_documents(self, elements: list[dict]) -> list[dict]:
61
+ client = self.get_client()
62
+ elements = elements.copy()
63
+ elements_with_text = [e for e in elements if e.get("text")]
64
+ texts = [e["text"] for e in elements_with_text]
65
+ embeddings = []
66
+ try:
67
+ for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
68
+ embeddings = self.embed_batch(client=client, batch=batch)
69
+ embeddings.extend(embeddings)
70
+ except Exception as e:
71
+ raise self.wrap_error(e=e)
72
+ for element, embedding in zip(elements_with_text, embeddings):
73
+ element[EMBEDDINGS_KEY] = embedding
74
+ return elements
85
75
 
86
- return results
76
+ def _embed_query(self, query: str) -> list[float]:
77
+ client = self.get_client()
78
+ return self.embed_batch(client=client, batch=[query])[0]
79
+
80
+ def embed_query(self, query: str) -> list[float]:
81
+ try:
82
+ return self._embed_query(query=query)
83
+ except Exception as e:
84
+ raise self.wrap_error(e=e)
87
85
 
88
86
 
89
87
  @dataclass
@@ -107,14 +105,35 @@ class AsyncBaseEmbeddingEncoder(BaseEncoder, ABC):
107
105
  exemplary_embedding = await self.get_exemplary_embedding()
108
106
  return np.isclose(np.linalg.norm(exemplary_embedding), 1.0)
109
107
 
110
- @abstractmethod
108
+ def get_client(self):
109
+ raise NotImplementedError
110
+
111
+ async def embed_batch(self, client: Any, batch: list[str]) -> list[list[float]]:
112
+ raise NotImplementedError
113
+
111
114
  async def embed_documents(self, elements: list[dict]) -> list[dict]:
112
- pass
115
+ client = self.get_client()
116
+ elements = elements.copy()
117
+ elements_with_text = [e for e in elements if e.get("text")]
118
+ texts = [e["text"] for e in elements_with_text]
119
+ embeddings = []
120
+ try:
121
+ for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
122
+ embeddings = await self.embed_batch(client=client, batch=batch)
123
+ embeddings.extend(embeddings)
124
+ except Exception as e:
125
+ raise self.wrap_error(e=e)
126
+ for element, embedding in zip(elements_with_text, embeddings):
127
+ element[EMBEDDINGS_KEY] = embedding
128
+ return elements
113
129
 
114
- @abstractmethod
115
- async def embed_query(self, query: str) -> list[float]:
116
- pass
130
+ async def _embed_query(self, query: str) -> list[float]:
131
+ client = self.get_client()
132
+ embeddings = await self.embed_batch(client=client, batch=[query])
133
+ return embeddings[0]
117
134
 
118
- async def _embed_documents(self, elements: list[str]) -> list[list[float]]:
119
- results = await asyncio.gather(*[self.embed_query(query=text) for text in elements])
120
- return results
135
+ async def embed_query(self, query: str) -> list[float]:
136
+ try:
137
+ return await self._embed_query(query=query)
138
+ except Exception as e:
139
+ raise self.wrap_error(e=e)
@@ -1,4 +1,3 @@
1
- import asyncio
2
1
  import os
3
2
  from dataclasses import dataclass
4
3
  from typing import TYPE_CHECKING
@@ -10,7 +9,6 @@ from unstructured_ingest.embed.interfaces import (
10
9
  BaseEmbeddingEncoder,
11
10
  EmbeddingConfig,
12
11
  )
13
- from unstructured_ingest.utils.data_prep import batch_generator
14
12
  from unstructured_ingest.utils.dep_check import requires_dependencies
15
13
 
16
14
  USER_AGENT = "@mixedbread-ai/unstructured"
@@ -84,7 +82,7 @@ class MixedbreadAIEmbeddingEncoder(BaseEmbeddingEncoder):
84
82
 
85
83
  def get_exemplary_embedding(self) -> list[float]:
86
84
  """Get an exemplary embedding to determine dimensions and unit vector status."""
87
- return self._embed(["Q"])[0]
85
+ return self.embed_query(query="Q")
88
86
 
89
87
  @requires_dependencies(
90
88
  ["mixedbread_ai"],
@@ -99,55 +97,19 @@ class MixedbreadAIEmbeddingEncoder(BaseEmbeddingEncoder):
99
97
  additional_headers={"User-Agent": USER_AGENT},
100
98
  )
101
99
 
102
- def _embed(self, texts: list[str]) -> list[list[float]]:
103
- """
104
- Embed a list of texts using the Mixedbread AI API.
105
-
106
- Args:
107
- texts (list[str]): List of texts to embed.
108
-
109
- Returns:
110
- list[list[float]]: List of embeddings.
111
- """
112
-
113
- responses = []
114
- client = self.config.get_client()
115
- for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
116
- response = client.embeddings(
117
- model=self.config.embedder_model_name,
118
- normalized=True,
119
- encoding_format=ENCODING_FORMAT,
120
- truncation_strategy=TRUNCATION_STRATEGY,
121
- request_options=self.get_request_options(),
122
- input=batch,
123
- )
124
- responses.append(response)
125
- return [item.embedding for response in responses for item in response.data]
126
-
127
- def embed_documents(self, elements: list[dict]) -> list[dict]:
128
- """
129
- Embed a list of document elements.
130
-
131
- Args:
132
- elements (list[Element]): List of document elements.
133
-
134
- Returns:
135
- list[Element]: Elements with embeddings.
136
- """
137
- embeddings = self._embed([e.get("text", "") for e in elements])
138
- return self._add_embeddings_to_elements(elements, embeddings)
139
-
140
- def embed_query(self, query: str) -> list[float]:
141
- """
142
- Embed a query string.
143
-
144
- Args:
145
- query (str): Query string to embed.
146
-
147
- Returns:
148
- list[float]: Embedding of the query.
149
- """
150
- return self._embed([query])[0]
100
+ def get_client(self) -> "MixedbreadAI":
101
+ return self.config.get_client()
102
+
103
+ def embed_batch(self, client: "MixedbreadAI", batch: list[str]) -> list[list[float]]:
104
+ response = client.embeddings(
105
+ model=self.config.embedder_model_name,
106
+ normalized=True,
107
+ encoding_format=ENCODING_FORMAT,
108
+ truncation_strategy=TRUNCATION_STRATEGY,
109
+ request_options=self.get_request_options(),
110
+ input=batch,
111
+ )
112
+ return [datum.embedding for datum in response.data]
151
113
 
152
114
 
153
115
  @dataclass
@@ -157,8 +119,7 @@ class AsyncMixedbreadAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
157
119
 
158
120
  async def get_exemplary_embedding(self) -> list[float]:
159
121
  """Get an exemplary embedding to determine dimensions and unit vector status."""
160
- embedding = await self._embed(["Q"])
161
- return embedding[0]
122
+ return await self.embed_query(query="Q")
162
123
 
163
124
  @requires_dependencies(
164
125
  ["mixedbread_ai"],
@@ -173,54 +134,16 @@ class AsyncMixedbreadAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
173
134
  additional_headers={"User-Agent": USER_AGENT},
174
135
  )
175
136
 
176
- async def _embed(self, texts: list[str]) -> list[list[float]]:
177
- """
178
- Embed a list of texts using the Mixedbread AI API.
179
-
180
- Args:
181
- texts (list[str]): List of texts to embed.
182
-
183
- Returns:
184
- list[list[float]]: List of embeddings.
185
- """
186
- client = self.config.get_async_client()
187
- tasks = []
188
- for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
189
- tasks.append(
190
- client.embeddings(
191
- model=self.config.embedder_model_name,
192
- normalized=True,
193
- encoding_format=ENCODING_FORMAT,
194
- truncation_strategy=TRUNCATION_STRATEGY,
195
- request_options=self.get_request_options(),
196
- input=batch,
197
- )
198
- )
199
- responses = await asyncio.gather(*tasks)
200
- return [item.embedding for response in responses for item in response.data]
201
-
202
- async def embed_documents(self, elements: list[dict]) -> list[dict]:
203
- """
204
- Embed a list of document elements.
205
-
206
- Args:
207
- elements (list[Element]): List of document elements.
208
-
209
- Returns:
210
- list[Element]: Elements with embeddings.
211
- """
212
- embeddings = await self._embed([e.get("text", "") for e in elements])
213
- return self._add_embeddings_to_elements(elements, embeddings)
214
-
215
- async def embed_query(self, query: str) -> list[float]:
216
- """
217
- Embed a query string.
218
-
219
- Args:
220
- query (str): Query string to embed.
221
-
222
- Returns:
223
- list[float]: Embedding of the query.
224
- """
225
- embedding = await self._embed([query])
226
- return embedding[0]
137
+ def get_client(self) -> "AsyncMixedbreadAI":
138
+ return self.config.get_async_client()
139
+
140
+ async def embed_batch(self, client: "AsyncMixedbreadAI", batch: list[str]) -> list[list[float]]:
141
+ response = await client.embeddings(
142
+ model=self.config.embedder_model_name,
143
+ normalized=True,
144
+ encoding_format=ENCODING_FORMAT,
145
+ truncation_strategy=TRUNCATION_STRATEGY,
146
+ request_options=self.get_request_options(),
147
+ input=batch,
148
+ )
149
+ return [datum.embedding for datum in response.data]