rivetsql-postgres 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rivetsql_postgres-0.1.0/.gitignore +25 -0
- rivetsql_postgres-0.1.0/PKG-INFO +18 -0
- rivetsql_postgres-0.1.0/__init__.py +42 -0
- rivetsql_postgres-0.1.0/adapters/__init__.py +1 -0
- rivetsql_postgres-0.1.0/adapters/duckdb.py +279 -0
- rivetsql_postgres-0.1.0/adapters/pyspark.py +305 -0
- rivetsql_postgres-0.1.0/catalog.py +343 -0
- rivetsql_postgres-0.1.0/cross_joint.py +48 -0
- rivetsql_postgres-0.1.0/engine.py +325 -0
- rivetsql_postgres-0.1.0/errors.py +84 -0
- rivetsql_postgres-0.1.0/py.typed +0 -0
- rivetsql_postgres-0.1.0/pyproject.toml +37 -0
- rivetsql_postgres-0.1.0/sink.py +525 -0
- rivetsql_postgres-0.1.0/source.py +137 -0
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
.DS_Store
|
|
2
|
+
.env
|
|
3
|
+
.env.*
|
|
4
|
+
*.pem
|
|
5
|
+
*.key
|
|
6
|
+
secrets/
|
|
7
|
+
target/
|
|
8
|
+
build/
|
|
9
|
+
dist/
|
|
10
|
+
out/
|
|
11
|
+
.next/
|
|
12
|
+
node_modules/
|
|
13
|
+
__pycache__/
|
|
14
|
+
*.pyc
|
|
15
|
+
.pytest_cache/
|
|
16
|
+
.hypothesis/
|
|
17
|
+
.cargo/
|
|
18
|
+
*.class
|
|
19
|
+
bin/
|
|
20
|
+
obj/
|
|
21
|
+
.ralph-logs/
|
|
22
|
+
mcp.json
|
|
23
|
+
docs/reference/
|
|
24
|
+
docs/_build/
|
|
25
|
+
*.zip
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: rivetsql-postgres
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: PostgreSQL plugin for Rivet SQL
|
|
5
|
+
Project-URL: Homepage, https://github.com/rivetsql/rivet
|
|
6
|
+
Project-URL: Repository, https://github.com/rivetsql/rivet
|
|
7
|
+
Author: Rivet Contributors
|
|
8
|
+
License-Expression: MIT
|
|
9
|
+
Keywords: data-pipeline,postgres,rivet,sql
|
|
10
|
+
Classifier: Development Status :: 4 - Beta
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
15
|
+
Classifier: Topic :: Database
|
|
16
|
+
Requires-Python: >=3.11
|
|
17
|
+
Requires-Dist: psycopg[binary,pool]>=3.1
|
|
18
|
+
Requires-Dist: rivetsql-core>=0.1.0
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
"""rivet_postgres — PostgreSQL plugin for Rivet."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import TYPE_CHECKING
|
|
6
|
+
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from rivet_core.plugins import PluginRegistry
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def PostgresPlugin(registry: PluginRegistry) -> None:
|
|
12
|
+
"""Register all rivet_postgres components into the plugin registry.
|
|
13
|
+
|
|
14
|
+
Core components (catalog, engine, source, sink, cross-joint) are always
|
|
15
|
+
registered. Cross-catalog adapters are registered best-effort since they
|
|
16
|
+
depend on optional packages.
|
|
17
|
+
"""
|
|
18
|
+
from rivet_postgres.catalog import PostgresCatalogPlugin
|
|
19
|
+
from rivet_postgres.cross_joint import PostgresCrossJointAdapter
|
|
20
|
+
from rivet_postgres.engine import PostgresComputeEnginePlugin
|
|
21
|
+
from rivet_postgres.sink import PostgresSink
|
|
22
|
+
from rivet_postgres.source import PostgresSource
|
|
23
|
+
|
|
24
|
+
registry.register_catalog_plugin(PostgresCatalogPlugin())
|
|
25
|
+
registry.register_engine_plugin(PostgresComputeEnginePlugin())
|
|
26
|
+
registry.register_source(PostgresSource())
|
|
27
|
+
registry.register_sink(PostgresSink())
|
|
28
|
+
registry.register_cross_joint_adapter(PostgresCrossJointAdapter())
|
|
29
|
+
|
|
30
|
+
# Adapters depend on optional packages — register best-effort
|
|
31
|
+
try:
|
|
32
|
+
from rivet_postgres.adapters.duckdb import PostgresDuckDBAdapter
|
|
33
|
+
|
|
34
|
+
registry.register_adapter(PostgresDuckDBAdapter())
|
|
35
|
+
except ImportError:
|
|
36
|
+
pass
|
|
37
|
+
try:
|
|
38
|
+
from rivet_postgres.adapters.pyspark import PostgresPySparkAdapter
|
|
39
|
+
|
|
40
|
+
registry.register_adapter(PostgresPySparkAdapter())
|
|
41
|
+
except ImportError:
|
|
42
|
+
pass
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""PostgreSQL engine adapters for external compute engines."""
|
|
@@ -0,0 +1,279 @@
|
|
|
1
|
+
"""PostgresDuckDBAdapter: catalog-plugin-contributed adapter for DuckDB ← postgres.
|
|
2
|
+
|
|
3
|
+
Registers with target_engine="duckdb", catalog_type="postgres", source_plugin="rivet_postgres".
|
|
4
|
+
This adapter takes priority over any engine-plugin adapter for the same pair per Core adapter
|
|
5
|
+
precedence (catalog_plugin > engine_plugin).
|
|
6
|
+
|
|
7
|
+
Uses DuckDB's postgres community extension with ATTACH to establish direct
|
|
8
|
+
PostgreSQL access from DuckDB queries.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
from typing import TYPE_CHECKING, Any
|
|
14
|
+
|
|
15
|
+
import pyarrow
|
|
16
|
+
|
|
17
|
+
from rivet_core.errors import ExecutionError, RivetError, plugin_error
|
|
18
|
+
from rivet_core.models import Column, Material, Schema
|
|
19
|
+
from rivet_core.optimizer import AdapterPushdownResult, Cast, PushdownPlan, ResidualPlan
|
|
20
|
+
from rivet_core.plugins import ComputeEngineAdapter
|
|
21
|
+
from rivet_core.strategies import MaterializedRef
|
|
22
|
+
|
|
23
|
+
if TYPE_CHECKING:
|
|
24
|
+
import duckdb
|
|
25
|
+
|
|
26
|
+
from rivet_core.sql_parser import Predicate
|
|
27
|
+
|
|
28
|
+
_EMPTY_RESIDUAL = ResidualPlan(predicates=[], limit=None, casts=[])
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _apply_pushdown(
|
|
32
|
+
base_sql: str,
|
|
33
|
+
pushdown: PushdownPlan | None,
|
|
34
|
+
) -> tuple[str, ResidualPlan]:
|
|
35
|
+
"""Apply pushdown operations to a SQL query string (DuckDB dialect)."""
|
|
36
|
+
if pushdown is None:
|
|
37
|
+
return base_sql, _EMPTY_RESIDUAL
|
|
38
|
+
|
|
39
|
+
sql = base_sql
|
|
40
|
+
residual_predicates: list[Predicate] = list(pushdown.predicates.residual)
|
|
41
|
+
residual_casts: list[Cast] = list(pushdown.casts.residual)
|
|
42
|
+
residual_limit: int | None = pushdown.limit.residual_limit
|
|
43
|
+
|
|
44
|
+
if pushdown.projections.pushed_columns is not None:
|
|
45
|
+
try:
|
|
46
|
+
cols = ", ".join(pushdown.projections.pushed_columns)
|
|
47
|
+
sql = sql.replace("SELECT *", f"SELECT {cols}", 1)
|
|
48
|
+
except Exception:
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
if pushdown.predicates.pushed:
|
|
52
|
+
where_parts: list[str] = []
|
|
53
|
+
for pred in pushdown.predicates.pushed:
|
|
54
|
+
try:
|
|
55
|
+
where_parts.append(pred.expression)
|
|
56
|
+
except Exception:
|
|
57
|
+
residual_predicates.append(pred)
|
|
58
|
+
if where_parts:
|
|
59
|
+
where_clause = " AND ".join(where_parts)
|
|
60
|
+
sql = f"SELECT * FROM ({sql}) AS __pd WHERE {where_clause}"
|
|
61
|
+
|
|
62
|
+
if pushdown.limit.pushed_limit is not None:
|
|
63
|
+
try:
|
|
64
|
+
sql = f"{sql} LIMIT {pushdown.limit.pushed_limit}"
|
|
65
|
+
except Exception:
|
|
66
|
+
residual_limit = pushdown.limit.pushed_limit
|
|
67
|
+
|
|
68
|
+
for cast in pushdown.casts.pushed:
|
|
69
|
+
try:
|
|
70
|
+
sql = sql.replace(cast.column, f"CAST({cast.column} AS {cast.to_type})")
|
|
71
|
+
except Exception:
|
|
72
|
+
residual_casts.append(cast)
|
|
73
|
+
|
|
74
|
+
return sql, ResidualPlan(predicates=residual_predicates, limit=residual_limit, casts=residual_casts)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def _ensure_duckdb_extension(conn: duckdb.DuckDBPyConnection, ext: str, install_from: str | None = None) -> None:
|
|
78
|
+
"""Check-then-load a DuckDB extension."""
|
|
79
|
+
try:
|
|
80
|
+
row = conn.execute(
|
|
81
|
+
"SELECT installed, loaded FROM duckdb_extensions() WHERE extension_name = ?",
|
|
82
|
+
[ext],
|
|
83
|
+
).fetchone()
|
|
84
|
+
installed = row[0] if row is not None else False
|
|
85
|
+
loaded = row[1] if row is not None else False
|
|
86
|
+
if loaded:
|
|
87
|
+
return
|
|
88
|
+
if not installed:
|
|
89
|
+
install_sql = f"INSTALL {ext} FROM {install_from}" if install_from else f"INSTALL {ext}"
|
|
90
|
+
conn.execute(install_sql)
|
|
91
|
+
conn.execute(f"LOAD {ext}")
|
|
92
|
+
except ExecutionError:
|
|
93
|
+
raise
|
|
94
|
+
except Exception as exc:
|
|
95
|
+
raise ExecutionError(
|
|
96
|
+
plugin_error(
|
|
97
|
+
"RVT-502",
|
|
98
|
+
f"Failed to load DuckDB extension '{ext}': {exc}",
|
|
99
|
+
plugin_name="rivet_postgres",
|
|
100
|
+
plugin_type="adapter",
|
|
101
|
+
remediation=(
|
|
102
|
+
f"Run: INSTALL {ext}; LOAD {ext}; "
|
|
103
|
+
"or set the DuckDB extension directory for offline environments."
|
|
104
|
+
),
|
|
105
|
+
extension=ext,
|
|
106
|
+
)
|
|
107
|
+
) from exc
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def _build_dsn(catalog_options: dict[str, Any]) -> str:
|
|
111
|
+
"""Build a PostgreSQL connection string from catalog options."""
|
|
112
|
+
host = catalog_options["host"]
|
|
113
|
+
port = catalog_options.get("port", 5432)
|
|
114
|
+
database = catalog_options["database"]
|
|
115
|
+
user = catalog_options.get("user", "")
|
|
116
|
+
password = catalog_options.get("password", "")
|
|
117
|
+
parts = [f"host={host}", f"port={port}", f"dbname={database}"]
|
|
118
|
+
if user:
|
|
119
|
+
parts.append(f"user={user}")
|
|
120
|
+
if password:
|
|
121
|
+
parts.append(f"password={password}")
|
|
122
|
+
return " ".join(parts)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _attach_alias(catalog_name: str) -> str:
|
|
126
|
+
"""Derive a DuckDB ATTACH alias from the catalog name."""
|
|
127
|
+
return f"pg_{catalog_name}"
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class _PostgresDuckDBMaterializedRef(MaterializedRef):
|
|
131
|
+
"""Deferred ref that reads PostgreSQL data via DuckDB postgres extension on to_arrow()."""
|
|
132
|
+
|
|
133
|
+
def __init__(self, catalog_options: dict[str, Any], catalog_name: str, sql: str | None, table: str | None) -> None:
|
|
134
|
+
self._catalog_options = catalog_options
|
|
135
|
+
self._catalog_name = catalog_name
|
|
136
|
+
self._sql = sql
|
|
137
|
+
self._table = table
|
|
138
|
+
self._cached: pyarrow.Table | None = None
|
|
139
|
+
|
|
140
|
+
def _execute(self) -> pyarrow.Table:
|
|
141
|
+
if self._cached is not None:
|
|
142
|
+
return self._cached
|
|
143
|
+
import duckdb
|
|
144
|
+
|
|
145
|
+
conn = duckdb.connect(":memory:")
|
|
146
|
+
try:
|
|
147
|
+
_ensure_duckdb_extension(conn, "postgres", install_from="community")
|
|
148
|
+
dsn = _build_dsn(self._catalog_options)
|
|
149
|
+
alias = _attach_alias(self._catalog_name)
|
|
150
|
+
conn.execute(f"ATTACH '{dsn}' AS {alias} (TYPE postgres, READ_ONLY)")
|
|
151
|
+
|
|
152
|
+
if self._sql:
|
|
153
|
+
result = conn.execute(self._sql).arrow()
|
|
154
|
+
else:
|
|
155
|
+
pg_schema = self._catalog_options.get("schema", "public")
|
|
156
|
+
table_ref = f"{alias}.{pg_schema}.{self._table}"
|
|
157
|
+
result = conn.execute(f"SELECT * FROM {table_ref}").arrow()
|
|
158
|
+
|
|
159
|
+
self._cached = result
|
|
160
|
+
return result
|
|
161
|
+
except ExecutionError:
|
|
162
|
+
raise
|
|
163
|
+
except Exception as exc:
|
|
164
|
+
raise ExecutionError(
|
|
165
|
+
RivetError(
|
|
166
|
+
code="RVT-504",
|
|
167
|
+
message=f"DuckDB postgres extension failed: {exc}",
|
|
168
|
+
context={"host": self._catalog_options.get("host"), "database": self._catalog_options.get("database")},
|
|
169
|
+
remediation="Check PostgreSQL connectivity, credentials, and that the DuckDB postgres extension is available.",
|
|
170
|
+
)
|
|
171
|
+
) from exc
|
|
172
|
+
finally:
|
|
173
|
+
conn.close()
|
|
174
|
+
|
|
175
|
+
def to_arrow(self) -> pyarrow.Table:
|
|
176
|
+
return self._execute()
|
|
177
|
+
|
|
178
|
+
@property
|
|
179
|
+
def schema(self) -> Schema:
|
|
180
|
+
table = self._execute()
|
|
181
|
+
return Schema(columns=[Column(name=f.name, type=str(f.type), nullable=f.nullable) for f in table.schema])
|
|
182
|
+
|
|
183
|
+
@property
|
|
184
|
+
def row_count(self) -> int:
|
|
185
|
+
return self._execute().num_rows # type: ignore[no-any-return]
|
|
186
|
+
|
|
187
|
+
@property
|
|
188
|
+
def size_bytes(self) -> int | None:
|
|
189
|
+
return None
|
|
190
|
+
|
|
191
|
+
@property
|
|
192
|
+
def storage_type(self) -> str:
|
|
193
|
+
return "postgres"
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
class PostgresDuckDBAdapter(ComputeEngineAdapter):
|
|
197
|
+
"""DuckDB adapter for PostgreSQL catalog type, shipped by rivet_postgres.
|
|
198
|
+
|
|
199
|
+
Uses DuckDB's postgres community extension with ATTACH for direct
|
|
200
|
+
PostgreSQL access from DuckDB queries.
|
|
201
|
+
"""
|
|
202
|
+
|
|
203
|
+
target_engine_type = "duckdb"
|
|
204
|
+
catalog_type = "postgres"
|
|
205
|
+
capabilities: list[str] = [
|
|
206
|
+
"projection_pushdown",
|
|
207
|
+
"predicate_pushdown",
|
|
208
|
+
"limit_pushdown",
|
|
209
|
+
"cast_pushdown",
|
|
210
|
+
"join",
|
|
211
|
+
"aggregation",
|
|
212
|
+
]
|
|
213
|
+
source = "catalog_plugin"
|
|
214
|
+
source_plugin = "rivet_postgres"
|
|
215
|
+
|
|
216
|
+
def read_dispatch(self, engine: Any, catalog: Any, joint: Any, pushdown: PushdownPlan | None = None) -> AdapterPushdownResult:
|
|
217
|
+
if joint.sql:
|
|
218
|
+
base_sql = joint.sql
|
|
219
|
+
else:
|
|
220
|
+
pg_schema = catalog.options.get("schema", "public")
|
|
221
|
+
alias = _attach_alias(catalog.name)
|
|
222
|
+
base_sql = f"SELECT * FROM {alias}.{pg_schema}.{joint.table}"
|
|
223
|
+
|
|
224
|
+
sql, residual = _apply_pushdown(base_sql, pushdown)
|
|
225
|
+
|
|
226
|
+
material = Material(
|
|
227
|
+
name=joint.name,
|
|
228
|
+
catalog=catalog.name,
|
|
229
|
+
materialized_ref=_PostgresDuckDBMaterializedRef(
|
|
230
|
+
catalog_options=catalog.options,
|
|
231
|
+
catalog_name=catalog.name,
|
|
232
|
+
sql=sql,
|
|
233
|
+
table=joint.table,
|
|
234
|
+
),
|
|
235
|
+
state="deferred",
|
|
236
|
+
)
|
|
237
|
+
return AdapterPushdownResult(material=material, residual=residual)
|
|
238
|
+
|
|
239
|
+
def write_dispatch(self, engine: Any, catalog: Any, joint: Any, material: Any) -> Any:
|
|
240
|
+
import duckdb
|
|
241
|
+
|
|
242
|
+
conn = duckdb.connect(":memory:")
|
|
243
|
+
try:
|
|
244
|
+
_ensure_duckdb_extension(conn, "postgres", install_from="community")
|
|
245
|
+
dsn = _build_dsn(catalog.options)
|
|
246
|
+
alias = _attach_alias(catalog.name)
|
|
247
|
+
conn.execute(f"ATTACH '{dsn}' AS {alias} (TYPE postgres)")
|
|
248
|
+
|
|
249
|
+
arrow_table = material.to_arrow()
|
|
250
|
+
conn.register("__write_data", arrow_table)
|
|
251
|
+
|
|
252
|
+
pg_schema = catalog.options.get("schema", "public")
|
|
253
|
+
table_name = joint.table
|
|
254
|
+
table_ref = f"{alias}.{pg_schema}.{table_name}"
|
|
255
|
+
strategy = joint.write_strategy or "replace"
|
|
256
|
+
|
|
257
|
+
if strategy == "replace":
|
|
258
|
+
conn.execute(f"DROP TABLE IF EXISTS {table_ref}")
|
|
259
|
+
conn.execute(f"CREATE TABLE {table_ref} AS SELECT * FROM __write_data")
|
|
260
|
+
elif strategy == "append":
|
|
261
|
+
conn.execute(f"INSERT INTO {table_ref} SELECT * FROM __write_data")
|
|
262
|
+
elif strategy == "truncate_insert":
|
|
263
|
+
conn.execute(f"DELETE FROM {table_ref}")
|
|
264
|
+
conn.execute(f"INSERT INTO {table_ref} SELECT * FROM __write_data")
|
|
265
|
+
else:
|
|
266
|
+
conn.execute(f"INSERT INTO {table_ref} SELECT * FROM __write_data")
|
|
267
|
+
except ExecutionError:
|
|
268
|
+
raise
|
|
269
|
+
except Exception as exc:
|
|
270
|
+
raise ExecutionError(
|
|
271
|
+
RivetError(
|
|
272
|
+
code="RVT-504",
|
|
273
|
+
message=f"DuckDB postgres extension write failed: {exc}",
|
|
274
|
+
context={"host": catalog.options.get("host"), "database": catalog.options.get("database")},
|
|
275
|
+
remediation="Check PostgreSQL connectivity, credentials, and write permissions.",
|
|
276
|
+
)
|
|
277
|
+
) from exc
|
|
278
|
+
finally:
|
|
279
|
+
conn.close()
|
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
"""PostgresPySparkAdapter: JDBC read/write for PostgreSQL catalogs via PySpark.
|
|
2
|
+
|
|
3
|
+
Simple writes (append, replace) use Spark JDBC. Complex writes (truncate_insert,
|
|
4
|
+
merge, delete_insert, incremental_append, scd2) use a psycopg3 side-channel that
|
|
5
|
+
materializes to Arrow and writes directly to PostgreSQL.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from __future__ import annotations
|
|
9
|
+
|
|
10
|
+
import asyncio
|
|
11
|
+
from typing import TYPE_CHECKING, Any
|
|
12
|
+
|
|
13
|
+
import pyarrow
|
|
14
|
+
|
|
15
|
+
from rivet_core.errors import ExecutionError, RivetError
|
|
16
|
+
from rivet_core.models import Column, Material, Schema
|
|
17
|
+
from rivet_core.optimizer import AdapterPushdownResult, Cast, PushdownPlan, ResidualPlan
|
|
18
|
+
from rivet_core.plugins import ComputeEngineAdapter
|
|
19
|
+
from rivet_core.strategies import MaterializedRef
|
|
20
|
+
|
|
21
|
+
if TYPE_CHECKING:
|
|
22
|
+
from rivet_core.sql_parser import Predicate
|
|
23
|
+
|
|
24
|
+
ALL_6_CAPABILITIES = [
|
|
25
|
+
"projection_pushdown",
|
|
26
|
+
"predicate_pushdown",
|
|
27
|
+
"limit_pushdown",
|
|
28
|
+
"cast_pushdown",
|
|
29
|
+
"join",
|
|
30
|
+
"aggregation",
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
_JDBC_DRIVER = "org.postgresql.Driver"
|
|
34
|
+
_JDBC_MAVEN = "org.postgresql:postgresql:42.7.3"
|
|
35
|
+
_SIDE_CHANNEL_STRATEGIES = frozenset({"truncate_insert", "merge", "delete_insert", "incremental_append", "scd2"})
|
|
36
|
+
|
|
37
|
+
_EMPTY_RESIDUAL = ResidualPlan(predicates=[], limit=None, casts=[])
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _apply_pyspark_pushdown(
|
|
41
|
+
df: Any,
|
|
42
|
+
pushdown: PushdownPlan | None,
|
|
43
|
+
) -> tuple[Any, ResidualPlan]:
|
|
44
|
+
"""Apply pushdown operations to a PySpark DataFrame.
|
|
45
|
+
|
|
46
|
+
Returns (modified_df, residual) where residual contains any operations
|
|
47
|
+
that could not be applied.
|
|
48
|
+
"""
|
|
49
|
+
if pushdown is None:
|
|
50
|
+
return df, _EMPTY_RESIDUAL
|
|
51
|
+
|
|
52
|
+
residual_predicates: list[Predicate] = list(pushdown.predicates.residual)
|
|
53
|
+
residual_casts: list[Cast] = list(pushdown.casts.residual)
|
|
54
|
+
residual_limit: int | None = pushdown.limit.residual_limit
|
|
55
|
+
|
|
56
|
+
if pushdown.projections.pushed_columns is not None:
|
|
57
|
+
try:
|
|
58
|
+
df = df.select(*pushdown.projections.pushed_columns)
|
|
59
|
+
except Exception:
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
for pred in pushdown.predicates.pushed:
|
|
63
|
+
try:
|
|
64
|
+
df = df.filter(pred.expression)
|
|
65
|
+
except Exception:
|
|
66
|
+
residual_predicates.append(pred)
|
|
67
|
+
|
|
68
|
+
if pushdown.limit.pushed_limit is not None:
|
|
69
|
+
try:
|
|
70
|
+
df = df.limit(pushdown.limit.pushed_limit)
|
|
71
|
+
except Exception:
|
|
72
|
+
residual_limit = pushdown.limit.pushed_limit
|
|
73
|
+
|
|
74
|
+
if pushdown.casts.pushed:
|
|
75
|
+
from pyspark.sql import functions as F
|
|
76
|
+
|
|
77
|
+
for cast in pushdown.casts.pushed:
|
|
78
|
+
try:
|
|
79
|
+
df = df.withColumn(cast.column, F.col(cast.column).cast(cast.to_type))
|
|
80
|
+
except Exception:
|
|
81
|
+
residual_casts.append(cast)
|
|
82
|
+
|
|
83
|
+
return df, ResidualPlan(
|
|
84
|
+
predicates=residual_predicates,
|
|
85
|
+
limit=residual_limit,
|
|
86
|
+
casts=residual_casts,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class _SparkDataFrameMaterializedRef(MaterializedRef):
|
|
91
|
+
"""MaterializedRef backed by a Spark DataFrame. Materializes to Arrow on to_arrow()."""
|
|
92
|
+
|
|
93
|
+
def __init__(self, df: Any) -> None:
|
|
94
|
+
self._df = df
|
|
95
|
+
|
|
96
|
+
def to_arrow(self) -> pyarrow.Table:
|
|
97
|
+
if hasattr(self._df, "toArrow"):
|
|
98
|
+
result = self._df.toArrow()
|
|
99
|
+
if isinstance(result, pyarrow.RecordBatchReader):
|
|
100
|
+
return result.read_all()
|
|
101
|
+
return result
|
|
102
|
+
return pyarrow.Table.from_pandas(self._df.toPandas())
|
|
103
|
+
|
|
104
|
+
@property
|
|
105
|
+
def schema(self) -> Schema:
|
|
106
|
+
return Schema(
|
|
107
|
+
columns=[
|
|
108
|
+
Column(name=field.name, type=str(field.dataType), nullable=field.nullable)
|
|
109
|
+
for field in self._df.schema.fields
|
|
110
|
+
]
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def row_count(self) -> int:
|
|
115
|
+
return self._df.count() # type: ignore[no-any-return]
|
|
116
|
+
|
|
117
|
+
@property
|
|
118
|
+
def size_bytes(self) -> int | None:
|
|
119
|
+
return None
|
|
120
|
+
|
|
121
|
+
@property
|
|
122
|
+
def storage_type(self) -> str:
|
|
123
|
+
return "spark_dataframe"
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _check_jdbc_driver(session: Any) -> None:
|
|
127
|
+
"""Fail with RVT-505 if the PostgreSQL JDBC driver JAR is missing."""
|
|
128
|
+
try:
|
|
129
|
+
session._jvm.java.lang.Class.forName(_JDBC_DRIVER)
|
|
130
|
+
except Exception:
|
|
131
|
+
raise ExecutionError( # noqa: B904
|
|
132
|
+
RivetError(
|
|
133
|
+
code="RVT-505",
|
|
134
|
+
message="PostgreSQL JDBC driver JAR not found on Spark classpath.",
|
|
135
|
+
context={"driver": _JDBC_DRIVER, "adapter": "PostgresPySparkAdapter"},
|
|
136
|
+
remediation=(
|
|
137
|
+
f"Add the PostgreSQL JDBC driver to Spark via "
|
|
138
|
+
f"'spark.jars.packages' option: '{_JDBC_MAVEN}'."
|
|
139
|
+
),
|
|
140
|
+
)
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def _build_jdbc_url(options: dict[str, Any]) -> str:
|
|
145
|
+
"""Build a JDBC URL from postgres catalog options."""
|
|
146
|
+
host = options["host"]
|
|
147
|
+
port = options.get("port", 5432)
|
|
148
|
+
database = options["database"]
|
|
149
|
+
ssl_mode = options.get("ssl_mode", "prefer")
|
|
150
|
+
return f"jdbc:postgresql://{host}:{port}/{database}?sslmode={ssl_mode}"
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def _build_jdbc_properties(options: dict[str, Any]) -> dict[str, str]:
|
|
154
|
+
"""Build JDBC connection properties from catalog options."""
|
|
155
|
+
props: dict[str, str] = {"driver": _JDBC_DRIVER}
|
|
156
|
+
if options.get("user"):
|
|
157
|
+
props["user"] = options["user"]
|
|
158
|
+
if options.get("password"):
|
|
159
|
+
props["password"] = options["password"]
|
|
160
|
+
return props
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class PostgresPySparkAdapter(ComputeEngineAdapter):
|
|
164
|
+
"""Adapter enabling PySpark engine to read/write PostgreSQL catalogs via JDBC.
|
|
165
|
+
|
|
166
|
+
Registered as catalog_plugin-contributed so it takes precedence over any
|
|
167
|
+
engine_plugin adapter for the same (pyspark, postgres) pair.
|
|
168
|
+
"""
|
|
169
|
+
|
|
170
|
+
target_engine_type = "pyspark"
|
|
171
|
+
catalog_type = "postgres"
|
|
172
|
+
capabilities = ALL_6_CAPABILITIES
|
|
173
|
+
source = "catalog_plugin"
|
|
174
|
+
source_plugin = "rivet_postgres"
|
|
175
|
+
|
|
176
|
+
def read_dispatch(self, engine: Any, catalog: Any, joint: Any, pushdown: PushdownPlan | None = None) -> AdapterPushdownResult:
|
|
177
|
+
"""Read from PostgreSQL via Spark JDBC with optional parallel partitioned reads."""
|
|
178
|
+
session = engine.get_session()
|
|
179
|
+
_check_jdbc_driver(session)
|
|
180
|
+
|
|
181
|
+
url = _build_jdbc_url(catalog.options)
|
|
182
|
+
props = _build_jdbc_properties(catalog.options)
|
|
183
|
+
|
|
184
|
+
# Determine what to read: wrap SQL as subquery or use table reference
|
|
185
|
+
sql = getattr(joint, "sql", None)
|
|
186
|
+
table = getattr(joint, "table", None)
|
|
187
|
+
if sql:
|
|
188
|
+
dbtable = f"({sql}) AS _rivet_subquery"
|
|
189
|
+
elif table:
|
|
190
|
+
schema = catalog.options.get("schema", "public")
|
|
191
|
+
dbtable = f"{schema}.{table}"
|
|
192
|
+
else:
|
|
193
|
+
schema = catalog.options.get("schema", "public")
|
|
194
|
+
dbtable = f"{schema}.{joint.name}"
|
|
195
|
+
|
|
196
|
+
# Check for parallel partitioned read options on the joint
|
|
197
|
+
partition_column = getattr(joint, "jdbc_partition_column", None)
|
|
198
|
+
lower_bound = getattr(joint, "jdbc_lower_bound", None)
|
|
199
|
+
upper_bound = getattr(joint, "jdbc_upper_bound", None)
|
|
200
|
+
num_partitions = getattr(joint, "jdbc_num_partitions", None)
|
|
201
|
+
|
|
202
|
+
try:
|
|
203
|
+
if partition_column and num_partitions:
|
|
204
|
+
df = (
|
|
205
|
+
session.read.jdbc(
|
|
206
|
+
url=url,
|
|
207
|
+
table=dbtable,
|
|
208
|
+
column=partition_column,
|
|
209
|
+
lowerBound=int(lower_bound) if lower_bound is not None else 0,
|
|
210
|
+
upperBound=int(upper_bound) if upper_bound is not None else 1000000,
|
|
211
|
+
numPartitions=int(num_partitions),
|
|
212
|
+
properties=props,
|
|
213
|
+
)
|
|
214
|
+
)
|
|
215
|
+
else:
|
|
216
|
+
df = session.read.jdbc(url=url, table=dbtable, properties=props)
|
|
217
|
+
except Exception as exc:
|
|
218
|
+
raise ExecutionError(
|
|
219
|
+
RivetError(
|
|
220
|
+
code="RVT-501",
|
|
221
|
+
message=f"PostgreSQL JDBC read failed: {exc}",
|
|
222
|
+
context={"host": catalog.options.get("host"), "database": catalog.options.get("database")},
|
|
223
|
+
remediation="Check PostgreSQL credentials, host, and network connectivity.",
|
|
224
|
+
)
|
|
225
|
+
) from exc
|
|
226
|
+
|
|
227
|
+
df, residual = _apply_pyspark_pushdown(df, pushdown)
|
|
228
|
+
|
|
229
|
+
ref = _SparkDataFrameMaterializedRef(df)
|
|
230
|
+
material = Material(
|
|
231
|
+
name=joint.name,
|
|
232
|
+
catalog=catalog.name,
|
|
233
|
+
materialized_ref=ref,
|
|
234
|
+
state="deferred",
|
|
235
|
+
)
|
|
236
|
+
return AdapterPushdownResult(material=material, residual=residual)
|
|
237
|
+
|
|
238
|
+
def write_dispatch(self, engine: Any, catalog: Any, joint: Any, material: Any) -> Any:
|
|
239
|
+
"""Write to PostgreSQL via Spark JDBC (append, replace) or psycopg3 side-channel (complex)."""
|
|
240
|
+
strategy = getattr(joint, "write_strategy", None) or "replace"
|
|
241
|
+
|
|
242
|
+
if strategy in _SIDE_CHANNEL_STRATEGIES:
|
|
243
|
+
return _psycopg3_side_channel(catalog, joint, material, strategy)
|
|
244
|
+
|
|
245
|
+
session = engine.get_session()
|
|
246
|
+
_check_jdbc_driver(session)
|
|
247
|
+
|
|
248
|
+
url = _build_jdbc_url(catalog.options)
|
|
249
|
+
props = _build_jdbc_properties(catalog.options)
|
|
250
|
+
|
|
251
|
+
# Resolve target table name
|
|
252
|
+
table = getattr(joint, "table", None) or joint.name
|
|
253
|
+
schema = catalog.options.get("schema", "public")
|
|
254
|
+
dbtable = f"{schema}.{table}"
|
|
255
|
+
|
|
256
|
+
# Map strategy to Spark JDBC write mode
|
|
257
|
+
mode = "append" if strategy == "append" else "overwrite"
|
|
258
|
+
|
|
259
|
+
# Materialize to Spark DataFrame
|
|
260
|
+
arrow_table = material.materialized_ref.to_arrow()
|
|
261
|
+
df = session.createDataFrame(arrow_table.to_pandas())
|
|
262
|
+
|
|
263
|
+
try:
|
|
264
|
+
df.write.jdbc(url=url, table=dbtable, mode=mode, properties=props)
|
|
265
|
+
except Exception as exc:
|
|
266
|
+
raise ExecutionError(
|
|
267
|
+
RivetError(
|
|
268
|
+
code="RVT-501",
|
|
269
|
+
message=f"PostgreSQL JDBC write failed: {exc}",
|
|
270
|
+
context={"host": catalog.options.get("host"), "database": catalog.options.get("database")},
|
|
271
|
+
remediation="Check PostgreSQL credentials, host, and write permissions.",
|
|
272
|
+
)
|
|
273
|
+
) from exc
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
# --- psycopg3 side-channel for complex write strategies ---
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def _psycopg3_side_channel(catalog: Any, joint: Any, material: Any, strategy: str) -> None:
|
|
280
|
+
"""Execute complex write strategies via psycopg3 directly, bypassing Spark JDBC."""
|
|
281
|
+
from rivet_postgres.sink import _build_conninfo, _execute_strategy
|
|
282
|
+
|
|
283
|
+
arrow_table = material.materialized_ref.to_arrow()
|
|
284
|
+
conninfo = _build_conninfo(catalog.options)
|
|
285
|
+
table = getattr(joint, "table", None) or joint.name
|
|
286
|
+
schema = catalog.options.get("schema", "public")
|
|
287
|
+
qualified_table = f"{schema}.{table}"
|
|
288
|
+
|
|
289
|
+
try:
|
|
290
|
+
asyncio.run(_execute_strategy(conninfo, qualified_table, arrow_table, strategy, joint))
|
|
291
|
+
except ExecutionError:
|
|
292
|
+
raise
|
|
293
|
+
except Exception as exc:
|
|
294
|
+
raise ExecutionError(
|
|
295
|
+
RivetError(
|
|
296
|
+
code="RVT-501",
|
|
297
|
+
message=f"PostgreSQL psycopg3 side-channel write failed: {exc}",
|
|
298
|
+
context={
|
|
299
|
+
"host": catalog.options.get("host"),
|
|
300
|
+
"database": catalog.options.get("database"),
|
|
301
|
+
"strategy": strategy,
|
|
302
|
+
},
|
|
303
|
+
remediation="Check PostgreSQL credentials, host, and write permissions.",
|
|
304
|
+
)
|
|
305
|
+
) from exc
|