databricks-labs-lakebridge 0.10.7__py3-none-any.whl → 0.10.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- databricks/labs/lakebridge/__about__.py +1 -1
- databricks/labs/lakebridge/assessments/profiler_validator.py +103 -0
- databricks/labs/lakebridge/base_install.py +1 -5
- databricks/labs/lakebridge/cli.py +13 -6
- databricks/labs/lakebridge/helpers/validation.py +5 -3
- databricks/labs/lakebridge/install.py +40 -481
- databricks/labs/lakebridge/reconcile/connectors/data_source.py +9 -5
- databricks/labs/lakebridge/reconcile/connectors/databricks.py +2 -1
- databricks/labs/lakebridge/reconcile/connectors/oracle.py +2 -1
- databricks/labs/lakebridge/reconcile/connectors/secrets.py +19 -1
- databricks/labs/lakebridge/reconcile/connectors/snowflake.py +50 -29
- databricks/labs/lakebridge/reconcile/connectors/tsql.py +2 -1
- databricks/labs/lakebridge/reconcile/query_builder/base.py +50 -11
- databricks/labs/lakebridge/reconcile/query_builder/expression_generator.py +8 -2
- databricks/labs/lakebridge/reconcile/query_builder/hash_query.py +7 -13
- databricks/labs/lakebridge/reconcile/query_builder/sampling_query.py +18 -19
- databricks/labs/lakebridge/reconcile/query_builder/threshold_query.py +36 -15
- databricks/labs/lakebridge/reconcile/recon_config.py +0 -15
- databricks/labs/lakebridge/reconcile/reconciliation.py +4 -1
- databricks/labs/lakebridge/reconcile/trigger_recon_aggregate_service.py +11 -31
- databricks/labs/lakebridge/reconcile/trigger_recon_service.py +4 -1
- databricks/labs/lakebridge/transpiler/execute.py +34 -28
- databricks/labs/lakebridge/transpiler/installers.py +523 -0
- databricks/labs/lakebridge/transpiler/lsp/lsp_engine.py +2 -0
- {databricks_labs_lakebridge-0.10.7.dist-info → databricks_labs_lakebridge-0.10.8.dist-info}/METADATA +1 -1
- {databricks_labs_lakebridge-0.10.7.dist-info → databricks_labs_lakebridge-0.10.8.dist-info}/RECORD +30 -28
- {databricks_labs_lakebridge-0.10.7.dist-info → databricks_labs_lakebridge-0.10.8.dist-info}/WHEEL +0 -0
- {databricks_labs_lakebridge-0.10.7.dist-info → databricks_labs_lakebridge-0.10.8.dist-info}/entry_points.txt +0 -0
- {databricks_labs_lakebridge-0.10.7.dist-info → databricks_labs_lakebridge-0.10.8.dist-info}/licenses/LICENSE +0 -0
- {databricks_labs_lakebridge-0.10.7.dist-info → databricks_labs_lakebridge-0.10.8.dist-info}/licenses/NOTICE +0 -0
@@ -11,8 +11,26 @@ class SecretsMixin:
|
|
11
11
|
_ws: WorkspaceClient
|
12
12
|
_secret_scope: str
|
13
13
|
|
14
|
+
def _get_secret_or_none(self, secret_key: str) -> str | None:
|
15
|
+
"""
|
16
|
+
Get the secret value given a secret scope & secret key. Log a warning if secret does not exist
|
17
|
+
Used To ensure backwards compatibility when supporting new secrets
|
18
|
+
"""
|
19
|
+
try:
|
20
|
+
# Return the decoded secret value in string format
|
21
|
+
return self._get_secret(secret_key)
|
22
|
+
except NotFound as e:
|
23
|
+
logger.warning(f"Secret not found: key={secret_key}")
|
24
|
+
logger.debug("Secret lookup failed", exc_info=e)
|
25
|
+
return None
|
26
|
+
|
14
27
|
def _get_secret(self, secret_key: str) -> str:
|
15
|
-
"""Get the secret value given a secret scope & secret key.
|
28
|
+
"""Get the secret value given a secret scope & secret key.
|
29
|
+
|
30
|
+
Raises:
|
31
|
+
NotFound: The secret could not be found.
|
32
|
+
UnicodeDecodeError: The secret value was not Base64-encoded UTF-8.
|
33
|
+
"""
|
16
34
|
try:
|
17
35
|
# Return the decoded secret value in string format
|
18
36
|
secret = self._ws.secrets.get_secret(self._secret_scope, secret_key)
|
@@ -79,31 +79,6 @@ class SnowflakeDataSource(DataSource, SecretsMixin, JDBCReaderMixin):
|
|
79
79
|
f"&warehouse={self._get_secret('sfWarehouse')}&role={self._get_secret('sfRole')}"
|
80
80
|
)
|
81
81
|
|
82
|
-
@staticmethod
|
83
|
-
def get_private_key(pem_private_key: str) -> str:
|
84
|
-
try:
|
85
|
-
private_key_bytes = pem_private_key.encode("UTF-8")
|
86
|
-
p_key = serialization.load_pem_private_key(
|
87
|
-
private_key_bytes,
|
88
|
-
password=None,
|
89
|
-
backend=default_backend(),
|
90
|
-
)
|
91
|
-
pkb = p_key.private_bytes(
|
92
|
-
encoding=serialization.Encoding.PEM,
|
93
|
-
format=serialization.PrivateFormat.PKCS8,
|
94
|
-
encryption_algorithm=serialization.NoEncryption(),
|
95
|
-
)
|
96
|
-
pkb_str = pkb.decode("UTF-8")
|
97
|
-
# Remove the first and last lines (BEGIN/END markers)
|
98
|
-
private_key_pem_lines = pkb_str.strip().split('\n')[1:-1]
|
99
|
-
# Join the lines to form the base64 encoded string
|
100
|
-
private_key_pem_str = ''.join(private_key_pem_lines)
|
101
|
-
return private_key_pem_str
|
102
|
-
except Exception as e:
|
103
|
-
message = f"Failed to load or process the provided PEM private key. --> {e}"
|
104
|
-
logger.error(message)
|
105
|
-
raise InvalidSnowflakePemPrivateKey(message) from e
|
106
|
-
|
107
82
|
def read_data(
|
108
83
|
self,
|
109
84
|
catalog: str | None,
|
@@ -132,6 +107,7 @@ class SnowflakeDataSource(DataSource, SecretsMixin, JDBCReaderMixin):
|
|
132
107
|
catalog: str | None,
|
133
108
|
schema: str,
|
134
109
|
table: str,
|
110
|
+
normalize: bool = True,
|
135
111
|
) -> list[Schema]:
|
136
112
|
"""
|
137
113
|
Fetch the Schema from the INFORMATION_SCHEMA.COLUMNS table in Snowflake.
|
@@ -151,11 +127,17 @@ class SnowflakeDataSource(DataSource, SecretsMixin, JDBCReaderMixin):
|
|
151
127
|
df = self.reader(schema_query).load()
|
152
128
|
schema_metadata = df.select([col(c).alias(c.lower()) for c in df.columns]).collect()
|
153
129
|
logger.info(f"Schema fetched successfully. Completed at: {datetime.now()}")
|
154
|
-
return [self._map_meta_column(field) for field in schema_metadata]
|
130
|
+
return [self._map_meta_column(field, normalize) for field in schema_metadata]
|
155
131
|
except (RuntimeError, PySparkException) as e:
|
156
132
|
return self.log_and_throw_exception(e, "schema", schema_query)
|
157
133
|
|
158
134
|
def reader(self, query: str) -> DataFrameReader:
|
135
|
+
options = self._get_snowflake_options()
|
136
|
+
return self._spark.read.format("snowflake").option("dbtable", f"({query}) as tmp").options(**options)
|
137
|
+
|
138
|
+
# TODO cache this method using @functools.cache
|
139
|
+
# Pay attention to https://pylint.pycqa.org/en/latest/user_guide/messages/warning/method-cache-max-size-none.html
|
140
|
+
def _get_snowflake_options(self):
|
159
141
|
options = {
|
160
142
|
"sfUrl": self._get_secret('sfUrl'),
|
161
143
|
"sfUser": self._get_secret('sfUser'),
|
@@ -164,18 +146,57 @@ class SnowflakeDataSource(DataSource, SecretsMixin, JDBCReaderMixin):
|
|
164
146
|
"sfWarehouse": self._get_secret('sfWarehouse'),
|
165
147
|
"sfRole": self._get_secret('sfRole'),
|
166
148
|
}
|
149
|
+
options = options | self._get_snowflake_auth_options()
|
150
|
+
|
151
|
+
return options
|
152
|
+
|
153
|
+
def _get_snowflake_auth_options(self):
|
167
154
|
try:
|
168
|
-
|
155
|
+
key = SnowflakeDataSource._get_private_key(
|
156
|
+
self._get_secret('pem_private_key'), self._get_secret_or_none('pem_private_key_password')
|
157
|
+
)
|
158
|
+
return {"pem_private_key": key}
|
169
159
|
except (NotFound, KeyError):
|
170
160
|
logger.warning("pem_private_key not found. Checking for sfPassword")
|
171
161
|
try:
|
172
|
-
|
162
|
+
password = self._get_secret('sfPassword')
|
163
|
+
return {"sfPassword": password}
|
173
164
|
except (NotFound, KeyError) as e:
|
174
165
|
message = "sfPassword and pem_private_key not found. Either one is required for snowflake auth."
|
175
166
|
logger.error(message)
|
176
167
|
raise NotFound(message) from e
|
177
168
|
|
178
|
-
|
169
|
+
@staticmethod
|
170
|
+
def _get_private_key(pem_private_key: str, pem_private_key_password: str | None) -> str:
|
171
|
+
try:
|
172
|
+
private_key_bytes = pem_private_key.encode("UTF-8")
|
173
|
+
password_bytes = pem_private_key_password.encode("UTF-8") if pem_private_key_password else None
|
174
|
+
except UnicodeEncodeError as e:
|
175
|
+
message = f"Invalid pem key and/or pem password: unable to encode. --> {e}"
|
176
|
+
logger.error(message)
|
177
|
+
raise ValueError(message) from e
|
178
|
+
|
179
|
+
try:
|
180
|
+
p_key = serialization.load_pem_private_key(
|
181
|
+
private_key_bytes,
|
182
|
+
password_bytes,
|
183
|
+
backend=default_backend(),
|
184
|
+
)
|
185
|
+
pkb = p_key.private_bytes(
|
186
|
+
encoding=serialization.Encoding.PEM,
|
187
|
+
format=serialization.PrivateFormat.PKCS8,
|
188
|
+
encryption_algorithm=serialization.NoEncryption(),
|
189
|
+
)
|
190
|
+
pkb_str = pkb.decode("UTF-8")
|
191
|
+
# Remove the first and last lines (BEGIN/END markers)
|
192
|
+
private_key_pem_lines = pkb_str.strip().split('\n')[1:-1]
|
193
|
+
# Join the lines to form the base64 encoded string
|
194
|
+
private_key_pem_str = ''.join(private_key_pem_lines)
|
195
|
+
return private_key_pem_str
|
196
|
+
except Exception as e:
|
197
|
+
message = f"Failed to load or process the provided PEM private key. --> {e}"
|
198
|
+
logger.error(message)
|
199
|
+
raise InvalidSnowflakePemPrivateKey(message) from e
|
179
200
|
|
180
201
|
def normalize_identifier(self, identifier: str) -> NormalizedIdentifier:
|
181
202
|
return DialectUtils.normalize_identifier(
|
@@ -109,6 +109,7 @@ class TSQLServerDataSource(DataSource, SecretsMixin, JDBCReaderMixin):
|
|
109
109
|
catalog: str | None,
|
110
110
|
schema: str,
|
111
111
|
table: str,
|
112
|
+
normalize: bool = True,
|
112
113
|
) -> list[Schema]:
|
113
114
|
"""
|
114
115
|
Fetch the Schema from the INFORMATION_SCHEMA.COLUMNS table in SQL Server.
|
@@ -128,7 +129,7 @@ class TSQLServerDataSource(DataSource, SecretsMixin, JDBCReaderMixin):
|
|
128
129
|
df = self.reader(schema_query).load()
|
129
130
|
schema_metadata = df.select([col(c).alias(c.lower()) for c in df.columns]).collect()
|
130
131
|
logger.info(f"Schema fetched successfully. Completed at: {datetime.now()}")
|
131
|
-
return [self._map_meta_column(field) for field in schema_metadata]
|
132
|
+
return [self._map_meta_column(field, normalize) for field in schema_metadata]
|
132
133
|
except (RuntimeError, PySparkException) as e:
|
133
134
|
return self.log_and_throw_exception(e, "schema", schema_query)
|
134
135
|
|
@@ -5,10 +5,12 @@ import sqlglot.expressions as exp
|
|
5
5
|
from sqlglot import Dialect, parse_one
|
6
6
|
|
7
7
|
from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource
|
8
|
+
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils
|
8
9
|
from databricks.labs.lakebridge.reconcile.exception import InvalidInputException
|
9
10
|
from databricks.labs.lakebridge.reconcile.query_builder.expression_generator import (
|
10
11
|
DataType_transform_mapping,
|
11
12
|
transform_expression,
|
13
|
+
build_column,
|
12
14
|
)
|
13
15
|
from databricks.labs.lakebridge.reconcile.recon_config import Schema, Table, Aggregate
|
14
16
|
from databricks.labs.lakebridge.transpiler.sqlglot.dialect_utils import get_dialect, SQLGLOT_DIALECTS
|
@@ -26,7 +28,7 @@ class QueryBuilder(ABC):
|
|
26
28
|
|
27
29
|
@property
|
28
30
|
def engine(self) -> Dialect:
|
29
|
-
return self._engine
|
31
|
+
return self._engine if self.layer == "source" else get_dialect("databricks")
|
30
32
|
|
31
33
|
@property
|
32
34
|
def layer(self) -> str:
|
@@ -66,7 +68,25 @@ class QueryBuilder(ABC):
|
|
66
68
|
|
67
69
|
@property
|
68
70
|
def user_transformations(self) -> dict[str, str]:
|
69
|
-
|
71
|
+
if self._table_conf.transformations:
|
72
|
+
if self.layer == "source":
|
73
|
+
return {
|
74
|
+
trans.column_name: (
|
75
|
+
trans.source
|
76
|
+
if trans.source
|
77
|
+
else self._data_source.normalize_identifier(trans.column_name).source_normalized
|
78
|
+
)
|
79
|
+
for trans in self._table_conf.transformations
|
80
|
+
}
|
81
|
+
return {
|
82
|
+
self._table_conf.get_layer_src_to_tgt_col_mapping(trans.column_name, self.layer): (
|
83
|
+
trans.target
|
84
|
+
if trans.target
|
85
|
+
else self._table_conf.get_layer_src_to_tgt_col_mapping(trans.column_name, self.layer)
|
86
|
+
)
|
87
|
+
for trans in self._table_conf.transformations
|
88
|
+
}
|
89
|
+
return {}
|
70
90
|
|
71
91
|
@property
|
72
92
|
def aggregates(self) -> list[Aggregate] | None:
|
@@ -89,10 +109,12 @@ class QueryBuilder(ABC):
|
|
89
109
|
|
90
110
|
def _user_transformer(self, node: exp.Expression, user_transformations: dict[str, str]) -> exp.Expression:
|
91
111
|
if isinstance(node, exp.Column) and user_transformations:
|
92
|
-
|
93
|
-
|
94
|
-
if
|
95
|
-
return parse_one(
|
112
|
+
normalized_column = self._data_source.normalize_identifier(node.name)
|
113
|
+
ansi_name = normalized_column.ansi_normalized
|
114
|
+
if ansi_name in user_transformations.keys():
|
115
|
+
return parse_one(
|
116
|
+
user_transformations.get(ansi_name, normalized_column.source_normalized), read=self.engine
|
117
|
+
)
|
96
118
|
return node
|
97
119
|
|
98
120
|
def _apply_default_transformation(
|
@@ -103,8 +125,7 @@ class QueryBuilder(ABC):
|
|
103
125
|
with_transform.append(alias.transform(self._default_transformer, schema, source))
|
104
126
|
return with_transform
|
105
127
|
|
106
|
-
|
107
|
-
def _default_transformer(node: exp.Expression, schema: list[Schema], source: Dialect) -> exp.Expression:
|
128
|
+
def _default_transformer(self, node: exp.Expression, schema: list[Schema], source: Dialect) -> exp.Expression:
|
108
129
|
|
109
130
|
def _get_transform(datatype: str):
|
110
131
|
source_dialects = [source_key for source_key, dialect in SQLGLOT_DIALECTS.items() if dialect == source]
|
@@ -121,9 +142,10 @@ class QueryBuilder(ABC):
|
|
121
142
|
|
122
143
|
schema_dict = {v.column_name: v.data_type for v in schema}
|
123
144
|
if isinstance(node, exp.Column):
|
124
|
-
|
125
|
-
|
126
|
-
|
145
|
+
normalized_column = self._data_source.normalize_identifier(node.name)
|
146
|
+
ansi_name = normalized_column.ansi_normalized
|
147
|
+
if ansi_name in schema_dict.keys():
|
148
|
+
transform = _get_transform(schema_dict.get(ansi_name, normalized_column.source_normalized))
|
127
149
|
return transform_expression(node, transform)
|
128
150
|
return node
|
129
151
|
|
@@ -132,3 +154,20 @@ class QueryBuilder(ABC):
|
|
132
154
|
message = f"Exception for {self.table_conf.target_name} target table in {self.layer} layer --> {message}"
|
133
155
|
logger.error(message)
|
134
156
|
raise InvalidInputException(message)
|
157
|
+
|
158
|
+
def _build_column_with_alias(self, column: str):
|
159
|
+
return build_column(
|
160
|
+
this=self._build_column_name_source_normalized(column),
|
161
|
+
alias=DialectUtils.unnormalize_identifier(
|
162
|
+
self.table_conf.get_layer_tgt_to_src_col_mapping(column, self.layer)
|
163
|
+
),
|
164
|
+
quoted=True,
|
165
|
+
)
|
166
|
+
|
167
|
+
def _build_column_name_source_normalized(self, column: str):
|
168
|
+
return self._data_source.normalize_identifier(column).source_normalized
|
169
|
+
|
170
|
+
def _build_alias_source_normalized(self, column: str):
|
171
|
+
return self._data_source.normalize_identifier(
|
172
|
+
self.table_conf.get_layer_tgt_to_src_col_mapping(column, self.layer)
|
173
|
+
).source_normalized
|
@@ -125,6 +125,7 @@ def anonymous(expr: exp.Column, func: str, is_expr: bool = False, dialect=None)
|
|
125
125
|
return new_expr
|
126
126
|
|
127
127
|
|
128
|
+
# TODO Standardize impl and use quoted and Identifier/Column consistently
|
128
129
|
def build_column(this: exp.ExpOrStr, table_name="", quoted=False, alias=None) -> exp.Expression:
|
129
130
|
if alias:
|
130
131
|
if isinstance(this, str):
|
@@ -135,6 +136,10 @@ def build_column(this: exp.ExpOrStr, table_name="", quoted=False, alias=None) ->
|
|
135
136
|
return exp.Column(this=exp.Identifier(this=this, quoted=quoted), table=table_name)
|
136
137
|
|
137
138
|
|
139
|
+
def build_column_no_alias(this: str, table_name="") -> exp.Expression:
|
140
|
+
return exp.Column(this=this, table=table_name)
|
141
|
+
|
142
|
+
|
138
143
|
def build_literal(this: exp.ExpOrStr, alias=None, quoted=False, is_string=True, cast=None) -> exp.Expression:
|
139
144
|
base_literal = exp.Literal(this=this, is_string=is_string)
|
140
145
|
if not cast and not alias:
|
@@ -207,10 +212,11 @@ def build_sub(
|
|
207
212
|
right_column_name: str,
|
208
213
|
left_table_name: str | None = None,
|
209
214
|
right_table_name: str | None = None,
|
215
|
+
quoted: bool = False,
|
210
216
|
) -> exp.Sub:
|
211
217
|
return exp.Sub(
|
212
|
-
this=build_column(left_column_name, left_table_name),
|
213
|
-
expression=build_column(right_column_name, right_table_name),
|
218
|
+
this=build_column(left_column_name, left_table_name, quoted=quoted),
|
219
|
+
expression=build_column(right_column_name, right_table_name, quoted=quoted),
|
214
220
|
)
|
215
221
|
|
216
222
|
|
@@ -11,8 +11,8 @@ from databricks.labs.lakebridge.reconcile.query_builder.expression_generator imp
|
|
11
11
|
get_hash_transform,
|
12
12
|
lower,
|
13
13
|
transform_expression,
|
14
|
+
build_column_no_alias,
|
14
15
|
)
|
15
|
-
from databricks.labs.lakebridge.transpiler.sqlglot.dialect_utils import get_dialect
|
16
16
|
|
17
17
|
logger = logging.getLogger(__name__)
|
18
18
|
|
@@ -41,15 +41,12 @@ class HashQueryBuilder(QueryBuilder):
|
|
41
41
|
|
42
42
|
key_cols = hash_cols if report_type == "row" else sorted(_join_columns | self.partition_column)
|
43
43
|
|
44
|
-
cols_with_alias = [
|
45
|
-
build_column(this=col, alias=self.table_conf.get_layer_tgt_to_src_col_mapping(col, self.layer))
|
46
|
-
for col in key_cols
|
47
|
-
]
|
44
|
+
cols_with_alias = [self._build_column_with_alias(col) for col in key_cols]
|
48
45
|
|
49
46
|
# in case if we have column mapping, we need to sort the target columns in the order of source columns to get
|
50
47
|
# same hash value
|
51
48
|
hash_cols_with_alias = [
|
52
|
-
{"this": col, "alias": self.
|
49
|
+
{"this": self._build_column_name_source_normalized(col), "alias": self._build_alias_source_normalized(col)}
|
53
50
|
for col in hash_cols
|
54
51
|
]
|
55
52
|
sorted_hash_cols_with_alias = sorted(hash_cols_with_alias, key=lambda column: column["alias"])
|
@@ -60,12 +57,11 @@ class HashQueryBuilder(QueryBuilder):
|
|
60
57
|
)
|
61
58
|
hash_col_with_transform = [self._generate_hash_algorithm(hashcols_sorted_as_src_seq, _HASH_COLUMN_NAME)]
|
62
59
|
|
63
|
-
dialect = self.engine if self.layer == "source" else get_dialect("databricks")
|
64
60
|
res = (
|
65
61
|
exp.select(*hash_col_with_transform + key_cols_with_transform)
|
66
62
|
.from_(":tbl")
|
67
|
-
.where(self.filter)
|
68
|
-
.sql(dialect=
|
63
|
+
.where(self.filter, dialect=self.engine)
|
64
|
+
.sql(dialect=self.engine)
|
69
65
|
)
|
70
66
|
|
71
67
|
logger.info(f"Hash Query for {self.layer}: {res}")
|
@@ -76,10 +72,8 @@ class HashQueryBuilder(QueryBuilder):
|
|
76
72
|
cols: list[str],
|
77
73
|
column_alias: str,
|
78
74
|
) -> exp.Expression:
|
79
|
-
|
80
|
-
cols_with_transform = self.add_transformations(
|
81
|
-
cols_with_alias, self.engine if self.layer == "source" else get_dialect("databricks")
|
82
|
-
)
|
75
|
+
cols_no_alias = [build_column_no_alias(this=col) for col in cols]
|
76
|
+
cols_with_transform = self.add_transformations(cols_no_alias, self.engine)
|
83
77
|
col_exprs = exp.select(*cols_with_transform).iter_expressions()
|
84
78
|
concat_expr = concat(list(col_exprs))
|
85
79
|
|
@@ -4,6 +4,7 @@ import sqlglot.expressions as exp
|
|
4
4
|
from pyspark.sql import DataFrame
|
5
5
|
from sqlglot import select
|
6
6
|
|
7
|
+
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils
|
7
8
|
from databricks.labs.lakebridge.transpiler.sqlglot.dialect_utils import get_key_from_dialect
|
8
9
|
from databricks.labs.lakebridge.reconcile.query_builder.base import QueryBuilder
|
9
10
|
from databricks.labs.lakebridge.reconcile.query_builder.expression_generator import (
|
@@ -37,12 +38,9 @@ class SamplingQueryBuilder(QueryBuilder):
|
|
37
38
|
|
38
39
|
cols = sorted((join_columns | self.select_columns) - self.threshold_columns - self.drop_columns)
|
39
40
|
|
40
|
-
cols_with_alias = [
|
41
|
-
build_column(this=col, alias=self.table_conf.get_layer_tgt_to_src_col_mapping(col, self.layer))
|
42
|
-
for col in cols
|
43
|
-
]
|
41
|
+
cols_with_alias = [self._build_column_with_alias(col) for col in cols]
|
44
42
|
|
45
|
-
query = select(*cols_with_alias).from_(":tbl").where(self.filter).sql(dialect=self.engine)
|
43
|
+
query = select(*cols_with_alias).from_(":tbl").where(self.filter, dialect=self.engine).sql(dialect=self.engine)
|
46
44
|
|
47
45
|
logger.info(f"Sampling Query with Alias for {self.layer}: {query}")
|
48
46
|
return query
|
@@ -59,22 +57,22 @@ class SamplingQueryBuilder(QueryBuilder):
|
|
59
57
|
|
60
58
|
cols = sorted((join_columns | self.select_columns) - self.threshold_columns - self.drop_columns)
|
61
59
|
|
62
|
-
cols_with_alias = [
|
63
|
-
build_column(this=col, alias=self.table_conf.get_layer_tgt_to_src_col_mapping(col, self.layer))
|
64
|
-
for col in cols
|
65
|
-
]
|
60
|
+
cols_with_alias = [self._build_column_with_alias(col) for col in cols]
|
66
61
|
|
67
62
|
sql_with_transforms = self.add_transformations(cols_with_alias, self.engine)
|
68
|
-
query_sql = select(*sql_with_transforms).from_(":tbl").where(self.filter)
|
63
|
+
query_sql = select(*sql_with_transforms).from_(":tbl").where(self.filter, dialect=self.engine)
|
69
64
|
if self.layer == "source":
|
70
|
-
with_select = [
|
65
|
+
with_select = [
|
66
|
+
build_column(this=DialectUtils.unnormalize_identifier(col), table_name="src", quoted=True)
|
67
|
+
for col in sorted(cols)
|
68
|
+
]
|
71
69
|
else:
|
72
70
|
with_select = [
|
73
|
-
build_column(this=col, table_name="src")
|
71
|
+
build_column(this=DialectUtils.unnormalize_identifier(col), table_name="src", quoted=True)
|
74
72
|
for col in sorted(self.table_conf.get_tgt_to_src_col_mapping_list(cols))
|
75
73
|
]
|
76
74
|
|
77
|
-
join_clause =
|
75
|
+
join_clause = self._get_join_clause(key_cols)
|
78
76
|
|
79
77
|
query = (
|
80
78
|
with_clause.with_(alias="src", as_=query_sql)
|
@@ -86,10 +84,10 @@ class SamplingQueryBuilder(QueryBuilder):
|
|
86
84
|
logger.info(f"Sampling Query for {self.layer}: {query}")
|
87
85
|
return query
|
88
86
|
|
89
|
-
|
90
|
-
|
87
|
+
def _get_join_clause(self, key_cols: list):
|
88
|
+
normalized = [self._build_column_name_source_normalized(col) for col in key_cols]
|
91
89
|
return build_join_clause(
|
92
|
-
"recon",
|
90
|
+
"recon", normalized, source_table_alias="src", target_table_alias="recon", kind="inner", func=exp.EQ
|
93
91
|
)
|
94
92
|
|
95
93
|
def _get_with_clause(self, df: DataFrame) -> exp.Select:
|
@@ -106,12 +104,13 @@ class SamplingQueryBuilder(QueryBuilder):
|
|
106
104
|
(
|
107
105
|
build_literal(
|
108
106
|
this=str(value),
|
109
|
-
alias=col,
|
107
|
+
alias=DialectUtils.unnormalize_identifier(col),
|
110
108
|
is_string=_get_is_string(column_types_dict, col),
|
111
|
-
cast=orig_types_dict.get(col),
|
109
|
+
cast=orig_types_dict.get(DialectUtils.ansi_normalize_identifier(col)),
|
110
|
+
quoted=True,
|
112
111
|
)
|
113
112
|
if value is not None
|
114
|
-
else exp.Alias(this=exp.Null(), alias=col)
|
113
|
+
else exp.Alias(this=exp.Null(), alias=DialectUtils.unnormalize_identifier(col), quoted=True)
|
115
114
|
)
|
116
115
|
for col, value in zip(df.columns, row)
|
117
116
|
]
|
@@ -3,6 +3,7 @@ import logging
|
|
3
3
|
from sqlglot import expressions as exp
|
4
4
|
from sqlglot import select
|
5
5
|
|
6
|
+
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils
|
6
7
|
from databricks.labs.lakebridge.reconcile.query_builder.base import QueryBuilder
|
7
8
|
from databricks.labs.lakebridge.reconcile.query_builder.expression_generator import (
|
8
9
|
anonymous,
|
@@ -54,6 +55,7 @@ class ThresholdQueryBuilder(QueryBuilder):
|
|
54
55
|
left_table_name="source",
|
55
56
|
right_column_name=column,
|
56
57
|
right_table_name="databricks",
|
58
|
+
quoted=False,
|
57
59
|
)
|
58
60
|
).transform(coalesce)
|
59
61
|
|
@@ -62,7 +64,14 @@ class ThresholdQueryBuilder(QueryBuilder):
|
|
62
64
|
where_clause.append(where)
|
63
65
|
# join columns
|
64
66
|
for column in sorted(join_columns):
|
65
|
-
select_clause.append(
|
67
|
+
select_clause.append(
|
68
|
+
build_column(
|
69
|
+
this=column,
|
70
|
+
alias=f"{DialectUtils.unnormalize_identifier(column)}_source",
|
71
|
+
table_name="source",
|
72
|
+
quoted=True,
|
73
|
+
)
|
74
|
+
)
|
66
75
|
where = build_where_clause(where_clause)
|
67
76
|
|
68
77
|
return select_clause, where
|
@@ -76,10 +85,20 @@ class ThresholdQueryBuilder(QueryBuilder):
|
|
76
85
|
select_clause = []
|
77
86
|
column = threshold.column_name
|
78
87
|
select_clause.append(
|
79
|
-
build_column(
|
88
|
+
build_column(
|
89
|
+
this=column,
|
90
|
+
alias=f"{DialectUtils.unnormalize_identifier(column)}_source",
|
91
|
+
table_name="source",
|
92
|
+
quoted=True,
|
93
|
+
).transform(coalesce)
|
80
94
|
)
|
81
95
|
select_clause.append(
|
82
|
-
build_column(
|
96
|
+
build_column(
|
97
|
+
this=column,
|
98
|
+
alias=f"{DialectUtils.unnormalize_identifier(column)}_databricks",
|
99
|
+
table_name="databricks",
|
100
|
+
quoted=True,
|
101
|
+
).transform(coalesce)
|
83
102
|
)
|
84
103
|
where_clause = exp.NEQ(this=base, expression=exp.Literal(this="0", is_string=False))
|
85
104
|
return select_clause, where_clause
|
@@ -110,7 +129,13 @@ class ThresholdQueryBuilder(QueryBuilder):
|
|
110
129
|
logger.error(error_message)
|
111
130
|
raise ValueError(error_message)
|
112
131
|
|
113
|
-
select_clause.append(
|
132
|
+
select_clause.append(
|
133
|
+
build_column(
|
134
|
+
this=func(base=base, threshold=threshold),
|
135
|
+
alias=f"{DialectUtils.unnormalize_identifier(column)}_match",
|
136
|
+
quoted=True,
|
137
|
+
)
|
138
|
+
)
|
114
139
|
|
115
140
|
return select_clause, where_clause
|
116
141
|
|
@@ -170,8 +195,8 @@ class ThresholdQueryBuilder(QueryBuilder):
|
|
170
195
|
),
|
171
196
|
expression=exp.Is(
|
172
197
|
this=exp.Column(
|
173
|
-
this=
|
174
|
-
table=
|
198
|
+
this=threshold.column_name,
|
199
|
+
table="databricks",
|
175
200
|
),
|
176
201
|
expression=exp.Null(),
|
177
202
|
),
|
@@ -211,21 +236,17 @@ class ThresholdQueryBuilder(QueryBuilder):
|
|
211
236
|
self._validate(self.join_columns, "Join Columns are compulsory for threshold query")
|
212
237
|
join_columns = self.join_columns if self.join_columns else set()
|
213
238
|
keys: list[str] = sorted(self.partition_column.union(join_columns))
|
214
|
-
keys_select_alias = [
|
215
|
-
build_column(this=col, alias=self.table_conf.get_layer_tgt_to_src_col_mapping(col, self.layer))
|
216
|
-
for col in keys
|
217
|
-
]
|
239
|
+
keys_select_alias = [self._build_column_with_alias(col) for col in keys]
|
218
240
|
keys_expr = self._apply_user_transformation(keys_select_alias)
|
219
241
|
|
220
242
|
# threshold column expression
|
221
|
-
threshold_alias = [
|
222
|
-
build_column(this=col, alias=self.table_conf.get_layer_tgt_to_src_col_mapping(col, self.layer))
|
223
|
-
for col in sorted(self.threshold_columns)
|
224
|
-
]
|
243
|
+
threshold_alias = [self._build_column_with_alias(col) for col in sorted(self.threshold_columns)]
|
225
244
|
thresholds_expr = threshold_alias
|
226
245
|
if self.user_transformations:
|
227
246
|
thresholds_expr = self._apply_user_transformation(threshold_alias)
|
228
247
|
|
229
|
-
query = (select(*keys_expr + thresholds_expr).from_(":tbl").where(self.filter
|
248
|
+
query = (select(*keys_expr + thresholds_expr).from_(":tbl").where(self.filter, dialect=self.engine)).sql(
|
249
|
+
dialect=self.engine
|
250
|
+
)
|
230
251
|
logger.info(f"Threshold Query for {self.layer}: {query}")
|
231
252
|
return query
|
@@ -257,21 +257,6 @@ class Table:
|
|
257
257
|
return set()
|
258
258
|
return {self.get_layer_src_to_tgt_col_mapping(col, layer) for col in self.drop_columns}
|
259
259
|
|
260
|
-
def get_transformation_dict(self, layer: str) -> dict[str, str]:
|
261
|
-
if self.transformations:
|
262
|
-
if layer == "source":
|
263
|
-
return {
|
264
|
-
trans.column_name: (trans.source if trans.source else trans.column_name)
|
265
|
-
for trans in self.transformations
|
266
|
-
}
|
267
|
-
return {
|
268
|
-
self.get_layer_src_to_tgt_col_mapping(trans.column_name, layer): (
|
269
|
-
trans.target if trans.target else self.get_layer_src_to_tgt_col_mapping(trans.column_name, layer)
|
270
|
-
)
|
271
|
-
for trans in self.transformations
|
272
|
-
}
|
273
|
-
return {}
|
274
|
-
|
275
260
|
def get_partition_column(self, layer: str) -> set[str]:
|
276
261
|
if self.jdbc_reader_options and layer == "source":
|
277
262
|
if self.jdbc_reader_options.partition_column:
|
@@ -15,6 +15,7 @@ from databricks.labs.lakebridge.reconcile.compare import (
|
|
15
15
|
reconcile_agg_data_per_rule,
|
16
16
|
)
|
17
17
|
from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource
|
18
|
+
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils
|
18
19
|
from databricks.labs.lakebridge.reconcile.exception import (
|
19
20
|
DataSourceRuntimeException,
|
20
21
|
)
|
@@ -455,7 +456,9 @@ class Reconciliation:
|
|
455
456
|
options=table_conf.jdbc_reader_options,
|
456
457
|
)
|
457
458
|
threshold_columns = table_conf.get_threshold_columns("source")
|
458
|
-
failed_where_cond = " OR ".join(
|
459
|
+
failed_where_cond = " OR ".join(
|
460
|
+
["`" + DialectUtils.unnormalize_identifier(name) + "_match` = 'Failed'" for name in threshold_columns]
|
461
|
+
)
|
459
462
|
mismatched_df = threshold_result.filter(failed_where_cond)
|
460
463
|
mismatched_count = mismatched_df.count()
|
461
464
|
threshold_df = None
|