soda-databricks 4.0.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- soda_databricks-4.0.5/PKG-INFO +6 -0
- soda_databricks-4.0.5/setup.cfg +4 -0
- soda_databricks-4.0.5/setup.py +24 -0
- soda_databricks-4.0.5/src/soda_databricks/common/data_sources/databricks_data_source.py +298 -0
- soda_databricks-4.0.5/src/soda_databricks/common/data_sources/databricks_data_source_connection.py +40 -0
- soda_databricks-4.0.5/src/soda_databricks/common/statements/hive_metadata_tables_query.py +143 -0
- soda_databricks-4.0.5/src/soda_databricks/model/data_source/databricks_connection_properties.py +41 -0
- soda_databricks-4.0.5/src/soda_databricks/model/data_source/databricks_data_source.py +23 -0
- soda_databricks-4.0.5/src/soda_databricks/test_helpers/databricks_data_source_test_helper.py +25 -0
- soda_databricks-4.0.5/src/soda_databricks.egg-info/PKG-INFO +6 -0
- soda_databricks-4.0.5/src/soda_databricks.egg-info/SOURCES.txt +13 -0
- soda_databricks-4.0.5/src/soda_databricks.egg-info/dependency_links.txt +1 -0
- soda_databricks-4.0.5/src/soda_databricks.egg-info/entry_points.txt +2 -0
- soda_databricks-4.0.5/src/soda_databricks.egg-info/requires.txt +2 -0
- soda_databricks-4.0.5/src/soda_databricks.egg-info/top_level.txt +1 -0
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
#!/usr/bin/env python
|
|
2
|
+
|
|
3
|
+
from setuptools import setup
|
|
4
|
+
|
|
5
|
+
package_name = "soda-databricks"
|
|
6
|
+
package_version = "4.0.5"
|
|
7
|
+
description = "Soda Databricks V4"
|
|
8
|
+
|
|
9
|
+
requires = [
|
|
10
|
+
f"soda-core=={package_version}",
|
|
11
|
+
"databricks-sql-connector",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
setup(
|
|
15
|
+
name=package_name,
|
|
16
|
+
version=package_version,
|
|
17
|
+
install_requires=requires,
|
|
18
|
+
package_dir={"": "src"},
|
|
19
|
+
entry_points={
|
|
20
|
+
"soda.plugins.data_source.databricks": [
|
|
21
|
+
"DatabricksDataSourceImpl = soda_databricks.common.data_sources.databricks_data_source:DatabricksDataSourceImpl",
|
|
22
|
+
],
|
|
23
|
+
},
|
|
24
|
+
)
|
|
@@ -0,0 +1,298 @@
|
|
|
1
|
+
from logging import Logger
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from soda_core.common.data_source_connection import DataSourceConnection
|
|
5
|
+
from soda_core.common.data_source_impl import DataSourceImpl
|
|
6
|
+
from soda_core.common.data_source_results import QueryResult
|
|
7
|
+
from soda_core.common.logging_constants import soda_logger
|
|
8
|
+
from soda_core.common.metadata_types import (
|
|
9
|
+
ColumnMetadata,
|
|
10
|
+
DataSourceNamespace,
|
|
11
|
+
SodaDataTypeName,
|
|
12
|
+
)
|
|
13
|
+
from soda_core.common.sql_ast import (
|
|
14
|
+
ALTER_TABLE_ADD_COLUMN,
|
|
15
|
+
ALTER_TABLE_DROP_COLUMN,
|
|
16
|
+
CREATE_TABLE_COLUMN,
|
|
17
|
+
)
|
|
18
|
+
from soda_core.common.sql_dialect import SqlDialect
|
|
19
|
+
from soda_core.common.statements.metadata_tables_query import MetadataTablesQuery
|
|
20
|
+
from soda_core.common.statements.table_types import TableType
|
|
21
|
+
from soda_databricks.common.data_sources.databricks_data_source_connection import (
|
|
22
|
+
DatabricksDataSourceConnection,
|
|
23
|
+
)
|
|
24
|
+
from soda_databricks.common.statements.hive_metadata_tables_query import (
|
|
25
|
+
HiveMetadataTablesQuery,
|
|
26
|
+
)
|
|
27
|
+
from soda_databricks.model.data_source.databricks_data_source import (
|
|
28
|
+
DatabricksDataSource as DatabricksDataSourceModel,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
logger: Logger = soda_logger
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class DatabricksDataSourceImpl(DataSourceImpl, model_class=DatabricksDataSourceModel):
|
|
35
|
+
def __init__(self, data_source_model: DatabricksDataSourceModel, connection: Optional[DataSourceConnection] = None):
|
|
36
|
+
super().__init__(data_source_model=data_source_model, connection=connection)
|
|
37
|
+
|
|
38
|
+
def _create_sql_dialect(self) -> SqlDialect:
|
|
39
|
+
if self.__is_hive_catalog():
|
|
40
|
+
return DatabricksHiveSqlDialect(data_source_impl=self)
|
|
41
|
+
return DatabricksSqlDialect(data_source_impl=self)
|
|
42
|
+
|
|
43
|
+
def _create_data_source_connection(self) -> DataSourceConnection:
|
|
44
|
+
return DatabricksDataSourceConnection(
|
|
45
|
+
name=self.data_source_model.name, connection_properties=self.data_source_model.connection_properties
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
def create_metadata_tables_query(self) -> MetadataTablesQuery:
|
|
49
|
+
if self.__is_hive_catalog():
|
|
50
|
+
return HiveMetadataTablesQuery(
|
|
51
|
+
sql_dialect=self.sql_dialect, data_source_connection=self.data_source_connection
|
|
52
|
+
)
|
|
53
|
+
else:
|
|
54
|
+
return super().create_metadata_tables_query()
|
|
55
|
+
|
|
56
|
+
def __is_hive_catalog(self):
|
|
57
|
+
# Check the connection "catalog"
|
|
58
|
+
catalog: Optional[str] = self.data_source_model.connection_properties.catalog
|
|
59
|
+
if catalog and catalog.lower() == "hive_metastore":
|
|
60
|
+
return True
|
|
61
|
+
# All other catalogs should be treated as "unity catalogs"
|
|
62
|
+
return False
|
|
63
|
+
|
|
64
|
+
def get_columns_metadata(self, dataset_prefixes: list[str], dataset_name: str) -> list[ColumnMetadata]:
|
|
65
|
+
try:
|
|
66
|
+
return super().get_columns_metadata(dataset_prefixes, dataset_name)
|
|
67
|
+
except Exception as e:
|
|
68
|
+
logger.warning(f"Error getting columns metadata for {dataset_name}: {e}\n\nReturning empty list.")
|
|
69
|
+
return []
|
|
70
|
+
|
|
71
|
+
def test_schema_exists(self, prefixes: list[str]) -> bool:
|
|
72
|
+
if not self.__is_hive_catalog():
|
|
73
|
+
return super().test_schema_exists(prefixes)
|
|
74
|
+
|
|
75
|
+
schema_name: str = prefixes[1]
|
|
76
|
+
|
|
77
|
+
result = self.execute_query(
|
|
78
|
+
f"SHOW SCHEMAS LIKE '{schema_name}'"
|
|
79
|
+
).rows # We only need to check the schema name, as the catalog name is always the same
|
|
80
|
+
for row in result:
|
|
81
|
+
if row[0] and row[0].lower() == schema_name.lower():
|
|
82
|
+
return True
|
|
83
|
+
return False
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class DatabricksSqlDialect(SqlDialect):
|
|
87
|
+
DEFAULT_QUOTE_CHAR = "`"
|
|
88
|
+
|
|
89
|
+
SODA_DATA_TYPE_SYNONYMS = (
|
|
90
|
+
(SodaDataTypeName.TEXT, SodaDataTypeName.VARCHAR, SodaDataTypeName.CHAR),
|
|
91
|
+
(SodaDataTypeName.NUMERIC, SodaDataTypeName.DECIMAL),
|
|
92
|
+
(SodaDataTypeName.TIMESTAMP_TZ, SodaDataTypeName.TIMESTAMP),
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
def _get_data_type_name_synonyms(self) -> list[list[str]]:
|
|
96
|
+
return [
|
|
97
|
+
["int", "integer"],
|
|
98
|
+
]
|
|
99
|
+
|
|
100
|
+
def column_data_type(self) -> str:
|
|
101
|
+
return self.default_casify("data_type")
|
|
102
|
+
|
|
103
|
+
def supports_data_type_character_maximum_length(self) -> bool:
|
|
104
|
+
return False
|
|
105
|
+
|
|
106
|
+
def supports_data_type_numeric_precision(self) -> bool:
|
|
107
|
+
return False
|
|
108
|
+
|
|
109
|
+
def supports_data_type_numeric_scale(self) -> bool:
|
|
110
|
+
return False
|
|
111
|
+
|
|
112
|
+
def supports_data_type_datetime_precision(self) -> bool:
|
|
113
|
+
return False
|
|
114
|
+
|
|
115
|
+
def get_data_source_data_type_name_by_soda_data_type_names(self) -> dict:
|
|
116
|
+
return {
|
|
117
|
+
SodaDataTypeName.CHAR: "string",
|
|
118
|
+
SodaDataTypeName.VARCHAR: "string",
|
|
119
|
+
SodaDataTypeName.TEXT: "string",
|
|
120
|
+
SodaDataTypeName.SMALLINT: "smallint",
|
|
121
|
+
SodaDataTypeName.INTEGER: "int",
|
|
122
|
+
SodaDataTypeName.BIGINT: "bigint",
|
|
123
|
+
SodaDataTypeName.NUMERIC: "decimal",
|
|
124
|
+
SodaDataTypeName.DECIMAL: "decimal",
|
|
125
|
+
SodaDataTypeName.FLOAT: "float",
|
|
126
|
+
SodaDataTypeName.DOUBLE: "double",
|
|
127
|
+
SodaDataTypeName.TIMESTAMP: "timestamp",
|
|
128
|
+
SodaDataTypeName.TIMESTAMP_TZ: "timestamp",
|
|
129
|
+
SodaDataTypeName.DATE: "date",
|
|
130
|
+
SodaDataTypeName.TIME: "string", # no native TIME type in Databricks
|
|
131
|
+
SodaDataTypeName.BOOLEAN: "boolean",
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
def get_soda_data_type_name_by_data_source_data_type_names(self) -> dict[str, SodaDataTypeName]:
|
|
135
|
+
return {
|
|
136
|
+
"string": SodaDataTypeName.TEXT,
|
|
137
|
+
"varchar": SodaDataTypeName.VARCHAR,
|
|
138
|
+
"char": SodaDataTypeName.CHAR,
|
|
139
|
+
"tinyint": SodaDataTypeName.SMALLINT,
|
|
140
|
+
"short": SodaDataTypeName.SMALLINT,
|
|
141
|
+
"smallint": SodaDataTypeName.SMALLINT,
|
|
142
|
+
"int": SodaDataTypeName.INTEGER,
|
|
143
|
+
"integer": SodaDataTypeName.INTEGER,
|
|
144
|
+
"bigint": SodaDataTypeName.BIGINT,
|
|
145
|
+
"long": SodaDataTypeName.BIGINT,
|
|
146
|
+
"decimal": SodaDataTypeName.DECIMAL,
|
|
147
|
+
"numeric": SodaDataTypeName.NUMERIC,
|
|
148
|
+
"float": SodaDataTypeName.FLOAT,
|
|
149
|
+
"real": SodaDataTypeName.FLOAT,
|
|
150
|
+
"float4": SodaDataTypeName.FLOAT,
|
|
151
|
+
"double": SodaDataTypeName.DOUBLE,
|
|
152
|
+
"double precision": SodaDataTypeName.DOUBLE,
|
|
153
|
+
"float8": SodaDataTypeName.DOUBLE,
|
|
154
|
+
"timestamp": SodaDataTypeName.TIMESTAMP,
|
|
155
|
+
"timestamp without time zone": SodaDataTypeName.TIMESTAMP,
|
|
156
|
+
"timestamp_ntz": SodaDataTypeName.TIMESTAMP, # If there is explicitly stated that the timestamp is without time zone, we consider it to be the same as TIMESTAMP
|
|
157
|
+
"timestamptz": SodaDataTypeName.TIMESTAMP_TZ,
|
|
158
|
+
"timestamp with time zone": SodaDataTypeName.TIMESTAMP_TZ,
|
|
159
|
+
"date": SodaDataTypeName.DATE,
|
|
160
|
+
"boolean": SodaDataTypeName.BOOLEAN,
|
|
161
|
+
# Not supported -> will be converted to varchar
|
|
162
|
+
# "binary"
|
|
163
|
+
# "interval",
|
|
164
|
+
# "array",
|
|
165
|
+
# "map",
|
|
166
|
+
# "struct"
|
|
167
|
+
}
|
|
168
|
+
|
|
169
|
+
def escape_string(self, value: str):
|
|
170
|
+
raw_string = rf"{value}"
|
|
171
|
+
string_literal: str = raw_string.replace(r"'", r"\'")
|
|
172
|
+
return string_literal
|
|
173
|
+
|
|
174
|
+
def encode_string_for_sql(self, string: str) -> str:
|
|
175
|
+
"""This escapes values that contain newlines correctly."""
|
|
176
|
+
# For databricks, we don't need to encode the string, it's able to handle the newlines correctly.
|
|
177
|
+
# In fact, when we encode the string, we run into issues with the escape characters for the quotes
|
|
178
|
+
return string
|
|
179
|
+
|
|
180
|
+
def column_data_type_numeric_scale(self) -> str:
|
|
181
|
+
return self.default_casify("numeric_scale")
|
|
182
|
+
|
|
183
|
+
def column_data_type_datetime_precision(self) -> str:
|
|
184
|
+
return self.default_casify("datetime_precision")
|
|
185
|
+
|
|
186
|
+
def _build_create_table_column_type(self, create_table_column: CREATE_TABLE_COLUMN) -> str:
|
|
187
|
+
# Databricks will complain if string lengths or datetime precisions are passed in, so strip if they are provided
|
|
188
|
+
if create_table_column.type.name == "string":
|
|
189
|
+
create_table_column.type.character_maximum_length = None
|
|
190
|
+
if create_table_column.type.name in ["timestamp_ntz", "timestamp"]:
|
|
191
|
+
create_table_column.type.datetime_precision = None
|
|
192
|
+
return super()._build_create_table_column_type(create_table_column=create_table_column)
|
|
193
|
+
|
|
194
|
+
def _get_data_type_name_synonyms(self) -> list[list[str]]:
|
|
195
|
+
return [
|
|
196
|
+
["varchar", "char", "string"],
|
|
197
|
+
["smallint", "int2"],
|
|
198
|
+
["integer", "int", "int4"],
|
|
199
|
+
["bigint", "int8"],
|
|
200
|
+
["real", "float4", "float"],
|
|
201
|
+
["double precision", "float8", "double"],
|
|
202
|
+
["timestamp", "timestamp without time zone"],
|
|
203
|
+
["timestamptz", "timestamp with time zone"],
|
|
204
|
+
["time", "time without time zone"],
|
|
205
|
+
]
|
|
206
|
+
|
|
207
|
+
# Explicitly leaving this here for reference. See comment below why we moved to DESCRIBE TABLE.
|
|
208
|
+
# def build_columns_metadata_query_str(self, table_namespace: DataSourceNamespace, table_name: str) -> str:
|
|
209
|
+
# Unity catalog only stores things in lower case,
|
|
210
|
+
# even though create table may have been quoted and with mixed case
|
|
211
|
+
# table_name_lower: str = table_name.lower()
|
|
212
|
+
# return super().build_columns_metadata_query_str(table_namespace, table_name_lower)
|
|
213
|
+
|
|
214
|
+
# We move to DESCRIBE TABLE, that is a more up-to-date way to get the columns metadata. (information_schema is lagging behind sometimes, and does not always return the correct columns)
|
|
215
|
+
def build_columns_metadata_query_str(self, table_namespace: DataSourceNamespace, table_name: str) -> str:
|
|
216
|
+
database_name: str | None = table_namespace.get_database_for_metadata_query()
|
|
217
|
+
schema_name: str = table_namespace.get_schema_for_metadata_query()
|
|
218
|
+
return f"DESCRIBE {database_name}.{schema_name}.{table_name}"
|
|
219
|
+
|
|
220
|
+
def build_column_metadatas_from_query_result(self, query_result: QueryResult) -> list[ColumnMetadata]:
|
|
221
|
+
# Filter out dataset description rows (first such line starts with #, ignore the rest) or empty
|
|
222
|
+
filtered_rows = []
|
|
223
|
+
for row in query_result.rows:
|
|
224
|
+
if row[0].startswith("#"): # ignore all description rows
|
|
225
|
+
break
|
|
226
|
+
if not row[0] and not row[1]: # empty row
|
|
227
|
+
continue
|
|
228
|
+
|
|
229
|
+
# Trim data type details, e.g. decimal(10,0) -> decimal. Only decimal supports it anyway.
|
|
230
|
+
data_type = row[1]
|
|
231
|
+
if "(" in data_type:
|
|
232
|
+
data_type = data_type[: data_type.index("(")].strip()
|
|
233
|
+
row = (row[0], data_type) + row[2:]
|
|
234
|
+
filtered_rows.append(row)
|
|
235
|
+
|
|
236
|
+
return super().build_column_metadatas_from_query_result(
|
|
237
|
+
QueryResult(rows=filtered_rows, columns=query_result.columns)
|
|
238
|
+
)
|
|
239
|
+
|
|
240
|
+
def post_schema_create_sql(self, prefixes: list[str]) -> Optional[list[str]]:
|
|
241
|
+
assert len(prefixes) == 2, f"Expected 2 prefixes, got {len(prefixes)}"
|
|
242
|
+
catalog_name: str = self.quote_default(prefixes[0])
|
|
243
|
+
schema_name: str = self.quote_default(prefixes[1])
|
|
244
|
+
return [f"GRANT SELECT, USAGE, CREATE, MANAGE ON SCHEMA {catalog_name}.{schema_name} TO `account users`;"]
|
|
245
|
+
# f"GRANT SELECT ON FUTURE TABLES IN SCHEMA {catalog_name}.{schema_name} TO `account users`;"]
|
|
246
|
+
|
|
247
|
+
@classmethod
|
|
248
|
+
def is_same_soda_data_type_with_synonyms(cls, expected: SodaDataTypeName, actual: SodaDataTypeName) -> bool:
|
|
249
|
+
# Special case of a 1-way synonym: TEXT is allowed where TIME is expected
|
|
250
|
+
if expected == SodaDataTypeName.TIME and actual == SodaDataTypeName.TEXT:
|
|
251
|
+
logger.debug(
|
|
252
|
+
f"In is_same_soda_data_type_with_synonyms, Expected {expected} and actual {actual} are the same"
|
|
253
|
+
)
|
|
254
|
+
return True
|
|
255
|
+
return super().is_same_soda_data_type_with_synonyms(expected, actual)
|
|
256
|
+
|
|
257
|
+
def _build_alter_table_add_column_sql(
|
|
258
|
+
self, alter_table: ALTER_TABLE_ADD_COLUMN, add_semicolon: bool = True, add_parenthesis: bool = False
|
|
259
|
+
) -> str:
|
|
260
|
+
return super()._build_alter_table_add_column_sql(alter_table, add_semicolon=add_semicolon, add_parenthesis=True)
|
|
261
|
+
|
|
262
|
+
def _get_add_column_sql_expr(self) -> str:
|
|
263
|
+
return "ADD COLUMNS"
|
|
264
|
+
|
|
265
|
+
def _build_alter_table_drop_column_sql(
|
|
266
|
+
self, alter_table: ALTER_TABLE_DROP_COLUMN, add_semicolon: bool = True
|
|
267
|
+
) -> str:
|
|
268
|
+
column_name_quoted: str = self._quote_column_for_create_table(alter_table.column_name)
|
|
269
|
+
return f"ALTER TABLE {alter_table.fully_qualified_table_name} DROP COLUMNS ({column_name_quoted})" + (
|
|
270
|
+
";" if add_semicolon else ""
|
|
271
|
+
)
|
|
272
|
+
|
|
273
|
+
def drop_column_supported(self) -> bool:
|
|
274
|
+
return False # Note, this is technically supported. But we need to change the delta table mapping mode name for this (out of scope at the time of writing)
|
|
275
|
+
|
|
276
|
+
def convert_table_type_to_enum(self, table_type: str) -> TableType:
|
|
277
|
+
if table_type == "MANAGED":
|
|
278
|
+
return TableType.TABLE
|
|
279
|
+
elif table_type == "VIEW":
|
|
280
|
+
return TableType.VIEW
|
|
281
|
+
elif table_type == "MATERIALIZED_VIEW":
|
|
282
|
+
return TableType.VIEW # For now, a materialized view is treated as a view.
|
|
283
|
+
else:
|
|
284
|
+
# Default to TABLE if the table type is not recognized (so we're backwards compatible with existing code)
|
|
285
|
+
logger.warning(f"Invalid table type: {table_type}, defaulting to TABLE")
|
|
286
|
+
return TableType.TABLE
|
|
287
|
+
|
|
288
|
+
def metadata_casify(self, identifier: str) -> str:
|
|
289
|
+
return identifier.lower()
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
class DatabricksHiveSqlDialect(DatabricksSqlDialect):
|
|
293
|
+
def post_schema_create_sql(self, prefixes: list[str]) -> Optional[list[str]]:
|
|
294
|
+
assert len(prefixes) == 2, f"Expected 2 prefixes, got {len(prefixes)}"
|
|
295
|
+
catalog_name: str = self.quote_default(prefixes[0])
|
|
296
|
+
schema_name: str = self.quote_default(prefixes[1])
|
|
297
|
+
|
|
298
|
+
return [f"GRANT SELECT, USAGE, CREATE ON SCHEMA {catalog_name}.{schema_name} TO `users`;"]
|
soda_databricks-4.0.5/src/soda_databricks/common/data_sources/databricks_data_source_connection.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
from databricks import sql
|
|
6
|
+
from soda_core.common.data_source_connection import DataSourceConnection
|
|
7
|
+
from soda_core.common.logging_constants import soda_logger
|
|
8
|
+
from soda_core.model.data_source.data_source_connection_properties import (
|
|
9
|
+
DataSourceConnectionProperties,
|
|
10
|
+
)
|
|
11
|
+
from soda_databricks.model.data_source.databricks_connection_properties import (
|
|
12
|
+
DatabricksConnectionProperties,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
logger: logging.Logger = soda_logger
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class DatabricksDataSourceConnection(DataSourceConnection):
|
|
19
|
+
def __init__(self, name: str, connection_properties: DataSourceConnectionProperties):
|
|
20
|
+
super().__init__(name, connection_properties)
|
|
21
|
+
|
|
22
|
+
def _create_connection(
|
|
23
|
+
self,
|
|
24
|
+
config: DatabricksConnectionProperties,
|
|
25
|
+
):
|
|
26
|
+
return sql.connect(
|
|
27
|
+
user_agent_entry="Soda Core",
|
|
28
|
+
**config.to_connection_kwargs(),
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
def rollback(self) -> None:
|
|
32
|
+
# We do not start any transactions, Databricks default is autocommit.
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
def commit(self) -> None:
|
|
36
|
+
# We do not start any transactions, Databricks default is autocommit.
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
def _execute_query_get_result_row_column_name(self, column) -> str:
|
|
40
|
+
return column[0]
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import re
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
from soda_core.common.data_source_connection import DataSourceConnection
|
|
7
|
+
from soda_core.common.data_source_results import QueryResult
|
|
8
|
+
from soda_core.common.sql_dialect import SqlDialect
|
|
9
|
+
from soda_core.common.statements.metadata_tables_query import (
|
|
10
|
+
FullyQualifiedTableName,
|
|
11
|
+
MetadataTablesQuery,
|
|
12
|
+
)
|
|
13
|
+
from soda_core.common.statements.table_types import FullyQualifiedViewName, TableType
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class HiveMetadataTablesQuery(MetadataTablesQuery):
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
sql_dialect: SqlDialect,
|
|
20
|
+
data_source_connection: DataSourceConnection,
|
|
21
|
+
prefixes: Optional[list[str]] = None,
|
|
22
|
+
):
|
|
23
|
+
self.sql_dialect = sql_dialect
|
|
24
|
+
self.data_source_connection: DataSourceConnection = data_source_connection
|
|
25
|
+
self.prefixes = prefixes
|
|
26
|
+
|
|
27
|
+
def execute(
|
|
28
|
+
self,
|
|
29
|
+
database_name: Optional[str] = None,
|
|
30
|
+
schema_name: Optional[str] = None,
|
|
31
|
+
include_table_name_like_filters: Optional[list[str]] = None,
|
|
32
|
+
exclude_table_name_like_filters: Optional[list[str]] = None,
|
|
33
|
+
types_to_return: Optional[
|
|
34
|
+
list[TableType]
|
|
35
|
+
] = None, # To make sure it's backwards compatible with the old behavior, when we use None it should default to [TableType.TABLE]
|
|
36
|
+
) -> list[FullyQualifiedTableName]:
|
|
37
|
+
if types_to_return is None:
|
|
38
|
+
types_to_return = [TableType.TABLE]
|
|
39
|
+
results: list[FullyQualifiedTableName] = []
|
|
40
|
+
if TableType.TABLE in types_to_return:
|
|
41
|
+
sql: str = self.build_sql_statement(
|
|
42
|
+
database_name=database_name, schema_name=schema_name, object_type_to_fetch=TableType.TABLE
|
|
43
|
+
)
|
|
44
|
+
query_result: QueryResult = self.data_source_connection.execute_query(sql)
|
|
45
|
+
results.extend(
|
|
46
|
+
self.get_results(
|
|
47
|
+
query_result,
|
|
48
|
+
object_type_to_fetch=TableType.TABLE,
|
|
49
|
+
include_table_name_like_filters=include_table_name_like_filters,
|
|
50
|
+
exclude_table_name_like_filters=exclude_table_name_like_filters,
|
|
51
|
+
)
|
|
52
|
+
)
|
|
53
|
+
if TableType.VIEW in types_to_return:
|
|
54
|
+
sql: str = self.build_sql_statement(
|
|
55
|
+
database_name=database_name, schema_name=schema_name, object_type_to_fetch=TableType.VIEW
|
|
56
|
+
)
|
|
57
|
+
query_result: QueryResult = self.data_source_connection.execute_query(sql)
|
|
58
|
+
results.extend(
|
|
59
|
+
self.get_results(
|
|
60
|
+
query_result,
|
|
61
|
+
object_type_to_fetch=TableType.VIEW,
|
|
62
|
+
include_table_name_like_filters=include_table_name_like_filters,
|
|
63
|
+
exclude_table_name_like_filters=exclude_table_name_like_filters,
|
|
64
|
+
)
|
|
65
|
+
)
|
|
66
|
+
return results
|
|
67
|
+
|
|
68
|
+
def build_sql_statement(
|
|
69
|
+
self,
|
|
70
|
+
database_name: Optional[str] = None,
|
|
71
|
+
schema_name: Optional[str] = None,
|
|
72
|
+
include_table_name_like_filters: Optional[list[str]] = None,
|
|
73
|
+
exclude_table_name_like_filters: Optional[list[str]] = None,
|
|
74
|
+
object_type_to_fetch: TableType = TableType.TABLE,
|
|
75
|
+
) -> str: # Return type for this function is a string, not a list (for the super method it is a list!)
|
|
76
|
+
schema_str = ""
|
|
77
|
+
if schema_name:
|
|
78
|
+
schema_str = f" FROM {self.sql_dialect.quote_default(schema_name)}"
|
|
79
|
+
if object_type_to_fetch == TableType.TABLE:
|
|
80
|
+
return f"SHOW TABLES{schema_str}"
|
|
81
|
+
elif object_type_to_fetch == TableType.VIEW:
|
|
82
|
+
return f"SHOW VIEWS{schema_str}"
|
|
83
|
+
else:
|
|
84
|
+
raise ValueError(f"Invalid object type to fetch: {object_type_to_fetch}")
|
|
85
|
+
|
|
86
|
+
def get_results(
|
|
87
|
+
self,
|
|
88
|
+
query_result: QueryResult,
|
|
89
|
+
object_type_to_fetch: TableType,
|
|
90
|
+
include_table_name_like_filters: Optional[list[str]] = None,
|
|
91
|
+
exclude_table_name_like_filters: Optional[list[str]] = None,
|
|
92
|
+
) -> list[FullyQualifiedTableName]:
|
|
93
|
+
if object_type_to_fetch == TableType.TABLE:
|
|
94
|
+
names_for_filtering = [table_name for _, table_name, _ in query_result.rows]
|
|
95
|
+
elif object_type_to_fetch == TableType.VIEW:
|
|
96
|
+
names_for_filtering = [view_name for _, view_name, *_ in query_result.rows]
|
|
97
|
+
else:
|
|
98
|
+
raise ValueError(f"Invalid object type to fetch: {object_type_to_fetch}")
|
|
99
|
+
filtered_names = self._filter_include_exclude(
|
|
100
|
+
names_for_filtering, include_table_name_like_filters, exclude_table_name_like_filters
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
if object_type_to_fetch == TableType.TABLE:
|
|
104
|
+
return [
|
|
105
|
+
FullyQualifiedTableName(database_name="hive_metastore", schema_name=schema_name, table_name=table_name)
|
|
106
|
+
for schema_name, table_name, _is_temporary in query_result.rows
|
|
107
|
+
if table_name in filtered_names
|
|
108
|
+
]
|
|
109
|
+
elif object_type_to_fetch == TableType.VIEW:
|
|
110
|
+
return [
|
|
111
|
+
FullyQualifiedViewName(database_name="hive_metastore", schema_name=schema_name, view_name=view_name)
|
|
112
|
+
for schema_name, view_name, *_ in query_result.rows
|
|
113
|
+
if view_name in filtered_names
|
|
114
|
+
]
|
|
115
|
+
else:
|
|
116
|
+
raise ValueError(f"Invalid object type to fetch: {object_type_to_fetch}")
|
|
117
|
+
|
|
118
|
+
# Copy from soda-library (v3)
|
|
119
|
+
@staticmethod
|
|
120
|
+
def _filter_include_exclude(
|
|
121
|
+
item_names: list[str], included_items: Optional[list[str]] = None, excluded_items: Optional[list[str]] = None
|
|
122
|
+
) -> list[str]:
|
|
123
|
+
filtered_names = item_names
|
|
124
|
+
if included_items or excluded_items:
|
|
125
|
+
|
|
126
|
+
def matches(name, pattern: str) -> bool:
|
|
127
|
+
pattern_regex = pattern.replace("%", ".*").lower()
|
|
128
|
+
is_match = re.fullmatch(pattern_regex, name.lower())
|
|
129
|
+
return bool(is_match)
|
|
130
|
+
|
|
131
|
+
if included_items:
|
|
132
|
+
filtered_names = [
|
|
133
|
+
filtered_name
|
|
134
|
+
for filtered_name in filtered_names
|
|
135
|
+
if any(matches(filtered_name, included_item) for included_item in included_items)
|
|
136
|
+
]
|
|
137
|
+
if excluded_items:
|
|
138
|
+
filtered_names = [
|
|
139
|
+
filtered_name
|
|
140
|
+
for filtered_name in filtered_names
|
|
141
|
+
if all(not matches(filtered_name, excluded_item) for excluded_item in excluded_items)
|
|
142
|
+
]
|
|
143
|
+
return filtered_names
|
soda_databricks-4.0.5/src/soda_databricks/model/data_source/databricks_connection_properties.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
from typing import ClassVar, Dict, Optional
|
|
3
|
+
|
|
4
|
+
from pydantic import Field, SecretStr
|
|
5
|
+
from soda_core.model.data_source.data_source_connection_properties import (
|
|
6
|
+
DataSourceConnectionProperties,
|
|
7
|
+
)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class DatabricksConnectionProperties(DataSourceConnectionProperties, ABC):
|
|
11
|
+
...
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class DatabricksSharedConnectionProperties(DatabricksConnectionProperties, ABC):
|
|
15
|
+
host: str = Field(
|
|
16
|
+
...,
|
|
17
|
+
description="Databricks workspace hostname (e.g. 'abc.cloud.databricks.com'). If it starts with https:// or http://, it will be removed.",
|
|
18
|
+
)
|
|
19
|
+
http_path: str = Field(..., description="HTTP path for the SQL endpoint or cluster")
|
|
20
|
+
catalog: str = Field(None, description="Default catalog to use")
|
|
21
|
+
session_configuration: Optional[Dict[str, str]] = Field(None, description="Optional session configuration dict")
|
|
22
|
+
|
|
23
|
+
field_mapping: ClassVar[Dict[str, str]] = {
|
|
24
|
+
"host": "server_hostname",
|
|
25
|
+
}
|
|
26
|
+
|
|
27
|
+
def to_connection_kwargs(self) -> dict:
|
|
28
|
+
connection_kwargs = super().to_connection_kwargs()
|
|
29
|
+
server_hostname: str = connection_kwargs["server_hostname"]
|
|
30
|
+
# Check if the server_hostname starts with https:// or http:// and remove it
|
|
31
|
+
prefixes = ["https://", "http://"]
|
|
32
|
+
for prefix in prefixes:
|
|
33
|
+
if server_hostname.startswith(prefix):
|
|
34
|
+
server_hostname = server_hostname[len(prefix) :]
|
|
35
|
+
break # Stop looking for prefixes once we find one
|
|
36
|
+
connection_kwargs["server_hostname"] = server_hostname
|
|
37
|
+
return connection_kwargs
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class DatabricksTokenAuth(DatabricksSharedConnectionProperties):
|
|
41
|
+
access_token: SecretStr = Field(..., description="Personal access token")
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
from pydantic import Field, field_validator
|
|
5
|
+
from soda_core.model.data_source.data_source import DataSourceBase
|
|
6
|
+
from soda_databricks.model.data_source.databricks_connection_properties import (
|
|
7
|
+
DatabricksConnectionProperties,
|
|
8
|
+
DatabricksTokenAuth,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class DatabricksDataSource(DataSourceBase, abc.ABC):
|
|
13
|
+
type: Literal["databricks"] = Field("databricks")
|
|
14
|
+
connection_properties: DatabricksConnectionProperties = Field(
|
|
15
|
+
..., alias="connection", description="Databricks connection configuration"
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
@field_validator("connection_properties", mode="before")
|
|
19
|
+
@classmethod
|
|
20
|
+
def infer_connection_type(cls, value):
|
|
21
|
+
if "access_token" in value:
|
|
22
|
+
return DatabricksTokenAuth(**value)
|
|
23
|
+
raise ValueError("Could not infer Databricks connection type from input")
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
|
|
5
|
+
from helpers.data_source_test_helper import DataSourceTestHelper
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class DatabricksDataSourceTestHelper(DataSourceTestHelper):
|
|
9
|
+
def _create_database_name(self) -> str | None:
|
|
10
|
+
return os.getenv("DATABRICKS_CATALOG", "unity_catalog")
|
|
11
|
+
|
|
12
|
+
def _create_data_source_yaml_str(self) -> str:
|
|
13
|
+
"""
|
|
14
|
+
Called in _create_data_source_impl to initialized self.data_source_impl
|
|
15
|
+
self.database_name and self.schema_name are available if appropriate for the data source type
|
|
16
|
+
"""
|
|
17
|
+
return f"""
|
|
18
|
+
type: databricks
|
|
19
|
+
name: {self.name}
|
|
20
|
+
connection:
|
|
21
|
+
host: {os.getenv("DATABRICKS_HOST")}
|
|
22
|
+
http_path: {os.getenv("DATABRICKS_HTTP_PATH")}
|
|
23
|
+
access_token: {os.getenv("DATABRICKS_TOKEN")}
|
|
24
|
+
catalog: {os.getenv("DATABRICKS_CATALOG", "unity_catalog")}
|
|
25
|
+
"""
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
setup.py
|
|
2
|
+
src/soda_databricks.egg-info/PKG-INFO
|
|
3
|
+
src/soda_databricks.egg-info/SOURCES.txt
|
|
4
|
+
src/soda_databricks.egg-info/dependency_links.txt
|
|
5
|
+
src/soda_databricks.egg-info/entry_points.txt
|
|
6
|
+
src/soda_databricks.egg-info/requires.txt
|
|
7
|
+
src/soda_databricks.egg-info/top_level.txt
|
|
8
|
+
src/soda_databricks/common/data_sources/databricks_data_source.py
|
|
9
|
+
src/soda_databricks/common/data_sources/databricks_data_source_connection.py
|
|
10
|
+
src/soda_databricks/common/statements/hive_metadata_tables_query.py
|
|
11
|
+
src/soda_databricks/model/data_source/databricks_connection_properties.py
|
|
12
|
+
src/soda_databricks/model/data_source/databricks_data_source.py
|
|
13
|
+
src/soda_databricks/test_helpers/databricks_data_source_test_helper.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
soda_databricks
|