unstructured-ingest 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of unstructured-ingest might be problematic. Click here for more details.
- test/integration/connectors/elasticsearch/__init__.py +0 -0
- test/integration/connectors/elasticsearch/conftest.py +34 -0
- test/integration/connectors/elasticsearch/test_elasticsearch.py +308 -0
- test/integration/connectors/elasticsearch/test_opensearch.py +302 -0
- test/integration/connectors/sql/test_postgres.py +10 -4
- test/integration/connectors/sql/test_singlestore.py +8 -4
- test/integration/connectors/sql/test_snowflake.py +10 -6
- test/integration/connectors/sql/test_sqlite.py +4 -4
- test/integration/connectors/test_astradb.py +50 -3
- test/integration/connectors/test_delta_table.py +46 -0
- test/integration/connectors/test_kafka.py +40 -6
- test/integration/connectors/test_lancedb.py +210 -0
- test/integration/connectors/test_milvus.py +141 -0
- test/integration/connectors/test_mongodb.py +332 -0
- test/integration/connectors/test_pinecone.py +53 -1
- test/integration/connectors/utils/docker.py +81 -15
- test/integration/connectors/utils/validation.py +10 -0
- test/integration/connectors/weaviate/__init__.py +0 -0
- test/integration/connectors/weaviate/conftest.py +15 -0
- test/integration/connectors/weaviate/test_local.py +131 -0
- unstructured_ingest/__version__.py +1 -1
- unstructured_ingest/pipeline/reformat/embedding.py +1 -1
- unstructured_ingest/utils/data_prep.py +9 -1
- unstructured_ingest/v2/processes/connectors/__init__.py +3 -16
- unstructured_ingest/v2/processes/connectors/astradb.py +2 -2
- unstructured_ingest/v2/processes/connectors/azure_ai_search.py +4 -0
- unstructured_ingest/v2/processes/connectors/delta_table.py +20 -4
- unstructured_ingest/v2/processes/connectors/elasticsearch/__init__.py +19 -0
- unstructured_ingest/v2/processes/connectors/{elasticsearch.py → elasticsearch/elasticsearch.py} +92 -46
- unstructured_ingest/v2/processes/connectors/{opensearch.py → elasticsearch/opensearch.py} +1 -1
- unstructured_ingest/v2/processes/connectors/google_drive.py +1 -1
- unstructured_ingest/v2/processes/connectors/kafka/kafka.py +6 -0
- unstructured_ingest/v2/processes/connectors/lancedb/__init__.py +17 -0
- unstructured_ingest/v2/processes/connectors/lancedb/aws.py +43 -0
- unstructured_ingest/v2/processes/connectors/lancedb/azure.py +43 -0
- unstructured_ingest/v2/processes/connectors/lancedb/gcp.py +44 -0
- unstructured_ingest/v2/processes/connectors/lancedb/lancedb.py +161 -0
- unstructured_ingest/v2/processes/connectors/lancedb/local.py +44 -0
- unstructured_ingest/v2/processes/connectors/milvus.py +72 -27
- unstructured_ingest/v2/processes/connectors/mongodb.py +122 -111
- unstructured_ingest/v2/processes/connectors/pinecone.py +24 -7
- unstructured_ingest/v2/processes/connectors/sql/sql.py +97 -26
- unstructured_ingest/v2/processes/connectors/weaviate/__init__.py +25 -0
- unstructured_ingest/v2/processes/connectors/weaviate/cloud.py +164 -0
- unstructured_ingest/v2/processes/connectors/weaviate/embedded.py +90 -0
- unstructured_ingest/v2/processes/connectors/weaviate/local.py +73 -0
- unstructured_ingest/v2/processes/connectors/weaviate/weaviate.py +299 -0
- {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.2.dist-info}/METADATA +19 -19
- {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.2.dist-info}/RECORD +54 -33
- unstructured_ingest/v2/processes/connectors/weaviate.py +0 -242
- /test/integration/connectors/{test_azure_cog_search.py → test_azure_ai_search.py} +0 -0
- {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.2.dist-info}/LICENSE.md +0 -0
- {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.2.dist-info}/WHEEL +0 -0
- {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.2.dist-info}/entry_points.txt +0 -0
- {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
from contextlib import contextmanager
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import TYPE_CHECKING, Any, Generator, Optional
|
|
4
|
+
|
|
5
|
+
from pydantic import Field, Secret
|
|
6
|
+
|
|
7
|
+
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
8
|
+
from unstructured_ingest.v2.processes.connector_registry import DestinationRegistryEntry
|
|
9
|
+
from unstructured_ingest.v2.processes.connectors.weaviate.weaviate import (
|
|
10
|
+
WeaviateAccessConfig,
|
|
11
|
+
WeaviateConnectionConfig,
|
|
12
|
+
WeaviateUploader,
|
|
13
|
+
WeaviateUploaderConfig,
|
|
14
|
+
WeaviateUploadStager,
|
|
15
|
+
WeaviateUploadStagerConfig,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from weaviate.auth import AuthCredentials
|
|
20
|
+
from weaviate.client import WeaviateClient
|
|
21
|
+
|
|
22
|
+
CONNECTOR_TYPE = "weaviate-cloud"
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class CloudWeaviateAccessConfig(WeaviateAccessConfig):
|
|
26
|
+
access_token: Optional[str] = Field(
|
|
27
|
+
default=None, description="Used to create the bearer token."
|
|
28
|
+
)
|
|
29
|
+
api_key: Optional[str] = None
|
|
30
|
+
client_secret: Optional[str] = None
|
|
31
|
+
password: Optional[str] = None
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class CloudWeaviateConnectionConfig(WeaviateConnectionConfig):
|
|
35
|
+
cluster_url: str = Field(
|
|
36
|
+
description="The WCD cluster URL or hostname to connect to. "
|
|
37
|
+
"Usually in the form: rAnD0mD1g1t5.something.weaviate.cloud"
|
|
38
|
+
)
|
|
39
|
+
username: Optional[str] = None
|
|
40
|
+
anonymous: bool = Field(default=False, description="if set, all auth values will be ignored")
|
|
41
|
+
refresh_token: Optional[str] = Field(
|
|
42
|
+
default=None,
|
|
43
|
+
description="Will tie this value to the bearer token. If not provided, "
|
|
44
|
+
"the authentication will expire once the lifetime of the access token is up.",
|
|
45
|
+
)
|
|
46
|
+
access_config: Secret[CloudWeaviateAccessConfig]
|
|
47
|
+
|
|
48
|
+
def model_post_init(self, __context: Any) -> None:
|
|
49
|
+
if self.anonymous:
|
|
50
|
+
return
|
|
51
|
+
access_config = self.access_config.get_secret_value()
|
|
52
|
+
auths = {
|
|
53
|
+
"api_key": access_config.api_key is not None,
|
|
54
|
+
"bearer_token": access_config.access_token is not None,
|
|
55
|
+
"client_secret": access_config.client_secret is not None,
|
|
56
|
+
"client_password": access_config.password is not None and self.username is not None,
|
|
57
|
+
}
|
|
58
|
+
if len(auths) == 0:
|
|
59
|
+
raise ValueError("No auth values provided and anonymous is False")
|
|
60
|
+
if len(auths) > 1:
|
|
61
|
+
existing_auths = [auth_method for auth_method, flag in auths.items() if flag]
|
|
62
|
+
raise ValueError(
|
|
63
|
+
"Multiple auth values provided, only one approach can be used: {}".format(
|
|
64
|
+
", ".join(existing_auths)
|
|
65
|
+
)
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
@requires_dependencies(["weaviate"], extras="weaviate")
|
|
69
|
+
def get_api_key_auth(self) -> Optional["AuthCredentials"]:
|
|
70
|
+
from weaviate.classes.init import Auth
|
|
71
|
+
|
|
72
|
+
if api_key := self.access_config.get_secret_value().api_key:
|
|
73
|
+
return Auth.api_key(api_key=api_key)
|
|
74
|
+
return None
|
|
75
|
+
|
|
76
|
+
@requires_dependencies(["weaviate"], extras="weaviate")
|
|
77
|
+
def get_bearer_token_auth(self) -> Optional["AuthCredentials"]:
|
|
78
|
+
from weaviate.classes.init import Auth
|
|
79
|
+
|
|
80
|
+
if access_token := self.access_config.get_secret_value().access_token:
|
|
81
|
+
return Auth.bearer_token(access_token=access_token, refresh_token=self.refresh_token)
|
|
82
|
+
return None
|
|
83
|
+
|
|
84
|
+
@requires_dependencies(["weaviate"], extras="weaviate")
|
|
85
|
+
def get_client_secret_auth(self) -> Optional["AuthCredentials"]:
|
|
86
|
+
from weaviate.classes.init import Auth
|
|
87
|
+
|
|
88
|
+
if client_secret := self.access_config.get_secret_value().client_secret:
|
|
89
|
+
return Auth.client_credentials(client_secret=client_secret)
|
|
90
|
+
return None
|
|
91
|
+
|
|
92
|
+
@requires_dependencies(["weaviate"], extras="weaviate")
|
|
93
|
+
def get_client_password_auth(self) -> Optional["AuthCredentials"]:
|
|
94
|
+
from weaviate.classes.init import Auth
|
|
95
|
+
|
|
96
|
+
if (username := self.username) and (
|
|
97
|
+
password := self.access_config.get_secret_value().password
|
|
98
|
+
):
|
|
99
|
+
return Auth.client_password(username=username, password=password)
|
|
100
|
+
return None
|
|
101
|
+
|
|
102
|
+
@requires_dependencies(["weaviate"], extras="weaviate")
|
|
103
|
+
def get_auth(self) -> "AuthCredentials":
|
|
104
|
+
auths = [
|
|
105
|
+
self.get_api_key_auth(),
|
|
106
|
+
self.get_client_secret_auth(),
|
|
107
|
+
self.get_bearer_token_auth(),
|
|
108
|
+
self.get_client_password_auth(),
|
|
109
|
+
]
|
|
110
|
+
auths = [auth for auth in auths if auth]
|
|
111
|
+
if len(auths) == 0:
|
|
112
|
+
raise ValueError("No auth values provided and anonymous is False")
|
|
113
|
+
if len(auths) > 1:
|
|
114
|
+
raise ValueError("Multiple auth values provided, only one approach can be used")
|
|
115
|
+
return auths[0]
|
|
116
|
+
|
|
117
|
+
@contextmanager
|
|
118
|
+
@requires_dependencies(["weaviate"], extras="weaviate")
|
|
119
|
+
def get_client(self) -> Generator["WeaviateClient", None, None]:
|
|
120
|
+
from weaviate import connect_to_weaviate_cloud
|
|
121
|
+
from weaviate.classes.init import AdditionalConfig
|
|
122
|
+
|
|
123
|
+
auth_credentials = None if self.anonymous else self.get_auth()
|
|
124
|
+
with connect_to_weaviate_cloud(
|
|
125
|
+
cluster_url=self.cluster_url,
|
|
126
|
+
auth_credentials=auth_credentials,
|
|
127
|
+
additional_config=AdditionalConfig(timeout=self.get_timeout()),
|
|
128
|
+
) as weaviate_client:
|
|
129
|
+
yield weaviate_client
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class CloudWeaviateUploadStagerConfig(WeaviateUploadStagerConfig):
|
|
133
|
+
pass
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
@dataclass
|
|
137
|
+
class CloudWeaviateUploadStager(WeaviateUploadStager):
|
|
138
|
+
upload_stager_config: CloudWeaviateUploadStagerConfig = field(
|
|
139
|
+
default_factory=lambda: WeaviateUploadStagerConfig()
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class CloudWeaviateUploaderConfig(WeaviateUploaderConfig):
|
|
144
|
+
pass
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
@dataclass
|
|
148
|
+
class CloudWeaviateUploader(WeaviateUploader):
|
|
149
|
+
connection_config: CloudWeaviateConnectionConfig = field(
|
|
150
|
+
default_factory=lambda: CloudWeaviateConnectionConfig()
|
|
151
|
+
)
|
|
152
|
+
upload_config: CloudWeaviateUploaderConfig = field(
|
|
153
|
+
default_factory=lambda: CloudWeaviateUploaderConfig()
|
|
154
|
+
)
|
|
155
|
+
connector_type: str = CONNECTOR_TYPE
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
weaviate_cloud_destination_entry = DestinationRegistryEntry(
|
|
159
|
+
connection_config=CloudWeaviateConnectionConfig,
|
|
160
|
+
uploader=CloudWeaviateUploader,
|
|
161
|
+
uploader_config=CloudWeaviateUploaderConfig,
|
|
162
|
+
upload_stager=CloudWeaviateUploadStager,
|
|
163
|
+
upload_stager_config=CloudWeaviateUploadStagerConfig,
|
|
164
|
+
)
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
from contextlib import contextmanager
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import TYPE_CHECKING, Generator, Optional
|
|
4
|
+
|
|
5
|
+
from pydantic import Field, Secret
|
|
6
|
+
|
|
7
|
+
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
8
|
+
from unstructured_ingest.v2.processes.connector_registry import DestinationRegistryEntry
|
|
9
|
+
from unstructured_ingest.v2.processes.connectors.weaviate.weaviate import (
|
|
10
|
+
WeaviateAccessConfig,
|
|
11
|
+
WeaviateConnectionConfig,
|
|
12
|
+
WeaviateUploader,
|
|
13
|
+
WeaviateUploaderConfig,
|
|
14
|
+
WeaviateUploadStager,
|
|
15
|
+
WeaviateUploadStagerConfig,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from weaviate.client import WeaviateClient
|
|
20
|
+
|
|
21
|
+
CONNECTOR_TYPE = "weaviate-embedded"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class EmbeddedWeaviateAccessConfig(WeaviateAccessConfig):
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class EmbeddedWeaviateConnectionConfig(WeaviateConnectionConfig):
|
|
29
|
+
hostname: str = Field(default="127.0.0.1", description="hostname")
|
|
30
|
+
port: int = Field(default=8079, description="http port")
|
|
31
|
+
grpc_port: int = Field(default=50050, description="grpc port")
|
|
32
|
+
data_path: Optional[str] = Field(
|
|
33
|
+
default=None,
|
|
34
|
+
description="directory where the files making up the "
|
|
35
|
+
"database are stored. If not provided, will "
|
|
36
|
+
"default to underlying SDK implementation",
|
|
37
|
+
)
|
|
38
|
+
access_config: Secret[WeaviateAccessConfig] = Field(
|
|
39
|
+
default=WeaviateAccessConfig(), validate_default=True
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
@contextmanager
|
|
43
|
+
@requires_dependencies(["weaviate"], extras="weaviate")
|
|
44
|
+
def get_client(self) -> Generator["WeaviateClient", None, None]:
|
|
45
|
+
from weaviate import connect_to_embedded
|
|
46
|
+
from weaviate.classes.init import AdditionalConfig
|
|
47
|
+
|
|
48
|
+
with connect_to_embedded(
|
|
49
|
+
hostname=self.hostname,
|
|
50
|
+
port=self.port,
|
|
51
|
+
grpc_port=self.grpc_port,
|
|
52
|
+
persistence_data_path=self.data_path,
|
|
53
|
+
additional_config=AdditionalConfig(timeout=self.get_timeout()),
|
|
54
|
+
) as weaviate_client:
|
|
55
|
+
yield weaviate_client
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class EmbeddedWeaviateUploadStagerConfig(WeaviateUploadStagerConfig):
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@dataclass
|
|
63
|
+
class EmbeddedWeaviateUploadStager(WeaviateUploadStager):
|
|
64
|
+
upload_stager_config: EmbeddedWeaviateUploadStagerConfig = field(
|
|
65
|
+
default_factory=lambda: WeaviateUploadStagerConfig()
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class EmbeddedWeaviateUploaderConfig(WeaviateUploaderConfig):
|
|
70
|
+
pass
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@dataclass
|
|
74
|
+
class EmbeddedWeaviateUploader(WeaviateUploader):
|
|
75
|
+
connection_config: EmbeddedWeaviateConnectionConfig = field(
|
|
76
|
+
default_factory=lambda: EmbeddedWeaviateConnectionConfig()
|
|
77
|
+
)
|
|
78
|
+
upload_config: EmbeddedWeaviateUploaderConfig = field(
|
|
79
|
+
default_factory=lambda: EmbeddedWeaviateUploaderConfig()
|
|
80
|
+
)
|
|
81
|
+
connector_type: str = CONNECTOR_TYPE
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
weaviate_embedded_destination_entry = DestinationRegistryEntry(
|
|
85
|
+
connection_config=EmbeddedWeaviateConnectionConfig,
|
|
86
|
+
uploader=EmbeddedWeaviateUploader,
|
|
87
|
+
uploader_config=EmbeddedWeaviateUploaderConfig,
|
|
88
|
+
upload_stager=EmbeddedWeaviateUploadStager,
|
|
89
|
+
upload_stager_config=EmbeddedWeaviateUploadStagerConfig,
|
|
90
|
+
)
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
from contextlib import contextmanager
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import TYPE_CHECKING, Generator
|
|
4
|
+
|
|
5
|
+
from pydantic import Field, Secret
|
|
6
|
+
|
|
7
|
+
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
8
|
+
from unstructured_ingest.v2.processes.connector_registry import DestinationRegistryEntry
|
|
9
|
+
from unstructured_ingest.v2.processes.connectors.weaviate.weaviate import (
|
|
10
|
+
WeaviateAccessConfig,
|
|
11
|
+
WeaviateConnectionConfig,
|
|
12
|
+
WeaviateUploader,
|
|
13
|
+
WeaviateUploaderConfig,
|
|
14
|
+
WeaviateUploadStager,
|
|
15
|
+
WeaviateUploadStagerConfig,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
if TYPE_CHECKING:
|
|
19
|
+
from weaviate.client import WeaviateClient
|
|
20
|
+
|
|
21
|
+
CONNECTOR_TYPE = "weaviate-local"
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class LocalWeaviateAccessConfig(WeaviateAccessConfig):
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class LocalWeaviateConnectionConfig(WeaviateConnectionConfig):
|
|
29
|
+
access_config: Secret[WeaviateAccessConfig] = Field(
|
|
30
|
+
default=WeaviateAccessConfig(), validate_default=True
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
@contextmanager
|
|
34
|
+
@requires_dependencies(["weaviate"], extras="weaviate")
|
|
35
|
+
def get_client(self) -> Generator["WeaviateClient", None, None]:
|
|
36
|
+
from weaviate import connect_to_local
|
|
37
|
+
from weaviate.classes.init import AdditionalConfig
|
|
38
|
+
|
|
39
|
+
with connect_to_local(
|
|
40
|
+
additional_config=AdditionalConfig(timeout=self.get_timeout())
|
|
41
|
+
) as weaviate_client:
|
|
42
|
+
yield weaviate_client
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class LocalWeaviateUploadStagerConfig(WeaviateUploadStagerConfig):
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclass
|
|
50
|
+
class LocalWeaviateUploadStager(WeaviateUploadStager):
|
|
51
|
+
upload_stager_config: LocalWeaviateUploadStagerConfig = field(
|
|
52
|
+
default_factory=lambda: WeaviateUploadStagerConfig()
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class LocalWeaviateUploaderConfig(WeaviateUploaderConfig):
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@dataclass
|
|
61
|
+
class LocalWeaviateUploader(WeaviateUploader):
|
|
62
|
+
upload_config: LocalWeaviateUploaderConfig
|
|
63
|
+
connector_type: str = CONNECTOR_TYPE
|
|
64
|
+
connection_config: LocalWeaviateConnectionConfig
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
weaviate_local_destination_entry = DestinationRegistryEntry(
|
|
68
|
+
connection_config=LocalWeaviateConnectionConfig,
|
|
69
|
+
uploader=LocalWeaviateUploader,
|
|
70
|
+
uploader_config=LocalWeaviateUploaderConfig,
|
|
71
|
+
upload_stager=LocalWeaviateUploadStager,
|
|
72
|
+
upload_stager_config=LocalWeaviateUploadStagerConfig,
|
|
73
|
+
)
|
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
import json
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from contextlib import contextmanager
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from datetime import date, datetime
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import TYPE_CHECKING, Any, Generator, Optional
|
|
8
|
+
|
|
9
|
+
from dateutil import parser
|
|
10
|
+
from pydantic import Field, Secret
|
|
11
|
+
|
|
12
|
+
from unstructured_ingest.error import DestinationConnectionError, WriteError
|
|
13
|
+
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
14
|
+
from unstructured_ingest.v2.constants import RECORD_ID_LABEL
|
|
15
|
+
from unstructured_ingest.v2.interfaces import (
|
|
16
|
+
AccessConfig,
|
|
17
|
+
ConnectionConfig,
|
|
18
|
+
FileData,
|
|
19
|
+
Uploader,
|
|
20
|
+
UploaderConfig,
|
|
21
|
+
UploadStager,
|
|
22
|
+
UploadStagerConfig,
|
|
23
|
+
)
|
|
24
|
+
from unstructured_ingest.v2.logger import logger
|
|
25
|
+
from unstructured_ingest.v2.processes.connector_registry import DestinationRegistryEntry
|
|
26
|
+
|
|
27
|
+
if TYPE_CHECKING:
|
|
28
|
+
from weaviate.classes.init import Timeout
|
|
29
|
+
from weaviate.client import WeaviateClient
|
|
30
|
+
from weaviate.collections.batch.client import BatchClient
|
|
31
|
+
|
|
32
|
+
CONNECTOR_TYPE = "weaviate"
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class WeaviateAccessConfig(AccessConfig, ABC):
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class WeaviateConnectionConfig(ConnectionConfig, ABC):
|
|
40
|
+
init_timeout: int = Field(default=2, ge=0, description="Timeout for initialization checks")
|
|
41
|
+
insert_timeout: int = Field(default=90, ge=0, description="Timeout for insert operations")
|
|
42
|
+
query_timeout: int = Field(default=30, ge=0, description="Timeout for query operations")
|
|
43
|
+
access_config: Secret[WeaviateAccessConfig] = Field(
|
|
44
|
+
default=WeaviateAccessConfig(), validate_default=True
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
@requires_dependencies(["weaviate"], extras="weaviate")
|
|
48
|
+
def get_timeout(self) -> "Timeout":
|
|
49
|
+
from weaviate.classes.init import Timeout
|
|
50
|
+
|
|
51
|
+
return Timeout(init=self.init_timeout, query=self.query_timeout, insert=self.insert_timeout)
|
|
52
|
+
|
|
53
|
+
@abstractmethod
|
|
54
|
+
@contextmanager
|
|
55
|
+
def get_client(self) -> Generator["WeaviateClient", None, None]:
|
|
56
|
+
pass
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class WeaviateUploadStagerConfig(UploadStagerConfig):
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@dataclass
|
|
64
|
+
class WeaviateUploadStager(UploadStager):
|
|
65
|
+
upload_stager_config: WeaviateUploadStagerConfig = field(
|
|
66
|
+
default_factory=lambda: WeaviateUploadStagerConfig()
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
@staticmethod
|
|
70
|
+
def parse_date_string(date_string: str) -> date:
|
|
71
|
+
try:
|
|
72
|
+
timestamp = float(date_string)
|
|
73
|
+
return datetime.fromtimestamp(timestamp)
|
|
74
|
+
except Exception as e:
|
|
75
|
+
logger.debug(f"date {date_string} string not a timestamp: {e}")
|
|
76
|
+
return parser.parse(date_string)
|
|
77
|
+
|
|
78
|
+
@classmethod
|
|
79
|
+
def conform_dict(cls, data: dict, file_data: FileData) -> dict:
|
|
80
|
+
"""
|
|
81
|
+
Updates the element dictionary to conform to the Weaviate schema
|
|
82
|
+
"""
|
|
83
|
+
working_data = data.copy()
|
|
84
|
+
# Dict as string formatting
|
|
85
|
+
if (
|
|
86
|
+
record_locator := working_data.get("metadata", {})
|
|
87
|
+
.get("data_source", {})
|
|
88
|
+
.get("record_locator")
|
|
89
|
+
):
|
|
90
|
+
# Explicit casting otherwise fails schema type checking
|
|
91
|
+
working_data["metadata"]["data_source"]["record_locator"] = str(
|
|
92
|
+
json.dumps(record_locator)
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
# Array of items as string formatting
|
|
96
|
+
if points := working_data.get("metadata", {}).get("coordinates", {}).get("points"):
|
|
97
|
+
working_data["metadata"]["coordinates"]["points"] = str(json.dumps(points))
|
|
98
|
+
|
|
99
|
+
if links := working_data.get("metadata", {}).get("links", {}):
|
|
100
|
+
working_data["metadata"]["links"] = str(json.dumps(links))
|
|
101
|
+
|
|
102
|
+
if permissions_data := (
|
|
103
|
+
working_data.get("metadata", {}).get("data_source", {}).get("permissions_data")
|
|
104
|
+
):
|
|
105
|
+
working_data["metadata"]["data_source"]["permissions_data"] = json.dumps(
|
|
106
|
+
permissions_data
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# Datetime formatting
|
|
110
|
+
if (
|
|
111
|
+
date_created := working_data.get("metadata", {})
|
|
112
|
+
.get("data_source", {})
|
|
113
|
+
.get("date_created")
|
|
114
|
+
):
|
|
115
|
+
working_data["metadata"]["data_source"]["date_created"] = cls.parse_date_string(
|
|
116
|
+
date_created
|
|
117
|
+
).strftime(
|
|
118
|
+
"%Y-%m-%dT%H:%M:%S.%fZ",
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
if (
|
|
122
|
+
date_modified := working_data.get("metadata", {})
|
|
123
|
+
.get("data_source", {})
|
|
124
|
+
.get("date_modified")
|
|
125
|
+
):
|
|
126
|
+
working_data["metadata"]["data_source"]["date_modified"] = cls.parse_date_string(
|
|
127
|
+
date_modified
|
|
128
|
+
).strftime(
|
|
129
|
+
"%Y-%m-%dT%H:%M:%S.%fZ",
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
if (
|
|
133
|
+
date_processed := working_data.get("metadata", {})
|
|
134
|
+
.get("data_source", {})
|
|
135
|
+
.get("date_processed")
|
|
136
|
+
):
|
|
137
|
+
working_data["metadata"]["data_source"]["date_processed"] = cls.parse_date_string(
|
|
138
|
+
date_processed
|
|
139
|
+
).strftime(
|
|
140
|
+
"%Y-%m-%dT%H:%M:%S.%fZ",
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
if last_modified := working_data.get("metadata", {}).get("last_modified"):
|
|
144
|
+
working_data["metadata"]["last_modified"] = cls.parse_date_string(
|
|
145
|
+
last_modified
|
|
146
|
+
).strftime(
|
|
147
|
+
"%Y-%m-%dT%H:%M:%S.%fZ",
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
# String casting
|
|
151
|
+
if version := working_data.get("metadata", {}).get("data_source", {}).get("version"):
|
|
152
|
+
working_data["metadata"]["data_source"]["version"] = str(version)
|
|
153
|
+
|
|
154
|
+
if page_number := working_data.get("metadata", {}).get("page_number"):
|
|
155
|
+
working_data["metadata"]["page_number"] = str(page_number)
|
|
156
|
+
|
|
157
|
+
if regex_metadata := working_data.get("metadata", {}).get("regex_metadata"):
|
|
158
|
+
working_data["metadata"]["regex_metadata"] = str(json.dumps(regex_metadata))
|
|
159
|
+
|
|
160
|
+
working_data[RECORD_ID_LABEL] = file_data.identifier
|
|
161
|
+
return working_data
|
|
162
|
+
|
|
163
|
+
def run(
|
|
164
|
+
self,
|
|
165
|
+
elements_filepath: Path,
|
|
166
|
+
file_data: FileData,
|
|
167
|
+
output_dir: Path,
|
|
168
|
+
output_filename: str,
|
|
169
|
+
**kwargs: Any,
|
|
170
|
+
) -> Path:
|
|
171
|
+
with open(elements_filepath) as elements_file:
|
|
172
|
+
elements_contents = json.load(elements_file)
|
|
173
|
+
updated_elements = [
|
|
174
|
+
self.conform_dict(data=element, file_data=file_data) for element in elements_contents
|
|
175
|
+
]
|
|
176
|
+
output_path = Path(output_dir) / Path(f"{output_filename}.json")
|
|
177
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
178
|
+
with open(output_path, "w") as output_file:
|
|
179
|
+
json.dump(updated_elements, output_file, indent=2)
|
|
180
|
+
return output_path
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
class WeaviateUploaderConfig(UploaderConfig):
|
|
184
|
+
collection: str = Field(description="The name of the collection this object belongs to")
|
|
185
|
+
batch_size: Optional[int] = Field(default=None, description="Number of records per batch")
|
|
186
|
+
requests_per_minute: Optional[int] = Field(default=None, description="Rate limit for upload")
|
|
187
|
+
dynamic_batch: bool = Field(default=True, description="Whether to use dynamic batch")
|
|
188
|
+
record_id_key: str = Field(
|
|
189
|
+
default=RECORD_ID_LABEL,
|
|
190
|
+
description="searchable key to find entries for the same record on previous runs",
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
def model_post_init(self, __context: Any) -> None:
|
|
194
|
+
batch_types = {
|
|
195
|
+
"fixed_size": self.batch_size is not None,
|
|
196
|
+
"rate_limited": self.requests_per_minute is not None,
|
|
197
|
+
"dynamic": self.dynamic_batch,
|
|
198
|
+
}
|
|
199
|
+
|
|
200
|
+
enabled_batch_modes = [batch_key for batch_key, flag in batch_types.items() if flag]
|
|
201
|
+
if not enabled_batch_modes:
|
|
202
|
+
raise ValueError("No batch mode enabled")
|
|
203
|
+
if len(enabled_batch_modes) > 1:
|
|
204
|
+
raise ValueError(
|
|
205
|
+
"Multiple batch modes enabled, only one mode can be used: {}".format(
|
|
206
|
+
", ".join(enabled_batch_modes)
|
|
207
|
+
)
|
|
208
|
+
)
|
|
209
|
+
logger.info(f"Uploader config instantiated with {enabled_batch_modes[0]} batch mode")
|
|
210
|
+
|
|
211
|
+
@contextmanager
|
|
212
|
+
def get_batch_client(self, client: "WeaviateClient") -> Generator["BatchClient", None, None]:
|
|
213
|
+
if self.dynamic_batch:
|
|
214
|
+
with client.batch.dynamic() as batch_client:
|
|
215
|
+
yield batch_client
|
|
216
|
+
elif self.batch_size:
|
|
217
|
+
with client.batch.fixed_size(batch_size=self.batch_size) as batch_client:
|
|
218
|
+
yield batch_client
|
|
219
|
+
elif self.requests_per_minute:
|
|
220
|
+
with client.batch.rate_limit(
|
|
221
|
+
requests_per_minute=self.requests_per_minute
|
|
222
|
+
) as batch_client:
|
|
223
|
+
yield batch_client
|
|
224
|
+
else:
|
|
225
|
+
raise ValueError("No batch mode enabled")
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
@dataclass
|
|
229
|
+
class WeaviateUploader(Uploader, ABC):
|
|
230
|
+
upload_config: WeaviateUploaderConfig
|
|
231
|
+
connection_config: WeaviateConnectionConfig
|
|
232
|
+
|
|
233
|
+
def precheck(self) -> None:
|
|
234
|
+
try:
|
|
235
|
+
self.connection_config.get_client()
|
|
236
|
+
except Exception as e:
|
|
237
|
+
logger.error(f"Failed to validate connection {e}", exc_info=True)
|
|
238
|
+
raise DestinationConnectionError(f"failed to validate connection: {e}")
|
|
239
|
+
|
|
240
|
+
def check_for_errors(self, client: "WeaviateClient") -> None:
|
|
241
|
+
failed_uploads = client.batch.failed_objects
|
|
242
|
+
if failed_uploads:
|
|
243
|
+
for failure in failed_uploads:
|
|
244
|
+
logger.error(
|
|
245
|
+
f"Failed to upload object with id {failure.original_uuid}: {failure.message}"
|
|
246
|
+
)
|
|
247
|
+
raise WriteError("Failed to upload to weaviate")
|
|
248
|
+
|
|
249
|
+
@requires_dependencies(["weaviate"], extras="weaviate")
|
|
250
|
+
def delete_by_record_id(self, client: "WeaviateClient", file_data: FileData) -> None:
|
|
251
|
+
from weaviate.classes.query import Filter
|
|
252
|
+
|
|
253
|
+
record_id = file_data.identifier
|
|
254
|
+
collection = client.collections.get(self.upload_config.collection)
|
|
255
|
+
delete_filter = Filter.by_property(name=self.upload_config.record_id_key).equal(
|
|
256
|
+
val=record_id
|
|
257
|
+
)
|
|
258
|
+
# There is a configurable maximum limit (QUERY_MAXIMUM_RESULTS) on the number of
|
|
259
|
+
# objects that can be deleted in a single query (default 10,000). To delete
|
|
260
|
+
# more objects than the limit, re-run the query until nothing is deleted.
|
|
261
|
+
while True:
|
|
262
|
+
resp = collection.data.delete_many(where=delete_filter)
|
|
263
|
+
if resp.failed:
|
|
264
|
+
raise WriteError(
|
|
265
|
+
f"failed to delete records in collection "
|
|
266
|
+
f"{self.upload_config.collection} with record "
|
|
267
|
+
f"id property {record_id}"
|
|
268
|
+
)
|
|
269
|
+
if not resp.failed and not resp.successful:
|
|
270
|
+
break
|
|
271
|
+
|
|
272
|
+
def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
|
|
273
|
+
with path.open("r") as file:
|
|
274
|
+
elements_dict = json.load(file)
|
|
275
|
+
logger.info(
|
|
276
|
+
f"writing {len(elements_dict)} objects to destination "
|
|
277
|
+
f"class {self.connection_config.access_config} "
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
with self.connection_config.get_client() as weaviate_client:
|
|
281
|
+
self.delete_by_record_id(client=weaviate_client, file_data=file_data)
|
|
282
|
+
with self.upload_config.get_batch_client(client=weaviate_client) as batch_client:
|
|
283
|
+
for e in elements_dict:
|
|
284
|
+
vector = e.pop("embeddings", None)
|
|
285
|
+
batch_client.add_object(
|
|
286
|
+
collection=self.upload_config.collection,
|
|
287
|
+
properties=e,
|
|
288
|
+
vector=vector,
|
|
289
|
+
)
|
|
290
|
+
self.check_for_errors(client=weaviate_client)
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
weaviate_destination_entry = DestinationRegistryEntry(
|
|
294
|
+
connection_config=WeaviateConnectionConfig,
|
|
295
|
+
uploader=WeaviateUploader,
|
|
296
|
+
uploader_config=WeaviateUploaderConfig,
|
|
297
|
+
upload_stager=WeaviateUploadStager,
|
|
298
|
+
upload_stager_config=WeaviateUploadStagerConfig,
|
|
299
|
+
)
|