unstructured-ingest 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of unstructured-ingest might be problematic. Click here for more details.
- test/integration/connectors/sql/test_singlestore.py +156 -0
- test/integration/connectors/test_s3.py +1 -1
- test/integration/connectors/utils/docker_compose.py +23 -8
- unstructured_ingest/__version__.py +1 -1
- unstructured_ingest/v2/interfaces/file_data.py +1 -0
- unstructured_ingest/v2/processes/connectors/__init__.py +3 -6
- unstructured_ingest/v2/processes/connectors/astradb.py +278 -55
- unstructured_ingest/v2/processes/connectors/databricks/volumes.py +3 -1
- unstructured_ingest/v2/processes/connectors/fsspec/fsspec.py +1 -0
- unstructured_ingest/v2/processes/connectors/sql/__init__.py +5 -0
- unstructured_ingest/v2/processes/connectors/sql/postgres.py +1 -20
- unstructured_ingest/v2/processes/connectors/sql/singlestore.py +168 -0
- unstructured_ingest/v2/processes/connectors/sql/snowflake.py +2 -4
- unstructured_ingest/v2/processes/connectors/sql/sql.py +13 -2
- unstructured_ingest/v2/unstructured_api.py +1 -1
- {unstructured_ingest-0.2.0.dist-info → unstructured_ingest-0.2.1.dist-info}/METADATA +17 -17
- {unstructured_ingest-0.2.0.dist-info → unstructured_ingest-0.2.1.dist-info}/RECORD +21 -20
- unstructured_ingest/v2/processes/connectors/singlestore.py +0 -156
- {unstructured_ingest-0.2.0.dist-info → unstructured_ingest-0.2.1.dist-info}/LICENSE.md +0 -0
- {unstructured_ingest-0.2.0.dist-info → unstructured_ingest-0.2.1.dist-info}/WHEEL +0 -0
- {unstructured_ingest-0.2.0.dist-info → unstructured_ingest-0.2.1.dist-info}/entry_points.txt +0 -0
- {unstructured_ingest-0.2.0.dist-info → unstructured_ingest-0.2.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,156 @@
|
|
|
1
|
+
import tempfile
|
|
2
|
+
from contextlib import contextmanager
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import pandas as pd
|
|
6
|
+
import pytest
|
|
7
|
+
import singlestoredb as s2
|
|
8
|
+
|
|
9
|
+
from test.integration.connectors.utils.constants import DESTINATION_TAG, SOURCE_TAG, env_setup_path
|
|
10
|
+
from test.integration.connectors.utils.docker_compose import docker_compose_context
|
|
11
|
+
from test.integration.connectors.utils.validation import (
|
|
12
|
+
ValidationConfigs,
|
|
13
|
+
source_connector_validation,
|
|
14
|
+
)
|
|
15
|
+
from unstructured_ingest.v2.interfaces import FileData
|
|
16
|
+
from unstructured_ingest.v2.processes.connectors.sql.singlestore import (
|
|
17
|
+
CONNECTOR_TYPE,
|
|
18
|
+
SingleStoreAccessConfig,
|
|
19
|
+
SingleStoreConnectionConfig,
|
|
20
|
+
SingleStoreDownloader,
|
|
21
|
+
SingleStoreDownloaderConfig,
|
|
22
|
+
SingleStoreIndexer,
|
|
23
|
+
SingleStoreIndexerConfig,
|
|
24
|
+
SingleStoreUploader,
|
|
25
|
+
SingleStoreUploaderConfig,
|
|
26
|
+
SingleStoreUploadStager,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
SEED_DATA_ROWS = 20
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@contextmanager
|
|
33
|
+
def singlestore_download_setup(connect_params: dict) -> None:
|
|
34
|
+
with docker_compose_context(
|
|
35
|
+
docker_compose_path=env_setup_path / "sql" / "singlestore" / "source"
|
|
36
|
+
):
|
|
37
|
+
with s2.connect(**connect_params) as connection:
|
|
38
|
+
with connection.cursor() as cursor:
|
|
39
|
+
for i in range(SEED_DATA_ROWS):
|
|
40
|
+
sql_statment = f"INSERT INTO cars (brand, price) VALUES " f"('brand_{i}', {i})"
|
|
41
|
+
cursor.execute(sql_statment)
|
|
42
|
+
connection.commit()
|
|
43
|
+
yield
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@pytest.mark.asyncio
|
|
47
|
+
@pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG, "sql")
|
|
48
|
+
async def test_singlestore_source():
|
|
49
|
+
connect_params = {
|
|
50
|
+
"host": "localhost",
|
|
51
|
+
"port": 3306,
|
|
52
|
+
"database": "ingest_test",
|
|
53
|
+
"user": "root",
|
|
54
|
+
"password": "password",
|
|
55
|
+
}
|
|
56
|
+
with singlestore_download_setup(connect_params=connect_params):
|
|
57
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
58
|
+
connection_config = SingleStoreConnectionConfig(
|
|
59
|
+
host=connect_params["host"],
|
|
60
|
+
port=connect_params["port"],
|
|
61
|
+
database=connect_params["database"],
|
|
62
|
+
user=connect_params["user"],
|
|
63
|
+
access_config=SingleStoreAccessConfig(password=connect_params["password"]),
|
|
64
|
+
)
|
|
65
|
+
indexer = SingleStoreIndexer(
|
|
66
|
+
connection_config=connection_config,
|
|
67
|
+
index_config=SingleStoreIndexerConfig(
|
|
68
|
+
table_name="cars", id_column="car_id", batch_size=5
|
|
69
|
+
),
|
|
70
|
+
)
|
|
71
|
+
downloader = SingleStoreDownloader(
|
|
72
|
+
connection_config=connection_config,
|
|
73
|
+
download_config=SingleStoreDownloaderConfig(
|
|
74
|
+
fields=["car_id", "brand"], download_dir=Path(tmpdir)
|
|
75
|
+
),
|
|
76
|
+
)
|
|
77
|
+
await source_connector_validation(
|
|
78
|
+
indexer=indexer,
|
|
79
|
+
downloader=downloader,
|
|
80
|
+
configs=ValidationConfigs(
|
|
81
|
+
test_id="singlestore",
|
|
82
|
+
expected_num_files=SEED_DATA_ROWS,
|
|
83
|
+
expected_number_indexed_file_data=4,
|
|
84
|
+
validate_downloaded_files=True,
|
|
85
|
+
),
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def validate_destination(
|
|
90
|
+
connect_params: dict,
|
|
91
|
+
expected_num_elements: int,
|
|
92
|
+
):
|
|
93
|
+
with s2.connect(**connect_params) as connection:
|
|
94
|
+
with connection.cursor() as cursor:
|
|
95
|
+
query = "select count(*) from elements;"
|
|
96
|
+
cursor.execute(query)
|
|
97
|
+
count = cursor.fetchone()[0]
|
|
98
|
+
assert (
|
|
99
|
+
count == expected_num_elements
|
|
100
|
+
), f"dest check failed: got {count}, expected {expected_num_elements}"
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
@pytest.mark.asyncio
|
|
104
|
+
@pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG, "sql")
|
|
105
|
+
async def test_singlestore_destination(upload_file: Path):
|
|
106
|
+
mock_file_data = FileData(identifier="mock file data", connector_type=CONNECTOR_TYPE)
|
|
107
|
+
with docker_compose_context(
|
|
108
|
+
docker_compose_path=env_setup_path / "sql" / "singlestore" / "destination"
|
|
109
|
+
):
|
|
110
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
111
|
+
stager = SingleStoreUploadStager()
|
|
112
|
+
stager_params = {
|
|
113
|
+
"elements_filepath": upload_file,
|
|
114
|
+
"file_data": mock_file_data,
|
|
115
|
+
"output_dir": Path(tmpdir),
|
|
116
|
+
"output_filename": "test_db",
|
|
117
|
+
}
|
|
118
|
+
if stager.is_async():
|
|
119
|
+
staged_path = await stager.run_async(**stager_params)
|
|
120
|
+
else:
|
|
121
|
+
staged_path = stager.run(**stager_params)
|
|
122
|
+
|
|
123
|
+
# The stager should append the `.json` suffix to the output filename passed in.
|
|
124
|
+
assert staged_path.name == "test_db.json"
|
|
125
|
+
|
|
126
|
+
connect_params = {
|
|
127
|
+
"host": "localhost",
|
|
128
|
+
"port": 3306,
|
|
129
|
+
"database": "ingest_test",
|
|
130
|
+
"user": "root",
|
|
131
|
+
"password": "password",
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
uploader = SingleStoreUploader(
|
|
135
|
+
connection_config=SingleStoreConnectionConfig(
|
|
136
|
+
host=connect_params["host"],
|
|
137
|
+
port=connect_params["port"],
|
|
138
|
+
database=connect_params["database"],
|
|
139
|
+
user=connect_params["user"],
|
|
140
|
+
access_config=SingleStoreAccessConfig(password=connect_params["password"]),
|
|
141
|
+
),
|
|
142
|
+
upload_config=SingleStoreUploaderConfig(
|
|
143
|
+
table_name="elements",
|
|
144
|
+
),
|
|
145
|
+
)
|
|
146
|
+
if uploader.is_async():
|
|
147
|
+
await uploader.run_async(path=staged_path, file_data=mock_file_data)
|
|
148
|
+
else:
|
|
149
|
+
uploader.run(path=staged_path, file_data=mock_file_data)
|
|
150
|
+
|
|
151
|
+
staged_df = pd.read_json(staged_path, orient="records", lines=True)
|
|
152
|
+
expected_num_elements = len(staged_df)
|
|
153
|
+
validate_destination(
|
|
154
|
+
connect_params=connect_params,
|
|
155
|
+
expected_num_elements=expected_num_elements,
|
|
156
|
+
)
|
|
@@ -85,7 +85,7 @@ async def test_s3_source_no_access(anon_connection_config: S3ConnectionConfig):
|
|
|
85
85
|
async def test_s3_minio_source(anon_connection_config: S3ConnectionConfig):
|
|
86
86
|
anon_connection_config.endpoint_url = "http://localhost:9000"
|
|
87
87
|
indexer_config = S3IndexerConfig(remote_url="s3://utic-dev-tech-fixtures/")
|
|
88
|
-
with docker_compose_context(docker_compose_path=env_setup_path / "minio"):
|
|
88
|
+
with docker_compose_context(docker_compose_path=env_setup_path / "minio" / "source"):
|
|
89
89
|
with tempfile.TemporaryDirectory() as tempdir:
|
|
90
90
|
tempdir_path = Path(tempdir)
|
|
91
91
|
download_config = S3DownloaderConfig(download_dir=tempdir_path)
|
|
@@ -3,6 +3,23 @@ from contextlib import contextmanager
|
|
|
3
3
|
from pathlib import Path
|
|
4
4
|
|
|
5
5
|
|
|
6
|
+
def docker_compose_down(docker_compose_path: Path):
|
|
7
|
+
cmd = f"docker compose -f {docker_compose_path.resolve()} down --remove-orphans -v --rmi all"
|
|
8
|
+
print(f"Running command: {cmd}")
|
|
9
|
+
final_resp = subprocess.run(
|
|
10
|
+
cmd,
|
|
11
|
+
shell=True,
|
|
12
|
+
capture_output=True,
|
|
13
|
+
)
|
|
14
|
+
if final_resp.returncode != 0:
|
|
15
|
+
print("STDOUT: {}".format(final_resp.stdout.decode("utf-8")))
|
|
16
|
+
print("STDERR: {}".format(final_resp.stderr.decode("utf-8")))
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def run_cleanup(docker_compose_path: Path):
|
|
20
|
+
docker_compose_down(docker_compose_path=docker_compose_path)
|
|
21
|
+
|
|
22
|
+
|
|
6
23
|
@contextmanager
|
|
7
24
|
def docker_compose_context(docker_compose_path: Path):
|
|
8
25
|
# Dynamically run a specific docker compose file and make sure it gets cleanup by
|
|
@@ -30,15 +47,13 @@ def docker_compose_context(docker_compose_path: Path):
|
|
|
30
47
|
if resp:
|
|
31
48
|
print("STDOUT: {}".format(resp.stdout.decode("utf-8")))
|
|
32
49
|
print("STDERR: {}".format(resp.stderr.decode("utf-8")))
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
cmd = f"docker compose -f {docker_compose_path.resolve()} down --remove-orphans -v"
|
|
36
|
-
print(f"Running command: {cmd}")
|
|
37
|
-
final_resp = subprocess.run(
|
|
50
|
+
cmd = f"docker compose -f {docker_compose_path.resolve()} logs"
|
|
51
|
+
logs = subprocess.run(
|
|
38
52
|
cmd,
|
|
39
53
|
shell=True,
|
|
40
54
|
capture_output=True,
|
|
41
55
|
)
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
56
|
+
print("DOCKER LOGS: {}".format(logs.stdout.decode("utf-8")))
|
|
57
|
+
raise e
|
|
58
|
+
finally:
|
|
59
|
+
run_cleanup(docker_compose_path=docker_compose_path)
|
|
@@ -1 +1 @@
|
|
|
1
|
-
__version__ = "0.2.
|
|
1
|
+
__version__ = "0.2.1" # pragma: no cover
|
|
@@ -43,6 +43,7 @@ class FileData(DataClassJsonMixin):
|
|
|
43
43
|
additional_metadata: dict[str, Any] = field(default_factory=dict)
|
|
44
44
|
reprocess: bool = False
|
|
45
45
|
local_download_path: Optional[str] = None
|
|
46
|
+
display_name: Optional[str] = None
|
|
46
47
|
|
|
47
48
|
@classmethod
|
|
48
49
|
def from_file(cls, path: str) -> "FileData":
|
|
@@ -11,7 +11,7 @@ from unstructured_ingest.v2.processes.connector_registry import (
|
|
|
11
11
|
from .airtable import CONNECTOR_TYPE as AIRTABLE_CONNECTOR_TYPE
|
|
12
12
|
from .airtable import airtable_source_entry
|
|
13
13
|
from .astradb import CONNECTOR_TYPE as ASTRA_DB_CONNECTOR_TYPE
|
|
14
|
-
from .astradb import astra_db_destination_entry
|
|
14
|
+
from .astradb import astra_db_destination_entry, astra_db_source_entry
|
|
15
15
|
from .azure_cognitive_search import CONNECTOR_TYPE as AZURE_COGNTIVE_SEARCH_CONNECTOR_TYPE
|
|
16
16
|
from .azure_cognitive_search import azure_cognitive_search_destination_entry
|
|
17
17
|
from .chroma import CONNECTOR_TYPE as CHROMA_CONNECTOR_TYPE
|
|
@@ -44,13 +44,12 @@ from .salesforce import CONNECTOR_TYPE as SALESFORCE_CONNECTOR_TYPE
|
|
|
44
44
|
from .salesforce import salesforce_source_entry
|
|
45
45
|
from .sharepoint import CONNECTOR_TYPE as SHAREPOINT_CONNECTOR_TYPE
|
|
46
46
|
from .sharepoint import sharepoint_source_entry
|
|
47
|
-
from .singlestore import CONNECTOR_TYPE as SINGLESTORE_CONNECTOR_TYPE
|
|
48
|
-
from .singlestore import singlestore_destination_entry
|
|
49
47
|
from .slack import CONNECTOR_TYPE as SLACK_CONNECTOR_TYPE
|
|
50
48
|
from .slack import slack_source_entry
|
|
51
49
|
from .weaviate import CONNECTOR_TYPE as WEAVIATE_CONNECTOR_TYPE
|
|
52
50
|
from .weaviate import weaviate_destination_entry
|
|
53
51
|
|
|
52
|
+
add_source_entry(source_type=ASTRA_DB_CONNECTOR_TYPE, entry=astra_db_source_entry)
|
|
54
53
|
add_destination_entry(destination_type=ASTRA_DB_CONNECTOR_TYPE, entry=astra_db_destination_entry)
|
|
55
54
|
|
|
56
55
|
add_destination_entry(destination_type=CHROMA_CONNECTOR_TYPE, entry=chroma_destination_entry)
|
|
@@ -88,9 +87,7 @@ add_source_entry(source_type=MONGODB_CONNECTOR_TYPE, entry=mongodb_source_entry)
|
|
|
88
87
|
|
|
89
88
|
add_destination_entry(destination_type=PINECONE_CONNECTOR_TYPE, entry=pinecone_destination_entry)
|
|
90
89
|
add_source_entry(source_type=SHAREPOINT_CONNECTOR_TYPE, entry=sharepoint_source_entry)
|
|
91
|
-
|
|
92
|
-
destination_type=SINGLESTORE_CONNECTOR_TYPE, entry=singlestore_destination_entry
|
|
93
|
-
)
|
|
90
|
+
|
|
94
91
|
add_destination_entry(destination_type=MILVUS_CONNECTOR_TYPE, entry=milvus_destination_entry)
|
|
95
92
|
add_destination_entry(
|
|
96
93
|
destination_type=AZURE_COGNTIVE_SEARCH_CONNECTOR_TYPE,
|
|
@@ -1,31 +1,50 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import csv
|
|
3
|
+
import hashlib
|
|
1
4
|
import json
|
|
5
|
+
import sys
|
|
2
6
|
from dataclasses import dataclass, field
|
|
3
7
|
from pathlib import Path
|
|
4
|
-
from
|
|
8
|
+
from time import time
|
|
9
|
+
from typing import TYPE_CHECKING, Any, Generator, Optional
|
|
5
10
|
|
|
6
11
|
from pydantic import Field, Secret
|
|
7
12
|
|
|
8
13
|
from unstructured_ingest import __name__ as integration_name
|
|
9
14
|
from unstructured_ingest.__version__ import __version__ as integration_version
|
|
10
|
-
from unstructured_ingest.error import
|
|
15
|
+
from unstructured_ingest.error import (
|
|
16
|
+
DestinationConnectionError,
|
|
17
|
+
SourceConnectionError,
|
|
18
|
+
SourceConnectionNetworkError,
|
|
19
|
+
)
|
|
11
20
|
from unstructured_ingest.utils.data_prep import batch_generator
|
|
12
21
|
from unstructured_ingest.utils.dep_check import requires_dependencies
|
|
13
22
|
from unstructured_ingest.v2.interfaces import (
|
|
14
23
|
AccessConfig,
|
|
15
24
|
ConnectionConfig,
|
|
25
|
+
Downloader,
|
|
26
|
+
DownloaderConfig,
|
|
27
|
+
DownloadResponse,
|
|
16
28
|
FileData,
|
|
29
|
+
FileDataSourceMetadata,
|
|
30
|
+
Indexer,
|
|
31
|
+
IndexerConfig,
|
|
17
32
|
Uploader,
|
|
18
33
|
UploaderConfig,
|
|
19
34
|
UploadStager,
|
|
20
35
|
UploadStagerConfig,
|
|
36
|
+
download_responses,
|
|
21
37
|
)
|
|
22
38
|
from unstructured_ingest.v2.logger import logger
|
|
23
39
|
from unstructured_ingest.v2.processes.connector_registry import (
|
|
24
40
|
DestinationRegistryEntry,
|
|
41
|
+
SourceRegistryEntry,
|
|
25
42
|
)
|
|
26
43
|
|
|
27
44
|
if TYPE_CHECKING:
|
|
45
|
+
from astrapy import AsyncCollection as AstraDBAsyncCollection
|
|
28
46
|
from astrapy import Collection as AstraDBCollection
|
|
47
|
+
from astrapy import DataAPIClient as AstraDBClient
|
|
29
48
|
|
|
30
49
|
|
|
31
50
|
CONNECTOR_TYPE = "astradb"
|
|
@@ -37,14 +56,253 @@ class AstraDBAccessConfig(AccessConfig):
|
|
|
37
56
|
|
|
38
57
|
|
|
39
58
|
class AstraDBConnectionConfig(ConnectionConfig):
|
|
40
|
-
|
|
59
|
+
connector_type: str = Field(default=CONNECTOR_TYPE, init=False)
|
|
41
60
|
access_config: Secret[AstraDBAccessConfig]
|
|
42
61
|
|
|
62
|
+
@requires_dependencies(["astrapy"], extras="astradb")
|
|
63
|
+
def get_client(self) -> "AstraDBClient":
|
|
64
|
+
from astrapy import DataAPIClient as AstraDBClient
|
|
65
|
+
|
|
66
|
+
# Create a client object to interact with the Astra DB
|
|
67
|
+
# caller_name/version for Astra DB tracking
|
|
68
|
+
return AstraDBClient(
|
|
69
|
+
caller_name=integration_name,
|
|
70
|
+
caller_version=integration_version,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def get_astra_collection(
|
|
75
|
+
connection_config: AstraDBConnectionConfig,
|
|
76
|
+
collection_name: str,
|
|
77
|
+
keyspace: str,
|
|
78
|
+
) -> "AstraDBCollection":
|
|
79
|
+
# Build the Astra DB object.
|
|
80
|
+
access_configs = connection_config.access_config.get_secret_value()
|
|
81
|
+
|
|
82
|
+
# Create a client object to interact with the Astra DB
|
|
83
|
+
# caller_name/version for Astra DB tracking
|
|
84
|
+
client = connection_config.get_client()
|
|
85
|
+
|
|
86
|
+
# Get the database object
|
|
87
|
+
astra_db = client.get_database(
|
|
88
|
+
api_endpoint=access_configs.api_endpoint,
|
|
89
|
+
token=access_configs.token,
|
|
90
|
+
keyspace=keyspace,
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
# Connect to the collection
|
|
94
|
+
astra_db_collection = astra_db.get_collection(name=collection_name)
|
|
95
|
+
return astra_db_collection
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
async def get_async_astra_collection(
|
|
99
|
+
connection_config: AstraDBConnectionConfig,
|
|
100
|
+
collection_name: str,
|
|
101
|
+
keyspace: str,
|
|
102
|
+
) -> "AstraDBAsyncCollection":
|
|
103
|
+
# Build the Astra DB object.
|
|
104
|
+
access_configs = connection_config.access_config.get_secret_value()
|
|
105
|
+
|
|
106
|
+
# Create a client object to interact with the Astra DB
|
|
107
|
+
client = connection_config.get_client()
|
|
108
|
+
|
|
109
|
+
# Get the async database object
|
|
110
|
+
async_astra_db = client.get_async_database(
|
|
111
|
+
api_endpoint=access_configs.api_endpoint,
|
|
112
|
+
token=access_configs.token,
|
|
113
|
+
keyspace=keyspace,
|
|
114
|
+
)
|
|
115
|
+
|
|
116
|
+
# Get async collection from AsyncDatabase
|
|
117
|
+
async_astra_db_collection = await async_astra_db.get_collection(name=collection_name)
|
|
118
|
+
return async_astra_db_collection
|
|
119
|
+
|
|
43
120
|
|
|
44
121
|
class AstraDBUploadStagerConfig(UploadStagerConfig):
|
|
45
122
|
pass
|
|
46
123
|
|
|
47
124
|
|
|
125
|
+
class AstraDBIndexerConfig(IndexerConfig):
|
|
126
|
+
collection_name: str = Field(
|
|
127
|
+
description="The name of the Astra DB collection. "
|
|
128
|
+
"Note that the collection name must only include letters, "
|
|
129
|
+
"numbers, and underscores."
|
|
130
|
+
)
|
|
131
|
+
keyspace: Optional[str] = Field(default=None, description="The Astra DB connection keyspace.")
|
|
132
|
+
namespace: Optional[str] = Field(
|
|
133
|
+
default=None,
|
|
134
|
+
description="The Astra DB connection namespace.",
|
|
135
|
+
deprecated="Please use 'keyspace' instead.",
|
|
136
|
+
)
|
|
137
|
+
batch_size: int = Field(default=20, description="Number of records per batch")
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
class AstraDBDownloaderConfig(DownloaderConfig):
|
|
141
|
+
fields: list[str] = field(default_factory=list)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class AstraDBUploaderConfig(UploaderConfig):
|
|
145
|
+
collection_name: str = Field(
|
|
146
|
+
description="The name of the Astra DB collection. "
|
|
147
|
+
"Note that the collection name must only include letters, "
|
|
148
|
+
"numbers, and underscores."
|
|
149
|
+
)
|
|
150
|
+
embedding_dimension: int = Field(
|
|
151
|
+
default=384, description="The dimensionality of the embeddings"
|
|
152
|
+
)
|
|
153
|
+
keyspace: Optional[str] = Field(default=None, description="The Astra DB connection keyspace.")
|
|
154
|
+
namespace: Optional[str] = Field(
|
|
155
|
+
default=None,
|
|
156
|
+
description="The Astra DB connection namespace.",
|
|
157
|
+
deprecated="Please use 'keyspace' instead.",
|
|
158
|
+
)
|
|
159
|
+
requested_indexing_policy: Optional[dict[str, Any]] = Field(
|
|
160
|
+
default=None,
|
|
161
|
+
description="The indexing policy to use for the collection.",
|
|
162
|
+
examples=['{"deny": ["metadata"]}'],
|
|
163
|
+
)
|
|
164
|
+
batch_size: int = Field(default=20, description="Number of records per batch")
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
@dataclass
|
|
168
|
+
class AstraDBIndexer(Indexer):
|
|
169
|
+
connection_config: AstraDBConnectionConfig
|
|
170
|
+
index_config: AstraDBIndexerConfig
|
|
171
|
+
|
|
172
|
+
def get_collection(self) -> "AstraDBCollection":
|
|
173
|
+
return get_astra_collection(
|
|
174
|
+
connection_config=self.connection_config,
|
|
175
|
+
collection_name=self.index_config.collection_name,
|
|
176
|
+
keyspace=self.index_config.keyspace or self.index_config.namespace,
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
def precheck(self) -> None:
|
|
180
|
+
try:
|
|
181
|
+
self.get_collection()
|
|
182
|
+
except Exception as e:
|
|
183
|
+
logger.error(f"Failed to validate connection {e}", exc_info=True)
|
|
184
|
+
raise SourceConnectionError(f"failed to validate connection: {e}")
|
|
185
|
+
|
|
186
|
+
def _get_doc_ids(self) -> set[str]:
|
|
187
|
+
"""Fetches all document ids in an index"""
|
|
188
|
+
# Initialize set of ids
|
|
189
|
+
ids = set()
|
|
190
|
+
|
|
191
|
+
# Get the collection
|
|
192
|
+
collection = self.get_collection()
|
|
193
|
+
|
|
194
|
+
# Perform the find operation to get all items
|
|
195
|
+
astra_db_docs_cursor = collection.find({}, projection={"_id": True})
|
|
196
|
+
|
|
197
|
+
# Iterate over the cursor
|
|
198
|
+
astra_db_docs = []
|
|
199
|
+
for result in astra_db_docs_cursor:
|
|
200
|
+
astra_db_docs.append(result)
|
|
201
|
+
|
|
202
|
+
# Create file data for each astra record
|
|
203
|
+
for astra_record in astra_db_docs:
|
|
204
|
+
ids.add(astra_record["_id"])
|
|
205
|
+
|
|
206
|
+
return ids
|
|
207
|
+
|
|
208
|
+
def run(self, **kwargs: Any) -> Generator[FileData, None, None]:
|
|
209
|
+
all_ids = self._get_doc_ids()
|
|
210
|
+
ids = list(all_ids)
|
|
211
|
+
id_batches = batch_generator(ids, self.index_config.batch_size)
|
|
212
|
+
|
|
213
|
+
for batch in id_batches:
|
|
214
|
+
# Make sure the hash is always a positive number to create identified
|
|
215
|
+
identified = str(hash(batch) + sys.maxsize + 1)
|
|
216
|
+
fd = FileData(
|
|
217
|
+
identifier=identified,
|
|
218
|
+
connector_type=CONNECTOR_TYPE,
|
|
219
|
+
doc_type="batch",
|
|
220
|
+
metadata=FileDataSourceMetadata(
|
|
221
|
+
date_processed=str(time()),
|
|
222
|
+
),
|
|
223
|
+
additional_metadata={
|
|
224
|
+
"ids": list(batch),
|
|
225
|
+
"collection_name": self.index_config.collection_name,
|
|
226
|
+
"keyspace": self.index_config.keyspace or self.index_config.namespace,
|
|
227
|
+
},
|
|
228
|
+
)
|
|
229
|
+
yield fd
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
@dataclass
|
|
233
|
+
class AstraDBDownloader(Downloader):
|
|
234
|
+
connection_config: AstraDBConnectionConfig
|
|
235
|
+
download_config: AstraDBDownloaderConfig
|
|
236
|
+
connector_type: str = CONNECTOR_TYPE
|
|
237
|
+
|
|
238
|
+
def is_async(self) -> bool:
|
|
239
|
+
return True
|
|
240
|
+
|
|
241
|
+
def get_identifier(self, record_id: str) -> str:
|
|
242
|
+
f = f"{record_id}"
|
|
243
|
+
if self.download_config.fields:
|
|
244
|
+
f = "{}-{}".format(
|
|
245
|
+
f,
|
|
246
|
+
hashlib.sha256(",".join(self.download_config.fields).encode()).hexdigest()[:8],
|
|
247
|
+
)
|
|
248
|
+
return f
|
|
249
|
+
|
|
250
|
+
def write_astra_result_to_csv(self, astra_result: dict, download_path: str) -> None:
|
|
251
|
+
with open(download_path, "w", encoding="utf8") as f:
|
|
252
|
+
writer = csv.writer(f)
|
|
253
|
+
writer.writerow(astra_result.keys())
|
|
254
|
+
writer.writerow(astra_result.values())
|
|
255
|
+
|
|
256
|
+
def generate_download_response(self, result: dict, file_data: FileData) -> DownloadResponse:
|
|
257
|
+
record_id = result["_id"]
|
|
258
|
+
filename_id = self.get_identifier(record_id=record_id)
|
|
259
|
+
filename = f"{filename_id}.csv" # csv to preserve column info
|
|
260
|
+
download_path = self.download_dir / Path(filename)
|
|
261
|
+
logger.debug(f"Downloading results from record {record_id} as csv to {download_path}")
|
|
262
|
+
download_path.parent.mkdir(parents=True, exist_ok=True)
|
|
263
|
+
try:
|
|
264
|
+
self.write_astra_result_to_csv(astra_result=result, download_path=download_path)
|
|
265
|
+
except Exception as e:
|
|
266
|
+
logger.error(
|
|
267
|
+
f"failed to download from record {record_id} to {download_path}: {e}",
|
|
268
|
+
exc_info=True,
|
|
269
|
+
)
|
|
270
|
+
raise SourceConnectionNetworkError(f"failed to download file {file_data.identifier}")
|
|
271
|
+
|
|
272
|
+
# modify input file_data for download_response
|
|
273
|
+
copied_file_data = copy.deepcopy(file_data)
|
|
274
|
+
copied_file_data.identifier = filename
|
|
275
|
+
copied_file_data.doc_type = "file"
|
|
276
|
+
copied_file_data.metadata.date_processed = str(time())
|
|
277
|
+
copied_file_data.metadata.record_locator = {"document_id": record_id}
|
|
278
|
+
copied_file_data.additional_metadata.pop("ids", None)
|
|
279
|
+
return super().generate_download_response(
|
|
280
|
+
file_data=copied_file_data, download_path=download_path
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
def run(self, file_data: FileData, **kwargs: Any) -> download_responses:
|
|
284
|
+
raise NotImplementedError("Use astradb run_async instead")
|
|
285
|
+
|
|
286
|
+
async def run_async(self, file_data: FileData, **kwargs: Any) -> download_responses:
|
|
287
|
+
# Get metadata from file_data
|
|
288
|
+
ids: list[str] = file_data.additional_metadata["ids"]
|
|
289
|
+
collection_name: str = file_data.additional_metadata["collection_name"]
|
|
290
|
+
keyspace: str = file_data.additional_metadata["keyspace"]
|
|
291
|
+
|
|
292
|
+
# Retrieve results from async collection
|
|
293
|
+
download_responses = []
|
|
294
|
+
async_astra_collection = await get_async_astra_collection(
|
|
295
|
+
connection_config=self.connection_config,
|
|
296
|
+
collection_name=collection_name,
|
|
297
|
+
keyspace=keyspace,
|
|
298
|
+
)
|
|
299
|
+
async for result in async_astra_collection.find({"_id": {"$in": ids}}):
|
|
300
|
+
download_responses.append(
|
|
301
|
+
self.generate_download_response(result=result, file_data=file_data)
|
|
302
|
+
)
|
|
303
|
+
return download_responses
|
|
304
|
+
|
|
305
|
+
|
|
48
306
|
@dataclass
|
|
49
307
|
class AstraDBUploadStager(UploadStager):
|
|
50
308
|
upload_stager_config: AstraDBUploadStagerConfig = field(
|
|
@@ -77,29 +335,6 @@ class AstraDBUploadStager(UploadStager):
|
|
|
77
335
|
return output_path
|
|
78
336
|
|
|
79
337
|
|
|
80
|
-
class AstraDBUploaderConfig(UploaderConfig):
|
|
81
|
-
collection_name: str = Field(
|
|
82
|
-
description="The name of the Astra DB collection. "
|
|
83
|
-
"Note that the collection name must only include letters, "
|
|
84
|
-
"numbers, and underscores."
|
|
85
|
-
)
|
|
86
|
-
embedding_dimension: int = Field(
|
|
87
|
-
default=384, description="The dimensionality of the embeddings"
|
|
88
|
-
)
|
|
89
|
-
keyspace: Optional[str] = Field(default=None, description="The Astra DB connection keyspace.")
|
|
90
|
-
namespace: Optional[str] = Field(
|
|
91
|
-
default=None,
|
|
92
|
-
description="The Astra DB connection namespace.",
|
|
93
|
-
deprecated="Please use 'keyspace' instead.",
|
|
94
|
-
)
|
|
95
|
-
requested_indexing_policy: Optional[dict[str, Any]] = Field(
|
|
96
|
-
default=None,
|
|
97
|
-
description="The indexing policy to use for the collection.",
|
|
98
|
-
examples=['{"deny": ["metadata"]}'],
|
|
99
|
-
)
|
|
100
|
-
batch_size: int = Field(default=20, description="Number of records per batch")
|
|
101
|
-
|
|
102
|
-
|
|
103
338
|
@dataclass
|
|
104
339
|
class AstraDBUploader(Uploader):
|
|
105
340
|
connection_config: AstraDBConnectionConfig
|
|
@@ -108,43 +343,23 @@ class AstraDBUploader(Uploader):
|
|
|
108
343
|
|
|
109
344
|
def precheck(self) -> None:
|
|
110
345
|
try:
|
|
111
|
-
|
|
346
|
+
get_astra_collection(
|
|
347
|
+
connection_config=self.connection_config,
|
|
348
|
+
collection_name=self.upload_config.collection_name,
|
|
349
|
+
keyspace=self.upload_config.keyspace or self.upload_config.namespace,
|
|
350
|
+
)
|
|
112
351
|
except Exception as e:
|
|
113
352
|
logger.error(f"Failed to validate connection {e}", exc_info=True)
|
|
114
353
|
raise DestinationConnectionError(f"failed to validate connection: {e}")
|
|
115
354
|
|
|
116
355
|
@requires_dependencies(["astrapy"], extras="astradb")
|
|
117
356
|
def get_collection(self) -> "AstraDBCollection":
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
# Get the collection_name
|
|
124
|
-
collection_name = self.upload_config.collection_name
|
|
125
|
-
|
|
126
|
-
# Build the Astra DB object.
|
|
127
|
-
access_configs = self.connection_config.access_config.get_secret_value()
|
|
128
|
-
|
|
129
|
-
# Create a client object to interact with the Astra DB
|
|
130
|
-
# caller_name/version for Astra DB tracking
|
|
131
|
-
my_client = AstraDBClient(
|
|
132
|
-
caller_name=integration_name,
|
|
133
|
-
caller_version=integration_version,
|
|
134
|
-
)
|
|
135
|
-
|
|
136
|
-
# Get the database object
|
|
137
|
-
astra_db = my_client.get_database(
|
|
138
|
-
api_endpoint=access_configs.api_endpoint,
|
|
139
|
-
token=access_configs.token,
|
|
140
|
-
keyspace=keyspace_param,
|
|
357
|
+
return get_astra_collection(
|
|
358
|
+
connection_config=self.connection_config,
|
|
359
|
+
collection_name=self.upload_config.collection_name,
|
|
360
|
+
keyspace=self.upload_config.keyspace or self.upload_config.namespace,
|
|
141
361
|
)
|
|
142
362
|
|
|
143
|
-
# Connect to the newly created collection
|
|
144
|
-
astra_db_collection = astra_db.get_collection(name=collection_name)
|
|
145
|
-
|
|
146
|
-
return astra_db_collection
|
|
147
|
-
|
|
148
363
|
def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
|
|
149
364
|
with path.open("r") as file:
|
|
150
365
|
elements_dict = json.load(file)
|
|
@@ -160,6 +375,14 @@ class AstraDBUploader(Uploader):
|
|
|
160
375
|
collection.insert_many(chunk)
|
|
161
376
|
|
|
162
377
|
|
|
378
|
+
astra_db_source_entry = SourceRegistryEntry(
|
|
379
|
+
indexer=AstraDBIndexer,
|
|
380
|
+
indexer_config=AstraDBIndexerConfig,
|
|
381
|
+
downloader=AstraDBDownloader,
|
|
382
|
+
downloader_config=AstraDBDownloaderConfig,
|
|
383
|
+
connection_config=AstraDBConnectionConfig,
|
|
384
|
+
)
|
|
385
|
+
|
|
163
386
|
astra_db_destination_entry = DestinationRegistryEntry(
|
|
164
387
|
connection_config=AstraDBConnectionConfig,
|
|
165
388
|
upload_stager_config=AstraDBUploadStagerConfig,
|
|
@@ -166,7 +166,9 @@ class DatabricksVolumesUploader(Uploader, ABC):
|
|
|
166
166
|
raise DestinationConnectionError(f"failed to validate connection: {e}")
|
|
167
167
|
|
|
168
168
|
def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
|
|
169
|
-
output_path = os.path.join(
|
|
169
|
+
output_path = os.path.join(
|
|
170
|
+
self.upload_config.path, f"{file_data.source_identifiers.filename}.json"
|
|
171
|
+
)
|
|
170
172
|
with open(path, "rb") as elements_file:
|
|
171
173
|
self.connection_config.get_client().files.upload(
|
|
172
174
|
file_path=output_path,
|