unstructured-ingest 0.5.1__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of unstructured-ingest might be problematic. Click here for more details.
- test/integration/connectors/test_google_drive.py +141 -0
- test/unit/v2/embedders/test_bedrock.py +1 -1
- test/unit/v2/embedders/test_huggingface.py +1 -1
- unstructured_ingest/__version__.py +1 -1
- unstructured_ingest/embed/azure_openai.py +6 -0
- unstructured_ingest/embed/bedrock.py +29 -12
- unstructured_ingest/embed/huggingface.py +14 -5
- unstructured_ingest/embed/interfaces.py +63 -44
- unstructured_ingest/embed/mixedbreadai.py +28 -105
- unstructured_ingest/embed/octoai.py +19 -44
- unstructured_ingest/embed/openai.py +17 -48
- unstructured_ingest/embed/togetherai.py +16 -49
- unstructured_ingest/embed/vertexai.py +15 -39
- unstructured_ingest/embed/voyageai.py +16 -42
- unstructured_ingest/v2/errors.py +7 -0
- unstructured_ingest/v2/processes/connectors/google_drive.py +132 -3
- unstructured_ingest/v2/processes/connectors/neo4j.py +129 -43
- unstructured_ingest/v2/processes/connectors/sql/snowflake.py +53 -3
- unstructured_ingest/v2/processes/embedder.py +9 -7
- {unstructured_ingest-0.5.1.dist-info → unstructured_ingest-0.5.3.dist-info}/METADATA +99 -87
- {unstructured_ingest-0.5.1.dist-info → unstructured_ingest-0.5.3.dist-info}/RECORD +25 -25
- {unstructured_ingest-0.5.1.dist-info → unstructured_ingest-0.5.3.dist-info}/WHEEL +1 -1
- {unstructured_ingest-0.5.1.dist-info → unstructured_ingest-0.5.3.dist-info}/LICENSE.md +0 -0
- {unstructured_ingest-0.5.1.dist-info → unstructured_ingest-0.5.3.dist-info}/entry_points.txt +0 -0
- {unstructured_ingest-0.5.1.dist-info → unstructured_ingest-0.5.3.dist-info}/top_level.txt +0 -0
|
@@ -8,9 +8,9 @@ from contextlib import asynccontextmanager
|
|
|
8
8
|
from dataclasses import dataclass
|
|
9
9
|
from enum import Enum
|
|
10
10
|
from pathlib import Path
|
|
11
|
-
from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional
|
|
11
|
+
from typing import TYPE_CHECKING, Any, AsyncGenerator, Literal, Optional
|
|
12
12
|
|
|
13
|
-
from pydantic import BaseModel, ConfigDict, Field, Secret
|
|
13
|
+
from pydantic import BaseModel, ConfigDict, Field, Secret, field_validator
|
|
14
14
|
|
|
15
15
|
from unstructured_ingest.error import DestinationConnectionError
|
|
16
16
|
from unstructured_ingest.logger import logger
|
|
@@ -30,6 +30,8 @@ from unstructured_ingest.v2.processes.connector_registry import (
|
|
|
30
30
|
DestinationRegistryEntry,
|
|
31
31
|
)
|
|
32
32
|
|
|
33
|
+
SimilarityFunction = Literal["cosine"]
|
|
34
|
+
|
|
33
35
|
if TYPE_CHECKING:
|
|
34
36
|
from neo4j import AsyncDriver, Auth
|
|
35
37
|
from networkx import Graph, MultiDiGraph
|
|
@@ -44,9 +46,9 @@ class Neo4jAccessConfig(AccessConfig):
|
|
|
44
46
|
class Neo4jConnectionConfig(ConnectionConfig):
|
|
45
47
|
access_config: Secret[Neo4jAccessConfig]
|
|
46
48
|
connector_type: str = Field(default=CONNECTOR_TYPE, init=False)
|
|
47
|
-
username: str
|
|
49
|
+
username: str = Field(default="neo4j")
|
|
48
50
|
uri: str = Field(description="Neo4j Connection URI <scheme>://<host>:<port>")
|
|
49
|
-
database: str = Field(description="Name of the target database")
|
|
51
|
+
database: str = Field(default="neo4j", description="Name of the target database")
|
|
50
52
|
|
|
51
53
|
@requires_dependencies(["neo4j"], extras="neo4j")
|
|
52
54
|
@asynccontextmanager
|
|
@@ -186,8 +188,8 @@ class _GraphData(BaseModel):
|
|
|
186
188
|
nodes = list(nx_graph.nodes())
|
|
187
189
|
edges = [
|
|
188
190
|
_Edge(
|
|
189
|
-
|
|
190
|
-
|
|
191
|
+
source=u,
|
|
192
|
+
destination=v,
|
|
191
193
|
relationship=Relationship(data_dict["relationship"]),
|
|
192
194
|
)
|
|
193
195
|
for u, v, data_dict in nx_graph.edges(data=True)
|
|
@@ -198,19 +200,30 @@ class _GraphData(BaseModel):
|
|
|
198
200
|
class _Node(BaseModel):
|
|
199
201
|
model_config = ConfigDict()
|
|
200
202
|
|
|
201
|
-
|
|
202
|
-
labels: list[Label] = Field(default_factory=list)
|
|
203
|
+
labels: list[Label]
|
|
203
204
|
properties: dict = Field(default_factory=dict)
|
|
205
|
+
id_: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
|
204
206
|
|
|
205
207
|
def __hash__(self):
|
|
206
208
|
return hash(self.id_)
|
|
207
209
|
|
|
210
|
+
@property
|
|
211
|
+
def main_label(self) -> Label:
|
|
212
|
+
return self.labels[0]
|
|
213
|
+
|
|
214
|
+
@classmethod
|
|
215
|
+
@field_validator("labels", mode="after")
|
|
216
|
+
def require_at_least_one_label(cls, value: list[Label]) -> list[Label]:
|
|
217
|
+
if not value:
|
|
218
|
+
raise ValueError("Node must have at least one label.")
|
|
219
|
+
return value
|
|
220
|
+
|
|
208
221
|
|
|
209
222
|
class _Edge(BaseModel):
|
|
210
223
|
model_config = ConfigDict()
|
|
211
224
|
|
|
212
|
-
|
|
213
|
-
|
|
225
|
+
source: _Node
|
|
226
|
+
destination: _Node
|
|
214
227
|
relationship: Relationship
|
|
215
228
|
|
|
216
229
|
|
|
@@ -229,7 +242,14 @@ class Relationship(Enum):
|
|
|
229
242
|
|
|
230
243
|
class Neo4jUploaderConfig(UploaderConfig):
|
|
231
244
|
batch_size: int = Field(
|
|
232
|
-
default=
|
|
245
|
+
default=1000, description="Maximal number of nodes/relationships created per transaction."
|
|
246
|
+
)
|
|
247
|
+
similarity_function: SimilarityFunction = Field(
|
|
248
|
+
default="cosine",
|
|
249
|
+
description="Vector similarity function used to create index on Chunk nodes",
|
|
250
|
+
)
|
|
251
|
+
create_destination: bool = Field(
|
|
252
|
+
default=True, description="Create destination if it does not exist"
|
|
233
253
|
)
|
|
234
254
|
|
|
235
255
|
|
|
@@ -257,6 +277,13 @@ class Neo4jUploader(Uploader):
|
|
|
257
277
|
graph_data = _GraphData.model_validate(staged_data)
|
|
258
278
|
async with self.connection_config.get_client() as client:
|
|
259
279
|
await self._create_uniqueness_constraints(client)
|
|
280
|
+
embedding_dimensions = self._get_embedding_dimensions(graph_data)
|
|
281
|
+
if embedding_dimensions and self.upload_config.create_destination:
|
|
282
|
+
await self._create_vector_index(
|
|
283
|
+
client,
|
|
284
|
+
dimensions=embedding_dimensions,
|
|
285
|
+
similarity_function=self.upload_config.similarity_function,
|
|
286
|
+
)
|
|
260
287
|
await self._delete_old_data_if_exists(file_data, client=client)
|
|
261
288
|
await self._merge_graph(graph_data=graph_data, client=client)
|
|
262
289
|
|
|
@@ -274,13 +301,33 @@ class Neo4jUploader(Uploader):
|
|
|
274
301
|
"""
|
|
275
302
|
)
|
|
276
303
|
|
|
304
|
+
async def _create_vector_index(
|
|
305
|
+
self, client: AsyncDriver, dimensions: int, similarity_function: SimilarityFunction
|
|
306
|
+
) -> None:
|
|
307
|
+
label = Label.CHUNK
|
|
308
|
+
logger.info(
|
|
309
|
+
f"Creating index on nodes labeled '{label.value}' if it does not already exist."
|
|
310
|
+
)
|
|
311
|
+
index_name = f"{label.value.lower()}_vector"
|
|
312
|
+
await client.execute_query(
|
|
313
|
+
f"""
|
|
314
|
+
CREATE VECTOR INDEX {index_name} IF NOT EXISTS
|
|
315
|
+
FOR (n:{label.value}) ON n.embedding
|
|
316
|
+
OPTIONS {{indexConfig: {{
|
|
317
|
+
`vector.similarity_function`: '{similarity_function}',
|
|
318
|
+
`vector.dimensions`: {dimensions}}}
|
|
319
|
+
}}
|
|
320
|
+
"""
|
|
321
|
+
)
|
|
322
|
+
|
|
277
323
|
async def _delete_old_data_if_exists(self, file_data: FileData, client: AsyncDriver) -> None:
|
|
278
324
|
logger.info(f"Deleting old data for the record '{file_data.identifier}' (if present).")
|
|
279
325
|
_, summary, _ = await client.execute_query(
|
|
280
326
|
f"""
|
|
281
|
-
MATCH (n: {Label.DOCUMENT.value} {{id: $identifier}})
|
|
282
|
-
MATCH (n)--(m: {Label.CHUNK.value}
|
|
283
|
-
DETACH DELETE m
|
|
327
|
+
MATCH (n: `{Label.DOCUMENT.value}` {{id: $identifier}})
|
|
328
|
+
MATCH (n)--(m: `{Label.CHUNK.value}`|`{Label.UNSTRUCTURED_ELEMENT.value}`)
|
|
329
|
+
DETACH DELETE m
|
|
330
|
+
DETACH DELETE n""",
|
|
284
331
|
identifier=file_data.identifier,
|
|
285
332
|
)
|
|
286
333
|
logger.info(
|
|
@@ -289,16 +336,15 @@ class Neo4jUploader(Uploader):
|
|
|
289
336
|
)
|
|
290
337
|
|
|
291
338
|
async def _merge_graph(self, graph_data: _GraphData, client: AsyncDriver) -> None:
|
|
292
|
-
nodes_by_labels: defaultdict[
|
|
339
|
+
nodes_by_labels: defaultdict[Label, list[_Node]] = defaultdict(list)
|
|
293
340
|
for node in graph_data.nodes:
|
|
294
|
-
nodes_by_labels[
|
|
295
|
-
|
|
341
|
+
nodes_by_labels[node.main_label].append(node)
|
|
296
342
|
logger.info(f"Merging {len(graph_data.nodes)} graph nodes.")
|
|
297
343
|
# NOTE: Processed in parallel as there's no overlap between accessed nodes
|
|
298
344
|
await self._execute_queries(
|
|
299
345
|
[
|
|
300
|
-
self._create_nodes_query(nodes_batch,
|
|
301
|
-
for
|
|
346
|
+
self._create_nodes_query(nodes_batch, label)
|
|
347
|
+
for label, nodes in nodes_by_labels.items()
|
|
302
348
|
for nodes_batch in batch_generator(nodes, batch_size=self.upload_config.batch_size)
|
|
303
349
|
],
|
|
304
350
|
client=client,
|
|
@@ -306,16 +352,23 @@ class Neo4jUploader(Uploader):
|
|
|
306
352
|
)
|
|
307
353
|
logger.info(f"Finished merging {len(graph_data.nodes)} graph nodes.")
|
|
308
354
|
|
|
309
|
-
edges_by_relationship: defaultdict[Relationship, list[_Edge]] =
|
|
355
|
+
edges_by_relationship: defaultdict[tuple[Relationship, Label, Label], list[_Edge]] = (
|
|
356
|
+
defaultdict(list)
|
|
357
|
+
)
|
|
310
358
|
for edge in graph_data.edges:
|
|
311
|
-
|
|
359
|
+
key = (edge.relationship, edge.source.main_label, edge.destination.main_label)
|
|
360
|
+
edges_by_relationship[key].append(edge)
|
|
312
361
|
|
|
313
362
|
logger.info(f"Merging {len(graph_data.edges)} graph relationships (edges).")
|
|
314
363
|
# NOTE: Processed sequentially to avoid queries locking node access to one another
|
|
315
364
|
await self._execute_queries(
|
|
316
365
|
[
|
|
317
|
-
self._create_edges_query(edges_batch, relationship)
|
|
318
|
-
for
|
|
366
|
+
self._create_edges_query(edges_batch, relationship, source_label, destination_label)
|
|
367
|
+
for (
|
|
368
|
+
relationship,
|
|
369
|
+
source_label,
|
|
370
|
+
destination_label,
|
|
371
|
+
), edges in edges_by_relationship.items()
|
|
319
372
|
for edges_batch in batch_generator(edges, batch_size=self.upload_config.batch_size)
|
|
320
373
|
],
|
|
321
374
|
client=client,
|
|
@@ -328,53 +381,86 @@ class Neo4jUploader(Uploader):
|
|
|
328
381
|
client: AsyncDriver,
|
|
329
382
|
in_parallel: bool = False,
|
|
330
383
|
) -> None:
|
|
384
|
+
from neo4j import EagerResult
|
|
385
|
+
|
|
386
|
+
results: list[EagerResult] = []
|
|
387
|
+
logger.info(
|
|
388
|
+
f"Executing {len(queries_with_parameters)} "
|
|
389
|
+
+ f"{'parallel' if in_parallel else 'sequential'} Cypher statements."
|
|
390
|
+
)
|
|
331
391
|
if in_parallel:
|
|
332
|
-
|
|
333
|
-
await asyncio.gather(
|
|
392
|
+
results = await asyncio.gather(
|
|
334
393
|
*[
|
|
335
394
|
client.execute_query(query, parameters_=parameters)
|
|
336
395
|
for query, parameters in queries_with_parameters
|
|
337
396
|
]
|
|
338
397
|
)
|
|
339
|
-
logger.info("Finished executing parallel queries.")
|
|
340
398
|
else:
|
|
341
|
-
logger.info(f"Executing {len(queries_with_parameters)} queries sequentially.")
|
|
342
399
|
for i, (query, parameters) in enumerate(queries_with_parameters):
|
|
343
|
-
logger.info(f"
|
|
344
|
-
await client.execute_query(query, parameters_=parameters)
|
|
345
|
-
logger.info(f"
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
400
|
+
logger.info(f"Statement #{i} started.")
|
|
401
|
+
results.append(await client.execute_query(query, parameters_=parameters))
|
|
402
|
+
logger.info(f"Statement #{i} finished.")
|
|
403
|
+
nodeCount = sum([res.summary.counters.nodes_created for res in results])
|
|
404
|
+
relCount = sum([res.summary.counters.relationships_created for res in results])
|
|
405
|
+
logger.info(
|
|
406
|
+
f"Finished executing all ({len(queries_with_parameters)}) "
|
|
407
|
+
+ f"{'parallel' if in_parallel else 'sequential'} Cypher statements. "
|
|
408
|
+
+ f"Created {nodeCount} nodes, {relCount} relationships."
|
|
409
|
+
)
|
|
349
410
|
|
|
350
411
|
@staticmethod
|
|
351
|
-
def _create_nodes_query(nodes: list[_Node],
|
|
352
|
-
|
|
353
|
-
logger.info(f"Preparing MERGE query for {len(nodes)} nodes labeled '{labels_string}'.")
|
|
412
|
+
def _create_nodes_query(nodes: list[_Node], label: Label) -> tuple[str, dict]:
|
|
413
|
+
logger.info(f"Preparing MERGE query for {len(nodes)} nodes labeled '{label}'.")
|
|
354
414
|
query_string = f"""
|
|
355
415
|
UNWIND $nodes AS node
|
|
356
|
-
MERGE (n: {
|
|
416
|
+
MERGE (n: `{label.value}` {{id: node.id}})
|
|
357
417
|
SET n += node.properties
|
|
418
|
+
SET n:$(node.labels)
|
|
419
|
+
WITH * WHERE node.vector IS NOT NULL
|
|
420
|
+
CALL db.create.setNodeVectorProperty(n, 'embedding', node.vector)
|
|
358
421
|
"""
|
|
359
|
-
parameters = {
|
|
422
|
+
parameters = {
|
|
423
|
+
"nodes": [
|
|
424
|
+
{
|
|
425
|
+
"id": node.id_,
|
|
426
|
+
"labels": [l.value for l in node.labels if l != label], # noqa: E741
|
|
427
|
+
"vector": node.properties.pop("embedding", None),
|
|
428
|
+
"properties": node.properties,
|
|
429
|
+
}
|
|
430
|
+
for node in nodes
|
|
431
|
+
]
|
|
432
|
+
}
|
|
360
433
|
return query_string, parameters
|
|
361
434
|
|
|
362
435
|
@staticmethod
|
|
363
|
-
def _create_edges_query(
|
|
436
|
+
def _create_edges_query(
|
|
437
|
+
edges: list[_Edge],
|
|
438
|
+
relationship: Relationship,
|
|
439
|
+
source_label: Label,
|
|
440
|
+
destination_label: Label,
|
|
441
|
+
) -> tuple[str, dict]:
|
|
364
442
|
logger.info(f"Preparing MERGE query for {len(edges)} {relationship} relationships.")
|
|
365
443
|
query_string = f"""
|
|
366
444
|
UNWIND $edges AS edge
|
|
367
|
-
MATCH (u {{id: edge.source}})
|
|
368
|
-
MATCH (v {{id: edge.destination}})
|
|
369
|
-
MERGE (u)-[
|
|
445
|
+
MATCH (u: `{source_label.value}` {{id: edge.source}})
|
|
446
|
+
MATCH (v: `{destination_label.value}` {{id: edge.destination}})
|
|
447
|
+
MERGE (u)-[:`{relationship.value}`]->(v)
|
|
370
448
|
"""
|
|
371
449
|
parameters = {
|
|
372
450
|
"edges": [
|
|
373
|
-
{"source": edge.
|
|
451
|
+
{"source": edge.source.id_, "destination": edge.destination.id_} for edge in edges
|
|
374
452
|
]
|
|
375
453
|
}
|
|
376
454
|
return query_string, parameters
|
|
377
455
|
|
|
456
|
+
def _get_embedding_dimensions(self, graph_data: _GraphData) -> int | None:
|
|
457
|
+
"""Embedding dimensions inferred from chunk nodes or None if it can't be determined."""
|
|
458
|
+
for node in graph_data.nodes:
|
|
459
|
+
if Label.CHUNK in node.labels and "embeddings" in node.properties:
|
|
460
|
+
return len(node.properties["embeddings"])
|
|
461
|
+
|
|
462
|
+
return None
|
|
463
|
+
|
|
378
464
|
|
|
379
465
|
neo4j_destination_entry = DestinationRegistryEntry(
|
|
380
466
|
connection_config=Neo4jConnectionConfig,
|
|
@@ -1,6 +1,7 @@
|
|
|
1
|
+
import json
|
|
1
2
|
from contextlib import contextmanager
|
|
2
3
|
from dataclasses import dataclass, field
|
|
3
|
-
from typing import TYPE_CHECKING, Generator, Optional
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Generator, Optional
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
6
7
|
import pandas as pd
|
|
@@ -15,6 +16,7 @@ from unstructured_ingest.v2.processes.connector_registry import (
|
|
|
15
16
|
SourceRegistryEntry,
|
|
16
17
|
)
|
|
17
18
|
from unstructured_ingest.v2.processes.connectors.sql.sql import (
|
|
19
|
+
_DATE_COLUMNS,
|
|
18
20
|
SQLAccessConfig,
|
|
19
21
|
SqlBatchFileData,
|
|
20
22
|
SQLConnectionConfig,
|
|
@@ -26,6 +28,7 @@ from unstructured_ingest.v2.processes.connectors.sql.sql import (
|
|
|
26
28
|
SQLUploaderConfig,
|
|
27
29
|
SQLUploadStager,
|
|
28
30
|
SQLUploadStagerConfig,
|
|
31
|
+
parse_date_string,
|
|
29
32
|
)
|
|
30
33
|
|
|
31
34
|
if TYPE_CHECKING:
|
|
@@ -34,6 +37,17 @@ if TYPE_CHECKING:
|
|
|
34
37
|
|
|
35
38
|
CONNECTOR_TYPE = "snowflake"
|
|
36
39
|
|
|
40
|
+
_ARRAY_COLUMNS = (
|
|
41
|
+
"embeddings",
|
|
42
|
+
"languages",
|
|
43
|
+
"link_urls",
|
|
44
|
+
"link_texts",
|
|
45
|
+
"sent_from",
|
|
46
|
+
"sent_to",
|
|
47
|
+
"emphasized_text_contents",
|
|
48
|
+
"emphasized_text_tags",
|
|
49
|
+
)
|
|
50
|
+
|
|
37
51
|
|
|
38
52
|
class SnowflakeAccessConfig(SQLAccessConfig):
|
|
39
53
|
password: Optional[str] = Field(default=None, description="DB password")
|
|
@@ -160,6 +174,42 @@ class SnowflakeUploader(SQLUploader):
|
|
|
160
174
|
connector_type: str = CONNECTOR_TYPE
|
|
161
175
|
values_delimiter: str = "?"
|
|
162
176
|
|
|
177
|
+
def prepare_data(
|
|
178
|
+
self, columns: list[str], data: tuple[tuple[Any, ...], ...]
|
|
179
|
+
) -> list[tuple[Any, ...]]:
|
|
180
|
+
output = []
|
|
181
|
+
for row in data:
|
|
182
|
+
parsed = []
|
|
183
|
+
for column_name, value in zip(columns, row):
|
|
184
|
+
if column_name in _DATE_COLUMNS:
|
|
185
|
+
if value is None or pd.isna(value): # pandas is nan
|
|
186
|
+
parsed.append(None)
|
|
187
|
+
else:
|
|
188
|
+
parsed.append(parse_date_string(value))
|
|
189
|
+
elif column_name in _ARRAY_COLUMNS:
|
|
190
|
+
if not isinstance(value, list) and (
|
|
191
|
+
value is None or pd.isna(value)
|
|
192
|
+
): # pandas is nan
|
|
193
|
+
parsed.append(None)
|
|
194
|
+
else:
|
|
195
|
+
parsed.append(json.dumps(value))
|
|
196
|
+
else:
|
|
197
|
+
parsed.append(value)
|
|
198
|
+
output.append(tuple(parsed))
|
|
199
|
+
return output
|
|
200
|
+
|
|
201
|
+
def _parse_values(self, columns: list[str]) -> str:
|
|
202
|
+
return ",".join(
|
|
203
|
+
[
|
|
204
|
+
(
|
|
205
|
+
f"PARSE_JSON({self.values_delimiter})"
|
|
206
|
+
if col in _ARRAY_COLUMNS
|
|
207
|
+
else self.values_delimiter
|
|
208
|
+
)
|
|
209
|
+
for col in columns
|
|
210
|
+
]
|
|
211
|
+
)
|
|
212
|
+
|
|
163
213
|
def upload_dataframe(self, df: pd.DataFrame, file_data: FileData) -> None:
|
|
164
214
|
if self.can_delete():
|
|
165
215
|
self.delete_by_record_id(file_data=file_data)
|
|
@@ -173,10 +223,10 @@ class SnowflakeUploader(SQLUploader):
|
|
|
173
223
|
self._fit_to_schema(df=df)
|
|
174
224
|
|
|
175
225
|
columns = list(df.columns)
|
|
176
|
-
stmt = "INSERT INTO {table_name} ({columns})
|
|
226
|
+
stmt = "INSERT INTO {table_name} ({columns}) SELECT {values}".format(
|
|
177
227
|
table_name=self.upload_config.table_name,
|
|
178
228
|
columns=",".join(columns),
|
|
179
|
-
values=
|
|
229
|
+
values=self._parse_values(columns),
|
|
180
230
|
)
|
|
181
231
|
logger.info(
|
|
182
232
|
f"writing a total of {len(df)} elements via"
|
|
@@ -92,18 +92,20 @@ class EmbedderConfig(BaseModel):
|
|
|
92
92
|
|
|
93
93
|
return OctoAIEmbeddingEncoder(config=OctoAiEmbeddingConfig.model_validate(embedding_kwargs))
|
|
94
94
|
|
|
95
|
-
def get_bedrock_embedder(self) -> "BaseEmbeddingEncoder":
|
|
95
|
+
def get_bedrock_embedder(self, embedding_kwargs: dict) -> "BaseEmbeddingEncoder":
|
|
96
96
|
from unstructured_ingest.embed.bedrock import (
|
|
97
97
|
BedrockEmbeddingConfig,
|
|
98
98
|
BedrockEmbeddingEncoder,
|
|
99
99
|
)
|
|
100
100
|
|
|
101
|
+
embedding_kwargs = embedding_kwargs | {
|
|
102
|
+
"aws_access_key_id": self.embedding_aws_access_key_id,
|
|
103
|
+
"aws_secret_access_key": self.embedding_aws_secret_access_key.get_secret_value(),
|
|
104
|
+
"region_name": self.embedding_aws_region,
|
|
105
|
+
}
|
|
106
|
+
|
|
101
107
|
return BedrockEmbeddingEncoder(
|
|
102
|
-
config=BedrockEmbeddingConfig(
|
|
103
|
-
aws_access_key_id=self.embedding_aws_access_key_id,
|
|
104
|
-
aws_secret_access_key=self.embedding_aws_secret_access_key.get_secret_value(),
|
|
105
|
-
region_name=self.embedding_aws_region,
|
|
106
|
-
)
|
|
108
|
+
config=BedrockEmbeddingConfig.model_validate(embedding_kwargs)
|
|
107
109
|
)
|
|
108
110
|
|
|
109
111
|
def get_vertexai_embedder(self, embedding_kwargs: dict) -> "BaseEmbeddingEncoder":
|
|
@@ -163,7 +165,7 @@ class EmbedderConfig(BaseModel):
|
|
|
163
165
|
return self.get_octoai_embedder(embedding_kwargs=kwargs)
|
|
164
166
|
|
|
165
167
|
if self.embedding_provider == "bedrock":
|
|
166
|
-
return self.get_bedrock_embedder()
|
|
168
|
+
return self.get_bedrock_embedder(embedding_kwargs=kwargs)
|
|
167
169
|
|
|
168
170
|
if self.embedding_provider == "vertexai":
|
|
169
171
|
return self.get_vertexai_embedder(embedding_kwargs=kwargs)
|