unstructured-ingest 0.2.0__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unstructured-ingest might be problematic. Click here for more details.

Files changed (22) hide show
  1. test/integration/connectors/sql/test_singlestore.py +156 -0
  2. test/integration/connectors/test_s3.py +1 -1
  3. test/integration/connectors/utils/docker_compose.py +23 -8
  4. unstructured_ingest/__version__.py +1 -1
  5. unstructured_ingest/v2/interfaces/file_data.py +1 -0
  6. unstructured_ingest/v2/processes/connectors/__init__.py +3 -6
  7. unstructured_ingest/v2/processes/connectors/astradb.py +278 -55
  8. unstructured_ingest/v2/processes/connectors/databricks/volumes.py +3 -1
  9. unstructured_ingest/v2/processes/connectors/fsspec/fsspec.py +1 -0
  10. unstructured_ingest/v2/processes/connectors/sql/__init__.py +5 -0
  11. unstructured_ingest/v2/processes/connectors/sql/postgres.py +1 -20
  12. unstructured_ingest/v2/processes/connectors/sql/singlestore.py +168 -0
  13. unstructured_ingest/v2/processes/connectors/sql/snowflake.py +2 -4
  14. unstructured_ingest/v2/processes/connectors/sql/sql.py +13 -2
  15. unstructured_ingest/v2/unstructured_api.py +1 -1
  16. {unstructured_ingest-0.2.0.dist-info → unstructured_ingest-0.2.1.dist-info}/METADATA +17 -17
  17. {unstructured_ingest-0.2.0.dist-info → unstructured_ingest-0.2.1.dist-info}/RECORD +21 -20
  18. unstructured_ingest/v2/processes/connectors/singlestore.py +0 -156
  19. {unstructured_ingest-0.2.0.dist-info → unstructured_ingest-0.2.1.dist-info}/LICENSE.md +0 -0
  20. {unstructured_ingest-0.2.0.dist-info → unstructured_ingest-0.2.1.dist-info}/WHEEL +0 -0
  21. {unstructured_ingest-0.2.0.dist-info → unstructured_ingest-0.2.1.dist-info}/entry_points.txt +0 -0
  22. {unstructured_ingest-0.2.0.dist-info → unstructured_ingest-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,156 @@
1
+ import tempfile
2
+ from contextlib import contextmanager
3
+ from pathlib import Path
4
+
5
+ import pandas as pd
6
+ import pytest
7
+ import singlestoredb as s2
8
+
9
+ from test.integration.connectors.utils.constants import DESTINATION_TAG, SOURCE_TAG, env_setup_path
10
+ from test.integration.connectors.utils.docker_compose import docker_compose_context
11
+ from test.integration.connectors.utils.validation import (
12
+ ValidationConfigs,
13
+ source_connector_validation,
14
+ )
15
+ from unstructured_ingest.v2.interfaces import FileData
16
+ from unstructured_ingest.v2.processes.connectors.sql.singlestore import (
17
+ CONNECTOR_TYPE,
18
+ SingleStoreAccessConfig,
19
+ SingleStoreConnectionConfig,
20
+ SingleStoreDownloader,
21
+ SingleStoreDownloaderConfig,
22
+ SingleStoreIndexer,
23
+ SingleStoreIndexerConfig,
24
+ SingleStoreUploader,
25
+ SingleStoreUploaderConfig,
26
+ SingleStoreUploadStager,
27
+ )
28
+
29
+ SEED_DATA_ROWS = 20
30
+
31
+
32
+ @contextmanager
33
+ def singlestore_download_setup(connect_params: dict) -> None:
34
+ with docker_compose_context(
35
+ docker_compose_path=env_setup_path / "sql" / "singlestore" / "source"
36
+ ):
37
+ with s2.connect(**connect_params) as connection:
38
+ with connection.cursor() as cursor:
39
+ for i in range(SEED_DATA_ROWS):
40
+ sql_statment = f"INSERT INTO cars (brand, price) VALUES " f"('brand_{i}', {i})"
41
+ cursor.execute(sql_statment)
42
+ connection.commit()
43
+ yield
44
+
45
+
46
+ @pytest.mark.asyncio
47
+ @pytest.mark.tags(CONNECTOR_TYPE, SOURCE_TAG, "sql")
48
+ async def test_singlestore_source():
49
+ connect_params = {
50
+ "host": "localhost",
51
+ "port": 3306,
52
+ "database": "ingest_test",
53
+ "user": "root",
54
+ "password": "password",
55
+ }
56
+ with singlestore_download_setup(connect_params=connect_params):
57
+ with tempfile.TemporaryDirectory() as tmpdir:
58
+ connection_config = SingleStoreConnectionConfig(
59
+ host=connect_params["host"],
60
+ port=connect_params["port"],
61
+ database=connect_params["database"],
62
+ user=connect_params["user"],
63
+ access_config=SingleStoreAccessConfig(password=connect_params["password"]),
64
+ )
65
+ indexer = SingleStoreIndexer(
66
+ connection_config=connection_config,
67
+ index_config=SingleStoreIndexerConfig(
68
+ table_name="cars", id_column="car_id", batch_size=5
69
+ ),
70
+ )
71
+ downloader = SingleStoreDownloader(
72
+ connection_config=connection_config,
73
+ download_config=SingleStoreDownloaderConfig(
74
+ fields=["car_id", "brand"], download_dir=Path(tmpdir)
75
+ ),
76
+ )
77
+ await source_connector_validation(
78
+ indexer=indexer,
79
+ downloader=downloader,
80
+ configs=ValidationConfigs(
81
+ test_id="singlestore",
82
+ expected_num_files=SEED_DATA_ROWS,
83
+ expected_number_indexed_file_data=4,
84
+ validate_downloaded_files=True,
85
+ ),
86
+ )
87
+
88
+
89
+ def validate_destination(
90
+ connect_params: dict,
91
+ expected_num_elements: int,
92
+ ):
93
+ with s2.connect(**connect_params) as connection:
94
+ with connection.cursor() as cursor:
95
+ query = "select count(*) from elements;"
96
+ cursor.execute(query)
97
+ count = cursor.fetchone()[0]
98
+ assert (
99
+ count == expected_num_elements
100
+ ), f"dest check failed: got {count}, expected {expected_num_elements}"
101
+
102
+
103
+ @pytest.mark.asyncio
104
+ @pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG, "sql")
105
+ async def test_singlestore_destination(upload_file: Path):
106
+ mock_file_data = FileData(identifier="mock file data", connector_type=CONNECTOR_TYPE)
107
+ with docker_compose_context(
108
+ docker_compose_path=env_setup_path / "sql" / "singlestore" / "destination"
109
+ ):
110
+ with tempfile.TemporaryDirectory() as tmpdir:
111
+ stager = SingleStoreUploadStager()
112
+ stager_params = {
113
+ "elements_filepath": upload_file,
114
+ "file_data": mock_file_data,
115
+ "output_dir": Path(tmpdir),
116
+ "output_filename": "test_db",
117
+ }
118
+ if stager.is_async():
119
+ staged_path = await stager.run_async(**stager_params)
120
+ else:
121
+ staged_path = stager.run(**stager_params)
122
+
123
+ # The stager should append the `.json` suffix to the output filename passed in.
124
+ assert staged_path.name == "test_db.json"
125
+
126
+ connect_params = {
127
+ "host": "localhost",
128
+ "port": 3306,
129
+ "database": "ingest_test",
130
+ "user": "root",
131
+ "password": "password",
132
+ }
133
+
134
+ uploader = SingleStoreUploader(
135
+ connection_config=SingleStoreConnectionConfig(
136
+ host=connect_params["host"],
137
+ port=connect_params["port"],
138
+ database=connect_params["database"],
139
+ user=connect_params["user"],
140
+ access_config=SingleStoreAccessConfig(password=connect_params["password"]),
141
+ ),
142
+ upload_config=SingleStoreUploaderConfig(
143
+ table_name="elements",
144
+ ),
145
+ )
146
+ if uploader.is_async():
147
+ await uploader.run_async(path=staged_path, file_data=mock_file_data)
148
+ else:
149
+ uploader.run(path=staged_path, file_data=mock_file_data)
150
+
151
+ staged_df = pd.read_json(staged_path, orient="records", lines=True)
152
+ expected_num_elements = len(staged_df)
153
+ validate_destination(
154
+ connect_params=connect_params,
155
+ expected_num_elements=expected_num_elements,
156
+ )
@@ -85,7 +85,7 @@ async def test_s3_source_no_access(anon_connection_config: S3ConnectionConfig):
85
85
  async def test_s3_minio_source(anon_connection_config: S3ConnectionConfig):
86
86
  anon_connection_config.endpoint_url = "http://localhost:9000"
87
87
  indexer_config = S3IndexerConfig(remote_url="s3://utic-dev-tech-fixtures/")
88
- with docker_compose_context(docker_compose_path=env_setup_path / "minio"):
88
+ with docker_compose_context(docker_compose_path=env_setup_path / "minio" / "source"):
89
89
  with tempfile.TemporaryDirectory() as tempdir:
90
90
  tempdir_path = Path(tempdir)
91
91
  download_config = S3DownloaderConfig(download_dir=tempdir_path)
@@ -3,6 +3,23 @@ from contextlib import contextmanager
3
3
  from pathlib import Path
4
4
 
5
5
 
6
+ def docker_compose_down(docker_compose_path: Path):
7
+ cmd = f"docker compose -f {docker_compose_path.resolve()} down --remove-orphans -v --rmi all"
8
+ print(f"Running command: {cmd}")
9
+ final_resp = subprocess.run(
10
+ cmd,
11
+ shell=True,
12
+ capture_output=True,
13
+ )
14
+ if final_resp.returncode != 0:
15
+ print("STDOUT: {}".format(final_resp.stdout.decode("utf-8")))
16
+ print("STDERR: {}".format(final_resp.stderr.decode("utf-8")))
17
+
18
+
19
+ def run_cleanup(docker_compose_path: Path):
20
+ docker_compose_down(docker_compose_path=docker_compose_path)
21
+
22
+
6
23
  @contextmanager
7
24
  def docker_compose_context(docker_compose_path: Path):
8
25
  # Dynamically run a specific docker compose file and make sure it gets cleanup by
@@ -30,15 +47,13 @@ def docker_compose_context(docker_compose_path: Path):
30
47
  if resp:
31
48
  print("STDOUT: {}".format(resp.stdout.decode("utf-8")))
32
49
  print("STDERR: {}".format(resp.stderr.decode("utf-8")))
33
- raise e
34
- finally:
35
- cmd = f"docker compose -f {docker_compose_path.resolve()} down --remove-orphans -v"
36
- print(f"Running command: {cmd}")
37
- final_resp = subprocess.run(
50
+ cmd = f"docker compose -f {docker_compose_path.resolve()} logs"
51
+ logs = subprocess.run(
38
52
  cmd,
39
53
  shell=True,
40
54
  capture_output=True,
41
55
  )
42
- if final_resp.returncode != 0:
43
- print("STDOUT: {}".format(final_resp.stdout.decode("utf-8")))
44
- print("STDERR: {}".format(final_resp.stderr.decode("utf-8")))
56
+ print("DOCKER LOGS: {}".format(logs.stdout.decode("utf-8")))
57
+ raise e
58
+ finally:
59
+ run_cleanup(docker_compose_path=docker_compose_path)
@@ -1 +1 @@
1
- __version__ = "0.2.0" # pragma: no cover
1
+ __version__ = "0.2.1" # pragma: no cover
@@ -43,6 +43,7 @@ class FileData(DataClassJsonMixin):
43
43
  additional_metadata: dict[str, Any] = field(default_factory=dict)
44
44
  reprocess: bool = False
45
45
  local_download_path: Optional[str] = None
46
+ display_name: Optional[str] = None
46
47
 
47
48
  @classmethod
48
49
  def from_file(cls, path: str) -> "FileData":
@@ -11,7 +11,7 @@ from unstructured_ingest.v2.processes.connector_registry import (
11
11
  from .airtable import CONNECTOR_TYPE as AIRTABLE_CONNECTOR_TYPE
12
12
  from .airtable import airtable_source_entry
13
13
  from .astradb import CONNECTOR_TYPE as ASTRA_DB_CONNECTOR_TYPE
14
- from .astradb import astra_db_destination_entry
14
+ from .astradb import astra_db_destination_entry, astra_db_source_entry
15
15
  from .azure_cognitive_search import CONNECTOR_TYPE as AZURE_COGNTIVE_SEARCH_CONNECTOR_TYPE
16
16
  from .azure_cognitive_search import azure_cognitive_search_destination_entry
17
17
  from .chroma import CONNECTOR_TYPE as CHROMA_CONNECTOR_TYPE
@@ -44,13 +44,12 @@ from .salesforce import CONNECTOR_TYPE as SALESFORCE_CONNECTOR_TYPE
44
44
  from .salesforce import salesforce_source_entry
45
45
  from .sharepoint import CONNECTOR_TYPE as SHAREPOINT_CONNECTOR_TYPE
46
46
  from .sharepoint import sharepoint_source_entry
47
- from .singlestore import CONNECTOR_TYPE as SINGLESTORE_CONNECTOR_TYPE
48
- from .singlestore import singlestore_destination_entry
49
47
  from .slack import CONNECTOR_TYPE as SLACK_CONNECTOR_TYPE
50
48
  from .slack import slack_source_entry
51
49
  from .weaviate import CONNECTOR_TYPE as WEAVIATE_CONNECTOR_TYPE
52
50
  from .weaviate import weaviate_destination_entry
53
51
 
52
+ add_source_entry(source_type=ASTRA_DB_CONNECTOR_TYPE, entry=astra_db_source_entry)
54
53
  add_destination_entry(destination_type=ASTRA_DB_CONNECTOR_TYPE, entry=astra_db_destination_entry)
55
54
 
56
55
  add_destination_entry(destination_type=CHROMA_CONNECTOR_TYPE, entry=chroma_destination_entry)
@@ -88,9 +87,7 @@ add_source_entry(source_type=MONGODB_CONNECTOR_TYPE, entry=mongodb_source_entry)
88
87
 
89
88
  add_destination_entry(destination_type=PINECONE_CONNECTOR_TYPE, entry=pinecone_destination_entry)
90
89
  add_source_entry(source_type=SHAREPOINT_CONNECTOR_TYPE, entry=sharepoint_source_entry)
91
- add_destination_entry(
92
- destination_type=SINGLESTORE_CONNECTOR_TYPE, entry=singlestore_destination_entry
93
- )
90
+
94
91
  add_destination_entry(destination_type=MILVUS_CONNECTOR_TYPE, entry=milvus_destination_entry)
95
92
  add_destination_entry(
96
93
  destination_type=AZURE_COGNTIVE_SEARCH_CONNECTOR_TYPE,
@@ -1,31 +1,50 @@
1
+ import copy
2
+ import csv
3
+ import hashlib
1
4
  import json
5
+ import sys
2
6
  from dataclasses import dataclass, field
3
7
  from pathlib import Path
4
- from typing import TYPE_CHECKING, Any, Optional
8
+ from time import time
9
+ from typing import TYPE_CHECKING, Any, Generator, Optional
5
10
 
6
11
  from pydantic import Field, Secret
7
12
 
8
13
  from unstructured_ingest import __name__ as integration_name
9
14
  from unstructured_ingest.__version__ import __version__ as integration_version
10
- from unstructured_ingest.error import DestinationConnectionError
15
+ from unstructured_ingest.error import (
16
+ DestinationConnectionError,
17
+ SourceConnectionError,
18
+ SourceConnectionNetworkError,
19
+ )
11
20
  from unstructured_ingest.utils.data_prep import batch_generator
12
21
  from unstructured_ingest.utils.dep_check import requires_dependencies
13
22
  from unstructured_ingest.v2.interfaces import (
14
23
  AccessConfig,
15
24
  ConnectionConfig,
25
+ Downloader,
26
+ DownloaderConfig,
27
+ DownloadResponse,
16
28
  FileData,
29
+ FileDataSourceMetadata,
30
+ Indexer,
31
+ IndexerConfig,
17
32
  Uploader,
18
33
  UploaderConfig,
19
34
  UploadStager,
20
35
  UploadStagerConfig,
36
+ download_responses,
21
37
  )
22
38
  from unstructured_ingest.v2.logger import logger
23
39
  from unstructured_ingest.v2.processes.connector_registry import (
24
40
  DestinationRegistryEntry,
41
+ SourceRegistryEntry,
25
42
  )
26
43
 
27
44
  if TYPE_CHECKING:
45
+ from astrapy import AsyncCollection as AstraDBAsyncCollection
28
46
  from astrapy import Collection as AstraDBCollection
47
+ from astrapy import DataAPIClient as AstraDBClient
29
48
 
30
49
 
31
50
  CONNECTOR_TYPE = "astradb"
@@ -37,14 +56,253 @@ class AstraDBAccessConfig(AccessConfig):
37
56
 
38
57
 
39
58
  class AstraDBConnectionConfig(ConnectionConfig):
40
- connection_type: str = Field(default=CONNECTOR_TYPE, init=False)
59
+ connector_type: str = Field(default=CONNECTOR_TYPE, init=False)
41
60
  access_config: Secret[AstraDBAccessConfig]
42
61
 
62
+ @requires_dependencies(["astrapy"], extras="astradb")
63
+ def get_client(self) -> "AstraDBClient":
64
+ from astrapy import DataAPIClient as AstraDBClient
65
+
66
+ # Create a client object to interact with the Astra DB
67
+ # caller_name/version for Astra DB tracking
68
+ return AstraDBClient(
69
+ caller_name=integration_name,
70
+ caller_version=integration_version,
71
+ )
72
+
73
+
74
+ def get_astra_collection(
75
+ connection_config: AstraDBConnectionConfig,
76
+ collection_name: str,
77
+ keyspace: str,
78
+ ) -> "AstraDBCollection":
79
+ # Build the Astra DB object.
80
+ access_configs = connection_config.access_config.get_secret_value()
81
+
82
+ # Create a client object to interact with the Astra DB
83
+ # caller_name/version for Astra DB tracking
84
+ client = connection_config.get_client()
85
+
86
+ # Get the database object
87
+ astra_db = client.get_database(
88
+ api_endpoint=access_configs.api_endpoint,
89
+ token=access_configs.token,
90
+ keyspace=keyspace,
91
+ )
92
+
93
+ # Connect to the collection
94
+ astra_db_collection = astra_db.get_collection(name=collection_name)
95
+ return astra_db_collection
96
+
97
+
98
+ async def get_async_astra_collection(
99
+ connection_config: AstraDBConnectionConfig,
100
+ collection_name: str,
101
+ keyspace: str,
102
+ ) -> "AstraDBAsyncCollection":
103
+ # Build the Astra DB object.
104
+ access_configs = connection_config.access_config.get_secret_value()
105
+
106
+ # Create a client object to interact with the Astra DB
107
+ client = connection_config.get_client()
108
+
109
+ # Get the async database object
110
+ async_astra_db = client.get_async_database(
111
+ api_endpoint=access_configs.api_endpoint,
112
+ token=access_configs.token,
113
+ keyspace=keyspace,
114
+ )
115
+
116
+ # Get async collection from AsyncDatabase
117
+ async_astra_db_collection = await async_astra_db.get_collection(name=collection_name)
118
+ return async_astra_db_collection
119
+
43
120
 
44
121
  class AstraDBUploadStagerConfig(UploadStagerConfig):
45
122
  pass
46
123
 
47
124
 
125
+ class AstraDBIndexerConfig(IndexerConfig):
126
+ collection_name: str = Field(
127
+ description="The name of the Astra DB collection. "
128
+ "Note that the collection name must only include letters, "
129
+ "numbers, and underscores."
130
+ )
131
+ keyspace: Optional[str] = Field(default=None, description="The Astra DB connection keyspace.")
132
+ namespace: Optional[str] = Field(
133
+ default=None,
134
+ description="The Astra DB connection namespace.",
135
+ deprecated="Please use 'keyspace' instead.",
136
+ )
137
+ batch_size: int = Field(default=20, description="Number of records per batch")
138
+
139
+
140
+ class AstraDBDownloaderConfig(DownloaderConfig):
141
+ fields: list[str] = field(default_factory=list)
142
+
143
+
144
+ class AstraDBUploaderConfig(UploaderConfig):
145
+ collection_name: str = Field(
146
+ description="The name of the Astra DB collection. "
147
+ "Note that the collection name must only include letters, "
148
+ "numbers, and underscores."
149
+ )
150
+ embedding_dimension: int = Field(
151
+ default=384, description="The dimensionality of the embeddings"
152
+ )
153
+ keyspace: Optional[str] = Field(default=None, description="The Astra DB connection keyspace.")
154
+ namespace: Optional[str] = Field(
155
+ default=None,
156
+ description="The Astra DB connection namespace.",
157
+ deprecated="Please use 'keyspace' instead.",
158
+ )
159
+ requested_indexing_policy: Optional[dict[str, Any]] = Field(
160
+ default=None,
161
+ description="The indexing policy to use for the collection.",
162
+ examples=['{"deny": ["metadata"]}'],
163
+ )
164
+ batch_size: int = Field(default=20, description="Number of records per batch")
165
+
166
+
167
+ @dataclass
168
+ class AstraDBIndexer(Indexer):
169
+ connection_config: AstraDBConnectionConfig
170
+ index_config: AstraDBIndexerConfig
171
+
172
+ def get_collection(self) -> "AstraDBCollection":
173
+ return get_astra_collection(
174
+ connection_config=self.connection_config,
175
+ collection_name=self.index_config.collection_name,
176
+ keyspace=self.index_config.keyspace or self.index_config.namespace,
177
+ )
178
+
179
+ def precheck(self) -> None:
180
+ try:
181
+ self.get_collection()
182
+ except Exception as e:
183
+ logger.error(f"Failed to validate connection {e}", exc_info=True)
184
+ raise SourceConnectionError(f"failed to validate connection: {e}")
185
+
186
+ def _get_doc_ids(self) -> set[str]:
187
+ """Fetches all document ids in an index"""
188
+ # Initialize set of ids
189
+ ids = set()
190
+
191
+ # Get the collection
192
+ collection = self.get_collection()
193
+
194
+ # Perform the find operation to get all items
195
+ astra_db_docs_cursor = collection.find({}, projection={"_id": True})
196
+
197
+ # Iterate over the cursor
198
+ astra_db_docs = []
199
+ for result in astra_db_docs_cursor:
200
+ astra_db_docs.append(result)
201
+
202
+ # Create file data for each astra record
203
+ for astra_record in astra_db_docs:
204
+ ids.add(astra_record["_id"])
205
+
206
+ return ids
207
+
208
+ def run(self, **kwargs: Any) -> Generator[FileData, None, None]:
209
+ all_ids = self._get_doc_ids()
210
+ ids = list(all_ids)
211
+ id_batches = batch_generator(ids, self.index_config.batch_size)
212
+
213
+ for batch in id_batches:
214
+ # Make sure the hash is always a positive number to create identified
215
+ identified = str(hash(batch) + sys.maxsize + 1)
216
+ fd = FileData(
217
+ identifier=identified,
218
+ connector_type=CONNECTOR_TYPE,
219
+ doc_type="batch",
220
+ metadata=FileDataSourceMetadata(
221
+ date_processed=str(time()),
222
+ ),
223
+ additional_metadata={
224
+ "ids": list(batch),
225
+ "collection_name": self.index_config.collection_name,
226
+ "keyspace": self.index_config.keyspace or self.index_config.namespace,
227
+ },
228
+ )
229
+ yield fd
230
+
231
+
232
+ @dataclass
233
+ class AstraDBDownloader(Downloader):
234
+ connection_config: AstraDBConnectionConfig
235
+ download_config: AstraDBDownloaderConfig
236
+ connector_type: str = CONNECTOR_TYPE
237
+
238
+ def is_async(self) -> bool:
239
+ return True
240
+
241
+ def get_identifier(self, record_id: str) -> str:
242
+ f = f"{record_id}"
243
+ if self.download_config.fields:
244
+ f = "{}-{}".format(
245
+ f,
246
+ hashlib.sha256(",".join(self.download_config.fields).encode()).hexdigest()[:8],
247
+ )
248
+ return f
249
+
250
+ def write_astra_result_to_csv(self, astra_result: dict, download_path: str) -> None:
251
+ with open(download_path, "w", encoding="utf8") as f:
252
+ writer = csv.writer(f)
253
+ writer.writerow(astra_result.keys())
254
+ writer.writerow(astra_result.values())
255
+
256
+ def generate_download_response(self, result: dict, file_data: FileData) -> DownloadResponse:
257
+ record_id = result["_id"]
258
+ filename_id = self.get_identifier(record_id=record_id)
259
+ filename = f"{filename_id}.csv" # csv to preserve column info
260
+ download_path = self.download_dir / Path(filename)
261
+ logger.debug(f"Downloading results from record {record_id} as csv to {download_path}")
262
+ download_path.parent.mkdir(parents=True, exist_ok=True)
263
+ try:
264
+ self.write_astra_result_to_csv(astra_result=result, download_path=download_path)
265
+ except Exception as e:
266
+ logger.error(
267
+ f"failed to download from record {record_id} to {download_path}: {e}",
268
+ exc_info=True,
269
+ )
270
+ raise SourceConnectionNetworkError(f"failed to download file {file_data.identifier}")
271
+
272
+ # modify input file_data for download_response
273
+ copied_file_data = copy.deepcopy(file_data)
274
+ copied_file_data.identifier = filename
275
+ copied_file_data.doc_type = "file"
276
+ copied_file_data.metadata.date_processed = str(time())
277
+ copied_file_data.metadata.record_locator = {"document_id": record_id}
278
+ copied_file_data.additional_metadata.pop("ids", None)
279
+ return super().generate_download_response(
280
+ file_data=copied_file_data, download_path=download_path
281
+ )
282
+
283
+ def run(self, file_data: FileData, **kwargs: Any) -> download_responses:
284
+ raise NotImplementedError("Use astradb run_async instead")
285
+
286
+ async def run_async(self, file_data: FileData, **kwargs: Any) -> download_responses:
287
+ # Get metadata from file_data
288
+ ids: list[str] = file_data.additional_metadata["ids"]
289
+ collection_name: str = file_data.additional_metadata["collection_name"]
290
+ keyspace: str = file_data.additional_metadata["keyspace"]
291
+
292
+ # Retrieve results from async collection
293
+ download_responses = []
294
+ async_astra_collection = await get_async_astra_collection(
295
+ connection_config=self.connection_config,
296
+ collection_name=collection_name,
297
+ keyspace=keyspace,
298
+ )
299
+ async for result in async_astra_collection.find({"_id": {"$in": ids}}):
300
+ download_responses.append(
301
+ self.generate_download_response(result=result, file_data=file_data)
302
+ )
303
+ return download_responses
304
+
305
+
48
306
  @dataclass
49
307
  class AstraDBUploadStager(UploadStager):
50
308
  upload_stager_config: AstraDBUploadStagerConfig = field(
@@ -77,29 +335,6 @@ class AstraDBUploadStager(UploadStager):
77
335
  return output_path
78
336
 
79
337
 
80
- class AstraDBUploaderConfig(UploaderConfig):
81
- collection_name: str = Field(
82
- description="The name of the Astra DB collection. "
83
- "Note that the collection name must only include letters, "
84
- "numbers, and underscores."
85
- )
86
- embedding_dimension: int = Field(
87
- default=384, description="The dimensionality of the embeddings"
88
- )
89
- keyspace: Optional[str] = Field(default=None, description="The Astra DB connection keyspace.")
90
- namespace: Optional[str] = Field(
91
- default=None,
92
- description="The Astra DB connection namespace.",
93
- deprecated="Please use 'keyspace' instead.",
94
- )
95
- requested_indexing_policy: Optional[dict[str, Any]] = Field(
96
- default=None,
97
- description="The indexing policy to use for the collection.",
98
- examples=['{"deny": ["metadata"]}'],
99
- )
100
- batch_size: int = Field(default=20, description="Number of records per batch")
101
-
102
-
103
338
  @dataclass
104
339
  class AstraDBUploader(Uploader):
105
340
  connection_config: AstraDBConnectionConfig
@@ -108,43 +343,23 @@ class AstraDBUploader(Uploader):
108
343
 
109
344
  def precheck(self) -> None:
110
345
  try:
111
- self.get_collection()
346
+ get_astra_collection(
347
+ connection_config=self.connection_config,
348
+ collection_name=self.upload_config.collection_name,
349
+ keyspace=self.upload_config.keyspace or self.upload_config.namespace,
350
+ )
112
351
  except Exception as e:
113
352
  logger.error(f"Failed to validate connection {e}", exc_info=True)
114
353
  raise DestinationConnectionError(f"failed to validate connection: {e}")
115
354
 
116
355
  @requires_dependencies(["astrapy"], extras="astradb")
117
356
  def get_collection(self) -> "AstraDBCollection":
118
- from astrapy import DataAPIClient as AstraDBClient
119
-
120
- # Choose keyspace or deprecated namespace
121
- keyspace_param = self.upload_config.keyspace or self.upload_config.namespace
122
-
123
- # Get the collection_name
124
- collection_name = self.upload_config.collection_name
125
-
126
- # Build the Astra DB object.
127
- access_configs = self.connection_config.access_config.get_secret_value()
128
-
129
- # Create a client object to interact with the Astra DB
130
- # caller_name/version for Astra DB tracking
131
- my_client = AstraDBClient(
132
- caller_name=integration_name,
133
- caller_version=integration_version,
134
- )
135
-
136
- # Get the database object
137
- astra_db = my_client.get_database(
138
- api_endpoint=access_configs.api_endpoint,
139
- token=access_configs.token,
140
- keyspace=keyspace_param,
357
+ return get_astra_collection(
358
+ connection_config=self.connection_config,
359
+ collection_name=self.upload_config.collection_name,
360
+ keyspace=self.upload_config.keyspace or self.upload_config.namespace,
141
361
  )
142
362
 
143
- # Connect to the newly created collection
144
- astra_db_collection = astra_db.get_collection(name=collection_name)
145
-
146
- return astra_db_collection
147
-
148
363
  def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
149
364
  with path.open("r") as file:
150
365
  elements_dict = json.load(file)
@@ -160,6 +375,14 @@ class AstraDBUploader(Uploader):
160
375
  collection.insert_many(chunk)
161
376
 
162
377
 
378
+ astra_db_source_entry = SourceRegistryEntry(
379
+ indexer=AstraDBIndexer,
380
+ indexer_config=AstraDBIndexerConfig,
381
+ downloader=AstraDBDownloader,
382
+ downloader_config=AstraDBDownloaderConfig,
383
+ connection_config=AstraDBConnectionConfig,
384
+ )
385
+
163
386
  astra_db_destination_entry = DestinationRegistryEntry(
164
387
  connection_config=AstraDBConnectionConfig,
165
388
  upload_stager_config=AstraDBUploadStagerConfig,
@@ -166,7 +166,9 @@ class DatabricksVolumesUploader(Uploader, ABC):
166
166
  raise DestinationConnectionError(f"failed to validate connection: {e}")
167
167
 
168
168
  def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
169
- output_path = os.path.join(self.upload_config.path, file_data.source_identifiers.filename)
169
+ output_path = os.path.join(
170
+ self.upload_config.path, f"{file_data.source_identifiers.filename}.json"
171
+ )
170
172
  with open(path, "rb") as elements_file:
171
173
  self.connection_config.get_client().files.upload(
172
174
  file_path=output_path,
@@ -176,6 +176,7 @@ class FsspecIndexer(Indexer):
176
176
  ),
177
177
  metadata=self.get_metadata(file_data=file_data),
178
178
  additional_metadata=additional_metadata,
179
+ display_name=file_path,
179
180
  )
180
181
 
181
182