databricks-labs-lakebridge 0.10.6__py3-none-any.whl → 0.10.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- databricks/labs/lakebridge/__about__.py +1 -1
- databricks/labs/lakebridge/analyzer/__init__.py +0 -0
- databricks/labs/lakebridge/analyzer/lakebridge_analyzer.py +95 -0
- databricks/labs/lakebridge/assessments/profiler_validator.py +103 -0
- databricks/labs/lakebridge/base_install.py +20 -3
- databricks/labs/lakebridge/cli.py +32 -59
- databricks/labs/lakebridge/contexts/application.py +7 -0
- databricks/labs/lakebridge/deployment/job.py +2 -2
- databricks/labs/lakebridge/helpers/file_utils.py +36 -0
- databricks/labs/lakebridge/helpers/validation.py +5 -3
- databricks/labs/lakebridge/install.py +73 -484
- databricks/labs/lakebridge/reconcile/compare.py +70 -33
- databricks/labs/lakebridge/reconcile/connectors/data_source.py +24 -1
- databricks/labs/lakebridge/reconcile/connectors/databricks.py +12 -1
- databricks/labs/lakebridge/reconcile/connectors/dialect_utils.py +126 -0
- databricks/labs/lakebridge/reconcile/connectors/models.py +7 -0
- databricks/labs/lakebridge/reconcile/connectors/oracle.py +12 -1
- databricks/labs/lakebridge/reconcile/connectors/secrets.py +19 -1
- databricks/labs/lakebridge/reconcile/connectors/snowflake.py +63 -30
- databricks/labs/lakebridge/reconcile/connectors/tsql.py +28 -2
- databricks/labs/lakebridge/reconcile/constants.py +4 -3
- databricks/labs/lakebridge/reconcile/execute.py +9 -810
- databricks/labs/lakebridge/reconcile/normalize_recon_config_service.py +133 -0
- databricks/labs/lakebridge/reconcile/query_builder/base.py +53 -18
- databricks/labs/lakebridge/reconcile/query_builder/expression_generator.py +8 -2
- databricks/labs/lakebridge/reconcile/query_builder/hash_query.py +7 -13
- databricks/labs/lakebridge/reconcile/query_builder/sampling_query.py +18 -19
- databricks/labs/lakebridge/reconcile/query_builder/threshold_query.py +36 -15
- databricks/labs/lakebridge/reconcile/recon_config.py +3 -15
- databricks/labs/lakebridge/reconcile/recon_output_config.py +2 -1
- databricks/labs/lakebridge/reconcile/reconciliation.py +511 -0
- databricks/labs/lakebridge/reconcile/schema_compare.py +26 -19
- databricks/labs/lakebridge/reconcile/trigger_recon_aggregate_service.py +78 -0
- databricks/labs/lakebridge/reconcile/trigger_recon_service.py +256 -0
- databricks/labs/lakebridge/reconcile/utils.py +38 -0
- databricks/labs/lakebridge/transpiler/execute.py +34 -28
- databricks/labs/lakebridge/transpiler/installers.py +523 -0
- databricks/labs/lakebridge/transpiler/lsp/lsp_engine.py +47 -60
- databricks/labs/lakebridge/transpiler/sqlglot/dialect_utils.py +2 -0
- databricks/labs/lakebridge/transpiler/transpile_engine.py +0 -18
- {databricks_labs_lakebridge-0.10.6.dist-info → databricks_labs_lakebridge-0.10.8.dist-info}/METADATA +1 -1
- {databricks_labs_lakebridge-0.10.6.dist-info → databricks_labs_lakebridge-0.10.8.dist-info}/RECORD +46 -35
- {databricks_labs_lakebridge-0.10.6.dist-info → databricks_labs_lakebridge-0.10.8.dist-info}/WHEEL +0 -0
- {databricks_labs_lakebridge-0.10.6.dist-info → databricks_labs_lakebridge-0.10.8.dist-info}/entry_points.txt +0 -0
- {databricks_labs_lakebridge-0.10.6.dist-info → databricks_labs_lakebridge-0.10.8.dist-info}/licenses/LICENSE +0 -0
- {databricks_labs_lakebridge-0.10.6.dist-info → databricks_labs_lakebridge-0.10.8.dist-info}/licenses/NOTICE +0 -0
@@ -3,6 +3,7 @@ from functools import reduce
|
|
3
3
|
from pyspark.sql import DataFrame, SparkSession
|
4
4
|
from pyspark.sql.functions import col, expr, lit
|
5
5
|
|
6
|
+
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils
|
6
7
|
from databricks.labs.lakebridge.reconcile.exception import ColumnMismatchException
|
7
8
|
from databricks.labs.lakebridge.reconcile.recon_capture import (
|
8
9
|
ReconIntermediatePersist,
|
@@ -22,7 +23,7 @@ _HASH_COLUMN_NAME = "hash_value_recon"
|
|
22
23
|
_SAMPLE_ROWS = 50
|
23
24
|
|
24
25
|
|
25
|
-
def
|
26
|
+
def _raise_column_mismatch_exception(msg: str, source_missing: list[str], target_missing: list[str]) -> Exception:
|
26
27
|
error_msg = (
|
27
28
|
f"{msg}\n"
|
28
29
|
f"columns missing in source: {','.join(source_missing) if source_missing else None}\n"
|
@@ -33,12 +34,25 @@ def raise_column_mismatch_exception(msg: str, source_missing: list[str], target_
|
|
33
34
|
|
34
35
|
def _generate_join_condition(source_alias, target_alias, key_columns):
|
35
36
|
conditions = [
|
36
|
-
col(f"{source_alias}.{key_column}").eqNullSafe(
|
37
|
+
col(f"{source_alias}.{DialectUtils.ansi_normalize_identifier(key_column)}").eqNullSafe(
|
38
|
+
col(f"{target_alias}.{DialectUtils.ansi_normalize_identifier(key_column)}")
|
39
|
+
)
|
37
40
|
for key_column in key_columns
|
38
41
|
]
|
39
42
|
return reduce(lambda a, b: a & b, conditions)
|
40
43
|
|
41
44
|
|
45
|
+
def _build_column_selector(table_name, column_name):
|
46
|
+
alias = DialectUtils.ansi_normalize_identifier(f"{table_name}_{DialectUtils.unnormalize_identifier(column_name)}")
|
47
|
+
return f'{table_name}.{DialectUtils.ansi_normalize_identifier(column_name)} as {alias}'
|
48
|
+
|
49
|
+
|
50
|
+
def _build_mismatch_column(table, column):
|
51
|
+
return col(DialectUtils.ansi_normalize_identifier(column)).alias(
|
52
|
+
DialectUtils.unnormalize_identifier(column.replace(f'{table}_', '').lower())
|
53
|
+
)
|
54
|
+
|
55
|
+
|
42
56
|
def reconcile_data(
|
43
57
|
source: DataFrame,
|
44
58
|
target: DataFrame,
|
@@ -59,14 +73,14 @@ def reconcile_data(
|
|
59
73
|
how="full",
|
60
74
|
)
|
61
75
|
.selectExpr(
|
62
|
-
*[f'{source_alias
|
63
|
-
*[f'{target_alias
|
76
|
+
*[f'{_build_column_selector(source_alias, col_name)}' for col_name in source.columns],
|
77
|
+
*[f'{_build_column_selector(target_alias, col_name)}' for col_name in target.columns],
|
64
78
|
)
|
65
79
|
)
|
66
80
|
|
67
81
|
# Write unmatched df to volume
|
68
82
|
df = ReconIntermediatePersist(spark, path).write_and_read_unmatched_df_with_volumes(df)
|
69
|
-
logger.warning(f"Unmatched data
|
83
|
+
logger.warning(f"Unmatched data was written to {path} successfully")
|
70
84
|
|
71
85
|
mismatch = _get_mismatch_data(df, source_alias, target_alias) if report_type in {"all", "data"} else None
|
72
86
|
|
@@ -74,24 +88,24 @@ def reconcile_data(
|
|
74
88
|
df.filter(col(f"{source_alias}_{_HASH_COLUMN_NAME}").isNull())
|
75
89
|
.select(
|
76
90
|
*[
|
77
|
-
|
91
|
+
_build_mismatch_column(target_alias, col_name)
|
78
92
|
for col_name in df.columns
|
79
93
|
if col_name.startswith(f'{target_alias}_')
|
80
94
|
]
|
81
95
|
)
|
82
|
-
.drop(_HASH_COLUMN_NAME)
|
96
|
+
.drop(f"{_HASH_COLUMN_NAME}")
|
83
97
|
)
|
84
98
|
|
85
99
|
missing_in_tgt = (
|
86
100
|
df.filter(col(f"{target_alias}_{_HASH_COLUMN_NAME}").isNull())
|
87
101
|
.select(
|
88
102
|
*[
|
89
|
-
|
103
|
+
_build_mismatch_column(source_alias, col_name)
|
90
104
|
for col_name in df.columns
|
91
105
|
if col_name.startswith(f'{source_alias}_')
|
92
106
|
]
|
93
107
|
)
|
94
|
-
.drop(_HASH_COLUMN_NAME)
|
108
|
+
.drop(f"{_HASH_COLUMN_NAME}")
|
95
109
|
)
|
96
110
|
mismatch_count = 0
|
97
111
|
if mismatch:
|
@@ -123,23 +137,27 @@ def _get_mismatch_data(df: DataFrame, src_alias: str, tgt_alias: str) -> DataFra
|
|
123
137
|
.filter(col("hash_match") == lit(False))
|
124
138
|
.select(
|
125
139
|
*[
|
126
|
-
|
140
|
+
_build_mismatch_column(src_alias, col_name)
|
127
141
|
for col_name in df.columns
|
128
142
|
if col_name.startswith(f'{src_alias}_')
|
129
143
|
]
|
130
144
|
)
|
131
|
-
.drop(_HASH_COLUMN_NAME)
|
145
|
+
.drop(f"{_HASH_COLUMN_NAME}")
|
132
146
|
)
|
133
147
|
|
134
148
|
|
135
|
-
def
|
136
|
-
|
137
|
-
|
149
|
+
def _build_capture_df(df: DataFrame) -> DataFrame:
|
150
|
+
columns = [
|
151
|
+
col(DialectUtils.ansi_normalize_identifier(column)).alias(DialectUtils.unnormalize_identifier(column))
|
152
|
+
for column in df.columns
|
153
|
+
]
|
154
|
+
return df.select(*columns)
|
138
155
|
|
139
156
|
|
140
157
|
def capture_mismatch_data_and_columns(source: DataFrame, target: DataFrame, key_columns: list[str]) -> MismatchOutput:
|
141
|
-
source_df =
|
142
|
-
target_df =
|
158
|
+
source_df = _build_capture_df(source)
|
159
|
+
target_df = _build_capture_df(target)
|
160
|
+
unnormalized_key_columns = [DialectUtils.unnormalize_identifier(column) for column in key_columns]
|
143
161
|
|
144
162
|
source_columns = source_df.columns
|
145
163
|
target_columns = target_df.columns
|
@@ -148,10 +166,10 @@ def capture_mismatch_data_and_columns(source: DataFrame, target: DataFrame, key_
|
|
148
166
|
message = "source and target should have same columns for capturing the mismatch data"
|
149
167
|
source_missing = [column for column in target_columns if column not in source_columns]
|
150
168
|
target_missing = [column for column in source_columns if column not in target_columns]
|
151
|
-
raise
|
169
|
+
raise _raise_column_mismatch_exception(message, source_missing, target_missing)
|
152
170
|
|
153
|
-
check_columns = [column for column in source_columns if column not in
|
154
|
-
mismatch_df = _get_mismatch_df(source_df, target_df,
|
171
|
+
check_columns = [column for column in source_columns if column not in unnormalized_key_columns]
|
172
|
+
mismatch_df = _get_mismatch_df(source_df, target_df, unnormalized_key_columns, check_columns)
|
155
173
|
mismatch_columns = _get_mismatch_columns(mismatch_df, check_columns)
|
156
174
|
return MismatchOutput(mismatch_df, mismatch_columns)
|
157
175
|
|
@@ -167,31 +185,50 @@ def _get_mismatch_columns(df: DataFrame, columns: list[str]):
|
|
167
185
|
return mismatch_columns
|
168
186
|
|
169
187
|
|
188
|
+
def _normalize_mismatch_df_col(column, suffix):
|
189
|
+
unnormalized = DialectUtils.unnormalize_identifier(column) + suffix
|
190
|
+
return DialectUtils.ansi_normalize_identifier(unnormalized)
|
191
|
+
|
192
|
+
|
193
|
+
def _unnormalize_mismatch_df_col(column, suffix):
|
194
|
+
unnormalized = DialectUtils.unnormalize_identifier(column) + suffix
|
195
|
+
return unnormalized
|
196
|
+
|
197
|
+
|
170
198
|
def _get_mismatch_df(source: DataFrame, target: DataFrame, key_columns: list[str], column_list: list[str]):
|
171
|
-
source_aliased = [
|
172
|
-
|
199
|
+
source_aliased = [
|
200
|
+
col('base.' + DialectUtils.ansi_normalize_identifier(column)).alias(
|
201
|
+
_unnormalize_mismatch_df_col(column, '_base')
|
202
|
+
)
|
203
|
+
for column in column_list
|
204
|
+
]
|
205
|
+
target_aliased = [
|
206
|
+
col('compare.' + DialectUtils.ansi_normalize_identifier(column)).alias(
|
207
|
+
_unnormalize_mismatch_df_col(column, '_compare')
|
208
|
+
)
|
209
|
+
for column in column_list
|
210
|
+
]
|
173
211
|
|
174
|
-
match_expr = [
|
175
|
-
|
212
|
+
match_expr = [
|
213
|
+
expr(f"{_normalize_mismatch_df_col(column,'_base')}=={_normalize_mismatch_df_col(column,'_compare')}").alias(
|
214
|
+
_unnormalize_mismatch_df_col(column, '_match')
|
215
|
+
)
|
216
|
+
for column in column_list
|
217
|
+
]
|
218
|
+
key_cols = [col(DialectUtils.ansi_normalize_identifier(column)) for column in key_columns]
|
176
219
|
select_expr = key_cols + source_aliased + target_aliased + match_expr
|
177
220
|
|
178
|
-
filter_columns = " and ".join([column + "_match" for column in column_list])
|
179
|
-
filter_expr = ~expr(filter_columns)
|
180
|
-
|
181
221
|
logger.info(f"KEY COLUMNS: {key_columns}")
|
182
|
-
logger.info(f"FILTER COLUMNS: {filter_expr}")
|
183
222
|
logger.info(f"SELECT COLUMNS: {select_expr}")
|
184
223
|
|
185
224
|
mismatch_df = (
|
186
225
|
source.alias('base').join(other=target.alias('compare'), on=key_columns, how="inner").select(*select_expr)
|
187
226
|
)
|
188
227
|
|
189
|
-
compare_columns = [
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
def alias_column_str(alias: str, columns: list[str]) -> list[str]:
|
194
|
-
return [f"{alias}.{column}" for column in columns]
|
228
|
+
compare_columns = [
|
229
|
+
DialectUtils.ansi_normalize_identifier(column) for column in mismatch_df.columns if column not in key_columns
|
230
|
+
]
|
231
|
+
return mismatch_df.select(*key_cols + sorted(compare_columns))
|
195
232
|
|
196
233
|
|
197
234
|
def _generate_agg_join_condition(source_alias: str, target_alias: str, key_columns: list[str]):
|
@@ -3,6 +3,7 @@ from abc import ABC, abstractmethod
|
|
3
3
|
|
4
4
|
from pyspark.sql import DataFrame
|
5
5
|
|
6
|
+
from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier
|
6
7
|
from databricks.labs.lakebridge.reconcile.exception import DataSourceRuntimeException
|
7
8
|
from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema
|
8
9
|
|
@@ -28,15 +29,34 @@ class DataSource(ABC):
|
|
28
29
|
catalog: str | None,
|
29
30
|
schema: str,
|
30
31
|
table: str,
|
32
|
+
normalize: bool = True,
|
31
33
|
) -> list[Schema]:
|
32
34
|
return NotImplemented
|
33
35
|
|
36
|
+
@abstractmethod
|
37
|
+
def normalize_identifier(self, identifier: str) -> NormalizedIdentifier:
|
38
|
+
pass
|
39
|
+
|
34
40
|
@classmethod
|
35
41
|
def log_and_throw_exception(cls, exception: Exception, fetch_type: str, query: str):
|
36
42
|
error_msg = f"Runtime exception occurred while fetching {fetch_type} using {query} : {exception}"
|
37
43
|
logger.warning(error_msg)
|
38
44
|
raise DataSourceRuntimeException(error_msg) from exception
|
39
45
|
|
46
|
+
def _map_meta_column(self, meta_column, normalize: bool) -> Schema:
|
47
|
+
"""Create a normalized Schema DTO from the database metadata
|
48
|
+
|
49
|
+
Used in the implementations of get_schema to build a Schema DTO from the `INFORMATION_SCHEMA` query result.
|
50
|
+
The returned Schema is normalized in case the database is having columns with special characters and standardize
|
51
|
+
"""
|
52
|
+
name = meta_column.col_name.lower()
|
53
|
+
dtype = meta_column.data_type.strip().lower()
|
54
|
+
if normalize:
|
55
|
+
normalized = self.normalize_identifier(name)
|
56
|
+
return Schema(normalized.ansi_normalized, dtype, normalized.ansi_normalized, normalized.source_normalized)
|
57
|
+
|
58
|
+
return Schema(name, dtype, name, name)
|
59
|
+
|
40
60
|
|
41
61
|
class MockDataSource(DataSource):
|
42
62
|
|
@@ -64,9 +84,12 @@ class MockDataSource(DataSource):
|
|
64
84
|
return self.log_and_throw_exception(self._exception, "data", f"({catalog}, {schema}, {query})")
|
65
85
|
return mock_df
|
66
86
|
|
67
|
-
def get_schema(self, catalog: str | None, schema: str, table: str) -> list[Schema]:
|
87
|
+
def get_schema(self, catalog: str | None, schema: str, table: str, normalize: bool = True) -> list[Schema]:
|
68
88
|
catalog_str = catalog if catalog else ""
|
69
89
|
mock_schema = self._schema_repository.get((catalog_str, schema, table))
|
70
90
|
if not mock_schema:
|
71
91
|
return self.log_and_throw_exception(self._exception, "schema", f"({catalog}, {schema}, {table})")
|
72
92
|
return mock_schema
|
93
|
+
|
94
|
+
def normalize_identifier(self, identifier: str) -> NormalizedIdentifier:
|
95
|
+
return NormalizedIdentifier(identifier, identifier)
|
@@ -8,7 +8,9 @@ from pyspark.sql.functions import col
|
|
8
8
|
from sqlglot import Dialect
|
9
9
|
|
10
10
|
from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource
|
11
|
+
from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier
|
11
12
|
from databricks.labs.lakebridge.reconcile.connectors.secrets import SecretsMixin
|
13
|
+
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils
|
12
14
|
from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema
|
13
15
|
from databricks.sdk import WorkspaceClient
|
14
16
|
|
@@ -35,6 +37,7 @@ def _get_schema_query(catalog: str, schema: str, table: str):
|
|
35
37
|
|
36
38
|
|
37
39
|
class DatabricksDataSource(DataSource, SecretsMixin):
|
40
|
+
_IDENTIFIER_DELIMITER = "`"
|
38
41
|
|
39
42
|
def __init__(
|
40
43
|
self,
|
@@ -74,6 +77,7 @@ class DatabricksDataSource(DataSource, SecretsMixin):
|
|
74
77
|
catalog: str | None,
|
75
78
|
schema: str,
|
76
79
|
table: str,
|
80
|
+
normalize: bool = True,
|
77
81
|
) -> list[Schema]:
|
78
82
|
catalog_str = catalog if catalog else "hive_metastore"
|
79
83
|
schema_query = _get_schema_query(catalog_str, schema, table)
|
@@ -82,6 +86,13 @@ class DatabricksDataSource(DataSource, SecretsMixin):
|
|
82
86
|
logger.info(f"Fetching Schema: Started at: {datetime.now()}")
|
83
87
|
schema_metadata = self._spark.sql(schema_query).where("col_name not like '#%'").distinct().collect()
|
84
88
|
logger.info(f"Schema fetched successfully. Completed at: {datetime.now()}")
|
85
|
-
return [
|
89
|
+
return [self._map_meta_column(field, normalize) for field in schema_metadata]
|
86
90
|
except (RuntimeError, PySparkException) as e:
|
87
91
|
return self.log_and_throw_exception(e, "schema", schema_query)
|
92
|
+
|
93
|
+
def normalize_identifier(self, identifier: str) -> NormalizedIdentifier:
|
94
|
+
return DialectUtils.normalize_identifier(
|
95
|
+
identifier,
|
96
|
+
source_start_delimiter=DatabricksDataSource._IDENTIFIER_DELIMITER,
|
97
|
+
source_end_delimiter=DatabricksDataSource._IDENTIFIER_DELIMITER,
|
98
|
+
)
|
@@ -0,0 +1,126 @@
|
|
1
|
+
from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier
|
2
|
+
|
3
|
+
|
4
|
+
class DialectUtils:
|
5
|
+
_ANSI_IDENTIFIER_DELIMITER = "`"
|
6
|
+
|
7
|
+
@staticmethod
|
8
|
+
def unnormalize_identifier(identifier: str) -> str:
|
9
|
+
"""Return an ansi identifier without the outer backticks.
|
10
|
+
|
11
|
+
Use this at your own risk as the missing outer backticks will result in bugs.
|
12
|
+
E.g. <`mary's lamb`> is returned <mary's lamb> so the outer backticks are needed.
|
13
|
+
This is useful for scenarios where the returned identifier will be part of another delimited identifier.
|
14
|
+
|
15
|
+
:param identifier: a database identifier
|
16
|
+
:return: ansi identifier without the outer backticks
|
17
|
+
"""
|
18
|
+
ansi = DialectUtils.ansi_normalize_identifier(identifier)
|
19
|
+
unescape = (
|
20
|
+
DialectUtils._unescape_source_end_delimiter(ansi[1:-1], DialectUtils._ANSI_IDENTIFIER_DELIMITER)
|
21
|
+
if ansi
|
22
|
+
else ansi
|
23
|
+
)
|
24
|
+
return unescape
|
25
|
+
|
26
|
+
@staticmethod
|
27
|
+
def ansi_normalize_identifier(identifier: str) -> str:
|
28
|
+
return DialectUtils.normalize_identifier(
|
29
|
+
identifier, DialectUtils._ANSI_IDENTIFIER_DELIMITER, DialectUtils._ANSI_IDENTIFIER_DELIMITER
|
30
|
+
).ansi_normalized
|
31
|
+
|
32
|
+
@staticmethod
|
33
|
+
def normalize_identifier(
|
34
|
+
identifier: str, source_start_delimiter: str, source_end_delimiter: str
|
35
|
+
) -> NormalizedIdentifier:
|
36
|
+
identifier = identifier.strip().lower()
|
37
|
+
|
38
|
+
ansi = DialectUtils._normalize_identifier_source_agnostic(
|
39
|
+
identifier,
|
40
|
+
source_start_delimiter,
|
41
|
+
source_end_delimiter,
|
42
|
+
DialectUtils._ANSI_IDENTIFIER_DELIMITER,
|
43
|
+
DialectUtils._ANSI_IDENTIFIER_DELIMITER,
|
44
|
+
)
|
45
|
+
|
46
|
+
# Input was already ansi normalized
|
47
|
+
if ansi == identifier:
|
48
|
+
source = DialectUtils._normalize_identifier_source_agnostic(
|
49
|
+
identifier,
|
50
|
+
DialectUtils._ANSI_IDENTIFIER_DELIMITER,
|
51
|
+
DialectUtils._ANSI_IDENTIFIER_DELIMITER,
|
52
|
+
source_start_delimiter,
|
53
|
+
source_end_delimiter,
|
54
|
+
)
|
55
|
+
|
56
|
+
# Ansi has backticks escaped which has to be unescaped for other delimiters and escape source end delimiters
|
57
|
+
if source != ansi:
|
58
|
+
source = DialectUtils._unescape_source_end_delimiter(source, DialectUtils._ANSI_IDENTIFIER_DELIMITER)
|
59
|
+
source = (
|
60
|
+
DialectUtils._escape_source_end_delimiter(source, source_start_delimiter, source_end_delimiter)
|
61
|
+
if source
|
62
|
+
else source
|
63
|
+
)
|
64
|
+
else:
|
65
|
+
# Make sure backticks are escaped properly for ansi and source end delimiters are unescaped
|
66
|
+
ansi = DialectUtils._unescape_source_end_delimiter(ansi, source_end_delimiter)
|
67
|
+
ansi = DialectUtils._escape_backticks(ansi) if ansi else ansi
|
68
|
+
|
69
|
+
if source_end_delimiter != DialectUtils._ANSI_IDENTIFIER_DELIMITER:
|
70
|
+
ansi = DialectUtils._unescape_source_end_delimiter(ansi, source_end_delimiter)
|
71
|
+
|
72
|
+
source = DialectUtils._normalize_identifier_source_agnostic(
|
73
|
+
identifier, source_start_delimiter, source_end_delimiter, source_start_delimiter, source_end_delimiter
|
74
|
+
)
|
75
|
+
|
76
|
+
# Make sure source end delimiter is escaped else nothing as it was already normalized
|
77
|
+
if source != identifier:
|
78
|
+
source = (
|
79
|
+
DialectUtils._escape_source_end_delimiter(source, source_start_delimiter, source_end_delimiter)
|
80
|
+
if source
|
81
|
+
else source
|
82
|
+
)
|
83
|
+
|
84
|
+
return NormalizedIdentifier(ansi, source)
|
85
|
+
|
86
|
+
@staticmethod
|
87
|
+
def _normalize_identifier_source_agnostic(
|
88
|
+
identifier: str,
|
89
|
+
source_start_delimiter: str,
|
90
|
+
source_end_delimiter: str,
|
91
|
+
expected_source_start_delimiter: str,
|
92
|
+
expected_source_end_delimiter: str,
|
93
|
+
) -> str:
|
94
|
+
if identifier == "" or identifier is None:
|
95
|
+
return ""
|
96
|
+
|
97
|
+
if DialectUtils.is_already_delimited(
|
98
|
+
identifier, expected_source_start_delimiter, expected_source_end_delimiter
|
99
|
+
):
|
100
|
+
return identifier
|
101
|
+
|
102
|
+
if DialectUtils.is_already_delimited(identifier, source_start_delimiter, source_end_delimiter):
|
103
|
+
stripped_identifier = identifier.removeprefix(source_start_delimiter).removesuffix(source_end_delimiter)
|
104
|
+
else:
|
105
|
+
stripped_identifier = identifier
|
106
|
+
return f"{expected_source_start_delimiter}{stripped_identifier}{expected_source_end_delimiter}"
|
107
|
+
|
108
|
+
@staticmethod
|
109
|
+
def is_already_delimited(identifier: str, start_delimiter: str, end_delimiter: str) -> bool:
|
110
|
+
return identifier.startswith(start_delimiter) and identifier.endswith(end_delimiter)
|
111
|
+
|
112
|
+
@staticmethod
|
113
|
+
def _escape_backticks(identifier: str) -> str:
|
114
|
+
identifier = identifier[1:-1]
|
115
|
+
identifier = identifier.replace("`", "``")
|
116
|
+
return f"`{identifier}`"
|
117
|
+
|
118
|
+
@staticmethod
|
119
|
+
def _unescape_source_end_delimiter(identifier: str, source_end_delimiter: str) -> str:
|
120
|
+
return identifier.replace(f"{source_end_delimiter}{source_end_delimiter}", source_end_delimiter)
|
121
|
+
|
122
|
+
@staticmethod
|
123
|
+
def _escape_source_end_delimiter(identifier: str, start_end_delimiter, source_end_delimiter: str) -> str:
|
124
|
+
identifier = identifier[1:-1]
|
125
|
+
identifier = identifier.replace(source_end_delimiter, f"{source_end_delimiter}{source_end_delimiter}")
|
126
|
+
return f"{start_end_delimiter}{identifier}{source_end_delimiter}"
|
@@ -9,7 +9,9 @@ from sqlglot import Dialect
|
|
9
9
|
|
10
10
|
from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource
|
11
11
|
from databricks.labs.lakebridge.reconcile.connectors.jdbc_reader import JDBCReaderMixin
|
12
|
+
from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier
|
12
13
|
from databricks.labs.lakebridge.reconcile.connectors.secrets import SecretsMixin
|
14
|
+
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils
|
13
15
|
from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema
|
14
16
|
from databricks.sdk import WorkspaceClient
|
15
17
|
|
@@ -18,6 +20,7 @@ logger = logging.getLogger(__name__)
|
|
18
20
|
|
19
21
|
class OracleDataSource(DataSource, SecretsMixin, JDBCReaderMixin):
|
20
22
|
_DRIVER = "oracle"
|
23
|
+
_IDENTIFIER_DELIMITER = "\""
|
21
24
|
_SCHEMA_QUERY = """select column_name, case when (data_precision is not null
|
22
25
|
and data_scale <> 0)
|
23
26
|
then data_type || '(' || data_precision || ',' || data_scale || ')'
|
@@ -78,6 +81,7 @@ class OracleDataSource(DataSource, SecretsMixin, JDBCReaderMixin):
|
|
78
81
|
catalog: str | None,
|
79
82
|
schema: str,
|
80
83
|
table: str,
|
84
|
+
normalize: bool = True,
|
81
85
|
) -> list[Schema]:
|
82
86
|
schema_query = re.sub(
|
83
87
|
r'\s+',
|
@@ -91,7 +95,7 @@ class OracleDataSource(DataSource, SecretsMixin, JDBCReaderMixin):
|
|
91
95
|
schema_metadata = df.select([col(c).alias(c.lower()) for c in df.columns]).collect()
|
92
96
|
logger.info(f"Schema fetched successfully. Completed at: {datetime.now()}")
|
93
97
|
logger.debug(f"schema_metadata: ${schema_metadata}")
|
94
|
-
return [
|
98
|
+
return [self._map_meta_column(field, normalize) for field in schema_metadata]
|
95
99
|
except (RuntimeError, PySparkException) as e:
|
96
100
|
return self.log_and_throw_exception(e, "schema", schema_query)
|
97
101
|
|
@@ -106,3 +110,10 @@ class OracleDataSource(DataSource, SecretsMixin, JDBCReaderMixin):
|
|
106
110
|
|
107
111
|
def reader(self, query: str) -> DataFrameReader:
|
108
112
|
return self._get_jdbc_reader(query, self.get_jdbc_url, OracleDataSource._DRIVER)
|
113
|
+
|
114
|
+
def normalize_identifier(self, identifier: str) -> NormalizedIdentifier:
|
115
|
+
return DialectUtils.normalize_identifier(
|
116
|
+
identifier,
|
117
|
+
source_start_delimiter=OracleDataSource._IDENTIFIER_DELIMITER,
|
118
|
+
source_end_delimiter=OracleDataSource._IDENTIFIER_DELIMITER,
|
119
|
+
)
|
@@ -11,8 +11,26 @@ class SecretsMixin:
|
|
11
11
|
_ws: WorkspaceClient
|
12
12
|
_secret_scope: str
|
13
13
|
|
14
|
+
def _get_secret_or_none(self, secret_key: str) -> str | None:
|
15
|
+
"""
|
16
|
+
Get the secret value given a secret scope & secret key. Log a warning if secret does not exist
|
17
|
+
Used To ensure backwards compatibility when supporting new secrets
|
18
|
+
"""
|
19
|
+
try:
|
20
|
+
# Return the decoded secret value in string format
|
21
|
+
return self._get_secret(secret_key)
|
22
|
+
except NotFound as e:
|
23
|
+
logger.warning(f"Secret not found: key={secret_key}")
|
24
|
+
logger.debug("Secret lookup failed", exc_info=e)
|
25
|
+
return None
|
26
|
+
|
14
27
|
def _get_secret(self, secret_key: str) -> str:
|
15
|
-
"""Get the secret value given a secret scope & secret key.
|
28
|
+
"""Get the secret value given a secret scope & secret key.
|
29
|
+
|
30
|
+
Raises:
|
31
|
+
NotFound: The secret could not be found.
|
32
|
+
UnicodeDecodeError: The secret value was not Base64-encoded UTF-8.
|
33
|
+
"""
|
16
34
|
try:
|
17
35
|
# Return the decoded secret value in string format
|
18
36
|
secret = self._ws.secrets.get_secret(self._secret_scope, secret_key)
|
@@ -11,7 +11,9 @@ from cryptography.hazmat.primitives import serialization
|
|
11
11
|
|
12
12
|
from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource
|
13
13
|
from databricks.labs.lakebridge.reconcile.connectors.jdbc_reader import JDBCReaderMixin
|
14
|
+
from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier
|
14
15
|
from databricks.labs.lakebridge.reconcile.connectors.secrets import SecretsMixin
|
16
|
+
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils
|
15
17
|
from databricks.labs.lakebridge.reconcile.exception import InvalidSnowflakePemPrivateKey
|
16
18
|
from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema
|
17
19
|
from databricks.sdk import WorkspaceClient
|
@@ -22,6 +24,8 @@ logger = logging.getLogger(__name__)
|
|
22
24
|
|
23
25
|
class SnowflakeDataSource(DataSource, SecretsMixin, JDBCReaderMixin):
|
24
26
|
_DRIVER = "snowflake"
|
27
|
+
_IDENTIFIER_DELIMITER = "\""
|
28
|
+
|
25
29
|
"""
|
26
30
|
* INFORMATION_SCHEMA:
|
27
31
|
- see https://docs.snowflake.com/en/sql-reference/info-schema#considerations-for-replacing-show-commands-with-information-schema-views
|
@@ -75,31 +79,6 @@ class SnowflakeDataSource(DataSource, SecretsMixin, JDBCReaderMixin):
|
|
75
79
|
f"&warehouse={self._get_secret('sfWarehouse')}&role={self._get_secret('sfRole')}"
|
76
80
|
)
|
77
81
|
|
78
|
-
@staticmethod
|
79
|
-
def get_private_key(pem_private_key: str) -> str:
|
80
|
-
try:
|
81
|
-
private_key_bytes = pem_private_key.encode("UTF-8")
|
82
|
-
p_key = serialization.load_pem_private_key(
|
83
|
-
private_key_bytes,
|
84
|
-
password=None,
|
85
|
-
backend=default_backend(),
|
86
|
-
)
|
87
|
-
pkb = p_key.private_bytes(
|
88
|
-
encoding=serialization.Encoding.PEM,
|
89
|
-
format=serialization.PrivateFormat.PKCS8,
|
90
|
-
encryption_algorithm=serialization.NoEncryption(),
|
91
|
-
)
|
92
|
-
pkb_str = pkb.decode("UTF-8")
|
93
|
-
# Remove the first and last lines (BEGIN/END markers)
|
94
|
-
private_key_pem_lines = pkb_str.strip().split('\n')[1:-1]
|
95
|
-
# Join the lines to form the base64 encoded string
|
96
|
-
private_key_pem_str = ''.join(private_key_pem_lines)
|
97
|
-
return private_key_pem_str
|
98
|
-
except Exception as e:
|
99
|
-
message = f"Failed to load or process the provided PEM private key. --> {e}"
|
100
|
-
logger.error(message)
|
101
|
-
raise InvalidSnowflakePemPrivateKey(message) from e
|
102
|
-
|
103
82
|
def read_data(
|
104
83
|
self,
|
105
84
|
catalog: str | None,
|
@@ -128,6 +107,7 @@ class SnowflakeDataSource(DataSource, SecretsMixin, JDBCReaderMixin):
|
|
128
107
|
catalog: str | None,
|
129
108
|
schema: str,
|
130
109
|
table: str,
|
110
|
+
normalize: bool = True,
|
131
111
|
) -> list[Schema]:
|
132
112
|
"""
|
133
113
|
Fetch the Schema from the INFORMATION_SCHEMA.COLUMNS table in Snowflake.
|
@@ -144,13 +124,20 @@ class SnowflakeDataSource(DataSource, SecretsMixin, JDBCReaderMixin):
|
|
144
124
|
try:
|
145
125
|
logger.debug(f"Fetching schema using query: \n`{schema_query}`")
|
146
126
|
logger.info(f"Fetching Schema: Started at: {datetime.now()}")
|
147
|
-
|
127
|
+
df = self.reader(schema_query).load()
|
128
|
+
schema_metadata = df.select([col(c).alias(c.lower()) for c in df.columns]).collect()
|
148
129
|
logger.info(f"Schema fetched successfully. Completed at: {datetime.now()}")
|
149
|
-
return [
|
130
|
+
return [self._map_meta_column(field, normalize) for field in schema_metadata]
|
150
131
|
except (RuntimeError, PySparkException) as e:
|
151
132
|
return self.log_and_throw_exception(e, "schema", schema_query)
|
152
133
|
|
153
134
|
def reader(self, query: str) -> DataFrameReader:
|
135
|
+
options = self._get_snowflake_options()
|
136
|
+
return self._spark.read.format("snowflake").option("dbtable", f"({query}) as tmp").options(**options)
|
137
|
+
|
138
|
+
# TODO cache this method using @functools.cache
|
139
|
+
# Pay attention to https://pylint.pycqa.org/en/latest/user_guide/messages/warning/method-cache-max-size-none.html
|
140
|
+
def _get_snowflake_options(self):
|
154
141
|
options = {
|
155
142
|
"sfUrl": self._get_secret('sfUrl'),
|
156
143
|
"sfUser": self._get_secret('sfUser'),
|
@@ -159,15 +146,61 @@ class SnowflakeDataSource(DataSource, SecretsMixin, JDBCReaderMixin):
|
|
159
146
|
"sfWarehouse": self._get_secret('sfWarehouse'),
|
160
147
|
"sfRole": self._get_secret('sfRole'),
|
161
148
|
}
|
149
|
+
options = options | self._get_snowflake_auth_options()
|
150
|
+
|
151
|
+
return options
|
152
|
+
|
153
|
+
def _get_snowflake_auth_options(self):
|
162
154
|
try:
|
163
|
-
|
155
|
+
key = SnowflakeDataSource._get_private_key(
|
156
|
+
self._get_secret('pem_private_key'), self._get_secret_or_none('pem_private_key_password')
|
157
|
+
)
|
158
|
+
return {"pem_private_key": key}
|
164
159
|
except (NotFound, KeyError):
|
165
160
|
logger.warning("pem_private_key not found. Checking for sfPassword")
|
166
161
|
try:
|
167
|
-
|
162
|
+
password = self._get_secret('sfPassword')
|
163
|
+
return {"sfPassword": password}
|
168
164
|
except (NotFound, KeyError) as e:
|
169
165
|
message = "sfPassword and pem_private_key not found. Either one is required for snowflake auth."
|
170
166
|
logger.error(message)
|
171
167
|
raise NotFound(message) from e
|
172
168
|
|
173
|
-
|
169
|
+
@staticmethod
|
170
|
+
def _get_private_key(pem_private_key: str, pem_private_key_password: str | None) -> str:
|
171
|
+
try:
|
172
|
+
private_key_bytes = pem_private_key.encode("UTF-8")
|
173
|
+
password_bytes = pem_private_key_password.encode("UTF-8") if pem_private_key_password else None
|
174
|
+
except UnicodeEncodeError as e:
|
175
|
+
message = f"Invalid pem key and/or pem password: unable to encode. --> {e}"
|
176
|
+
logger.error(message)
|
177
|
+
raise ValueError(message) from e
|
178
|
+
|
179
|
+
try:
|
180
|
+
p_key = serialization.load_pem_private_key(
|
181
|
+
private_key_bytes,
|
182
|
+
password_bytes,
|
183
|
+
backend=default_backend(),
|
184
|
+
)
|
185
|
+
pkb = p_key.private_bytes(
|
186
|
+
encoding=serialization.Encoding.PEM,
|
187
|
+
format=serialization.PrivateFormat.PKCS8,
|
188
|
+
encryption_algorithm=serialization.NoEncryption(),
|
189
|
+
)
|
190
|
+
pkb_str = pkb.decode("UTF-8")
|
191
|
+
# Remove the first and last lines (BEGIN/END markers)
|
192
|
+
private_key_pem_lines = pkb_str.strip().split('\n')[1:-1]
|
193
|
+
# Join the lines to form the base64 encoded string
|
194
|
+
private_key_pem_str = ''.join(private_key_pem_lines)
|
195
|
+
return private_key_pem_str
|
196
|
+
except Exception as e:
|
197
|
+
message = f"Failed to load or process the provided PEM private key. --> {e}"
|
198
|
+
logger.error(message)
|
199
|
+
raise InvalidSnowflakePemPrivateKey(message) from e
|
200
|
+
|
201
|
+
def normalize_identifier(self, identifier: str) -> NormalizedIdentifier:
|
202
|
+
return DialectUtils.normalize_identifier(
|
203
|
+
identifier,
|
204
|
+
source_start_delimiter=SnowflakeDataSource._IDENTIFIER_DELIMITER,
|
205
|
+
source_end_delimiter=SnowflakeDataSource._IDENTIFIER_DELIMITER,
|
206
|
+
)
|