unstructured-ingest 1.2.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of unstructured-ingest might be problematic. Click here for more details.
- unstructured_ingest/__init__.py +1 -0
- unstructured_ingest/__version__.py +1 -0
- unstructured_ingest/cli/README.md +28 -0
- unstructured_ingest/cli/__init__.py +0 -0
- unstructured_ingest/cli/base/__init__.py +4 -0
- unstructured_ingest/cli/base/cmd.py +269 -0
- unstructured_ingest/cli/base/dest.py +84 -0
- unstructured_ingest/cli/base/importer.py +34 -0
- unstructured_ingest/cli/base/src.py +75 -0
- unstructured_ingest/cli/cli.py +24 -0
- unstructured_ingest/cli/cmds.py +14 -0
- unstructured_ingest/cli/utils/__init__.py +0 -0
- unstructured_ingest/cli/utils/click.py +237 -0
- unstructured_ingest/cli/utils/model_conversion.py +222 -0
- unstructured_ingest/data_types/__init__.py +0 -0
- unstructured_ingest/data_types/entities.py +17 -0
- unstructured_ingest/data_types/file_data.py +116 -0
- unstructured_ingest/embed/__init__.py +0 -0
- unstructured_ingest/embed/azure_openai.py +63 -0
- unstructured_ingest/embed/bedrock.py +323 -0
- unstructured_ingest/embed/huggingface.py +69 -0
- unstructured_ingest/embed/interfaces.py +146 -0
- unstructured_ingest/embed/mixedbreadai.py +134 -0
- unstructured_ingest/embed/octoai.py +133 -0
- unstructured_ingest/embed/openai.py +142 -0
- unstructured_ingest/embed/togetherai.py +116 -0
- unstructured_ingest/embed/vertexai.py +109 -0
- unstructured_ingest/embed/voyageai.py +130 -0
- unstructured_ingest/error.py +156 -0
- unstructured_ingest/errors_v2.py +156 -0
- unstructured_ingest/interfaces/__init__.py +27 -0
- unstructured_ingest/interfaces/connector.py +56 -0
- unstructured_ingest/interfaces/downloader.py +90 -0
- unstructured_ingest/interfaces/indexer.py +29 -0
- unstructured_ingest/interfaces/process.py +22 -0
- unstructured_ingest/interfaces/processor.py +88 -0
- unstructured_ingest/interfaces/upload_stager.py +89 -0
- unstructured_ingest/interfaces/uploader.py +67 -0
- unstructured_ingest/logger.py +39 -0
- unstructured_ingest/main.py +11 -0
- unstructured_ingest/otel.py +128 -0
- unstructured_ingest/pipeline/__init__.py +0 -0
- unstructured_ingest/pipeline/interfaces.py +211 -0
- unstructured_ingest/pipeline/otel.py +32 -0
- unstructured_ingest/pipeline/pipeline.py +408 -0
- unstructured_ingest/pipeline/steps/__init__.py +0 -0
- unstructured_ingest/pipeline/steps/chunk.py +78 -0
- unstructured_ingest/pipeline/steps/download.py +206 -0
- unstructured_ingest/pipeline/steps/embed.py +77 -0
- unstructured_ingest/pipeline/steps/filter.py +35 -0
- unstructured_ingest/pipeline/steps/index.py +86 -0
- unstructured_ingest/pipeline/steps/partition.py +77 -0
- unstructured_ingest/pipeline/steps/stage.py +65 -0
- unstructured_ingest/pipeline/steps/uncompress.py +50 -0
- unstructured_ingest/pipeline/steps/upload.py +58 -0
- unstructured_ingest/processes/__init__.py +18 -0
- unstructured_ingest/processes/chunker.py +131 -0
- unstructured_ingest/processes/connector_registry.py +69 -0
- unstructured_ingest/processes/connectors/__init__.py +129 -0
- unstructured_ingest/processes/connectors/airtable.py +238 -0
- unstructured_ingest/processes/connectors/assets/__init__.py +0 -0
- unstructured_ingest/processes/connectors/assets/databricks_delta_table_schema.sql +9 -0
- unstructured_ingest/processes/connectors/assets/weaviate_collection_config.json +23 -0
- unstructured_ingest/processes/connectors/astradb.py +592 -0
- unstructured_ingest/processes/connectors/azure_ai_search.py +275 -0
- unstructured_ingest/processes/connectors/chroma.py +193 -0
- unstructured_ingest/processes/connectors/confluence.py +527 -0
- unstructured_ingest/processes/connectors/couchbase.py +336 -0
- unstructured_ingest/processes/connectors/databricks/__init__.py +58 -0
- unstructured_ingest/processes/connectors/databricks/volumes.py +233 -0
- unstructured_ingest/processes/connectors/databricks/volumes_aws.py +93 -0
- unstructured_ingest/processes/connectors/databricks/volumes_azure.py +108 -0
- unstructured_ingest/processes/connectors/databricks/volumes_gcp.py +91 -0
- unstructured_ingest/processes/connectors/databricks/volumes_native.py +92 -0
- unstructured_ingest/processes/connectors/databricks/volumes_table.py +187 -0
- unstructured_ingest/processes/connectors/delta_table.py +310 -0
- unstructured_ingest/processes/connectors/discord.py +161 -0
- unstructured_ingest/processes/connectors/duckdb/__init__.py +15 -0
- unstructured_ingest/processes/connectors/duckdb/base.py +103 -0
- unstructured_ingest/processes/connectors/duckdb/duckdb.py +130 -0
- unstructured_ingest/processes/connectors/duckdb/motherduck.py +130 -0
- unstructured_ingest/processes/connectors/elasticsearch/__init__.py +19 -0
- unstructured_ingest/processes/connectors/elasticsearch/elasticsearch.py +478 -0
- unstructured_ingest/processes/connectors/elasticsearch/opensearch.py +523 -0
- unstructured_ingest/processes/connectors/fsspec/__init__.py +37 -0
- unstructured_ingest/processes/connectors/fsspec/azure.py +203 -0
- unstructured_ingest/processes/connectors/fsspec/box.py +176 -0
- unstructured_ingest/processes/connectors/fsspec/dropbox.py +238 -0
- unstructured_ingest/processes/connectors/fsspec/fsspec.py +475 -0
- unstructured_ingest/processes/connectors/fsspec/gcs.py +203 -0
- unstructured_ingest/processes/connectors/fsspec/s3.py +253 -0
- unstructured_ingest/processes/connectors/fsspec/sftp.py +177 -0
- unstructured_ingest/processes/connectors/fsspec/utils.py +17 -0
- unstructured_ingest/processes/connectors/github.py +226 -0
- unstructured_ingest/processes/connectors/gitlab.py +270 -0
- unstructured_ingest/processes/connectors/google_drive.py +848 -0
- unstructured_ingest/processes/connectors/ibm_watsonx/__init__.py +10 -0
- unstructured_ingest/processes/connectors/ibm_watsonx/ibm_watsonx_s3.py +367 -0
- unstructured_ingest/processes/connectors/jira.py +522 -0
- unstructured_ingest/processes/connectors/kafka/__init__.py +17 -0
- unstructured_ingest/processes/connectors/kafka/cloud.py +121 -0
- unstructured_ingest/processes/connectors/kafka/kafka.py +275 -0
- unstructured_ingest/processes/connectors/kafka/local.py +103 -0
- unstructured_ingest/processes/connectors/kdbai.py +156 -0
- unstructured_ingest/processes/connectors/lancedb/__init__.py +30 -0
- unstructured_ingest/processes/connectors/lancedb/aws.py +43 -0
- unstructured_ingest/processes/connectors/lancedb/azure.py +43 -0
- unstructured_ingest/processes/connectors/lancedb/cloud.py +42 -0
- unstructured_ingest/processes/connectors/lancedb/gcp.py +44 -0
- unstructured_ingest/processes/connectors/lancedb/lancedb.py +181 -0
- unstructured_ingest/processes/connectors/lancedb/local.py +44 -0
- unstructured_ingest/processes/connectors/local.py +227 -0
- unstructured_ingest/processes/connectors/milvus.py +311 -0
- unstructured_ingest/processes/connectors/mongodb.py +389 -0
- unstructured_ingest/processes/connectors/neo4j.py +534 -0
- unstructured_ingest/processes/connectors/notion/__init__.py +0 -0
- unstructured_ingest/processes/connectors/notion/client.py +349 -0
- unstructured_ingest/processes/connectors/notion/connector.py +350 -0
- unstructured_ingest/processes/connectors/notion/helpers.py +448 -0
- unstructured_ingest/processes/connectors/notion/ingest_backoff/__init__.py +3 -0
- unstructured_ingest/processes/connectors/notion/ingest_backoff/_common.py +102 -0
- unstructured_ingest/processes/connectors/notion/ingest_backoff/_wrapper.py +126 -0
- unstructured_ingest/processes/connectors/notion/ingest_backoff/types.py +24 -0
- unstructured_ingest/processes/connectors/notion/interfaces.py +32 -0
- unstructured_ingest/processes/connectors/notion/types/__init__.py +0 -0
- unstructured_ingest/processes/connectors/notion/types/block.py +96 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/__init__.py +63 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/bookmark.py +40 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/breadcrumb.py +21 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/bulleted_list_item.py +31 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/callout.py +131 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/child_database.py +23 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/child_page.py +23 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/code.py +43 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/column_list.py +35 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/divider.py +22 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/embed.py +36 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/equation.py +23 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/file.py +49 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/heading.py +37 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/image.py +21 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/link_preview.py +24 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/link_to_page.py +29 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/numbered_list.py +29 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/paragraph.py +31 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/pdf.py +49 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/quote.py +37 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/synced_block.py +109 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/table.py +60 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/table_of_contents.py +23 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/template.py +30 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/todo.py +42 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/toggle.py +37 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/unsupported.py +20 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/video.py +22 -0
- unstructured_ingest/processes/connectors/notion/types/database.py +73 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/__init__.py +125 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/checkbox.py +39 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/created_by.py +36 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/created_time.py +35 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/date.py +42 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/email.py +37 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/files.py +38 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/formula.py +50 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/last_edited_by.py +34 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/last_edited_time.py +35 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/multiselect.py +74 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/number.py +50 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/people.py +42 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/phone_number.py +37 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/relation.py +68 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/rich_text.py +44 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/rollup.py +57 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/select.py +70 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/status.py +82 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/title.py +38 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/unique_id.py +51 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/url.py +38 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/verification.py +79 -0
- unstructured_ingest/processes/connectors/notion/types/date.py +29 -0
- unstructured_ingest/processes/connectors/notion/types/file.py +54 -0
- unstructured_ingest/processes/connectors/notion/types/page.py +52 -0
- unstructured_ingest/processes/connectors/notion/types/parent.py +66 -0
- unstructured_ingest/processes/connectors/notion/types/rich_text.py +189 -0
- unstructured_ingest/processes/connectors/notion/types/user.py +83 -0
- unstructured_ingest/processes/connectors/onedrive.py +485 -0
- unstructured_ingest/processes/connectors/outlook.py +242 -0
- unstructured_ingest/processes/connectors/pinecone.py +400 -0
- unstructured_ingest/processes/connectors/qdrant/__init__.py +16 -0
- unstructured_ingest/processes/connectors/qdrant/cloud.py +59 -0
- unstructured_ingest/processes/connectors/qdrant/local.py +58 -0
- unstructured_ingest/processes/connectors/qdrant/qdrant.py +163 -0
- unstructured_ingest/processes/connectors/qdrant/server.py +60 -0
- unstructured_ingest/processes/connectors/redisdb.py +214 -0
- unstructured_ingest/processes/connectors/salesforce.py +307 -0
- unstructured_ingest/processes/connectors/sharepoint.py +282 -0
- unstructured_ingest/processes/connectors/slack.py +249 -0
- unstructured_ingest/processes/connectors/sql/__init__.py +41 -0
- unstructured_ingest/processes/connectors/sql/databricks_delta_tables.py +228 -0
- unstructured_ingest/processes/connectors/sql/postgres.py +168 -0
- unstructured_ingest/processes/connectors/sql/singlestore.py +176 -0
- unstructured_ingest/processes/connectors/sql/snowflake.py +298 -0
- unstructured_ingest/processes/connectors/sql/sql.py +456 -0
- unstructured_ingest/processes/connectors/sql/sqlite.py +179 -0
- unstructured_ingest/processes/connectors/sql/teradata.py +254 -0
- unstructured_ingest/processes/connectors/sql/vastdb.py +263 -0
- unstructured_ingest/processes/connectors/utils.py +60 -0
- unstructured_ingest/processes/connectors/vectara.py +348 -0
- unstructured_ingest/processes/connectors/weaviate/__init__.py +22 -0
- unstructured_ingest/processes/connectors/weaviate/cloud.py +166 -0
- unstructured_ingest/processes/connectors/weaviate/embedded.py +90 -0
- unstructured_ingest/processes/connectors/weaviate/local.py +73 -0
- unstructured_ingest/processes/connectors/weaviate/weaviate.py +337 -0
- unstructured_ingest/processes/connectors/zendesk/__init__.py +0 -0
- unstructured_ingest/processes/connectors/zendesk/client.py +314 -0
- unstructured_ingest/processes/connectors/zendesk/zendesk.py +241 -0
- unstructured_ingest/processes/embedder.py +203 -0
- unstructured_ingest/processes/filter.py +60 -0
- unstructured_ingest/processes/partitioner.py +233 -0
- unstructured_ingest/processes/uncompress.py +61 -0
- unstructured_ingest/processes/utils/__init__.py +8 -0
- unstructured_ingest/processes/utils/blob_storage.py +32 -0
- unstructured_ingest/processes/utils/logging/connector.py +365 -0
- unstructured_ingest/processes/utils/logging/sanitizer.py +117 -0
- unstructured_ingest/unstructured_api.py +140 -0
- unstructured_ingest/utils/__init__.py +5 -0
- unstructured_ingest/utils/chunking.py +56 -0
- unstructured_ingest/utils/compression.py +72 -0
- unstructured_ingest/utils/constants.py +2 -0
- unstructured_ingest/utils/data_prep.py +216 -0
- unstructured_ingest/utils/dep_check.py +78 -0
- unstructured_ingest/utils/filesystem.py +27 -0
- unstructured_ingest/utils/html.py +174 -0
- unstructured_ingest/utils/ndjson.py +52 -0
- unstructured_ingest/utils/pydantic_models.py +52 -0
- unstructured_ingest/utils/string_and_date_utils.py +74 -0
- unstructured_ingest/utils/table.py +80 -0
- unstructured_ingest/utils/tls.py +15 -0
- unstructured_ingest-1.2.32.dist-info/METADATA +235 -0
- unstructured_ingest-1.2.32.dist-info/RECORD +243 -0
- unstructured_ingest-1.2.32.dist-info/WHEEL +4 -0
- unstructured_ingest-1.2.32.dist-info/entry_points.txt +2 -0
- unstructured_ingest-1.2.32.dist-info/licenses/LICENSE.md +201 -0
|
@@ -0,0 +1,323 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
from contextlib import asynccontextmanager
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import TYPE_CHECKING, AsyncIterable
|
|
7
|
+
|
|
8
|
+
from pydantic import Field, SecretStr
|
|
9
|
+
|
|
10
|
+
from unstructured_ingest.embed.interfaces import (
|
|
11
|
+
EMBEDDINGS_KEY,
|
|
12
|
+
AsyncBaseEmbeddingEncoder,
|
|
13
|
+
BaseEmbeddingEncoder,
|
|
14
|
+
EmbeddingConfig,
|
|
15
|
+
)
|
|
16
|
+
from unstructured_ingest.error import (
|
|
17
|
+
ProviderError,
|
|
18
|
+
RateLimitError,
|
|
19
|
+
UserAuthError,
|
|
20
|
+
UserError,
|
|
21
|
+
is_internal_error,
|
|
22
|
+
)
|
|
23
|
+
from unstructured_ingest.logger import logger
|
|
24
|
+
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
25
|
+
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
from botocore.client import BaseClient
|
|
28
|
+
|
|
29
|
+
class BedrockRuntimeClient(BaseClient):
|
|
30
|
+
def invoke_model(
|
|
31
|
+
self,
|
|
32
|
+
body: str,
|
|
33
|
+
modelId: str,
|
|
34
|
+
accept: str,
|
|
35
|
+
contentType: str,
|
|
36
|
+
inferenceProfileId: str = None,
|
|
37
|
+
) -> dict:
|
|
38
|
+
pass
|
|
39
|
+
|
|
40
|
+
class AsyncBedrockRuntimeClient(BaseClient):
|
|
41
|
+
async def invoke_model(
|
|
42
|
+
self,
|
|
43
|
+
body: str,
|
|
44
|
+
modelId: str,
|
|
45
|
+
accept: str,
|
|
46
|
+
contentType: str,
|
|
47
|
+
inferenceProfileId: str = None,
|
|
48
|
+
) -> dict:
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
class BedrockClient(BaseClient):
|
|
52
|
+
def list_foundation_models(self, byOutputModality: str) -> dict:
|
|
53
|
+
pass
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def conform_query(query: str, provider: str) -> dict:
|
|
57
|
+
# replace newlines, which can negatively affect performance.
|
|
58
|
+
text = query.replace(os.linesep, " ")
|
|
59
|
+
|
|
60
|
+
# format input body for provider
|
|
61
|
+
input_body = {}
|
|
62
|
+
if provider == "cohere":
|
|
63
|
+
if "input_type" not in input_body:
|
|
64
|
+
input_body["input_type"] = "search_document"
|
|
65
|
+
input_body["texts"] = [text]
|
|
66
|
+
else:
|
|
67
|
+
# includes common provider == "amazon"
|
|
68
|
+
input_body["inputText"] = text
|
|
69
|
+
return input_body
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class BedrockEmbeddingConfig(EmbeddingConfig):
|
|
73
|
+
aws_access_key_id: SecretStr | None = Field(description="aws access key id", default=None)
|
|
74
|
+
aws_secret_access_key: SecretStr | None = Field(
|
|
75
|
+
description="aws secret access key", default=None
|
|
76
|
+
)
|
|
77
|
+
region_name: str = Field(
|
|
78
|
+
description="aws region name",
|
|
79
|
+
default_factory=lambda: (
|
|
80
|
+
os.getenv("BEDROCK_REGION_NAME") or
|
|
81
|
+
os.getenv("AWS_DEFAULT_REGION") or
|
|
82
|
+
"us-west-2"
|
|
83
|
+
)
|
|
84
|
+
)
|
|
85
|
+
endpoint_url: str | None = Field(description="custom bedrock endpoint url", default=None)
|
|
86
|
+
access_method: str = Field(
|
|
87
|
+
description="authentication method", default="credentials"
|
|
88
|
+
) # "credentials" or "iam"
|
|
89
|
+
embedder_model_name: str = Field(
|
|
90
|
+
default="amazon.titan-embed-text-v1",
|
|
91
|
+
alias="model_name",
|
|
92
|
+
description="AWS Bedrock model name",
|
|
93
|
+
)
|
|
94
|
+
inference_profile_id: str | None = Field(
|
|
95
|
+
description="AWS Bedrock inference profile ID",
|
|
96
|
+
default_factory=lambda: os.getenv("BEDROCK_INFERENCE_PROFILE_ID"),
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
def wrap_error(self, e: Exception) -> Exception:
|
|
100
|
+
if is_internal_error(e=e):
|
|
101
|
+
return e
|
|
102
|
+
from botocore.exceptions import ClientError
|
|
103
|
+
|
|
104
|
+
if isinstance(e, ClientError):
|
|
105
|
+
# https://docs.aws.amazon.com/awssupport/latest/APIReference/CommonErrors.html
|
|
106
|
+
http_response = e.response
|
|
107
|
+
meta = http_response["ResponseMetadata"]
|
|
108
|
+
http_response_code = meta["HTTPStatusCode"]
|
|
109
|
+
error_code = http_response["Error"]["Code"]
|
|
110
|
+
if http_response_code == 400:
|
|
111
|
+
if error_code == "ValidationError":
|
|
112
|
+
return UserError(http_response["Error"])
|
|
113
|
+
elif error_code == "ThrottlingException":
|
|
114
|
+
return RateLimitError(http_response["Error"])
|
|
115
|
+
elif error_code == "NotAuthorized" or error_code == "AccessDeniedException":
|
|
116
|
+
return UserAuthError(http_response["Error"])
|
|
117
|
+
if http_response_code == 403:
|
|
118
|
+
return UserAuthError(http_response["Error"])
|
|
119
|
+
if 400 <= http_response_code < 500:
|
|
120
|
+
return UserError(http_response["Error"])
|
|
121
|
+
if http_response_code >= 500:
|
|
122
|
+
return ProviderError(http_response["Error"])
|
|
123
|
+
|
|
124
|
+
logger.error(f"unhandled exception from bedrock: {e}", exc_info=True)
|
|
125
|
+
return e
|
|
126
|
+
|
|
127
|
+
def run_precheck(self) -> None:
|
|
128
|
+
# Validate access method and credentials configuration
|
|
129
|
+
if self.access_method == "credentials":
|
|
130
|
+
if not (self.aws_access_key_id and self.aws_secret_access_key):
|
|
131
|
+
raise ValueError(
|
|
132
|
+
"Credentials access method requires aws_access_key_id and aws_secret_access_key"
|
|
133
|
+
)
|
|
134
|
+
elif self.access_method == "iam":
|
|
135
|
+
# For IAM, credentials are handled by AWS SDK
|
|
136
|
+
pass
|
|
137
|
+
else:
|
|
138
|
+
raise ValueError(
|
|
139
|
+
f"Invalid access_method: {self.access_method}. Must be 'credentials' or 'iam'"
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
client = self.get_bedrock_client()
|
|
143
|
+
try:
|
|
144
|
+
model_info = client.list_foundation_models(byOutputModality="EMBEDDING")
|
|
145
|
+
summaries = model_info.get("modelSummaries", [])
|
|
146
|
+
model_ids = [m["modelId"] for m in summaries]
|
|
147
|
+
arns = [":".join(m["modelArn"]) for m in summaries]
|
|
148
|
+
|
|
149
|
+
if self.embedder_model_name not in model_ids and self.embedder_model_name not in arns:
|
|
150
|
+
raise UserError(
|
|
151
|
+
"model '{}' not found either : {} or {}".format(
|
|
152
|
+
self.embedder_model_name, ", ".join(model_ids), ", ".join(arns)
|
|
153
|
+
)
|
|
154
|
+
)
|
|
155
|
+
except Exception as e:
|
|
156
|
+
raise self.wrap_error(e=e)
|
|
157
|
+
|
|
158
|
+
def get_client_kwargs(self) -> dict:
|
|
159
|
+
kwargs = {
|
|
160
|
+
"region_name": self.region_name,
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
if self.endpoint_url:
|
|
164
|
+
kwargs["endpoint_url"] = self.endpoint_url
|
|
165
|
+
|
|
166
|
+
if self.access_method == "credentials":
|
|
167
|
+
if self.aws_access_key_id and self.aws_secret_access_key:
|
|
168
|
+
kwargs["aws_access_key_id"] = self.aws_access_key_id.get_secret_value()
|
|
169
|
+
kwargs["aws_secret_access_key"] = self.aws_secret_access_key.get_secret_value()
|
|
170
|
+
else:
|
|
171
|
+
raise ValueError(
|
|
172
|
+
"Credentials access method requires aws_access_key_id and aws_secret_access_key"
|
|
173
|
+
)
|
|
174
|
+
elif self.access_method == "iam":
|
|
175
|
+
# For IAM, boto3 will use default credential chain (IAM roles, environment, etc.)
|
|
176
|
+
pass
|
|
177
|
+
else:
|
|
178
|
+
raise ValueError(
|
|
179
|
+
f"Invalid access_method: {self.access_method}. Must be 'credentials' or 'iam'"
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
return kwargs
|
|
183
|
+
|
|
184
|
+
@requires_dependencies(
|
|
185
|
+
["boto3"],
|
|
186
|
+
extras="bedrock",
|
|
187
|
+
)
|
|
188
|
+
def get_bedrock_client(self) -> "BedrockClient":
|
|
189
|
+
import boto3
|
|
190
|
+
|
|
191
|
+
bedrock_client = boto3.client(service_name="bedrock", **self.get_client_kwargs())
|
|
192
|
+
|
|
193
|
+
return bedrock_client
|
|
194
|
+
|
|
195
|
+
@requires_dependencies(
|
|
196
|
+
["boto3", "numpy", "botocore"],
|
|
197
|
+
extras="bedrock",
|
|
198
|
+
)
|
|
199
|
+
def get_client(self) -> "BedrockRuntimeClient":
|
|
200
|
+
import boto3
|
|
201
|
+
|
|
202
|
+
bedrock_client = boto3.client(service_name="bedrock-runtime", **self.get_client_kwargs())
|
|
203
|
+
|
|
204
|
+
return bedrock_client
|
|
205
|
+
|
|
206
|
+
@requires_dependencies(
|
|
207
|
+
["aioboto3"],
|
|
208
|
+
extras="bedrock",
|
|
209
|
+
)
|
|
210
|
+
@asynccontextmanager
|
|
211
|
+
async def get_async_client(self) -> AsyncIterable["AsyncBedrockRuntimeClient"]:
|
|
212
|
+
import aioboto3
|
|
213
|
+
|
|
214
|
+
session = aioboto3.Session()
|
|
215
|
+
async with session.client("bedrock-runtime", **self.get_client_kwargs()) as aws_bedrock:
|
|
216
|
+
yield aws_bedrock
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
@dataclass
|
|
220
|
+
class BedrockEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
221
|
+
config: BedrockEmbeddingConfig
|
|
222
|
+
|
|
223
|
+
def precheck(self):
|
|
224
|
+
self.config.run_precheck()
|
|
225
|
+
|
|
226
|
+
def wrap_error(self, e: Exception) -> Exception:
|
|
227
|
+
return self.config.wrap_error(e=e)
|
|
228
|
+
|
|
229
|
+
def embed_query(self, query: str) -> list[float]:
|
|
230
|
+
"""Call out to Bedrock embedding endpoint."""
|
|
231
|
+
provider = self.config.embedder_model_name.split(".")[0]
|
|
232
|
+
body = conform_query(query=query, provider=provider)
|
|
233
|
+
|
|
234
|
+
bedrock_client = self.config.get_client()
|
|
235
|
+
# invoke bedrock API
|
|
236
|
+
try:
|
|
237
|
+
invoke_params = {
|
|
238
|
+
"body": json.dumps(body),
|
|
239
|
+
"modelId": self.config.embedder_model_name,
|
|
240
|
+
"accept": "application/json",
|
|
241
|
+
"contentType": "application/json",
|
|
242
|
+
}
|
|
243
|
+
|
|
244
|
+
# Add inference profile if configured
|
|
245
|
+
if self.config.inference_profile_id:
|
|
246
|
+
invoke_params["inferenceProfileId"] = self.config.inference_profile_id
|
|
247
|
+
|
|
248
|
+
response = bedrock_client.invoke_model(**invoke_params)
|
|
249
|
+
except Exception as e:
|
|
250
|
+
raise self.wrap_error(e=e)
|
|
251
|
+
|
|
252
|
+
# format output based on provider
|
|
253
|
+
response_body = json.loads(response.get("body").read())
|
|
254
|
+
if provider == "cohere":
|
|
255
|
+
return response_body.get("embeddings")[0]
|
|
256
|
+
else:
|
|
257
|
+
# includes common provider == "amazon"
|
|
258
|
+
return response_body.get("embedding")
|
|
259
|
+
|
|
260
|
+
def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
261
|
+
elements = elements.copy()
|
|
262
|
+
elements_with_text = [e for e in elements if e.get("text")]
|
|
263
|
+
if not elements_with_text:
|
|
264
|
+
return elements
|
|
265
|
+
embeddings = [self.embed_query(query=e["text"]) for e in elements_with_text]
|
|
266
|
+
for element, embedding in zip(elements_with_text, embeddings):
|
|
267
|
+
element[EMBEDDINGS_KEY] = embedding
|
|
268
|
+
return elements
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
@dataclass
|
|
272
|
+
class AsyncBedrockEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
|
|
273
|
+
config: BedrockEmbeddingConfig
|
|
274
|
+
|
|
275
|
+
def precheck(self):
|
|
276
|
+
self.config.run_precheck()
|
|
277
|
+
|
|
278
|
+
def wrap_error(self, e: Exception) -> Exception:
|
|
279
|
+
return self.config.wrap_error(e=e)
|
|
280
|
+
|
|
281
|
+
async def embed_query(self, query: str) -> list[float]:
|
|
282
|
+
"""Call out to Bedrock embedding endpoint."""
|
|
283
|
+
provider = self.config.embedder_model_name.split(".")[0]
|
|
284
|
+
body = conform_query(query=query, provider=provider)
|
|
285
|
+
try:
|
|
286
|
+
async with self.config.get_async_client() as bedrock_client:
|
|
287
|
+
# invoke bedrock API
|
|
288
|
+
try:
|
|
289
|
+
invoke_params = {
|
|
290
|
+
"body": json.dumps(body),
|
|
291
|
+
"modelId": self.config.embedder_model_name,
|
|
292
|
+
"accept": "application/json",
|
|
293
|
+
"contentType": "application/json",
|
|
294
|
+
}
|
|
295
|
+
|
|
296
|
+
# Add inference profile if configured
|
|
297
|
+
if self.config.inference_profile_id:
|
|
298
|
+
invoke_params["inferenceProfileId"] = self.config.inference_profile_id
|
|
299
|
+
|
|
300
|
+
response = await bedrock_client.invoke_model(**invoke_params)
|
|
301
|
+
except Exception as e:
|
|
302
|
+
raise self.wrap_error(e=e)
|
|
303
|
+
async with response.get("body") as client_response:
|
|
304
|
+
response_body = await client_response.json()
|
|
305
|
+
|
|
306
|
+
# format output based on provider
|
|
307
|
+
if provider == "cohere":
|
|
308
|
+
return response_body.get("embeddings")[0]
|
|
309
|
+
else:
|
|
310
|
+
# includes common provider == "amazon"
|
|
311
|
+
return response_body.get("embedding")
|
|
312
|
+
except Exception as e:
|
|
313
|
+
raise ValueError(f"Error raised by inference endpoint: {e}")
|
|
314
|
+
|
|
315
|
+
async def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
316
|
+
elements = elements.copy()
|
|
317
|
+
elements_with_text = [e for e in elements if e.get("text")]
|
|
318
|
+
embeddings = await asyncio.gather(
|
|
319
|
+
*[self.embed_query(query=e.get("text", "")) for e in elements_with_text]
|
|
320
|
+
)
|
|
321
|
+
for element, embedding in zip(elements_with_text, embeddings):
|
|
322
|
+
element[EMBEDDINGS_KEY] = embedding
|
|
323
|
+
return elements
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import TYPE_CHECKING, Optional
|
|
3
|
+
|
|
4
|
+
from pydantic import Field
|
|
5
|
+
|
|
6
|
+
from unstructured_ingest.embed.interfaces import (
|
|
7
|
+
EMBEDDINGS_KEY,
|
|
8
|
+
BaseEmbeddingEncoder,
|
|
9
|
+
EmbeddingConfig,
|
|
10
|
+
)
|
|
11
|
+
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
12
|
+
|
|
13
|
+
if TYPE_CHECKING:
|
|
14
|
+
from sentence_transformers import SentenceTransformer
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class HuggingFaceEmbeddingConfig(EmbeddingConfig):
|
|
18
|
+
embedder_model_name: Optional[str] = Field(
|
|
19
|
+
default="all-MiniLM-L6-v2", alias="model_name", description="HuggingFace model name"
|
|
20
|
+
)
|
|
21
|
+
embedder_model_kwargs: Optional[dict] = Field(
|
|
22
|
+
default_factory=lambda: {"device": "cpu"},
|
|
23
|
+
alias="model_kwargs",
|
|
24
|
+
description="additional model parameters",
|
|
25
|
+
)
|
|
26
|
+
encode_kwargs: Optional[dict] = Field(
|
|
27
|
+
default_factory=lambda: {"normalize_embeddings": False},
|
|
28
|
+
description="additional embedding parameters",
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
@requires_dependencies(
|
|
32
|
+
["sentence_transformers"],
|
|
33
|
+
extras="huggingface",
|
|
34
|
+
)
|
|
35
|
+
def get_client(self) -> "SentenceTransformer":
|
|
36
|
+
from sentence_transformers import SentenceTransformer
|
|
37
|
+
|
|
38
|
+
return SentenceTransformer(
|
|
39
|
+
model_name_or_path=self.embedder_model_name,
|
|
40
|
+
**self.embedder_model_kwargs,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
def get_encoder_kwargs(self) -> dict:
|
|
44
|
+
encoder_kwargs = self.encode_kwargs or {}
|
|
45
|
+
encoder_kwargs["batch_size"] = self.batch_size
|
|
46
|
+
return encoder_kwargs
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclass
|
|
50
|
+
class HuggingFaceEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
51
|
+
config: HuggingFaceEmbeddingConfig = field(default_factory=HuggingFaceEmbeddingConfig)
|
|
52
|
+
|
|
53
|
+
def _embed_query(self, query: str) -> list[float]:
|
|
54
|
+
return self._embed_documents(texts=[query])[0]
|
|
55
|
+
|
|
56
|
+
def _embed_documents(self, texts: list[str]) -> list[list[float]]:
|
|
57
|
+
client = self.config.get_client()
|
|
58
|
+
embeddings = client.encode(texts, **self.config.get_encoder_kwargs())
|
|
59
|
+
return embeddings.tolist()
|
|
60
|
+
|
|
61
|
+
def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
62
|
+
elements = elements.copy()
|
|
63
|
+
elements_with_text = [e for e in elements if e.get("text")]
|
|
64
|
+
if not elements_with_text:
|
|
65
|
+
return elements
|
|
66
|
+
embeddings = self._embed_documents([e["text"] for e in elements_with_text])
|
|
67
|
+
for element, embedding in zip(elements_with_text, embeddings):
|
|
68
|
+
element[EMBEDDINGS_KEY] = embedding
|
|
69
|
+
return elements
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Any, Optional
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
|
+
|
|
7
|
+
from unstructured_ingest.utils.data_prep import batch_generator
|
|
8
|
+
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
9
|
+
|
|
10
|
+
EMBEDDINGS_KEY = "embeddings"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class EmbeddingConfig(BaseModel):
|
|
14
|
+
batch_size: Optional[int] = Field(
|
|
15
|
+
default=32, description="Optional batch size for embedding requests."
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class BaseEncoder(ABC):
|
|
21
|
+
config: EmbeddingConfig
|
|
22
|
+
|
|
23
|
+
def precheck(self):
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
def initialize(self):
|
|
27
|
+
"""Initializes the embedding encoder class. Should also validate the instance
|
|
28
|
+
is properly configured: e.g., embed a single a element"""
|
|
29
|
+
|
|
30
|
+
def wrap_error(self, e: Exception) -> Exception:
|
|
31
|
+
"""Handle errors from the embedding service. Should raise a more informative error
|
|
32
|
+
if possible"""
|
|
33
|
+
return e
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class BaseEmbeddingEncoder(BaseEncoder, ABC):
|
|
38
|
+
def initialize(self):
|
|
39
|
+
"""Initializes the embedding encoder class. Should also validate the instance
|
|
40
|
+
is properly configured: e.g., embed a single a element"""
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def dimension(self):
|
|
44
|
+
exemplary_embedding = self.get_exemplary_embedding()
|
|
45
|
+
return len(exemplary_embedding)
|
|
46
|
+
|
|
47
|
+
def get_exemplary_embedding(self) -> list[float]:
|
|
48
|
+
return self.embed_query(query="Q")
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
@requires_dependencies(["numpy"])
|
|
52
|
+
def is_unit_vector(self) -> bool:
|
|
53
|
+
"""Denotes if the embedding vector is a unit vector."""
|
|
54
|
+
import numpy as np
|
|
55
|
+
|
|
56
|
+
exemplary_embedding = self.get_exemplary_embedding()
|
|
57
|
+
return np.isclose(np.linalg.norm(exemplary_embedding), 1.0, rtol=1e-03)
|
|
58
|
+
|
|
59
|
+
def get_client(self):
|
|
60
|
+
raise NotImplementedError
|
|
61
|
+
|
|
62
|
+
def embed_batch(self, client: Any, batch: list[str]) -> list[list[float]]:
|
|
63
|
+
raise NotImplementedError
|
|
64
|
+
|
|
65
|
+
def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
66
|
+
client = self.get_client()
|
|
67
|
+
elements = elements.copy()
|
|
68
|
+
elements_with_text = [e for e in elements if e.get("text")]
|
|
69
|
+
texts = [e["text"] for e in elements_with_text]
|
|
70
|
+
all_embeddings = []
|
|
71
|
+
try:
|
|
72
|
+
for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
|
|
73
|
+
embeddings_batch = self.embed_batch(client=client, batch=batch)
|
|
74
|
+
all_embeddings.extend(embeddings_batch)
|
|
75
|
+
except Exception as e:
|
|
76
|
+
raise self.wrap_error(e=e)
|
|
77
|
+
for element, embedding in zip(elements_with_text, all_embeddings, strict=True):
|
|
78
|
+
element[EMBEDDINGS_KEY] = embedding
|
|
79
|
+
return elements
|
|
80
|
+
|
|
81
|
+
def _embed_query(self, query: str) -> list[float]:
|
|
82
|
+
client = self.get_client()
|
|
83
|
+
return self.embed_batch(client=client, batch=[query])[0]
|
|
84
|
+
|
|
85
|
+
def embed_query(self, query: str) -> list[float]:
|
|
86
|
+
try:
|
|
87
|
+
return self._embed_query(query=query)
|
|
88
|
+
except Exception as e:
|
|
89
|
+
raise self.wrap_error(e=e)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@dataclass
|
|
93
|
+
class AsyncBaseEmbeddingEncoder(BaseEncoder, ABC):
|
|
94
|
+
async def initialize(self):
|
|
95
|
+
"""Initializes the embedding encoder class. Should also validate the instance
|
|
96
|
+
is properly configured: e.g., embed a single a element"""
|
|
97
|
+
|
|
98
|
+
@property
|
|
99
|
+
async def dimension(self):
|
|
100
|
+
exemplary_embedding = await self.get_exemplary_embedding()
|
|
101
|
+
return len(exemplary_embedding)
|
|
102
|
+
|
|
103
|
+
async def get_exemplary_embedding(self) -> list[float]:
|
|
104
|
+
return await self.embed_query(query="Q")
|
|
105
|
+
|
|
106
|
+
@property
|
|
107
|
+
@requires_dependencies(["numpy"])
|
|
108
|
+
async def is_unit_vector(self) -> bool:
|
|
109
|
+
"""Denotes if the embedding vector is a unit vector."""
|
|
110
|
+
import numpy as np
|
|
111
|
+
|
|
112
|
+
exemplary_embedding = await self.get_exemplary_embedding()
|
|
113
|
+
return np.isclose(np.linalg.norm(exemplary_embedding), 1.0, rtol=1e-03)
|
|
114
|
+
|
|
115
|
+
def get_client(self):
|
|
116
|
+
raise NotImplementedError
|
|
117
|
+
|
|
118
|
+
async def embed_batch(self, client: Any, batch: list[str]) -> list[list[float]]:
|
|
119
|
+
raise NotImplementedError
|
|
120
|
+
|
|
121
|
+
async def embed_documents(self, elements: list[dict]) -> list[dict]:
|
|
122
|
+
client = self.get_client()
|
|
123
|
+
elements = elements.copy()
|
|
124
|
+
elements_with_text = [e for e in elements if e.get("text")]
|
|
125
|
+
texts = [e["text"] for e in elements_with_text]
|
|
126
|
+
all_embeddings = []
|
|
127
|
+
try:
|
|
128
|
+
for batch in batch_generator(texts, batch_size=self.config.batch_size or len(texts)):
|
|
129
|
+
embeddings_batch = await self.embed_batch(client=client, batch=batch)
|
|
130
|
+
all_embeddings.extend(embeddings_batch)
|
|
131
|
+
except Exception as e:
|
|
132
|
+
raise self.wrap_error(e=e)
|
|
133
|
+
for element, embedding in zip(elements_with_text, all_embeddings, strict=True):
|
|
134
|
+
element[EMBEDDINGS_KEY] = embedding
|
|
135
|
+
return elements
|
|
136
|
+
|
|
137
|
+
async def _embed_query(self, query: str) -> list[float]:
|
|
138
|
+
client = self.get_client()
|
|
139
|
+
embeddings = await self.embed_batch(client=client, batch=[query])
|
|
140
|
+
return embeddings[0]
|
|
141
|
+
|
|
142
|
+
async def embed_query(self, query: str) -> list[float]:
|
|
143
|
+
try:
|
|
144
|
+
return await self._embed_query(query=query)
|
|
145
|
+
except Exception as e:
|
|
146
|
+
raise self.wrap_error(e=e)
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from pydantic import Field, SecretStr
|
|
6
|
+
|
|
7
|
+
from unstructured_ingest.embed.interfaces import (
|
|
8
|
+
AsyncBaseEmbeddingEncoder,
|
|
9
|
+
BaseEmbeddingEncoder,
|
|
10
|
+
EmbeddingConfig,
|
|
11
|
+
)
|
|
12
|
+
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
13
|
+
|
|
14
|
+
USER_AGENT = "@mixedbread-ai/unstructured"
|
|
15
|
+
TIMEOUT = 60
|
|
16
|
+
MAX_RETRIES = 3
|
|
17
|
+
ENCODING_FORMAT = "float"
|
|
18
|
+
TRUNCATION_STRATEGY = "end"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from mixedbread import AsyncMixedbread, Mixedbread
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class MixedbreadAIEmbeddingConfig(EmbeddingConfig):
|
|
26
|
+
"""
|
|
27
|
+
Configuration class for Mixedbread AI Embedding Encoder.
|
|
28
|
+
|
|
29
|
+
Attributes:
|
|
30
|
+
api_key (str): API key for accessing Mixedbread AI..
|
|
31
|
+
embedder_model_name (str): Name of the model to use for embeddings.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
api_key: SecretStr = Field(
|
|
35
|
+
default_factory=lambda: SecretStr(os.environ.get("MXBAI_API_KEY")),
|
|
36
|
+
description="API key for Mixedbread AI",
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
embedder_model_name: str = Field(
|
|
40
|
+
default="mixedbread-ai/mxbai-embed-large-v1",
|
|
41
|
+
alias="model_name",
|
|
42
|
+
description="Mixedbread AI model name",
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
@requires_dependencies(
|
|
46
|
+
["mixedbread"],
|
|
47
|
+
extras="embed-mixedbreadai",
|
|
48
|
+
)
|
|
49
|
+
def get_client(self) -> "Mixedbread":
|
|
50
|
+
"""
|
|
51
|
+
Create the Mixedbread AI client.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
Mixedbread: Initialized client.
|
|
55
|
+
"""
|
|
56
|
+
from mixedbread import Mixedbread
|
|
57
|
+
|
|
58
|
+
return Mixedbread(
|
|
59
|
+
api_key=self.api_key.get_secret_value(),
|
|
60
|
+
max_retries=MAX_RETRIES,
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
@requires_dependencies(
|
|
64
|
+
["mixedbread"],
|
|
65
|
+
extras="embed-mixedbreadai",
|
|
66
|
+
)
|
|
67
|
+
def get_async_client(self) -> "AsyncMixedbread":
|
|
68
|
+
from mixedbread import AsyncMixedbread
|
|
69
|
+
|
|
70
|
+
return AsyncMixedbread(
|
|
71
|
+
api_key=self.api_key.get_secret_value(),
|
|
72
|
+
max_retries=MAX_RETRIES,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@dataclass
|
|
77
|
+
class MixedbreadAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
78
|
+
"""
|
|
79
|
+
Embedding encoder for Mixedbread AI.
|
|
80
|
+
|
|
81
|
+
Attributes:
|
|
82
|
+
config (MixedbreadAIEmbeddingConfig): Configuration for the embedding encoder.
|
|
83
|
+
"""
|
|
84
|
+
|
|
85
|
+
config: MixedbreadAIEmbeddingConfig
|
|
86
|
+
|
|
87
|
+
def get_exemplary_embedding(self) -> list[float]:
|
|
88
|
+
"""Get an exemplary embedding to determine dimensions and unit vector status."""
|
|
89
|
+
return self.embed_query(query="Q")
|
|
90
|
+
|
|
91
|
+
@requires_dependencies(
|
|
92
|
+
["mixedbread"],
|
|
93
|
+
extras="embed-mixedbreadai",
|
|
94
|
+
)
|
|
95
|
+
def get_client(self) -> "Mixedbread":
|
|
96
|
+
return self.config.get_client()
|
|
97
|
+
|
|
98
|
+
def embed_batch(self, client: "Mixedbread", batch: list[str]) -> list[list[float]]:
|
|
99
|
+
response = client.embed(
|
|
100
|
+
model=self.config.embedder_model_name,
|
|
101
|
+
input=batch,
|
|
102
|
+
normalized=True,
|
|
103
|
+
encoding_format=ENCODING_FORMAT,
|
|
104
|
+
extra_headers={"User-Agent": USER_AGENT},
|
|
105
|
+
timeout=TIMEOUT,
|
|
106
|
+
)
|
|
107
|
+
return [datum.embedding for datum in response.data]
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@dataclass
|
|
111
|
+
class AsyncMixedbreadAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
|
|
112
|
+
config: MixedbreadAIEmbeddingConfig
|
|
113
|
+
|
|
114
|
+
async def get_exemplary_embedding(self) -> list[float]:
|
|
115
|
+
"""Get an exemplary embedding to determine dimensions and unit vector status."""
|
|
116
|
+
return await self.embed_query(query="Q")
|
|
117
|
+
|
|
118
|
+
@requires_dependencies(
|
|
119
|
+
["mixedbread"],
|
|
120
|
+
extras="embed-mixedbreadai",
|
|
121
|
+
)
|
|
122
|
+
def get_client(self) -> "AsyncMixedbread":
|
|
123
|
+
return self.config.get_async_client()
|
|
124
|
+
|
|
125
|
+
async def embed_batch(self, client: "AsyncMixedbread", batch: list[str]) -> list[list[float]]:
|
|
126
|
+
response = await client.embed(
|
|
127
|
+
model=self.config.embedder_model_name,
|
|
128
|
+
input=batch,
|
|
129
|
+
normalized=True,
|
|
130
|
+
encoding_format=ENCODING_FORMAT,
|
|
131
|
+
extra_headers={"User-Agent": USER_AGENT},
|
|
132
|
+
timeout=TIMEOUT,
|
|
133
|
+
)
|
|
134
|
+
return [datum.embedding for datum in response.data]
|