unstructured-ingest 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of unstructured-ingest might be problematic. Click here for more details.
- test/integration/connectors/elasticsearch/__init__.py +0 -0
- test/integration/connectors/elasticsearch/conftest.py +34 -0
- test/integration/connectors/elasticsearch/test_elasticsearch.py +308 -0
- test/integration/connectors/elasticsearch/test_opensearch.py +302 -0
- test/integration/connectors/sql/test_postgres.py +10 -4
- test/integration/connectors/sql/test_singlestore.py +8 -4
- test/integration/connectors/sql/test_snowflake.py +10 -6
- test/integration/connectors/sql/test_sqlite.py +4 -4
- test/integration/connectors/test_astradb.py +50 -3
- test/integration/connectors/test_delta_table.py +46 -0
- test/integration/connectors/test_kafka.py +40 -6
- test/integration/connectors/test_lancedb.py +210 -0
- test/integration/connectors/test_milvus.py +141 -0
- test/integration/connectors/test_mongodb.py +332 -0
- test/integration/connectors/test_pinecone.py +53 -1
- test/integration/connectors/utils/docker.py +81 -15
- test/integration/connectors/utils/validation.py +10 -0
- test/integration/connectors/weaviate/__init__.py +0 -0
- test/integration/connectors/weaviate/conftest.py +15 -0
- test/integration/connectors/weaviate/test_local.py +131 -0
- unstructured_ingest/__version__.py +1 -1
- unstructured_ingest/pipeline/reformat/embedding.py +1 -1
- unstructured_ingest/utils/data_prep.py +9 -1
- unstructured_ingest/v2/processes/connectors/__init__.py +3 -16
- unstructured_ingest/v2/processes/connectors/astradb.py +2 -2
- unstructured_ingest/v2/processes/connectors/azure_ai_search.py +4 -0
- unstructured_ingest/v2/processes/connectors/delta_table.py +20 -4
- unstructured_ingest/v2/processes/connectors/elasticsearch/__init__.py +19 -0
- unstructured_ingest/v2/processes/connectors/{elasticsearch.py → elasticsearch/elasticsearch.py} +92 -46
- unstructured_ingest/v2/processes/connectors/{opensearch.py → elasticsearch/opensearch.py} +1 -1
- unstructured_ingest/v2/processes/connectors/google_drive.py +1 -1
- unstructured_ingest/v2/processes/connectors/kafka/kafka.py +6 -0
- unstructured_ingest/v2/processes/connectors/lancedb/__init__.py +17 -0
- unstructured_ingest/v2/processes/connectors/lancedb/aws.py +43 -0
- unstructured_ingest/v2/processes/connectors/lancedb/azure.py +43 -0
- unstructured_ingest/v2/processes/connectors/lancedb/gcp.py +44 -0
- unstructured_ingest/v2/processes/connectors/lancedb/lancedb.py +161 -0
- unstructured_ingest/v2/processes/connectors/lancedb/local.py +44 -0
- unstructured_ingest/v2/processes/connectors/milvus.py +72 -27
- unstructured_ingest/v2/processes/connectors/mongodb.py +122 -111
- unstructured_ingest/v2/processes/connectors/pinecone.py +24 -7
- unstructured_ingest/v2/processes/connectors/sql/sql.py +97 -26
- unstructured_ingest/v2/processes/connectors/weaviate/__init__.py +25 -0
- unstructured_ingest/v2/processes/connectors/weaviate/cloud.py +164 -0
- unstructured_ingest/v2/processes/connectors/weaviate/embedded.py +90 -0
- unstructured_ingest/v2/processes/connectors/weaviate/local.py +73 -0
- unstructured_ingest/v2/processes/connectors/weaviate/weaviate.py +299 -0
- {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.2.dist-info}/METADATA +19 -19
- {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.2.dist-info}/RECORD +54 -33
- unstructured_ingest/v2/processes/connectors/weaviate.py +0 -242
- /test/integration/connectors/{test_azure_cog_search.py → test_azure_ai_search.py} +0 -0
- {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.2.dist-info}/LICENSE.md +0 -0
- {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.2.dist-info}/WHEEL +0 -0
- {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.2.dist-info}/entry_points.txt +0 -0
- {unstructured_ingest-0.3.0.dist-info → unstructured_ingest-0.3.2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from pydantic import Field, Secret
|
|
4
|
+
|
|
5
|
+
from unstructured_ingest.v2.interfaces.connector import AccessConfig
|
|
6
|
+
from unstructured_ingest.v2.processes.connector_registry import DestinationRegistryEntry
|
|
7
|
+
from unstructured_ingest.v2.processes.connectors.lancedb.lancedb import (
|
|
8
|
+
LanceDBRemoteConnectionConfig,
|
|
9
|
+
LanceDBUploader,
|
|
10
|
+
LanceDBUploaderConfig,
|
|
11
|
+
LanceDBUploadStager,
|
|
12
|
+
LanceDBUploadStagerConfig,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
CONNECTOR_TYPE = "lancedb_aws"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class LanceDBS3AccessConfig(AccessConfig):
|
|
19
|
+
aws_access_key_id: str = Field(description="The AWS access key ID to use.")
|
|
20
|
+
aws_secret_access_key: str = Field(description="The AWS secret access key to use.")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class LanceDBS3ConnectionConfig(LanceDBRemoteConnectionConfig):
|
|
24
|
+
access_config: Secret[LanceDBS3AccessConfig]
|
|
25
|
+
|
|
26
|
+
def get_storage_options(self) -> dict:
|
|
27
|
+
return {**self.access_config.get_secret_value().model_dump(), "timeout": self.timeout}
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class LanceDBS3Uploader(LanceDBUploader):
|
|
32
|
+
upload_config: LanceDBUploaderConfig
|
|
33
|
+
connection_config: LanceDBS3ConnectionConfig
|
|
34
|
+
connector_type: str = CONNECTOR_TYPE
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
lancedb_aws_destination_entry = DestinationRegistryEntry(
|
|
38
|
+
connection_config=LanceDBS3ConnectionConfig,
|
|
39
|
+
uploader=LanceDBS3Uploader,
|
|
40
|
+
uploader_config=LanceDBUploaderConfig,
|
|
41
|
+
upload_stager_config=LanceDBUploadStagerConfig,
|
|
42
|
+
upload_stager=LanceDBUploadStager,
|
|
43
|
+
)
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from pydantic import Field, Secret
|
|
4
|
+
|
|
5
|
+
from unstructured_ingest.v2.interfaces.connector import AccessConfig
|
|
6
|
+
from unstructured_ingest.v2.processes.connector_registry import DestinationRegistryEntry
|
|
7
|
+
from unstructured_ingest.v2.processes.connectors.lancedb.lancedb import (
|
|
8
|
+
LanceDBRemoteConnectionConfig,
|
|
9
|
+
LanceDBUploader,
|
|
10
|
+
LanceDBUploaderConfig,
|
|
11
|
+
LanceDBUploadStager,
|
|
12
|
+
LanceDBUploadStagerConfig,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
CONNECTOR_TYPE = "lancedb_azure"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class LanceDBAzureAccessConfig(AccessConfig):
|
|
19
|
+
azure_storage_account_name: str = Field(description="The name of the azure storage account.")
|
|
20
|
+
azure_storage_account_key: str = Field(description="The serialized azure service account key.")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class LanceDBAzureConnectionConfig(LanceDBRemoteConnectionConfig):
|
|
24
|
+
access_config: Secret[LanceDBAzureAccessConfig]
|
|
25
|
+
|
|
26
|
+
def get_storage_options(self) -> dict:
|
|
27
|
+
return {**self.access_config.get_secret_value().model_dump(), "timeout": self.timeout}
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
@dataclass
|
|
31
|
+
class LanceDBAzureUploader(LanceDBUploader):
|
|
32
|
+
upload_config: LanceDBUploaderConfig
|
|
33
|
+
connection_config: LanceDBAzureConnectionConfig
|
|
34
|
+
connector_type: str = CONNECTOR_TYPE
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
lancedb_azure_destination_entry = DestinationRegistryEntry(
|
|
38
|
+
connection_config=LanceDBAzureConnectionConfig,
|
|
39
|
+
uploader=LanceDBAzureUploader,
|
|
40
|
+
uploader_config=LanceDBUploaderConfig,
|
|
41
|
+
upload_stager_config=LanceDBUploadStagerConfig,
|
|
42
|
+
upload_stager=LanceDBUploadStager,
|
|
43
|
+
)
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from pydantic import Field, Secret
|
|
4
|
+
|
|
5
|
+
from unstructured_ingest.v2.interfaces.connector import AccessConfig
|
|
6
|
+
from unstructured_ingest.v2.processes.connector_registry import DestinationRegistryEntry
|
|
7
|
+
from unstructured_ingest.v2.processes.connectors.lancedb.lancedb import (
|
|
8
|
+
LanceDBRemoteConnectionConfig,
|
|
9
|
+
LanceDBUploader,
|
|
10
|
+
LanceDBUploaderConfig,
|
|
11
|
+
LanceDBUploadStager,
|
|
12
|
+
LanceDBUploadStagerConfig,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
CONNECTOR_TYPE = "lancedb_gcs"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class LanceDBGCSAccessConfig(AccessConfig):
|
|
19
|
+
google_service_account_key: str = Field(
|
|
20
|
+
description="The serialized google service account key."
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class LanceDBGCSConnectionConfig(LanceDBRemoteConnectionConfig):
|
|
25
|
+
access_config: Secret[LanceDBGCSAccessConfig]
|
|
26
|
+
|
|
27
|
+
def get_storage_options(self) -> dict:
|
|
28
|
+
return {**self.access_config.get_secret_value().model_dump(), "timeout": self.timeout}
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class LanceDBGSPUploader(LanceDBUploader):
|
|
33
|
+
upload_config: LanceDBUploaderConfig
|
|
34
|
+
connection_config: LanceDBGCSConnectionConfig
|
|
35
|
+
connector_type: str = CONNECTOR_TYPE
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
lancedb_gcp_destination_entry = DestinationRegistryEntry(
|
|
39
|
+
connection_config=LanceDBGCSConnectionConfig,
|
|
40
|
+
uploader=LanceDBGSPUploader,
|
|
41
|
+
uploader_config=LanceDBUploaderConfig,
|
|
42
|
+
upload_stager_config=LanceDBUploadStagerConfig,
|
|
43
|
+
upload_stager=LanceDBUploadStager,
|
|
44
|
+
)
|
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from contextlib import asynccontextmanager
|
|
7
|
+
from dataclasses import dataclass, field
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional
|
|
10
|
+
|
|
11
|
+
import pandas as pd
|
|
12
|
+
from pydantic import Field
|
|
13
|
+
|
|
14
|
+
from unstructured_ingest.error import DestinationConnectionError
|
|
15
|
+
from unstructured_ingest.logger import logger
|
|
16
|
+
from unstructured_ingest.utils.data_prep import flatten_dict
|
|
17
|
+
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
18
|
+
from unstructured_ingest.v2.interfaces.connector import ConnectionConfig
|
|
19
|
+
from unstructured_ingest.v2.interfaces.file_data import FileData
|
|
20
|
+
from unstructured_ingest.v2.interfaces.upload_stager import UploadStager, UploadStagerConfig
|
|
21
|
+
from unstructured_ingest.v2.interfaces.uploader import Uploader, UploaderConfig
|
|
22
|
+
|
|
23
|
+
CONNECTOR_TYPE = "lancedb"
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from lancedb import AsyncConnection
|
|
27
|
+
from lancedb.table import AsyncTable
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class LanceDBConnectionConfig(ConnectionConfig, ABC):
|
|
31
|
+
uri: str = Field(description="The uri of the database.")
|
|
32
|
+
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def get_storage_options(self) -> Optional[dict[str, str]]:
|
|
35
|
+
raise NotImplementedError
|
|
36
|
+
|
|
37
|
+
@asynccontextmanager
|
|
38
|
+
@requires_dependencies(["lancedb"], extras="lancedb")
|
|
39
|
+
@DestinationConnectionError.wrap
|
|
40
|
+
async def get_async_connection(self) -> AsyncGenerator["AsyncConnection", None]:
|
|
41
|
+
import lancedb
|
|
42
|
+
|
|
43
|
+
connection = await lancedb.connect_async(
|
|
44
|
+
self.uri,
|
|
45
|
+
storage_options=self.get_storage_options(),
|
|
46
|
+
)
|
|
47
|
+
try:
|
|
48
|
+
yield connection
|
|
49
|
+
finally:
|
|
50
|
+
connection.close()
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class LanceDBRemoteConnectionConfig(LanceDBConnectionConfig):
|
|
54
|
+
timeout: str = Field(
|
|
55
|
+
default="30s",
|
|
56
|
+
description=(
|
|
57
|
+
"Timeout for the entire request, from connection until the response body has finished"
|
|
58
|
+
"in a [0-9]+(ns|us|ms|[smhdwy]) format."
|
|
59
|
+
),
|
|
60
|
+
pattern=r"[0-9]+(ns|us|ms|[smhdwy])",
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class LanceDBUploadStagerConfig(UploadStagerConfig):
|
|
65
|
+
pass
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
@dataclass
|
|
69
|
+
class LanceDBUploadStager(UploadStager):
|
|
70
|
+
upload_stager_config: LanceDBUploadStagerConfig = field(
|
|
71
|
+
default_factory=LanceDBUploadStagerConfig
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
def run(
|
|
75
|
+
self,
|
|
76
|
+
elements_filepath: Path,
|
|
77
|
+
file_data: FileData,
|
|
78
|
+
output_dir: Path,
|
|
79
|
+
output_filename: str,
|
|
80
|
+
**kwargs: Any,
|
|
81
|
+
) -> Path:
|
|
82
|
+
with open(elements_filepath) as elements_file:
|
|
83
|
+
elements_contents: list[dict] = json.load(elements_file)
|
|
84
|
+
|
|
85
|
+
df = pd.DataFrame(
|
|
86
|
+
[
|
|
87
|
+
self._conform_element_contents(element_contents)
|
|
88
|
+
for element_contents in elements_contents
|
|
89
|
+
]
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
output_path = (output_dir / output_filename).with_suffix(".feather")
|
|
93
|
+
df.to_feather(output_path)
|
|
94
|
+
|
|
95
|
+
return output_path
|
|
96
|
+
|
|
97
|
+
def _conform_element_contents(self, element: dict) -> dict:
|
|
98
|
+
return {
|
|
99
|
+
"vector": element.pop("embeddings", None),
|
|
100
|
+
**flatten_dict(element, separator="-"),
|
|
101
|
+
}
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class LanceDBUploaderConfig(UploaderConfig):
|
|
105
|
+
table_name: str = Field(description="The name of the table.")
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
@dataclass
|
|
109
|
+
class LanceDBUploader(Uploader):
|
|
110
|
+
upload_config: LanceDBUploaderConfig
|
|
111
|
+
connection_config: LanceDBConnectionConfig
|
|
112
|
+
connector_type: str = CONNECTOR_TYPE
|
|
113
|
+
|
|
114
|
+
@DestinationConnectionError.wrap
|
|
115
|
+
def precheck(self):
|
|
116
|
+
async def _precheck() -> None:
|
|
117
|
+
async with self.connection_config.get_async_connection() as conn:
|
|
118
|
+
table = await conn.open_table(self.upload_config.table_name)
|
|
119
|
+
table.close()
|
|
120
|
+
|
|
121
|
+
asyncio.run(_precheck())
|
|
122
|
+
|
|
123
|
+
@asynccontextmanager
|
|
124
|
+
async def get_table(self) -> AsyncGenerator["AsyncTable", None]:
|
|
125
|
+
async with self.connection_config.get_async_connection() as conn:
|
|
126
|
+
table = await conn.open_table(self.upload_config.table_name)
|
|
127
|
+
try:
|
|
128
|
+
yield table
|
|
129
|
+
finally:
|
|
130
|
+
table.close()
|
|
131
|
+
|
|
132
|
+
async def run_async(self, path, file_data, **kwargs):
|
|
133
|
+
df = pd.read_feather(path)
|
|
134
|
+
async with self.get_table() as table:
|
|
135
|
+
schema = await table.schema()
|
|
136
|
+
df = self._fit_to_schema(df, schema)
|
|
137
|
+
await table.add(data=df)
|
|
138
|
+
|
|
139
|
+
def _fit_to_schema(self, df: pd.DataFrame, schema) -> pd.DataFrame:
|
|
140
|
+
columns = set(df.columns)
|
|
141
|
+
schema_fields = set(schema.names)
|
|
142
|
+
columns_to_drop = columns - schema_fields
|
|
143
|
+
missing_columns = schema_fields - columns
|
|
144
|
+
|
|
145
|
+
if columns_to_drop:
|
|
146
|
+
logger.info(
|
|
147
|
+
"Following columns will be dropped to match the table's schema: "
|
|
148
|
+
f"{', '.join(columns_to_drop)}"
|
|
149
|
+
)
|
|
150
|
+
if missing_columns:
|
|
151
|
+
logger.info(
|
|
152
|
+
"Following null filled columns will be added to match the table's schema:"
|
|
153
|
+
f" {', '.join(missing_columns)} "
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
df = df.drop(columns=columns_to_drop)
|
|
157
|
+
|
|
158
|
+
for column in missing_columns:
|
|
159
|
+
df[column] = pd.Series()
|
|
160
|
+
|
|
161
|
+
return df
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from pydantic import Field, Secret
|
|
4
|
+
|
|
5
|
+
from unstructured_ingest.v2.interfaces.connector import AccessConfig
|
|
6
|
+
from unstructured_ingest.v2.processes.connector_registry import DestinationRegistryEntry
|
|
7
|
+
from unstructured_ingest.v2.processes.connectors.lancedb.lancedb import (
|
|
8
|
+
LanceDBConnectionConfig,
|
|
9
|
+
LanceDBUploader,
|
|
10
|
+
LanceDBUploaderConfig,
|
|
11
|
+
LanceDBUploadStager,
|
|
12
|
+
LanceDBUploadStagerConfig,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
CONNECTOR_TYPE = "lancedb_local"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class LanceDBLocalAccessConfig(AccessConfig):
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class LanceDBLocalConnectionConfig(LanceDBConnectionConfig):
|
|
23
|
+
access_config: Secret[LanceDBLocalAccessConfig] = Field(
|
|
24
|
+
default_factory=LanceDBLocalAccessConfig, validate_default=True
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
def get_storage_options(self) -> None:
|
|
28
|
+
return None
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@dataclass
|
|
32
|
+
class LanceDBLocalUploader(LanceDBUploader):
|
|
33
|
+
upload_config: LanceDBUploaderConfig
|
|
34
|
+
connection_config: LanceDBLocalConnectionConfig
|
|
35
|
+
connector_type: str = CONNECTOR_TYPE
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
lancedb_local_destination_entry = DestinationRegistryEntry(
|
|
39
|
+
connection_config=LanceDBLocalConnectionConfig,
|
|
40
|
+
uploader=LanceDBLocalUploader,
|
|
41
|
+
uploader_config=LanceDBUploaderConfig,
|
|
42
|
+
upload_stager_config=LanceDBUploadStagerConfig,
|
|
43
|
+
upload_stager=LanceDBUploadStager,
|
|
44
|
+
)
|
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
import json
|
|
2
|
+
from contextlib import contextmanager
|
|
2
3
|
from dataclasses import dataclass, field
|
|
3
4
|
from pathlib import Path
|
|
4
|
-
from typing import TYPE_CHECKING, Any, Optional, Union
|
|
5
|
+
from typing import TYPE_CHECKING, Any, Generator, Optional, Union
|
|
5
6
|
|
|
6
7
|
import pandas as pd
|
|
7
8
|
from dateutil import parser
|
|
@@ -10,6 +11,7 @@ from pydantic import Field, Secret
|
|
|
10
11
|
from unstructured_ingest.error import WriteError
|
|
11
12
|
from unstructured_ingest.utils.data_prep import flatten_dict
|
|
12
13
|
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
14
|
+
from unstructured_ingest.v2.constants import RECORD_ID_LABEL
|
|
13
15
|
from unstructured_ingest.v2.interfaces import (
|
|
14
16
|
AccessConfig,
|
|
15
17
|
ConnectionConfig,
|
|
@@ -90,24 +92,27 @@ class MilvusUploadStager(UploadStager):
|
|
|
90
92
|
pass
|
|
91
93
|
return parser.parse(date_string).timestamp()
|
|
92
94
|
|
|
93
|
-
def conform_dict(self, data: dict) ->
|
|
94
|
-
|
|
95
|
-
|
|
95
|
+
def conform_dict(self, data: dict, file_data: FileData) -> dict:
|
|
96
|
+
working_data = data.copy()
|
|
97
|
+
if self.upload_stager_config.flatten_metadata and (
|
|
98
|
+
metadata := working_data.pop("metadata", None)
|
|
99
|
+
):
|
|
100
|
+
working_data.update(flatten_dict(metadata, keys_to_omit=["data_source_record_locator"]))
|
|
96
101
|
|
|
97
102
|
# TODO: milvus sdk doesn't seem to support defaults via the schema yet,
|
|
98
103
|
# remove once that gets updated
|
|
99
104
|
defaults = {"is_continuation": False}
|
|
100
105
|
for default in defaults:
|
|
101
|
-
if default not in
|
|
102
|
-
|
|
106
|
+
if default not in working_data:
|
|
107
|
+
working_data[default] = defaults[default]
|
|
103
108
|
|
|
104
109
|
if self.upload_stager_config.fields_to_include:
|
|
105
|
-
data_keys = set(
|
|
110
|
+
data_keys = set(working_data.keys())
|
|
106
111
|
for data_key in data_keys:
|
|
107
112
|
if data_key not in self.upload_stager_config.fields_to_include:
|
|
108
|
-
|
|
113
|
+
working_data.pop(data_key)
|
|
109
114
|
for field_include_key in self.upload_stager_config.fields_to_include:
|
|
110
|
-
if field_include_key not in
|
|
115
|
+
if field_include_key not in working_data:
|
|
111
116
|
raise KeyError(f"Field '{field_include_key}' is missing in data!")
|
|
112
117
|
|
|
113
118
|
datetime_columns = [
|
|
@@ -120,11 +125,15 @@ class MilvusUploadStager(UploadStager):
|
|
|
120
125
|
json_dumps_fields = ["languages", "data_source_permissions_data"]
|
|
121
126
|
|
|
122
127
|
for datetime_column in datetime_columns:
|
|
123
|
-
if datetime_column in
|
|
124
|
-
|
|
128
|
+
if datetime_column in working_data:
|
|
129
|
+
working_data[datetime_column] = self.parse_date_string(
|
|
130
|
+
working_data[datetime_column]
|
|
131
|
+
)
|
|
125
132
|
for json_dumps_field in json_dumps_fields:
|
|
126
|
-
if json_dumps_field in
|
|
127
|
-
|
|
133
|
+
if json_dumps_field in working_data:
|
|
134
|
+
working_data[json_dumps_field] = json.dumps(working_data[json_dumps_field])
|
|
135
|
+
working_data[RECORD_ID_LABEL] = file_data.identifier
|
|
136
|
+
return working_data
|
|
128
137
|
|
|
129
138
|
def run(
|
|
130
139
|
self,
|
|
@@ -136,18 +145,27 @@ class MilvusUploadStager(UploadStager):
|
|
|
136
145
|
) -> Path:
|
|
137
146
|
with open(elements_filepath) as elements_file:
|
|
138
147
|
elements_contents: list[dict[str, Any]] = json.load(elements_file)
|
|
139
|
-
|
|
140
|
-
self.conform_dict(data=element)
|
|
141
|
-
|
|
142
|
-
|
|
148
|
+
new_content = [
|
|
149
|
+
self.conform_dict(data=element, file_data=file_data) for element in elements_contents
|
|
150
|
+
]
|
|
151
|
+
output_filename_path = Path(output_filename)
|
|
152
|
+
if output_filename_path.suffix == ".json":
|
|
153
|
+
output_path = Path(output_dir) / output_filename_path
|
|
154
|
+
else:
|
|
155
|
+
output_path = Path(output_dir) / output_filename_path.with_suffix(".json")
|
|
143
156
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
144
157
|
with output_path.open("w") as output_file:
|
|
145
|
-
json.dump(
|
|
158
|
+
json.dump(new_content, output_file, indent=2)
|
|
146
159
|
return output_path
|
|
147
160
|
|
|
148
161
|
|
|
149
162
|
class MilvusUploaderConfig(UploaderConfig):
|
|
163
|
+
db_name: Optional[str] = Field(default=None, description="Milvus database name")
|
|
150
164
|
collection_name: str = Field(description="Milvus collections to write to")
|
|
165
|
+
record_id_key: str = Field(
|
|
166
|
+
default=RECORD_ID_LABEL,
|
|
167
|
+
description="searchable key to find entries for the same record on previous runs",
|
|
168
|
+
)
|
|
151
169
|
|
|
152
170
|
|
|
153
171
|
@dataclass
|
|
@@ -156,6 +174,16 @@ class MilvusUploader(Uploader):
|
|
|
156
174
|
upload_config: MilvusUploaderConfig
|
|
157
175
|
connector_type: str = CONNECTOR_TYPE
|
|
158
176
|
|
|
177
|
+
@contextmanager
|
|
178
|
+
def get_client(self) -> Generator["MilvusClient", None, None]:
|
|
179
|
+
client = self.connection_config.get_client()
|
|
180
|
+
if db_name := self.upload_config.db_name:
|
|
181
|
+
client.using_database(db_name=db_name)
|
|
182
|
+
try:
|
|
183
|
+
yield client
|
|
184
|
+
finally:
|
|
185
|
+
client.close()
|
|
186
|
+
|
|
159
187
|
def upload(self, content: UploadContent) -> None:
|
|
160
188
|
file_extension = content.path.suffix
|
|
161
189
|
if file_extension == ".json":
|
|
@@ -165,23 +193,39 @@ class MilvusUploader(Uploader):
|
|
|
165
193
|
else:
|
|
166
194
|
raise ValueError(f"Unsupported file extension: {file_extension}")
|
|
167
195
|
|
|
196
|
+
def delete_by_record_id(self, file_data: FileData) -> None:
|
|
197
|
+
logger.info(
|
|
198
|
+
f"deleting any content with metadata {RECORD_ID_LABEL}={file_data.identifier} "
|
|
199
|
+
f"from milvus collection {self.upload_config.collection_name}"
|
|
200
|
+
)
|
|
201
|
+
with self.get_client() as client:
|
|
202
|
+
delete_filter = f'{self.upload_config.record_id_key} == "{file_data.identifier}"'
|
|
203
|
+
resp = client.delete(
|
|
204
|
+
collection_name=self.upload_config.collection_name, filter=delete_filter
|
|
205
|
+
)
|
|
206
|
+
logger.info(
|
|
207
|
+
"deleted {} records from milvus collection {}".format(
|
|
208
|
+
resp["delete_count"], self.upload_config.collection_name
|
|
209
|
+
)
|
|
210
|
+
)
|
|
211
|
+
|
|
168
212
|
@requires_dependencies(["pymilvus"], extras="milvus")
|
|
169
213
|
def insert_results(self, data: Union[dict, list[dict]]):
|
|
170
214
|
from pymilvus import MilvusException
|
|
171
215
|
|
|
172
|
-
logger.
|
|
216
|
+
logger.info(
|
|
173
217
|
f"uploading {len(data)} entries to {self.connection_config.db_name} "
|
|
174
218
|
f"db in collection {self.upload_config.collection_name}"
|
|
175
219
|
)
|
|
176
|
-
|
|
220
|
+
with self.get_client() as client:
|
|
177
221
|
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
222
|
+
try:
|
|
223
|
+
res = client.insert(collection_name=self.upload_config.collection_name, data=data)
|
|
224
|
+
except MilvusException as milvus_exception:
|
|
225
|
+
raise WriteError("failed to upload records to milvus") from milvus_exception
|
|
226
|
+
if "err_count" in res and isinstance(res["err_count"], int) and res["err_count"] > 0:
|
|
227
|
+
err_count = res["err_count"]
|
|
228
|
+
raise WriteError(f"failed to upload {err_count} docs")
|
|
185
229
|
|
|
186
230
|
def upload_csv(self, content: UploadContent) -> None:
|
|
187
231
|
df = pd.read_csv(content.path)
|
|
@@ -194,6 +238,7 @@ class MilvusUploader(Uploader):
|
|
|
194
238
|
self.insert_results(data=data)
|
|
195
239
|
|
|
196
240
|
def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
|
|
241
|
+
self.delete_by_record_id(file_data=file_data)
|
|
197
242
|
self.upload(content=UploadContent(path=path, file_data=file_data))
|
|
198
243
|
|
|
199
244
|
|