unstructured-ingest 1.2.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of unstructured-ingest might be problematic. Click here for more details.
- unstructured_ingest/__init__.py +1 -0
- unstructured_ingest/__version__.py +1 -0
- unstructured_ingest/cli/README.md +28 -0
- unstructured_ingest/cli/__init__.py +0 -0
- unstructured_ingest/cli/base/__init__.py +4 -0
- unstructured_ingest/cli/base/cmd.py +269 -0
- unstructured_ingest/cli/base/dest.py +84 -0
- unstructured_ingest/cli/base/importer.py +34 -0
- unstructured_ingest/cli/base/src.py +75 -0
- unstructured_ingest/cli/cli.py +24 -0
- unstructured_ingest/cli/cmds.py +14 -0
- unstructured_ingest/cli/utils/__init__.py +0 -0
- unstructured_ingest/cli/utils/click.py +237 -0
- unstructured_ingest/cli/utils/model_conversion.py +222 -0
- unstructured_ingest/data_types/__init__.py +0 -0
- unstructured_ingest/data_types/entities.py +17 -0
- unstructured_ingest/data_types/file_data.py +116 -0
- unstructured_ingest/embed/__init__.py +0 -0
- unstructured_ingest/embed/azure_openai.py +63 -0
- unstructured_ingest/embed/bedrock.py +323 -0
- unstructured_ingest/embed/huggingface.py +69 -0
- unstructured_ingest/embed/interfaces.py +146 -0
- unstructured_ingest/embed/mixedbreadai.py +134 -0
- unstructured_ingest/embed/octoai.py +133 -0
- unstructured_ingest/embed/openai.py +142 -0
- unstructured_ingest/embed/togetherai.py +116 -0
- unstructured_ingest/embed/vertexai.py +109 -0
- unstructured_ingest/embed/voyageai.py +130 -0
- unstructured_ingest/error.py +156 -0
- unstructured_ingest/errors_v2.py +156 -0
- unstructured_ingest/interfaces/__init__.py +27 -0
- unstructured_ingest/interfaces/connector.py +56 -0
- unstructured_ingest/interfaces/downloader.py +90 -0
- unstructured_ingest/interfaces/indexer.py +29 -0
- unstructured_ingest/interfaces/process.py +22 -0
- unstructured_ingest/interfaces/processor.py +88 -0
- unstructured_ingest/interfaces/upload_stager.py +89 -0
- unstructured_ingest/interfaces/uploader.py +67 -0
- unstructured_ingest/logger.py +39 -0
- unstructured_ingest/main.py +11 -0
- unstructured_ingest/otel.py +128 -0
- unstructured_ingest/pipeline/__init__.py +0 -0
- unstructured_ingest/pipeline/interfaces.py +211 -0
- unstructured_ingest/pipeline/otel.py +32 -0
- unstructured_ingest/pipeline/pipeline.py +408 -0
- unstructured_ingest/pipeline/steps/__init__.py +0 -0
- unstructured_ingest/pipeline/steps/chunk.py +78 -0
- unstructured_ingest/pipeline/steps/download.py +206 -0
- unstructured_ingest/pipeline/steps/embed.py +77 -0
- unstructured_ingest/pipeline/steps/filter.py +35 -0
- unstructured_ingest/pipeline/steps/index.py +86 -0
- unstructured_ingest/pipeline/steps/partition.py +77 -0
- unstructured_ingest/pipeline/steps/stage.py +65 -0
- unstructured_ingest/pipeline/steps/uncompress.py +50 -0
- unstructured_ingest/pipeline/steps/upload.py +58 -0
- unstructured_ingest/processes/__init__.py +18 -0
- unstructured_ingest/processes/chunker.py +131 -0
- unstructured_ingest/processes/connector_registry.py +69 -0
- unstructured_ingest/processes/connectors/__init__.py +129 -0
- unstructured_ingest/processes/connectors/airtable.py +238 -0
- unstructured_ingest/processes/connectors/assets/__init__.py +0 -0
- unstructured_ingest/processes/connectors/assets/databricks_delta_table_schema.sql +9 -0
- unstructured_ingest/processes/connectors/assets/weaviate_collection_config.json +23 -0
- unstructured_ingest/processes/connectors/astradb.py +592 -0
- unstructured_ingest/processes/connectors/azure_ai_search.py +275 -0
- unstructured_ingest/processes/connectors/chroma.py +193 -0
- unstructured_ingest/processes/connectors/confluence.py +527 -0
- unstructured_ingest/processes/connectors/couchbase.py +336 -0
- unstructured_ingest/processes/connectors/databricks/__init__.py +58 -0
- unstructured_ingest/processes/connectors/databricks/volumes.py +233 -0
- unstructured_ingest/processes/connectors/databricks/volumes_aws.py +93 -0
- unstructured_ingest/processes/connectors/databricks/volumes_azure.py +108 -0
- unstructured_ingest/processes/connectors/databricks/volumes_gcp.py +91 -0
- unstructured_ingest/processes/connectors/databricks/volumes_native.py +92 -0
- unstructured_ingest/processes/connectors/databricks/volumes_table.py +187 -0
- unstructured_ingest/processes/connectors/delta_table.py +310 -0
- unstructured_ingest/processes/connectors/discord.py +161 -0
- unstructured_ingest/processes/connectors/duckdb/__init__.py +15 -0
- unstructured_ingest/processes/connectors/duckdb/base.py +103 -0
- unstructured_ingest/processes/connectors/duckdb/duckdb.py +130 -0
- unstructured_ingest/processes/connectors/duckdb/motherduck.py +130 -0
- unstructured_ingest/processes/connectors/elasticsearch/__init__.py +19 -0
- unstructured_ingest/processes/connectors/elasticsearch/elasticsearch.py +478 -0
- unstructured_ingest/processes/connectors/elasticsearch/opensearch.py +523 -0
- unstructured_ingest/processes/connectors/fsspec/__init__.py +37 -0
- unstructured_ingest/processes/connectors/fsspec/azure.py +203 -0
- unstructured_ingest/processes/connectors/fsspec/box.py +176 -0
- unstructured_ingest/processes/connectors/fsspec/dropbox.py +238 -0
- unstructured_ingest/processes/connectors/fsspec/fsspec.py +475 -0
- unstructured_ingest/processes/connectors/fsspec/gcs.py +203 -0
- unstructured_ingest/processes/connectors/fsspec/s3.py +253 -0
- unstructured_ingest/processes/connectors/fsspec/sftp.py +177 -0
- unstructured_ingest/processes/connectors/fsspec/utils.py +17 -0
- unstructured_ingest/processes/connectors/github.py +226 -0
- unstructured_ingest/processes/connectors/gitlab.py +270 -0
- unstructured_ingest/processes/connectors/google_drive.py +848 -0
- unstructured_ingest/processes/connectors/ibm_watsonx/__init__.py +10 -0
- unstructured_ingest/processes/connectors/ibm_watsonx/ibm_watsonx_s3.py +367 -0
- unstructured_ingest/processes/connectors/jira.py +522 -0
- unstructured_ingest/processes/connectors/kafka/__init__.py +17 -0
- unstructured_ingest/processes/connectors/kafka/cloud.py +121 -0
- unstructured_ingest/processes/connectors/kafka/kafka.py +275 -0
- unstructured_ingest/processes/connectors/kafka/local.py +103 -0
- unstructured_ingest/processes/connectors/kdbai.py +156 -0
- unstructured_ingest/processes/connectors/lancedb/__init__.py +30 -0
- unstructured_ingest/processes/connectors/lancedb/aws.py +43 -0
- unstructured_ingest/processes/connectors/lancedb/azure.py +43 -0
- unstructured_ingest/processes/connectors/lancedb/cloud.py +42 -0
- unstructured_ingest/processes/connectors/lancedb/gcp.py +44 -0
- unstructured_ingest/processes/connectors/lancedb/lancedb.py +181 -0
- unstructured_ingest/processes/connectors/lancedb/local.py +44 -0
- unstructured_ingest/processes/connectors/local.py +227 -0
- unstructured_ingest/processes/connectors/milvus.py +311 -0
- unstructured_ingest/processes/connectors/mongodb.py +389 -0
- unstructured_ingest/processes/connectors/neo4j.py +534 -0
- unstructured_ingest/processes/connectors/notion/__init__.py +0 -0
- unstructured_ingest/processes/connectors/notion/client.py +349 -0
- unstructured_ingest/processes/connectors/notion/connector.py +350 -0
- unstructured_ingest/processes/connectors/notion/helpers.py +448 -0
- unstructured_ingest/processes/connectors/notion/ingest_backoff/__init__.py +3 -0
- unstructured_ingest/processes/connectors/notion/ingest_backoff/_common.py +102 -0
- unstructured_ingest/processes/connectors/notion/ingest_backoff/_wrapper.py +126 -0
- unstructured_ingest/processes/connectors/notion/ingest_backoff/types.py +24 -0
- unstructured_ingest/processes/connectors/notion/interfaces.py +32 -0
- unstructured_ingest/processes/connectors/notion/types/__init__.py +0 -0
- unstructured_ingest/processes/connectors/notion/types/block.py +96 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/__init__.py +63 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/bookmark.py +40 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/breadcrumb.py +21 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/bulleted_list_item.py +31 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/callout.py +131 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/child_database.py +23 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/child_page.py +23 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/code.py +43 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/column_list.py +35 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/divider.py +22 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/embed.py +36 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/equation.py +23 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/file.py +49 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/heading.py +37 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/image.py +21 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/link_preview.py +24 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/link_to_page.py +29 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/numbered_list.py +29 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/paragraph.py +31 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/pdf.py +49 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/quote.py +37 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/synced_block.py +109 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/table.py +60 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/table_of_contents.py +23 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/template.py +30 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/todo.py +42 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/toggle.py +37 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/unsupported.py +20 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/video.py +22 -0
- unstructured_ingest/processes/connectors/notion/types/database.py +73 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/__init__.py +125 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/checkbox.py +39 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/created_by.py +36 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/created_time.py +35 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/date.py +42 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/email.py +37 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/files.py +38 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/formula.py +50 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/last_edited_by.py +34 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/last_edited_time.py +35 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/multiselect.py +74 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/number.py +50 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/people.py +42 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/phone_number.py +37 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/relation.py +68 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/rich_text.py +44 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/rollup.py +57 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/select.py +70 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/status.py +82 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/title.py +38 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/unique_id.py +51 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/url.py +38 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/verification.py +79 -0
- unstructured_ingest/processes/connectors/notion/types/date.py +29 -0
- unstructured_ingest/processes/connectors/notion/types/file.py +54 -0
- unstructured_ingest/processes/connectors/notion/types/page.py +52 -0
- unstructured_ingest/processes/connectors/notion/types/parent.py +66 -0
- unstructured_ingest/processes/connectors/notion/types/rich_text.py +189 -0
- unstructured_ingest/processes/connectors/notion/types/user.py +83 -0
- unstructured_ingest/processes/connectors/onedrive.py +485 -0
- unstructured_ingest/processes/connectors/outlook.py +242 -0
- unstructured_ingest/processes/connectors/pinecone.py +400 -0
- unstructured_ingest/processes/connectors/qdrant/__init__.py +16 -0
- unstructured_ingest/processes/connectors/qdrant/cloud.py +59 -0
- unstructured_ingest/processes/connectors/qdrant/local.py +58 -0
- unstructured_ingest/processes/connectors/qdrant/qdrant.py +163 -0
- unstructured_ingest/processes/connectors/qdrant/server.py +60 -0
- unstructured_ingest/processes/connectors/redisdb.py +214 -0
- unstructured_ingest/processes/connectors/salesforce.py +307 -0
- unstructured_ingest/processes/connectors/sharepoint.py +282 -0
- unstructured_ingest/processes/connectors/slack.py +249 -0
- unstructured_ingest/processes/connectors/sql/__init__.py +41 -0
- unstructured_ingest/processes/connectors/sql/databricks_delta_tables.py +228 -0
- unstructured_ingest/processes/connectors/sql/postgres.py +168 -0
- unstructured_ingest/processes/connectors/sql/singlestore.py +176 -0
- unstructured_ingest/processes/connectors/sql/snowflake.py +298 -0
- unstructured_ingest/processes/connectors/sql/sql.py +456 -0
- unstructured_ingest/processes/connectors/sql/sqlite.py +179 -0
- unstructured_ingest/processes/connectors/sql/teradata.py +254 -0
- unstructured_ingest/processes/connectors/sql/vastdb.py +263 -0
- unstructured_ingest/processes/connectors/utils.py +60 -0
- unstructured_ingest/processes/connectors/vectara.py +348 -0
- unstructured_ingest/processes/connectors/weaviate/__init__.py +22 -0
- unstructured_ingest/processes/connectors/weaviate/cloud.py +166 -0
- unstructured_ingest/processes/connectors/weaviate/embedded.py +90 -0
- unstructured_ingest/processes/connectors/weaviate/local.py +73 -0
- unstructured_ingest/processes/connectors/weaviate/weaviate.py +337 -0
- unstructured_ingest/processes/connectors/zendesk/__init__.py +0 -0
- unstructured_ingest/processes/connectors/zendesk/client.py +314 -0
- unstructured_ingest/processes/connectors/zendesk/zendesk.py +241 -0
- unstructured_ingest/processes/embedder.py +203 -0
- unstructured_ingest/processes/filter.py +60 -0
- unstructured_ingest/processes/partitioner.py +233 -0
- unstructured_ingest/processes/uncompress.py +61 -0
- unstructured_ingest/processes/utils/__init__.py +8 -0
- unstructured_ingest/processes/utils/blob_storage.py +32 -0
- unstructured_ingest/processes/utils/logging/connector.py +365 -0
- unstructured_ingest/processes/utils/logging/sanitizer.py +117 -0
- unstructured_ingest/unstructured_api.py +140 -0
- unstructured_ingest/utils/__init__.py +5 -0
- unstructured_ingest/utils/chunking.py +56 -0
- unstructured_ingest/utils/compression.py +72 -0
- unstructured_ingest/utils/constants.py +2 -0
- unstructured_ingest/utils/data_prep.py +216 -0
- unstructured_ingest/utils/dep_check.py +78 -0
- unstructured_ingest/utils/filesystem.py +27 -0
- unstructured_ingest/utils/html.py +174 -0
- unstructured_ingest/utils/ndjson.py +52 -0
- unstructured_ingest/utils/pydantic_models.py +52 -0
- unstructured_ingest/utils/string_and_date_utils.py +74 -0
- unstructured_ingest/utils/table.py +80 -0
- unstructured_ingest/utils/tls.py +15 -0
- unstructured_ingest-1.2.32.dist-info/METADATA +235 -0
- unstructured_ingest-1.2.32.dist-info/RECORD +243 -0
- unstructured_ingest-1.2.32.dist-info/WHEEL +4 -0
- unstructured_ingest-1.2.32.dist-info/entry_points.txt +2 -0
- unstructured_ingest-1.2.32.dist-info/licenses/LICENSE.md +201 -0
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
from pydantic import Field, SecretStr
|
|
5
|
+
|
|
6
|
+
from unstructured_ingest.embed.interfaces import (
|
|
7
|
+
AsyncBaseEmbeddingEncoder,
|
|
8
|
+
BaseEmbeddingEncoder,
|
|
9
|
+
EmbeddingConfig,
|
|
10
|
+
)
|
|
11
|
+
from unstructured_ingest.error import (
|
|
12
|
+
ProviderError,
|
|
13
|
+
QuotaError,
|
|
14
|
+
RateLimitError,
|
|
15
|
+
UserAuthError,
|
|
16
|
+
UserError,
|
|
17
|
+
is_internal_error,
|
|
18
|
+
)
|
|
19
|
+
from unstructured_ingest.logger import logger
|
|
20
|
+
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
21
|
+
|
|
22
|
+
if TYPE_CHECKING:
|
|
23
|
+
from openai import AsyncOpenAI, OpenAI
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class OctoAiEmbeddingConfig(EmbeddingConfig):
|
|
27
|
+
api_key: SecretStr = Field(description="API key for OctoAI")
|
|
28
|
+
embedder_model_name: str = Field(
|
|
29
|
+
default="thenlper/gte-large", alias="model_name", description="octoai model name"
|
|
30
|
+
)
|
|
31
|
+
base_url: str = Field(
|
|
32
|
+
default="https://text.octoai.run/v1", description="optional override for the base url"
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
def wrap_error(self, e: Exception) -> Exception:
|
|
36
|
+
if is_internal_error(e=e):
|
|
37
|
+
return e
|
|
38
|
+
# https://platform.openai.com/docs/guides/error-codes/api-errors
|
|
39
|
+
from openai import APIStatusError
|
|
40
|
+
|
|
41
|
+
if not isinstance(e, APIStatusError):
|
|
42
|
+
logger.error(f"unhandled exception from openai: {e}", exc_info=True)
|
|
43
|
+
raise e
|
|
44
|
+
error_code = e.code
|
|
45
|
+
if 400 <= e.status_code < 500:
|
|
46
|
+
# user error
|
|
47
|
+
if e.status_code == 401:
|
|
48
|
+
return UserAuthError(e.message)
|
|
49
|
+
if e.status_code == 429:
|
|
50
|
+
# 429 indicates rate limit exceeded and quote exceeded
|
|
51
|
+
if error_code == "insufficient_quota":
|
|
52
|
+
return QuotaError(e.message)
|
|
53
|
+
else:
|
|
54
|
+
return RateLimitError(e.message)
|
|
55
|
+
return UserError(e.message)
|
|
56
|
+
if e.status_code >= 500:
|
|
57
|
+
return ProviderError(e.message)
|
|
58
|
+
logger.error(f"unhandled exception from openai: {e}", exc_info=True)
|
|
59
|
+
return e
|
|
60
|
+
|
|
61
|
+
def run_precheck(self) -> None:
|
|
62
|
+
client = self.get_client()
|
|
63
|
+
try:
|
|
64
|
+
models = [m.id for m in list(client.models.list())]
|
|
65
|
+
if self.embedder_model_name not in models:
|
|
66
|
+
raise UserError(
|
|
67
|
+
"model '{}' not found: {}".format(self.embedder_model_name, ", ".join(models))
|
|
68
|
+
)
|
|
69
|
+
except Exception as e:
|
|
70
|
+
raise self.wrap_error(e=e)
|
|
71
|
+
|
|
72
|
+
@requires_dependencies(
|
|
73
|
+
["openai", "tiktoken"],
|
|
74
|
+
extras="octoai",
|
|
75
|
+
)
|
|
76
|
+
def get_client(self) -> "OpenAI":
|
|
77
|
+
"""Creates an OpenAI python client to embed elements. Uses the OpenAI SDK."""
|
|
78
|
+
from openai import OpenAI
|
|
79
|
+
|
|
80
|
+
return OpenAI(api_key=self.api_key.get_secret_value(), base_url=self.base_url)
|
|
81
|
+
|
|
82
|
+
@requires_dependencies(
|
|
83
|
+
["openai", "tiktoken"],
|
|
84
|
+
extras="octoai",
|
|
85
|
+
)
|
|
86
|
+
def get_async_client(self) -> "AsyncOpenAI":
|
|
87
|
+
"""Creates an OpenAI python client to embed elements. Uses the OpenAI SDK."""
|
|
88
|
+
from openai import AsyncOpenAI
|
|
89
|
+
|
|
90
|
+
return AsyncOpenAI(api_key=self.api_key.get_secret_value(), base_url=self.base_url)
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
@dataclass
|
|
94
|
+
class OctoAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
95
|
+
config: OctoAiEmbeddingConfig
|
|
96
|
+
|
|
97
|
+
def precheck(self):
|
|
98
|
+
self.config.run_precheck()
|
|
99
|
+
|
|
100
|
+
def wrap_error(self, e: Exception) -> Exception:
|
|
101
|
+
return self.config.wrap_error(e=e)
|
|
102
|
+
|
|
103
|
+
def _embed_query(self, query: str):
|
|
104
|
+
client = self.get_client()
|
|
105
|
+
response = client.embeddings.create(input=query, model=self.config.embedder_model_name)
|
|
106
|
+
return response.data[0].embedding
|
|
107
|
+
|
|
108
|
+
def get_client(self) -> "OpenAI":
|
|
109
|
+
return self.config.get_client()
|
|
110
|
+
|
|
111
|
+
def embed_batch(self, client: "OpenAI", batch: list[str]) -> list[list[float]]:
|
|
112
|
+
response = client.embeddings.create(input=batch, model=self.config.embedder_model_name)
|
|
113
|
+
return [data.embedding for data in response.data]
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
@dataclass
|
|
117
|
+
class AsyncOctoAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
|
|
118
|
+
config: OctoAiEmbeddingConfig
|
|
119
|
+
|
|
120
|
+
def precheck(self):
|
|
121
|
+
self.config.run_precheck()
|
|
122
|
+
|
|
123
|
+
def wrap_error(self, e: Exception) -> Exception:
|
|
124
|
+
return self.config.wrap_error(e=e)
|
|
125
|
+
|
|
126
|
+
def get_client(self) -> "AsyncOpenAI":
|
|
127
|
+
return self.config.get_async_client()
|
|
128
|
+
|
|
129
|
+
async def embed_batch(self, client: "AsyncOpenAI", batch: list[str]) -> list[list[float]]:
|
|
130
|
+
response = await client.embeddings.create(
|
|
131
|
+
input=batch, model=self.config.embedder_model_name
|
|
132
|
+
)
|
|
133
|
+
return [data.embedding for data in response.data]
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import TYPE_CHECKING, Optional
|
|
3
|
+
|
|
4
|
+
from pydantic import Field, SecretStr
|
|
5
|
+
|
|
6
|
+
from unstructured_ingest.embed.interfaces import (
|
|
7
|
+
AsyncBaseEmbeddingEncoder,
|
|
8
|
+
BaseEmbeddingEncoder,
|
|
9
|
+
EmbeddingConfig,
|
|
10
|
+
)
|
|
11
|
+
from unstructured_ingest.error import (
|
|
12
|
+
ProviderError,
|
|
13
|
+
QuotaError,
|
|
14
|
+
RateLimitError,
|
|
15
|
+
UserAuthError,
|
|
16
|
+
UserError,
|
|
17
|
+
is_internal_error,
|
|
18
|
+
)
|
|
19
|
+
from unstructured_ingest.logger import logger
|
|
20
|
+
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
21
|
+
from unstructured_ingest.utils.tls import ssl_context_with_optional_ca_override
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
from openai import AsyncOpenAI, OpenAI
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class OpenAIEmbeddingConfig(EmbeddingConfig):
|
|
28
|
+
api_key: SecretStr = Field(description="API key for OpenAI")
|
|
29
|
+
embedder_model_name: str = Field(
|
|
30
|
+
default="text-embedding-ada-002", alias="model_name", description="OpenAI model name"
|
|
31
|
+
)
|
|
32
|
+
base_url: Optional[str] = Field(default=None, description="optional override for the base url")
|
|
33
|
+
|
|
34
|
+
@requires_dependencies(["openai"], extras="openai")
|
|
35
|
+
def wrap_error(self, e: Exception) -> Exception:
|
|
36
|
+
if is_internal_error(e=e):
|
|
37
|
+
return e
|
|
38
|
+
# https://platform.openai.com/docs/guides/error-codes/api-errors
|
|
39
|
+
from openai import APIStatusError
|
|
40
|
+
|
|
41
|
+
if not isinstance(e, APIStatusError):
|
|
42
|
+
logger.error(f"unhandled exception from openai: {e}", exc_info=True)
|
|
43
|
+
raise e
|
|
44
|
+
error_code = e.code
|
|
45
|
+
if 400 <= e.status_code < 500:
|
|
46
|
+
# user error
|
|
47
|
+
if e.status_code == 401:
|
|
48
|
+
return UserAuthError(e.message)
|
|
49
|
+
if e.status_code == 429:
|
|
50
|
+
# 429 indicates rate limit exceeded and quote exceeded
|
|
51
|
+
if error_code == "insufficient_quota":
|
|
52
|
+
return QuotaError(e.message)
|
|
53
|
+
else:
|
|
54
|
+
return RateLimitError(e.message)
|
|
55
|
+
return UserError(e.message)
|
|
56
|
+
if e.status_code >= 500:
|
|
57
|
+
return ProviderError(e.message)
|
|
58
|
+
logger.error(f"unhandled exception from openai: {e}", exc_info=True)
|
|
59
|
+
return e
|
|
60
|
+
|
|
61
|
+
@requires_dependencies(["openai"], extras="openai")
|
|
62
|
+
def get_models(self) -> Optional[list[str]]:
|
|
63
|
+
# In case the list model endpoint isn't exposed, don't break
|
|
64
|
+
from openai import APIStatusError
|
|
65
|
+
|
|
66
|
+
client = self.get_client()
|
|
67
|
+
try:
|
|
68
|
+
models = [m.id for m in list(client.models.list())]
|
|
69
|
+
return models
|
|
70
|
+
except APIStatusError as e:
|
|
71
|
+
if e.status_code == 404:
|
|
72
|
+
return None
|
|
73
|
+
except Exception as e:
|
|
74
|
+
raise self.wrap_error(e=e)
|
|
75
|
+
|
|
76
|
+
def run_precheck(self) -> None:
|
|
77
|
+
try:
|
|
78
|
+
models = self.get_models()
|
|
79
|
+
if models is None:
|
|
80
|
+
return
|
|
81
|
+
if self.embedder_model_name not in models:
|
|
82
|
+
raise UserError(
|
|
83
|
+
"model '{}' not found: {}".format(self.embedder_model_name, ", ".join(models))
|
|
84
|
+
)
|
|
85
|
+
except Exception as e:
|
|
86
|
+
raise self.wrap_error(e=e)
|
|
87
|
+
|
|
88
|
+
@requires_dependencies(["openai"], extras="openai")
|
|
89
|
+
def get_client(self) -> "OpenAI":
|
|
90
|
+
from openai import DefaultHttpxClient, OpenAI
|
|
91
|
+
|
|
92
|
+
client = DefaultHttpxClient(verify=ssl_context_with_optional_ca_override())
|
|
93
|
+
return OpenAI(
|
|
94
|
+
api_key=self.api_key.get_secret_value(), http_client=client, base_url=self.base_url
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
@requires_dependencies(["openai"], extras="openai")
|
|
98
|
+
def get_async_client(self) -> "AsyncOpenAI":
|
|
99
|
+
from openai import AsyncOpenAI, DefaultAsyncHttpxClient
|
|
100
|
+
|
|
101
|
+
client = DefaultAsyncHttpxClient(verify=ssl_context_with_optional_ca_override())
|
|
102
|
+
return AsyncOpenAI(
|
|
103
|
+
api_key=self.api_key.get_secret_value(), http_client=client, base_url=self.base_url
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
@dataclass
|
|
108
|
+
class OpenAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
109
|
+
config: OpenAIEmbeddingConfig
|
|
110
|
+
|
|
111
|
+
def precheck(self):
|
|
112
|
+
self.config.run_precheck()
|
|
113
|
+
|
|
114
|
+
def wrap_error(self, e: Exception) -> Exception:
|
|
115
|
+
return self.config.wrap_error(e=e)
|
|
116
|
+
|
|
117
|
+
def get_client(self) -> "OpenAI":
|
|
118
|
+
return self.config.get_client()
|
|
119
|
+
|
|
120
|
+
def embed_batch(self, client: "OpenAI", batch: list[str]) -> list[list[float]]:
|
|
121
|
+
response = client.embeddings.create(input=batch, model=self.config.embedder_model_name)
|
|
122
|
+
return [data.embedding for data in response.data]
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
@dataclass
|
|
126
|
+
class AsyncOpenAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
|
|
127
|
+
config: OpenAIEmbeddingConfig
|
|
128
|
+
|
|
129
|
+
def precheck(self):
|
|
130
|
+
self.config.run_precheck()
|
|
131
|
+
|
|
132
|
+
def wrap_error(self, e: Exception) -> Exception:
|
|
133
|
+
return self.config.wrap_error(e=e)
|
|
134
|
+
|
|
135
|
+
def get_client(self) -> "AsyncOpenAI":
|
|
136
|
+
return self.config.get_async_client()
|
|
137
|
+
|
|
138
|
+
async def embed_batch(self, client: "AsyncOpenAI", batch: list[str]) -> list[list[float]]:
|
|
139
|
+
response = await client.embeddings.create(
|
|
140
|
+
input=batch, model=self.config.embedder_model_name
|
|
141
|
+
)
|
|
142
|
+
return [data.embedding for data in response.data]
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import TYPE_CHECKING, Any
|
|
3
|
+
|
|
4
|
+
from pydantic import Field, SecretStr
|
|
5
|
+
|
|
6
|
+
from unstructured_ingest.embed.interfaces import (
|
|
7
|
+
AsyncBaseEmbeddingEncoder,
|
|
8
|
+
BaseEmbeddingEncoder,
|
|
9
|
+
EmbeddingConfig,
|
|
10
|
+
)
|
|
11
|
+
from unstructured_ingest.error import (
|
|
12
|
+
ProviderError,
|
|
13
|
+
UserAuthError,
|
|
14
|
+
UserError,
|
|
15
|
+
is_internal_error,
|
|
16
|
+
)
|
|
17
|
+
from unstructured_ingest.error import (
|
|
18
|
+
RateLimitError as CustomRateLimitError,
|
|
19
|
+
)
|
|
20
|
+
from unstructured_ingest.logger import logger
|
|
21
|
+
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
from together import AsyncTogether, Together
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class TogetherAIEmbeddingConfig(EmbeddingConfig):
|
|
28
|
+
api_key: SecretStr = Field(description="API key for Together AI")
|
|
29
|
+
embedder_model_name: str = Field(
|
|
30
|
+
default="togethercomputer/m2-bert-80M-32k-retrieval",
|
|
31
|
+
alias="model_name",
|
|
32
|
+
description="Together AI model name",
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
def wrap_error(self, e: Exception) -> Exception:
|
|
36
|
+
if is_internal_error(e=e):
|
|
37
|
+
return e
|
|
38
|
+
# https://docs.together.ai/docs/error-codes
|
|
39
|
+
from together.error import AuthenticationError, RateLimitError, TogetherException
|
|
40
|
+
|
|
41
|
+
if not isinstance(e, TogetherException):
|
|
42
|
+
logger.error(f"unhandled exception from together: {e}", exc_info=True)
|
|
43
|
+
return e
|
|
44
|
+
message = e.args[0]
|
|
45
|
+
if isinstance(e, AuthenticationError):
|
|
46
|
+
return UserAuthError(message)
|
|
47
|
+
if isinstance(e, RateLimitError):
|
|
48
|
+
return CustomRateLimitError(message)
|
|
49
|
+
|
|
50
|
+
status_code = getattr(e, "status_code", None)
|
|
51
|
+
if status_code is not None:
|
|
52
|
+
if 400 <= status_code < 500:
|
|
53
|
+
return UserError(message)
|
|
54
|
+
if status_code >= 500:
|
|
55
|
+
return ProviderError(message)
|
|
56
|
+
logger.error(f"unhandled exception from together: {e}", exc_info=True)
|
|
57
|
+
return e
|
|
58
|
+
|
|
59
|
+
def run_precheck(self) -> None:
|
|
60
|
+
client = self.get_client()
|
|
61
|
+
try:
|
|
62
|
+
models = [m.id for m in list(client.models.list())]
|
|
63
|
+
if self.embedder_model_name not in models:
|
|
64
|
+
raise UserError(
|
|
65
|
+
"model '{}' not found: {}".format(self.embedder_model_name, ", ".join(models))
|
|
66
|
+
)
|
|
67
|
+
except Exception as e:
|
|
68
|
+
raise self.wrap_error(e=e)
|
|
69
|
+
|
|
70
|
+
@requires_dependencies(["together"], extras="togetherai")
|
|
71
|
+
def get_client(self) -> "Together":
|
|
72
|
+
from together import Together
|
|
73
|
+
|
|
74
|
+
return Together(api_key=self.api_key.get_secret_value())
|
|
75
|
+
|
|
76
|
+
@requires_dependencies(["together"], extras="togetherai")
|
|
77
|
+
def get_async_client(self) -> "AsyncTogether":
|
|
78
|
+
from together import AsyncTogether
|
|
79
|
+
|
|
80
|
+
return AsyncTogether(api_key=self.api_key.get_secret_value())
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@dataclass
|
|
84
|
+
class TogetherAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
85
|
+
config: TogetherAIEmbeddingConfig
|
|
86
|
+
|
|
87
|
+
def precheck(self):
|
|
88
|
+
self.config.run_precheck()
|
|
89
|
+
|
|
90
|
+
def wrap_error(self, e: Exception) -> Exception:
|
|
91
|
+
return self.config.wrap_error(e=e)
|
|
92
|
+
|
|
93
|
+
def get_client(self) -> "Together":
|
|
94
|
+
return self.config.get_client()
|
|
95
|
+
|
|
96
|
+
def embed_batch(self, client: "Together", batch: list[str]) -> list[list[float]]:
|
|
97
|
+
outputs = client.embeddings.create(model=self.config.embedder_model_name, input=batch)
|
|
98
|
+
return [outputs.data[i].embedding for i in range(len(batch))]
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
@dataclass
|
|
102
|
+
class AsyncTogetherAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
|
|
103
|
+
config: TogetherAIEmbeddingConfig
|
|
104
|
+
|
|
105
|
+
def precheck(self):
|
|
106
|
+
self.config.run_precheck()
|
|
107
|
+
|
|
108
|
+
def wrap_error(self, e: Exception) -> Exception:
|
|
109
|
+
return self.config.wrap_error(e=e)
|
|
110
|
+
|
|
111
|
+
def get_client(self) -> "AsyncTogether":
|
|
112
|
+
return self.config.get_async_client()
|
|
113
|
+
|
|
114
|
+
async def embed_batch(self, client: Any, batch: list[str]) -> list[list[float]]:
|
|
115
|
+
outputs = await client.embeddings.create(model=self.config.embedder_model_name, input=batch)
|
|
116
|
+
return [outputs.data[i].embedding for i in range(len(batch))]
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
# type: ignore
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import TYPE_CHECKING, Annotated, Any, Optional
|
|
7
|
+
|
|
8
|
+
from pydantic import Field, Secret, ValidationError
|
|
9
|
+
from pydantic.functional_validators import BeforeValidator
|
|
10
|
+
|
|
11
|
+
from unstructured_ingest.embed.interfaces import (
|
|
12
|
+
AsyncBaseEmbeddingEncoder,
|
|
13
|
+
BaseEmbeddingEncoder,
|
|
14
|
+
EmbeddingConfig,
|
|
15
|
+
)
|
|
16
|
+
from unstructured_ingest.error import UserAuthError, is_internal_error
|
|
17
|
+
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from vertexai.language_models import TextEmbeddingModel
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def conform_string_to_dict(value: Any) -> dict:
|
|
24
|
+
if isinstance(value, dict):
|
|
25
|
+
return value
|
|
26
|
+
if isinstance(value, str):
|
|
27
|
+
return json.loads(value)
|
|
28
|
+
raise ValidationError(f"Input could not be mapped to a valid dict: {value}")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
ApiKeyType = Secret[Annotated[dict, BeforeValidator(conform_string_to_dict)]]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class VertexAIEmbeddingConfig(EmbeddingConfig):
|
|
35
|
+
api_key: ApiKeyType = Field(description="API key for Vertex AI")
|
|
36
|
+
embedder_model_name: Optional[str] = Field(
|
|
37
|
+
default="text-embedding-005", alias="model_name", description="Vertex AI model name"
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
def wrap_error(self, e: Exception) -> Exception:
|
|
41
|
+
if is_internal_error(e=e):
|
|
42
|
+
return e
|
|
43
|
+
from google.auth.exceptions import GoogleAuthError
|
|
44
|
+
|
|
45
|
+
if isinstance(e, GoogleAuthError):
|
|
46
|
+
return UserAuthError(e)
|
|
47
|
+
return e
|
|
48
|
+
|
|
49
|
+
def register_application_credentials(self):
|
|
50
|
+
# TODO look into passing credentials in directly, rather than via env var and tmp file
|
|
51
|
+
application_credentials_path = Path("/tmp") / "google-vertex-app-credentials.json"
|
|
52
|
+
with application_credentials_path.open("w+") as credentials_file:
|
|
53
|
+
json.dump(self.api_key.get_secret_value(), credentials_file)
|
|
54
|
+
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = str(application_credentials_path)
|
|
55
|
+
|
|
56
|
+
@requires_dependencies(
|
|
57
|
+
["vertexai"],
|
|
58
|
+
extras="vertexai",
|
|
59
|
+
)
|
|
60
|
+
def get_client(self) -> "TextEmbeddingModel":
|
|
61
|
+
"""Creates a VertexAI python client to embed elements."""
|
|
62
|
+
from vertexai.language_models import TextEmbeddingModel
|
|
63
|
+
|
|
64
|
+
self.register_application_credentials()
|
|
65
|
+
return TextEmbeddingModel.from_pretrained(self.embedder_model_name)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@dataclass
|
|
69
|
+
class VertexAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
70
|
+
config: VertexAIEmbeddingConfig
|
|
71
|
+
|
|
72
|
+
def wrap_error(self, e: Exception) -> Exception:
|
|
73
|
+
return self.config.wrap_error(e=e)
|
|
74
|
+
|
|
75
|
+
def get_client(self) -> "TextEmbeddingModel":
|
|
76
|
+
return self.config.get_client()
|
|
77
|
+
|
|
78
|
+
@requires_dependencies(
|
|
79
|
+
["vertexai"],
|
|
80
|
+
extras="embed-vertexai",
|
|
81
|
+
)
|
|
82
|
+
def embed_batch(self, client: "TextEmbeddingModel", batch: list[str]) -> list[list[float]]:
|
|
83
|
+
from vertexai.language_models import TextEmbeddingInput
|
|
84
|
+
|
|
85
|
+
inputs = [TextEmbeddingInput(text=text) for text in batch]
|
|
86
|
+
response = client.get_embeddings(inputs)
|
|
87
|
+
return [e.values for e in response]
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@dataclass
|
|
91
|
+
class AsyncVertexAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
|
|
92
|
+
config: VertexAIEmbeddingConfig
|
|
93
|
+
|
|
94
|
+
def wrap_error(self, e: Exception) -> Exception:
|
|
95
|
+
return self.config.wrap_error(e=e)
|
|
96
|
+
|
|
97
|
+
def get_client(self) -> "TextEmbeddingModel":
|
|
98
|
+
return self.config.get_client()
|
|
99
|
+
|
|
100
|
+
@requires_dependencies(
|
|
101
|
+
["vertexai"],
|
|
102
|
+
extras="embed-vertexai",
|
|
103
|
+
)
|
|
104
|
+
async def embed_batch(self, client: Any, batch: list[str]) -> list[list[float]]:
|
|
105
|
+
from vertexai.language_models import TextEmbeddingInput
|
|
106
|
+
|
|
107
|
+
inputs = [TextEmbeddingInput(text=text) for text in batch]
|
|
108
|
+
response = await client.get_embeddings_async(inputs)
|
|
109
|
+
return [e.values for e in response]
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import TYPE_CHECKING, Optional
|
|
3
|
+
|
|
4
|
+
from pydantic import Field, SecretStr
|
|
5
|
+
|
|
6
|
+
from unstructured_ingest.embed.interfaces import (
|
|
7
|
+
AsyncBaseEmbeddingEncoder,
|
|
8
|
+
BaseEmbeddingEncoder,
|
|
9
|
+
EmbeddingConfig,
|
|
10
|
+
)
|
|
11
|
+
from unstructured_ingest.error import ProviderError, UserAuthError, UserError, is_internal_error
|
|
12
|
+
from unstructured_ingest.error import (
|
|
13
|
+
RateLimitError as CustomRateLimitError,
|
|
14
|
+
)
|
|
15
|
+
from unstructured_ingest.logger import logger
|
|
16
|
+
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from voyageai import AsyncClient as AsyncVoyageAIClient
|
|
20
|
+
from voyageai import Client as VoyageAIClient
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class VoyageAIEmbeddingConfig(EmbeddingConfig):
|
|
24
|
+
batch_size: int = Field(
|
|
25
|
+
default=32,
|
|
26
|
+
le=128,
|
|
27
|
+
description="Batch size for embedding requests. VoyageAI has a limit of 128.",
|
|
28
|
+
)
|
|
29
|
+
api_key: SecretStr = Field(description="API key for VoyageAI")
|
|
30
|
+
embedder_model_name: str = Field(
|
|
31
|
+
default="voyage-3", alias="model_name", description="VoyageAI model name"
|
|
32
|
+
)
|
|
33
|
+
max_retries: int = Field(default=0, description="Max retries for embedding requests.")
|
|
34
|
+
timeout_in_seconds: Optional[int] = Field(
|
|
35
|
+
default=None, description="Optional timeout in seconds for embedding requests."
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
def wrap_error(self, e: Exception) -> Exception:
|
|
39
|
+
if is_internal_error(e=e):
|
|
40
|
+
return e
|
|
41
|
+
# https://docs.voyageai.com/docs/error-codes
|
|
42
|
+
from voyageai.error import AuthenticationError, RateLimitError, VoyageError
|
|
43
|
+
|
|
44
|
+
if not isinstance(e, VoyageError):
|
|
45
|
+
logger.error(f"unhandled exception from openai: {e}", exc_info=True)
|
|
46
|
+
raise e
|
|
47
|
+
http_code = e.http_status
|
|
48
|
+
message = e.user_message
|
|
49
|
+
if isinstance(e, AuthenticationError):
|
|
50
|
+
return UserAuthError(message)
|
|
51
|
+
if isinstance(e, RateLimitError):
|
|
52
|
+
return CustomRateLimitError(message)
|
|
53
|
+
if 400 <= http_code < 500:
|
|
54
|
+
return UserError(message)
|
|
55
|
+
if http_code >= 500:
|
|
56
|
+
return ProviderError(message)
|
|
57
|
+
logger.error(f"unhandled exception from openai: {e}", exc_info=True)
|
|
58
|
+
return e
|
|
59
|
+
|
|
60
|
+
@requires_dependencies(
|
|
61
|
+
["voyageai"],
|
|
62
|
+
extras="embed-voyageai",
|
|
63
|
+
)
|
|
64
|
+
def get_client(self) -> "VoyageAIClient":
|
|
65
|
+
"""Creates a VoyageAI python client to embed elements."""
|
|
66
|
+
from voyageai import Client as VoyageAIClient
|
|
67
|
+
|
|
68
|
+
client = VoyageAIClient(
|
|
69
|
+
api_key=self.api_key.get_secret_value(),
|
|
70
|
+
max_retries=self.max_retries,
|
|
71
|
+
timeout=self.timeout_in_seconds,
|
|
72
|
+
)
|
|
73
|
+
return client
|
|
74
|
+
|
|
75
|
+
@requires_dependencies(
|
|
76
|
+
["voyageai"],
|
|
77
|
+
extras="embed-voyageai",
|
|
78
|
+
)
|
|
79
|
+
def get_async_client(self) -> "AsyncVoyageAIClient":
|
|
80
|
+
"""Creates a VoyageAI python client to embed elements."""
|
|
81
|
+
from voyageai import AsyncClient as AsyncVoyageAIClient
|
|
82
|
+
|
|
83
|
+
client = AsyncVoyageAIClient(
|
|
84
|
+
api_key=self.api_key.get_secret_value(),
|
|
85
|
+
max_retries=self.max_retries,
|
|
86
|
+
timeout=self.timeout_in_seconds,
|
|
87
|
+
)
|
|
88
|
+
return client
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
@dataclass
|
|
92
|
+
class VoyageAIEmbeddingEncoder(BaseEmbeddingEncoder):
|
|
93
|
+
config: VoyageAIEmbeddingConfig
|
|
94
|
+
|
|
95
|
+
def wrap_error(self, e: Exception) -> Exception:
|
|
96
|
+
return self.config.wrap_error(e=e)
|
|
97
|
+
|
|
98
|
+
def get_client(self) -> "VoyageAIClient":
|
|
99
|
+
return self.config.get_client()
|
|
100
|
+
|
|
101
|
+
def embed_batch(self, client: "VoyageAIClient", batch: list[str]) -> list[list[float]]:
|
|
102
|
+
if self.config.embedder_model_name == "voyage-multimodal-3":
|
|
103
|
+
batch = [[text] for text in batch]
|
|
104
|
+
response = client.multimodal_embed(inputs=batch, model=self.config.embedder_model_name)
|
|
105
|
+
else:
|
|
106
|
+
response = client.embed(texts=batch, model=self.config.embedder_model_name)
|
|
107
|
+
return response.embeddings
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@dataclass
|
|
111
|
+
class AsyncVoyageAIEmbeddingEncoder(AsyncBaseEmbeddingEncoder):
|
|
112
|
+
config: VoyageAIEmbeddingConfig
|
|
113
|
+
|
|
114
|
+
def wrap_error(self, e: Exception) -> Exception:
|
|
115
|
+
return self.config.wrap_error(e=e)
|
|
116
|
+
|
|
117
|
+
def get_client(self) -> "AsyncVoyageAIClient":
|
|
118
|
+
return self.config.get_async_client()
|
|
119
|
+
|
|
120
|
+
async def embed_batch(
|
|
121
|
+
self, client: "AsyncVoyageAIClient", batch: list[str]
|
|
122
|
+
) -> list[list[float]]:
|
|
123
|
+
if self.config.embedder_model_name == "voyage-multimodal-3":
|
|
124
|
+
batch = [[text] for text in batch]
|
|
125
|
+
response = await client.multimodal_embed(
|
|
126
|
+
inputs=batch, model=self.config.embedder_model_name
|
|
127
|
+
)
|
|
128
|
+
else:
|
|
129
|
+
response = await client.embed(texts=batch, model=self.config.embedder_model_name)
|
|
130
|
+
return response.embeddings
|