unstructured-ingest 0.0.0__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unstructured-ingest might be problematic. Click here for more details.

Files changed (44) hide show
  1. unstructured_ingest/__version__.py +1 -1
  2. unstructured_ingest/connector/notion/helpers.py +1 -1
  3. unstructured_ingest/logger.py +2 -2
  4. unstructured_ingest/v2/cli/base/cmd.py +10 -0
  5. unstructured_ingest/v2/cli/base/src.py +2 -0
  6. unstructured_ingest/v2/cli/cmds/__init__.py +2 -0
  7. unstructured_ingest/v2/cli/cmds/fsspec/fsspec.py +1 -9
  8. unstructured_ingest/v2/cli/cmds/local.py +0 -8
  9. unstructured_ingest/v2/cli/cmds/milvus.py +72 -0
  10. unstructured_ingest/v2/cli/configs/__init__.py +8 -1
  11. unstructured_ingest/v2/cli/configs/filter.py +28 -0
  12. unstructured_ingest/v2/interfaces/__init__.py +2 -1
  13. unstructured_ingest/v2/interfaces/downloader.py +9 -3
  14. unstructured_ingest/v2/interfaces/file_data.py +6 -1
  15. unstructured_ingest/v2/interfaces/process.py +3 -0
  16. unstructured_ingest/v2/logger.py +1 -1
  17. unstructured_ingest/v2/pipeline/interfaces.py +3 -1
  18. unstructured_ingest/v2/pipeline/pipeline.py +72 -2
  19. unstructured_ingest/v2/pipeline/steps/download.py +77 -13
  20. unstructured_ingest/v2/pipeline/steps/filter.py +40 -0
  21. unstructured_ingest/v2/processes/connectors/__init__.py +4 -2
  22. unstructured_ingest/v2/processes/connectors/astra.py +8 -0
  23. unstructured_ingest/v2/processes/connectors/azure_cognitive_search.py +8 -0
  24. unstructured_ingest/v2/processes/connectors/chroma.py +8 -6
  25. unstructured_ingest/v2/processes/connectors/databricks_volumes.py +9 -0
  26. unstructured_ingest/v2/processes/connectors/elasticsearch.py +23 -9
  27. unstructured_ingest/v2/processes/connectors/fsspec/fsspec.py +22 -31
  28. unstructured_ingest/v2/processes/connectors/fsspec/s3.py +13 -5
  29. unstructured_ingest/v2/processes/connectors/google_drive.py +13 -9
  30. unstructured_ingest/v2/processes/connectors/local.py +15 -15
  31. unstructured_ingest/v2/processes/connectors/milvus.py +200 -0
  32. unstructured_ingest/v2/processes/connectors/mongodb.py +10 -4
  33. unstructured_ingest/v2/processes/connectors/onedrive.py +14 -2
  34. unstructured_ingest/v2/processes/connectors/pinecone.py +10 -7
  35. unstructured_ingest/v2/processes/connectors/salesforce.py +10 -8
  36. unstructured_ingest/v2/processes/connectors/sharepoint.py +14 -8
  37. unstructured_ingest/v2/processes/connectors/sql.py +24 -9
  38. unstructured_ingest/v2/processes/connectors/weaviate.py +13 -5
  39. unstructured_ingest/v2/processes/filter.py +54 -0
  40. {unstructured_ingest-0.0.0.dist-info → unstructured_ingest-0.0.2.dist-info}/METADATA +16 -14
  41. {unstructured_ingest-0.0.0.dist-info → unstructured_ingest-0.0.2.dist-info}/RECORD +44 -39
  42. {unstructured_ingest-0.0.0.dist-info → unstructured_ingest-0.0.2.dist-info}/WHEEL +0 -0
  43. {unstructured_ingest-0.0.0.dist-info → unstructured_ingest-0.0.2.dist-info}/entry_points.txt +0 -0
  44. {unstructured_ingest-0.0.0.dist-info → unstructured_ingest-0.0.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,200 @@
1
+ import json
2
+ import multiprocessing as mp
3
+ from dataclasses import dataclass, field
4
+ from pathlib import Path
5
+ from typing import TYPE_CHECKING, Any, Optional, Union
6
+
7
+ import pandas as pd
8
+ from dateutil import parser
9
+
10
+ from unstructured_ingest.enhanced_dataclass import enhanced_field
11
+ from unstructured_ingest.error import WriteError
12
+ from unstructured_ingest.utils.data_prep import flatten_dict
13
+ from unstructured_ingest.utils.dep_check import requires_dependencies
14
+ from unstructured_ingest.v2.interfaces import (
15
+ AccessConfig,
16
+ ConnectionConfig,
17
+ FileData,
18
+ UploadContent,
19
+ Uploader,
20
+ UploaderConfig,
21
+ UploadStager,
22
+ UploadStagerConfig,
23
+ )
24
+ from unstructured_ingest.v2.logger import logger
25
+ from unstructured_ingest.v2.processes.connector_registry import (
26
+ DestinationRegistryEntry,
27
+ )
28
+
29
+ if TYPE_CHECKING:
30
+ from pymilvus import MilvusClient
31
+
32
+ CONNECTOR_TYPE = "milvus"
33
+
34
+
35
+ @dataclass
36
+ class MilvusAccessConfig(AccessConfig):
37
+ password: Optional[str] = None
38
+ token: Optional[str] = None
39
+
40
+
41
+ @dataclass
42
+ class MilvusConnectionConfig(ConnectionConfig):
43
+ access_config: MilvusAccessConfig = enhanced_field(
44
+ sensitive=True, default_factory=lambda: MilvusAccessConfig()
45
+ )
46
+ uri: Optional[str] = None
47
+ user: Optional[str] = None
48
+ db_name: Optional[str] = None
49
+
50
+ def get_connection_kwargs(self) -> dict[str, Any]:
51
+ access_config_dict = self.access_config.to_dict()
52
+ connection_config_dict = self.to_dict()
53
+ connection_config_dict.pop("access_config", None)
54
+ connection_config_dict.update(access_config_dict)
55
+ # Drop any that were not set explicitly
56
+ connection_config_dict = {k: v for k, v in connection_config_dict.items() if v is not None}
57
+ return connection_config_dict
58
+
59
+ @requires_dependencies(["pymilvus"], extras="milvus")
60
+ def get_client(self) -> "MilvusClient":
61
+ from pymilvus import MilvusClient
62
+
63
+ return MilvusClient(**self.get_connection_kwargs())
64
+
65
+
66
+ @dataclass
67
+ class MilvusUploadStagerConfig(UploadStagerConfig):
68
+ pass
69
+
70
+
71
+ @dataclass
72
+ class MilvusUploadStager(UploadStager):
73
+ upload_stager_config: MilvusUploadStagerConfig = field(
74
+ default_factory=lambda: MilvusUploadStagerConfig()
75
+ )
76
+
77
+ @staticmethod
78
+ def parse_date_string(date_string: str) -> float:
79
+ try:
80
+ timestamp = float(date_string)
81
+ return timestamp
82
+ except ValueError:
83
+ pass
84
+ return parser.parse(date_string).timestamp()
85
+
86
+ @classmethod
87
+ def conform_dict(cls, data: dict) -> None:
88
+ datetime_columns = [
89
+ "data_source_date_created",
90
+ "data_source_date_modified",
91
+ "data_source_date_processed",
92
+ "last_modified",
93
+ ]
94
+
95
+ json_dumps_fields = ["languages", "data_source_permissions_data"]
96
+
97
+ # TODO: milvus sdk doesn't seem to support defaults via the schema yet,
98
+ # remove once that gets updated
99
+ defaults = {"is_continuation": False}
100
+
101
+ if metadata := data.pop("metadata", None):
102
+ data.update(flatten_dict(metadata, keys_to_omit=["data_source_record_locator"]))
103
+ for datetime_column in datetime_columns:
104
+ if datetime_column in data:
105
+ data[datetime_column] = cls.parse_date_string(data[datetime_column])
106
+ for json_dumps_field in json_dumps_fields:
107
+ if json_dumps_field in data:
108
+ data[json_dumps_field] = json.dumps(data[json_dumps_field])
109
+ for default in defaults:
110
+ if default not in data:
111
+ data[default] = defaults[default]
112
+
113
+ def run(
114
+ self,
115
+ elements_filepath: Path,
116
+ file_data: FileData,
117
+ output_dir: Path,
118
+ output_filename: str,
119
+ **kwargs: Any,
120
+ ) -> Path:
121
+ with open(elements_filepath) as elements_file:
122
+ elements_contents: list[dict[str, Any]] = json.load(elements_file)
123
+ for element in elements_contents:
124
+ self.conform_dict(data=element)
125
+
126
+ output_path = Path(output_dir) / Path(f"{output_filename}.json")
127
+ output_path.parent.mkdir(parents=True, exist_ok=True)
128
+ with output_path.open("w") as output_file:
129
+ json.dump(elements_contents, output_file, indent=2)
130
+ return output_path
131
+
132
+
133
+ @dataclass
134
+ class MilvusUploaderConfig(UploaderConfig):
135
+ collection_name: str
136
+ num_of_processes: int = 4
137
+
138
+
139
+ @dataclass
140
+ class MilvusUploader(Uploader):
141
+ connection_config: MilvusConnectionConfig
142
+ upload_config: MilvusUploaderConfig
143
+ connector_type: str = CONNECTOR_TYPE
144
+
145
+ def upload(self, content: UploadContent) -> None:
146
+ file_extension = content.path.suffix
147
+ if file_extension == ".json":
148
+ self.upload_json(content=content)
149
+ elif file_extension == ".csv":
150
+ self.upload_csv(content=content)
151
+ else:
152
+ raise ValueError(f"Unsupported file extension: {file_extension}")
153
+
154
+ @requires_dependencies(["pymilvus"], extras="milvus")
155
+ def insert_results(self, data: Union[dict, list[dict]]):
156
+ from pymilvus import MilvusException
157
+
158
+ logger.debug(
159
+ f"uploading {len(data)} entries to {self.connection_config.db_name} "
160
+ f"db in collection {self.upload_config.collection_name}"
161
+ )
162
+ client = self.connection_config.get_client()
163
+
164
+ try:
165
+ res = client.insert(collection_name=self.upload_config.collection_name, data=data)
166
+ except MilvusException as milvus_exception:
167
+ raise WriteError("failed to upload records to milvus") from milvus_exception
168
+ if "err_count" in res and isinstance(res["err_count"], int) and res["err_count"] > 0:
169
+ err_count = res["err_count"]
170
+ raise WriteError(f"failed to upload {err_count} docs")
171
+
172
+ def upload_csv(self, content: UploadContent) -> None:
173
+ df = pd.read_csv(content.path)
174
+ data = df.to_dict(orient="records")
175
+ self.insert_results(data=data)
176
+
177
+ def upload_json(self, content: UploadContent) -> None:
178
+ with content.path.open("r") as file:
179
+ data: list[dict] = json.load(file)
180
+ self.insert_results(data=data)
181
+
182
+ def run(self, contents: list[UploadContent], **kwargs: Any) -> None:
183
+ if self.upload_config.num_of_processes == 1:
184
+ for content in contents:
185
+ self.upload(content=content)
186
+
187
+ else:
188
+ with mp.Pool(
189
+ processes=self.upload_config.num_of_processes,
190
+ ) as pool:
191
+ pool.map(self.upload, contents)
192
+
193
+
194
+ milvus_destination_entry = DestinationRegistryEntry(
195
+ connection_config=MilvusConnectionConfig,
196
+ uploader=MilvusUploader,
197
+ uploader_config=MilvusUploaderConfig,
198
+ upload_stager=MilvusUploadStager,
199
+ upload_stager_config=MilvusUploadStagerConfig,
200
+ )
@@ -6,6 +6,7 @@ from typing import TYPE_CHECKING, Any, Optional
6
6
  from unstructured.__version__ import __version__ as unstructured_version
7
7
 
8
8
  from unstructured_ingest.enhanced_dataclass import enhanced_field
9
+ from unstructured_ingest.error import DestinationConnectionError
9
10
  from unstructured_ingest.utils.data_prep import batch_generator
10
11
  from unstructured_ingest.utils.dep_check import requires_dependencies
11
12
  from unstructured_ingest.v2.interfaces import (
@@ -85,11 +86,15 @@ class MongoDBUploaderConfig(UploaderConfig):
85
86
  class MongoDBUploader(Uploader):
86
87
  upload_config: MongoDBUploaderConfig
87
88
  connection_config: MongoDBConnectionConfig
88
- client: Optional["MongoClient"] = field(init=False)
89
89
  connector_type: str = CONNECTOR_TYPE
90
90
 
91
- def __post_init__(self):
92
- self.client = self.create_client()
91
+ def precheck(self) -> None:
92
+ try:
93
+ client = self.create_client()
94
+ client.admin.command("ping")
95
+ except Exception as e:
96
+ logger.error(f"failed to validate connection: {e}", exc_info=True)
97
+ raise DestinationConnectionError(f"failed to validate connection: {e}")
93
98
 
94
99
  @requires_dependencies(["pymongo"], extras="mongodb")
95
100
  def create_client(self) -> "MongoClient":
@@ -123,7 +128,8 @@ class MongoDBUploader(Uploader):
123
128
  f"collection {self.connection_config.collection} "
124
129
  f"at {self.connection_config.host}",
125
130
  )
126
- db = self.client[self.connection_config.database]
131
+ client = self.create_client()
132
+ db = client[self.connection_config.database]
127
133
  collection = db[self.connection_config.collection]
128
134
  for chunk in batch_generator(elements_dict, self.upload_config.batch_size):
129
135
  collection.insert_many(chunk)
@@ -5,7 +5,6 @@ from time import time
5
5
  from typing import TYPE_CHECKING, Any, Generator, Optional
6
6
 
7
7
  from dateutil import parser
8
- from unstructured.documents.elements import DataSourceMetadata
9
8
 
10
9
  from unstructured_ingest.enhanced_dataclass import enhanced_field
11
10
  from unstructured_ingest.error import SourceConnectionError, SourceConnectionNetworkError
@@ -17,6 +16,7 @@ from unstructured_ingest.v2.interfaces import (
17
16
  DownloaderConfig,
18
17
  DownloadResponse,
19
18
  FileData,
19
+ FileDataSourceMetadata,
20
20
  Indexer,
21
21
  IndexerConfig,
22
22
  SourceIdentifiers,
@@ -87,6 +87,18 @@ class OnedriveIndexer(Indexer):
87
87
  connection_config: OnedriveConnectionConfig
88
88
  index_config: OnedriveIndexerConfig
89
89
 
90
+ def precheck(self) -> None:
91
+ try:
92
+ token_resp: dict = self.connection_config.get_token()
93
+ if error := token_resp.get("error"):
94
+ raise SourceConnectionError(
95
+ "{} ({})".format(error, token_resp.get("error_description"))
96
+ )
97
+ self.connection_config.get_client()
98
+ except Exception as e:
99
+ logger.error(f"failed to validate connection: {e}", exc_info=True)
100
+ raise SourceConnectionError(f"failed to validate connection: {e}")
101
+
90
102
  def list_objects(self, folder, recursive) -> list["DriveItem"]:
91
103
  drive_items = folder.children.get().execute_query()
92
104
  files = [d for d in drive_items if d.is_file]
@@ -136,7 +148,7 @@ class OnedriveIndexer(Indexer):
136
148
  source_identifiers=SourceIdentifiers(
137
149
  fullpath=server_path, filename=drive_item.name, rel_path=rel_path
138
150
  ),
139
- metadata=DataSourceMetadata(
151
+ metadata=FileDataSourceMetadata(
140
152
  url=drive_item.parent_reference.path + "/" + drive_item.name,
141
153
  version=drive_item.etag,
142
154
  date_modified=str(date_modified_dt.timestamp()) if date_modified_dt else None,
@@ -5,10 +5,6 @@ from dataclasses import dataclass, field
5
5
  from pathlib import Path
6
6
  from typing import TYPE_CHECKING, Any, Optional
7
7
 
8
- from unstructured.ingest.v2.logger import logger
9
- from unstructured.ingest.v2.processes.connector_registry import (
10
- DestinationRegistryEntry,
11
- )
12
8
  from unstructured.staging.base import flatten_dict
13
9
  from unstructured.utils import requires_dependencies
14
10
 
@@ -24,6 +20,10 @@ from unstructured_ingest.v2.interfaces import (
24
20
  UploadStager,
25
21
  UploadStagerConfig,
26
22
  )
23
+ from unstructured_ingest.v2.logger import logger
24
+ from unstructured_ingest.v2.processes.connector_registry import (
25
+ DestinationRegistryEntry,
26
+ )
27
27
 
28
28
  if TYPE_CHECKING:
29
29
  from pinecone import Index as PineconeIndex
@@ -123,9 +123,12 @@ class PineconeUploader(Uploader):
123
123
  connection_config: PineconeConnectionConfig
124
124
  connector_type: str = CONNECTOR_TYPE
125
125
 
126
- @DestinationConnectionError.wrap
127
- def check_connection(self):
128
- _ = self.connection_config.get_index()
126
+ def precheck(self):
127
+ try:
128
+ self.connection_config.get_index()
129
+ except Exception as e:
130
+ logger.error(f"failed to validate connection: {e}", exc_info=True)
131
+ raise DestinationConnectionError(f"failed to validate connection: {e}")
129
132
 
130
133
  @requires_dependencies(["pinecone"], extras="pinecone")
131
134
  def upsert_batch(self, batch):
@@ -18,10 +18,9 @@ from textwrap import dedent
18
18
  from typing import TYPE_CHECKING, Any, Generator, Type
19
19
 
20
20
  from dateutil import parser
21
- from unstructured.documents.elements import DataSourceMetadata
22
21
 
23
22
  from unstructured_ingest.enhanced_dataclass import enhanced_field
24
- from unstructured_ingest.error import SourceConnectionNetworkError
23
+ from unstructured_ingest.error import SourceConnectionError, SourceConnectionNetworkError
25
24
  from unstructured_ingest.utils.dep_check import requires_dependencies
26
25
  from unstructured_ingest.v2.interfaces import (
27
26
  AccessConfig,
@@ -30,6 +29,7 @@ from unstructured_ingest.v2.interfaces import (
30
29
  DownloaderConfig,
31
30
  DownloadResponse,
32
31
  FileData,
32
+ FileDataSourceMetadata,
33
33
  Indexer,
34
34
  IndexerConfig,
35
35
  SourceIdentifiers,
@@ -132,6 +132,13 @@ class SalesforceIndexer(Indexer):
132
132
  if record_type not in ACCEPTED_CATEGORIES:
133
133
  raise ValueError(f"{record_type} not currently an accepted Salesforce category")
134
134
 
135
+ def precheck(self) -> None:
136
+ try:
137
+ self.connection_config.get_client()
138
+ except Exception as e:
139
+ logger.error(f"failed to validate connection: {e}", exc_info=True)
140
+ raise SourceConnectionError(f"failed to validate connection: {e}")
141
+
135
142
  def get_file_extension(self, record_type) -> str:
136
143
  if record_type == "EmailMessage":
137
144
  extension = ".eml"
@@ -172,7 +179,7 @@ class SalesforceIndexer(Indexer):
172
179
  filename=record_with_extension,
173
180
  fullpath=f"{record['attributes']['type']}/{record_with_extension}",
174
181
  ),
175
- metadata=DataSourceMetadata(
182
+ metadata=FileDataSourceMetadata(
176
183
  url=record["attributes"]["url"],
177
184
  version=str(parser.parse(record["SystemModstamp"]).timestamp()),
178
185
  date_created=str(parser.parse(record["CreatedDate"]).timestamp()),
@@ -207,11 +214,6 @@ class SalesforceDownloader(Downloader):
207
214
  )
208
215
  connector_type: str = CONNECTOR_TYPE
209
216
 
210
- def get_download_path(self, file_data: FileData) -> Path:
211
- rel_path = file_data.source_identifiers.relative_path
212
- rel_path = rel_path[1:] if rel_path.startswith("/") else rel_path
213
- return self.download_dir / Path(rel_path)
214
-
215
217
  def _xml_for_record(self, record: OrderedDict) -> str:
216
218
  """Creates partitionable xml file from a record"""
217
219
  import xml.etree.ElementTree as ET
@@ -6,10 +6,8 @@ from time import time
6
6
  from typing import TYPE_CHECKING, Any, Generator, Optional
7
7
  from urllib.parse import quote
8
8
 
9
- from unstructured.documents.elements import DataSourceMetadata
10
-
11
9
  from unstructured_ingest.enhanced_dataclass import EnhancedDataClassJsonMixin, enhanced_field
12
- from unstructured_ingest.error import SourceConnectionNetworkError
10
+ from unstructured_ingest.error import SourceConnectionError, SourceConnectionNetworkError
13
11
  from unstructured_ingest.utils.dep_check import requires_dependencies
14
12
  from unstructured_ingest.v2.interfaces import (
15
13
  AccessConfig,
@@ -18,6 +16,7 @@ from unstructured_ingest.v2.interfaces import (
18
16
  DownloaderConfig,
19
17
  DownloadResponse,
20
18
  FileData,
19
+ FileDataSourceMetadata,
21
20
  Indexer,
22
21
  IndexerConfig,
23
22
  SourceIdentifiers,
@@ -134,6 +133,14 @@ class SharepointIndexer(Indexer):
134
133
  connection_config: SharepointConnectionConfig
135
134
  index_config: SharepointIndexerConfig = field(default_factory=lambda: SharepointIndexerConfig())
136
135
 
136
+ def precheck(self) -> None:
137
+ try:
138
+ site_client = self.connection_config.get_client()
139
+ site_client.site_pages.pages.get().execute_query()
140
+ except Exception as e:
141
+ logger.error(f"failed to validate connection: {e}", exc_info=True)
142
+ raise SourceConnectionError(f"failed to validate connection: {e}")
143
+
137
144
  def list_files(self, folder: "Folder", recursive: bool = False) -> list["File"]:
138
145
  if not recursive:
139
146
  folder.expand(["Files"]).get().execute_query()
@@ -187,7 +194,7 @@ class SharepointIndexer(Indexer):
187
194
  fullpath=file_path,
188
195
  rel_path=file_path.replace(self.index_config.path, ""),
189
196
  ),
190
- metadata=DataSourceMetadata(
197
+ metadata=FileDataSourceMetadata(
191
198
  url=url,
192
199
  version=version,
193
200
  date_modified=str(date_modified_dt.timestamp()) if date_modified_dt else None,
@@ -222,7 +229,7 @@ class SharepointIndexer(Indexer):
222
229
  fullpath=fullpath,
223
230
  rel_path=rel_path,
224
231
  ),
225
- metadata=DataSourceMetadata(
232
+ metadata=FileDataSourceMetadata(
226
233
  url=absolute_url,
227
234
  version=f"{file.major_version}.{file.minor_version}",
228
235
  date_modified=str(date_modified_dt.timestamp()) if date_modified_dt else None,
@@ -340,10 +347,9 @@ class SharepointDownloader(Downloader):
340
347
  connector_type: str = CONNECTOR_TYPE
341
348
 
342
349
  def get_download_path(self, file_data: FileData) -> Path:
350
+ download_path = super().get_download_path(file_data=file_data)
351
+
343
352
  content_type = file_data.additional_metadata.get("sharepoint_content_type")
344
- rel_path = file_data.source_identifiers.fullpath
345
- rel_path = rel_path[1:] if rel_path.startswith("/") else rel_path
346
- download_path = self.download_dir / Path(rel_path)
347
353
  if content_type == SharepointContentType.SITEPAGE.value:
348
354
  # Update output extension to html if site page
349
355
  download_path = download_path.with_suffix(".html")
@@ -4,13 +4,14 @@ import uuid
4
4
  from dataclasses import dataclass, field
5
5
  from datetime import date, datetime
6
6
  from pathlib import Path
7
- from typing import Any, Optional, Union
7
+ from typing import TYPE_CHECKING, Any, Callable, Optional, Union
8
8
 
9
9
  import numpy as np
10
10
  import pandas as pd
11
11
  from dateutil import parser
12
12
 
13
13
  from unstructured_ingest.enhanced_dataclass import enhanced_field
14
+ from unstructured_ingest.error import DestinationConnectionError
14
15
  from unstructured_ingest.utils.dep_check import requires_dependencies
15
16
  from unstructured_ingest.v2.interfaces import (
16
17
  AccessConfig,
@@ -25,6 +26,11 @@ from unstructured_ingest.v2.interfaces import (
25
26
  from unstructured_ingest.v2.logger import logger
26
27
  from unstructured_ingest.v2.processes.connector_registry import DestinationRegistryEntry
27
28
 
29
+ if TYPE_CHECKING:
30
+ from sqlite3 import Connection as SqliteConnection
31
+
32
+ from psycopg2.extensions import connection as PostgresConnection
33
+
28
34
  CONNECTOR_TYPE = "sql"
29
35
  ELEMENTS_TABLE_NAME = "elements"
30
36
 
@@ -41,7 +47,7 @@ class DatabaseType(str, enum.Enum):
41
47
 
42
48
 
43
49
  @dataclass
44
- class SimpleSqlConfig(ConnectionConfig):
50
+ class SQLConnectionConfig(ConnectionConfig):
45
51
  db_type: DatabaseType = (
46
52
  # required default value here because of parent class
47
53
  DatabaseType.SQLITE
@@ -134,7 +140,7 @@ class SQLUploadStager(UploadStager):
134
140
  **kwargs: Any,
135
141
  ) -> Path:
136
142
  with open(elements_filepath) as elements_file:
137
- elements_contents = json.load(elements_file)
143
+ elements_contents: list[dict] = json.load(elements_file)
138
144
  output_path = Path(output_dir) / Path(f"{output_filename}.json")
139
145
  output_path.parent.mkdir(parents=True, exist_ok=True)
140
146
 
@@ -151,7 +157,7 @@ class SQLUploadStager(UploadStager):
151
157
  data["id"] = str(uuid.uuid4())
152
158
 
153
159
  # remove extraneous, not supported columns
154
- [data.pop(column) for column in data if column not in _COLUMNS]
160
+ data = {k: v for k, v in data.items() if k in _COLUMNS}
155
161
 
156
162
  output.append(data)
157
163
 
@@ -185,23 +191,32 @@ class SQLUploaderConfig(UploaderConfig):
185
191
  class SQLUploader(Uploader):
186
192
  connector_type: str = CONNECTOR_TYPE
187
193
  upload_config: SQLUploaderConfig
188
- connection_config: SimpleSqlConfig
194
+ connection_config: SQLConnectionConfig
195
+
196
+ def precheck(self) -> None:
197
+ try:
198
+ cursor = self.connection().cursor()
199
+ cursor.execute("SELECT 1;")
200
+ cursor.close()
201
+ except Exception as e:
202
+ logger.error(f"failed to validate connection: {e}", exc_info=True)
203
+ raise DestinationConnectionError(f"failed to validate connection: {e}")
189
204
 
190
205
  @property
191
- def connection(self):
206
+ def connection(self) -> Callable[[], Union["SqliteConnection", "PostgresConnection"]]:
192
207
  if self.connection_config.db_type == DatabaseType.POSTGRESQL:
193
208
  return self._make_psycopg_connection
194
209
  elif self.connection_config.db_type == DatabaseType.SQLITE:
195
210
  return self._make_sqlite_connection
196
211
  raise ValueError(f"Unsupported database {self.connection_config.db_type} connection.")
197
212
 
198
- def _make_sqlite_connection(self):
213
+ def _make_sqlite_connection(self) -> "SqliteConnection":
199
214
  from sqlite3 import connect
200
215
 
201
216
  return connect(database=self.connection_config.database)
202
217
 
203
218
  @requires_dependencies(["psycopg2"], extras="postgres")
204
- def _make_psycopg_connection(self):
219
+ def _make_psycopg_connection(self) -> "PostgresConnection":
205
220
  from psycopg2 import connect
206
221
 
207
222
  return connect(
@@ -261,7 +276,7 @@ class SQLUploader(Uploader):
261
276
 
262
277
 
263
278
  sql_destination_entry = DestinationRegistryEntry(
264
- connection_config=SimpleSqlConfig,
279
+ connection_config=SQLConnectionConfig,
265
280
  uploader=SQLUploader,
266
281
  uploader_config=SQLUploaderConfig,
267
282
  upload_stager=SQLUploadStager,
@@ -7,6 +7,7 @@ from typing import TYPE_CHECKING, Any, Optional
7
7
  from dateutil import parser
8
8
 
9
9
  from unstructured_ingest.enhanced_dataclass import enhanced_field
10
+ from unstructured_ingest.error import DestinationConnectionError
10
11
  from unstructured_ingest.utils.dep_check import requires_dependencies
11
12
  from unstructured_ingest.v2.interfaces import (
12
13
  AccessConfig,
@@ -156,15 +157,21 @@ class WeaviateUploaderConfig(UploaderConfig):
156
157
  class WeaviateUploader(Uploader):
157
158
  upload_config: WeaviateUploaderConfig
158
159
  connection_config: WeaviateConnectionConfig
159
- client: Optional["Client"] = field(init=False)
160
160
  connector_type: str = CONNECTOR_TYPE
161
161
 
162
162
  @requires_dependencies(["weaviate"], extras="weaviate")
163
- def __post_init__(self):
163
+ def get_client(self) -> "Client":
164
164
  from weaviate import Client
165
165
 
166
166
  auth = self._resolve_auth_method()
167
- self.client = Client(url=self.connection_config.host_url, auth_client_secret=auth)
167
+ return Client(url=self.connection_config.host_url, auth_client_secret=auth)
168
+
169
+ def precheck(self) -> None:
170
+ try:
171
+ self.get_client()
172
+ except Exception as e:
173
+ logger.error(f"Failed to validate connection {e}", exc_info=True)
174
+ raise DestinationConnectionError(f"failed to validate connection: {e}")
168
175
 
169
176
  @requires_dependencies(["weaviate"], extras="weaviate")
170
177
  def _resolve_auth_method(self):
@@ -215,8 +222,9 @@ class WeaviateUploader(Uploader):
215
222
  f"at {self.connection_config.host_url}",
216
223
  )
217
224
 
218
- self.client.batch.configure(batch_size=self.upload_config.batch_size)
219
- with self.client.batch as b:
225
+ client = self.get_client()
226
+ client.batch.configure(batch_size=self.upload_config.batch_size)
227
+ with client.batch as b:
220
228
  for e in elements_dict:
221
229
  vector = e.pop("embeddings", None)
222
230
  b.add_data_object(
@@ -0,0 +1,54 @@
1
+ import fnmatch
2
+ from abc import ABC
3
+ from dataclasses import dataclass, field
4
+ from typing import Any, Callable, Optional
5
+
6
+ from unstructured_ingest.enhanced_dataclass import EnhancedDataClassJsonMixin
7
+ from unstructured_ingest.v2.interfaces import FileData
8
+ from unstructured_ingest.v2.interfaces.process import BaseProcess
9
+ from unstructured_ingest.v2.logger import logger
10
+
11
+
12
+ @dataclass
13
+ class FiltererConfig(EnhancedDataClassJsonMixin):
14
+ file_glob: Optional[list[str]] = None
15
+ max_file_size: Optional[int] = None
16
+
17
+
18
+ @dataclass
19
+ class Filterer(BaseProcess, ABC):
20
+ config: FiltererConfig = field(default_factory=lambda: FiltererConfig())
21
+ filters: list[Callable[[FileData], bool]] = field(init=False, default_factory=list)
22
+
23
+ def __post_init__(self):
24
+ # Populate the filters based on values in config
25
+ if self.config.file_glob is not None:
26
+ self.filters.append(self.glob_filter)
27
+ if self.config.max_file_size:
28
+ self.filters.append(self.file_size_filter)
29
+
30
+ def is_async(self) -> bool:
31
+ return False
32
+
33
+ def file_size_filter(self, file_data: FileData) -> bool:
34
+ if filesize_bytes := file_data.metadata.filesize_bytes:
35
+ return filesize_bytes <= self.config.max_file_size
36
+ return True
37
+
38
+ def glob_filter(self, file_data: FileData) -> bool:
39
+ patterns = self.config.file_glob
40
+ path = file_data.source_identifiers.fullpath
41
+ for pattern in patterns:
42
+ if fnmatch.filter([path], pattern):
43
+ return True
44
+ logger.debug(f"The file {path!r} is discarded as it does not match any given glob.")
45
+ return False
46
+
47
+ def run(self, file_data: FileData, **kwargs: Any) -> Optional[FileData]:
48
+ for filter in self.filters:
49
+ if not filter(file_data):
50
+ logger.debug(
51
+ f"filtered out file data due to {filter.__name__}: {file_data.identifier}"
52
+ )
53
+ return None
54
+ return file_data