unstructured-ingest 0.3.8__py3-none-any.whl → 0.3.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unstructured-ingest might be problematic. Click here for more details.

Files changed (87) hide show
  1. test/integration/chunkers/test_chunkers.py +0 -11
  2. test/integration/connectors/conftest.py +11 -1
  3. test/integration/connectors/databricks_tests/test_volumes_native.py +4 -3
  4. test/integration/connectors/duckdb/conftest.py +14 -0
  5. test/integration/connectors/duckdb/test_duckdb.py +51 -44
  6. test/integration/connectors/duckdb/test_motherduck.py +37 -48
  7. test/integration/connectors/elasticsearch/test_elasticsearch.py +26 -4
  8. test/integration/connectors/elasticsearch/test_opensearch.py +26 -3
  9. test/integration/connectors/sql/test_postgres.py +103 -92
  10. test/integration/connectors/sql/test_singlestore.py +112 -100
  11. test/integration/connectors/sql/test_snowflake.py +142 -117
  12. test/integration/connectors/sql/test_sqlite.py +87 -76
  13. test/integration/connectors/test_astradb.py +62 -1
  14. test/integration/connectors/test_azure_ai_search.py +25 -3
  15. test/integration/connectors/test_chroma.py +120 -0
  16. test/integration/connectors/test_confluence.py +4 -4
  17. test/integration/connectors/test_delta_table.py +1 -0
  18. test/integration/connectors/test_kafka.py +6 -6
  19. test/integration/connectors/test_milvus.py +21 -0
  20. test/integration/connectors/test_mongodb.py +7 -4
  21. test/integration/connectors/test_neo4j.py +236 -0
  22. test/integration/connectors/test_pinecone.py +25 -1
  23. test/integration/connectors/test_qdrant.py +25 -2
  24. test/integration/connectors/test_s3.py +9 -6
  25. test/integration/connectors/utils/docker.py +6 -0
  26. test/integration/connectors/utils/validation/__init__.py +0 -0
  27. test/integration/connectors/utils/validation/destination.py +88 -0
  28. test/integration/connectors/utils/validation/equality.py +75 -0
  29. test/integration/connectors/utils/{validation.py → validation/source.py} +42 -98
  30. test/integration/connectors/utils/validation/utils.py +36 -0
  31. unstructured_ingest/__version__.py +1 -1
  32. unstructured_ingest/utils/chunking.py +11 -0
  33. unstructured_ingest/utils/data_prep.py +36 -0
  34. unstructured_ingest/v2/interfaces/__init__.py +3 -1
  35. unstructured_ingest/v2/interfaces/file_data.py +58 -14
  36. unstructured_ingest/v2/interfaces/upload_stager.py +70 -6
  37. unstructured_ingest/v2/interfaces/uploader.py +11 -2
  38. unstructured_ingest/v2/pipeline/steps/chunk.py +2 -1
  39. unstructured_ingest/v2/pipeline/steps/download.py +5 -4
  40. unstructured_ingest/v2/pipeline/steps/embed.py +2 -1
  41. unstructured_ingest/v2/pipeline/steps/filter.py +2 -2
  42. unstructured_ingest/v2/pipeline/steps/index.py +4 -4
  43. unstructured_ingest/v2/pipeline/steps/partition.py +3 -2
  44. unstructured_ingest/v2/pipeline/steps/stage.py +5 -3
  45. unstructured_ingest/v2/pipeline/steps/uncompress.py +2 -2
  46. unstructured_ingest/v2/pipeline/steps/upload.py +3 -3
  47. unstructured_ingest/v2/processes/connectors/__init__.py +3 -0
  48. unstructured_ingest/v2/processes/connectors/astradb.py +43 -63
  49. unstructured_ingest/v2/processes/connectors/azure_ai_search.py +16 -40
  50. unstructured_ingest/v2/processes/connectors/chroma.py +36 -59
  51. unstructured_ingest/v2/processes/connectors/couchbase.py +92 -93
  52. unstructured_ingest/v2/processes/connectors/delta_table.py +11 -33
  53. unstructured_ingest/v2/processes/connectors/duckdb/base.py +26 -26
  54. unstructured_ingest/v2/processes/connectors/duckdb/duckdb.py +29 -20
  55. unstructured_ingest/v2/processes/connectors/duckdb/motherduck.py +37 -44
  56. unstructured_ingest/v2/processes/connectors/elasticsearch/elasticsearch.py +46 -75
  57. unstructured_ingest/v2/processes/connectors/fsspec/azure.py +12 -35
  58. unstructured_ingest/v2/processes/connectors/fsspec/box.py +12 -35
  59. unstructured_ingest/v2/processes/connectors/fsspec/dropbox.py +15 -42
  60. unstructured_ingest/v2/processes/connectors/fsspec/fsspec.py +33 -29
  61. unstructured_ingest/v2/processes/connectors/fsspec/gcs.py +12 -34
  62. unstructured_ingest/v2/processes/connectors/fsspec/s3.py +13 -37
  63. unstructured_ingest/v2/processes/connectors/fsspec/sftp.py +19 -33
  64. unstructured_ingest/v2/processes/connectors/gitlab.py +32 -31
  65. unstructured_ingest/v2/processes/connectors/google_drive.py +32 -29
  66. unstructured_ingest/v2/processes/connectors/kafka/kafka.py +2 -4
  67. unstructured_ingest/v2/processes/connectors/kdbai.py +44 -70
  68. unstructured_ingest/v2/processes/connectors/lancedb/lancedb.py +8 -10
  69. unstructured_ingest/v2/processes/connectors/local.py +13 -2
  70. unstructured_ingest/v2/processes/connectors/milvus.py +16 -57
  71. unstructured_ingest/v2/processes/connectors/mongodb.py +99 -108
  72. unstructured_ingest/v2/processes/connectors/neo4j.py +383 -0
  73. unstructured_ingest/v2/processes/connectors/onedrive.py +1 -1
  74. unstructured_ingest/v2/processes/connectors/pinecone.py +3 -33
  75. unstructured_ingest/v2/processes/connectors/qdrant/qdrant.py +32 -41
  76. unstructured_ingest/v2/processes/connectors/sql/postgres.py +5 -5
  77. unstructured_ingest/v2/processes/connectors/sql/singlestore.py +5 -5
  78. unstructured_ingest/v2/processes/connectors/sql/snowflake.py +5 -5
  79. unstructured_ingest/v2/processes/connectors/sql/sql.py +72 -66
  80. unstructured_ingest/v2/processes/connectors/sql/sqlite.py +5 -5
  81. unstructured_ingest/v2/processes/connectors/weaviate/weaviate.py +9 -31
  82. {unstructured_ingest-0.3.8.dist-info → unstructured_ingest-0.3.10.dist-info}/METADATA +20 -15
  83. {unstructured_ingest-0.3.8.dist-info → unstructured_ingest-0.3.10.dist-info}/RECORD +87 -79
  84. {unstructured_ingest-0.3.8.dist-info → unstructured_ingest-0.3.10.dist-info}/LICENSE.md +0 -0
  85. {unstructured_ingest-0.3.8.dist-info → unstructured_ingest-0.3.10.dist-info}/WHEEL +0 -0
  86. {unstructured_ingest-0.3.8.dist-info → unstructured_ingest-0.3.10.dist-info}/entry_points.txt +0 -0
  87. {unstructured_ingest-0.3.8.dist-info → unstructured_ingest-0.3.10.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,5 @@
1
1
  import json
2
2
  from dataclasses import dataclass, field
3
- from pathlib import Path
4
3
  from typing import TYPE_CHECKING, Any, Optional
5
4
 
6
5
  from pydantic import Field, Secret
@@ -159,33 +158,6 @@ class PineconeUploadStager(UploadStager):
159
158
  "metadata": metadata,
160
159
  }
161
160
 
162
- def run(
163
- self,
164
- file_data: FileData,
165
- elements_filepath: Path,
166
- output_dir: Path,
167
- output_filename: str,
168
- **kwargs: Any,
169
- ) -> Path:
170
- with open(elements_filepath) as elements_file:
171
- elements_contents = json.load(elements_file)
172
-
173
- conformed_elements = [
174
- self.conform_dict(element_dict=element, file_data=file_data)
175
- for element in elements_contents
176
- ]
177
-
178
- if Path(output_filename).suffix != ".json":
179
- output_filename = f"{output_filename}.json"
180
- else:
181
- output_filename = f"{Path(output_filename).stem}.json"
182
- output_path = Path(output_dir) / Path(f"{output_filename}")
183
- output_path.parent.mkdir(parents=True, exist_ok=True)
184
-
185
- with open(output_path, "w") as output_file:
186
- json.dump(conformed_elements, output_file)
187
- return output_path
188
-
189
161
 
190
162
  @dataclass
191
163
  class PineconeUploader(Uploader):
@@ -278,11 +250,9 @@ class PineconeUploader(Uploader):
278
250
  raise DestinationConnectionError(f"http error: {api_error}") from api_error
279
251
  logger.debug(f"results: {results}")
280
252
 
281
- def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
282
- with path.open("r") as file:
283
- elements_dict = json.load(file)
253
+ def run_data(self, data: list[dict], file_data: FileData, **kwargs: Any) -> None:
284
254
  logger.info(
285
- f"writing a total of {len(elements_dict)} elements via"
255
+ f"writing a total of {len(data)} elements via"
286
256
  f" document batches to destination"
287
257
  f" index named {self.connection_config.index_name}"
288
258
  )
@@ -295,7 +265,7 @@ class PineconeUploader(Uploader):
295
265
  self.pod_delete_by_record_id(file_data=file_data)
296
266
  else:
297
267
  raise ValueError(f"unexpected spec type in index description: {index_description}")
298
- self.upsert_batches_async(elements_dict=elements_dict)
268
+ self.upsert_batches_async(elements_dict=data)
299
269
 
300
270
 
301
271
  pinecone_destination_entry = DestinationRegistryEntry(
@@ -1,10 +1,9 @@
1
1
  import asyncio
2
2
  import json
3
3
  from abc import ABC, abstractmethod
4
- from contextlib import asynccontextmanager
4
+ from contextlib import asynccontextmanager, contextmanager
5
5
  from dataclasses import dataclass, field
6
- from pathlib import Path
7
- from typing import TYPE_CHECKING, Any, AsyncGenerator, Optional
6
+ from typing import TYPE_CHECKING, Any, AsyncGenerator, Generator, Optional
8
7
 
9
8
  from pydantic import Field, Secret
10
9
 
@@ -24,7 +23,7 @@ from unstructured_ingest.v2.logger import logger
24
23
  from unstructured_ingest.v2.utils import get_enhanced_element_id
25
24
 
26
25
  if TYPE_CHECKING:
27
- from qdrant_client import AsyncQdrantClient
26
+ from qdrant_client import AsyncQdrantClient, QdrantClient
28
27
 
29
28
 
30
29
  class QdrantAccessConfig(AccessConfig, ABC):
@@ -42,8 +41,8 @@ class QdrantConnectionConfig(ConnectionConfig, ABC):
42
41
 
43
42
  @requires_dependencies(["qdrant_client"], extras="qdrant")
44
43
  @asynccontextmanager
45
- async def get_client(self) -> AsyncGenerator["AsyncQdrantClient", None]:
46
- from qdrant_client.async_qdrant_client import AsyncQdrantClient
44
+ async def get_async_client(self) -> AsyncGenerator["AsyncQdrantClient", None]:
45
+ from qdrant_client import AsyncQdrantClient
47
46
 
48
47
  client_kwargs = self.get_client_kwargs()
49
48
  client = AsyncQdrantClient(**client_kwargs)
@@ -52,6 +51,18 @@ class QdrantConnectionConfig(ConnectionConfig, ABC):
52
51
  finally:
53
52
  await client.close()
54
53
 
54
+ @requires_dependencies(["qdrant_client"], extras="qdrant")
55
+ @contextmanager
56
+ def get_client(self) -> Generator["QdrantClient", None, None]:
57
+ from qdrant_client import QdrantClient
58
+
59
+ client_kwargs = self.get_client_kwargs()
60
+ client = QdrantClient(**client_kwargs)
61
+ try:
62
+ yield client
63
+ finally:
64
+ client.close()
65
+
55
66
 
56
67
  class QdrantUploadStagerConfig(UploadStagerConfig):
57
68
  pass
@@ -63,9 +74,9 @@ class QdrantUploadStager(UploadStager, ABC):
63
74
  default_factory=lambda: QdrantUploadStagerConfig()
64
75
  )
65
76
 
66
- @staticmethod
67
- def conform_dict(data: dict, file_data: FileData) -> dict:
77
+ def conform_dict(self, element_dict: dict, file_data: FileData) -> dict:
68
78
  """Prepares dictionary in the format that Chroma requires"""
79
+ data = element_dict.copy()
69
80
  return {
70
81
  "id": get_enhanced_element_id(element_dict=data, file_data=file_data),
71
82
  "vector": data.pop("embeddings", {}),
@@ -80,26 +91,6 @@ class QdrantUploadStager(UploadStager, ABC):
80
91
  },
81
92
  }
82
93
 
83
- def run(
84
- self,
85
- elements_filepath: Path,
86
- file_data: FileData,
87
- output_dir: Path,
88
- output_filename: str,
89
- **kwargs: Any,
90
- ) -> Path:
91
- with open(elements_filepath) as elements_file:
92
- elements_contents = json.load(elements_file)
93
-
94
- conformed_elements = [
95
- self.conform_dict(data=element, file_data=file_data) for element in elements_contents
96
- ]
97
- output_path = Path(output_dir) / Path(f"{output_filename}.json")
98
-
99
- with open(output_path, "w") as output_file:
100
- json.dump(conformed_elements, output_file)
101
- return output_path
102
-
103
94
 
104
95
  class QdrantUploaderConfig(UploaderConfig):
105
96
  collection_name: str = Field(description="Name of the collection.")
@@ -118,27 +109,27 @@ class QdrantUploader(Uploader, ABC):
118
109
 
119
110
  @DestinationConnectionError.wrap
120
111
  def precheck(self) -> None:
121
- async def check_connection():
122
- async with self.connection_config.get_client() as async_client:
123
- await async_client.get_collections()
124
-
125
- asyncio.run(check_connection())
112
+ with self.connection_config.get_client() as client:
113
+ collections_response = client.get_collections()
114
+ collection_names = [c.name for c in collections_response.collections]
115
+ if self.upload_config.collection_name not in collection_names:
116
+ raise DestinationConnectionError(
117
+ "collection '{}' not found: {}".format(
118
+ self.upload_config.collection_name, ", ".join(collection_names)
119
+ )
120
+ )
126
121
 
127
122
  def is_async(self):
128
123
  return True
129
124
 
130
- async def run_async(
125
+ async def run_data_async(
131
126
  self,
132
- path: Path,
127
+ data: list[dict],
133
128
  file_data: FileData,
134
129
  **kwargs: Any,
135
130
  ) -> None:
136
- with path.open("r") as file:
137
- elements: list[dict] = json.load(file)
138
-
139
- logger.debug("Loaded %i elements from %s", len(elements), path)
140
131
 
141
- batches = list(batch_generator(elements, batch_size=self.upload_config.batch_size))
132
+ batches = list(batch_generator(data, batch_size=self.upload_config.batch_size))
142
133
  logger.debug(
143
134
  "Elements split into %i batches of size %i.",
144
135
  len(batches),
@@ -156,7 +147,7 @@ class QdrantUploader(Uploader, ABC):
156
147
  len(points),
157
148
  self.upload_config.collection_name,
158
149
  )
159
- async with self.connection_config.get_client() as async_client:
150
+ async with self.connection_config.get_async_client() as async_client:
160
151
  await async_client.upsert(
161
152
  self.upload_config.collection_name, points=points, wait=True
162
153
  )
@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Generator, Optional
5
5
  from pydantic import Field, Secret
6
6
 
7
7
  from unstructured_ingest.utils.dep_check import requires_dependencies
8
- from unstructured_ingest.v2.interfaces import FileData
9
8
  from unstructured_ingest.v2.logger import logger
10
9
  from unstructured_ingest.v2.processes.connector_registry import (
11
10
  DestinationRegistryEntry,
@@ -13,6 +12,7 @@ from unstructured_ingest.v2.processes.connector_registry import (
13
12
  )
14
13
  from unstructured_ingest.v2.processes.connectors.sql.sql import (
15
14
  SQLAccessConfig,
15
+ SqlBatchFileData,
16
16
  SQLConnectionConfig,
17
17
  SQLDownloader,
18
18
  SQLDownloaderConfig,
@@ -99,12 +99,12 @@ class PostgresDownloader(SQLDownloader):
99
99
  connector_type: str = CONNECTOR_TYPE
100
100
 
101
101
  @requires_dependencies(["psycopg2"], extras="postgres")
102
- def query_db(self, file_data: FileData) -> tuple[list[tuple], list[str]]:
102
+ def query_db(self, file_data: SqlBatchFileData) -> tuple[list[tuple], list[str]]:
103
103
  from psycopg2 import sql
104
104
 
105
- table_name = file_data.additional_metadata["table_name"]
106
- id_column = file_data.additional_metadata["id_column"]
107
- ids = tuple(file_data.additional_metadata["ids"])
105
+ table_name = file_data.additional_metadata.table_name
106
+ id_column = file_data.additional_metadata.id_column
107
+ ids = tuple([item.identifier for item in file_data.batch_items])
108
108
 
109
109
  with self.connection_config.get_cursor() as cursor:
110
110
  fields = (
@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, Any, Generator, Optional
5
5
 
6
6
  from pydantic import Field, Secret
7
7
 
8
- from unstructured_ingest.v2.interfaces import FileData
9
8
  from unstructured_ingest.v2.logger import logger
10
9
  from unstructured_ingest.v2.processes.connector_registry import (
11
10
  DestinationRegistryEntry,
@@ -14,6 +13,7 @@ from unstructured_ingest.v2.processes.connector_registry import (
14
13
  from unstructured_ingest.v2.processes.connectors.sql.sql import (
15
14
  _DATE_COLUMNS,
16
15
  SQLAccessConfig,
16
+ SqlBatchFileData,
17
17
  SQLConnectionConfig,
18
18
  SQLDownloader,
19
19
  SQLDownloaderConfig,
@@ -93,10 +93,10 @@ class SingleStoreDownloader(SQLDownloader):
93
93
  connector_type: str = CONNECTOR_TYPE
94
94
  values_delimiter: str = "%s"
95
95
 
96
- def query_db(self, file_data: FileData) -> tuple[list[tuple], list[str]]:
97
- table_name = file_data.additional_metadata["table_name"]
98
- id_column = file_data.additional_metadata["id_column"]
99
- ids = tuple(file_data.additional_metadata["ids"])
96
+ def query_db(self, file_data: SqlBatchFileData) -> tuple[list[tuple], list[str]]:
97
+ table_name = file_data.additional_metadata.table_name
98
+ id_column = file_data.additional_metadata.id_column
99
+ ids = tuple([item.identifier for item in file_data.batch_items])
100
100
  with self.connection_config.get_connection() as sqlite_connection:
101
101
  cursor = sqlite_connection.cursor()
102
102
  fields = ",".join(self.download_config.fields) if self.download_config.fields else "*"
@@ -9,7 +9,6 @@ from pydantic import Field, Secret
9
9
 
10
10
  from unstructured_ingest.utils.data_prep import split_dataframe
11
11
  from unstructured_ingest.utils.dep_check import requires_dependencies
12
- from unstructured_ingest.v2.interfaces.file_data import FileData
13
12
  from unstructured_ingest.v2.logger import logger
14
13
  from unstructured_ingest.v2.processes.connector_registry import (
15
14
  DestinationRegistryEntry,
@@ -17,6 +16,7 @@ from unstructured_ingest.v2.processes.connector_registry import (
17
16
  )
18
17
  from unstructured_ingest.v2.processes.connectors.sql.sql import (
19
18
  SQLAccessConfig,
19
+ SqlBatchFileData,
20
20
  SQLConnectionConfig,
21
21
  SQLDownloader,
22
22
  SQLDownloaderConfig,
@@ -118,10 +118,10 @@ class SnowflakeDownloader(SQLDownloader):
118
118
 
119
119
  # The actual snowflake module package name is: snowflake-connector-python
120
120
  @requires_dependencies(["snowflake"], extras="snowflake")
121
- def query_db(self, file_data: FileData) -> tuple[list[tuple], list[str]]:
122
- table_name = file_data.additional_metadata["table_name"]
123
- id_column = file_data.additional_metadata["id_column"]
124
- ids = file_data.additional_metadata["ids"]
121
+ def query_db(self, file_data: SqlBatchFileData) -> tuple[list[tuple], list[str]]:
122
+ table_name = file_data.additional_metadata.table_name
123
+ id_column = file_data.additional_metadata.id_column
124
+ ids = [item.identifier for item in file_data.batch_items]
125
125
 
126
126
  with self.connection_config.get_cursor() as cursor:
127
127
  query = "SELECT {fields} FROM {table_name} WHERE {id_column} IN ({values})".format(
@@ -1,9 +1,8 @@
1
1
  import hashlib
2
2
  import json
3
- import sys
4
3
  from abc import ABC, abstractmethod
5
4
  from contextlib import contextmanager
6
- from dataclasses import dataclass, field, replace
5
+ from dataclasses import dataclass, field
7
6
  from datetime import date, datetime
8
7
  from pathlib import Path
9
8
  from time import time
@@ -12,13 +11,15 @@ from typing import Any, Generator, Union
12
11
  import numpy as np
13
12
  import pandas as pd
14
13
  from dateutil import parser
15
- from pydantic import Field, Secret
14
+ from pydantic import BaseModel, Field, Secret
16
15
 
17
16
  from unstructured_ingest.error import DestinationConnectionError, SourceConnectionError
18
- from unstructured_ingest.utils.data_prep import split_dataframe
17
+ from unstructured_ingest.utils.data_prep import get_data_df, split_dataframe
19
18
  from unstructured_ingest.v2.constants import RECORD_ID_LABEL
20
19
  from unstructured_ingest.v2.interfaces import (
21
20
  AccessConfig,
21
+ BatchFileData,
22
+ BatchItem,
22
23
  ConnectionConfig,
23
24
  Downloader,
24
25
  DownloaderConfig,
@@ -81,6 +82,15 @@ _COLUMNS = (
81
82
  _DATE_COLUMNS = ("date_created", "date_modified", "date_processed", "last_modified")
82
83
 
83
84
 
85
+ class SqlAdditionalMetadata(BaseModel):
86
+ table_name: str
87
+ id_column: str
88
+
89
+
90
+ class SqlBatchFileData(BatchFileData):
91
+ additional_metadata: SqlAdditionalMetadata
92
+
93
+
84
94
  def parse_date_string(date_value: Union[str, int]) -> date:
85
95
  try:
86
96
  timestamp = float(date_value) / 1000 if isinstance(date_value, int) else float(date_value)
@@ -124,7 +134,7 @@ class SQLIndexer(Indexer, ABC):
124
134
  f"SELECT {self.index_config.id_column} FROM {self.index_config.table_name}"
125
135
  )
126
136
  results = cursor.fetchall()
127
- ids = [result[0] for result in results]
137
+ ids = sorted([result[0] for result in results])
128
138
  return ids
129
139
 
130
140
  def precheck(self) -> None:
@@ -135,7 +145,7 @@ class SQLIndexer(Indexer, ABC):
135
145
  logger.error(f"failed to validate connection: {e}", exc_info=True)
136
146
  raise SourceConnectionError(f"failed to validate connection: {e}")
137
147
 
138
- def run(self, **kwargs: Any) -> Generator[FileData, None, None]:
148
+ def run(self, **kwargs: Any) -> Generator[SqlBatchFileData, None, None]:
139
149
  ids = self._get_doc_ids()
140
150
  id_batches: list[frozenset[str]] = [
141
151
  frozenset(
@@ -151,19 +161,15 @@ class SQLIndexer(Indexer, ABC):
151
161
  ]
152
162
  for batch in id_batches:
153
163
  # Make sure the hash is always a positive number to create identified
154
- identified = str(hash(batch) + sys.maxsize + 1)
155
- yield FileData(
156
- identifier=identified,
164
+ yield SqlBatchFileData(
157
165
  connector_type=self.connector_type,
158
166
  metadata=FileDataSourceMetadata(
159
167
  date_processed=str(time()),
160
168
  ),
161
- doc_type="batch",
162
- additional_metadata={
163
- "ids": list(batch),
164
- "table_name": self.index_config.table_name,
165
- "id_column": self.index_config.id_column,
166
- },
169
+ additional_metadata=SqlAdditionalMetadata(
170
+ table_name=self.index_config.table_name, id_column=self.index_config.id_column
171
+ ),
172
+ batch_items=[BatchItem(identifier=str(b)) for b in batch],
167
173
  )
168
174
 
169
175
 
@@ -176,7 +182,7 @@ class SQLDownloader(Downloader, ABC):
176
182
  download_config: SQLDownloaderConfig
177
183
 
178
184
  @abstractmethod
179
- def query_db(self, file_data: FileData) -> tuple[list[tuple], list[str]]:
185
+ def query_db(self, file_data: SqlBatchFileData) -> tuple[list[tuple], list[str]]:
180
186
  pass
181
187
 
182
188
  def sql_to_df(self, rows: list[tuple], columns: list[str]) -> list[pd.DataFrame]:
@@ -185,7 +191,7 @@ class SQLDownloader(Downloader, ABC):
185
191
  dfs = [pd.DataFrame([row.values], columns=df.columns) for index, row in df.iterrows()]
186
192
  return dfs
187
193
 
188
- def get_data(self, file_data: FileData) -> list[pd.DataFrame]:
194
+ def get_data(self, file_data: SqlBatchFileData) -> list[pd.DataFrame]:
189
195
  rows, columns = self.query_db(file_data=file_data)
190
196
  return self.sql_to_df(rows=rows, columns=columns)
191
197
 
@@ -199,10 +205,10 @@ class SQLDownloader(Downloader, ABC):
199
205
  return f
200
206
 
201
207
  def generate_download_response(
202
- self, result: pd.DataFrame, file_data: FileData
208
+ self, result: pd.DataFrame, file_data: SqlBatchFileData
203
209
  ) -> DownloadResponse:
204
- id_column = file_data.additional_metadata["id_column"]
205
- table_name = file_data.additional_metadata["table_name"]
210
+ id_column = file_data.additional_metadata.id_column
211
+ table_name = file_data.additional_metadata.table_name
206
212
  record_id = result.iloc[0][id_column]
207
213
  filename_id = self.get_identifier(table_name=table_name, record_id=record_id)
208
214
  filename = f"{filename_id}.csv"
@@ -212,20 +218,19 @@ class SQLDownloader(Downloader, ABC):
212
218
  )
213
219
  download_path.parent.mkdir(parents=True, exist_ok=True)
214
220
  result.to_csv(download_path, index=False)
215
- copied_file_data = replace(file_data)
216
- copied_file_data.identifier = filename_id
217
- copied_file_data.doc_type = "file"
218
- copied_file_data.additional_metadata.pop("ids", None)
221
+ cast_file_data = FileData.cast(file_data=file_data)
222
+ cast_file_data.identifier = filename_id
219
223
  return super().generate_download_response(
220
- file_data=copied_file_data, download_path=download_path
224
+ file_data=cast_file_data, download_path=download_path
221
225
  )
222
226
 
223
227
  def run(self, file_data: FileData, **kwargs: Any) -> download_responses:
224
- data_dfs = self.get_data(file_data=file_data)
228
+ sql_filedata = SqlBatchFileData.cast(file_data=file_data)
229
+ data_dfs = self.get_data(file_data=sql_filedata)
225
230
  download_responses = []
226
231
  for df in data_dfs:
227
232
  download_responses.append(
228
- self.generate_download_response(result=df, file_data=file_data)
233
+ self.generate_download_response(result=df, file_data=sql_filedata)
229
234
  )
230
235
  return download_responses
231
236
 
@@ -238,27 +243,24 @@ class SQLUploadStagerConfig(UploadStagerConfig):
238
243
  class SQLUploadStager(UploadStager):
239
244
  upload_stager_config: SQLUploadStagerConfig = field(default_factory=SQLUploadStagerConfig)
240
245
 
241
- @staticmethod
242
- def conform_dict(data: dict, file_data: FileData) -> pd.DataFrame:
243
- working_data = data.copy()
244
- output = []
245
- for element in working_data:
246
- metadata: dict[str, Any] = element.pop("metadata", {})
247
- data_source = metadata.pop("data_source", {})
248
- coordinates = metadata.pop("coordinates", {})
246
+ def conform_dict(self, element_dict: dict, file_data: FileData) -> dict:
247
+ data = element_dict.copy()
248
+ metadata: dict[str, Any] = data.pop("metadata", {})
249
+ data_source = metadata.pop("data_source", {})
250
+ coordinates = metadata.pop("coordinates", {})
249
251
 
250
- element.update(metadata)
251
- element.update(data_source)
252
- element.update(coordinates)
252
+ data.update(metadata)
253
+ data.update(data_source)
254
+ data.update(coordinates)
253
255
 
254
- element["id"] = get_enhanced_element_id(element_dict=element, file_data=file_data)
256
+ data["id"] = get_enhanced_element_id(element_dict=data, file_data=file_data)
255
257
 
256
- # remove extraneous, not supported columns
257
- element = {k: v for k, v in element.items() if k in _COLUMNS}
258
- element[RECORD_ID_LABEL] = file_data.identifier
259
- output.append(element)
258
+ # remove extraneous, not supported columns
259
+ element = {k: v for k, v in data.items() if k in _COLUMNS}
260
+ element[RECORD_ID_LABEL] = file_data.identifier
261
+ return element
260
262
 
261
- df = pd.DataFrame.from_dict(output)
263
+ def conform_dataframe(self, df: pd.DataFrame) -> pd.DataFrame:
262
264
  for column in filter(lambda x: x in df.columns, _DATE_COLUMNS):
263
265
  df[column] = df[column].apply(parse_date_string)
264
266
  for column in filter(
@@ -283,19 +285,19 @@ class SQLUploadStager(UploadStager):
283
285
  output_filename: str,
284
286
  **kwargs: Any,
285
287
  ) -> Path:
286
- with open(elements_filepath) as elements_file:
287
- elements_contents: list[dict] = json.load(elements_file)
288
+ elements_contents = self.get_data(elements_filepath=elements_filepath)
288
289
 
289
- df = self.conform_dict(data=elements_contents, file_data=file_data)
290
- if Path(output_filename).suffix != ".json":
291
- output_filename = f"{output_filename}.json"
292
- else:
293
- output_filename = f"{Path(output_filename).stem}.json"
294
- output_path = Path(output_dir) / Path(f"{output_filename}")
295
- output_path.parent.mkdir(parents=True, exist_ok=True)
290
+ df = pd.DataFrame(
291
+ data=[
292
+ self.conform_dict(element_dict=element_dict, file_data=file_data)
293
+ for element_dict in elements_contents
294
+ ]
295
+ )
296
+ df = self.conform_dataframe(df=df)
296
297
 
297
- with output_path.open("w") as output_file:
298
- df.to_json(output_file, orient="records", lines=True)
298
+ output_path = self.get_output_path(output_filename=output_filename, output_dir=output_dir)
299
+
300
+ self.write_output(output_path=output_path, data=df.to_dict(orient="records"))
299
301
  return output_path
300
302
 
301
303
 
@@ -361,8 +363,15 @@ class SQLUploader(Uploader):
361
363
  for column in missing_columns:
362
364
  df[column] = pd.Series()
363
365
 
364
- def upload_contents(self, path: Path) -> None:
365
- df = pd.read_json(path, orient="records", lines=True)
366
+ def upload_dataframe(self, df: pd.DataFrame, file_data: FileData) -> None:
367
+ if self.can_delete():
368
+ self.delete_by_record_id(file_data=file_data)
369
+ else:
370
+ logger.warning(
371
+ f"table doesn't contain expected "
372
+ f"record id column "
373
+ f"{self.upload_config.record_id_key}, skipping delete"
374
+ )
366
375
  df.replace({np.nan: None}, inplace=True)
367
376
  self._fit_to_schema(df=df, columns=self.get_table_columns())
368
377
 
@@ -411,13 +420,10 @@ class SQLUploader(Uploader):
411
420
  rowcount = cursor.rowcount
412
421
  logger.info(f"deleted {rowcount} rows from table {self.upload_config.table_name}")
413
422
 
423
+ def run_data(self, data: list[dict], file_data: FileData, **kwargs: Any) -> None:
424
+ df = pd.DataFrame(data)
425
+ self.upload_dataframe(df=df, file_data=file_data)
426
+
414
427
  def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
415
- if self.can_delete():
416
- self.delete_by_record_id(file_data=file_data)
417
- else:
418
- logger.warning(
419
- f"table doesn't contain expected "
420
- f"record id column "
421
- f"{self.upload_config.record_id_key}, skipping delete"
422
- )
423
- self.upload_contents(path=path)
428
+ df = get_data_df(path=path)
429
+ self.upload_dataframe(df=df, file_data=file_data)
@@ -6,7 +6,6 @@ from typing import TYPE_CHECKING, Any, Generator
6
6
 
7
7
  from pydantic import Field, Secret, model_validator
8
8
 
9
- from unstructured_ingest.v2.interfaces import FileData
10
9
  from unstructured_ingest.v2.logger import logger
11
10
  from unstructured_ingest.v2.processes.connector_registry import (
12
11
  DestinationRegistryEntry,
@@ -15,6 +14,7 @@ from unstructured_ingest.v2.processes.connector_registry import (
15
14
  from unstructured_ingest.v2.processes.connectors.sql.sql import (
16
15
  _DATE_COLUMNS,
17
16
  SQLAccessConfig,
17
+ SqlBatchFileData,
18
18
  SQLConnectionConfig,
19
19
  SQLDownloader,
20
20
  SQLDownloaderConfig,
@@ -97,10 +97,10 @@ class SQLiteDownloader(SQLDownloader):
97
97
  connector_type: str = CONNECTOR_TYPE
98
98
  values_delimiter: str = "?"
99
99
 
100
- def query_db(self, file_data: FileData) -> tuple[list[tuple], list[str]]:
101
- table_name = file_data.additional_metadata["table_name"]
102
- id_column = file_data.additional_metadata["id_column"]
103
- ids = file_data.additional_metadata["ids"]
100
+ def query_db(self, file_data: SqlBatchFileData) -> tuple[list[tuple], list[str]]:
101
+ table_name = file_data.additional_metadata.table_name
102
+ id_column = file_data.additional_metadata.id_column
103
+ ids = [item.identifier for item in file_data.batch_items]
104
104
  with self.connection_config.get_connection() as sqlite_connection:
105
105
  cursor = sqlite_connection.cursor()
106
106
  fields = ",".join(self.download_config.fields) if self.download_config.fields else "*"