rivetsql-databricks 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rivetsql_databricks-0.1.0/.gitignore +25 -0
- rivetsql_databricks-0.1.0/PKG-INFO +19 -0
- rivetsql_databricks-0.1.0/__init__.py +55 -0
- rivetsql_databricks-0.1.0/adapters/__init__.py +1 -0
- rivetsql_databricks-0.1.0/adapters/duckdb.py +446 -0
- rivetsql_databricks-0.1.0/adapters/unity.py +341 -0
- rivetsql_databricks-0.1.0/auth.py +294 -0
- rivetsql_databricks-0.1.0/client.py +194 -0
- rivetsql_databricks-0.1.0/databricks_catalog.py +279 -0
- rivetsql_databricks-0.1.0/databricks_cross_joint.py +60 -0
- rivetsql_databricks-0.1.0/databricks_sink.py +457 -0
- rivetsql_databricks-0.1.0/databricks_source.py +223 -0
- rivetsql_databricks-0.1.0/engine.py +570 -0
- rivetsql_databricks-0.1.0/py.typed +0 -0
- rivetsql_databricks-0.1.0/pyproject.toml +38 -0
- rivetsql_databricks-0.1.0/unity_catalog.py +490 -0
- rivetsql_databricks-0.1.0/unity_sink.py +315 -0
- rivetsql_databricks-0.1.0/unity_source.py +198 -0
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
.DS_Store
|
|
2
|
+
.env
|
|
3
|
+
.env.*
|
|
4
|
+
*.pem
|
|
5
|
+
*.key
|
|
6
|
+
secrets/
|
|
7
|
+
target/
|
|
8
|
+
build/
|
|
9
|
+
dist/
|
|
10
|
+
out/
|
|
11
|
+
.next/
|
|
12
|
+
node_modules/
|
|
13
|
+
__pycache__/
|
|
14
|
+
*.pyc
|
|
15
|
+
.pytest_cache/
|
|
16
|
+
.hypothesis/
|
|
17
|
+
.cargo/
|
|
18
|
+
*.class
|
|
19
|
+
bin/
|
|
20
|
+
obj/
|
|
21
|
+
.ralph-logs/
|
|
22
|
+
mcp.json
|
|
23
|
+
docs/reference/
|
|
24
|
+
docs/_build/
|
|
25
|
+
*.zip
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: rivetsql-databricks
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Databricks and Unity Catalog plugin for Rivet SQL
|
|
5
|
+
Project-URL: Homepage, https://github.com/rivetsql/rivet
|
|
6
|
+
Project-URL: Repository, https://github.com/rivetsql/rivet
|
|
7
|
+
Author: Rivet Contributors
|
|
8
|
+
License-Expression: MIT
|
|
9
|
+
Keywords: data-pipeline,databricks,rivet,sql
|
|
10
|
+
Classifier: Development Status :: 4 - Beta
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
15
|
+
Classifier: Topic :: Database
|
|
16
|
+
Requires-Python: >=3.11
|
|
17
|
+
Requires-Dist: pyarrow>=14.0
|
|
18
|
+
Requires-Dist: requests>=2.28
|
|
19
|
+
Requires-Dist: rivetsql-core>=0.1.0
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
"""rivet_databricks — Databricks plugin for Rivet."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from rivet_core.plugins import PluginRegistry
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def DatabricksPlugin(registry: PluginRegistry) -> None:
|
|
12
|
+
"""Register all rivet_databricks components into the plugin registry.
|
|
13
|
+
|
|
14
|
+
Core components (catalogs, sources, sinks, cross-joint) are always
|
|
15
|
+
registered. Engine and cross-catalog adapters are registered best-effort
|
|
16
|
+
since they depend on optional packages.
|
|
17
|
+
"""
|
|
18
|
+
# Core components — always registered
|
|
19
|
+
from rivet_databricks.databricks_catalog import DatabricksCatalogPlugin
|
|
20
|
+
from rivet_databricks.databricks_cross_joint import DatabricksCrossJointAdapter
|
|
21
|
+
from rivet_databricks.databricks_sink import DatabricksSink
|
|
22
|
+
from rivet_databricks.databricks_source import DatabricksSource
|
|
23
|
+
from rivet_databricks.unity_catalog import UnityCatalogPlugin
|
|
24
|
+
from rivet_databricks.unity_sink import UnitySink
|
|
25
|
+
from rivet_databricks.unity_source import UnitySource
|
|
26
|
+
|
|
27
|
+
registry.register_catalog_plugin(UnityCatalogPlugin())
|
|
28
|
+
registry.register_catalog_plugin(DatabricksCatalogPlugin())
|
|
29
|
+
registry.register_source(UnitySource())
|
|
30
|
+
registry.register_source(DatabricksSource())
|
|
31
|
+
registry.register_sink(UnitySink())
|
|
32
|
+
registry.register_sink(DatabricksSink())
|
|
33
|
+
registry.register_cross_joint_adapter(DatabricksCrossJointAdapter())
|
|
34
|
+
|
|
35
|
+
# Engine depends on databricks-sql-connector — register best-effort
|
|
36
|
+
try:
|
|
37
|
+
from rivet_databricks.engine import DatabricksComputeEnginePlugin
|
|
38
|
+
|
|
39
|
+
registry.register_engine_plugin(DatabricksComputeEnginePlugin())
|
|
40
|
+
except ImportError:
|
|
41
|
+
pass
|
|
42
|
+
|
|
43
|
+
# Adapters depend on optional packages — register best-effort
|
|
44
|
+
try:
|
|
45
|
+
from rivet_databricks.adapters.unity import DatabricksUnityAdapter
|
|
46
|
+
|
|
47
|
+
registry.register_adapter(DatabricksUnityAdapter())
|
|
48
|
+
except ImportError:
|
|
49
|
+
pass
|
|
50
|
+
try:
|
|
51
|
+
from rivet_databricks.adapters.duckdb import DatabricksDuckDBAdapter
|
|
52
|
+
|
|
53
|
+
registry.register_adapter(DatabricksDuckDBAdapter())
|
|
54
|
+
except ImportError:
|
|
55
|
+
pass
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Databricks engine adapters for external catalog types."""
|
|
@@ -0,0 +1,446 @@
|
|
|
1
|
+
"""DatabricksDuckDBAdapter: credential vending, httpfs + storage read/write via DuckDB."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import warnings
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
import pyarrow
|
|
10
|
+
|
|
11
|
+
from rivet_core.errors import ExecutionError, plugin_error
|
|
12
|
+
from rivet_core.models import Material
|
|
13
|
+
from rivet_core.optimizer import AdapterPushdownResult, Cast, PushdownPlan, ResidualPlan
|
|
14
|
+
from rivet_core.plugins import ComputeEngineAdapter
|
|
15
|
+
from rivet_core.strategies import MaterializedRef
|
|
16
|
+
|
|
17
|
+
_logger = logging.getLogger(__name__)
|
|
18
|
+
|
|
19
|
+
_READ_CAPABILITIES = [
|
|
20
|
+
"projection_pushdown",
|
|
21
|
+
"predicate_pushdown",
|
|
22
|
+
"limit_pushdown",
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
_WRITE_CAPABILITIES = [
|
|
26
|
+
"write_append",
|
|
27
|
+
"write_replace",
|
|
28
|
+
"write_partition",
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
_FORMAT_TO_READER: dict[str, str] = {
|
|
32
|
+
"PARQUET": "read_parquet",
|
|
33
|
+
"DELTA": "delta_scan",
|
|
34
|
+
"CSV": "read_csv_auto",
|
|
35
|
+
"JSON": "read_json_auto",
|
|
36
|
+
"AVRO": "read_parquet",
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
_DIRECTORY_READERS = frozenset({"delta_scan"})
|
|
40
|
+
|
|
41
|
+
_FILE_EXTENSIONS = frozenset({".parquet", ".csv", ".json", ".avro", ".orc", ".gz", ".snappy", ".zst"})
|
|
42
|
+
|
|
43
|
+
_EMPTY_RESIDUAL = ResidualPlan(predicates=[], limit=None, casts=[])
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _ensure_duckdb_extension(conn: Any, ext: str) -> None:
|
|
47
|
+
"""Check-then-load a DuckDB extension."""
|
|
48
|
+
try:
|
|
49
|
+
row = conn.execute(
|
|
50
|
+
"SELECT installed, loaded FROM duckdb_extensions() WHERE extension_name = ?",
|
|
51
|
+
[ext],
|
|
52
|
+
).fetchone()
|
|
53
|
+
installed = row[0] if row is not None else False
|
|
54
|
+
loaded = row[1] if row is not None else False
|
|
55
|
+
if loaded:
|
|
56
|
+
return
|
|
57
|
+
if not installed:
|
|
58
|
+
conn.execute(f"INSTALL {ext}")
|
|
59
|
+
conn.execute(f"LOAD {ext}")
|
|
60
|
+
except ExecutionError:
|
|
61
|
+
raise
|
|
62
|
+
except Exception as exc:
|
|
63
|
+
raise ExecutionError(
|
|
64
|
+
plugin_error(
|
|
65
|
+
"RVT-502",
|
|
66
|
+
f"Failed to load DuckDB extension '{ext}': {exc}",
|
|
67
|
+
plugin_name="rivet_databricks",
|
|
68
|
+
plugin_type="adapter",
|
|
69
|
+
remediation=(
|
|
70
|
+
f"Run: INSTALL {ext}; LOAD {ext}; "
|
|
71
|
+
"or set the DuckDB extension directory for offline environments."
|
|
72
|
+
),
|
|
73
|
+
extension=ext,
|
|
74
|
+
)
|
|
75
|
+
) from exc
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _apply_duckdb_pushdown(
|
|
79
|
+
base_sql: str,
|
|
80
|
+
pushdown: PushdownPlan | None,
|
|
81
|
+
) -> tuple[str, ResidualPlan]:
|
|
82
|
+
"""Apply pushdown operations to a DuckDB SQL query."""
|
|
83
|
+
if pushdown is None:
|
|
84
|
+
return base_sql, _EMPTY_RESIDUAL
|
|
85
|
+
|
|
86
|
+
sql = base_sql
|
|
87
|
+
residual_predicates: list[Any] = list(pushdown.predicates.residual)
|
|
88
|
+
residual_casts: list[Cast] = list(pushdown.casts.residual)
|
|
89
|
+
residual_limit: int | None = pushdown.limit.residual_limit
|
|
90
|
+
|
|
91
|
+
if pushdown.projections.pushed_columns is not None:
|
|
92
|
+
try:
|
|
93
|
+
cols = ", ".join(pushdown.projections.pushed_columns)
|
|
94
|
+
sql = sql.replace("SELECT *", f"SELECT {cols}", 1)
|
|
95
|
+
except Exception:
|
|
96
|
+
pass
|
|
97
|
+
|
|
98
|
+
if pushdown.predicates.pushed:
|
|
99
|
+
where_parts: list[str] = []
|
|
100
|
+
for pred in pushdown.predicates.pushed:
|
|
101
|
+
try:
|
|
102
|
+
where_parts.append(pred.expression)
|
|
103
|
+
except Exception:
|
|
104
|
+
residual_predicates.append(pred)
|
|
105
|
+
if where_parts:
|
|
106
|
+
where_clause = " AND ".join(where_parts)
|
|
107
|
+
sql = f"SELECT * FROM ({sql}) AS __pd WHERE {where_clause}"
|
|
108
|
+
|
|
109
|
+
if pushdown.limit.pushed_limit is not None:
|
|
110
|
+
try:
|
|
111
|
+
sql = f"{sql} LIMIT {pushdown.limit.pushed_limit}"
|
|
112
|
+
except Exception:
|
|
113
|
+
residual_limit = pushdown.limit.pushed_limit
|
|
114
|
+
|
|
115
|
+
for cast in pushdown.casts.pushed:
|
|
116
|
+
try:
|
|
117
|
+
sql = sql.replace(cast.column, f"CAST({cast.column} AS {cast.to_type})")
|
|
118
|
+
except Exception:
|
|
119
|
+
residual_casts.append(cast)
|
|
120
|
+
|
|
121
|
+
return sql, ResidualPlan(predicates=residual_predicates, limit=residual_limit, casts=residual_casts)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _resolve_full_name(joint: Any, catalog: Any) -> str:
|
|
125
|
+
"""Build the three-part Databricks table name from joint and catalog options."""
|
|
126
|
+
table = getattr(joint, "table", None) or joint.name
|
|
127
|
+
catalog_name = catalog.options.get("catalog", "")
|
|
128
|
+
schema = catalog.options.get("schema", "default")
|
|
129
|
+
dot_count = table.count(".")
|
|
130
|
+
if dot_count >= 2:
|
|
131
|
+
return table # type: ignore[no-any-return]
|
|
132
|
+
if dot_count == 1:
|
|
133
|
+
return f"{catalog_name}.{table}"
|
|
134
|
+
return f"{catalog_name}.{schema}.{table}"
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def _resolve_storage_path(storage_location: str, reader_func: str) -> str:
|
|
138
|
+
"""Resolve the storage path for DuckDB reader functions."""
|
|
139
|
+
if reader_func in _DIRECTORY_READERS:
|
|
140
|
+
return storage_location.rstrip("/")
|
|
141
|
+
|
|
142
|
+
from pathlib import PurePosixPath
|
|
143
|
+
|
|
144
|
+
suffix = PurePosixPath(storage_location.rstrip("/")).suffix.lower()
|
|
145
|
+
if suffix in _FILE_EXTENSIONS:
|
|
146
|
+
return storage_location
|
|
147
|
+
|
|
148
|
+
ext = {
|
|
149
|
+
"read_parquet": "parquet",
|
|
150
|
+
"read_csv_auto": "csv",
|
|
151
|
+
"read_json_auto": "json",
|
|
152
|
+
}.get(reader_func, "parquet")
|
|
153
|
+
|
|
154
|
+
return f"{storage_location.rstrip('/')}/**/*.{ext}"
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def _configure_duckdb_credentials(
|
|
158
|
+
conn: Any,
|
|
159
|
+
storage_location: str,
|
|
160
|
+
credentials: dict[str, Any] | None,
|
|
161
|
+
catalog_options: dict[str, Any] | None = None,
|
|
162
|
+
) -> None:
|
|
163
|
+
"""Configure DuckDB secret manager with vended or ambient credentials."""
|
|
164
|
+
if credentials is None:
|
|
165
|
+
warnings.warn(
|
|
166
|
+
"No vended credentials available; using ambient cloud credentials.",
|
|
167
|
+
stacklevel=3,
|
|
168
|
+
)
|
|
169
|
+
return
|
|
170
|
+
|
|
171
|
+
catalog_options = catalog_options or {}
|
|
172
|
+
|
|
173
|
+
aws_creds = credentials.get("aws_temp_credentials")
|
|
174
|
+
if aws_creds:
|
|
175
|
+
access_key = aws_creds.get("access_key_id", "")
|
|
176
|
+
secret_key = aws_creds.get("secret_access_key", "")
|
|
177
|
+
session_token = aws_creds.get("session_token", "")
|
|
178
|
+
region = (
|
|
179
|
+
aws_creds.get("region")
|
|
180
|
+
or catalog_options.get("region")
|
|
181
|
+
or "us-east-1"
|
|
182
|
+
)
|
|
183
|
+
conn.execute(f"""
|
|
184
|
+
CREATE OR REPLACE SECRET databricks_s3 (
|
|
185
|
+
TYPE S3,
|
|
186
|
+
KEY_ID '{access_key}',
|
|
187
|
+
SECRET '{secret_key}',
|
|
188
|
+
SESSION_TOKEN '{session_token}',
|
|
189
|
+
REGION '{region}'
|
|
190
|
+
)
|
|
191
|
+
""")
|
|
192
|
+
return
|
|
193
|
+
|
|
194
|
+
azure_creds = credentials.get("azure_user_delegation_sas")
|
|
195
|
+
if azure_creds:
|
|
196
|
+
sas_token = azure_creds.get("sas_token", "")
|
|
197
|
+
conn.execute(f"""
|
|
198
|
+
CREATE OR REPLACE SECRET databricks_azure (
|
|
199
|
+
TYPE AZURE,
|
|
200
|
+
CONNECTION_STRING 'BlobEndpoint=https://placeholder.blob.core.windows.net;SharedAccessSignature={sas_token}'
|
|
201
|
+
)
|
|
202
|
+
""")
|
|
203
|
+
return
|
|
204
|
+
|
|
205
|
+
gcs_creds = credentials.get("gcp_oauth_token")
|
|
206
|
+
if gcs_creds:
|
|
207
|
+
oauth_token = gcs_creds.get("oauth_token", "")
|
|
208
|
+
conn.execute(f"""
|
|
209
|
+
CREATE OR REPLACE SECRET databricks_gcs (
|
|
210
|
+
TYPE GCS,
|
|
211
|
+
TOKEN '{oauth_token}'
|
|
212
|
+
)
|
|
213
|
+
""")
|
|
214
|
+
return
|
|
215
|
+
|
|
216
|
+
warnings.warn(
|
|
217
|
+
f"Unrecognized credential format from Databricks vending: {list(credentials.keys())}. "
|
|
218
|
+
"Falling back to ambient cloud credentials.",
|
|
219
|
+
stacklevel=3,
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
class _DatabricksDuckDBMaterializedRef(MaterializedRef):
|
|
224
|
+
"""Deferred ref that reads from Databricks-managed storage via DuckDB + httpfs."""
|
|
225
|
+
|
|
226
|
+
def __init__(
|
|
227
|
+
self,
|
|
228
|
+
storage_location: str,
|
|
229
|
+
reader_func: str,
|
|
230
|
+
credentials: dict[str, Any] | None,
|
|
231
|
+
catalog_options: dict[str, Any] | None = None,
|
|
232
|
+
sql_override: str | None = None,
|
|
233
|
+
) -> None:
|
|
234
|
+
self._storage_location = storage_location
|
|
235
|
+
self._reader_func = reader_func
|
|
236
|
+
self._credentials = credentials
|
|
237
|
+
self._catalog_options = catalog_options or {}
|
|
238
|
+
self._sql_override = sql_override
|
|
239
|
+
|
|
240
|
+
def to_arrow(self) -> pyarrow.Table:
|
|
241
|
+
import duckdb
|
|
242
|
+
|
|
243
|
+
conn = duckdb.connect(":memory:")
|
|
244
|
+
try:
|
|
245
|
+
_ensure_duckdb_extension(conn, "httpfs")
|
|
246
|
+
if self._reader_func == "delta_scan":
|
|
247
|
+
_ensure_duckdb_extension(conn, "delta")
|
|
248
|
+
_configure_duckdb_credentials(conn, self._storage_location, self._credentials, self._catalog_options)
|
|
249
|
+
location = _resolve_storage_path(self._storage_location, self._reader_func)
|
|
250
|
+
sql = self._sql_override or f"SELECT * FROM {self._reader_func}('{location}')"
|
|
251
|
+
return conn.execute(sql).arrow()
|
|
252
|
+
except ExecutionError:
|
|
253
|
+
raise
|
|
254
|
+
except Exception as exc:
|
|
255
|
+
raise ExecutionError(
|
|
256
|
+
plugin_error(
|
|
257
|
+
"RVT-501",
|
|
258
|
+
f"DuckDB Databricks read failed: {exc}",
|
|
259
|
+
plugin_name="rivet_databricks",
|
|
260
|
+
plugin_type="adapter",
|
|
261
|
+
adapter="DatabricksDuckDBAdapter",
|
|
262
|
+
remediation="Check storage location accessibility and credential validity.",
|
|
263
|
+
storage_location=self._storage_location,
|
|
264
|
+
)
|
|
265
|
+
) from exc
|
|
266
|
+
finally:
|
|
267
|
+
conn.close()
|
|
268
|
+
|
|
269
|
+
@property
|
|
270
|
+
def schema(self) -> Any:
|
|
271
|
+
from rivet_core.models import Column, Schema
|
|
272
|
+
|
|
273
|
+
table = self.to_arrow()
|
|
274
|
+
return Schema(
|
|
275
|
+
columns=[
|
|
276
|
+
Column(name=f.name, type=str(f.type), nullable=f.nullable)
|
|
277
|
+
for f in table.schema
|
|
278
|
+
]
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
@property
|
|
282
|
+
def row_count(self) -> int:
|
|
283
|
+
return self.to_arrow().num_rows # type: ignore[no-any-return]
|
|
284
|
+
|
|
285
|
+
@property
|
|
286
|
+
def size_bytes(self) -> int | None:
|
|
287
|
+
return None
|
|
288
|
+
|
|
289
|
+
@property
|
|
290
|
+
def storage_type(self) -> str:
|
|
291
|
+
return "databricks_storage"
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
class DatabricksDuckDBAdapter(ComputeEngineAdapter):
|
|
295
|
+
"""DuckDB adapter for Databricks catalog: REST API metadata + credential vending + httpfs."""
|
|
296
|
+
|
|
297
|
+
target_engine_type = "duckdb"
|
|
298
|
+
catalog_type = "databricks"
|
|
299
|
+
capabilities = _READ_CAPABILITIES + _WRITE_CAPABILITIES
|
|
300
|
+
source = "catalog_plugin"
|
|
301
|
+
source_plugin = "rivet_databricks"
|
|
302
|
+
|
|
303
|
+
def read_dispatch(self, engine: Any, catalog: Any, joint: Any, pushdown: PushdownPlan | None = None) -> AdapterPushdownResult:
|
|
304
|
+
"""Read from Databricks-managed storage via DuckDB httpfs with vended credentials."""
|
|
305
|
+
from rivet_databricks.databricks_catalog import DatabricksCatalogPlugin
|
|
306
|
+
|
|
307
|
+
plugin = DatabricksCatalogPlugin()
|
|
308
|
+
full_name = _resolve_full_name(joint, catalog)
|
|
309
|
+
|
|
310
|
+
table_meta = _get_table_metadata(plugin, full_name, catalog)
|
|
311
|
+
storage_location = table_meta.get("storage_location")
|
|
312
|
+
if not storage_location:
|
|
313
|
+
raise ExecutionError(
|
|
314
|
+
plugin_error(
|
|
315
|
+
"RVT-503",
|
|
316
|
+
f"No storage_location for Databricks table '{full_name}'.",
|
|
317
|
+
plugin_name="rivet_databricks",
|
|
318
|
+
plugin_type="adapter",
|
|
319
|
+
adapter="DatabricksDuckDBAdapter",
|
|
320
|
+
remediation="Verify the table exists and has a storage location.",
|
|
321
|
+
table=full_name,
|
|
322
|
+
)
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
file_format = (table_meta.get("file_format") or "PARQUET").upper()
|
|
326
|
+
reader_func = _FORMAT_TO_READER.get(file_format, "read_parquet")
|
|
327
|
+
credentials = table_meta.get("temporary_credentials")
|
|
328
|
+
|
|
329
|
+
location = _resolve_storage_path(storage_location, reader_func)
|
|
330
|
+
base_sql = f"SELECT * FROM {reader_func}('{location}')"
|
|
331
|
+
sql, residual = _apply_duckdb_pushdown(base_sql, pushdown)
|
|
332
|
+
|
|
333
|
+
ref = _DatabricksDuckDBMaterializedRef(
|
|
334
|
+
storage_location=storage_location,
|
|
335
|
+
reader_func=reader_func,
|
|
336
|
+
credentials=credentials,
|
|
337
|
+
catalog_options=catalog.options,
|
|
338
|
+
sql_override=sql,
|
|
339
|
+
)
|
|
340
|
+
material = Material(
|
|
341
|
+
name=joint.name,
|
|
342
|
+
catalog=catalog.name,
|
|
343
|
+
materialized_ref=ref,
|
|
344
|
+
state="deferred",
|
|
345
|
+
)
|
|
346
|
+
return AdapterPushdownResult(material=material, residual=residual)
|
|
347
|
+
|
|
348
|
+
def write_dispatch(self, engine: Any, catalog: Any, joint: Any, material: Any) -> Any:
|
|
349
|
+
"""Write to Databricks-managed storage via DuckDB httpfs with vended credentials."""
|
|
350
|
+
import duckdb
|
|
351
|
+
|
|
352
|
+
from rivet_databricks.databricks_catalog import DatabricksCatalogPlugin
|
|
353
|
+
|
|
354
|
+
plugin = DatabricksCatalogPlugin()
|
|
355
|
+
full_name = _resolve_full_name(joint, catalog)
|
|
356
|
+
|
|
357
|
+
table_meta = _get_table_metadata(plugin, full_name, catalog)
|
|
358
|
+
storage_location = table_meta.get("storage_location")
|
|
359
|
+
if not storage_location:
|
|
360
|
+
raise ExecutionError(
|
|
361
|
+
plugin_error(
|
|
362
|
+
"RVT-503",
|
|
363
|
+
f"No storage_location for Databricks table '{full_name}'.",
|
|
364
|
+
plugin_name="rivet_databricks",
|
|
365
|
+
plugin_type="adapter",
|
|
366
|
+
adapter="DatabricksDuckDBAdapter",
|
|
367
|
+
remediation="Verify the table exists and has a storage location.",
|
|
368
|
+
table=full_name,
|
|
369
|
+
)
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
credentials = table_meta.get("temporary_credentials")
|
|
373
|
+
arrow_table = material.to_arrow()
|
|
374
|
+
strategy = getattr(joint, "write_strategy", None) or "replace"
|
|
375
|
+
|
|
376
|
+
conn = duckdb.connect(":memory:")
|
|
377
|
+
try:
|
|
378
|
+
_ensure_duckdb_extension(conn, "httpfs")
|
|
379
|
+
_configure_duckdb_credentials(conn, storage_location, credentials, catalog.options)
|
|
380
|
+
conn.register("__write_data", arrow_table)
|
|
381
|
+
|
|
382
|
+
if strategy == "append":
|
|
383
|
+
sql = f"COPY __write_data TO '{storage_location}' (FORMAT PARQUET, APPEND)"
|
|
384
|
+
elif strategy == "partition":
|
|
385
|
+
partition_by = getattr(joint, "partition_by", None)
|
|
386
|
+
if partition_by:
|
|
387
|
+
cols = ", ".join(partition_by) if isinstance(partition_by, list) else partition_by
|
|
388
|
+
sql = f"COPY __write_data TO '{storage_location}' (FORMAT PARQUET, PARTITION_BY ({cols}))"
|
|
389
|
+
else:
|
|
390
|
+
sql = f"COPY __write_data TO '{storage_location}' (FORMAT PARQUET)"
|
|
391
|
+
else:
|
|
392
|
+
sql = f"COPY __write_data TO '{storage_location}' (FORMAT PARQUET)"
|
|
393
|
+
|
|
394
|
+
conn.execute(sql)
|
|
395
|
+
except ExecutionError:
|
|
396
|
+
raise
|
|
397
|
+
except Exception as exc:
|
|
398
|
+
raise ExecutionError(
|
|
399
|
+
plugin_error(
|
|
400
|
+
"RVT-501",
|
|
401
|
+
f"DuckDB Databricks write failed: {exc}",
|
|
402
|
+
plugin_name="rivet_databricks",
|
|
403
|
+
plugin_type="adapter",
|
|
404
|
+
adapter="DatabricksDuckDBAdapter",
|
|
405
|
+
remediation="Check storage location write permissions and credential validity.",
|
|
406
|
+
storage_location=storage_location,
|
|
407
|
+
strategy=strategy,
|
|
408
|
+
)
|
|
409
|
+
) from exc
|
|
410
|
+
finally:
|
|
411
|
+
conn.close()
|
|
412
|
+
|
|
413
|
+
return material
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
def _get_table_metadata(plugin: Any, full_name: str, catalog: Any) -> dict[str, Any]:
|
|
417
|
+
"""Get table metadata via the DatabricksCatalogPlugin, with graceful fallback on HTTP 403."""
|
|
418
|
+
from rivet_databricks.auth import resolve_credentials
|
|
419
|
+
from rivet_databricks.client import UnityCatalogClient
|
|
420
|
+
|
|
421
|
+
host = catalog.options["workspace_url"]
|
|
422
|
+
credential = resolve_credentials(catalog.options, host=host)
|
|
423
|
+
client = UnityCatalogClient(host=host, credential=credential)
|
|
424
|
+
try:
|
|
425
|
+
raw = client.get_table(full_name)
|
|
426
|
+
table_id = raw.get("table_id") or raw.get("full_name") or full_name
|
|
427
|
+
try:
|
|
428
|
+
temporary_credentials = client.vend_credentials(table_id, operation="READ")
|
|
429
|
+
except ExecutionError as exc:
|
|
430
|
+
if exc.error.code == "RVT-508":
|
|
431
|
+
warnings.warn(
|
|
432
|
+
f"Credential vending unavailable for '{full_name}': {exc.error.message} "
|
|
433
|
+
"Falling back to ambient cloud credentials.",
|
|
434
|
+
stacklevel=4,
|
|
435
|
+
)
|
|
436
|
+
temporary_credentials = None
|
|
437
|
+
else:
|
|
438
|
+
raise
|
|
439
|
+
finally:
|
|
440
|
+
client.close()
|
|
441
|
+
|
|
442
|
+
return {
|
|
443
|
+
"storage_location": raw.get("storage_location"),
|
|
444
|
+
"file_format": raw.get("data_source_format"),
|
|
445
|
+
"temporary_credentials": temporary_credentials,
|
|
446
|
+
}
|