wren-engine 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wren/__init__.py +9 -0
- wren/cli.py +382 -0
- wren/connector/__init__.py +4 -0
- wren/connector/base.py +87 -0
- wren/connector/bigquery.py +57 -0
- wren/connector/canner.py +79 -0
- wren/connector/databricks.py +65 -0
- wren/connector/duckdb.py +125 -0
- wren/connector/factory.py +68 -0
- wren/connector/ibis.py +103 -0
- wren/connector/mssql.py +113 -0
- wren/connector/mysql.py +57 -0
- wren/connector/oracle.py +171 -0
- wren/connector/postgres.py +76 -0
- wren/connector/redshift.py +68 -0
- wren/connector/spark.py +52 -0
- wren/engine.py +199 -0
- wren/mdl/__init__.py +44 -0
- wren/mdl/cte_rewriter.py +234 -0
- wren/mdl/wren_dialect.py +19 -0
- wren/memory/__init__.py +101 -0
- wren/memory/cli.py +263 -0
- wren/memory/embeddings.py +26 -0
- wren/memory/schema_indexer.py +262 -0
- wren/memory/store.py +304 -0
- wren/model/__init__.py +296 -0
- wren/model/data_source.py +477 -0
- wren/model/error.py +86 -0
- wren_engine-0.1.0.dist-info/METADATA +222 -0
- wren_engine-0.1.0.dist-info/RECORD +32 -0
- wren_engine-0.1.0.dist-info/WHEEL +4 -0
- wren_engine-0.1.0.dist-info/entry_points.txt +2 -0
wren/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"""Wren — semantic SQL layer for 20+ data sources."""
|
|
2
|
+
|
|
3
|
+
__version__ = "0.1.0"
|
|
4
|
+
|
|
5
|
+
from wren.engine import WrenEngine
|
|
6
|
+
from wren.model.data_source import DataSource
|
|
7
|
+
from wren.model.error import WrenError
|
|
8
|
+
|
|
9
|
+
__all__ = ["WrenEngine", "DataSource", "WrenError", "__version__"]
|
wren/cli.py
ADDED
|
@@ -0,0 +1,382 @@
|
|
|
1
|
+
"""Wren CLI — SQL transform and execution via the Wren semantic layer."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Annotated, Optional
|
|
8
|
+
|
|
9
|
+
import typer
|
|
10
|
+
|
|
11
|
+
app = typer.Typer(name="wren", help="Wren Engine CLI", no_args_is_help=False)
|
|
12
|
+
|
|
13
|
+
_WREN_HOME = Path.home() / ".wren"
|
|
14
|
+
_DEFAULT_MDL = _WREN_HOME / "mdl.json"
|
|
15
|
+
_DEFAULT_CONN = _WREN_HOME / "connection_info.json"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# ── File discovery helpers ─────────────────────────────────────────────────
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _require_mdl(mdl: str | None) -> str:
|
|
22
|
+
"""Return mdl arg if given, else auto-discover mdl.json from ~/.wren."""
|
|
23
|
+
if mdl is not None:
|
|
24
|
+
return mdl
|
|
25
|
+
if _DEFAULT_MDL.exists():
|
|
26
|
+
return str(_DEFAULT_MDL)
|
|
27
|
+
typer.echo(
|
|
28
|
+
f"Error: --mdl not specified and '{_DEFAULT_MDL}' not found.",
|
|
29
|
+
err=True,
|
|
30
|
+
)
|
|
31
|
+
raise typer.Exit(1)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _load_manifest(mdl: str) -> str:
|
|
35
|
+
"""Load MDL from a file path or treat as base64 string directly."""
|
|
36
|
+
path = Path(mdl).expanduser()
|
|
37
|
+
if path.exists():
|
|
38
|
+
import base64 # noqa: PLC0415
|
|
39
|
+
|
|
40
|
+
content = path.read_bytes()
|
|
41
|
+
if path.suffix.lower() == ".json":
|
|
42
|
+
# Raw JSON file — base64-encode it for WrenEngine
|
|
43
|
+
return base64.b64encode(content).decode()
|
|
44
|
+
# Non-.json file — assume it already contains a base64-encoded MDL string
|
|
45
|
+
return content.decode().strip()
|
|
46
|
+
# Not a file path — treat as a raw base64 string passed directly
|
|
47
|
+
return mdl
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def _load_conn(
|
|
51
|
+
connection_info: str | None,
|
|
52
|
+
connection_file: str | None,
|
|
53
|
+
*,
|
|
54
|
+
required: bool = True,
|
|
55
|
+
) -> dict:
|
|
56
|
+
"""Load connection dict from inline JSON or file, with ~/.wren auto-discovery.
|
|
57
|
+
|
|
58
|
+
If neither --connection-info nor --connection-file is given, looks for
|
|
59
|
+
connection_info.json in ~/.wren. Raises typer.Exit(1) if required=True and nothing
|
|
60
|
+
is found.
|
|
61
|
+
"""
|
|
62
|
+
if connection_info:
|
|
63
|
+
try:
|
|
64
|
+
conn = json.loads(connection_info)
|
|
65
|
+
except json.JSONDecodeError as e:
|
|
66
|
+
typer.echo(f"Error: invalid JSON in --connection-info: {e}", err=True)
|
|
67
|
+
raise typer.Exit(1)
|
|
68
|
+
if not isinstance(conn, dict):
|
|
69
|
+
typer.echo(
|
|
70
|
+
"Error: --connection-info must decode to a JSON object.", err=True
|
|
71
|
+
)
|
|
72
|
+
raise typer.Exit(1)
|
|
73
|
+
return conn
|
|
74
|
+
|
|
75
|
+
path_str = connection_file or (
|
|
76
|
+
str(_DEFAULT_CONN) if _DEFAULT_CONN.exists() else None
|
|
77
|
+
)
|
|
78
|
+
if path_str:
|
|
79
|
+
path = Path(path_str).expanduser()
|
|
80
|
+
if not path.exists():
|
|
81
|
+
typer.echo(f"Error: connection file not found: {path_str}", err=True)
|
|
82
|
+
raise typer.Exit(1)
|
|
83
|
+
try:
|
|
84
|
+
conn = json.loads(path.read_text())
|
|
85
|
+
except json.JSONDecodeError as e:
|
|
86
|
+
typer.echo(f"Error: invalid JSON in {path_str}: {e}", err=True)
|
|
87
|
+
raise typer.Exit(1)
|
|
88
|
+
if not isinstance(conn, dict):
|
|
89
|
+
typer.echo(f"Error: {path_str} must contain a JSON object.", err=True)
|
|
90
|
+
raise typer.Exit(1)
|
|
91
|
+
return conn
|
|
92
|
+
|
|
93
|
+
if required:
|
|
94
|
+
typer.echo(
|
|
95
|
+
f"Error: --connection-file not specified and '{_DEFAULT_CONN}' not found.",
|
|
96
|
+
err=True,
|
|
97
|
+
)
|
|
98
|
+
raise typer.Exit(1)
|
|
99
|
+
return {}
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _resolve_datasource(explicit: str | None, conn_dict: dict) -> str:
|
|
103
|
+
"""Return datasource: use explicit --datasource arg first, then pop from conn dict.
|
|
104
|
+
|
|
105
|
+
Note: mutates conn_dict by removing the 'datasource' key so it is not
|
|
106
|
+
forwarded as an unknown field to WrenEngine / the connector.
|
|
107
|
+
"""
|
|
108
|
+
if explicit:
|
|
109
|
+
conn_dict.pop("datasource", None)
|
|
110
|
+
return explicit
|
|
111
|
+
ds = conn_dict.pop("datasource", None)
|
|
112
|
+
if ds:
|
|
113
|
+
return ds
|
|
114
|
+
typer.echo(
|
|
115
|
+
"Error: --datasource not specified and 'datasource' key not found in connection info.",
|
|
116
|
+
err=True,
|
|
117
|
+
)
|
|
118
|
+
raise typer.Exit(1)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def _build_engine(
|
|
122
|
+
datasource: str | None,
|
|
123
|
+
mdl: str | None,
|
|
124
|
+
connection_info: str | None,
|
|
125
|
+
connection_file: str | None,
|
|
126
|
+
*,
|
|
127
|
+
conn_required: bool = True,
|
|
128
|
+
):
|
|
129
|
+
from wren.engine import WrenEngine # noqa: PLC0415
|
|
130
|
+
from wren.model.data_source import DataSource # noqa: PLC0415
|
|
131
|
+
|
|
132
|
+
manifest_str = _load_manifest(_require_mdl(mdl))
|
|
133
|
+
conn_dict = _load_conn(connection_info, connection_file, required=conn_required)
|
|
134
|
+
ds_str = _resolve_datasource(datasource, conn_dict)
|
|
135
|
+
|
|
136
|
+
try:
|
|
137
|
+
ds = DataSource(ds_str.lower())
|
|
138
|
+
except ValueError:
|
|
139
|
+
typer.echo(f"Error: unknown datasource '{ds_str}'", err=True)
|
|
140
|
+
raise typer.Exit(1)
|
|
141
|
+
|
|
142
|
+
return WrenEngine(
|
|
143
|
+
manifest_str=manifest_str, data_source=ds, connection_info=conn_dict
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
# ── Shared option types ────────────────────────────────────────────────────
|
|
148
|
+
|
|
149
|
+
DatasourceOpt = Annotated[
|
|
150
|
+
Optional[str],
|
|
151
|
+
typer.Option(
|
|
152
|
+
"--datasource",
|
|
153
|
+
"-d",
|
|
154
|
+
help="Data source (e.g. mysql, postgres). Defaults to 'datasource' field in connection_info.json.",
|
|
155
|
+
),
|
|
156
|
+
]
|
|
157
|
+
MdlOpt = Annotated[
|
|
158
|
+
Optional[str],
|
|
159
|
+
typer.Option(
|
|
160
|
+
"--mdl",
|
|
161
|
+
"-m",
|
|
162
|
+
help=f"Path to MDL JSON file or base64 string. Defaults to {_DEFAULT_MDL}.",
|
|
163
|
+
),
|
|
164
|
+
]
|
|
165
|
+
ConnInfoOpt = Annotated[
|
|
166
|
+
Optional[str],
|
|
167
|
+
typer.Option("--connection-info", help="Inline JSON connection string"),
|
|
168
|
+
]
|
|
169
|
+
ConnFileOpt = Annotated[
|
|
170
|
+
Optional[str],
|
|
171
|
+
typer.Option(
|
|
172
|
+
"--connection-file",
|
|
173
|
+
help=f"Path to JSON connection file. Defaults to {_DEFAULT_CONN}.",
|
|
174
|
+
),
|
|
175
|
+
]
|
|
176
|
+
LimitOpt = Annotated[
|
|
177
|
+
Optional[int], typer.Option("--limit", "-l", help="Max rows to return")
|
|
178
|
+
]
|
|
179
|
+
OutputOpt = Annotated[
|
|
180
|
+
str, typer.Option("--output", "-o", help="Output format: json|csv|table")
|
|
181
|
+
]
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
# ── Default command (no subcommand = query) ────────────────────────────────
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
@app.callback(invoke_without_command=True)
|
|
188
|
+
def main(
|
|
189
|
+
ctx: typer.Context,
|
|
190
|
+
sql: Annotated[
|
|
191
|
+
Optional[str],
|
|
192
|
+
typer.Option(
|
|
193
|
+
"--sql", "-s", help="SQL query to execute (runs query by default)"
|
|
194
|
+
),
|
|
195
|
+
] = None,
|
|
196
|
+
datasource: DatasourceOpt = None,
|
|
197
|
+
mdl: MdlOpt = None,
|
|
198
|
+
connection_info: ConnInfoOpt = None,
|
|
199
|
+
connection_file: ConnFileOpt = None,
|
|
200
|
+
limit: LimitOpt = None,
|
|
201
|
+
output: OutputOpt = "table",
|
|
202
|
+
) -> None:
|
|
203
|
+
"""Wren Engine CLI.
|
|
204
|
+
|
|
205
|
+
Run with --sql to execute a query using mdl.json and connection_info.json from
|
|
206
|
+
~/.wren. Use a subcommand (query / dry-run / dry-plan / validate)
|
|
207
|
+
for explicit control.
|
|
208
|
+
|
|
209
|
+
connection_info.json format:
|
|
210
|
+
|
|
211
|
+
\b
|
|
212
|
+
{
|
|
213
|
+
"datasource": "mysql",
|
|
214
|
+
"host": "localhost",
|
|
215
|
+
"port": 3306,
|
|
216
|
+
"database": "mydb",
|
|
217
|
+
"user": "root",
|
|
218
|
+
"password": "secret"
|
|
219
|
+
}
|
|
220
|
+
"""
|
|
221
|
+
if ctx.invoked_subcommand is not None:
|
|
222
|
+
return
|
|
223
|
+
if sql is None:
|
|
224
|
+
typer.echo(ctx.get_help())
|
|
225
|
+
return
|
|
226
|
+
with _build_engine(datasource, mdl, connection_info, connection_file) as engine:
|
|
227
|
+
try:
|
|
228
|
+
result = engine.query(sql, limit=limit)
|
|
229
|
+
except Exception as e:
|
|
230
|
+
typer.echo(f"Error: {e}", err=True)
|
|
231
|
+
raise typer.Exit(1)
|
|
232
|
+
_print_result(result, output)
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
# ── Subcommands ────────────────────────────────────────────────────────────
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
@app.command()
|
|
239
|
+
def query(
|
|
240
|
+
sql: Annotated[str, typer.Option("--sql", "-s", help="SQL query to execute")],
|
|
241
|
+
datasource: DatasourceOpt = None,
|
|
242
|
+
mdl: MdlOpt = None,
|
|
243
|
+
connection_info: ConnInfoOpt = None,
|
|
244
|
+
connection_file: ConnFileOpt = None,
|
|
245
|
+
limit: LimitOpt = None,
|
|
246
|
+
output: OutputOpt = "table",
|
|
247
|
+
):
|
|
248
|
+
"""Execute a SQL query through the Wren semantic layer."""
|
|
249
|
+
with _build_engine(datasource, mdl, connection_info, connection_file) as engine:
|
|
250
|
+
try:
|
|
251
|
+
result = engine.query(sql, limit=limit)
|
|
252
|
+
except Exception as e:
|
|
253
|
+
typer.echo(f"Error: {e}", err=True)
|
|
254
|
+
raise typer.Exit(1)
|
|
255
|
+
_print_result(result, output)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
@app.command(name="dry-run")
|
|
259
|
+
def dry_run(
|
|
260
|
+
sql: Annotated[str, typer.Option("--sql", "-s", help="SQL query to validate")],
|
|
261
|
+
datasource: DatasourceOpt = None,
|
|
262
|
+
mdl: MdlOpt = None,
|
|
263
|
+
connection_info: ConnInfoOpt = None,
|
|
264
|
+
connection_file: ConnFileOpt = None,
|
|
265
|
+
):
|
|
266
|
+
"""Dry-run a SQL query (parse + validate, no results returned)."""
|
|
267
|
+
with _build_engine(datasource, mdl, connection_info, connection_file) as engine:
|
|
268
|
+
try:
|
|
269
|
+
engine.dry_run(sql)
|
|
270
|
+
typer.echo("OK")
|
|
271
|
+
except Exception as e:
|
|
272
|
+
typer.echo(f"Error: {e}", err=True)
|
|
273
|
+
raise typer.Exit(1)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
@app.command(name="dry-plan")
|
|
277
|
+
def dry_plan(
|
|
278
|
+
sql: Annotated[str, typer.Option("--sql", "-s", help="SQL query to plan")],
|
|
279
|
+
datasource: DatasourceOpt = None,
|
|
280
|
+
mdl: MdlOpt = None,
|
|
281
|
+
connection_file: ConnFileOpt = None,
|
|
282
|
+
):
|
|
283
|
+
"""Plan SQL through MDL and print the expanded SQL (no DB required)."""
|
|
284
|
+
from wren.engine import WrenEngine # noqa: PLC0415
|
|
285
|
+
from wren.model.data_source import DataSource # noqa: PLC0415
|
|
286
|
+
|
|
287
|
+
manifest_str = _load_manifest(_require_mdl(mdl))
|
|
288
|
+
# Read datasource from connection_info.json only when --datasource is not given
|
|
289
|
+
conn_dict = (
|
|
290
|
+
_load_conn(None, connection_file, required=False)
|
|
291
|
+
if connection_file is not None or datasource is None
|
|
292
|
+
else {}
|
|
293
|
+
)
|
|
294
|
+
ds_str = _resolve_datasource(datasource, conn_dict)
|
|
295
|
+
|
|
296
|
+
try:
|
|
297
|
+
ds = DataSource(ds_str.lower())
|
|
298
|
+
except ValueError:
|
|
299
|
+
typer.echo(f"Error: unknown datasource '{ds_str}'", err=True)
|
|
300
|
+
raise typer.Exit(1)
|
|
301
|
+
|
|
302
|
+
with WrenEngine(
|
|
303
|
+
manifest_str=manifest_str, data_source=ds, connection_info={}
|
|
304
|
+
) as engine:
|
|
305
|
+
try:
|
|
306
|
+
result = engine.dry_plan(sql)
|
|
307
|
+
typer.echo(result)
|
|
308
|
+
except Exception as e:
|
|
309
|
+
typer.echo(f"Error: {e}", err=True)
|
|
310
|
+
raise typer.Exit(1)
|
|
311
|
+
|
|
312
|
+
|
|
313
|
+
@app.command()
|
|
314
|
+
def validate(
|
|
315
|
+
sql: Annotated[str, typer.Option("--sql", "-s", help="SQL query to validate")],
|
|
316
|
+
datasource: DatasourceOpt = None,
|
|
317
|
+
mdl: MdlOpt = None,
|
|
318
|
+
connection_info: ConnInfoOpt = None,
|
|
319
|
+
connection_file: ConnFileOpt = None,
|
|
320
|
+
):
|
|
321
|
+
"""Validate SQL can be planned and dry-run against the data source."""
|
|
322
|
+
with _build_engine(datasource, mdl, connection_info, connection_file) as engine:
|
|
323
|
+
try:
|
|
324
|
+
engine.dry_run(sql)
|
|
325
|
+
typer.echo("Valid")
|
|
326
|
+
except Exception as e:
|
|
327
|
+
typer.echo(f"Invalid: {e}", err=True)
|
|
328
|
+
raise typer.Exit(1)
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
# ── Output formatting ──────────────────────────────────────────────────────
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def _print_result(table, output: str) -> None:
|
|
335
|
+
output = output.lower()
|
|
336
|
+
if output not in {"json", "csv", "table"}:
|
|
337
|
+
typer.echo(
|
|
338
|
+
f"Error: unsupported output format '{output}'. Use json, csv, or table.",
|
|
339
|
+
err=True,
|
|
340
|
+
)
|
|
341
|
+
raise typer.Exit(1)
|
|
342
|
+
if output == "json":
|
|
343
|
+
try:
|
|
344
|
+
df = table.to_pandas()
|
|
345
|
+
typer.echo(df.to_json(orient="records", lines=True))
|
|
346
|
+
except Exception:
|
|
347
|
+
typer.echo(json.dumps(table.to_pydict()))
|
|
348
|
+
elif output == "csv":
|
|
349
|
+
try:
|
|
350
|
+
df = table.to_pandas()
|
|
351
|
+
typer.echo(df.to_csv(index=False))
|
|
352
|
+
except Exception:
|
|
353
|
+
typer.echo(str(table))
|
|
354
|
+
else:
|
|
355
|
+
try:
|
|
356
|
+
df = table.to_pandas()
|
|
357
|
+
typer.echo(df.to_string(index=False))
|
|
358
|
+
except Exception:
|
|
359
|
+
typer.echo(str(table))
|
|
360
|
+
|
|
361
|
+
|
|
362
|
+
@app.command()
|
|
363
|
+
def version():
|
|
364
|
+
"""Print the wren-engine version."""
|
|
365
|
+
from wren import __version__ # noqa: PLC0415
|
|
366
|
+
|
|
367
|
+
typer.echo(f"wren-engine {__version__}")
|
|
368
|
+
|
|
369
|
+
|
|
370
|
+
try:
|
|
371
|
+
import lancedb # noqa: PLC0415, F401
|
|
372
|
+
import sentence_transformers # noqa: PLC0415, F401
|
|
373
|
+
|
|
374
|
+
from wren.memory.cli import memory_app # noqa: PLC0415
|
|
375
|
+
|
|
376
|
+
app.add_typer(memory_app)
|
|
377
|
+
except ImportError:
|
|
378
|
+
pass # wren[memory] not installed
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
if __name__ == "__main__":
|
|
382
|
+
app()
|
wren/connector/base.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
|
|
3
|
+
import pyarrow as pa
|
|
4
|
+
from ibis.expr.datatypes import Decimal
|
|
5
|
+
from ibis.expr.datatypes.core import UUID
|
|
6
|
+
from ibis.expr.types import Table
|
|
7
|
+
from loguru import logger
|
|
8
|
+
|
|
9
|
+
from wren.model.data_source import DataSource
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ConnectorABC(ABC):
|
|
13
|
+
@abstractmethod
|
|
14
|
+
def query(self, sql: str, limit: int | None = None) -> pa.Table:
|
|
15
|
+
pass
|
|
16
|
+
|
|
17
|
+
@abstractmethod
|
|
18
|
+
def dry_run(self, sql: str) -> None:
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
@abstractmethod
|
|
22
|
+
def close(self) -> None:
|
|
23
|
+
pass
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class IbisConnector(ConnectorABC):
|
|
27
|
+
def __init__(self, data_source: DataSource, connection_info):
|
|
28
|
+
self.data_source = data_source
|
|
29
|
+
self.connection = self.data_source.get_connection(connection_info)
|
|
30
|
+
self._closed = False
|
|
31
|
+
|
|
32
|
+
def query(self, sql: str, limit: int | None = None) -> pa.Table:
|
|
33
|
+
ibis_table = self.connection.sql(sql)
|
|
34
|
+
if limit is not None:
|
|
35
|
+
ibis_table = ibis_table.limit(limit)
|
|
36
|
+
ibis_table = self._handle_pyarrow_unsupported_type(ibis_table)
|
|
37
|
+
return ibis_table.to_pyarrow()
|
|
38
|
+
|
|
39
|
+
def _handle_pyarrow_unsupported_type(self, ibis_table: Table, **kwargs) -> Table:
|
|
40
|
+
result_table = ibis_table
|
|
41
|
+
for name, dtype in ibis_table.schema().items():
|
|
42
|
+
if isinstance(dtype, Decimal):
|
|
43
|
+
result_table = self._round_decimal_columns(
|
|
44
|
+
result_table=result_table, col_name=name, **kwargs
|
|
45
|
+
)
|
|
46
|
+
elif isinstance(dtype, UUID):
|
|
47
|
+
result_table = self._cast_uuid_columns(
|
|
48
|
+
result_table=result_table, col_name=name
|
|
49
|
+
)
|
|
50
|
+
return result_table
|
|
51
|
+
|
|
52
|
+
def _cast_uuid_columns(self, result_table: Table, col_name: str) -> Table:
|
|
53
|
+
return result_table.mutate(**{col_name: result_table[col_name].cast("string")})
|
|
54
|
+
|
|
55
|
+
def _round_decimal_columns(
|
|
56
|
+
self, result_table: Table, col_name: str, scale: int = 9
|
|
57
|
+
) -> Table:
|
|
58
|
+
col = result_table[col_name]
|
|
59
|
+
decimal_type = Decimal(precision=38, scale=scale)
|
|
60
|
+
rounded_col = col.cast(decimal_type).round(scale)
|
|
61
|
+
return result_table.mutate(**{col_name: rounded_col})
|
|
62
|
+
|
|
63
|
+
def dry_run(self, sql: str) -> None:
|
|
64
|
+
self.connection.sql(sql)
|
|
65
|
+
|
|
66
|
+
def close(self) -> None:
|
|
67
|
+
if self._closed or not hasattr(self, "connection") or self.connection is None:
|
|
68
|
+
return
|
|
69
|
+
try:
|
|
70
|
+
if hasattr(self.connection, "con"):
|
|
71
|
+
if hasattr(self.connection.con, "close"):
|
|
72
|
+
self.connection.con.close()
|
|
73
|
+
elif hasattr(self.connection, "close"):
|
|
74
|
+
self.connection.close()
|
|
75
|
+
elif hasattr(self.connection, "disconnect"):
|
|
76
|
+
self.connection.disconnect()
|
|
77
|
+
else:
|
|
78
|
+
logger.warning(
|
|
79
|
+
f"Closing connection for {self.data_source.value} is not implemented."
|
|
80
|
+
)
|
|
81
|
+
except Exception as e:
|
|
82
|
+
logger.warning(
|
|
83
|
+
f"Error closing connection for {self.data_source.value}: {e}"
|
|
84
|
+
)
|
|
85
|
+
finally:
|
|
86
|
+
self._closed = True
|
|
87
|
+
self.connection = None
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import base64
|
|
2
|
+
from json import loads
|
|
3
|
+
|
|
4
|
+
import pyarrow as pa
|
|
5
|
+
from loguru import logger
|
|
6
|
+
|
|
7
|
+
from wren.connector.base import ConnectorABC
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BigQueryConnector(ConnectorABC):
|
|
11
|
+
def __init__(self, connection_info):
|
|
12
|
+
from google.cloud import bigquery # noqa: PLC0415
|
|
13
|
+
from google.oauth2 import service_account # noqa: PLC0415
|
|
14
|
+
|
|
15
|
+
self.connection_info = connection_info
|
|
16
|
+
credits_json = loads(
|
|
17
|
+
base64.b64decode(connection_info.credentials.get_secret_value()).decode(
|
|
18
|
+
"utf-8"
|
|
19
|
+
)
|
|
20
|
+
)
|
|
21
|
+
credentials = service_account.Credentials.from_service_account_info(
|
|
22
|
+
credits_json
|
|
23
|
+
)
|
|
24
|
+
credentials = credentials.with_scopes(
|
|
25
|
+
[
|
|
26
|
+
"https://www.googleapis.com/auth/drive",
|
|
27
|
+
"https://www.googleapis.com/auth/cloud-platform",
|
|
28
|
+
]
|
|
29
|
+
)
|
|
30
|
+
client = bigquery.Client(
|
|
31
|
+
credentials=credentials,
|
|
32
|
+
project=connection_info.get_billing_project_id(),
|
|
33
|
+
)
|
|
34
|
+
job_config = bigquery.QueryJobConfig()
|
|
35
|
+
job_config.job_timeout_ms = connection_info.job_timeout_ms
|
|
36
|
+
client.default_query_job_config = job_config
|
|
37
|
+
self.connection = client
|
|
38
|
+
|
|
39
|
+
def query(self, sql: str, limit: int | None = None) -> pa.Table:
|
|
40
|
+
return self.connection.query(sql).result(max_results=limit).to_arrow()
|
|
41
|
+
|
|
42
|
+
def dry_run(self, sql: str) -> None:
|
|
43
|
+
from google.cloud import bigquery # noqa: PLC0415
|
|
44
|
+
|
|
45
|
+
self.connection.query(
|
|
46
|
+
sql, job_config=bigquery.QueryJobConfig(dry_run=True, use_query_cache=False)
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
def close(self) -> None:
|
|
50
|
+
try:
|
|
51
|
+
self.connection.close()
|
|
52
|
+
except Exception as e:
|
|
53
|
+
logger.warning(f"Error closing BigQuery connection: {e}")
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def create_connector(connection_info) -> BigQueryConnector:
|
|
57
|
+
return BigQueryConnector(connection_info)
|
wren/connector/canner.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
from contextlib import closing
|
|
2
|
+
from functools import cache
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import ibis
|
|
6
|
+
import ibis.expr.schema as sch
|
|
7
|
+
import pyarrow as pa
|
|
8
|
+
from ibis import BaseBackend
|
|
9
|
+
from ibis.backends.sql.compilers.postgres import compiler as postgres_compiler
|
|
10
|
+
from ibis.expr.datatypes import Decimal
|
|
11
|
+
from ibis.expr.datatypes.core import UUID
|
|
12
|
+
from ibis.expr.types import Table
|
|
13
|
+
from loguru import logger
|
|
14
|
+
|
|
15
|
+
from wren.connector.base import ConnectorABC
|
|
16
|
+
from wren.model.data_source import DataSource
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@cache
|
|
20
|
+
def _get_pg_type_names(connection: BaseBackend) -> dict[int, str]:
|
|
21
|
+
with closing(connection.raw_sql("SELECT oid, typname FROM pg_type")) as cur:
|
|
22
|
+
return dict(cur.fetchall())
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class CannerConnector(ConnectorABC):
|
|
26
|
+
def __init__(self, connection_info):
|
|
27
|
+
self.connection = DataSource.canner.get_connection(connection_info)
|
|
28
|
+
|
|
29
|
+
def query(self, sql: str, limit: int | None = None) -> pa.Table:
|
|
30
|
+
schema = self._get_schema(sql)
|
|
31
|
+
ibis_table = self.connection.sql(sql, schema=schema)
|
|
32
|
+
if limit is not None:
|
|
33
|
+
ibis_table = ibis_table.limit(limit)
|
|
34
|
+
ibis_table = self._handle_pyarrow_unsupported_type(ibis_table)
|
|
35
|
+
return ibis_table.to_pyarrow()
|
|
36
|
+
|
|
37
|
+
def _handle_pyarrow_unsupported_type(self, ibis_table: Table, **kwargs) -> Table:
|
|
38
|
+
result_table = ibis_table
|
|
39
|
+
for name, dtype in ibis_table.schema().items():
|
|
40
|
+
if isinstance(dtype, Decimal):
|
|
41
|
+
col = result_table[name]
|
|
42
|
+
decimal_type = Decimal(precision=38, scale=9)
|
|
43
|
+
rounded_col = col.cast(decimal_type).round(9)
|
|
44
|
+
result_table = result_table.mutate(**{name: rounded_col})
|
|
45
|
+
elif isinstance(dtype, UUID):
|
|
46
|
+
result_table = result_table.mutate(
|
|
47
|
+
**{name: result_table[name].cast("string")}
|
|
48
|
+
)
|
|
49
|
+
return result_table
|
|
50
|
+
|
|
51
|
+
def dry_run(self, sql: str) -> Any:
|
|
52
|
+
return self.connection.raw_sql(f"SELECT * FROM ({sql}) LIMIT 0")
|
|
53
|
+
|
|
54
|
+
def close(self) -> None:
|
|
55
|
+
try:
|
|
56
|
+
if hasattr(self.connection, "con") and hasattr(
|
|
57
|
+
self.connection.con, "close"
|
|
58
|
+
):
|
|
59
|
+
self.connection.con.close()
|
|
60
|
+
elif hasattr(self.connection, "close"):
|
|
61
|
+
self.connection.close()
|
|
62
|
+
except Exception as e:
|
|
63
|
+
logger.warning(f"Error closing Canner connection: {e}")
|
|
64
|
+
|
|
65
|
+
def _get_schema(self, sql: str) -> sch.Schema:
|
|
66
|
+
cur = self.dry_run(sql)
|
|
67
|
+
type_names = _get_pg_type_names(self.connection)
|
|
68
|
+
return ibis.schema(
|
|
69
|
+
{
|
|
70
|
+
desc.name: postgres_compiler.type_mapper.from_string(
|
|
71
|
+
type_names[desc.type_code]
|
|
72
|
+
)
|
|
73
|
+
for desc in cur.description
|
|
74
|
+
}
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def create_connector(connection_info) -> CannerConnector:
|
|
79
|
+
return CannerConnector(connection_info)
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
from contextlib import closing
|
|
2
|
+
|
|
3
|
+
import pyarrow as pa
|
|
4
|
+
from loguru import logger
|
|
5
|
+
|
|
6
|
+
from wren.connector.base import ConnectorABC
|
|
7
|
+
from wren.model import (
|
|
8
|
+
DatabricksConnectionUnion,
|
|
9
|
+
DatabricksServicePrincipalConnectionInfo,
|
|
10
|
+
DatabricksTokenConnectionInfo,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class DatabricksConnector(ConnectorABC):
|
|
15
|
+
def __init__(self, connection_info: DatabricksConnectionUnion):
|
|
16
|
+
from databricks import sql as dbsql # noqa: PLC0415
|
|
17
|
+
from databricks.sdk.core import Config as DbConfig # noqa: PLC0415
|
|
18
|
+
from databricks.sdk.core import oauth_service_principal # noqa: PLC0415
|
|
19
|
+
|
|
20
|
+
if isinstance(connection_info, DatabricksTokenConnectionInfo):
|
|
21
|
+
self.connection = dbsql.connect(
|
|
22
|
+
server_hostname=connection_info.server_hostname.get_secret_value(),
|
|
23
|
+
http_path=connection_info.http_path.get_secret_value(),
|
|
24
|
+
access_token=connection_info.access_token.get_secret_value(),
|
|
25
|
+
)
|
|
26
|
+
elif isinstance(connection_info, DatabricksServicePrincipalConnectionInfo):
|
|
27
|
+
kwargs = {
|
|
28
|
+
"host": connection_info.server_hostname.get_secret_value(),
|
|
29
|
+
"client_id": connection_info.client_id.get_secret_value(),
|
|
30
|
+
"client_secret": connection_info.client_secret.get_secret_value(),
|
|
31
|
+
}
|
|
32
|
+
if connection_info.azure_tenant_id is not None:
|
|
33
|
+
kwargs["azure_tenant_id"] = (
|
|
34
|
+
connection_info.azure_tenant_id.get_secret_value()
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
def credential_provider():
|
|
38
|
+
return oauth_service_principal(DbConfig(**kwargs))
|
|
39
|
+
|
|
40
|
+
self.connection = dbsql.connect(
|
|
41
|
+
server_hostname=connection_info.server_hostname.get_secret_value(),
|
|
42
|
+
http_path=connection_info.http_path.get_secret_value(),
|
|
43
|
+
credentials_provider=credential_provider,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
def query(self, sql: str, limit: int | None = None) -> pa.Table:
|
|
47
|
+
with closing(self.connection.cursor()) as cursor:
|
|
48
|
+
cursor.execute(sql)
|
|
49
|
+
if limit is not None:
|
|
50
|
+
return cursor.fetchmany_arrow(limit)
|
|
51
|
+
return cursor.fetchall_arrow()
|
|
52
|
+
|
|
53
|
+
def dry_run(self, sql: str) -> None:
|
|
54
|
+
with closing(self.connection.cursor()) as cursor:
|
|
55
|
+
cursor.execute(f"SELECT * FROM ({sql}) AS sub LIMIT 0")
|
|
56
|
+
|
|
57
|
+
def close(self) -> None:
|
|
58
|
+
try:
|
|
59
|
+
self.connection.close()
|
|
60
|
+
except Exception as e:
|
|
61
|
+
logger.warning(f"Error closing Databricks connection: {e}")
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def create_connector(connection_info) -> DatabricksConnector:
|
|
65
|
+
return DatabricksConnector(connection_info)
|