rivetsql-databricks 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,25 @@
1
+ .DS_Store
2
+ .env
3
+ .env.*
4
+ *.pem
5
+ *.key
6
+ secrets/
7
+ target/
8
+ build/
9
+ dist/
10
+ out/
11
+ .next/
12
+ node_modules/
13
+ __pycache__/
14
+ *.pyc
15
+ .pytest_cache/
16
+ .hypothesis/
17
+ .cargo/
18
+ *.class
19
+ bin/
20
+ obj/
21
+ .ralph-logs/
22
+ mcp.json
23
+ docs/reference/
24
+ docs/_build/
25
+ *.zip
@@ -0,0 +1,19 @@
1
+ Metadata-Version: 2.4
2
+ Name: rivetsql-databricks
3
+ Version: 0.1.0
4
+ Summary: Databricks and Unity Catalog plugin for Rivet SQL
5
+ Project-URL: Homepage, https://github.com/rivetsql/rivet
6
+ Project-URL: Repository, https://github.com/rivetsql/rivet
7
+ Author: Rivet Contributors
8
+ License-Expression: MIT
9
+ Keywords: data-pipeline,databricks,rivet,sql
10
+ Classifier: Development Status :: 4 - Beta
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Programming Language :: Python :: 3.11
13
+ Classifier: Programming Language :: Python :: 3.12
14
+ Classifier: Programming Language :: Python :: 3.13
15
+ Classifier: Topic :: Database
16
+ Requires-Python: >=3.11
17
+ Requires-Dist: pyarrow>=14.0
18
+ Requires-Dist: requests>=2.28
19
+ Requires-Dist: rivetsql-core>=0.1.0
@@ -0,0 +1,55 @@
1
+ """rivet_databricks — Databricks plugin for Rivet."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import TYPE_CHECKING
6
+
7
+ if TYPE_CHECKING:
8
+ from rivet_core.plugins import PluginRegistry
9
+
10
+
11
+ def DatabricksPlugin(registry: PluginRegistry) -> None:
12
+ """Register all rivet_databricks components into the plugin registry.
13
+
14
+ Core components (catalogs, sources, sinks, cross-joint) are always
15
+ registered. Engine and cross-catalog adapters are registered best-effort
16
+ since they depend on optional packages.
17
+ """
18
+ # Core components — always registered
19
+ from rivet_databricks.databricks_catalog import DatabricksCatalogPlugin
20
+ from rivet_databricks.databricks_cross_joint import DatabricksCrossJointAdapter
21
+ from rivet_databricks.databricks_sink import DatabricksSink
22
+ from rivet_databricks.databricks_source import DatabricksSource
23
+ from rivet_databricks.unity_catalog import UnityCatalogPlugin
24
+ from rivet_databricks.unity_sink import UnitySink
25
+ from rivet_databricks.unity_source import UnitySource
26
+
27
+ registry.register_catalog_plugin(UnityCatalogPlugin())
28
+ registry.register_catalog_plugin(DatabricksCatalogPlugin())
29
+ registry.register_source(UnitySource())
30
+ registry.register_source(DatabricksSource())
31
+ registry.register_sink(UnitySink())
32
+ registry.register_sink(DatabricksSink())
33
+ registry.register_cross_joint_adapter(DatabricksCrossJointAdapter())
34
+
35
+ # Engine depends on databricks-sql-connector — register best-effort
36
+ try:
37
+ from rivet_databricks.engine import DatabricksComputeEnginePlugin
38
+
39
+ registry.register_engine_plugin(DatabricksComputeEnginePlugin())
40
+ except ImportError:
41
+ pass
42
+
43
+ # Adapters depend on optional packages — register best-effort
44
+ try:
45
+ from rivet_databricks.adapters.unity import DatabricksUnityAdapter
46
+
47
+ registry.register_adapter(DatabricksUnityAdapter())
48
+ except ImportError:
49
+ pass
50
+ try:
51
+ from rivet_databricks.adapters.duckdb import DatabricksDuckDBAdapter
52
+
53
+ registry.register_adapter(DatabricksDuckDBAdapter())
54
+ except ImportError:
55
+ pass
@@ -0,0 +1 @@
1
+ """Databricks engine adapters for external catalog types."""
@@ -0,0 +1,446 @@
1
+ """DatabricksDuckDBAdapter: credential vending, httpfs + storage read/write via DuckDB."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import warnings
7
+ from typing import Any
8
+
9
+ import pyarrow
10
+
11
+ from rivet_core.errors import ExecutionError, plugin_error
12
+ from rivet_core.models import Material
13
+ from rivet_core.optimizer import AdapterPushdownResult, Cast, PushdownPlan, ResidualPlan
14
+ from rivet_core.plugins import ComputeEngineAdapter
15
+ from rivet_core.strategies import MaterializedRef
16
+
17
+ _logger = logging.getLogger(__name__)
18
+
19
+ _READ_CAPABILITIES = [
20
+ "projection_pushdown",
21
+ "predicate_pushdown",
22
+ "limit_pushdown",
23
+ ]
24
+
25
+ _WRITE_CAPABILITIES = [
26
+ "write_append",
27
+ "write_replace",
28
+ "write_partition",
29
+ ]
30
+
31
+ _FORMAT_TO_READER: dict[str, str] = {
32
+ "PARQUET": "read_parquet",
33
+ "DELTA": "delta_scan",
34
+ "CSV": "read_csv_auto",
35
+ "JSON": "read_json_auto",
36
+ "AVRO": "read_parquet",
37
+ }
38
+
39
+ _DIRECTORY_READERS = frozenset({"delta_scan"})
40
+
41
+ _FILE_EXTENSIONS = frozenset({".parquet", ".csv", ".json", ".avro", ".orc", ".gz", ".snappy", ".zst"})
42
+
43
+ _EMPTY_RESIDUAL = ResidualPlan(predicates=[], limit=None, casts=[])
44
+
45
+
46
+ def _ensure_duckdb_extension(conn: Any, ext: str) -> None:
47
+ """Check-then-load a DuckDB extension."""
48
+ try:
49
+ row = conn.execute(
50
+ "SELECT installed, loaded FROM duckdb_extensions() WHERE extension_name = ?",
51
+ [ext],
52
+ ).fetchone()
53
+ installed = row[0] if row is not None else False
54
+ loaded = row[1] if row is not None else False
55
+ if loaded:
56
+ return
57
+ if not installed:
58
+ conn.execute(f"INSTALL {ext}")
59
+ conn.execute(f"LOAD {ext}")
60
+ except ExecutionError:
61
+ raise
62
+ except Exception as exc:
63
+ raise ExecutionError(
64
+ plugin_error(
65
+ "RVT-502",
66
+ f"Failed to load DuckDB extension '{ext}': {exc}",
67
+ plugin_name="rivet_databricks",
68
+ plugin_type="adapter",
69
+ remediation=(
70
+ f"Run: INSTALL {ext}; LOAD {ext}; "
71
+ "or set the DuckDB extension directory for offline environments."
72
+ ),
73
+ extension=ext,
74
+ )
75
+ ) from exc
76
+
77
+
78
+ def _apply_duckdb_pushdown(
79
+ base_sql: str,
80
+ pushdown: PushdownPlan | None,
81
+ ) -> tuple[str, ResidualPlan]:
82
+ """Apply pushdown operations to a DuckDB SQL query."""
83
+ if pushdown is None:
84
+ return base_sql, _EMPTY_RESIDUAL
85
+
86
+ sql = base_sql
87
+ residual_predicates: list[Any] = list(pushdown.predicates.residual)
88
+ residual_casts: list[Cast] = list(pushdown.casts.residual)
89
+ residual_limit: int | None = pushdown.limit.residual_limit
90
+
91
+ if pushdown.projections.pushed_columns is not None:
92
+ try:
93
+ cols = ", ".join(pushdown.projections.pushed_columns)
94
+ sql = sql.replace("SELECT *", f"SELECT {cols}", 1)
95
+ except Exception:
96
+ pass
97
+
98
+ if pushdown.predicates.pushed:
99
+ where_parts: list[str] = []
100
+ for pred in pushdown.predicates.pushed:
101
+ try:
102
+ where_parts.append(pred.expression)
103
+ except Exception:
104
+ residual_predicates.append(pred)
105
+ if where_parts:
106
+ where_clause = " AND ".join(where_parts)
107
+ sql = f"SELECT * FROM ({sql}) AS __pd WHERE {where_clause}"
108
+
109
+ if pushdown.limit.pushed_limit is not None:
110
+ try:
111
+ sql = f"{sql} LIMIT {pushdown.limit.pushed_limit}"
112
+ except Exception:
113
+ residual_limit = pushdown.limit.pushed_limit
114
+
115
+ for cast in pushdown.casts.pushed:
116
+ try:
117
+ sql = sql.replace(cast.column, f"CAST({cast.column} AS {cast.to_type})")
118
+ except Exception:
119
+ residual_casts.append(cast)
120
+
121
+ return sql, ResidualPlan(predicates=residual_predicates, limit=residual_limit, casts=residual_casts)
122
+
123
+
124
+ def _resolve_full_name(joint: Any, catalog: Any) -> str:
125
+ """Build the three-part Databricks table name from joint and catalog options."""
126
+ table = getattr(joint, "table", None) or joint.name
127
+ catalog_name = catalog.options.get("catalog", "")
128
+ schema = catalog.options.get("schema", "default")
129
+ dot_count = table.count(".")
130
+ if dot_count >= 2:
131
+ return table # type: ignore[no-any-return]
132
+ if dot_count == 1:
133
+ return f"{catalog_name}.{table}"
134
+ return f"{catalog_name}.{schema}.{table}"
135
+
136
+
137
+ def _resolve_storage_path(storage_location: str, reader_func: str) -> str:
138
+ """Resolve the storage path for DuckDB reader functions."""
139
+ if reader_func in _DIRECTORY_READERS:
140
+ return storage_location.rstrip("/")
141
+
142
+ from pathlib import PurePosixPath
143
+
144
+ suffix = PurePosixPath(storage_location.rstrip("/")).suffix.lower()
145
+ if suffix in _FILE_EXTENSIONS:
146
+ return storage_location
147
+
148
+ ext = {
149
+ "read_parquet": "parquet",
150
+ "read_csv_auto": "csv",
151
+ "read_json_auto": "json",
152
+ }.get(reader_func, "parquet")
153
+
154
+ return f"{storage_location.rstrip('/')}/**/*.{ext}"
155
+
156
+
157
+ def _configure_duckdb_credentials(
158
+ conn: Any,
159
+ storage_location: str,
160
+ credentials: dict[str, Any] | None,
161
+ catalog_options: dict[str, Any] | None = None,
162
+ ) -> None:
163
+ """Configure DuckDB secret manager with vended or ambient credentials."""
164
+ if credentials is None:
165
+ warnings.warn(
166
+ "No vended credentials available; using ambient cloud credentials.",
167
+ stacklevel=3,
168
+ )
169
+ return
170
+
171
+ catalog_options = catalog_options or {}
172
+
173
+ aws_creds = credentials.get("aws_temp_credentials")
174
+ if aws_creds:
175
+ access_key = aws_creds.get("access_key_id", "")
176
+ secret_key = aws_creds.get("secret_access_key", "")
177
+ session_token = aws_creds.get("session_token", "")
178
+ region = (
179
+ aws_creds.get("region")
180
+ or catalog_options.get("region")
181
+ or "us-east-1"
182
+ )
183
+ conn.execute(f"""
184
+ CREATE OR REPLACE SECRET databricks_s3 (
185
+ TYPE S3,
186
+ KEY_ID '{access_key}',
187
+ SECRET '{secret_key}',
188
+ SESSION_TOKEN '{session_token}',
189
+ REGION '{region}'
190
+ )
191
+ """)
192
+ return
193
+
194
+ azure_creds = credentials.get("azure_user_delegation_sas")
195
+ if azure_creds:
196
+ sas_token = azure_creds.get("sas_token", "")
197
+ conn.execute(f"""
198
+ CREATE OR REPLACE SECRET databricks_azure (
199
+ TYPE AZURE,
200
+ CONNECTION_STRING 'BlobEndpoint=https://placeholder.blob.core.windows.net;SharedAccessSignature={sas_token}'
201
+ )
202
+ """)
203
+ return
204
+
205
+ gcs_creds = credentials.get("gcp_oauth_token")
206
+ if gcs_creds:
207
+ oauth_token = gcs_creds.get("oauth_token", "")
208
+ conn.execute(f"""
209
+ CREATE OR REPLACE SECRET databricks_gcs (
210
+ TYPE GCS,
211
+ TOKEN '{oauth_token}'
212
+ )
213
+ """)
214
+ return
215
+
216
+ warnings.warn(
217
+ f"Unrecognized credential format from Databricks vending: {list(credentials.keys())}. "
218
+ "Falling back to ambient cloud credentials.",
219
+ stacklevel=3,
220
+ )
221
+
222
+
223
+ class _DatabricksDuckDBMaterializedRef(MaterializedRef):
224
+ """Deferred ref that reads from Databricks-managed storage via DuckDB + httpfs."""
225
+
226
+ def __init__(
227
+ self,
228
+ storage_location: str,
229
+ reader_func: str,
230
+ credentials: dict[str, Any] | None,
231
+ catalog_options: dict[str, Any] | None = None,
232
+ sql_override: str | None = None,
233
+ ) -> None:
234
+ self._storage_location = storage_location
235
+ self._reader_func = reader_func
236
+ self._credentials = credentials
237
+ self._catalog_options = catalog_options or {}
238
+ self._sql_override = sql_override
239
+
240
+ def to_arrow(self) -> pyarrow.Table:
241
+ import duckdb
242
+
243
+ conn = duckdb.connect(":memory:")
244
+ try:
245
+ _ensure_duckdb_extension(conn, "httpfs")
246
+ if self._reader_func == "delta_scan":
247
+ _ensure_duckdb_extension(conn, "delta")
248
+ _configure_duckdb_credentials(conn, self._storage_location, self._credentials, self._catalog_options)
249
+ location = _resolve_storage_path(self._storage_location, self._reader_func)
250
+ sql = self._sql_override or f"SELECT * FROM {self._reader_func}('{location}')"
251
+ return conn.execute(sql).arrow()
252
+ except ExecutionError:
253
+ raise
254
+ except Exception as exc:
255
+ raise ExecutionError(
256
+ plugin_error(
257
+ "RVT-501",
258
+ f"DuckDB Databricks read failed: {exc}",
259
+ plugin_name="rivet_databricks",
260
+ plugin_type="adapter",
261
+ adapter="DatabricksDuckDBAdapter",
262
+ remediation="Check storage location accessibility and credential validity.",
263
+ storage_location=self._storage_location,
264
+ )
265
+ ) from exc
266
+ finally:
267
+ conn.close()
268
+
269
+ @property
270
+ def schema(self) -> Any:
271
+ from rivet_core.models import Column, Schema
272
+
273
+ table = self.to_arrow()
274
+ return Schema(
275
+ columns=[
276
+ Column(name=f.name, type=str(f.type), nullable=f.nullable)
277
+ for f in table.schema
278
+ ]
279
+ )
280
+
281
+ @property
282
+ def row_count(self) -> int:
283
+ return self.to_arrow().num_rows # type: ignore[no-any-return]
284
+
285
+ @property
286
+ def size_bytes(self) -> int | None:
287
+ return None
288
+
289
+ @property
290
+ def storage_type(self) -> str:
291
+ return "databricks_storage"
292
+
293
+
294
+ class DatabricksDuckDBAdapter(ComputeEngineAdapter):
295
+ """DuckDB adapter for Databricks catalog: REST API metadata + credential vending + httpfs."""
296
+
297
+ target_engine_type = "duckdb"
298
+ catalog_type = "databricks"
299
+ capabilities = _READ_CAPABILITIES + _WRITE_CAPABILITIES
300
+ source = "catalog_plugin"
301
+ source_plugin = "rivet_databricks"
302
+
303
+ def read_dispatch(self, engine: Any, catalog: Any, joint: Any, pushdown: PushdownPlan | None = None) -> AdapterPushdownResult:
304
+ """Read from Databricks-managed storage via DuckDB httpfs with vended credentials."""
305
+ from rivet_databricks.databricks_catalog import DatabricksCatalogPlugin
306
+
307
+ plugin = DatabricksCatalogPlugin()
308
+ full_name = _resolve_full_name(joint, catalog)
309
+
310
+ table_meta = _get_table_metadata(plugin, full_name, catalog)
311
+ storage_location = table_meta.get("storage_location")
312
+ if not storage_location:
313
+ raise ExecutionError(
314
+ plugin_error(
315
+ "RVT-503",
316
+ f"No storage_location for Databricks table '{full_name}'.",
317
+ plugin_name="rivet_databricks",
318
+ plugin_type="adapter",
319
+ adapter="DatabricksDuckDBAdapter",
320
+ remediation="Verify the table exists and has a storage location.",
321
+ table=full_name,
322
+ )
323
+ )
324
+
325
+ file_format = (table_meta.get("file_format") or "PARQUET").upper()
326
+ reader_func = _FORMAT_TO_READER.get(file_format, "read_parquet")
327
+ credentials = table_meta.get("temporary_credentials")
328
+
329
+ location = _resolve_storage_path(storage_location, reader_func)
330
+ base_sql = f"SELECT * FROM {reader_func}('{location}')"
331
+ sql, residual = _apply_duckdb_pushdown(base_sql, pushdown)
332
+
333
+ ref = _DatabricksDuckDBMaterializedRef(
334
+ storage_location=storage_location,
335
+ reader_func=reader_func,
336
+ credentials=credentials,
337
+ catalog_options=catalog.options,
338
+ sql_override=sql,
339
+ )
340
+ material = Material(
341
+ name=joint.name,
342
+ catalog=catalog.name,
343
+ materialized_ref=ref,
344
+ state="deferred",
345
+ )
346
+ return AdapterPushdownResult(material=material, residual=residual)
347
+
348
+ def write_dispatch(self, engine: Any, catalog: Any, joint: Any, material: Any) -> Any:
349
+ """Write to Databricks-managed storage via DuckDB httpfs with vended credentials."""
350
+ import duckdb
351
+
352
+ from rivet_databricks.databricks_catalog import DatabricksCatalogPlugin
353
+
354
+ plugin = DatabricksCatalogPlugin()
355
+ full_name = _resolve_full_name(joint, catalog)
356
+
357
+ table_meta = _get_table_metadata(plugin, full_name, catalog)
358
+ storage_location = table_meta.get("storage_location")
359
+ if not storage_location:
360
+ raise ExecutionError(
361
+ plugin_error(
362
+ "RVT-503",
363
+ f"No storage_location for Databricks table '{full_name}'.",
364
+ plugin_name="rivet_databricks",
365
+ plugin_type="adapter",
366
+ adapter="DatabricksDuckDBAdapter",
367
+ remediation="Verify the table exists and has a storage location.",
368
+ table=full_name,
369
+ )
370
+ )
371
+
372
+ credentials = table_meta.get("temporary_credentials")
373
+ arrow_table = material.to_arrow()
374
+ strategy = getattr(joint, "write_strategy", None) or "replace"
375
+
376
+ conn = duckdb.connect(":memory:")
377
+ try:
378
+ _ensure_duckdb_extension(conn, "httpfs")
379
+ _configure_duckdb_credentials(conn, storage_location, credentials, catalog.options)
380
+ conn.register("__write_data", arrow_table)
381
+
382
+ if strategy == "append":
383
+ sql = f"COPY __write_data TO '{storage_location}' (FORMAT PARQUET, APPEND)"
384
+ elif strategy == "partition":
385
+ partition_by = getattr(joint, "partition_by", None)
386
+ if partition_by:
387
+ cols = ", ".join(partition_by) if isinstance(partition_by, list) else partition_by
388
+ sql = f"COPY __write_data TO '{storage_location}' (FORMAT PARQUET, PARTITION_BY ({cols}))"
389
+ else:
390
+ sql = f"COPY __write_data TO '{storage_location}' (FORMAT PARQUET)"
391
+ else:
392
+ sql = f"COPY __write_data TO '{storage_location}' (FORMAT PARQUET)"
393
+
394
+ conn.execute(sql)
395
+ except ExecutionError:
396
+ raise
397
+ except Exception as exc:
398
+ raise ExecutionError(
399
+ plugin_error(
400
+ "RVT-501",
401
+ f"DuckDB Databricks write failed: {exc}",
402
+ plugin_name="rivet_databricks",
403
+ plugin_type="adapter",
404
+ adapter="DatabricksDuckDBAdapter",
405
+ remediation="Check storage location write permissions and credential validity.",
406
+ storage_location=storage_location,
407
+ strategy=strategy,
408
+ )
409
+ ) from exc
410
+ finally:
411
+ conn.close()
412
+
413
+ return material
414
+
415
+
416
+ def _get_table_metadata(plugin: Any, full_name: str, catalog: Any) -> dict[str, Any]:
417
+ """Get table metadata via the DatabricksCatalogPlugin, with graceful fallback on HTTP 403."""
418
+ from rivet_databricks.auth import resolve_credentials
419
+ from rivet_databricks.client import UnityCatalogClient
420
+
421
+ host = catalog.options["workspace_url"]
422
+ credential = resolve_credentials(catalog.options, host=host)
423
+ client = UnityCatalogClient(host=host, credential=credential)
424
+ try:
425
+ raw = client.get_table(full_name)
426
+ table_id = raw.get("table_id") or raw.get("full_name") or full_name
427
+ try:
428
+ temporary_credentials = client.vend_credentials(table_id, operation="READ")
429
+ except ExecutionError as exc:
430
+ if exc.error.code == "RVT-508":
431
+ warnings.warn(
432
+ f"Credential vending unavailable for '{full_name}': {exc.error.message} "
433
+ "Falling back to ambient cloud credentials.",
434
+ stacklevel=4,
435
+ )
436
+ temporary_credentials = None
437
+ else:
438
+ raise
439
+ finally:
440
+ client.close()
441
+
442
+ return {
443
+ "storage_location": raw.get("storage_location"),
444
+ "file_format": raw.get("data_source_format"),
445
+ "temporary_credentials": temporary_credentials,
446
+ }