unstructured-ingest 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of unstructured-ingest might be problematic. Click here for more details.
- test/integration/connectors/sql/test_vastdb.py +34 -0
- test/integration/connectors/test_google_drive.py +257 -0
- test/unit/v2/connectors/motherduck/__init__.py +0 -0
- test/unit/v2/connectors/motherduck/test_base.py +74 -0
- unstructured_ingest/__version__.py +1 -1
- unstructured_ingest/embed/bedrock.py +13 -6
- unstructured_ingest/embed/huggingface.py +11 -4
- unstructured_ingest/embed/interfaces.py +2 -21
- unstructured_ingest/embed/mixedbreadai.py +13 -4
- unstructured_ingest/embed/octoai.py +13 -6
- unstructured_ingest/embed/openai.py +13 -6
- unstructured_ingest/embed/togetherai.py +13 -4
- unstructured_ingest/embed/vertexai.py +13 -6
- unstructured_ingest/embed/voyageai.py +13 -4
- unstructured_ingest/v2/processes/connectors/duckdb/base.py +2 -0
- unstructured_ingest/v2/processes/connectors/duckdb/motherduck.py +5 -4
- unstructured_ingest/v2/processes/connectors/google_drive.py +144 -13
- unstructured_ingest/v2/processes/connectors/pinecone.py +1 -0
- unstructured_ingest/v2/processes/connectors/sql/snowflake.py +53 -3
- unstructured_ingest/v2/processes/connectors/sql/sql.py +3 -47
- unstructured_ingest/v2/processes/connectors/sql/vastdb.py +4 -12
- {unstructured_ingest-0.5.0.dist-info → unstructured_ingest-0.5.2.dist-info}/METADATA +18 -18
- {unstructured_ingest-0.5.0.dist-info → unstructured_ingest-0.5.2.dist-info}/RECORD +27 -23
- {unstructured_ingest-0.5.0.dist-info → unstructured_ingest-0.5.2.dist-info}/LICENSE.md +0 -0
- {unstructured_ingest-0.5.0.dist-info → unstructured_ingest-0.5.2.dist-info}/WHEEL +0 -0
- {unstructured_ingest-0.5.0.dist-info → unstructured_ingest-0.5.2.dist-info}/entry_points.txt +0 -0
- {unstructured_ingest-0.5.0.dist-info → unstructured_ingest-0.5.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
from _pytest.fixtures import TopRequest
|
|
5
|
+
|
|
6
|
+
from test.integration.connectors.utils.constants import DESTINATION_TAG, SQL_TAG
|
|
7
|
+
from test.integration.connectors.utils.validation.destination import (
|
|
8
|
+
StagerValidationConfigs,
|
|
9
|
+
stager_validation,
|
|
10
|
+
)
|
|
11
|
+
from unstructured_ingest.v2.processes.connectors.sql.vastdb import (
|
|
12
|
+
CONNECTOR_TYPE,
|
|
13
|
+
VastdbUploadStager,
|
|
14
|
+
VastdbUploadStagerConfig,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG, SQL_TAG)
|
|
19
|
+
@pytest.mark.parametrize("upload_file_str", ["upload_file_ndjson", "upload_file"])
|
|
20
|
+
def test_vast_stager(
|
|
21
|
+
request: TopRequest,
|
|
22
|
+
upload_file_str: str,
|
|
23
|
+
tmp_path: Path,
|
|
24
|
+
):
|
|
25
|
+
upload_file: Path = request.getfixturevalue(upload_file_str)
|
|
26
|
+
stager = VastdbUploadStager(
|
|
27
|
+
upload_stager_config=VastdbUploadStagerConfig(rename_columns_map={"page_number": "page"})
|
|
28
|
+
)
|
|
29
|
+
stager_validation(
|
|
30
|
+
configs=StagerValidationConfigs(test_id=CONNECTOR_TYPE, expected_count=22),
|
|
31
|
+
input_file=upload_file,
|
|
32
|
+
stager=stager,
|
|
33
|
+
tmp_dir=tmp_path,
|
|
34
|
+
)
|
|
@@ -0,0 +1,257 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import uuid
|
|
3
|
+
|
|
4
|
+
import pytest
|
|
5
|
+
from googleapiclient.errors import HttpError
|
|
6
|
+
|
|
7
|
+
from test.integration.connectors.utils.constants import (
|
|
8
|
+
SOURCE_TAG,
|
|
9
|
+
UNCATEGORIZED_TAG,
|
|
10
|
+
)
|
|
11
|
+
from test.integration.connectors.utils.validation.source import (
|
|
12
|
+
SourceValidationConfigs,
|
|
13
|
+
get_all_file_data,
|
|
14
|
+
run_all_validations,
|
|
15
|
+
update_fixtures,
|
|
16
|
+
)
|
|
17
|
+
from test.integration.utils import requires_env
|
|
18
|
+
from unstructured_ingest.error import (
|
|
19
|
+
SourceConnectionError,
|
|
20
|
+
)
|
|
21
|
+
from unstructured_ingest.v2.interfaces import Downloader, Indexer
|
|
22
|
+
from unstructured_ingest.v2.processes.connectors.google_drive import (
|
|
23
|
+
CONNECTOR_TYPE,
|
|
24
|
+
GoogleDriveAccessConfig,
|
|
25
|
+
GoogleDriveConnectionConfig,
|
|
26
|
+
GoogleDriveDownloader,
|
|
27
|
+
GoogleDriveDownloaderConfig,
|
|
28
|
+
GoogleDriveIndexer,
|
|
29
|
+
GoogleDriveIndexerConfig,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@pytest.fixture
|
|
34
|
+
def google_drive_connection_config():
|
|
35
|
+
"""
|
|
36
|
+
Build a valid GoogleDriveConnectionConfig using the environment variables.
|
|
37
|
+
Expects:
|
|
38
|
+
- GOOGLE_DRIVE_ID
|
|
39
|
+
- GOOGLE_DRIVE_SERVICE_KEY
|
|
40
|
+
"""
|
|
41
|
+
drive_id = os.getenv("GOOGLE_DRIVE_ID")
|
|
42
|
+
service_key = os.getenv("GOOGLE_DRIVE_SERVICE_KEY")
|
|
43
|
+
if not drive_id or not service_key:
|
|
44
|
+
pytest.skip("Google Drive credentials not provided in environment variables.")
|
|
45
|
+
|
|
46
|
+
access_config = GoogleDriveAccessConfig(service_account_key=service_key)
|
|
47
|
+
return GoogleDriveConnectionConfig(drive_id=drive_id, access_config=access_config)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@pytest.fixture
|
|
51
|
+
def google_drive_empty_folder(google_drive_connection_config):
|
|
52
|
+
"""
|
|
53
|
+
Creates an empty folder on Google Drive for testing the "empty folder" case.
|
|
54
|
+
The folder is deleted after the test.
|
|
55
|
+
"""
|
|
56
|
+
from google.oauth2 import service_account
|
|
57
|
+
from googleapiclient.discovery import build
|
|
58
|
+
|
|
59
|
+
access_config = google_drive_connection_config.access_config.get_secret_value()
|
|
60
|
+
creds = service_account.Credentials.from_service_account_info(access_config.service_account_key)
|
|
61
|
+
service = build("drive", "v3", credentials=creds)
|
|
62
|
+
|
|
63
|
+
# Create an empty folder.
|
|
64
|
+
file_metadata = {
|
|
65
|
+
"name": f"utic-empty-folder-{uuid.uuid4()}",
|
|
66
|
+
"mimeType": "application/vnd.google-apps.folder",
|
|
67
|
+
}
|
|
68
|
+
folder = service.files().create(body=file_metadata, fields="id, name").execute()
|
|
69
|
+
folder_id = folder.get("id")
|
|
70
|
+
try:
|
|
71
|
+
yield folder_id
|
|
72
|
+
finally:
|
|
73
|
+
service.files().delete(fileId=folder_id).execute()
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@requires_env("GOOGLE_DRIVE_SERVICE_KEY")
|
|
77
|
+
@pytest.mark.tags(SOURCE_TAG, CONNECTOR_TYPE)
|
|
78
|
+
def test_google_drive_source(temp_dir):
|
|
79
|
+
# Retrieve environment variables
|
|
80
|
+
service_account_key = os.environ["GOOGLE_DRIVE_SERVICE_KEY"]
|
|
81
|
+
|
|
82
|
+
# Create connection and indexer configurations
|
|
83
|
+
access_config = GoogleDriveAccessConfig(service_account_key=service_account_key)
|
|
84
|
+
connection_config = GoogleDriveConnectionConfig(
|
|
85
|
+
drive_id="1XidSOO76VpZ4m0i3gJN2m1X0Obol3UAi",
|
|
86
|
+
access_config=access_config,
|
|
87
|
+
)
|
|
88
|
+
index_config = GoogleDriveIndexerConfig(recursive=True)
|
|
89
|
+
|
|
90
|
+
download_config = GoogleDriveDownloaderConfig(download_dir=temp_dir)
|
|
91
|
+
|
|
92
|
+
# Instantiate indexer and downloader
|
|
93
|
+
indexer = GoogleDriveIndexer(
|
|
94
|
+
connection_config=connection_config,
|
|
95
|
+
index_config=index_config,
|
|
96
|
+
)
|
|
97
|
+
downloader = GoogleDriveDownloader(
|
|
98
|
+
connection_config=connection_config,
|
|
99
|
+
download_config=download_config,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# Run the source connector validation
|
|
103
|
+
source_connector_validation(
|
|
104
|
+
indexer=indexer,
|
|
105
|
+
downloader=downloader,
|
|
106
|
+
configs=SourceValidationConfigs(
|
|
107
|
+
test_id="google_drive_source",
|
|
108
|
+
expected_num_files=1,
|
|
109
|
+
validate_downloaded_files=True,
|
|
110
|
+
),
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@pytest.mark.tags(SOURCE_TAG, CONNECTOR_TYPE, UNCATEGORIZED_TAG)
|
|
115
|
+
def source_connector_validation(
|
|
116
|
+
indexer: Indexer,
|
|
117
|
+
downloader: Downloader,
|
|
118
|
+
configs: SourceValidationConfigs,
|
|
119
|
+
overwrite_fixtures: bool = os.getenv("OVERWRITE_FIXTURES", "False").lower() == "true",
|
|
120
|
+
) -> None:
|
|
121
|
+
# Run common validations on the process of running a source connector, supporting dynamic
|
|
122
|
+
# validators that get passed in along with comparisons on the saved expected values.
|
|
123
|
+
# If overwrite_fixtures is st to True, will ignore all validators but instead overwrite the
|
|
124
|
+
# expected values with what gets generated by this test.
|
|
125
|
+
all_predownload_file_data = []
|
|
126
|
+
all_postdownload_file_data = []
|
|
127
|
+
indexer.precheck()
|
|
128
|
+
download_dir = downloader.download_config.download_dir
|
|
129
|
+
test_output_dir = configs.test_output_dir()
|
|
130
|
+
|
|
131
|
+
for file_data in indexer.run():
|
|
132
|
+
assert file_data
|
|
133
|
+
predownload_file_data = file_data.model_copy(deep=True)
|
|
134
|
+
all_predownload_file_data.append(predownload_file_data)
|
|
135
|
+
resp = downloader.run(file_data=file_data)
|
|
136
|
+
if isinstance(resp, list):
|
|
137
|
+
for r in resp:
|
|
138
|
+
postdownload_file_data = r["file_data"].model_copy(deep=True)
|
|
139
|
+
all_postdownload_file_data.append(postdownload_file_data)
|
|
140
|
+
else:
|
|
141
|
+
postdownload_file_data = resp["file_data"].model_copy(deep=True)
|
|
142
|
+
all_postdownload_file_data.append(postdownload_file_data)
|
|
143
|
+
|
|
144
|
+
if not overwrite_fixtures:
|
|
145
|
+
print("Running validation")
|
|
146
|
+
run_all_validations(
|
|
147
|
+
configs=configs,
|
|
148
|
+
predownload_file_data=all_predownload_file_data,
|
|
149
|
+
postdownload_file_data=all_postdownload_file_data,
|
|
150
|
+
download_dir=download_dir,
|
|
151
|
+
test_output_dir=test_output_dir,
|
|
152
|
+
)
|
|
153
|
+
else:
|
|
154
|
+
print("Running fixtures update")
|
|
155
|
+
update_fixtures(
|
|
156
|
+
output_dir=test_output_dir,
|
|
157
|
+
download_dir=download_dir,
|
|
158
|
+
all_file_data=get_all_file_data(
|
|
159
|
+
all_predownload_file_data=all_predownload_file_data,
|
|
160
|
+
all_postdownload_file_data=all_postdownload_file_data,
|
|
161
|
+
),
|
|
162
|
+
save_downloads=configs.validate_downloaded_files,
|
|
163
|
+
save_filedata=configs.validate_file_data,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
# Precheck fails when the drive ID has an appended parameter (simulate copy-paste error)
|
|
168
|
+
@pytest.mark.tags("google-drive", "precheck")
|
|
169
|
+
@requires_env("GOOGLE_DRIVE_ID", "GOOGLE_DRIVE_SERVICE_KEY")
|
|
170
|
+
def test_google_drive_precheck_invalid_parameter(google_drive_connection_config):
|
|
171
|
+
# Append a query parameter as often happens when copying from a URL.
|
|
172
|
+
invalid_drive_id = google_drive_connection_config.drive_id + "?usp=sharing"
|
|
173
|
+
connection_config = GoogleDriveConnectionConfig(
|
|
174
|
+
drive_id=invalid_drive_id,
|
|
175
|
+
access_config=google_drive_connection_config.access_config,
|
|
176
|
+
)
|
|
177
|
+
index_config = GoogleDriveIndexerConfig(recursive=True)
|
|
178
|
+
indexer = GoogleDriveIndexer(connection_config=connection_config, index_config=index_config)
|
|
179
|
+
with pytest.raises(SourceConnectionError) as excinfo:
|
|
180
|
+
indexer.precheck()
|
|
181
|
+
assert "invalid" in str(excinfo.value).lower() or "not found" in str(excinfo.value).lower()
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
# Precheck fails due to lack of permission (simulate via monkeypatching).
|
|
185
|
+
@pytest.mark.tags("google-drive", "precheck")
|
|
186
|
+
@requires_env("GOOGLE_DRIVE_ID", "GOOGLE_DRIVE_SERVICE_KEY")
|
|
187
|
+
def test_google_drive_precheck_no_permission(google_drive_connection_config, monkeypatch):
|
|
188
|
+
index_config = GoogleDriveIndexerConfig(recursive=True)
|
|
189
|
+
indexer = GoogleDriveIndexer(
|
|
190
|
+
connection_config=google_drive_connection_config,
|
|
191
|
+
index_config=index_config,
|
|
192
|
+
)
|
|
193
|
+
|
|
194
|
+
# Monkeypatch get_root_info to always raise an HTTP 403 error.
|
|
195
|
+
def fake_get_root_info(files_client, object_id):
|
|
196
|
+
raise HttpError(
|
|
197
|
+
resp=type("Response", (), {"status": 403, "reason": "Forbidden"})(),
|
|
198
|
+
content=b"Forbidden",
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
monkeypatch.setattr(indexer, "get_root_info", fake_get_root_info)
|
|
202
|
+
with pytest.raises(SourceConnectionError) as excinfo:
|
|
203
|
+
indexer.precheck()
|
|
204
|
+
assert "forbidden" in str(excinfo.value).lower() or "permission" in str(excinfo.value).lower()
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
# Precheck fails when the folder is empty.
|
|
208
|
+
# @pytest.mark.tags("google-drive", "precheck")
|
|
209
|
+
# @requires_env("GOOGLE_DRIVE_ID", "GOOGLE_DRIVE_SERVICE_KEY")
|
|
210
|
+
# def test_google_drive_precheck_empty_folder(
|
|
211
|
+
# google_drive_connection_config, google_drive_empty_folder
|
|
212
|
+
# ):
|
|
213
|
+
# # Use the empty folder's ID as the target.
|
|
214
|
+
# connection_config = GoogleDriveConnectionConfig(
|
|
215
|
+
# drive_id=google_drive_empty_folder,
|
|
216
|
+
# access_config=google_drive_connection_config.access_config,
|
|
217
|
+
# )
|
|
218
|
+
|
|
219
|
+
# index_config = GoogleDriveIndexerConfig(recursive=True)
|
|
220
|
+
# indexer = GoogleDriveIndexer(connection_config=connection_config, index_config=index_config)
|
|
221
|
+
# with pytest.raises(SourceConnectionError) as excinfo:
|
|
222
|
+
# indexer.precheck()
|
|
223
|
+
# assert "empty folder" in str(excinfo.value).lower()
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
@pytest.mark.tags("google-drive", "count", "integration")
|
|
227
|
+
@requires_env("GOOGLE_DRIVE_ID", "GOOGLE_DRIVE_SERVICE_KEY")
|
|
228
|
+
def test_google_drive_count_files(google_drive_connection_config):
|
|
229
|
+
"""
|
|
230
|
+
This test verifies that the count_files_recursively method returns the expected count of files.
|
|
231
|
+
According to the test credentials, there are 3 files in the root directory and 1 nested file,
|
|
232
|
+
so the total count should be 4.
|
|
233
|
+
"""
|
|
234
|
+
# I assumed that we're applying the same extension filter as with other tests
|
|
235
|
+
# However there's 6 files in total in the test dir
|
|
236
|
+
extensions_filter = ["pdf", "docx"]
|
|
237
|
+
with google_drive_connection_config.get_client() as client:
|
|
238
|
+
count = GoogleDriveIndexer.count_files_recursively(
|
|
239
|
+
client, google_drive_connection_config.drive_id, extensions_filter
|
|
240
|
+
)
|
|
241
|
+
assert count == 4, f"Expected file count of 4, but got {count}"
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
# Precheck fails with a completely invalid drive ID.
|
|
245
|
+
@pytest.mark.tags("google-drive", "precheck")
|
|
246
|
+
@requires_env("GOOGLE_DRIVE_ID", "GOOGLE_DRIVE_SERVICE_KEY")
|
|
247
|
+
def test_google_drive_precheck_invalid_drive_id(google_drive_connection_config):
|
|
248
|
+
invalid_drive_id = "invalid_drive_id"
|
|
249
|
+
connection_config = GoogleDriveConnectionConfig(
|
|
250
|
+
drive_id=invalid_drive_id,
|
|
251
|
+
access_config=google_drive_connection_config.access_config,
|
|
252
|
+
)
|
|
253
|
+
index_config = GoogleDriveIndexerConfig(recursive=True)
|
|
254
|
+
indexer = GoogleDriveIndexer(connection_config=connection_config, index_config=index_config)
|
|
255
|
+
with pytest.raises(SourceConnectionError) as excinfo:
|
|
256
|
+
indexer.precheck()
|
|
257
|
+
assert "invalid" in str(excinfo.value).lower() or "not found" in str(excinfo.value).lower()
|
|
File without changes
|
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
import pytest
|
|
4
|
+
from pytest_mock import MockerFixture
|
|
5
|
+
|
|
6
|
+
from unstructured_ingest.v2.interfaces import FileData
|
|
7
|
+
from unstructured_ingest.v2.interfaces.file_data import SourceIdentifiers
|
|
8
|
+
from unstructured_ingest.v2.interfaces.upload_stager import UploadStagerConfig
|
|
9
|
+
from unstructured_ingest.v2.processes.connectors.duckdb.base import BaseDuckDBUploadStager
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@pytest.fixture
|
|
13
|
+
def mock_instance() -> BaseDuckDBUploadStager:
|
|
14
|
+
return BaseDuckDBUploadStager(UploadStagerConfig())
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
@pytest.mark.parametrize(
|
|
18
|
+
("input_filepath", "output_filename", "expected"),
|
|
19
|
+
[
|
|
20
|
+
(
|
|
21
|
+
"/path/to/input_file.ndjson",
|
|
22
|
+
"output_file.ndjson",
|
|
23
|
+
"output_file.ndjson",
|
|
24
|
+
),
|
|
25
|
+
("input_file.txt", "output_file.json", "output_file.txt"),
|
|
26
|
+
("/path/to/input_file.json", "output_file", "output_file.json"),
|
|
27
|
+
],
|
|
28
|
+
)
|
|
29
|
+
def test_run_output_filename_suffix(
|
|
30
|
+
mocker: MockerFixture,
|
|
31
|
+
mock_instance: BaseDuckDBUploadStager,
|
|
32
|
+
input_filepath: str,
|
|
33
|
+
output_filename: str,
|
|
34
|
+
expected: str,
|
|
35
|
+
):
|
|
36
|
+
output_dir = Path("/tmp/test/output_dir")
|
|
37
|
+
|
|
38
|
+
# Mocks
|
|
39
|
+
mock_get_data = mocker.patch(
|
|
40
|
+
"unstructured_ingest.v2.processes.connectors.duckdb.base.get_data",
|
|
41
|
+
return_value=[{"key": "value"}, {"key": "value2"}],
|
|
42
|
+
)
|
|
43
|
+
mock_conform_dict = mocker.patch.object(
|
|
44
|
+
BaseDuckDBUploadStager,
|
|
45
|
+
"conform_dict",
|
|
46
|
+
side_effect=lambda element_dict, file_data: element_dict,
|
|
47
|
+
)
|
|
48
|
+
mock_get_output_path = mocker.patch.object(
|
|
49
|
+
BaseDuckDBUploadStager, "get_output_path", return_value=output_dir / expected
|
|
50
|
+
)
|
|
51
|
+
mock_write_output = mocker.patch(
|
|
52
|
+
"unstructured_ingest.v2.processes.connectors.duckdb.base.write_data", return_value=None
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
# Act
|
|
56
|
+
result = mock_instance.run(
|
|
57
|
+
elements_filepath=Path(input_filepath),
|
|
58
|
+
file_data=FileData(
|
|
59
|
+
identifier="test",
|
|
60
|
+
connector_type="test",
|
|
61
|
+
source_identifiers=SourceIdentifiers(filename=input_filepath, fullpath=input_filepath),
|
|
62
|
+
),
|
|
63
|
+
output_dir=output_dir,
|
|
64
|
+
output_filename=output_filename,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# Assert
|
|
68
|
+
mock_get_data.assert_called_once_with(path=Path(input_filepath))
|
|
69
|
+
assert mock_conform_dict.call_count == 2
|
|
70
|
+
mock_get_output_path.assert_called_once_with(output_filename=expected, output_dir=output_dir)
|
|
71
|
+
mock_write_output.assert_called_once_with(
|
|
72
|
+
path=output_dir / expected, data=[{"key": "value"}, {"key": "value2"}]
|
|
73
|
+
)
|
|
74
|
+
assert result.name == expected
|
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.5.
|
|
1
|
+
__version__ = "0.5.2" # pragma: no cover
|
|
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, AsyncIterable
|
|
|
8
8
|
from pydantic import Field, SecretStr
|
|
9
9
|
|
|
10
10
|
from unstructured_ingest.embed.interfaces import (
|
|
11
|
+
EMBEDDINGS_KEY,
|
|
11
12
|
AsyncBaseEmbeddingEncoder,
|
|
12
13
|
BaseEmbeddingEncoder,
|
|
13
14
|
EmbeddingConfig,
|
|
@@ -145,9 +146,12 @@ class BedrockEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
145
146
|
return response_body.get("embedding")
|
|
146
147
|
|
|
147
148
|
def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
149
|
+
elements = elements.copy()
|
|
150
|
+
elements_with_text = [e for e in elements if e.get("text")]
|
|
151
|
+
embeddings = [self.embed_query(query=e["text"]) for e in elements_with_text]
|
|
152
|
+
for element, embedding in zip(elements_with_text, embeddings):
|
|
153
|
+
element[EMBEDDINGS_KEY] = embedding
|
|
154
|
+
return elements
|
|
151
155
|
|
|
152
156
|
|
|
153
157
|
@dataclass
|
|
@@ -186,8 +190,11 @@ class AsyncBedrockEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
|
|
|
186
190
|
raise ValueError(f"Error raised by inference endpoint: {e}")
|
|
187
191
|
|
|
188
192
|
async def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
193
|
+
elements = elements.copy()
|
|
194
|
+
elements_with_text = [e for e in elements if e.get("text")]
|
|
189
195
|
embeddings = await asyncio.gather(
|
|
190
|
-
*[self.embed_query(query=e.get("text", "")) for e in
|
|
196
|
+
*[self.embed_query(query=e.get("text", "")) for e in elements_with_text]
|
|
191
197
|
)
|
|
192
|
-
|
|
193
|
-
|
|
198
|
+
for element, embedding in zip(elements_with_text, embeddings):
|
|
199
|
+
element[EMBEDDINGS_KEY] = embedding
|
|
200
|
+
return elements
|
|
@@ -3,7 +3,11 @@ from typing import TYPE_CHECKING, Optional
|
|
|
3
3
|
|
|
4
4
|
from pydantic import Field
|
|
5
5
|
|
|
6
|
-
from unstructured_ingest.embed.interfaces import
|
|
6
|
+
from unstructured_ingest.embed.interfaces import (
|
|
7
|
+
EMBEDDINGS_KEY,
|
|
8
|
+
BaseEmbeddingEncoder,
|
|
9
|
+
EmbeddingConfig,
|
|
10
|
+
)
|
|
7
11
|
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
8
12
|
|
|
9
13
|
if TYPE_CHECKING:
|
|
@@ -52,6 +56,9 @@ class HuggingFaceEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
52
56
|
return embeddings.tolist()
|
|
53
57
|
|
|
54
58
|
def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
59
|
+
elements = elements.copy()
|
|
60
|
+
elements_with_text = [e for e in elements if e.get("text")]
|
|
61
|
+
embeddings = self._embed_documents([e["text"] for e in elements_with_text])
|
|
62
|
+
for element, embedding in zip(elements_with_text, embeddings):
|
|
63
|
+
element[EMBEDDINGS_KEY] = embedding
|
|
64
|
+
return elements
|
|
@@ -6,6 +6,8 @@ from typing import Optional
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
from pydantic import BaseModel, Field
|
|
8
8
|
|
|
9
|
+
EMBEDDINGS_KEY = "embeddings"
|
|
10
|
+
|
|
9
11
|
|
|
10
12
|
class EmbeddingConfig(BaseModel):
|
|
11
13
|
batch_size: Optional[int] = Field(
|
|
@@ -26,27 +28,6 @@ class BaseEncoder(ABC):
|
|
|
26
28
|
if possible"""
|
|
27
29
|
return e
|
|
28
30
|
|
|
29
|
-
@staticmethod
|
|
30
|
-
def _add_embeddings_to_elements(
|
|
31
|
-
elements: list[dict], embeddings: list[list[float]]
|
|
32
|
-
) -> list[dict]:
|
|
33
|
-
"""
|
|
34
|
-
Add embeddings to elements.
|
|
35
|
-
|
|
36
|
-
Args:
|
|
37
|
-
elements (list[Element]): List of elements.
|
|
38
|
-
embeddings (list[list[float]]): List of embeddings.
|
|
39
|
-
|
|
40
|
-
Returns:
|
|
41
|
-
list[Element]: Elements with embeddings added.
|
|
42
|
-
"""
|
|
43
|
-
assert len(elements) == len(embeddings)
|
|
44
|
-
elements_w_embedding = []
|
|
45
|
-
for i, element in enumerate(elements):
|
|
46
|
-
element["embeddings"] = embeddings[i]
|
|
47
|
-
elements_w_embedding.append(element)
|
|
48
|
-
return elements
|
|
49
|
-
|
|
50
31
|
|
|
51
32
|
@dataclass
|
|
52
33
|
class BaseEmbeddingEncoder(BaseEncoder, ABC):
|
|
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING
|
|
|
6
6
|
from pydantic import Field, SecretStr
|
|
7
7
|
|
|
8
8
|
from unstructured_ingest.embed.interfaces import (
|
|
9
|
+
EMBEDDINGS_KEY,
|
|
9
10
|
AsyncBaseEmbeddingEncoder,
|
|
10
11
|
BaseEmbeddingEncoder,
|
|
11
12
|
EmbeddingConfig,
|
|
@@ -134,8 +135,12 @@ class MixedbreadAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
134
135
|
Returns:
|
|
135
136
|
list[Element]: Elements with embeddings.
|
|
136
137
|
"""
|
|
137
|
-
|
|
138
|
-
|
|
138
|
+
elements = elements.copy()
|
|
139
|
+
elements_with_text = [e for e in elements if e.get("text")]
|
|
140
|
+
embeddings = self._embed([e["text"] for e in elements_with_text])
|
|
141
|
+
for element, embedding in zip(elements_with_text, embeddings):
|
|
142
|
+
element[EMBEDDINGS_KEY] = embedding
|
|
143
|
+
return elements
|
|
139
144
|
|
|
140
145
|
def embed_query(self, query: str) -> list[float]:
|
|
141
146
|
"""
|
|
@@ -209,8 +214,12 @@ class AsyncMixedbreadAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
|
|
|
209
214
|
Returns:
|
|
210
215
|
list[Element]: Elements with embeddings.
|
|
211
216
|
"""
|
|
212
|
-
|
|
213
|
-
|
|
217
|
+
elements = elements.copy()
|
|
218
|
+
elements_with_text = [e for e in elements if e.get("text")]
|
|
219
|
+
embeddings = await self._embed([e["text"] for e in elements_with_text])
|
|
220
|
+
for element, embedding in zip(elements_with_text, embeddings):
|
|
221
|
+
element[EMBEDDINGS_KEY] = embedding
|
|
222
|
+
return elements
|
|
214
223
|
|
|
215
224
|
async def embed_query(self, query: str) -> list[float]:
|
|
216
225
|
"""
|
|
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING
|
|
|
4
4
|
from pydantic import Field, SecretStr
|
|
5
5
|
|
|
6
6
|
from unstructured_ingest.embed.interfaces import (
|
|
7
|
+
EMBEDDINGS_KEY,
|
|
7
8
|
AsyncBaseEmbeddingEncoder,
|
|
8
9
|
BaseEmbeddingEncoder,
|
|
9
10
|
EmbeddingConfig,
|
|
@@ -89,7 +90,9 @@ class OctoAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
89
90
|
return response.data[0].embedding
|
|
90
91
|
|
|
91
92
|
def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
92
|
-
|
|
93
|
+
elements = elements.copy()
|
|
94
|
+
elements_with_text = [e for e in elements if e.get("text")]
|
|
95
|
+
texts = [e["text"] for e in elements_with_text]
|
|
93
96
|
embeddings = []
|
|
94
97
|
client = self.config.get_client()
|
|
95
98
|
try:
|
|
@@ -100,8 +103,9 @@ class OctoAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
100
103
|
embeddings.extend([data.embedding for data in response.data])
|
|
101
104
|
except Exception as e:
|
|
102
105
|
raise self.wrap_error(e=e)
|
|
103
|
-
|
|
104
|
-
|
|
106
|
+
for element, embedding in zip(elements_with_text, embeddings):
|
|
107
|
+
element[EMBEDDINGS_KEY] = embedding
|
|
108
|
+
return elements
|
|
105
109
|
|
|
106
110
|
|
|
107
111
|
@dataclass
|
|
@@ -122,7 +126,9 @@ class AsyncOctoAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
|
|
|
122
126
|
return response.data[0].embedding
|
|
123
127
|
|
|
124
128
|
async def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
125
|
-
|
|
129
|
+
elements = elements.copy()
|
|
130
|
+
elements_with_text = [e for e in elements if e.get("text")]
|
|
131
|
+
texts = [e["text"] for e in elements_with_text]
|
|
126
132
|
client = self.config.get_async_client()
|
|
127
133
|
embeddings = []
|
|
128
134
|
try:
|
|
@@ -133,5 +139,6 @@ class AsyncOctoAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
|
|
|
133
139
|
embeddings.extend([data.embedding for data in response.data])
|
|
134
140
|
except Exception as e:
|
|
135
141
|
raise self.wrap_error(e=e)
|
|
136
|
-
|
|
137
|
-
|
|
142
|
+
for element, embedding in zip(elements_with_text, embeddings):
|
|
143
|
+
element[EMBEDDINGS_KEY] = embedding
|
|
144
|
+
return elements
|
|
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING
|
|
|
4
4
|
from pydantic import Field, SecretStr
|
|
5
5
|
|
|
6
6
|
from unstructured_ingest.embed.interfaces import (
|
|
7
|
+
EMBEDDINGS_KEY,
|
|
7
8
|
AsyncBaseEmbeddingEncoder,
|
|
8
9
|
BaseEmbeddingEncoder,
|
|
9
10
|
EmbeddingConfig,
|
|
@@ -82,7 +83,9 @@ class OpenAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
82
83
|
|
|
83
84
|
def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
84
85
|
client = self.config.get_client()
|
|
85
|
-
|
|
86
|
+
elements = elements.copy()
|
|
87
|
+
elements_with_text = [e for e in elements if e.get("text")]
|
|
88
|
+
texts = [e["text"] for e in elements_with_text]
|
|
86
89
|
embeddings = []
|
|
87
90
|
try:
|
|
88
91
|
for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
|
|
@@ -92,8 +95,9 @@ class OpenAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
92
95
|
embeddings.extend([data.embedding for data in response.data])
|
|
93
96
|
except Exception as e:
|
|
94
97
|
raise self.wrap_error(e=e)
|
|
95
|
-
|
|
96
|
-
|
|
98
|
+
for element, embedding in zip(elements_with_text, embeddings):
|
|
99
|
+
element[EMBEDDINGS_KEY] = embedding
|
|
100
|
+
return elements
|
|
97
101
|
|
|
98
102
|
|
|
99
103
|
@dataclass
|
|
@@ -115,7 +119,9 @@ class AsyncOpenAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
|
|
|
115
119
|
|
|
116
120
|
async def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
117
121
|
client = self.config.get_async_client()
|
|
118
|
-
|
|
122
|
+
elements = elements.copy()
|
|
123
|
+
elements_with_text = [e for e in elements if e.get("text")]
|
|
124
|
+
texts = [e["text"] for e in elements_with_text]
|
|
119
125
|
embeddings = []
|
|
120
126
|
try:
|
|
121
127
|
for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
|
|
@@ -125,5 +131,6 @@ class AsyncOpenAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
|
|
|
125
131
|
embeddings.extend([data.embedding for data in response.data])
|
|
126
132
|
except Exception as e:
|
|
127
133
|
raise self.wrap_error(e=e)
|
|
128
|
-
|
|
129
|
-
|
|
134
|
+
for element, embedding in zip(elements_with_text, embeddings):
|
|
135
|
+
element[EMBEDDINGS_KEY] = embedding
|
|
136
|
+
return elements
|
|
@@ -4,6 +4,7 @@ from typing import TYPE_CHECKING
|
|
|
4
4
|
from pydantic import Field, SecretStr
|
|
5
5
|
|
|
6
6
|
from unstructured_ingest.embed.interfaces import (
|
|
7
|
+
EMBEDDINGS_KEY,
|
|
7
8
|
AsyncBaseEmbeddingEncoder,
|
|
8
9
|
BaseEmbeddingEncoder,
|
|
9
10
|
EmbeddingConfig,
|
|
@@ -67,8 +68,12 @@ class TogetherAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
|
67
68
|
return self._embed_documents(elements=[query])[0]
|
|
68
69
|
|
|
69
70
|
def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
70
|
-
|
|
71
|
-
|
|
71
|
+
elements = elements.copy()
|
|
72
|
+
elements_with_text = [e for e in elements if e.get("text")]
|
|
73
|
+
embeddings = self._embed_documents([e["text"] for e in elements_with_text])
|
|
74
|
+
for element, embedding in zip(elements_with_text, embeddings):
|
|
75
|
+
element[EMBEDDINGS_KEY] = embedding
|
|
76
|
+
return elements
|
|
72
77
|
|
|
73
78
|
def _embed_documents(self, elements: list[str]) -> list[list[float]]:
|
|
74
79
|
client = self.config.get_client()
|
|
@@ -98,8 +103,12 @@ class AsyncTogetherAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
|
|
|
98
103
|
return embedding[0]
|
|
99
104
|
|
|
100
105
|
async def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
101
|
-
|
|
102
|
-
|
|
106
|
+
elements = elements.copy()
|
|
107
|
+
elements_with_text = [e for e in elements if e.get("text")]
|
|
108
|
+
embeddings = await self._embed_documents([e["text"] for e in elements_with_text])
|
|
109
|
+
for element, embedding in zip(elements_with_text, embeddings):
|
|
110
|
+
element[EMBEDDINGS_KEY] = embedding
|
|
111
|
+
return elements
|
|
103
112
|
|
|
104
113
|
async def _embed_documents(self, elements: list[str]) -> list[list[float]]:
|
|
105
114
|
client = self.config.get_async_client()
|