unstructured-ingest 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unstructured-ingest might be problematic. Click here for more details.

Files changed (27) hide show
  1. test/integration/connectors/sql/test_vastdb.py +34 -0
  2. test/integration/connectors/test_google_drive.py +257 -0
  3. test/unit/v2/connectors/motherduck/__init__.py +0 -0
  4. test/unit/v2/connectors/motherduck/test_base.py +74 -0
  5. unstructured_ingest/__version__.py +1 -1
  6. unstructured_ingest/embed/bedrock.py +13 -6
  7. unstructured_ingest/embed/huggingface.py +11 -4
  8. unstructured_ingest/embed/interfaces.py +2 -21
  9. unstructured_ingest/embed/mixedbreadai.py +13 -4
  10. unstructured_ingest/embed/octoai.py +13 -6
  11. unstructured_ingest/embed/openai.py +13 -6
  12. unstructured_ingest/embed/togetherai.py +13 -4
  13. unstructured_ingest/embed/vertexai.py +13 -6
  14. unstructured_ingest/embed/voyageai.py +13 -4
  15. unstructured_ingest/v2/processes/connectors/duckdb/base.py +2 -0
  16. unstructured_ingest/v2/processes/connectors/duckdb/motherduck.py +5 -4
  17. unstructured_ingest/v2/processes/connectors/google_drive.py +144 -13
  18. unstructured_ingest/v2/processes/connectors/pinecone.py +1 -0
  19. unstructured_ingest/v2/processes/connectors/sql/snowflake.py +53 -3
  20. unstructured_ingest/v2/processes/connectors/sql/sql.py +3 -47
  21. unstructured_ingest/v2/processes/connectors/sql/vastdb.py +4 -12
  22. {unstructured_ingest-0.5.0.dist-info → unstructured_ingest-0.5.2.dist-info}/METADATA +18 -18
  23. {unstructured_ingest-0.5.0.dist-info → unstructured_ingest-0.5.2.dist-info}/RECORD +27 -23
  24. {unstructured_ingest-0.5.0.dist-info → unstructured_ingest-0.5.2.dist-info}/LICENSE.md +0 -0
  25. {unstructured_ingest-0.5.0.dist-info → unstructured_ingest-0.5.2.dist-info}/WHEEL +0 -0
  26. {unstructured_ingest-0.5.0.dist-info → unstructured_ingest-0.5.2.dist-info}/entry_points.txt +0 -0
  27. {unstructured_ingest-0.5.0.dist-info → unstructured_ingest-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,34 @@
1
+ from pathlib import Path
2
+
3
+ import pytest
4
+ from _pytest.fixtures import TopRequest
5
+
6
+ from test.integration.connectors.utils.constants import DESTINATION_TAG, SQL_TAG
7
+ from test.integration.connectors.utils.validation.destination import (
8
+ StagerValidationConfigs,
9
+ stager_validation,
10
+ )
11
+ from unstructured_ingest.v2.processes.connectors.sql.vastdb import (
12
+ CONNECTOR_TYPE,
13
+ VastdbUploadStager,
14
+ VastdbUploadStagerConfig,
15
+ )
16
+
17
+
18
+ @pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG, SQL_TAG)
19
+ @pytest.mark.parametrize("upload_file_str", ["upload_file_ndjson", "upload_file"])
20
+ def test_vast_stager(
21
+ request: TopRequest,
22
+ upload_file_str: str,
23
+ tmp_path: Path,
24
+ ):
25
+ upload_file: Path = request.getfixturevalue(upload_file_str)
26
+ stager = VastdbUploadStager(
27
+ upload_stager_config=VastdbUploadStagerConfig(rename_columns_map={"page_number": "page"})
28
+ )
29
+ stager_validation(
30
+ configs=StagerValidationConfigs(test_id=CONNECTOR_TYPE, expected_count=22),
31
+ input_file=upload_file,
32
+ stager=stager,
33
+ tmp_dir=tmp_path,
34
+ )
@@ -0,0 +1,257 @@
1
+ import os
2
+ import uuid
3
+
4
+ import pytest
5
+ from googleapiclient.errors import HttpError
6
+
7
+ from test.integration.connectors.utils.constants import (
8
+ SOURCE_TAG,
9
+ UNCATEGORIZED_TAG,
10
+ )
11
+ from test.integration.connectors.utils.validation.source import (
12
+ SourceValidationConfigs,
13
+ get_all_file_data,
14
+ run_all_validations,
15
+ update_fixtures,
16
+ )
17
+ from test.integration.utils import requires_env
18
+ from unstructured_ingest.error import (
19
+ SourceConnectionError,
20
+ )
21
+ from unstructured_ingest.v2.interfaces import Downloader, Indexer
22
+ from unstructured_ingest.v2.processes.connectors.google_drive import (
23
+ CONNECTOR_TYPE,
24
+ GoogleDriveAccessConfig,
25
+ GoogleDriveConnectionConfig,
26
+ GoogleDriveDownloader,
27
+ GoogleDriveDownloaderConfig,
28
+ GoogleDriveIndexer,
29
+ GoogleDriveIndexerConfig,
30
+ )
31
+
32
+
33
+ @pytest.fixture
34
+ def google_drive_connection_config():
35
+ """
36
+ Build a valid GoogleDriveConnectionConfig using the environment variables.
37
+ Expects:
38
+ - GOOGLE_DRIVE_ID
39
+ - GOOGLE_DRIVE_SERVICE_KEY
40
+ """
41
+ drive_id = os.getenv("GOOGLE_DRIVE_ID")
42
+ service_key = os.getenv("GOOGLE_DRIVE_SERVICE_KEY")
43
+ if not drive_id or not service_key:
44
+ pytest.skip("Google Drive credentials not provided in environment variables.")
45
+
46
+ access_config = GoogleDriveAccessConfig(service_account_key=service_key)
47
+ return GoogleDriveConnectionConfig(drive_id=drive_id, access_config=access_config)
48
+
49
+
50
+ @pytest.fixture
51
+ def google_drive_empty_folder(google_drive_connection_config):
52
+ """
53
+ Creates an empty folder on Google Drive for testing the "empty folder" case.
54
+ The folder is deleted after the test.
55
+ """
56
+ from google.oauth2 import service_account
57
+ from googleapiclient.discovery import build
58
+
59
+ access_config = google_drive_connection_config.access_config.get_secret_value()
60
+ creds = service_account.Credentials.from_service_account_info(access_config.service_account_key)
61
+ service = build("drive", "v3", credentials=creds)
62
+
63
+ # Create an empty folder.
64
+ file_metadata = {
65
+ "name": f"utic-empty-folder-{uuid.uuid4()}",
66
+ "mimeType": "application/vnd.google-apps.folder",
67
+ }
68
+ folder = service.files().create(body=file_metadata, fields="id, name").execute()
69
+ folder_id = folder.get("id")
70
+ try:
71
+ yield folder_id
72
+ finally:
73
+ service.files().delete(fileId=folder_id).execute()
74
+
75
+
76
+ @requires_env("GOOGLE_DRIVE_SERVICE_KEY")
77
+ @pytest.mark.tags(SOURCE_TAG, CONNECTOR_TYPE)
78
+ def test_google_drive_source(temp_dir):
79
+ # Retrieve environment variables
80
+ service_account_key = os.environ["GOOGLE_DRIVE_SERVICE_KEY"]
81
+
82
+ # Create connection and indexer configurations
83
+ access_config = GoogleDriveAccessConfig(service_account_key=service_account_key)
84
+ connection_config = GoogleDriveConnectionConfig(
85
+ drive_id="1XidSOO76VpZ4m0i3gJN2m1X0Obol3UAi",
86
+ access_config=access_config,
87
+ )
88
+ index_config = GoogleDriveIndexerConfig(recursive=True)
89
+
90
+ download_config = GoogleDriveDownloaderConfig(download_dir=temp_dir)
91
+
92
+ # Instantiate indexer and downloader
93
+ indexer = GoogleDriveIndexer(
94
+ connection_config=connection_config,
95
+ index_config=index_config,
96
+ )
97
+ downloader = GoogleDriveDownloader(
98
+ connection_config=connection_config,
99
+ download_config=download_config,
100
+ )
101
+
102
+ # Run the source connector validation
103
+ source_connector_validation(
104
+ indexer=indexer,
105
+ downloader=downloader,
106
+ configs=SourceValidationConfigs(
107
+ test_id="google_drive_source",
108
+ expected_num_files=1,
109
+ validate_downloaded_files=True,
110
+ ),
111
+ )
112
+
113
+
114
+ @pytest.mark.tags(SOURCE_TAG, CONNECTOR_TYPE, UNCATEGORIZED_TAG)
115
+ def source_connector_validation(
116
+ indexer: Indexer,
117
+ downloader: Downloader,
118
+ configs: SourceValidationConfigs,
119
+ overwrite_fixtures: bool = os.getenv("OVERWRITE_FIXTURES", "False").lower() == "true",
120
+ ) -> None:
121
+ # Run common validations on the process of running a source connector, supporting dynamic
122
+ # validators that get passed in along with comparisons on the saved expected values.
123
+ # If overwrite_fixtures is st to True, will ignore all validators but instead overwrite the
124
+ # expected values with what gets generated by this test.
125
+ all_predownload_file_data = []
126
+ all_postdownload_file_data = []
127
+ indexer.precheck()
128
+ download_dir = downloader.download_config.download_dir
129
+ test_output_dir = configs.test_output_dir()
130
+
131
+ for file_data in indexer.run():
132
+ assert file_data
133
+ predownload_file_data = file_data.model_copy(deep=True)
134
+ all_predownload_file_data.append(predownload_file_data)
135
+ resp = downloader.run(file_data=file_data)
136
+ if isinstance(resp, list):
137
+ for r in resp:
138
+ postdownload_file_data = r["file_data"].model_copy(deep=True)
139
+ all_postdownload_file_data.append(postdownload_file_data)
140
+ else:
141
+ postdownload_file_data = resp["file_data"].model_copy(deep=True)
142
+ all_postdownload_file_data.append(postdownload_file_data)
143
+
144
+ if not overwrite_fixtures:
145
+ print("Running validation")
146
+ run_all_validations(
147
+ configs=configs,
148
+ predownload_file_data=all_predownload_file_data,
149
+ postdownload_file_data=all_postdownload_file_data,
150
+ download_dir=download_dir,
151
+ test_output_dir=test_output_dir,
152
+ )
153
+ else:
154
+ print("Running fixtures update")
155
+ update_fixtures(
156
+ output_dir=test_output_dir,
157
+ download_dir=download_dir,
158
+ all_file_data=get_all_file_data(
159
+ all_predownload_file_data=all_predownload_file_data,
160
+ all_postdownload_file_data=all_postdownload_file_data,
161
+ ),
162
+ save_downloads=configs.validate_downloaded_files,
163
+ save_filedata=configs.validate_file_data,
164
+ )
165
+
166
+
167
+ # Precheck fails when the drive ID has an appended parameter (simulate copy-paste error)
168
+ @pytest.mark.tags("google-drive", "precheck")
169
+ @requires_env("GOOGLE_DRIVE_ID", "GOOGLE_DRIVE_SERVICE_KEY")
170
+ def test_google_drive_precheck_invalid_parameter(google_drive_connection_config):
171
+ # Append a query parameter as often happens when copying from a URL.
172
+ invalid_drive_id = google_drive_connection_config.drive_id + "?usp=sharing"
173
+ connection_config = GoogleDriveConnectionConfig(
174
+ drive_id=invalid_drive_id,
175
+ access_config=google_drive_connection_config.access_config,
176
+ )
177
+ index_config = GoogleDriveIndexerConfig(recursive=True)
178
+ indexer = GoogleDriveIndexer(connection_config=connection_config, index_config=index_config)
179
+ with pytest.raises(SourceConnectionError) as excinfo:
180
+ indexer.precheck()
181
+ assert "invalid" in str(excinfo.value).lower() or "not found" in str(excinfo.value).lower()
182
+
183
+
184
+ # Precheck fails due to lack of permission (simulate via monkeypatching).
185
+ @pytest.mark.tags("google-drive", "precheck")
186
+ @requires_env("GOOGLE_DRIVE_ID", "GOOGLE_DRIVE_SERVICE_KEY")
187
+ def test_google_drive_precheck_no_permission(google_drive_connection_config, monkeypatch):
188
+ index_config = GoogleDriveIndexerConfig(recursive=True)
189
+ indexer = GoogleDriveIndexer(
190
+ connection_config=google_drive_connection_config,
191
+ index_config=index_config,
192
+ )
193
+
194
+ # Monkeypatch get_root_info to always raise an HTTP 403 error.
195
+ def fake_get_root_info(files_client, object_id):
196
+ raise HttpError(
197
+ resp=type("Response", (), {"status": 403, "reason": "Forbidden"})(),
198
+ content=b"Forbidden",
199
+ )
200
+
201
+ monkeypatch.setattr(indexer, "get_root_info", fake_get_root_info)
202
+ with pytest.raises(SourceConnectionError) as excinfo:
203
+ indexer.precheck()
204
+ assert "forbidden" in str(excinfo.value).lower() or "permission" in str(excinfo.value).lower()
205
+
206
+
207
+ # Precheck fails when the folder is empty.
208
+ # @pytest.mark.tags("google-drive", "precheck")
209
+ # @requires_env("GOOGLE_DRIVE_ID", "GOOGLE_DRIVE_SERVICE_KEY")
210
+ # def test_google_drive_precheck_empty_folder(
211
+ # google_drive_connection_config, google_drive_empty_folder
212
+ # ):
213
+ # # Use the empty folder's ID as the target.
214
+ # connection_config = GoogleDriveConnectionConfig(
215
+ # drive_id=google_drive_empty_folder,
216
+ # access_config=google_drive_connection_config.access_config,
217
+ # )
218
+
219
+ # index_config = GoogleDriveIndexerConfig(recursive=True)
220
+ # indexer = GoogleDriveIndexer(connection_config=connection_config, index_config=index_config)
221
+ # with pytest.raises(SourceConnectionError) as excinfo:
222
+ # indexer.precheck()
223
+ # assert "empty folder" in str(excinfo.value).lower()
224
+
225
+
226
+ @pytest.mark.tags("google-drive", "count", "integration")
227
+ @requires_env("GOOGLE_DRIVE_ID", "GOOGLE_DRIVE_SERVICE_KEY")
228
+ def test_google_drive_count_files(google_drive_connection_config):
229
+ """
230
+ This test verifies that the count_files_recursively method returns the expected count of files.
231
+ According to the test credentials, there are 3 files in the root directory and 1 nested file,
232
+ so the total count should be 4.
233
+ """
234
+ # I assumed that we're applying the same extension filter as with other tests
235
+ # However there's 6 files in total in the test dir
236
+ extensions_filter = ["pdf", "docx"]
237
+ with google_drive_connection_config.get_client() as client:
238
+ count = GoogleDriveIndexer.count_files_recursively(
239
+ client, google_drive_connection_config.drive_id, extensions_filter
240
+ )
241
+ assert count == 4, f"Expected file count of 4, but got {count}"
242
+
243
+
244
+ # Precheck fails with a completely invalid drive ID.
245
+ @pytest.mark.tags("google-drive", "precheck")
246
+ @requires_env("GOOGLE_DRIVE_ID", "GOOGLE_DRIVE_SERVICE_KEY")
247
+ def test_google_drive_precheck_invalid_drive_id(google_drive_connection_config):
248
+ invalid_drive_id = "invalid_drive_id"
249
+ connection_config = GoogleDriveConnectionConfig(
250
+ drive_id=invalid_drive_id,
251
+ access_config=google_drive_connection_config.access_config,
252
+ )
253
+ index_config = GoogleDriveIndexerConfig(recursive=True)
254
+ indexer = GoogleDriveIndexer(connection_config=connection_config, index_config=index_config)
255
+ with pytest.raises(SourceConnectionError) as excinfo:
256
+ indexer.precheck()
257
+ assert "invalid" in str(excinfo.value).lower() or "not found" in str(excinfo.value).lower()
File without changes
@@ -0,0 +1,74 @@
1
+ from pathlib import Path
2
+
3
+ import pytest
4
+ from pytest_mock import MockerFixture
5
+
6
+ from unstructured_ingest.v2.interfaces import FileData
7
+ from unstructured_ingest.v2.interfaces.file_data import SourceIdentifiers
8
+ from unstructured_ingest.v2.interfaces.upload_stager import UploadStagerConfig
9
+ from unstructured_ingest.v2.processes.connectors.duckdb.base import BaseDuckDBUploadStager
10
+
11
+
12
+ @pytest.fixture
13
+ def mock_instance() -> BaseDuckDBUploadStager:
14
+ return BaseDuckDBUploadStager(UploadStagerConfig())
15
+
16
+
17
+ @pytest.mark.parametrize(
18
+ ("input_filepath", "output_filename", "expected"),
19
+ [
20
+ (
21
+ "/path/to/input_file.ndjson",
22
+ "output_file.ndjson",
23
+ "output_file.ndjson",
24
+ ),
25
+ ("input_file.txt", "output_file.json", "output_file.txt"),
26
+ ("/path/to/input_file.json", "output_file", "output_file.json"),
27
+ ],
28
+ )
29
+ def test_run_output_filename_suffix(
30
+ mocker: MockerFixture,
31
+ mock_instance: BaseDuckDBUploadStager,
32
+ input_filepath: str,
33
+ output_filename: str,
34
+ expected: str,
35
+ ):
36
+ output_dir = Path("/tmp/test/output_dir")
37
+
38
+ # Mocks
39
+ mock_get_data = mocker.patch(
40
+ "unstructured_ingest.v2.processes.connectors.duckdb.base.get_data",
41
+ return_value=[{"key": "value"}, {"key": "value2"}],
42
+ )
43
+ mock_conform_dict = mocker.patch.object(
44
+ BaseDuckDBUploadStager,
45
+ "conform_dict",
46
+ side_effect=lambda element_dict, file_data: element_dict,
47
+ )
48
+ mock_get_output_path = mocker.patch.object(
49
+ BaseDuckDBUploadStager, "get_output_path", return_value=output_dir / expected
50
+ )
51
+ mock_write_output = mocker.patch(
52
+ "unstructured_ingest.v2.processes.connectors.duckdb.base.write_data", return_value=None
53
+ )
54
+
55
+ # Act
56
+ result = mock_instance.run(
57
+ elements_filepath=Path(input_filepath),
58
+ file_data=FileData(
59
+ identifier="test",
60
+ connector_type="test",
61
+ source_identifiers=SourceIdentifiers(filename=input_filepath, fullpath=input_filepath),
62
+ ),
63
+ output_dir=output_dir,
64
+ output_filename=output_filename,
65
+ )
66
+
67
+ # Assert
68
+ mock_get_data.assert_called_once_with(path=Path(input_filepath))
69
+ assert mock_conform_dict.call_count == 2
70
+ mock_get_output_path.assert_called_once_with(output_filename=expected, output_dir=output_dir)
71
+ mock_write_output.assert_called_once_with(
72
+ path=output_dir / expected, data=[{"key": "value"}, {"key": "value2"}]
73
+ )
74
+ assert result.name == expected
@@ -1 +1 @@
1
- __version__ = "0.5.0" # pragma: no cover
1
+ __version__ = "0.5.2" # pragma: no cover
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, AsyncIterable
8
8
  from pydantic import Field, SecretStr
9
9
 
10
10
  from unstructured_ingest.embed.interfaces import (
11
+ EMBEDDINGS_KEY,
11
12
  AsyncBaseEmbeddingEncoder,
12
13
  BaseEmbeddingEncoder,
13
14
  EmbeddingConfig,
@@ -145,9 +146,12 @@ class BedrockEmbeddingEncoder(BaseEmbeddingEncoder):
145
146
  return response_body.get("embedding")
146
147
 
147
148
  def embed_documents(self, elements: list[dict]) -> list[dict]:
148
- embeddings = [self.embed_query(query=e.get("text", "")) for e in elements]
149
- elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
150
- return elements_with_embeddings
149
+ elements = elements.copy()
150
+ elements_with_text = [e for e in elements if e.get("text")]
151
+ embeddings = [self.embed_query(query=e["text"]) for e in elements_with_text]
152
+ for element, embedding in zip(elements_with_text, embeddings):
153
+ element[EMBEDDINGS_KEY] = embedding
154
+ return elements
151
155
 
152
156
 
153
157
  @dataclass
@@ -186,8 +190,11 @@ class AsyncBedrockEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
186
190
  raise ValueError(f"Error raised by inference endpoint: {e}")
187
191
 
188
192
  async def embed_documents(self, elements: list[dict]) -> list[dict]:
193
+ elements = elements.copy()
194
+ elements_with_text = [e for e in elements if e.get("text")]
189
195
  embeddings = await asyncio.gather(
190
- *[self.embed_query(query=e.get("text", "")) for e in elements]
196
+ *[self.embed_query(query=e.get("text", "")) for e in elements_with_text]
191
197
  )
192
- elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
193
- return elements_with_embeddings
198
+ for element, embedding in zip(elements_with_text, embeddings):
199
+ element[EMBEDDINGS_KEY] = embedding
200
+ return elements
@@ -3,7 +3,11 @@ from typing import TYPE_CHECKING, Optional
3
3
 
4
4
  from pydantic import Field
5
5
 
6
- from unstructured_ingest.embed.interfaces import BaseEmbeddingEncoder, EmbeddingConfig
6
+ from unstructured_ingest.embed.interfaces import (
7
+ EMBEDDINGS_KEY,
8
+ BaseEmbeddingEncoder,
9
+ EmbeddingConfig,
10
+ )
7
11
  from unstructured_ingest.utils.dep_check import requires_dependencies
8
12
 
9
13
  if TYPE_CHECKING:
@@ -52,6 +56,9 @@ class HuggingFaceEmbeddingEncoder(BaseEmbeddingEncoder):
52
56
  return embeddings.tolist()
53
57
 
54
58
  def embed_documents(self, elements: list[dict]) -> list[dict]:
55
- embeddings = self._embed_documents([e.get("text", "") for e in elements])
56
- elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
57
- return elements_with_embeddings
59
+ elements = elements.copy()
60
+ elements_with_text = [e for e in elements if e.get("text")]
61
+ embeddings = self._embed_documents([e["text"] for e in elements_with_text])
62
+ for element, embedding in zip(elements_with_text, embeddings):
63
+ element[EMBEDDINGS_KEY] = embedding
64
+ return elements
@@ -6,6 +6,8 @@ from typing import Optional
6
6
  import numpy as np
7
7
  from pydantic import BaseModel, Field
8
8
 
9
+ EMBEDDINGS_KEY = "embeddings"
10
+
9
11
 
10
12
  class EmbeddingConfig(BaseModel):
11
13
  batch_size: Optional[int] = Field(
@@ -26,27 +28,6 @@ class BaseEncoder(ABC):
26
28
  if possible"""
27
29
  return e
28
30
 
29
- @staticmethod
30
- def _add_embeddings_to_elements(
31
- elements: list[dict], embeddings: list[list[float]]
32
- ) -> list[dict]:
33
- """
34
- Add embeddings to elements.
35
-
36
- Args:
37
- elements (list[Element]): List of elements.
38
- embeddings (list[list[float]]): List of embeddings.
39
-
40
- Returns:
41
- list[Element]: Elements with embeddings added.
42
- """
43
- assert len(elements) == len(embeddings)
44
- elements_w_embedding = []
45
- for i, element in enumerate(elements):
46
- element["embeddings"] = embeddings[i]
47
- elements_w_embedding.append(element)
48
- return elements
49
-
50
31
 
51
32
  @dataclass
52
33
  class BaseEmbeddingEncoder(BaseEncoder, ABC):
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING
6
6
  from pydantic import Field, SecretStr
7
7
 
8
8
  from unstructured_ingest.embed.interfaces import (
9
+ EMBEDDINGS_KEY,
9
10
  AsyncBaseEmbeddingEncoder,
10
11
  BaseEmbeddingEncoder,
11
12
  EmbeddingConfig,
@@ -134,8 +135,12 @@ class MixedbreadAIEmbeddingEncoder(BaseEmbeddingEncoder):
134
135
  Returns:
135
136
  list[Element]: Elements with embeddings.
136
137
  """
137
- embeddings = self._embed([e.get("text", "") for e in elements])
138
- return self._add_embeddings_to_elements(elements, embeddings)
138
+ elements = elements.copy()
139
+ elements_with_text = [e for e in elements if e.get("text")]
140
+ embeddings = self._embed([e["text"] for e in elements_with_text])
141
+ for element, embedding in zip(elements_with_text, embeddings):
142
+ element[EMBEDDINGS_KEY] = embedding
143
+ return elements
139
144
 
140
145
  def embed_query(self, query: str) -> list[float]:
141
146
  """
@@ -209,8 +214,12 @@ class AsyncMixedbreadAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
209
214
  Returns:
210
215
  list[Element]: Elements with embeddings.
211
216
  """
212
- embeddings = await self._embed([e.get("text", "") for e in elements])
213
- return self._add_embeddings_to_elements(elements, embeddings)
217
+ elements = elements.copy()
218
+ elements_with_text = [e for e in elements if e.get("text")]
219
+ embeddings = await self._embed([e["text"] for e in elements_with_text])
220
+ for element, embedding in zip(elements_with_text, embeddings):
221
+ element[EMBEDDINGS_KEY] = embedding
222
+ return elements
214
223
 
215
224
  async def embed_query(self, query: str) -> list[float]:
216
225
  """
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING
4
4
  from pydantic import Field, SecretStr
5
5
 
6
6
  from unstructured_ingest.embed.interfaces import (
7
+ EMBEDDINGS_KEY,
7
8
  AsyncBaseEmbeddingEncoder,
8
9
  BaseEmbeddingEncoder,
9
10
  EmbeddingConfig,
@@ -89,7 +90,9 @@ class OctoAIEmbeddingEncoder(BaseEmbeddingEncoder):
89
90
  return response.data[0].embedding
90
91
 
91
92
  def embed_documents(self, elements: list[dict]) -> list[dict]:
92
- texts = [e.get("text", "") for e in elements]
93
+ elements = elements.copy()
94
+ elements_with_text = [e for e in elements if e.get("text")]
95
+ texts = [e["text"] for e in elements_with_text]
93
96
  embeddings = []
94
97
  client = self.config.get_client()
95
98
  try:
@@ -100,8 +103,9 @@ class OctoAIEmbeddingEncoder(BaseEmbeddingEncoder):
100
103
  embeddings.extend([data.embedding for data in response.data])
101
104
  except Exception as e:
102
105
  raise self.wrap_error(e=e)
103
- elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
104
- return elements_with_embeddings
106
+ for element, embedding in zip(elements_with_text, embeddings):
107
+ element[EMBEDDINGS_KEY] = embedding
108
+ return elements
105
109
 
106
110
 
107
111
  @dataclass
@@ -122,7 +126,9 @@ class AsyncOctoAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
122
126
  return response.data[0].embedding
123
127
 
124
128
  async def embed_documents(self, elements: list[dict]) -> list[dict]:
125
- texts = [e.get("text", "") for e in elements]
129
+ elements = elements.copy()
130
+ elements_with_text = [e for e in elements if e.get("text")]
131
+ texts = [e["text"] for e in elements_with_text]
126
132
  client = self.config.get_async_client()
127
133
  embeddings = []
128
134
  try:
@@ -133,5 +139,6 @@ class AsyncOctoAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
133
139
  embeddings.extend([data.embedding for data in response.data])
134
140
  except Exception as e:
135
141
  raise self.wrap_error(e=e)
136
- elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
137
- return elements_with_embeddings
142
+ for element, embedding in zip(elements_with_text, embeddings):
143
+ element[EMBEDDINGS_KEY] = embedding
144
+ return elements
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING
4
4
  from pydantic import Field, SecretStr
5
5
 
6
6
  from unstructured_ingest.embed.interfaces import (
7
+ EMBEDDINGS_KEY,
7
8
  AsyncBaseEmbeddingEncoder,
8
9
  BaseEmbeddingEncoder,
9
10
  EmbeddingConfig,
@@ -82,7 +83,9 @@ class OpenAIEmbeddingEncoder(BaseEmbeddingEncoder):
82
83
 
83
84
  def embed_documents(self, elements: list[dict]) -> list[dict]:
84
85
  client = self.config.get_client()
85
- texts = [e.get("text", "") for e in elements]
86
+ elements = elements.copy()
87
+ elements_with_text = [e for e in elements if e.get("text")]
88
+ texts = [e["text"] for e in elements_with_text]
86
89
  embeddings = []
87
90
  try:
88
91
  for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
@@ -92,8 +95,9 @@ class OpenAIEmbeddingEncoder(BaseEmbeddingEncoder):
92
95
  embeddings.extend([data.embedding for data in response.data])
93
96
  except Exception as e:
94
97
  raise self.wrap_error(e=e)
95
- elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
96
- return elements_with_embeddings
98
+ for element, embedding in zip(elements_with_text, embeddings):
99
+ element[EMBEDDINGS_KEY] = embedding
100
+ return elements
97
101
 
98
102
 
99
103
  @dataclass
@@ -115,7 +119,9 @@ class AsyncOpenAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
115
119
 
116
120
  async def embed_documents(self, elements: list[dict]) -> list[dict]:
117
121
  client = self.config.get_async_client()
118
- texts = [e.get("text", "") for e in elements]
122
+ elements = elements.copy()
123
+ elements_with_text = [e for e in elements if e.get("text")]
124
+ texts = [e["text"] for e in elements_with_text]
119
125
  embeddings = []
120
126
  try:
121
127
  for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
@@ -125,5 +131,6 @@ class AsyncOpenAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
125
131
  embeddings.extend([data.embedding for data in response.data])
126
132
  except Exception as e:
127
133
  raise self.wrap_error(e=e)
128
- elements_with_embeddings = self._add_embeddings_to_elements(elements, embeddings)
129
- return elements_with_embeddings
134
+ for element, embedding in zip(elements_with_text, embeddings):
135
+ element[EMBEDDINGS_KEY] = embedding
136
+ return elements
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING
4
4
  from pydantic import Field, SecretStr
5
5
 
6
6
  from unstructured_ingest.embed.interfaces import (
7
+ EMBEDDINGS_KEY,
7
8
  AsyncBaseEmbeddingEncoder,
8
9
  BaseEmbeddingEncoder,
9
10
  EmbeddingConfig,
@@ -67,8 +68,12 @@ class TogetherAIEmbeddingEncoder(BaseEmbeddingEncoder):
67
68
  return self._embed_documents(elements=[query])[0]
68
69
 
69
70
  def embed_documents(self, elements: list[dict]) -> list[dict]:
70
- embeddings = self._embed_documents([e.get("text", "") for e in elements])
71
- return self._add_embeddings_to_elements(elements, embeddings)
71
+ elements = elements.copy()
72
+ elements_with_text = [e for e in elements if e.get("text")]
73
+ embeddings = self._embed_documents([e["text"] for e in elements_with_text])
74
+ for element, embedding in zip(elements_with_text, embeddings):
75
+ element[EMBEDDINGS_KEY] = embedding
76
+ return elements
72
77
 
73
78
  def _embed_documents(self, elements: list[str]) -> list[list[float]]:
74
79
  client = self.config.get_client()
@@ -98,8 +103,12 @@ class AsyncTogetherAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
98
103
  return embedding[0]
99
104
 
100
105
  async def embed_documents(self, elements: list[dict]) -> list[dict]:
101
- embeddings = await self._embed_documents([e.get("text", "") for e in elements])
102
- return self._add_embeddings_to_elements(elements, embeddings)
106
+ elements = elements.copy()
107
+ elements_with_text = [e for e in elements if e.get("text")]
108
+ embeddings = await self._embed_documents([e["text"] for e in elements_with_text])
109
+ for element, embedding in zip(elements_with_text, embeddings):
110
+ element[EMBEDDINGS_KEY] = embedding
111
+ return elements
103
112
 
104
113
  async def _embed_documents(self, elements: list[str]) -> list[list[float]]:
105
114
  client = self.config.get_async_client()