unstructured-ingest 1.2.32__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of unstructured-ingest might be problematic. Click here for more details.
- unstructured_ingest/__init__.py +1 -0
- unstructured_ingest/__version__.py +1 -0
- unstructured_ingest/cli/README.md +28 -0
- unstructured_ingest/cli/__init__.py +0 -0
- unstructured_ingest/cli/base/__init__.py +4 -0
- unstructured_ingest/cli/base/cmd.py +269 -0
- unstructured_ingest/cli/base/dest.py +84 -0
- unstructured_ingest/cli/base/importer.py +34 -0
- unstructured_ingest/cli/base/src.py +75 -0
- unstructured_ingest/cli/cli.py +24 -0
- unstructured_ingest/cli/cmds.py +14 -0
- unstructured_ingest/cli/utils/__init__.py +0 -0
- unstructured_ingest/cli/utils/click.py +237 -0
- unstructured_ingest/cli/utils/model_conversion.py +222 -0
- unstructured_ingest/data_types/__init__.py +0 -0
- unstructured_ingest/data_types/entities.py +17 -0
- unstructured_ingest/data_types/file_data.py +116 -0
- unstructured_ingest/embed/__init__.py +0 -0
- unstructured_ingest/embed/azure_openai.py +63 -0
- unstructured_ingest/embed/bedrock.py +323 -0
- unstructured_ingest/embed/huggingface.py +69 -0
- unstructured_ingest/embed/interfaces.py +146 -0
- unstructured_ingest/embed/mixedbreadai.py +134 -0
- unstructured_ingest/embed/octoai.py +133 -0
- unstructured_ingest/embed/openai.py +142 -0
- unstructured_ingest/embed/togetherai.py +116 -0
- unstructured_ingest/embed/vertexai.py +109 -0
- unstructured_ingest/embed/voyageai.py +130 -0
- unstructured_ingest/error.py +156 -0
- unstructured_ingest/errors_v2.py +156 -0
- unstructured_ingest/interfaces/__init__.py +27 -0
- unstructured_ingest/interfaces/connector.py +56 -0
- unstructured_ingest/interfaces/downloader.py +90 -0
- unstructured_ingest/interfaces/indexer.py +29 -0
- unstructured_ingest/interfaces/process.py +22 -0
- unstructured_ingest/interfaces/processor.py +88 -0
- unstructured_ingest/interfaces/upload_stager.py +89 -0
- unstructured_ingest/interfaces/uploader.py +67 -0
- unstructured_ingest/logger.py +39 -0
- unstructured_ingest/main.py +11 -0
- unstructured_ingest/otel.py +128 -0
- unstructured_ingest/pipeline/__init__.py +0 -0
- unstructured_ingest/pipeline/interfaces.py +211 -0
- unstructured_ingest/pipeline/otel.py +32 -0
- unstructured_ingest/pipeline/pipeline.py +408 -0
- unstructured_ingest/pipeline/steps/__init__.py +0 -0
- unstructured_ingest/pipeline/steps/chunk.py +78 -0
- unstructured_ingest/pipeline/steps/download.py +206 -0
- unstructured_ingest/pipeline/steps/embed.py +77 -0
- unstructured_ingest/pipeline/steps/filter.py +35 -0
- unstructured_ingest/pipeline/steps/index.py +86 -0
- unstructured_ingest/pipeline/steps/partition.py +77 -0
- unstructured_ingest/pipeline/steps/stage.py +65 -0
- unstructured_ingest/pipeline/steps/uncompress.py +50 -0
- unstructured_ingest/pipeline/steps/upload.py +58 -0
- unstructured_ingest/processes/__init__.py +18 -0
- unstructured_ingest/processes/chunker.py +131 -0
- unstructured_ingest/processes/connector_registry.py +69 -0
- unstructured_ingest/processes/connectors/__init__.py +129 -0
- unstructured_ingest/processes/connectors/airtable.py +238 -0
- unstructured_ingest/processes/connectors/assets/__init__.py +0 -0
- unstructured_ingest/processes/connectors/assets/databricks_delta_table_schema.sql +9 -0
- unstructured_ingest/processes/connectors/assets/weaviate_collection_config.json +23 -0
- unstructured_ingest/processes/connectors/astradb.py +592 -0
- unstructured_ingest/processes/connectors/azure_ai_search.py +275 -0
- unstructured_ingest/processes/connectors/chroma.py +193 -0
- unstructured_ingest/processes/connectors/confluence.py +527 -0
- unstructured_ingest/processes/connectors/couchbase.py +336 -0
- unstructured_ingest/processes/connectors/databricks/__init__.py +58 -0
- unstructured_ingest/processes/connectors/databricks/volumes.py +233 -0
- unstructured_ingest/processes/connectors/databricks/volumes_aws.py +93 -0
- unstructured_ingest/processes/connectors/databricks/volumes_azure.py +108 -0
- unstructured_ingest/processes/connectors/databricks/volumes_gcp.py +91 -0
- unstructured_ingest/processes/connectors/databricks/volumes_native.py +92 -0
- unstructured_ingest/processes/connectors/databricks/volumes_table.py +187 -0
- unstructured_ingest/processes/connectors/delta_table.py +310 -0
- unstructured_ingest/processes/connectors/discord.py +161 -0
- unstructured_ingest/processes/connectors/duckdb/__init__.py +15 -0
- unstructured_ingest/processes/connectors/duckdb/base.py +103 -0
- unstructured_ingest/processes/connectors/duckdb/duckdb.py +130 -0
- unstructured_ingest/processes/connectors/duckdb/motherduck.py +130 -0
- unstructured_ingest/processes/connectors/elasticsearch/__init__.py +19 -0
- unstructured_ingest/processes/connectors/elasticsearch/elasticsearch.py +478 -0
- unstructured_ingest/processes/connectors/elasticsearch/opensearch.py +523 -0
- unstructured_ingest/processes/connectors/fsspec/__init__.py +37 -0
- unstructured_ingest/processes/connectors/fsspec/azure.py +203 -0
- unstructured_ingest/processes/connectors/fsspec/box.py +176 -0
- unstructured_ingest/processes/connectors/fsspec/dropbox.py +238 -0
- unstructured_ingest/processes/connectors/fsspec/fsspec.py +475 -0
- unstructured_ingest/processes/connectors/fsspec/gcs.py +203 -0
- unstructured_ingest/processes/connectors/fsspec/s3.py +253 -0
- unstructured_ingest/processes/connectors/fsspec/sftp.py +177 -0
- unstructured_ingest/processes/connectors/fsspec/utils.py +17 -0
- unstructured_ingest/processes/connectors/github.py +226 -0
- unstructured_ingest/processes/connectors/gitlab.py +270 -0
- unstructured_ingest/processes/connectors/google_drive.py +848 -0
- unstructured_ingest/processes/connectors/ibm_watsonx/__init__.py +10 -0
- unstructured_ingest/processes/connectors/ibm_watsonx/ibm_watsonx_s3.py +367 -0
- unstructured_ingest/processes/connectors/jira.py +522 -0
- unstructured_ingest/processes/connectors/kafka/__init__.py +17 -0
- unstructured_ingest/processes/connectors/kafka/cloud.py +121 -0
- unstructured_ingest/processes/connectors/kafka/kafka.py +275 -0
- unstructured_ingest/processes/connectors/kafka/local.py +103 -0
- unstructured_ingest/processes/connectors/kdbai.py +156 -0
- unstructured_ingest/processes/connectors/lancedb/__init__.py +30 -0
- unstructured_ingest/processes/connectors/lancedb/aws.py +43 -0
- unstructured_ingest/processes/connectors/lancedb/azure.py +43 -0
- unstructured_ingest/processes/connectors/lancedb/cloud.py +42 -0
- unstructured_ingest/processes/connectors/lancedb/gcp.py +44 -0
- unstructured_ingest/processes/connectors/lancedb/lancedb.py +181 -0
- unstructured_ingest/processes/connectors/lancedb/local.py +44 -0
- unstructured_ingest/processes/connectors/local.py +227 -0
- unstructured_ingest/processes/connectors/milvus.py +311 -0
- unstructured_ingest/processes/connectors/mongodb.py +389 -0
- unstructured_ingest/processes/connectors/neo4j.py +534 -0
- unstructured_ingest/processes/connectors/notion/__init__.py +0 -0
- unstructured_ingest/processes/connectors/notion/client.py +349 -0
- unstructured_ingest/processes/connectors/notion/connector.py +350 -0
- unstructured_ingest/processes/connectors/notion/helpers.py +448 -0
- unstructured_ingest/processes/connectors/notion/ingest_backoff/__init__.py +3 -0
- unstructured_ingest/processes/connectors/notion/ingest_backoff/_common.py +102 -0
- unstructured_ingest/processes/connectors/notion/ingest_backoff/_wrapper.py +126 -0
- unstructured_ingest/processes/connectors/notion/ingest_backoff/types.py +24 -0
- unstructured_ingest/processes/connectors/notion/interfaces.py +32 -0
- unstructured_ingest/processes/connectors/notion/types/__init__.py +0 -0
- unstructured_ingest/processes/connectors/notion/types/block.py +96 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/__init__.py +63 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/bookmark.py +40 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/breadcrumb.py +21 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/bulleted_list_item.py +31 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/callout.py +131 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/child_database.py +23 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/child_page.py +23 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/code.py +43 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/column_list.py +35 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/divider.py +22 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/embed.py +36 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/equation.py +23 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/file.py +49 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/heading.py +37 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/image.py +21 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/link_preview.py +24 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/link_to_page.py +29 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/numbered_list.py +29 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/paragraph.py +31 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/pdf.py +49 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/quote.py +37 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/synced_block.py +109 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/table.py +60 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/table_of_contents.py +23 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/template.py +30 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/todo.py +42 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/toggle.py +37 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/unsupported.py +20 -0
- unstructured_ingest/processes/connectors/notion/types/blocks/video.py +22 -0
- unstructured_ingest/processes/connectors/notion/types/database.py +73 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/__init__.py +125 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/checkbox.py +39 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/created_by.py +36 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/created_time.py +35 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/date.py +42 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/email.py +37 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/files.py +38 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/formula.py +50 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/last_edited_by.py +34 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/last_edited_time.py +35 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/multiselect.py +74 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/number.py +50 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/people.py +42 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/phone_number.py +37 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/relation.py +68 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/rich_text.py +44 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/rollup.py +57 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/select.py +70 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/status.py +82 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/title.py +38 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/unique_id.py +51 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/url.py +38 -0
- unstructured_ingest/processes/connectors/notion/types/database_properties/verification.py +79 -0
- unstructured_ingest/processes/connectors/notion/types/date.py +29 -0
- unstructured_ingest/processes/connectors/notion/types/file.py +54 -0
- unstructured_ingest/processes/connectors/notion/types/page.py +52 -0
- unstructured_ingest/processes/connectors/notion/types/parent.py +66 -0
- unstructured_ingest/processes/connectors/notion/types/rich_text.py +189 -0
- unstructured_ingest/processes/connectors/notion/types/user.py +83 -0
- unstructured_ingest/processes/connectors/onedrive.py +485 -0
- unstructured_ingest/processes/connectors/outlook.py +242 -0
- unstructured_ingest/processes/connectors/pinecone.py +400 -0
- unstructured_ingest/processes/connectors/qdrant/__init__.py +16 -0
- unstructured_ingest/processes/connectors/qdrant/cloud.py +59 -0
- unstructured_ingest/processes/connectors/qdrant/local.py +58 -0
- unstructured_ingest/processes/connectors/qdrant/qdrant.py +163 -0
- unstructured_ingest/processes/connectors/qdrant/server.py +60 -0
- unstructured_ingest/processes/connectors/redisdb.py +214 -0
- unstructured_ingest/processes/connectors/salesforce.py +307 -0
- unstructured_ingest/processes/connectors/sharepoint.py +282 -0
- unstructured_ingest/processes/connectors/slack.py +249 -0
- unstructured_ingest/processes/connectors/sql/__init__.py +41 -0
- unstructured_ingest/processes/connectors/sql/databricks_delta_tables.py +228 -0
- unstructured_ingest/processes/connectors/sql/postgres.py +168 -0
- unstructured_ingest/processes/connectors/sql/singlestore.py +176 -0
- unstructured_ingest/processes/connectors/sql/snowflake.py +298 -0
- unstructured_ingest/processes/connectors/sql/sql.py +456 -0
- unstructured_ingest/processes/connectors/sql/sqlite.py +179 -0
- unstructured_ingest/processes/connectors/sql/teradata.py +254 -0
- unstructured_ingest/processes/connectors/sql/vastdb.py +263 -0
- unstructured_ingest/processes/connectors/utils.py +60 -0
- unstructured_ingest/processes/connectors/vectara.py +348 -0
- unstructured_ingest/processes/connectors/weaviate/__init__.py +22 -0
- unstructured_ingest/processes/connectors/weaviate/cloud.py +166 -0
- unstructured_ingest/processes/connectors/weaviate/embedded.py +90 -0
- unstructured_ingest/processes/connectors/weaviate/local.py +73 -0
- unstructured_ingest/processes/connectors/weaviate/weaviate.py +337 -0
- unstructured_ingest/processes/connectors/zendesk/__init__.py +0 -0
- unstructured_ingest/processes/connectors/zendesk/client.py +314 -0
- unstructured_ingest/processes/connectors/zendesk/zendesk.py +241 -0
- unstructured_ingest/processes/embedder.py +203 -0
- unstructured_ingest/processes/filter.py +60 -0
- unstructured_ingest/processes/partitioner.py +233 -0
- unstructured_ingest/processes/uncompress.py +61 -0
- unstructured_ingest/processes/utils/__init__.py +8 -0
- unstructured_ingest/processes/utils/blob_storage.py +32 -0
- unstructured_ingest/processes/utils/logging/connector.py +365 -0
- unstructured_ingest/processes/utils/logging/sanitizer.py +117 -0
- unstructured_ingest/unstructured_api.py +140 -0
- unstructured_ingest/utils/__init__.py +5 -0
- unstructured_ingest/utils/chunking.py +56 -0
- unstructured_ingest/utils/compression.py +72 -0
- unstructured_ingest/utils/constants.py +2 -0
- unstructured_ingest/utils/data_prep.py +216 -0
- unstructured_ingest/utils/dep_check.py +78 -0
- unstructured_ingest/utils/filesystem.py +27 -0
- unstructured_ingest/utils/html.py +174 -0
- unstructured_ingest/utils/ndjson.py +52 -0
- unstructured_ingest/utils/pydantic_models.py +52 -0
- unstructured_ingest/utils/string_and_date_utils.py +74 -0
- unstructured_ingest/utils/table.py +80 -0
- unstructured_ingest/utils/tls.py +15 -0
- unstructured_ingest-1.2.32.dist-info/METADATA +235 -0
- unstructured_ingest-1.2.32.dist-info/RECORD +243 -0
- unstructured_ingest-1.2.32.dist-info/WHEEL +4 -0
- unstructured_ingest-1.2.32.dist-info/entry_points.txt +2 -0
- unstructured_ingest-1.2.32.dist-info/licenses/LICENSE.md +201 -0
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import os.path
|
|
3
|
+
from datetime import date, datetime
|
|
4
|
+
from gettext import gettext, ngettext
|
|
5
|
+
from gettext import gettext as _
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Optional, Type, TypeVar, Union
|
|
8
|
+
|
|
9
|
+
import click
|
|
10
|
+
from pydantic import BaseModel, ConfigDict, Secret, TypeAdapter, ValidationError
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def conform_click_options(options: dict):
|
|
14
|
+
# Click sets all multiple fields as tuple, this needs to be updated to list
|
|
15
|
+
for k, v in options.items():
|
|
16
|
+
if isinstance(v, tuple):
|
|
17
|
+
options[k] = list(v)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Dict(click.ParamType):
|
|
21
|
+
name = "dict"
|
|
22
|
+
|
|
23
|
+
def convert(
|
|
24
|
+
self,
|
|
25
|
+
value: Any,
|
|
26
|
+
param: Optional[click.Parameter] = None,
|
|
27
|
+
ctx: Optional[click.Context] = None,
|
|
28
|
+
) -> Any:
|
|
29
|
+
try:
|
|
30
|
+
if isinstance(value, dict):
|
|
31
|
+
return value
|
|
32
|
+
if isinstance(value, Path) and value.is_file():
|
|
33
|
+
with value.open() as f:
|
|
34
|
+
return json.load(f)
|
|
35
|
+
if isinstance(value, str):
|
|
36
|
+
return json.loads(value)
|
|
37
|
+
except json.JSONDecodeError:
|
|
38
|
+
self.fail(
|
|
39
|
+
gettext(
|
|
40
|
+
"{value} is not a valid json value.",
|
|
41
|
+
).format(value=value),
|
|
42
|
+
param,
|
|
43
|
+
ctx,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class FileOrJson(click.ParamType):
|
|
48
|
+
name = "file-or-json"
|
|
49
|
+
|
|
50
|
+
def __init__(self, allow_raw_str: bool = False):
|
|
51
|
+
self.allow_raw_str = allow_raw_str
|
|
52
|
+
|
|
53
|
+
def convert(
|
|
54
|
+
self,
|
|
55
|
+
value: Any,
|
|
56
|
+
param: Optional[click.Parameter] = None,
|
|
57
|
+
ctx: Optional[click.Context] = None,
|
|
58
|
+
) -> Any:
|
|
59
|
+
# check if valid file
|
|
60
|
+
full_path = os.path.abspath(os.path.expanduser(value))
|
|
61
|
+
if os.path.isfile(full_path):
|
|
62
|
+
return str(Path(full_path).resolve())
|
|
63
|
+
if isinstance(value, str):
|
|
64
|
+
try:
|
|
65
|
+
return json.loads(value)
|
|
66
|
+
except json.JSONDecodeError:
|
|
67
|
+
if self.allow_raw_str:
|
|
68
|
+
return value
|
|
69
|
+
self.fail(
|
|
70
|
+
gettext(
|
|
71
|
+
"{value} is neither a valid json string nor an existing filepath.",
|
|
72
|
+
).format(value=value),
|
|
73
|
+
param,
|
|
74
|
+
ctx,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
class DelimitedString(click.ParamType):
|
|
79
|
+
name = "delimited-string"
|
|
80
|
+
|
|
81
|
+
def __init__(self, delimiter: str = ",", choices: Optional[list[str]] = None):
|
|
82
|
+
self.choices = choices if choices else []
|
|
83
|
+
self.delimiter = delimiter
|
|
84
|
+
|
|
85
|
+
def convert(
|
|
86
|
+
self,
|
|
87
|
+
value: Any,
|
|
88
|
+
param: Optional[click.Parameter] = None,
|
|
89
|
+
ctx: Optional[click.Context] = None,
|
|
90
|
+
) -> Any:
|
|
91
|
+
# In case a list is provided as the default, will not break
|
|
92
|
+
if isinstance(value, list):
|
|
93
|
+
split = [str(v).strip() for v in value]
|
|
94
|
+
else:
|
|
95
|
+
split = [v.strip() for v in value.split(self.delimiter)]
|
|
96
|
+
if not self.choices:
|
|
97
|
+
return split
|
|
98
|
+
choices_str = ", ".join(map(repr, self.choices))
|
|
99
|
+
for s in split:
|
|
100
|
+
if s not in self.choices:
|
|
101
|
+
self.fail(
|
|
102
|
+
ngettext(
|
|
103
|
+
"{value!r} is not {choice}.",
|
|
104
|
+
"{value!r} is not one of {choices}.",
|
|
105
|
+
len(self.choices),
|
|
106
|
+
).format(value=s, choice=choices_str, choices=choices_str),
|
|
107
|
+
param,
|
|
108
|
+
ctx,
|
|
109
|
+
)
|
|
110
|
+
return split
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
class PydanticDateTime(click.ParamType):
|
|
114
|
+
name = "datetime"
|
|
115
|
+
|
|
116
|
+
def convert(
|
|
117
|
+
self,
|
|
118
|
+
value: Any,
|
|
119
|
+
param: Optional[click.Parameter] = None,
|
|
120
|
+
ctx: Optional[click.Context] = None,
|
|
121
|
+
) -> Any:
|
|
122
|
+
try:
|
|
123
|
+
return TypeAdapter(datetime).validate_strings(value)
|
|
124
|
+
except ValidationError:
|
|
125
|
+
self.fail(f"{value} is not a valid datetime", param, ctx)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class PydanticDate(click.ParamType):
|
|
129
|
+
name = "date"
|
|
130
|
+
|
|
131
|
+
def convert(
|
|
132
|
+
self,
|
|
133
|
+
value: Any,
|
|
134
|
+
param: Optional[click.Parameter] = None,
|
|
135
|
+
ctx: Optional[click.Context] = None,
|
|
136
|
+
) -> Any:
|
|
137
|
+
try:
|
|
138
|
+
return TypeAdapter(date).validate_strings(value)
|
|
139
|
+
except ValidationError:
|
|
140
|
+
self.fail(f"{value} is not a valid date", param, ctx)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
BaseModelT = TypeVar("BaseModelT", bound=BaseModel)
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def unwrap_optional(val: Any) -> tuple[Any, bool]:
|
|
147
|
+
if (
|
|
148
|
+
hasattr(val, "__origin__")
|
|
149
|
+
and hasattr(val, "__args__")
|
|
150
|
+
and val.__origin__ is Union
|
|
151
|
+
and len(val.__args__) == 2
|
|
152
|
+
and type(None) in val.__args__
|
|
153
|
+
):
|
|
154
|
+
args = val.__args__
|
|
155
|
+
args = [a for a in args if a is not None]
|
|
156
|
+
return args[0], True
|
|
157
|
+
return val, False
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def extract_config(flat_data: dict, config: Type[BaseModelT]) -> BaseModelT:
|
|
161
|
+
fields = config.model_fields
|
|
162
|
+
config.model_config = ConfigDict(extra="ignore")
|
|
163
|
+
field_names = [v.alias or k for k, v in fields.items()]
|
|
164
|
+
data = {k: v for k, v in flat_data.items() if k in field_names and v is not None}
|
|
165
|
+
if access_config := fields.get("access_config"):
|
|
166
|
+
access_config_type = access_config.annotation
|
|
167
|
+
access_config_type, is_optional = unwrap_optional(access_config_type)
|
|
168
|
+
# Check if raw type is wrapped by a secret
|
|
169
|
+
if (
|
|
170
|
+
hasattr(access_config_type, "__origin__")
|
|
171
|
+
and hasattr(access_config_type, "__args__")
|
|
172
|
+
and access_config_type.__origin__ is Secret
|
|
173
|
+
):
|
|
174
|
+
ac_subtypes = access_config_type.__args__
|
|
175
|
+
ac_fields = ac_subtypes[0].model_fields
|
|
176
|
+
elif issubclass(access_config_type, BaseModel):
|
|
177
|
+
ac_fields = access_config_type.model_fields
|
|
178
|
+
else:
|
|
179
|
+
raise TypeError(f"Unrecognized access_config type: {access_config_type}")
|
|
180
|
+
ac_field_names = [v.alias or k for k, v in ac_fields.items()]
|
|
181
|
+
access_config_data = {
|
|
182
|
+
k: v for k, v in flat_data.items() if k in ac_field_names and v is not None
|
|
183
|
+
}
|
|
184
|
+
if not access_config_data and is_optional:
|
|
185
|
+
data["access_config"] = None
|
|
186
|
+
else:
|
|
187
|
+
data["access_config"] = access_config_data
|
|
188
|
+
return config.model_validate(obj=data)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
class Group(click.Group):
|
|
192
|
+
def parse_args(self, ctx, args):
|
|
193
|
+
"""
|
|
194
|
+
This allows for subcommands to be called with the --help flag without breaking
|
|
195
|
+
if parent command is missing any of its required parameters
|
|
196
|
+
"""
|
|
197
|
+
try:
|
|
198
|
+
return super().parse_args(ctx, args)
|
|
199
|
+
except click.MissingParameter:
|
|
200
|
+
if "--help" not in args:
|
|
201
|
+
raise
|
|
202
|
+
# remove the required params so that help can display
|
|
203
|
+
for param in self.params:
|
|
204
|
+
param.required = False
|
|
205
|
+
return super().parse_args(ctx, args)
|
|
206
|
+
|
|
207
|
+
def format_commands(self, ctx: click.Context, formatter: click.HelpFormatter) -> None:
|
|
208
|
+
"""
|
|
209
|
+
Copy of the original click.Group format_commands() method but replacing
|
|
210
|
+
'Commands' -> 'Destinations'
|
|
211
|
+
"""
|
|
212
|
+
commands = []
|
|
213
|
+
for subcommand in self.list_commands(ctx):
|
|
214
|
+
cmd = self.get_command(ctx, subcommand)
|
|
215
|
+
# What is this, the tool lied about a command. Ignore it
|
|
216
|
+
if cmd is None:
|
|
217
|
+
continue
|
|
218
|
+
if cmd.hidden:
|
|
219
|
+
continue
|
|
220
|
+
|
|
221
|
+
commands.append((subcommand, cmd))
|
|
222
|
+
|
|
223
|
+
# allow for 3 times the default spacing
|
|
224
|
+
if len(commands):
|
|
225
|
+
if formatter.width:
|
|
226
|
+
limit = formatter.width - 6 - max(len(cmd[0]) for cmd in commands)
|
|
227
|
+
else:
|
|
228
|
+
limit = -6 - max(len(cmd[0]) for cmd in commands)
|
|
229
|
+
|
|
230
|
+
rows = []
|
|
231
|
+
for subcommand, cmd in commands:
|
|
232
|
+
help = cmd.get_short_help_str(limit)
|
|
233
|
+
rows.append((subcommand, help))
|
|
234
|
+
|
|
235
|
+
if rows:
|
|
236
|
+
with formatter.section(_("Destinations")):
|
|
237
|
+
formatter.write_dl(rows)
|
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
import contextlib
|
|
2
|
+
import datetime
|
|
3
|
+
from collections import Counter
|
|
4
|
+
from enum import EnumMeta
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import (
|
|
7
|
+
Annotated,
|
|
8
|
+
Any,
|
|
9
|
+
Callable,
|
|
10
|
+
Literal,
|
|
11
|
+
Optional,
|
|
12
|
+
Type,
|
|
13
|
+
TypedDict,
|
|
14
|
+
Union,
|
|
15
|
+
get_args,
|
|
16
|
+
get_origin,
|
|
17
|
+
)
|
|
18
|
+
from uuid import UUID
|
|
19
|
+
|
|
20
|
+
import click
|
|
21
|
+
from annotated_types import Ge, Gt, Le, Lt, SupportsGe, SupportsGt, SupportsLe, SupportsLt
|
|
22
|
+
from click import Option
|
|
23
|
+
from pydantic import BaseModel, Secret, SecretStr
|
|
24
|
+
from pydantic.fields import FieldInfo
|
|
25
|
+
from pydantic.types import _SecretBase
|
|
26
|
+
from pydantic_core import PydanticUndefined
|
|
27
|
+
|
|
28
|
+
from unstructured_ingest.cli.utils.click import (
|
|
29
|
+
DelimitedString,
|
|
30
|
+
Dict,
|
|
31
|
+
PydanticDate,
|
|
32
|
+
PydanticDateTime,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
NoneType = type(None)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class _RangeDict(TypedDict, total=False):
|
|
39
|
+
"""Represent arguments to `click.IntRange` or `click.FloatRange`."""
|
|
40
|
+
|
|
41
|
+
max: Union[SupportsLt, SupportsLe]
|
|
42
|
+
min: Union[SupportsGt, SupportsGe]
|
|
43
|
+
max_open: bool
|
|
44
|
+
min_open: bool
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def get_range_from_metadata(metadata: list[Any]) -> _RangeDict:
|
|
48
|
+
range_args: _RangeDict = {}
|
|
49
|
+
for constraint in metadata:
|
|
50
|
+
if isinstance(constraint, Le):
|
|
51
|
+
range_args["max"] = constraint.le
|
|
52
|
+
range_args["max_open"] = False
|
|
53
|
+
if isinstance(constraint, Lt):
|
|
54
|
+
range_args["max"] = constraint.lt
|
|
55
|
+
range_args["max_open"] = True
|
|
56
|
+
if isinstance(constraint, Ge):
|
|
57
|
+
range_args["min"] = constraint.ge
|
|
58
|
+
range_args["min_open"] = False
|
|
59
|
+
if isinstance(constraint, Gt):
|
|
60
|
+
range_args["min"] = constraint.gt
|
|
61
|
+
range_args["min_open"] = True
|
|
62
|
+
return range_args
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def is_boolean_flag(field_info: FieldInfo) -> bool:
|
|
66
|
+
annotation = field_info.annotation
|
|
67
|
+
raw_annotation = get_raw_type(annotation)
|
|
68
|
+
return raw_annotation is bool
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def get_raw_type(val: Any) -> Any:
|
|
72
|
+
field_args = get_args(val)
|
|
73
|
+
field_origin = get_origin(val)
|
|
74
|
+
if field_origin is Union and len(field_args) == 2 and NoneType in field_args:
|
|
75
|
+
field_type = next(field_arg for field_arg in field_args if field_arg is not None)
|
|
76
|
+
return field_type
|
|
77
|
+
if field_origin is Secret and len(field_args) == 1:
|
|
78
|
+
field_type = next(field_arg for field_arg in field_args if field_arg is not None)
|
|
79
|
+
return field_type
|
|
80
|
+
if val is SecretStr:
|
|
81
|
+
return str
|
|
82
|
+
return val
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def get_default_value_from_field(field: FieldInfo) -> Optional[Union[Any, Callable[[], Any]]]:
|
|
86
|
+
if field.default is not PydanticUndefined:
|
|
87
|
+
return field.default
|
|
88
|
+
elif field.default_factory is not None:
|
|
89
|
+
return field.default_factory
|
|
90
|
+
return None
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def get_option_name(field_name: str, field_info: FieldInfo) -> str:
|
|
94
|
+
field_name = field_info.alias or field_name
|
|
95
|
+
if field_name.startswith("--"):
|
|
96
|
+
field_name = field_name[2:]
|
|
97
|
+
field_name = field_name.lower().replace("_", "-")
|
|
98
|
+
if is_boolean_flag(field_info):
|
|
99
|
+
return f"--{field_name}/--no-{field_name}"
|
|
100
|
+
return f"--{field_name}"
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def get_numerical_type(field: FieldInfo) -> click.ParamType:
|
|
104
|
+
range_args = get_range_from_metadata(field.metadata)
|
|
105
|
+
if field.annotation is int:
|
|
106
|
+
if range_args:
|
|
107
|
+
return click.IntRange(**range_args) # type: ignore[arg-type]
|
|
108
|
+
return click.INT
|
|
109
|
+
# Non-integer numerical data_types default to float
|
|
110
|
+
if range_args:
|
|
111
|
+
return click.FloatRange(**range_args) # type: ignore[arg-type]
|
|
112
|
+
return click.FLOAT
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def get_type_from_annotation(field_type: Any) -> click.ParamType:
|
|
116
|
+
field_origin = get_origin(field_type)
|
|
117
|
+
field_args = get_args(field_type)
|
|
118
|
+
if field_origin is Union and len(field_args) == 2 and NoneType in field_args:
|
|
119
|
+
field_type = next(field_arg for field_arg in field_args if field_arg is not None)
|
|
120
|
+
return get_type_from_annotation(field_type=field_type)
|
|
121
|
+
if field_origin is Annotated:
|
|
122
|
+
field_origin = field_args[0]
|
|
123
|
+
field_metadata = field_args[1]
|
|
124
|
+
if isinstance(field_metadata, click.ParamType):
|
|
125
|
+
return field_metadata
|
|
126
|
+
if field_origin is Secret and len(field_args) == 1:
|
|
127
|
+
field_type = next(field_arg for field_arg in field_args if field_arg is not None)
|
|
128
|
+
return get_type_from_annotation(field_type=field_type)
|
|
129
|
+
if field_origin is list and len(field_args) == 1 and field_args[0] is str:
|
|
130
|
+
return DelimitedString()
|
|
131
|
+
if field_type is SecretStr:
|
|
132
|
+
return click.STRING
|
|
133
|
+
if dict in [field_type, field_origin]:
|
|
134
|
+
return Dict()
|
|
135
|
+
if field_type is str:
|
|
136
|
+
return click.STRING
|
|
137
|
+
if field_type is bool:
|
|
138
|
+
return click.BOOL
|
|
139
|
+
if field_type is UUID:
|
|
140
|
+
return click.UUID
|
|
141
|
+
if field_type is Path:
|
|
142
|
+
return click.Path(path_type=Path)
|
|
143
|
+
if field_type is datetime.datetime:
|
|
144
|
+
return PydanticDateTime()
|
|
145
|
+
if field_type is datetime.date:
|
|
146
|
+
return PydanticDate()
|
|
147
|
+
if field_origin is Literal:
|
|
148
|
+
return click.Choice(field_args)
|
|
149
|
+
if isinstance(field_type, EnumMeta):
|
|
150
|
+
values = [i.value for i in field_type]
|
|
151
|
+
return click.Choice(values)
|
|
152
|
+
raise TypeError(f"Unexpected field type: {field_type}")
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _get_type_from_field(field: FieldInfo) -> click.ParamType:
|
|
156
|
+
raw_field_type = get_raw_type(field.annotation)
|
|
157
|
+
|
|
158
|
+
if raw_field_type in (int, float):
|
|
159
|
+
return get_numerical_type(field)
|
|
160
|
+
return get_type_from_annotation(field_type=field.annotation)
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def get_option_from_field(option_name: str, field_info: FieldInfo) -> Option:
|
|
164
|
+
param_decls = [option_name]
|
|
165
|
+
help_text = field_info.description or ""
|
|
166
|
+
if examples := field_info.examples:
|
|
167
|
+
help_text += f" [Examples: {', '.join(examples)}]"
|
|
168
|
+
option_kwargs = {
|
|
169
|
+
"type": _get_type_from_field(field_info),
|
|
170
|
+
"default": get_default_value_from_field(field_info),
|
|
171
|
+
"required": field_info.is_required(),
|
|
172
|
+
"help": str(help_text),
|
|
173
|
+
"is_flag": is_boolean_flag(field_info),
|
|
174
|
+
"show_default": field_info.default is not PydanticUndefined,
|
|
175
|
+
}
|
|
176
|
+
return click.Option(param_decls=param_decls, **option_kwargs)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def is_subclass(x: Any, y: Any) -> bool:
|
|
180
|
+
with contextlib.suppress(TypeError):
|
|
181
|
+
return issubclass(x, y)
|
|
182
|
+
|
|
183
|
+
return False
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def post_check(options: list[Option], name: str):
|
|
187
|
+
option_names = [option.name for option in options]
|
|
188
|
+
duplicate_names = [name for name, count in Counter(option_names).items() if count > 1]
|
|
189
|
+
if duplicate_names:
|
|
190
|
+
raise ValueError(
|
|
191
|
+
"[{}] the following field name were reused, all must be unique: {}".format(
|
|
192
|
+
name, ", ".join(duplicate_names)
|
|
193
|
+
)
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
|
|
197
|
+
def is_secret(value: Any) -> bool:
|
|
198
|
+
# Case Secret[int]
|
|
199
|
+
if hasattr(value, "__origin__") and hasattr(value, "__args__"):
|
|
200
|
+
origin = value.__origin__
|
|
201
|
+
return is_subclass(origin, _SecretBase)
|
|
202
|
+
# Case SecretStr
|
|
203
|
+
return is_subclass(value, _SecretBase)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
def options_from_base_model(model: Union[BaseModel, Type[BaseModel]]) -> list[Option]:
|
|
207
|
+
options = []
|
|
208
|
+
model_fields = model.model_fields
|
|
209
|
+
for field_name, field_info in model_fields.items():
|
|
210
|
+
if field_info.init is False:
|
|
211
|
+
continue
|
|
212
|
+
option_name = get_option_name(field_name=field_name, field_info=field_info)
|
|
213
|
+
raw_annotation = get_raw_type(field_info.annotation)
|
|
214
|
+
if is_subclass(raw_annotation, BaseModel):
|
|
215
|
+
options.extend(options_from_base_model(model=raw_annotation))
|
|
216
|
+
else:
|
|
217
|
+
if is_secret(field_info.annotation):
|
|
218
|
+
field_info.description = f"[sensitive] {field_info.description}"
|
|
219
|
+
options.append(get_option_from_field(option_name=option_name, field_info=field_info))
|
|
220
|
+
|
|
221
|
+
post_check(options=options, name=model.__name__)
|
|
222
|
+
return options
|
|
File without changes
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from pydantic import BaseModel, Field
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class Entity(BaseModel):
|
|
5
|
+
type: str
|
|
6
|
+
entity: str
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class EntityRelationship(BaseModel):
|
|
10
|
+
to: str
|
|
11
|
+
from_: str = Field(..., alias="from")
|
|
12
|
+
relationship: str
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class EntitiesData(BaseModel):
|
|
16
|
+
items: list[Entity] = Field(default_factory=list)
|
|
17
|
+
relationships: list[EntityRelationship] = Field(default_factory=list)
|
|
@@ -0,0 +1,116 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any, Optional
|
|
4
|
+
from uuid import NAMESPACE_DNS, uuid5
|
|
5
|
+
|
|
6
|
+
from pydantic import BaseModel, Field, ValidationError, field_validator, model_validator
|
|
7
|
+
|
|
8
|
+
from unstructured_ingest.logger import logger
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class SourceIdentifiers(BaseModel):
|
|
12
|
+
filename: str
|
|
13
|
+
fullpath: str
|
|
14
|
+
rel_path: Optional[str] = None
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def filename_stem(self) -> str:
|
|
18
|
+
return Path(self.filename).stem
|
|
19
|
+
|
|
20
|
+
@property
|
|
21
|
+
def relative_path(self) -> str:
|
|
22
|
+
return self.rel_path or self.fullpath
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class FileDataSourceMetadata(BaseModel):
|
|
26
|
+
url: Optional[str] = None
|
|
27
|
+
version: Optional[str] = None
|
|
28
|
+
record_locator: Optional[dict[str, Any]] = None
|
|
29
|
+
date_created: Optional[str] = None
|
|
30
|
+
date_modified: Optional[str] = None
|
|
31
|
+
date_processed: Optional[str] = None
|
|
32
|
+
permissions_data: Optional[list[dict[str, Any]]] = None
|
|
33
|
+
filesize_bytes: Optional[int] = None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class FileData(BaseModel):
|
|
37
|
+
identifier: str
|
|
38
|
+
connector_type: str
|
|
39
|
+
source_identifiers: SourceIdentifiers
|
|
40
|
+
metadata: FileDataSourceMetadata = Field(default_factory=lambda: FileDataSourceMetadata())
|
|
41
|
+
additional_metadata: dict[str, Any] = Field(default_factory=dict)
|
|
42
|
+
reprocess: bool = False
|
|
43
|
+
local_download_path: Optional[str] = None
|
|
44
|
+
display_name: Optional[str] = None
|
|
45
|
+
|
|
46
|
+
@classmethod
|
|
47
|
+
def from_file(cls, path: str) -> "FileData":
|
|
48
|
+
path = Path(path).resolve()
|
|
49
|
+
if not path.exists() or not path.is_file():
|
|
50
|
+
raise ValueError(f"file path not valid: {path}")
|
|
51
|
+
with open(str(path.resolve()), "rb") as f:
|
|
52
|
+
file_data_dict = json.load(f)
|
|
53
|
+
file_data = cls.model_validate(file_data_dict)
|
|
54
|
+
return file_data
|
|
55
|
+
|
|
56
|
+
@classmethod
|
|
57
|
+
def cast(cls, file_data: "FileData", **kwargs) -> "FileData":
|
|
58
|
+
file_data_dict = file_data.model_dump()
|
|
59
|
+
return cls.model_validate(file_data_dict, **kwargs)
|
|
60
|
+
|
|
61
|
+
def to_file(self, path: str) -> None:
|
|
62
|
+
path = Path(path).resolve()
|
|
63
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
64
|
+
with open(str(path.resolve()), "w") as f:
|
|
65
|
+
json.dump(self.model_dump(), f, indent=2)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class BatchItem(BaseModel):
|
|
69
|
+
identifier: str
|
|
70
|
+
version: Optional[str] = None
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class BatchFileData(FileData):
|
|
74
|
+
identifier: str = Field(init=False)
|
|
75
|
+
batch_items: list[BatchItem]
|
|
76
|
+
source_identifiers: Optional[SourceIdentifiers] = None
|
|
77
|
+
|
|
78
|
+
@field_validator("batch_items")
|
|
79
|
+
@classmethod
|
|
80
|
+
def check_batch_items(cls, v: list[BatchItem]) -> list[BatchItem]:
|
|
81
|
+
if not v:
|
|
82
|
+
raise ValueError("batch items cannot be empty")
|
|
83
|
+
all_identifiers = [item.identifier for item in v]
|
|
84
|
+
if len(all_identifiers) != len(set(all_identifiers)):
|
|
85
|
+
raise ValueError(f"duplicate identifiers: {all_identifiers}")
|
|
86
|
+
sorted_batch_items = sorted(v, key=lambda item: item.identifier)
|
|
87
|
+
return sorted_batch_items
|
|
88
|
+
|
|
89
|
+
@model_validator(mode="before")
|
|
90
|
+
@classmethod
|
|
91
|
+
def populate_identifier(cls, data: Any) -> Any:
|
|
92
|
+
if isinstance(data, dict) and "identifier" not in data:
|
|
93
|
+
batch_items = data["batch_items"]
|
|
94
|
+
identifier_data = json.dumps(
|
|
95
|
+
{item.identifier: item.version for item in batch_items}, sort_keys=True
|
|
96
|
+
)
|
|
97
|
+
data["identifier"] = str(uuid5(NAMESPACE_DNS, str(identifier_data)))
|
|
98
|
+
return data
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def file_data_from_file(path: str) -> FileData:
|
|
102
|
+
try:
|
|
103
|
+
return BatchFileData.from_file(path=path)
|
|
104
|
+
except ValidationError:
|
|
105
|
+
logger.debug(f"{path} not detected as batch file data")
|
|
106
|
+
|
|
107
|
+
return FileData.from_file(path=path)
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def file_data_from_dict(data: dict) -> FileData:
|
|
111
|
+
try:
|
|
112
|
+
return BatchFileData.model_validate(data)
|
|
113
|
+
except ValidationError:
|
|
114
|
+
logger.debug(f"{data} not valid for batch file data")
|
|
115
|
+
|
|
116
|
+
return FileData.model_validate(data)
|
|
File without changes
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
from pydantic import Field
|
|
5
|
+
|
|
6
|
+
from unstructured_ingest.embed.openai import (
|
|
7
|
+
AsyncOpenAIEmbeddingEncoder,
|
|
8
|
+
OpenAIEmbeddingConfig,
|
|
9
|
+
OpenAIEmbeddingEncoder,
|
|
10
|
+
)
|
|
11
|
+
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
12
|
+
from unstructured_ingest.utils.tls import ssl_context_with_optional_ca_override
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from openai import AsyncAzureOpenAI, AzureOpenAI
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class AzureOpenAIEmbeddingConfig(OpenAIEmbeddingConfig):
|
|
19
|
+
api_version: str = Field(description="Azure API version", default="2024-06-01")
|
|
20
|
+
azure_endpoint: str = Field(description="Azure endpoint")
|
|
21
|
+
embedder_model_name: str = Field(
|
|
22
|
+
default="text-embedding-ada-002", alias="model_name", description="Azure OpenAI model name"
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
@requires_dependencies(["openai"], extras="openai")
|
|
26
|
+
def get_client(self) -> "AzureOpenAI":
|
|
27
|
+
from openai import AzureOpenAI, DefaultHttpxClient
|
|
28
|
+
|
|
29
|
+
client = DefaultHttpxClient(verify=ssl_context_with_optional_ca_override())
|
|
30
|
+
return AzureOpenAI(
|
|
31
|
+
http_client=client,
|
|
32
|
+
api_key=self.api_key.get_secret_value(),
|
|
33
|
+
api_version=self.api_version,
|
|
34
|
+
azure_endpoint=self.azure_endpoint,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
@requires_dependencies(["openai"], extras="openai")
|
|
38
|
+
def get_async_client(self) -> "AsyncAzureOpenAI":
|
|
39
|
+
from openai import AsyncAzureOpenAI, DefaultAsyncHttpxClient
|
|
40
|
+
|
|
41
|
+
client = DefaultAsyncHttpxClient(verify=ssl_context_with_optional_ca_override())
|
|
42
|
+
return AsyncAzureOpenAI(
|
|
43
|
+
http_client=client,
|
|
44
|
+
api_key=self.api_key.get_secret_value(),
|
|
45
|
+
api_version=self.api_version,
|
|
46
|
+
azure_endpoint=self.azure_endpoint,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass
|
|
51
|
+
class AzureOpenAIEmbeddingEncoder(OpenAIEmbeddingEncoder):
|
|
52
|
+
config: AzureOpenAIEmbeddingConfig
|
|
53
|
+
|
|
54
|
+
def get_client(self) -> "AzureOpenAI":
|
|
55
|
+
return self.config.get_client()
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@dataclass
|
|
59
|
+
class AsyncAzureOpenAIEmbeddingEncoder(AsyncOpenAIEmbeddingEncoder):
|
|
60
|
+
config: AzureOpenAIEmbeddingConfig
|
|
61
|
+
|
|
62
|
+
def get_client(self) -> "AsyncAzureOpenAI":
|
|
63
|
+
return self.config.get_async_client()
|