wren-engine 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
wren/__init__.py ADDED
@@ -0,0 +1,9 @@
1
+ """Wren — semantic SQL layer for 20+ data sources."""
2
+
3
+ __version__ = "0.1.0"
4
+
5
+ from wren.engine import WrenEngine
6
+ from wren.model.data_source import DataSource
7
+ from wren.model.error import WrenError
8
+
9
+ __all__ = ["WrenEngine", "DataSource", "WrenError", "__version__"]
wren/cli.py ADDED
@@ -0,0 +1,382 @@
1
+ """Wren CLI — SQL transform and execution via the Wren semantic layer."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from pathlib import Path
7
+ from typing import Annotated, Optional
8
+
9
+ import typer
10
+
11
+ app = typer.Typer(name="wren", help="Wren Engine CLI", no_args_is_help=False)
12
+
13
+ _WREN_HOME = Path.home() / ".wren"
14
+ _DEFAULT_MDL = _WREN_HOME / "mdl.json"
15
+ _DEFAULT_CONN = _WREN_HOME / "connection_info.json"
16
+
17
+
18
+ # ── File discovery helpers ─────────────────────────────────────────────────
19
+
20
+
21
+ def _require_mdl(mdl: str | None) -> str:
22
+ """Return mdl arg if given, else auto-discover mdl.json from ~/.wren."""
23
+ if mdl is not None:
24
+ return mdl
25
+ if _DEFAULT_MDL.exists():
26
+ return str(_DEFAULT_MDL)
27
+ typer.echo(
28
+ f"Error: --mdl not specified and '{_DEFAULT_MDL}' not found.",
29
+ err=True,
30
+ )
31
+ raise typer.Exit(1)
32
+
33
+
34
+ def _load_manifest(mdl: str) -> str:
35
+ """Load MDL from a file path or treat as base64 string directly."""
36
+ path = Path(mdl).expanduser()
37
+ if path.exists():
38
+ import base64 # noqa: PLC0415
39
+
40
+ content = path.read_bytes()
41
+ if path.suffix.lower() == ".json":
42
+ # Raw JSON file — base64-encode it for WrenEngine
43
+ return base64.b64encode(content).decode()
44
+ # Non-.json file — assume it already contains a base64-encoded MDL string
45
+ return content.decode().strip()
46
+ # Not a file path — treat as a raw base64 string passed directly
47
+ return mdl
48
+
49
+
50
+ def _load_conn(
51
+ connection_info: str | None,
52
+ connection_file: str | None,
53
+ *,
54
+ required: bool = True,
55
+ ) -> dict:
56
+ """Load connection dict from inline JSON or file, with ~/.wren auto-discovery.
57
+
58
+ If neither --connection-info nor --connection-file is given, looks for
59
+ connection_info.json in ~/.wren. Raises typer.Exit(1) if required=True and nothing
60
+ is found.
61
+ """
62
+ if connection_info:
63
+ try:
64
+ conn = json.loads(connection_info)
65
+ except json.JSONDecodeError as e:
66
+ typer.echo(f"Error: invalid JSON in --connection-info: {e}", err=True)
67
+ raise typer.Exit(1)
68
+ if not isinstance(conn, dict):
69
+ typer.echo(
70
+ "Error: --connection-info must decode to a JSON object.", err=True
71
+ )
72
+ raise typer.Exit(1)
73
+ return conn
74
+
75
+ path_str = connection_file or (
76
+ str(_DEFAULT_CONN) if _DEFAULT_CONN.exists() else None
77
+ )
78
+ if path_str:
79
+ path = Path(path_str).expanduser()
80
+ if not path.exists():
81
+ typer.echo(f"Error: connection file not found: {path_str}", err=True)
82
+ raise typer.Exit(1)
83
+ try:
84
+ conn = json.loads(path.read_text())
85
+ except json.JSONDecodeError as e:
86
+ typer.echo(f"Error: invalid JSON in {path_str}: {e}", err=True)
87
+ raise typer.Exit(1)
88
+ if not isinstance(conn, dict):
89
+ typer.echo(f"Error: {path_str} must contain a JSON object.", err=True)
90
+ raise typer.Exit(1)
91
+ return conn
92
+
93
+ if required:
94
+ typer.echo(
95
+ f"Error: --connection-file not specified and '{_DEFAULT_CONN}' not found.",
96
+ err=True,
97
+ )
98
+ raise typer.Exit(1)
99
+ return {}
100
+
101
+
102
+ def _resolve_datasource(explicit: str | None, conn_dict: dict) -> str:
103
+ """Return datasource: use explicit --datasource arg first, then pop from conn dict.
104
+
105
+ Note: mutates conn_dict by removing the 'datasource' key so it is not
106
+ forwarded as an unknown field to WrenEngine / the connector.
107
+ """
108
+ if explicit:
109
+ conn_dict.pop("datasource", None)
110
+ return explicit
111
+ ds = conn_dict.pop("datasource", None)
112
+ if ds:
113
+ return ds
114
+ typer.echo(
115
+ "Error: --datasource not specified and 'datasource' key not found in connection info.",
116
+ err=True,
117
+ )
118
+ raise typer.Exit(1)
119
+
120
+
121
+ def _build_engine(
122
+ datasource: str | None,
123
+ mdl: str | None,
124
+ connection_info: str | None,
125
+ connection_file: str | None,
126
+ *,
127
+ conn_required: bool = True,
128
+ ):
129
+ from wren.engine import WrenEngine # noqa: PLC0415
130
+ from wren.model.data_source import DataSource # noqa: PLC0415
131
+
132
+ manifest_str = _load_manifest(_require_mdl(mdl))
133
+ conn_dict = _load_conn(connection_info, connection_file, required=conn_required)
134
+ ds_str = _resolve_datasource(datasource, conn_dict)
135
+
136
+ try:
137
+ ds = DataSource(ds_str.lower())
138
+ except ValueError:
139
+ typer.echo(f"Error: unknown datasource '{ds_str}'", err=True)
140
+ raise typer.Exit(1)
141
+
142
+ return WrenEngine(
143
+ manifest_str=manifest_str, data_source=ds, connection_info=conn_dict
144
+ )
145
+
146
+
147
+ # ── Shared option types ────────────────────────────────────────────────────
148
+
149
+ DatasourceOpt = Annotated[
150
+ Optional[str],
151
+ typer.Option(
152
+ "--datasource",
153
+ "-d",
154
+ help="Data source (e.g. mysql, postgres). Defaults to 'datasource' field in connection_info.json.",
155
+ ),
156
+ ]
157
+ MdlOpt = Annotated[
158
+ Optional[str],
159
+ typer.Option(
160
+ "--mdl",
161
+ "-m",
162
+ help=f"Path to MDL JSON file or base64 string. Defaults to {_DEFAULT_MDL}.",
163
+ ),
164
+ ]
165
+ ConnInfoOpt = Annotated[
166
+ Optional[str],
167
+ typer.Option("--connection-info", help="Inline JSON connection string"),
168
+ ]
169
+ ConnFileOpt = Annotated[
170
+ Optional[str],
171
+ typer.Option(
172
+ "--connection-file",
173
+ help=f"Path to JSON connection file. Defaults to {_DEFAULT_CONN}.",
174
+ ),
175
+ ]
176
+ LimitOpt = Annotated[
177
+ Optional[int], typer.Option("--limit", "-l", help="Max rows to return")
178
+ ]
179
+ OutputOpt = Annotated[
180
+ str, typer.Option("--output", "-o", help="Output format: json|csv|table")
181
+ ]
182
+
183
+
184
+ # ── Default command (no subcommand = query) ────────────────────────────────
185
+
186
+
187
+ @app.callback(invoke_without_command=True)
188
+ def main(
189
+ ctx: typer.Context,
190
+ sql: Annotated[
191
+ Optional[str],
192
+ typer.Option(
193
+ "--sql", "-s", help="SQL query to execute (runs query by default)"
194
+ ),
195
+ ] = None,
196
+ datasource: DatasourceOpt = None,
197
+ mdl: MdlOpt = None,
198
+ connection_info: ConnInfoOpt = None,
199
+ connection_file: ConnFileOpt = None,
200
+ limit: LimitOpt = None,
201
+ output: OutputOpt = "table",
202
+ ) -> None:
203
+ """Wren Engine CLI.
204
+
205
+ Run with --sql to execute a query using mdl.json and connection_info.json from
206
+ ~/.wren. Use a subcommand (query / dry-run / dry-plan / validate)
207
+ for explicit control.
208
+
209
+ connection_info.json format:
210
+
211
+ \b
212
+ {
213
+ "datasource": "mysql",
214
+ "host": "localhost",
215
+ "port": 3306,
216
+ "database": "mydb",
217
+ "user": "root",
218
+ "password": "secret"
219
+ }
220
+ """
221
+ if ctx.invoked_subcommand is not None:
222
+ return
223
+ if sql is None:
224
+ typer.echo(ctx.get_help())
225
+ return
226
+ with _build_engine(datasource, mdl, connection_info, connection_file) as engine:
227
+ try:
228
+ result = engine.query(sql, limit=limit)
229
+ except Exception as e:
230
+ typer.echo(f"Error: {e}", err=True)
231
+ raise typer.Exit(1)
232
+ _print_result(result, output)
233
+
234
+
235
+ # ── Subcommands ────────────────────────────────────────────────────────────
236
+
237
+
238
+ @app.command()
239
+ def query(
240
+ sql: Annotated[str, typer.Option("--sql", "-s", help="SQL query to execute")],
241
+ datasource: DatasourceOpt = None,
242
+ mdl: MdlOpt = None,
243
+ connection_info: ConnInfoOpt = None,
244
+ connection_file: ConnFileOpt = None,
245
+ limit: LimitOpt = None,
246
+ output: OutputOpt = "table",
247
+ ):
248
+ """Execute a SQL query through the Wren semantic layer."""
249
+ with _build_engine(datasource, mdl, connection_info, connection_file) as engine:
250
+ try:
251
+ result = engine.query(sql, limit=limit)
252
+ except Exception as e:
253
+ typer.echo(f"Error: {e}", err=True)
254
+ raise typer.Exit(1)
255
+ _print_result(result, output)
256
+
257
+
258
+ @app.command(name="dry-run")
259
+ def dry_run(
260
+ sql: Annotated[str, typer.Option("--sql", "-s", help="SQL query to validate")],
261
+ datasource: DatasourceOpt = None,
262
+ mdl: MdlOpt = None,
263
+ connection_info: ConnInfoOpt = None,
264
+ connection_file: ConnFileOpt = None,
265
+ ):
266
+ """Dry-run a SQL query (parse + validate, no results returned)."""
267
+ with _build_engine(datasource, mdl, connection_info, connection_file) as engine:
268
+ try:
269
+ engine.dry_run(sql)
270
+ typer.echo("OK")
271
+ except Exception as e:
272
+ typer.echo(f"Error: {e}", err=True)
273
+ raise typer.Exit(1)
274
+
275
+
276
+ @app.command(name="dry-plan")
277
+ def dry_plan(
278
+ sql: Annotated[str, typer.Option("--sql", "-s", help="SQL query to plan")],
279
+ datasource: DatasourceOpt = None,
280
+ mdl: MdlOpt = None,
281
+ connection_file: ConnFileOpt = None,
282
+ ):
283
+ """Plan SQL through MDL and print the expanded SQL (no DB required)."""
284
+ from wren.engine import WrenEngine # noqa: PLC0415
285
+ from wren.model.data_source import DataSource # noqa: PLC0415
286
+
287
+ manifest_str = _load_manifest(_require_mdl(mdl))
288
+ # Read datasource from connection_info.json only when --datasource is not given
289
+ conn_dict = (
290
+ _load_conn(None, connection_file, required=False)
291
+ if connection_file is not None or datasource is None
292
+ else {}
293
+ )
294
+ ds_str = _resolve_datasource(datasource, conn_dict)
295
+
296
+ try:
297
+ ds = DataSource(ds_str.lower())
298
+ except ValueError:
299
+ typer.echo(f"Error: unknown datasource '{ds_str}'", err=True)
300
+ raise typer.Exit(1)
301
+
302
+ with WrenEngine(
303
+ manifest_str=manifest_str, data_source=ds, connection_info={}
304
+ ) as engine:
305
+ try:
306
+ result = engine.dry_plan(sql)
307
+ typer.echo(result)
308
+ except Exception as e:
309
+ typer.echo(f"Error: {e}", err=True)
310
+ raise typer.Exit(1)
311
+
312
+
313
+ @app.command()
314
+ def validate(
315
+ sql: Annotated[str, typer.Option("--sql", "-s", help="SQL query to validate")],
316
+ datasource: DatasourceOpt = None,
317
+ mdl: MdlOpt = None,
318
+ connection_info: ConnInfoOpt = None,
319
+ connection_file: ConnFileOpt = None,
320
+ ):
321
+ """Validate SQL can be planned and dry-run against the data source."""
322
+ with _build_engine(datasource, mdl, connection_info, connection_file) as engine:
323
+ try:
324
+ engine.dry_run(sql)
325
+ typer.echo("Valid")
326
+ except Exception as e:
327
+ typer.echo(f"Invalid: {e}", err=True)
328
+ raise typer.Exit(1)
329
+
330
+
331
+ # ── Output formatting ──────────────────────────────────────────────────────
332
+
333
+
334
+ def _print_result(table, output: str) -> None:
335
+ output = output.lower()
336
+ if output not in {"json", "csv", "table"}:
337
+ typer.echo(
338
+ f"Error: unsupported output format '{output}'. Use json, csv, or table.",
339
+ err=True,
340
+ )
341
+ raise typer.Exit(1)
342
+ if output == "json":
343
+ try:
344
+ df = table.to_pandas()
345
+ typer.echo(df.to_json(orient="records", lines=True))
346
+ except Exception:
347
+ typer.echo(json.dumps(table.to_pydict()))
348
+ elif output == "csv":
349
+ try:
350
+ df = table.to_pandas()
351
+ typer.echo(df.to_csv(index=False))
352
+ except Exception:
353
+ typer.echo(str(table))
354
+ else:
355
+ try:
356
+ df = table.to_pandas()
357
+ typer.echo(df.to_string(index=False))
358
+ except Exception:
359
+ typer.echo(str(table))
360
+
361
+
362
+ @app.command()
363
+ def version():
364
+ """Print the wren-engine version."""
365
+ from wren import __version__ # noqa: PLC0415
366
+
367
+ typer.echo(f"wren-engine {__version__}")
368
+
369
+
370
+ try:
371
+ import lancedb # noqa: PLC0415, F401
372
+ import sentence_transformers # noqa: PLC0415, F401
373
+
374
+ from wren.memory.cli import memory_app # noqa: PLC0415
375
+
376
+ app.add_typer(memory_app)
377
+ except ImportError:
378
+ pass # wren[memory] not installed
379
+
380
+
381
+ if __name__ == "__main__":
382
+ app()
@@ -0,0 +1,4 @@
1
+ from wren.connector.base import ConnectorABC, IbisConnector
2
+ from wren.connector.factory import get_connector
3
+
4
+ __all__ = ["ConnectorABC", "IbisConnector", "get_connector"]
wren/connector/base.py ADDED
@@ -0,0 +1,87 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+ import pyarrow as pa
4
+ from ibis.expr.datatypes import Decimal
5
+ from ibis.expr.datatypes.core import UUID
6
+ from ibis.expr.types import Table
7
+ from loguru import logger
8
+
9
+ from wren.model.data_source import DataSource
10
+
11
+
12
+ class ConnectorABC(ABC):
13
+ @abstractmethod
14
+ def query(self, sql: str, limit: int | None = None) -> pa.Table:
15
+ pass
16
+
17
+ @abstractmethod
18
+ def dry_run(self, sql: str) -> None:
19
+ pass
20
+
21
+ @abstractmethod
22
+ def close(self) -> None:
23
+ pass
24
+
25
+
26
+ class IbisConnector(ConnectorABC):
27
+ def __init__(self, data_source: DataSource, connection_info):
28
+ self.data_source = data_source
29
+ self.connection = self.data_source.get_connection(connection_info)
30
+ self._closed = False
31
+
32
+ def query(self, sql: str, limit: int | None = None) -> pa.Table:
33
+ ibis_table = self.connection.sql(sql)
34
+ if limit is not None:
35
+ ibis_table = ibis_table.limit(limit)
36
+ ibis_table = self._handle_pyarrow_unsupported_type(ibis_table)
37
+ return ibis_table.to_pyarrow()
38
+
39
+ def _handle_pyarrow_unsupported_type(self, ibis_table: Table, **kwargs) -> Table:
40
+ result_table = ibis_table
41
+ for name, dtype in ibis_table.schema().items():
42
+ if isinstance(dtype, Decimal):
43
+ result_table = self._round_decimal_columns(
44
+ result_table=result_table, col_name=name, **kwargs
45
+ )
46
+ elif isinstance(dtype, UUID):
47
+ result_table = self._cast_uuid_columns(
48
+ result_table=result_table, col_name=name
49
+ )
50
+ return result_table
51
+
52
+ def _cast_uuid_columns(self, result_table: Table, col_name: str) -> Table:
53
+ return result_table.mutate(**{col_name: result_table[col_name].cast("string")})
54
+
55
+ def _round_decimal_columns(
56
+ self, result_table: Table, col_name: str, scale: int = 9
57
+ ) -> Table:
58
+ col = result_table[col_name]
59
+ decimal_type = Decimal(precision=38, scale=scale)
60
+ rounded_col = col.cast(decimal_type).round(scale)
61
+ return result_table.mutate(**{col_name: rounded_col})
62
+
63
+ def dry_run(self, sql: str) -> None:
64
+ self.connection.sql(sql)
65
+
66
+ def close(self) -> None:
67
+ if self._closed or not hasattr(self, "connection") or self.connection is None:
68
+ return
69
+ try:
70
+ if hasattr(self.connection, "con"):
71
+ if hasattr(self.connection.con, "close"):
72
+ self.connection.con.close()
73
+ elif hasattr(self.connection, "close"):
74
+ self.connection.close()
75
+ elif hasattr(self.connection, "disconnect"):
76
+ self.connection.disconnect()
77
+ else:
78
+ logger.warning(
79
+ f"Closing connection for {self.data_source.value} is not implemented."
80
+ )
81
+ except Exception as e:
82
+ logger.warning(
83
+ f"Error closing connection for {self.data_source.value}: {e}"
84
+ )
85
+ finally:
86
+ self._closed = True
87
+ self.connection = None
@@ -0,0 +1,57 @@
1
+ import base64
2
+ from json import loads
3
+
4
+ import pyarrow as pa
5
+ from loguru import logger
6
+
7
+ from wren.connector.base import ConnectorABC
8
+
9
+
10
+ class BigQueryConnector(ConnectorABC):
11
+ def __init__(self, connection_info):
12
+ from google.cloud import bigquery # noqa: PLC0415
13
+ from google.oauth2 import service_account # noqa: PLC0415
14
+
15
+ self.connection_info = connection_info
16
+ credits_json = loads(
17
+ base64.b64decode(connection_info.credentials.get_secret_value()).decode(
18
+ "utf-8"
19
+ )
20
+ )
21
+ credentials = service_account.Credentials.from_service_account_info(
22
+ credits_json
23
+ )
24
+ credentials = credentials.with_scopes(
25
+ [
26
+ "https://www.googleapis.com/auth/drive",
27
+ "https://www.googleapis.com/auth/cloud-platform",
28
+ ]
29
+ )
30
+ client = bigquery.Client(
31
+ credentials=credentials,
32
+ project=connection_info.get_billing_project_id(),
33
+ )
34
+ job_config = bigquery.QueryJobConfig()
35
+ job_config.job_timeout_ms = connection_info.job_timeout_ms
36
+ client.default_query_job_config = job_config
37
+ self.connection = client
38
+
39
+ def query(self, sql: str, limit: int | None = None) -> pa.Table:
40
+ return self.connection.query(sql).result(max_results=limit).to_arrow()
41
+
42
+ def dry_run(self, sql: str) -> None:
43
+ from google.cloud import bigquery # noqa: PLC0415
44
+
45
+ self.connection.query(
46
+ sql, job_config=bigquery.QueryJobConfig(dry_run=True, use_query_cache=False)
47
+ )
48
+
49
+ def close(self) -> None:
50
+ try:
51
+ self.connection.close()
52
+ except Exception as e:
53
+ logger.warning(f"Error closing BigQuery connection: {e}")
54
+
55
+
56
+ def create_connector(connection_info) -> BigQueryConnector:
57
+ return BigQueryConnector(connection_info)
@@ -0,0 +1,79 @@
1
+ from contextlib import closing
2
+ from functools import cache
3
+ from typing import Any
4
+
5
+ import ibis
6
+ import ibis.expr.schema as sch
7
+ import pyarrow as pa
8
+ from ibis import BaseBackend
9
+ from ibis.backends.sql.compilers.postgres import compiler as postgres_compiler
10
+ from ibis.expr.datatypes import Decimal
11
+ from ibis.expr.datatypes.core import UUID
12
+ from ibis.expr.types import Table
13
+ from loguru import logger
14
+
15
+ from wren.connector.base import ConnectorABC
16
+ from wren.model.data_source import DataSource
17
+
18
+
19
+ @cache
20
+ def _get_pg_type_names(connection: BaseBackend) -> dict[int, str]:
21
+ with closing(connection.raw_sql("SELECT oid, typname FROM pg_type")) as cur:
22
+ return dict(cur.fetchall())
23
+
24
+
25
+ class CannerConnector(ConnectorABC):
26
+ def __init__(self, connection_info):
27
+ self.connection = DataSource.canner.get_connection(connection_info)
28
+
29
+ def query(self, sql: str, limit: int | None = None) -> pa.Table:
30
+ schema = self._get_schema(sql)
31
+ ibis_table = self.connection.sql(sql, schema=schema)
32
+ if limit is not None:
33
+ ibis_table = ibis_table.limit(limit)
34
+ ibis_table = self._handle_pyarrow_unsupported_type(ibis_table)
35
+ return ibis_table.to_pyarrow()
36
+
37
+ def _handle_pyarrow_unsupported_type(self, ibis_table: Table, **kwargs) -> Table:
38
+ result_table = ibis_table
39
+ for name, dtype in ibis_table.schema().items():
40
+ if isinstance(dtype, Decimal):
41
+ col = result_table[name]
42
+ decimal_type = Decimal(precision=38, scale=9)
43
+ rounded_col = col.cast(decimal_type).round(9)
44
+ result_table = result_table.mutate(**{name: rounded_col})
45
+ elif isinstance(dtype, UUID):
46
+ result_table = result_table.mutate(
47
+ **{name: result_table[name].cast("string")}
48
+ )
49
+ return result_table
50
+
51
+ def dry_run(self, sql: str) -> Any:
52
+ return self.connection.raw_sql(f"SELECT * FROM ({sql}) LIMIT 0")
53
+
54
+ def close(self) -> None:
55
+ try:
56
+ if hasattr(self.connection, "con") and hasattr(
57
+ self.connection.con, "close"
58
+ ):
59
+ self.connection.con.close()
60
+ elif hasattr(self.connection, "close"):
61
+ self.connection.close()
62
+ except Exception as e:
63
+ logger.warning(f"Error closing Canner connection: {e}")
64
+
65
+ def _get_schema(self, sql: str) -> sch.Schema:
66
+ cur = self.dry_run(sql)
67
+ type_names = _get_pg_type_names(self.connection)
68
+ return ibis.schema(
69
+ {
70
+ desc.name: postgres_compiler.type_mapper.from_string(
71
+ type_names[desc.type_code]
72
+ )
73
+ for desc in cur.description
74
+ }
75
+ )
76
+
77
+
78
+ def create_connector(connection_info) -> CannerConnector:
79
+ return CannerConnector(connection_info)
@@ -0,0 +1,65 @@
1
+ from contextlib import closing
2
+
3
+ import pyarrow as pa
4
+ from loguru import logger
5
+
6
+ from wren.connector.base import ConnectorABC
7
+ from wren.model import (
8
+ DatabricksConnectionUnion,
9
+ DatabricksServicePrincipalConnectionInfo,
10
+ DatabricksTokenConnectionInfo,
11
+ )
12
+
13
+
14
+ class DatabricksConnector(ConnectorABC):
15
+ def __init__(self, connection_info: DatabricksConnectionUnion):
16
+ from databricks import sql as dbsql # noqa: PLC0415
17
+ from databricks.sdk.core import Config as DbConfig # noqa: PLC0415
18
+ from databricks.sdk.core import oauth_service_principal # noqa: PLC0415
19
+
20
+ if isinstance(connection_info, DatabricksTokenConnectionInfo):
21
+ self.connection = dbsql.connect(
22
+ server_hostname=connection_info.server_hostname.get_secret_value(),
23
+ http_path=connection_info.http_path.get_secret_value(),
24
+ access_token=connection_info.access_token.get_secret_value(),
25
+ )
26
+ elif isinstance(connection_info, DatabricksServicePrincipalConnectionInfo):
27
+ kwargs = {
28
+ "host": connection_info.server_hostname.get_secret_value(),
29
+ "client_id": connection_info.client_id.get_secret_value(),
30
+ "client_secret": connection_info.client_secret.get_secret_value(),
31
+ }
32
+ if connection_info.azure_tenant_id is not None:
33
+ kwargs["azure_tenant_id"] = (
34
+ connection_info.azure_tenant_id.get_secret_value()
35
+ )
36
+
37
+ def credential_provider():
38
+ return oauth_service_principal(DbConfig(**kwargs))
39
+
40
+ self.connection = dbsql.connect(
41
+ server_hostname=connection_info.server_hostname.get_secret_value(),
42
+ http_path=connection_info.http_path.get_secret_value(),
43
+ credentials_provider=credential_provider,
44
+ )
45
+
46
+ def query(self, sql: str, limit: int | None = None) -> pa.Table:
47
+ with closing(self.connection.cursor()) as cursor:
48
+ cursor.execute(sql)
49
+ if limit is not None:
50
+ return cursor.fetchmany_arrow(limit)
51
+ return cursor.fetchall_arrow()
52
+
53
+ def dry_run(self, sql: str) -> None:
54
+ with closing(self.connection.cursor()) as cursor:
55
+ cursor.execute(f"SELECT * FROM ({sql}) AS sub LIMIT 0")
56
+
57
+ def close(self) -> None:
58
+ try:
59
+ self.connection.close()
60
+ except Exception as e:
61
+ logger.warning(f"Error closing Databricks connection: {e}")
62
+
63
+
64
+ def create_connector(connection_info) -> DatabricksConnector:
65
+ return DatabricksConnector(connection_info)