unstructured-ingest 0.3.14__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of unstructured-ingest might be problematic. Click here for more details.

Files changed (23) hide show
  1. test/integration/connectors/sql/test_databricks_delta_tables.py +142 -0
  2. test/integration/connectors/test_confluence.py +4 -4
  3. test/integration/connectors/test_pinecone.py +68 -2
  4. test/unit/v2/connectors/sql/__init__.py +0 -0
  5. test/unit/v2/connectors/sql/test_sql.py +72 -0
  6. test/unit/v2/connectors/test_confluence.py +6 -6
  7. unstructured_ingest/__version__.py +1 -1
  8. unstructured_ingest/v2/interfaces/upload_stager.py +3 -3
  9. unstructured_ingest/v2/processes/connectors/confluence.py +30 -10
  10. unstructured_ingest/v2/processes/connectors/databricks/__init__.py +6 -0
  11. unstructured_ingest/v2/processes/connectors/databricks/volumes.py +6 -3
  12. unstructured_ingest/v2/processes/connectors/databricks/volumes_table.py +106 -0
  13. unstructured_ingest/v2/processes/connectors/pinecone.py +18 -11
  14. unstructured_ingest/v2/processes/connectors/sql/__init__.py +6 -0
  15. unstructured_ingest/v2/processes/connectors/sql/databricks_delta_tables.py +213 -0
  16. unstructured_ingest/v2/processes/connectors/sql/snowflake.py +1 -1
  17. unstructured_ingest/v2/processes/connectors/sql/sql.py +28 -9
  18. {unstructured_ingest-0.3.14.dist-info → unstructured_ingest-0.4.0.dist-info}/METADATA +22 -20
  19. {unstructured_ingest-0.3.14.dist-info → unstructured_ingest-0.4.0.dist-info}/RECORD +23 -18
  20. {unstructured_ingest-0.3.14.dist-info → unstructured_ingest-0.4.0.dist-info}/LICENSE.md +0 -0
  21. {unstructured_ingest-0.3.14.dist-info → unstructured_ingest-0.4.0.dist-info}/WHEEL +0 -0
  22. {unstructured_ingest-0.3.14.dist-info → unstructured_ingest-0.4.0.dist-info}/entry_points.txt +0 -0
  23. {unstructured_ingest-0.3.14.dist-info → unstructured_ingest-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,142 @@
1
+ import json
2
+ import os
3
+ import time
4
+ from contextlib import contextmanager
5
+ from pathlib import Path
6
+ from uuid import uuid4
7
+
8
+ import pytest
9
+ from databricks.sql import connect
10
+ from databricks.sql.client import Connection as DeltaTableConnection
11
+ from databricks.sql.client import Cursor as DeltaTableCursor
12
+ from pydantic import BaseModel, SecretStr
13
+
14
+ from test.integration.connectors.utils.constants import DESTINATION_TAG, SQL_TAG, env_setup_path
15
+ from test.integration.utils import requires_env
16
+ from unstructured_ingest.v2.interfaces import FileData, SourceIdentifiers
17
+ from unstructured_ingest.v2.logger import logger
18
+ from unstructured_ingest.v2.processes.connectors.sql.databricks_delta_tables import (
19
+ CONNECTOR_TYPE,
20
+ DatabrickDeltaTablesAccessConfig,
21
+ DatabrickDeltaTablesConnectionConfig,
22
+ DatabrickDeltaTablesUploader,
23
+ DatabrickDeltaTablesUploaderConfig,
24
+ DatabrickDeltaTablesUploadStager,
25
+ )
26
+
27
+ CATALOG = "utic-dev-tech-fixtures"
28
+
29
+
30
+ class EnvData(BaseModel):
31
+ server_hostname: str
32
+ http_path: str
33
+ access_token: SecretStr
34
+
35
+
36
+ def get_env_data() -> EnvData:
37
+ return EnvData(
38
+ server_hostname=os.environ["DATABRICKS_SERVER_HOSTNAME"],
39
+ http_path=os.environ["DATABRICKS_HTTP_PATH"],
40
+ access_token=os.environ["DATABRICKS_ACCESS_TOKEN"],
41
+ )
42
+
43
+
44
+ def get_destination_schema(new_table_name: str) -> str:
45
+ p = Path(env_setup_path / "sql" / "databricks_delta_tables" / "destination" / "schema.sql")
46
+ with p.open() as f:
47
+ data_lines = f.readlines()
48
+ data_lines[0] = data_lines[0].replace("elements", new_table_name)
49
+ data = "".join([line.strip() for line in data_lines])
50
+ return data
51
+
52
+
53
+ @contextmanager
54
+ def get_connection() -> DeltaTableConnection:
55
+ env_data = get_env_data()
56
+ with connect(
57
+ server_hostname=env_data.server_hostname,
58
+ http_path=env_data.http_path,
59
+ access_token=env_data.access_token.get_secret_value(),
60
+ ) as connection:
61
+ yield connection
62
+
63
+
64
+ @contextmanager
65
+ def get_cursor() -> DeltaTableCursor:
66
+ with get_connection() as connection:
67
+ with connection.cursor() as cursor:
68
+ cursor.execute(f"USE CATALOG '{CATALOG}'")
69
+ yield cursor
70
+
71
+
72
+ @pytest.fixture
73
+ def destination_table() -> str:
74
+ random_id = str(uuid4())[:8]
75
+ table_name = f"elements_{random_id}"
76
+ destination_schema = get_destination_schema(new_table_name=table_name)
77
+ with get_cursor() as cursor:
78
+ logger.info(f"creating table: {table_name}")
79
+ cursor.execute(f"DROP TABLE IF EXISTS {table_name}")
80
+ cursor.execute(destination_schema)
81
+
82
+ yield table_name
83
+ with get_cursor() as cursor:
84
+ logger.info(f"dropping table: {table_name}")
85
+ cursor.execute(f"DROP TABLE IF EXISTS {table_name}")
86
+
87
+
88
+ def validate_destination(expected_num_elements: int, table_name: str, retries=30, interval=1):
89
+ with get_cursor() as cursor:
90
+ for i in range(retries):
91
+ cursor.execute(f"SELECT COUNT(*) FROM {table_name}")
92
+ count = cursor.fetchone()[0]
93
+ if count == expected_num_elements:
94
+ break
95
+ logger.info(f"retry attempt {i}: expected {expected_num_elements} != count {count}")
96
+ time.sleep(interval)
97
+ assert (
98
+ count == expected_num_elements
99
+ ), f"dest check failed: got {count}, expected {expected_num_elements}"
100
+
101
+
102
+ @pytest.mark.asyncio
103
+ @pytest.mark.skip("Resources take too long to spin up to run in CI")
104
+ @pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG, SQL_TAG)
105
+ @requires_env("DATABRICKS_SERVER_HOSTNAME", "DATABRICKS_HTTP_PATH", "DATABRICKS_ACCESS_TOKEN")
106
+ async def test_databricks_delta_tables_destination(
107
+ upload_file: Path, temp_dir: Path, destination_table: str
108
+ ):
109
+ env_data = get_env_data()
110
+ mock_file_data = FileData(
111
+ identifier="mock file data",
112
+ connector_type=CONNECTOR_TYPE,
113
+ source_identifiers=SourceIdentifiers(filename=upload_file.name, fullpath=upload_file.name),
114
+ )
115
+ stager = DatabrickDeltaTablesUploadStager()
116
+ staged_path = stager.run(
117
+ elements_filepath=upload_file,
118
+ file_data=mock_file_data,
119
+ output_dir=temp_dir,
120
+ output_filename=upload_file.name,
121
+ )
122
+
123
+ assert staged_path.suffix == upload_file.suffix
124
+
125
+ uploader = DatabrickDeltaTablesUploader(
126
+ connection_config=DatabrickDeltaTablesConnectionConfig(
127
+ access_config=DatabrickDeltaTablesAccessConfig(
128
+ token=env_data.access_token.get_secret_value()
129
+ ),
130
+ http_path=env_data.http_path,
131
+ server_hostname=env_data.server_hostname,
132
+ ),
133
+ upload_config=DatabrickDeltaTablesUploaderConfig(
134
+ catalog=CATALOG, database="default", table_name=destination_table
135
+ ),
136
+ )
137
+ with staged_path.open("r") as f:
138
+ staged_data = json.load(f)
139
+ expected_num_elements = len(staged_data)
140
+ uploader.precheck()
141
+ uploader.run(path=staged_path, file_data=mock_file_data)
142
+ validate_destination(expected_num_elements=expected_num_elements, table_name=destination_table)
@@ -30,10 +30,10 @@ async def test_confluence_source(temp_dir):
30
30
  spaces = ["testteamsp", "MFS"]
31
31
 
32
32
  # Create connection and indexer configurations
33
- access_config = ConfluenceAccessConfig(api_token=api_token)
33
+ access_config = ConfluenceAccessConfig(password=api_token)
34
34
  connection_config = ConfluenceConnectionConfig(
35
35
  url=confluence_url,
36
- user_email=user_email,
36
+ username=user_email,
37
37
  access_config=access_config,
38
38
  )
39
39
  index_config = ConfluenceIndexerConfig(
@@ -77,10 +77,10 @@ async def test_confluence_source_large(temp_dir):
77
77
  spaces = ["testteamsp1"]
78
78
 
79
79
  # Create connection and indexer configurations
80
- access_config = ConfluenceAccessConfig(api_token=api_token)
80
+ access_config = ConfluenceAccessConfig(password=api_token)
81
81
  connection_config = ConfluenceConnectionConfig(
82
82
  url=confluence_url,
83
- user_email=user_email,
83
+ username=user_email,
84
84
  access_config=access_config,
85
85
  )
86
86
  index_config = ConfluenceIndexerConfig(
@@ -107,11 +107,15 @@ def pinecone_index() -> Generator[str, None, None]:
107
107
 
108
108
 
109
109
  def validate_pinecone_index(
110
- index_name: str, expected_num_of_vectors: int, retries=30, interval=1
110
+ index_name: str,
111
+ expected_num_of_vectors: int,
112
+ retries=30,
113
+ interval=1,
114
+ namespace: str = "default",
111
115
  ) -> None:
112
116
  # Because there's a delay for the index to catch up to the recent writes, add in a retry
113
117
  pinecone = Pinecone(api_key=get_api_key())
114
- index = pinecone.Index(name=index_name)
118
+ index = pinecone.Index(name=index_name, namespace=namespace)
115
119
  vector_count = -1
116
120
  for i in range(retries):
117
121
  index_stats = index.describe_index_stats()
@@ -133,11 +137,13 @@ def validate_pinecone_index(
133
137
  @pytest.mark.asyncio
134
138
  @pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG, VECTOR_DB_TAG)
135
139
  async def test_pinecone_destination(pinecone_index: str, upload_file: Path, temp_dir: Path):
140
+
136
141
  file_data = FileData(
137
142
  source_identifiers=SourceIdentifiers(fullpath=upload_file.name, filename=upload_file.name),
138
143
  connector_type=CONNECTOR_TYPE,
139
144
  identifier="pinecone_mock_id",
140
145
  )
146
+
141
147
  connection_config = PineconeConnectionConfig(
142
148
  index_name=pinecone_index,
143
149
  access_config=PineconeAccessConfig(api_key=get_api_key()),
@@ -224,6 +230,66 @@ async def test_pinecone_destination_large_index(
224
230
  )
225
231
 
226
232
 
233
+ @requires_env(API_KEY)
234
+ @pytest.mark.asyncio
235
+ @pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG, VECTOR_DB_TAG)
236
+ async def test_pinecone_destination_namespace(
237
+ pinecone_index: str, upload_file: Path, temp_dir: Path
238
+ ):
239
+ """
240
+ tests namespace functionality of destination connector.
241
+ """
242
+
243
+ # creates a file data structure.
244
+ file_data = FileData(
245
+ source_identifiers=SourceIdentifiers(fullpath=upload_file.name, filename=upload_file.name),
246
+ connector_type=CONNECTOR_TYPE,
247
+ identifier="pinecone_mock_id",
248
+ )
249
+
250
+ connection_config = PineconeConnectionConfig(
251
+ index_name=pinecone_index,
252
+ access_config=PineconeAccessConfig(api_key=get_api_key()),
253
+ )
254
+
255
+ stager_config = PineconeUploadStagerConfig()
256
+
257
+ stager = PineconeUploadStager(upload_stager_config=stager_config)
258
+ new_upload_file = stager.run(
259
+ elements_filepath=upload_file,
260
+ output_dir=temp_dir,
261
+ output_filename=upload_file.name,
262
+ file_data=file_data,
263
+ )
264
+
265
+ # here add namespace defintion
266
+ upload_config = PineconeUploaderConfig()
267
+ namespace_test_name = "user-1"
268
+ upload_config.namespace = namespace_test_name
269
+ uploader = PineconeUploader(connection_config=connection_config, upload_config=upload_config)
270
+ uploader.precheck()
271
+
272
+ uploader.run(path=new_upload_file, file_data=file_data)
273
+ with new_upload_file.open() as f:
274
+ staged_content = json.load(f)
275
+ expected_num_of_vectors = len(staged_content)
276
+ logger.info("validating first upload")
277
+ validate_pinecone_index(
278
+ index_name=pinecone_index,
279
+ expected_num_of_vectors=expected_num_of_vectors,
280
+ namespace=namespace_test_name,
281
+ )
282
+
283
+ # Rerun uploader and make sure no duplicates exist
284
+ uploader.run(path=new_upload_file, file_data=file_data)
285
+ logger.info("validating second upload")
286
+ validate_pinecone_index(
287
+ index_name=pinecone_index,
288
+ expected_num_of_vectors=expected_num_of_vectors,
289
+ namespace=namespace_test_name,
290
+ )
291
+
292
+
227
293
  @requires_env(API_KEY)
228
294
  @pytest.mark.tags(CONNECTOR_TYPE, DESTINATION_TAG, VECTOR_DB_TAG)
229
295
  def test_large_metadata(pinecone_index: str, tmp_path: Path, upload_file: Path):
File without changes
@@ -0,0 +1,72 @@
1
+ from pathlib import Path
2
+
3
+ import pytest
4
+ from pytest_mock import MockerFixture
5
+
6
+ from unstructured_ingest.v2.interfaces.file_data import FileData, SourceIdentifiers
7
+ from unstructured_ingest.v2.processes.connectors.sql.sql import SQLUploadStager
8
+
9
+
10
+ @pytest.fixture
11
+ def mock_instance() -> SQLUploadStager:
12
+ return SQLUploadStager()
13
+
14
+
15
+ @pytest.mark.parametrize(
16
+ ("input_filepath", "output_filename", "expected"),
17
+ [
18
+ (
19
+ "/path/to/input_file.ndjson",
20
+ "output_file.ndjson",
21
+ "output_file.ndjson",
22
+ ),
23
+ ("input_file.txt", "output_file.json", "output_file.txt"),
24
+ ("/path/to/input_file.json", "output_file", "output_file.json"),
25
+ ],
26
+ )
27
+ def test_run_output_filename_suffix(
28
+ mocker: MockerFixture,
29
+ mock_instance: SQLUploadStager,
30
+ input_filepath: str,
31
+ output_filename: str,
32
+ expected: str,
33
+ ):
34
+ output_dir = Path("/tmp/test/output_dir")
35
+
36
+ # Mocks
37
+ mock_get_data = mocker.patch(
38
+ "unstructured_ingest.v2.processes.connectors.sql.sql.get_data",
39
+ return_value=[{"key": "value"}, {"key": "value2"}],
40
+ )
41
+ mock_conform_dict = mocker.patch.object(
42
+ SQLUploadStager, "conform_dict", side_effect=lambda element_dict, file_data: element_dict
43
+ )
44
+ mock_conform_dataframe = mocker.patch.object(
45
+ SQLUploadStager, "conform_dataframe", side_effect=lambda df: df
46
+ )
47
+ mock_get_output_path = mocker.patch.object(
48
+ SQLUploadStager, "get_output_path", return_value=output_dir / expected
49
+ )
50
+ mock_write_output = mocker.patch.object(SQLUploadStager, "write_output")
51
+
52
+ # Act
53
+ result = mock_instance.run(
54
+ elements_filepath=Path(input_filepath),
55
+ file_data=FileData(
56
+ identifier="test",
57
+ connector_type="test",
58
+ source_identifiers=SourceIdentifiers(filename=input_filepath, fullpath=input_filepath),
59
+ ),
60
+ output_dir=output_dir,
61
+ output_filename=output_filename,
62
+ )
63
+
64
+ # Assert
65
+ mock_get_data.assert_called_once_with(path=Path(input_filepath))
66
+ assert mock_conform_dict.call_count == 2
67
+ mock_conform_dataframe.assert_called_once()
68
+ mock_get_output_path.assert_called_once_with(output_filename=expected, output_dir=output_dir)
69
+ mock_write_output.assert_called_once_with(
70
+ output_path=output_dir / expected, data=[{"key": "value"}, {"key": "value2"}]
71
+ )
72
+ assert result.name == expected
@@ -11,10 +11,10 @@ def test_connection_config_multiple_auth():
11
11
  with pytest.raises(ValidationError):
12
12
  ConfluenceConnectionConfig(
13
13
  access_config=ConfluenceAccessConfig(
14
- api_token="api_token",
15
- access_token="access_token",
14
+ password="api_token",
15
+ token="access_token",
16
16
  ),
17
- user_email="user_email",
17
+ username="user_email",
18
18
  url="url",
19
19
  )
20
20
 
@@ -26,14 +26,14 @@ def test_connection_config_no_auth():
26
26
 
27
27
  def test_connection_config_basic_auth():
28
28
  ConfluenceConnectionConfig(
29
- access_config=ConfluenceAccessConfig(api_token="api_token"),
29
+ access_config=ConfluenceAccessConfig(password="api_token"),
30
30
  url="url",
31
- user_email="user_email",
31
+ username="user_email",
32
32
  )
33
33
 
34
34
 
35
35
  def test_connection_config_pat_auth():
36
36
  ConfluenceConnectionConfig(
37
- access_config=ConfluenceAccessConfig(access_token="access_token"),
37
+ access_config=ConfluenceAccessConfig(token="access_token"),
38
38
  url="url",
39
39
  )
@@ -1 +1 @@
1
- __version__ = "0.3.14" # pragma: no cover
1
+ __version__ = "0.4.0" # pragma: no cover
@@ -2,7 +2,7 @@ import json
2
2
  from abc import ABC
3
3
  from dataclasses import dataclass
4
4
  from pathlib import Path
5
- from typing import Any, TypeVar
5
+ from typing import Any, Optional, TypeVar
6
6
 
7
7
  import ndjson
8
8
  from pydantic import BaseModel
@@ -22,10 +22,10 @@ UploadStagerConfigT = TypeVar("UploadStagerConfigT", bound=UploadStagerConfig)
22
22
  class UploadStager(BaseProcess, ABC):
23
23
  upload_stager_config: UploadStagerConfigT
24
24
 
25
- def write_output(self, output_path: Path, data: list[dict]) -> None:
25
+ def write_output(self, output_path: Path, data: list[dict], indent: Optional[int] = 2) -> None:
26
26
  if output_path.suffix == ".json":
27
27
  with output_path.open("w") as f:
28
- json.dump(data, f, indent=2)
28
+ json.dump(data, f, indent=indent)
29
29
  elif output_path.suffix == ".ndjson":
30
30
  with output_path.open("w") as f:
31
31
  ndjson.dump(data, f)
@@ -30,27 +30,45 @@ CONNECTOR_TYPE = "confluence"
30
30
 
31
31
 
32
32
  class ConfluenceAccessConfig(AccessConfig):
33
- api_token: Optional[str] = Field(description="Confluence API token", default=None)
34
- access_token: Optional[str] = Field(
35
- description="Confluence Personal Access Token", default=None
33
+ password: Optional[str] = Field(
34
+ description="Confluence password or Cloud API token",
35
+ default=None,
36
+ )
37
+ token: Optional[str] = Field(
38
+ description="Confluence Personal Access Token",
39
+ default=None,
36
40
  )
37
41
 
38
42
 
39
43
  class ConfluenceConnectionConfig(ConnectionConfig):
40
44
  url: str = Field(description="URL of the Confluence instance")
41
- user_email: Optional[str] = Field(description="User email for authentication", default=None)
45
+ username: Optional[str] = Field(
46
+ description="Username or email for authentication",
47
+ default=None,
48
+ )
49
+ cloud: bool = Field(description="Authenticate to Confluence Cloud", default=False)
42
50
  access_config: Secret[ConfluenceAccessConfig] = Field(
43
51
  description="Access configuration for Confluence"
44
52
  )
45
53
 
46
54
  def model_post_init(self, __context):
47
55
  access_configs = self.access_config.get_secret_value()
48
- basic_auth = self.user_email and access_configs.api_token
49
- pat_auth = access_configs.access_token
56
+ basic_auth = self.username and access_configs.password
57
+ pat_auth = access_configs.token
58
+ if self.cloud and not basic_auth:
59
+ raise ValueError(
60
+ "cloud authentication requires username and API token (--password), "
61
+ "see: https://atlassian-python-api.readthedocs.io/"
62
+ )
50
63
  if basic_auth and pat_auth:
51
- raise ValueError("both forms of auth provided, only one allowed")
64
+ raise ValueError(
65
+ "both password and token provided, only one allowed, "
66
+ "see: https://atlassian-python-api.readthedocs.io/"
67
+ )
52
68
  if not (basic_auth or pat_auth):
53
- raise ValueError("neither forms of auth provided")
69
+ raise ValueError(
70
+ "no form of auth provided, see: https://atlassian-python-api.readthedocs.io/"
71
+ )
54
72
 
55
73
  @requires_dependencies(["atlassian"], extras="confluence")
56
74
  def get_client(self) -> "Confluence":
@@ -59,8 +77,10 @@ class ConfluenceConnectionConfig(ConnectionConfig):
59
77
  access_configs = self.access_config.get_secret_value()
60
78
  return Confluence(
61
79
  url=self.url,
62
- username=self.user_email,
63
- password=access_configs.api_token,
80
+ username=self.username,
81
+ password=access_configs.password,
82
+ token=access_configs.token,
83
+ cloud=self.cloud,
64
84
  )
65
85
 
66
86
 
@@ -25,6 +25,8 @@ from .volumes_native import (
25
25
  databricks_native_volumes_destination_entry,
26
26
  databricks_native_volumes_source_entry,
27
27
  )
28
+ from .volumes_table import CONNECTOR_TYPE as VOLUMES_TABLE_CONNECTOR_TYPE
29
+ from .volumes_table import databricks_volumes_delta_tables_destination_entry
28
30
 
29
31
  add_source_entry(source_type=VOLUMES_AWS_CONNECTOR_TYPE, entry=databricks_aws_volumes_source_entry)
30
32
  add_destination_entry(
@@ -50,3 +52,7 @@ add_source_entry(
50
52
  add_destination_entry(
51
53
  destination_type=VOLUMES_AZURE_CONNECTOR_TYPE, entry=databricks_azure_volumes_destination_entry
52
54
  )
55
+ add_destination_entry(
56
+ destination_type=VOLUMES_TABLE_CONNECTOR_TYPE,
57
+ entry=databricks_volumes_delta_tables_destination_entry,
58
+ )
@@ -187,6 +187,11 @@ class DatabricksVolumesUploader(Uploader, ABC):
187
187
  upload_config: DatabricksVolumesUploaderConfig
188
188
  connection_config: DatabricksVolumesConnectionConfig
189
189
 
190
+ def get_output_path(self, file_data: FileData) -> str:
191
+ return os.path.join(
192
+ self.upload_config.path, f"{file_data.source_identifiers.filename}.json"
193
+ )
194
+
190
195
  def precheck(self) -> None:
191
196
  try:
192
197
  assert self.connection_config.get_client().current_user.me().active
@@ -194,9 +199,7 @@ class DatabricksVolumesUploader(Uploader, ABC):
194
199
  raise self.connection_config.wrap_error(e=e)
195
200
 
196
201
  def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
197
- output_path = os.path.join(
198
- self.upload_config.path, f"{file_data.source_identifiers.filename}.json"
199
- )
202
+ output_path = self.get_output_path(file_data=file_data)
200
203
  with open(path, "rb") as elements_file:
201
204
  try:
202
205
  self.connection_config.get_client().files.upload(
@@ -0,0 +1,106 @@
1
+ import json
2
+ import os
3
+ from contextlib import contextmanager
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from typing import Any, Generator, Optional
7
+
8
+ from pydantic import Field
9
+
10
+ from unstructured_ingest.v2.interfaces import FileData, Uploader, UploaderConfig
11
+ from unstructured_ingest.v2.logger import logger
12
+ from unstructured_ingest.v2.processes.connector_registry import (
13
+ DestinationRegistryEntry,
14
+ )
15
+ from unstructured_ingest.v2.processes.connectors.databricks.volumes import DatabricksPathMixin
16
+ from unstructured_ingest.v2.processes.connectors.sql.databricks_delta_tables import (
17
+ DatabrickDeltaTablesConnectionConfig,
18
+ DatabrickDeltaTablesUploadStager,
19
+ DatabrickDeltaTablesUploadStagerConfig,
20
+ )
21
+
22
+ CONNECTOR_TYPE = "databricks_volume_delta_tables"
23
+
24
+
25
+ class DatabricksVolumeDeltaTableUploaderConfig(UploaderConfig, DatabricksPathMixin):
26
+ database: str = Field(description="Database name", default="default")
27
+ table_name: str = Field(description="Table name")
28
+
29
+
30
+ @dataclass
31
+ class DatabricksVolumeDeltaTableStager(DatabrickDeltaTablesUploadStager):
32
+ def write_output(self, output_path: Path, data: list[dict], indent: Optional[int] = 2) -> None:
33
+ # To avoid new line issues when migrating from volumes into delta tables, omit indenting
34
+ # and always write it as a json file
35
+ with output_path.with_suffix(".json").open("w") as f:
36
+ json.dump(data, f)
37
+
38
+
39
+ @dataclass
40
+ class DatabricksVolumeDeltaTableUploader(Uploader):
41
+ connection_config: DatabrickDeltaTablesConnectionConfig
42
+ upload_config: DatabricksVolumeDeltaTableUploaderConfig
43
+ connector_type: str = CONNECTOR_TYPE
44
+
45
+ def precheck(self) -> None:
46
+ with self.connection_config.get_cursor() as cursor:
47
+ cursor.execute("SHOW CATALOGS")
48
+ catalogs = [r[0] for r in cursor.fetchall()]
49
+ if self.upload_config.catalog not in catalogs:
50
+ raise ValueError(
51
+ "Catalog {} not found in {}".format(
52
+ self.upload_config.catalog, ", ".join(catalogs)
53
+ )
54
+ )
55
+ cursor.execute(f"USE CATALOG '{self.upload_config.catalog}'")
56
+ cursor.execute("SHOW DATABASES")
57
+ databases = [r[0] for r in cursor.fetchall()]
58
+ if self.upload_config.database not in databases:
59
+ raise ValueError(
60
+ "Database {} not found in {}".format(
61
+ self.upload_config.database, ", ".join(databases)
62
+ )
63
+ )
64
+ cursor.execute("SHOW TABLES")
65
+ table_names = [r[1] for r in cursor.fetchall()]
66
+ if self.upload_config.table_name not in table_names:
67
+ raise ValueError(
68
+ "Table {} not found in {}".format(
69
+ self.upload_config.table_name, ", ".join(table_names)
70
+ )
71
+ )
72
+
73
+ def get_output_path(self, file_data: FileData, suffix: str = ".json") -> str:
74
+ filename = Path(file_data.source_identifiers.filename)
75
+ adjusted_filename = filename if filename.suffix == suffix else f"{filename}{suffix}"
76
+ return os.path.join(self.upload_config.path, f"{adjusted_filename}")
77
+
78
+ @contextmanager
79
+ def get_cursor(self, **connect_kwargs) -> Generator[Any, None, None]:
80
+ with self.connection_config.get_cursor(**connect_kwargs) as cursor:
81
+ cursor.execute(f"USE CATALOG '{self.upload_config.catalog}'")
82
+ yield cursor
83
+
84
+ def run(self, path: Path, file_data: FileData, **kwargs: Any) -> None:
85
+ with self.get_cursor(staging_allowed_local_path=str(path.parent)) as cursor:
86
+ catalog_path = self.get_output_path(file_data=file_data)
87
+ logger.debug(f"uploading {path.as_posix()} to {catalog_path}")
88
+ cursor.execute(f"PUT '{path.as_posix()}' INTO '{catalog_path}' OVERWRITE")
89
+ logger.debug(
90
+ f"migrating content from {catalog_path} to table {self.upload_config.table_name}"
91
+ )
92
+ with path.open() as f:
93
+ data = json.load(f)
94
+ columns = data[0].keys()
95
+ column_str = ", ".join(columns)
96
+ sql_statment = f"INSERT INTO `{self.upload_config.table_name}` ({column_str}) SELECT {column_str} FROM json.`{catalog_path}`" # noqa: E501
97
+ cursor.execute(sql_statment)
98
+
99
+
100
+ databricks_volumes_delta_tables_destination_entry = DestinationRegistryEntry(
101
+ connection_config=DatabrickDeltaTablesConnectionConfig,
102
+ uploader=DatabricksVolumeDeltaTableUploader,
103
+ uploader_config=DatabricksVolumeDeltaTableUploaderConfig,
104
+ upload_stager=DatabricksVolumeDeltaTableStager,
105
+ upload_stager_config=DatabrickDeltaTablesUploadStagerConfig,
106
+ )