pydataframer-databricks 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pydataframer_databricks-0.1.0/.github/workflows/python-publish.yml +70 -0
- pydataframer_databricks-0.1.0/.gitignore +11 -0
- pydataframer_databricks-0.1.0/PKG-INFO +43 -0
- pydataframer_databricks-0.1.0/README.md +28 -0
- pydataframer_databricks-0.1.0/pydataframer_databricks/__init__.py +11 -0
- pydataframer_databricks-0.1.0/pydataframer_databricks/connectors.py +251 -0
- pydataframer_databricks-0.1.0/pyproject.toml +25 -0
- pydataframer_databricks-0.1.0/tests/__init__.py +0 -0
- pydataframer_databricks-0.1.0/tests/test_connectors.py +231 -0
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
# This workflow will upload a Python Package to PyPI when a release is created
|
|
2
|
+
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
|
|
3
|
+
|
|
4
|
+
# This workflow uses actions that are not certified by GitHub.
|
|
5
|
+
# They are provided by a third-party and are governed by
|
|
6
|
+
# separate terms of service, privacy policy, and support
|
|
7
|
+
# documentation.
|
|
8
|
+
|
|
9
|
+
name: Upload Python Package
|
|
10
|
+
|
|
11
|
+
on:
|
|
12
|
+
release:
|
|
13
|
+
types: [published]
|
|
14
|
+
|
|
15
|
+
permissions:
|
|
16
|
+
contents: read
|
|
17
|
+
|
|
18
|
+
jobs:
|
|
19
|
+
release-build:
|
|
20
|
+
runs-on: ubuntu-latest
|
|
21
|
+
|
|
22
|
+
steps:
|
|
23
|
+
- uses: actions/checkout@v4
|
|
24
|
+
|
|
25
|
+
- uses: actions/setup-python@v5
|
|
26
|
+
with:
|
|
27
|
+
python-version: "3.x"
|
|
28
|
+
|
|
29
|
+
- name: Build release distributions
|
|
30
|
+
run: |
|
|
31
|
+
# NOTE: put your own distribution build steps here.
|
|
32
|
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
|
33
|
+
uv build
|
|
34
|
+
|
|
35
|
+
- name: Upload distributions
|
|
36
|
+
uses: actions/upload-artifact@v4
|
|
37
|
+
with:
|
|
38
|
+
name: release-dists
|
|
39
|
+
path: dist/
|
|
40
|
+
|
|
41
|
+
pypi-publish:
|
|
42
|
+
runs-on: ubuntu-latest
|
|
43
|
+
needs:
|
|
44
|
+
- release-build
|
|
45
|
+
permissions:
|
|
46
|
+
# IMPORTANT: this permission is mandatory for trusted publishing
|
|
47
|
+
id-token: write
|
|
48
|
+
|
|
49
|
+
# Dedicated environments with protections for publishing are strongly recommended.
|
|
50
|
+
# For more information, see: https://docs.github.com/en/actions/deployment/targeting-different-environments/using-environments-for-deployment#deployment-protection-rules
|
|
51
|
+
environment:
|
|
52
|
+
name: pypi
|
|
53
|
+
# OPTIONAL: uncomment and update to include your PyPI project URL in the deployment status:
|
|
54
|
+
url: https://pypi.org/p/pydataframer-databricks
|
|
55
|
+
#
|
|
56
|
+
# ALTERNATIVE: if your GitHub Release name is the PyPI project version string
|
|
57
|
+
# ALTERNATIVE: exactly, uncomment the following line instead:
|
|
58
|
+
# url: https://pypi.org/project/YOURPROJECT/${{ github.event.release.name }}
|
|
59
|
+
|
|
60
|
+
steps:
|
|
61
|
+
- name: Retrieve release distributions
|
|
62
|
+
uses: actions/download-artifact@v4
|
|
63
|
+
with:
|
|
64
|
+
name: release-dists
|
|
65
|
+
path: dist/
|
|
66
|
+
|
|
67
|
+
- name: Publish release distributions to PyPI
|
|
68
|
+
uses: pypa/gh-action-pypi-publish@release/v1
|
|
69
|
+
with:
|
|
70
|
+
packages-dir: dist/
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: pydataframer-databricks
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Databricks connector for Dataframer
|
|
5
|
+
Author-email: Dataframer <info@dataframer.ai>
|
|
6
|
+
License: MIT
|
|
7
|
+
Requires-Python: >=3.9
|
|
8
|
+
Requires-Dist: databricks-sdk>=0.81.0
|
|
9
|
+
Requires-Dist: databricks-sql-connector>=4.2.4
|
|
10
|
+
Requires-Dist: pandas>=2.0.0
|
|
11
|
+
Provides-Extra: dev
|
|
12
|
+
Requires-Dist: pytest-cov>=4.1.0; extra == 'dev'
|
|
13
|
+
Requires-Dist: pytest>=7.4.0; extra == 'dev'
|
|
14
|
+
Description-Content-Type: text/markdown
|
|
15
|
+
|
|
16
|
+
# pydataframer-databricks
|
|
17
|
+
|
|
18
|
+
Databricks connector package for Dataframer, providing seamless integration with Databricks SQL and data operations.
|
|
19
|
+
|
|
20
|
+
## Installation
|
|
21
|
+
|
|
22
|
+
```bash
|
|
23
|
+
pip install pydataframer-databricks
|
|
24
|
+
```
|
|
25
|
+
|
|
26
|
+
## Building
|
|
27
|
+
|
|
28
|
+
Requires [uv](https://docs.astral.sh/uv/) installed in your environment.
|
|
29
|
+
|
|
30
|
+
```bash
|
|
31
|
+
uv build
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
## Development
|
|
35
|
+
|
|
36
|
+
```bash
|
|
37
|
+
# Install with dev dependencies
|
|
38
|
+
uv pip install -e ".[dev]"
|
|
39
|
+
|
|
40
|
+
# Run tests
|
|
41
|
+
pytest
|
|
42
|
+
```
|
|
43
|
+
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
# pydataframer-databricks
|
|
2
|
+
|
|
3
|
+
Databricks connector package for Dataframer, providing seamless integration with Databricks SQL and data operations.
|
|
4
|
+
|
|
5
|
+
## Installation
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
pip install pydataframer-databricks
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
## Building
|
|
12
|
+
|
|
13
|
+
Requires [uv](https://docs.astral.sh/uv/) installed in your environment.
|
|
14
|
+
|
|
15
|
+
```bash
|
|
16
|
+
uv build
|
|
17
|
+
```
|
|
18
|
+
|
|
19
|
+
## Development
|
|
20
|
+
|
|
21
|
+
```bash
|
|
22
|
+
# Install with dev dependencies
|
|
23
|
+
uv pip install -e ".[dev]"
|
|
24
|
+
|
|
25
|
+
# Run tests
|
|
26
|
+
pytest
|
|
27
|
+
```
|
|
28
|
+
|
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class DatasetType(Enum):
|
|
5
|
+
"""Dataset type enumeration matching Dataframer backend."""
|
|
6
|
+
SINGLE_FILE = "SINGLE_FILE"
|
|
7
|
+
MULTI_FILE = "MULTI_FILE"
|
|
8
|
+
MULTI_FOLDER = "MULTI_FOLDER"
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class FileType(Enum):
|
|
12
|
+
"""File type enumeration matching Dataframer backend."""
|
|
13
|
+
MD = "md"
|
|
14
|
+
TXT = "txt"
|
|
15
|
+
CSV = "csv"
|
|
16
|
+
PDF = "pdf"
|
|
17
|
+
JSON = "json"
|
|
18
|
+
JSONL = "jsonl"
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class DatabricksConnector:
|
|
22
|
+
"""
|
|
23
|
+
Databricks connector for Dataframer workflows.
|
|
24
|
+
|
|
25
|
+
This class provides methods to interact with Databricks SQL, fetch sample data,
|
|
26
|
+
and load generated data into Databricks tables.
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
dbutils : DBUtils
|
|
31
|
+
The dbutils object from your Databricks notebook context.
|
|
32
|
+
This is automatically available in Databricks notebooks.
|
|
33
|
+
|
|
34
|
+
Examples
|
|
35
|
+
--------
|
|
36
|
+
>>> databricks_connector = DatabricksConnector(dbutils)
|
|
37
|
+
>>> df = databricks_connector.fetch_sample_data(
|
|
38
|
+
... num_items_to_select=25,
|
|
39
|
+
... table_name="samples.bakehouse.media_customer_reviews"
|
|
40
|
+
... )
|
|
41
|
+
>>> df.head()
|
|
42
|
+
"""
|
|
43
|
+
|
|
44
|
+
def __init__(self, dbutils):
|
|
45
|
+
"""
|
|
46
|
+
Initialize the Databricks connector.
|
|
47
|
+
|
|
48
|
+
Parameters
|
|
49
|
+
----------
|
|
50
|
+
dbutils : DBUtils
|
|
51
|
+
The dbutils object from your Databricks notebook context.
|
|
52
|
+
"""
|
|
53
|
+
self.dbutils = dbutils
|
|
54
|
+
|
|
55
|
+
def get_connection(self):
|
|
56
|
+
"""
|
|
57
|
+
Return an authenticated Databricks SQL connection.
|
|
58
|
+
|
|
59
|
+
Returns
|
|
60
|
+
-------
|
|
61
|
+
Connection
|
|
62
|
+
A Databricks SQL connection object.
|
|
63
|
+
"""
|
|
64
|
+
from databricks import sql
|
|
65
|
+
from databricks.sdk.core import Config, oauth_service_principal
|
|
66
|
+
|
|
67
|
+
server_hostname = self.dbutils.secrets.get("dataframer", "DATABRICKS_SERVER_HOSTNAME")
|
|
68
|
+
http_path = self.dbutils.secrets.get("dataframer", "DATABRICKS_HTTP_PATH")
|
|
69
|
+
|
|
70
|
+
def credential_provider():
|
|
71
|
+
config = Config(
|
|
72
|
+
host=f"https://{server_hostname}",
|
|
73
|
+
client_id=self.dbutils.secrets.get("dataframer", "DATABRICKS_CLIENT_ID"),
|
|
74
|
+
client_secret=self.dbutils.secrets.get("dataframer", "DATABRICKS_CLIENT_SECRET"),
|
|
75
|
+
)
|
|
76
|
+
return oauth_service_principal(config)
|
|
77
|
+
|
|
78
|
+
return sql.connect(
|
|
79
|
+
server_hostname=server_hostname,
|
|
80
|
+
http_path=http_path,
|
|
81
|
+
credentials_provider=credential_provider,
|
|
82
|
+
user_agent_entry="dataframer_user_agent",
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
def fetch_sample_data(self, num_items_to_select, table_name):
|
|
86
|
+
"""
|
|
87
|
+
Fetch sample data from a Databricks table and return it as a Pandas DataFrame.
|
|
88
|
+
|
|
89
|
+
Parameters
|
|
90
|
+
----------
|
|
91
|
+
num_items_to_select : int
|
|
92
|
+
Number of rows to fetch from the table.
|
|
93
|
+
table_name : str
|
|
94
|
+
Fully qualified table name (e.g., "catalog.schema.table").
|
|
95
|
+
|
|
96
|
+
Returns
|
|
97
|
+
-------
|
|
98
|
+
pd.DataFrame
|
|
99
|
+
A Pandas DataFrame containing the sample data.
|
|
100
|
+
|
|
101
|
+
Examples
|
|
102
|
+
--------
|
|
103
|
+
>>> databricks_connector = DatabricksConnector(dbutils)
|
|
104
|
+
>>> df = databricks_connector.fetch_sample_data(
|
|
105
|
+
... num_items_to_select=25,
|
|
106
|
+
... table_name="samples.bakehouse.media_customer_reviews"
|
|
107
|
+
... )
|
|
108
|
+
>>> df.head()
|
|
109
|
+
"""
|
|
110
|
+
import pandas as pd
|
|
111
|
+
|
|
112
|
+
query = f"""
|
|
113
|
+
SELECT *
|
|
114
|
+
FROM {table_name}
|
|
115
|
+
LIMIT {num_items_to_select}
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
try:
|
|
119
|
+
with self.get_connection() as connection:
|
|
120
|
+
with connection.cursor() as cursor:
|
|
121
|
+
cursor.execute(query)
|
|
122
|
+
rows = cursor.fetchall()
|
|
123
|
+
columns = [desc[0] for desc in cursor.description]
|
|
124
|
+
except Exception as e:
|
|
125
|
+
error_msg = f"Failed to fetch data from table `{table_name}`"
|
|
126
|
+
print(f"{error_msg}: {str(e)}")
|
|
127
|
+
print("Verify table exists, is accessible, and you have SELECT permissions")
|
|
128
|
+
raise RuntimeError(f"{error_msg}: {str(e)}") from e
|
|
129
|
+
|
|
130
|
+
return pd.DataFrame(rows, columns=columns)
|
|
131
|
+
|
|
132
|
+
def load_generated_data(self, table_name, downloaded_zip, dataset_type, file_type):
|
|
133
|
+
"""
|
|
134
|
+
Load generated samples from a ZIP file into a Databricks table using Databricks SQL.
|
|
135
|
+
|
|
136
|
+
Parameters
|
|
137
|
+
----------
|
|
138
|
+
table_name : str
|
|
139
|
+
Target table name (e.g., "catalog.schema.table")
|
|
140
|
+
downloaded_zip : file-like
|
|
141
|
+
ZIP file object containing the generated data file
|
|
142
|
+
dataset_type : DatasetType
|
|
143
|
+
Type of dataset structure (DatasetType.SINGLE_FILE, DatasetType.MULTI_FILE, or DatasetType.MULTI_FOLDER)
|
|
144
|
+
file_type : FileType
|
|
145
|
+
Type of file in the ZIP (FileType.CSV, FileType.JSON, FileType.JSONL, etc.)
|
|
146
|
+
|
|
147
|
+
Examples
|
|
148
|
+
--------
|
|
149
|
+
>>> databricks_connector = DatabricksConnector(dbutils)
|
|
150
|
+
>>> with open("samples.zip", "rb") as f:
|
|
151
|
+
... databricks_connector.load_generated_data(
|
|
152
|
+
... table_name="my_catalog.my_schema.my_table",
|
|
153
|
+
... downloaded_zip=f,
|
|
154
|
+
... dataset_type=DatasetType.SINGLE_FILE,
|
|
155
|
+
... file_type=FileType.CSV
|
|
156
|
+
... )
|
|
157
|
+
"""
|
|
158
|
+
import zipfile
|
|
159
|
+
import pandas as pd
|
|
160
|
+
from io import BytesIO
|
|
161
|
+
|
|
162
|
+
zip_buffer = BytesIO(downloaded_zip.read())
|
|
163
|
+
|
|
164
|
+
if dataset_type == DatasetType.SINGLE_FILE:
|
|
165
|
+
try:
|
|
166
|
+
with zipfile.ZipFile(zip_buffer) as z:
|
|
167
|
+
file_list = z.namelist()
|
|
168
|
+
|
|
169
|
+
generated_data_files = [f for f in file_list if f.lower().endswith(f'.{file_type.value}')]
|
|
170
|
+
|
|
171
|
+
if len(generated_data_files) != 1:
|
|
172
|
+
error_msg = f"Expected exactly one .{file_type.value} file in ZIP"
|
|
173
|
+
print(f"{error_msg}. Available files: {file_list}")
|
|
174
|
+
raise ValueError(error_msg)
|
|
175
|
+
|
|
176
|
+
data_filename = generated_data_files[0]
|
|
177
|
+
data_bytes = z.read(data_filename)
|
|
178
|
+
print(f"Found {file_type.value} file: {data_filename}")
|
|
179
|
+
|
|
180
|
+
except zipfile.BadZipFile as e:
|
|
181
|
+
error_msg = "Invalid or corrupted ZIP file"
|
|
182
|
+
print(f"{error_msg}: {str(e)}")
|
|
183
|
+
raise ValueError(f"{error_msg}: {str(e)}") from e
|
|
184
|
+
except ValueError:
|
|
185
|
+
raise
|
|
186
|
+
except Exception as e:
|
|
187
|
+
error_msg = "Failed to extract file from ZIP"
|
|
188
|
+
print(f"{error_msg}: {str(e)}")
|
|
189
|
+
raise RuntimeError(f"{error_msg}: {str(e)}") from e
|
|
190
|
+
|
|
191
|
+
if file_type == FileType.CSV:
|
|
192
|
+
pandas_df = pd.read_csv(BytesIO(data_bytes))
|
|
193
|
+
elif file_type == FileType.JSON:
|
|
194
|
+
# TODO: Implement JSON file handling
|
|
195
|
+
pass
|
|
196
|
+
elif file_type == FileType.JSONL:
|
|
197
|
+
# TODO: Implement JSONL file handling
|
|
198
|
+
pass
|
|
199
|
+
else:
|
|
200
|
+
raise ValueError(f"Unsupported file_type: {file_type}. Supported: CSV, JSON, JSONL for SINGLE_FILE datasets")
|
|
201
|
+
|
|
202
|
+
with self.get_connection() as connection:
|
|
203
|
+
cursor = connection.cursor()
|
|
204
|
+
|
|
205
|
+
columns_sql = ", ".join(
|
|
206
|
+
f"`{col}` STRING" for col in pandas_df.columns
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
try:
|
|
210
|
+
cursor.execute(f"""
|
|
211
|
+
CREATE OR REPLACE TABLE {table_name} (
|
|
212
|
+
{columns_sql}
|
|
213
|
+
)
|
|
214
|
+
""")
|
|
215
|
+
except Exception as e:
|
|
216
|
+
error_msg = f"Failed to create table `{table_name}`"
|
|
217
|
+
print(f"{error_msg}: {str(e)}")
|
|
218
|
+
print("Verify table name format (catalog.schema.table), permissions, and warehouse is running")
|
|
219
|
+
cursor.close()
|
|
220
|
+
raise RuntimeError(f"{error_msg}: {str(e)}") from e
|
|
221
|
+
|
|
222
|
+
insert_sql = f"""
|
|
223
|
+
INSERT INTO {table_name}
|
|
224
|
+
VALUES ({", ".join(["?"] * len(pandas_df.columns))})
|
|
225
|
+
"""
|
|
226
|
+
|
|
227
|
+
try:
|
|
228
|
+
cursor.executemany(
|
|
229
|
+
insert_sql,
|
|
230
|
+
pandas_df.values.tolist()
|
|
231
|
+
)
|
|
232
|
+
except Exception as e:
|
|
233
|
+
error_msg = f"Failed to insert data into table `{table_name}`"
|
|
234
|
+
print(f"{error_msg}: {str(e)} | Rows attempted: {len(pandas_df)}")
|
|
235
|
+
cursor.close()
|
|
236
|
+
raise RuntimeError(f"{error_msg}: {str(e)}") from e
|
|
237
|
+
|
|
238
|
+
cursor.close()
|
|
239
|
+
|
|
240
|
+
print(f"✅ Table `{table_name}` saved successfully using Databricks SQL")
|
|
241
|
+
|
|
242
|
+
elif dataset_type == DatasetType.MULTI_FILE:
|
|
243
|
+
# TODO: Implement MULTI_FILE handling
|
|
244
|
+
pass
|
|
245
|
+
|
|
246
|
+
elif dataset_type == DatasetType.MULTI_FOLDER:
|
|
247
|
+
# TODO: Implement MULTI_FOLDER handling
|
|
248
|
+
pass
|
|
249
|
+
|
|
250
|
+
else:
|
|
251
|
+
raise ValueError(f"Invalid dataset_type: {dataset_type}. Expected DatasetType enum")
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "pydataframer-databricks"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Databricks connector for Dataframer"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.9"
|
|
7
|
+
license = { text = "MIT" }
|
|
8
|
+
authors = [
|
|
9
|
+
{ name = "Dataframer", email = "info@dataframer.ai" }
|
|
10
|
+
]
|
|
11
|
+
dependencies = [
|
|
12
|
+
"pandas>=2.0.0",
|
|
13
|
+
"databricks-sdk>=0.81.0",
|
|
14
|
+
"databricks-sql-connector>=4.2.4",
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
[project.optional-dependencies]
|
|
18
|
+
dev = [
|
|
19
|
+
"pytest>=7.4.0",
|
|
20
|
+
"pytest-cov>=4.1.0",
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
[build-system]
|
|
24
|
+
requires = ["hatchling"]
|
|
25
|
+
build-backend = "hatchling.build"
|
|
File without changes
|
|
@@ -0,0 +1,231 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
from unittest.mock import Mock, MagicMock, patch
|
|
3
|
+
import pandas as pd
|
|
4
|
+
from io import BytesIO
|
|
5
|
+
import zipfile
|
|
6
|
+
from pydataframer_databricks import FileType, DatasetType
|
|
7
|
+
from pydataframer_databricks import FileType
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class TestDatabricksConnector:
|
|
11
|
+
"""Test suite for DatabricksConnector class"""
|
|
12
|
+
|
|
13
|
+
@pytest.fixture
|
|
14
|
+
def mock_dbutils(self):
|
|
15
|
+
"""Create a mock dbutils object"""
|
|
16
|
+
dbutils = Mock()
|
|
17
|
+
dbutils.secrets.get = Mock(side_effect=lambda scope, key: {
|
|
18
|
+
("dataframer", "DATABRICKS_SERVER_HOSTNAME"): "test.databricks.com",
|
|
19
|
+
("dataframer", "DATABRICKS_HTTP_PATH"): "/sql/1.0/warehouses/abc123",
|
|
20
|
+
("dataframer", "DATABRICKS_CLIENT_ID"): "test-client-id",
|
|
21
|
+
("dataframer", "DATABRICKS_CLIENT_SECRET"): "test-client-secret",
|
|
22
|
+
}.get((scope, key), "default-value"))
|
|
23
|
+
return dbutils
|
|
24
|
+
|
|
25
|
+
@pytest.fixture
|
|
26
|
+
def connector(self, mock_dbutils):
|
|
27
|
+
"""Create a DatabricksConnector instance with mocked dbutils"""
|
|
28
|
+
from pydataframer_databricks import DatabricksConnector
|
|
29
|
+
return DatabricksConnector(mock_dbutils)
|
|
30
|
+
|
|
31
|
+
def test_init(self, mock_dbutils):
|
|
32
|
+
"""Test connector initialization"""
|
|
33
|
+
from pydataframer_databricks import DatabricksConnector
|
|
34
|
+
connector = DatabricksConnector(mock_dbutils)
|
|
35
|
+
assert connector.dbutils == mock_dbutils
|
|
36
|
+
|
|
37
|
+
def test_get_connection(self, connector, mock_dbutils):
|
|
38
|
+
"""Test get_connection establishes connection with correct parameters"""
|
|
39
|
+
with patch('databricks.sql.connect') as mock_sql_connect:
|
|
40
|
+
mock_connection = Mock()
|
|
41
|
+
mock_sql_connect.return_value = mock_connection
|
|
42
|
+
|
|
43
|
+
with patch('databricks.sdk.core.oauth_service_principal') as mock_oauth:
|
|
44
|
+
result = connector.get_connection()
|
|
45
|
+
|
|
46
|
+
mock_dbutils.secrets.get.assert_called()
|
|
47
|
+
|
|
48
|
+
mock_sql_connect.assert_called_once()
|
|
49
|
+
call_kwargs = mock_sql_connect.call_args.kwargs
|
|
50
|
+
assert call_kwargs['server_hostname'] == "test.databricks.com"
|
|
51
|
+
assert call_kwargs['http_path'] == "/sql/1.0/warehouses/abc123"
|
|
52
|
+
assert call_kwargs['user_agent_entry'] == "dataframer_user_agent"
|
|
53
|
+
|
|
54
|
+
assert result == mock_connection
|
|
55
|
+
|
|
56
|
+
@patch('pydataframer_databricks.connectors.DatabricksConnector.get_connection')
|
|
57
|
+
def test_fetch_sample_data_success(self, mock_get_connection, connector):
|
|
58
|
+
"""Test fetch_sample_data successfully fetches and returns DataFrame"""
|
|
59
|
+
mock_cursor = Mock()
|
|
60
|
+
mock_cursor.description = [("id",), ("name",), ("value",)]
|
|
61
|
+
mock_cursor.fetchall.return_value = [
|
|
62
|
+
(1, "test1", 100),
|
|
63
|
+
(2, "test2", 200),
|
|
64
|
+
]
|
|
65
|
+
|
|
66
|
+
mock_connection = MagicMock()
|
|
67
|
+
mock_connection.__enter__.return_value = mock_connection
|
|
68
|
+
mock_connection.cursor.return_value.__enter__.return_value = mock_cursor
|
|
69
|
+
|
|
70
|
+
mock_get_connection.return_value = mock_connection
|
|
71
|
+
|
|
72
|
+
result = connector.fetch_sample_data(
|
|
73
|
+
num_items_to_select=25,
|
|
74
|
+
table_name="test_catalog.test_schema.test_table"
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
assert isinstance(result, pd.DataFrame)
|
|
78
|
+
assert len(result) == 2
|
|
79
|
+
assert list(result.columns) == ["id", "name", "value"]
|
|
80
|
+
mock_cursor.execute.assert_called_once()
|
|
81
|
+
|
|
82
|
+
@patch('pydataframer_databricks.connectors.DatabricksConnector.get_connection')
|
|
83
|
+
def test_fetch_sample_data_query_failure(self, mock_get_connection, connector):
|
|
84
|
+
"""Test fetch_sample_data handles query execution errors"""
|
|
85
|
+
mock_cursor = Mock()
|
|
86
|
+
mock_cursor.execute.side_effect = Exception("Table not found")
|
|
87
|
+
|
|
88
|
+
mock_connection = MagicMock()
|
|
89
|
+
mock_connection.__enter__.return_value = mock_connection
|
|
90
|
+
mock_connection.cursor.return_value.__enter__.return_value = mock_cursor
|
|
91
|
+
|
|
92
|
+
mock_get_connection.return_value = mock_connection
|
|
93
|
+
|
|
94
|
+
with pytest.raises(RuntimeError) as exc_info:
|
|
95
|
+
connector.fetch_sample_data(
|
|
96
|
+
num_items_to_select=10,
|
|
97
|
+
table_name="nonexistent.table"
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
assert "Failed to fetch data from table" in str(exc_info.value)
|
|
101
|
+
|
|
102
|
+
def test_load_generated_data_single_file_csv_success(self, connector):
|
|
103
|
+
"""Test load_generated_data with SINGLE_FILE CSV dataset"""
|
|
104
|
+
csv_content = b"id,name,value\n1,test1,100\n2,test2,200\n"
|
|
105
|
+
zip_buffer = BytesIO()
|
|
106
|
+
with zipfile.ZipFile(zip_buffer, 'w') as z:
|
|
107
|
+
z.writestr("generated_samples.csv", csv_content)
|
|
108
|
+
zip_buffer.seek(0)
|
|
109
|
+
|
|
110
|
+
mock_cursor = Mock()
|
|
111
|
+
mock_connection = MagicMock()
|
|
112
|
+
mock_connection.__enter__.return_value = mock_connection
|
|
113
|
+
mock_connection.cursor.return_value = mock_cursor
|
|
114
|
+
|
|
115
|
+
with patch.object(connector, 'get_connection', return_value=mock_connection):
|
|
116
|
+
connector.load_generated_data(
|
|
117
|
+
table_name="test_catalog.test_schema.test_table",
|
|
118
|
+
downloaded_zip=zip_buffer,
|
|
119
|
+
dataset_type=DatasetType.SINGLE_FILE,
|
|
120
|
+
file_type=FileType.CSV
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
assert mock_cursor.execute.call_count == 1
|
|
124
|
+
create_table_call = mock_cursor.execute.call_args[0][0]
|
|
125
|
+
assert "CREATE OR REPLACE TABLE" in create_table_call
|
|
126
|
+
assert "test_catalog.test_schema.test_table" in create_table_call
|
|
127
|
+
|
|
128
|
+
assert mock_cursor.executemany.call_count == 1
|
|
129
|
+
assert mock_cursor.close.call_count == 1
|
|
130
|
+
|
|
131
|
+
def test_load_generated_data_invalid_zip(self, connector):
|
|
132
|
+
"""Test load_generated_data handles invalid ZIP files"""
|
|
133
|
+
invalid_zip = BytesIO(b"not a zip file")
|
|
134
|
+
|
|
135
|
+
with pytest.raises(ValueError) as exc_info:
|
|
136
|
+
connector.load_generated_data(
|
|
137
|
+
table_name="test_table",
|
|
138
|
+
downloaded_zip=invalid_zip,
|
|
139
|
+
dataset_type=DatasetType.SINGLE_FILE,
|
|
140
|
+
file_type=FileType.CSV
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
assert "Invalid or corrupted ZIP file" in str(exc_info.value)
|
|
144
|
+
|
|
145
|
+
def test_load_generated_data_no_csv_file(self, connector):
|
|
146
|
+
"""Test load_generated_data when ZIP has no CSV file"""
|
|
147
|
+
zip_buffer = BytesIO()
|
|
148
|
+
with zipfile.ZipFile(zip_buffer, 'w') as z:
|
|
149
|
+
z.writestr("data.txt", b"some text")
|
|
150
|
+
zip_buffer.seek(0)
|
|
151
|
+
|
|
152
|
+
with pytest.raises(ValueError) as exc_info:
|
|
153
|
+
connector.load_generated_data(
|
|
154
|
+
table_name="test_table",
|
|
155
|
+
downloaded_zip=zip_buffer,
|
|
156
|
+
dataset_type=DatasetType.SINGLE_FILE,
|
|
157
|
+
file_type=FileType.CSV
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
assert "Expected exactly one .csv file in ZIP" in str(exc_info.value)
|
|
161
|
+
|
|
162
|
+
def test_load_generated_data_multiple_csv_files(self, connector):
|
|
163
|
+
"""Test load_generated_data when ZIP has multiple CSV files"""
|
|
164
|
+
zip_buffer = BytesIO()
|
|
165
|
+
with zipfile.ZipFile(zip_buffer, 'w') as z:
|
|
166
|
+
z.writestr("data1.csv", b"id,name\n1,test1")
|
|
167
|
+
z.writestr("data2.csv", b"id,name\n2,test2")
|
|
168
|
+
zip_buffer.seek(0)
|
|
169
|
+
|
|
170
|
+
with pytest.raises(ValueError) as exc_info:
|
|
171
|
+
connector.load_generated_data(
|
|
172
|
+
table_name="test_table",
|
|
173
|
+
downloaded_zip=zip_buffer,
|
|
174
|
+
dataset_type=DatasetType.SINGLE_FILE,
|
|
175
|
+
file_type=FileType.CSV
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
assert "Expected exactly one .csv file in ZIP" in str(exc_info.value)
|
|
179
|
+
|
|
180
|
+
def test_load_generated_data_create_table_failure(self, connector):
|
|
181
|
+
"""Test load_generated_data handles CREATE TABLE errors"""
|
|
182
|
+
csv_content = b"id,name\n1,test1"
|
|
183
|
+
zip_buffer = BytesIO()
|
|
184
|
+
with zipfile.ZipFile(zip_buffer, 'w') as z:
|
|
185
|
+
z.writestr("data.csv", csv_content)
|
|
186
|
+
zip_buffer.seek(0)
|
|
187
|
+
|
|
188
|
+
mock_cursor = Mock()
|
|
189
|
+
mock_cursor.execute.side_effect = Exception("Permission denied")
|
|
190
|
+
mock_connection = MagicMock()
|
|
191
|
+
mock_connection.__enter__.return_value = mock_connection
|
|
192
|
+
mock_connection.cursor.return_value = mock_cursor
|
|
193
|
+
|
|
194
|
+
with patch.object(connector, 'get_connection', return_value=mock_connection):
|
|
195
|
+
with pytest.raises(RuntimeError) as exc_info:
|
|
196
|
+
connector.load_generated_data(
|
|
197
|
+
table_name="test_table",
|
|
198
|
+
downloaded_zip=zip_buffer,
|
|
199
|
+
dataset_type=DatasetType.SINGLE_FILE,
|
|
200
|
+
file_type=FileType.CSV
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
assert "Failed to create table" in str(exc_info.value)
|
|
204
|
+
assert mock_cursor.close.call_count == 1
|
|
205
|
+
|
|
206
|
+
def test_load_generated_data_insert_failure(self, connector):
|
|
207
|
+
"""Test load_generated_data handles INSERT errors"""
|
|
208
|
+
csv_content = b"id,name\n1,test1"
|
|
209
|
+
zip_buffer = BytesIO()
|
|
210
|
+
with zipfile.ZipFile(zip_buffer, 'w') as z:
|
|
211
|
+
z.writestr("data.csv", csv_content)
|
|
212
|
+
zip_buffer.seek(0)
|
|
213
|
+
|
|
214
|
+
mock_cursor = Mock()
|
|
215
|
+
mock_cursor.execute.return_value = None
|
|
216
|
+
mock_cursor.executemany.side_effect = Exception("Constraint violation")
|
|
217
|
+
mock_connection = MagicMock()
|
|
218
|
+
mock_connection.__enter__.return_value = mock_connection
|
|
219
|
+
mock_connection.cursor.return_value = mock_cursor
|
|
220
|
+
|
|
221
|
+
with patch.object(connector, 'get_connection', return_value=mock_connection):
|
|
222
|
+
with pytest.raises(RuntimeError) as exc_info:
|
|
223
|
+
connector.load_generated_data(
|
|
224
|
+
table_name="test_table",
|
|
225
|
+
downloaded_zip=zip_buffer,
|
|
226
|
+
dataset_type=DatasetType.SINGLE_FILE,
|
|
227
|
+
file_type=FileType.CSV
|
|
228
|
+
)
|
|
229
|
+
|
|
230
|
+
assert "Failed to insert data" in str(exc_info.value)
|
|
231
|
+
assert mock_cursor.close.call_count == 1
|