databricks-labs-lakebridge 0.10.7__py3-none-any.whl → 0.10.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. databricks/labs/lakebridge/__about__.py +1 -1
  2. databricks/labs/lakebridge/assessments/profiler_validator.py +103 -0
  3. databricks/labs/lakebridge/base_install.py +1 -5
  4. databricks/labs/lakebridge/cli.py +13 -6
  5. databricks/labs/lakebridge/helpers/validation.py +5 -3
  6. databricks/labs/lakebridge/install.py +40 -481
  7. databricks/labs/lakebridge/reconcile/connectors/data_source.py +9 -5
  8. databricks/labs/lakebridge/reconcile/connectors/databricks.py +2 -1
  9. databricks/labs/lakebridge/reconcile/connectors/oracle.py +2 -1
  10. databricks/labs/lakebridge/reconcile/connectors/secrets.py +19 -1
  11. databricks/labs/lakebridge/reconcile/connectors/snowflake.py +50 -29
  12. databricks/labs/lakebridge/reconcile/connectors/tsql.py +2 -1
  13. databricks/labs/lakebridge/reconcile/query_builder/base.py +50 -11
  14. databricks/labs/lakebridge/reconcile/query_builder/expression_generator.py +8 -2
  15. databricks/labs/lakebridge/reconcile/query_builder/hash_query.py +7 -13
  16. databricks/labs/lakebridge/reconcile/query_builder/sampling_query.py +18 -19
  17. databricks/labs/lakebridge/reconcile/query_builder/threshold_query.py +36 -15
  18. databricks/labs/lakebridge/reconcile/recon_config.py +0 -15
  19. databricks/labs/lakebridge/reconcile/reconciliation.py +4 -1
  20. databricks/labs/lakebridge/reconcile/trigger_recon_aggregate_service.py +11 -31
  21. databricks/labs/lakebridge/reconcile/trigger_recon_service.py +4 -1
  22. databricks/labs/lakebridge/transpiler/execute.py +34 -28
  23. databricks/labs/lakebridge/transpiler/installers.py +523 -0
  24. databricks/labs/lakebridge/transpiler/lsp/lsp_engine.py +2 -0
  25. {databricks_labs_lakebridge-0.10.7.dist-info → databricks_labs_lakebridge-0.10.8.dist-info}/METADATA +1 -1
  26. {databricks_labs_lakebridge-0.10.7.dist-info → databricks_labs_lakebridge-0.10.8.dist-info}/RECORD +30 -28
  27. {databricks_labs_lakebridge-0.10.7.dist-info → databricks_labs_lakebridge-0.10.8.dist-info}/WHEEL +0 -0
  28. {databricks_labs_lakebridge-0.10.7.dist-info → databricks_labs_lakebridge-0.10.8.dist-info}/entry_points.txt +0 -0
  29. {databricks_labs_lakebridge-0.10.7.dist-info → databricks_labs_lakebridge-0.10.8.dist-info}/licenses/LICENSE +0 -0
  30. {databricks_labs_lakebridge-0.10.7.dist-info → databricks_labs_lakebridge-0.10.8.dist-info}/licenses/NOTICE +0 -0
@@ -11,8 +11,26 @@ class SecretsMixin:
11
11
  _ws: WorkspaceClient
12
12
  _secret_scope: str
13
13
 
14
+ def _get_secret_or_none(self, secret_key: str) -> str | None:
15
+ """
16
+ Get the secret value given a secret scope & secret key. Log a warning if secret does not exist
17
+ Used To ensure backwards compatibility when supporting new secrets
18
+ """
19
+ try:
20
+ # Return the decoded secret value in string format
21
+ return self._get_secret(secret_key)
22
+ except NotFound as e:
23
+ logger.warning(f"Secret not found: key={secret_key}")
24
+ logger.debug("Secret lookup failed", exc_info=e)
25
+ return None
26
+
14
27
  def _get_secret(self, secret_key: str) -> str:
15
- """Get the secret value given a secret scope & secret key. Log a warning if secret does not exist"""
28
+ """Get the secret value given a secret scope & secret key.
29
+
30
+ Raises:
31
+ NotFound: The secret could not be found.
32
+ UnicodeDecodeError: The secret value was not Base64-encoded UTF-8.
33
+ """
16
34
  try:
17
35
  # Return the decoded secret value in string format
18
36
  secret = self._ws.secrets.get_secret(self._secret_scope, secret_key)
@@ -79,31 +79,6 @@ class SnowflakeDataSource(DataSource, SecretsMixin, JDBCReaderMixin):
79
79
  f"&warehouse={self._get_secret('sfWarehouse')}&role={self._get_secret('sfRole')}"
80
80
  )
81
81
 
82
- @staticmethod
83
- def get_private_key(pem_private_key: str) -> str:
84
- try:
85
- private_key_bytes = pem_private_key.encode("UTF-8")
86
- p_key = serialization.load_pem_private_key(
87
- private_key_bytes,
88
- password=None,
89
- backend=default_backend(),
90
- )
91
- pkb = p_key.private_bytes(
92
- encoding=serialization.Encoding.PEM,
93
- format=serialization.PrivateFormat.PKCS8,
94
- encryption_algorithm=serialization.NoEncryption(),
95
- )
96
- pkb_str = pkb.decode("UTF-8")
97
- # Remove the first and last lines (BEGIN/END markers)
98
- private_key_pem_lines = pkb_str.strip().split('\n')[1:-1]
99
- # Join the lines to form the base64 encoded string
100
- private_key_pem_str = ''.join(private_key_pem_lines)
101
- return private_key_pem_str
102
- except Exception as e:
103
- message = f"Failed to load or process the provided PEM private key. --> {e}"
104
- logger.error(message)
105
- raise InvalidSnowflakePemPrivateKey(message) from e
106
-
107
82
  def read_data(
108
83
  self,
109
84
  catalog: str | None,
@@ -132,6 +107,7 @@ class SnowflakeDataSource(DataSource, SecretsMixin, JDBCReaderMixin):
132
107
  catalog: str | None,
133
108
  schema: str,
134
109
  table: str,
110
+ normalize: bool = True,
135
111
  ) -> list[Schema]:
136
112
  """
137
113
  Fetch the Schema from the INFORMATION_SCHEMA.COLUMNS table in Snowflake.
@@ -151,11 +127,17 @@ class SnowflakeDataSource(DataSource, SecretsMixin, JDBCReaderMixin):
151
127
  df = self.reader(schema_query).load()
152
128
  schema_metadata = df.select([col(c).alias(c.lower()) for c in df.columns]).collect()
153
129
  logger.info(f"Schema fetched successfully. Completed at: {datetime.now()}")
154
- return [self._map_meta_column(field) for field in schema_metadata]
130
+ return [self._map_meta_column(field, normalize) for field in schema_metadata]
155
131
  except (RuntimeError, PySparkException) as e:
156
132
  return self.log_and_throw_exception(e, "schema", schema_query)
157
133
 
158
134
  def reader(self, query: str) -> DataFrameReader:
135
+ options = self._get_snowflake_options()
136
+ return self._spark.read.format("snowflake").option("dbtable", f"({query}) as tmp").options(**options)
137
+
138
+ # TODO cache this method using @functools.cache
139
+ # Pay attention to https://pylint.pycqa.org/en/latest/user_guide/messages/warning/method-cache-max-size-none.html
140
+ def _get_snowflake_options(self):
159
141
  options = {
160
142
  "sfUrl": self._get_secret('sfUrl'),
161
143
  "sfUser": self._get_secret('sfUser'),
@@ -164,18 +146,57 @@ class SnowflakeDataSource(DataSource, SecretsMixin, JDBCReaderMixin):
164
146
  "sfWarehouse": self._get_secret('sfWarehouse'),
165
147
  "sfRole": self._get_secret('sfRole'),
166
148
  }
149
+ options = options | self._get_snowflake_auth_options()
150
+
151
+ return options
152
+
153
+ def _get_snowflake_auth_options(self):
167
154
  try:
168
- options["pem_private_key"] = SnowflakeDataSource.get_private_key(self._get_secret('pem_private_key'))
155
+ key = SnowflakeDataSource._get_private_key(
156
+ self._get_secret('pem_private_key'), self._get_secret_or_none('pem_private_key_password')
157
+ )
158
+ return {"pem_private_key": key}
169
159
  except (NotFound, KeyError):
170
160
  logger.warning("pem_private_key not found. Checking for sfPassword")
171
161
  try:
172
- options["sfPassword"] = self._get_secret('sfPassword')
162
+ password = self._get_secret('sfPassword')
163
+ return {"sfPassword": password}
173
164
  except (NotFound, KeyError) as e:
174
165
  message = "sfPassword and pem_private_key not found. Either one is required for snowflake auth."
175
166
  logger.error(message)
176
167
  raise NotFound(message) from e
177
168
 
178
- return self._spark.read.format("snowflake").option("dbtable", f"({query}) as tmp").options(**options)
169
+ @staticmethod
170
+ def _get_private_key(pem_private_key: str, pem_private_key_password: str | None) -> str:
171
+ try:
172
+ private_key_bytes = pem_private_key.encode("UTF-8")
173
+ password_bytes = pem_private_key_password.encode("UTF-8") if pem_private_key_password else None
174
+ except UnicodeEncodeError as e:
175
+ message = f"Invalid pem key and/or pem password: unable to encode. --> {e}"
176
+ logger.error(message)
177
+ raise ValueError(message) from e
178
+
179
+ try:
180
+ p_key = serialization.load_pem_private_key(
181
+ private_key_bytes,
182
+ password_bytes,
183
+ backend=default_backend(),
184
+ )
185
+ pkb = p_key.private_bytes(
186
+ encoding=serialization.Encoding.PEM,
187
+ format=serialization.PrivateFormat.PKCS8,
188
+ encryption_algorithm=serialization.NoEncryption(),
189
+ )
190
+ pkb_str = pkb.decode("UTF-8")
191
+ # Remove the first and last lines (BEGIN/END markers)
192
+ private_key_pem_lines = pkb_str.strip().split('\n')[1:-1]
193
+ # Join the lines to form the base64 encoded string
194
+ private_key_pem_str = ''.join(private_key_pem_lines)
195
+ return private_key_pem_str
196
+ except Exception as e:
197
+ message = f"Failed to load or process the provided PEM private key. --> {e}"
198
+ logger.error(message)
199
+ raise InvalidSnowflakePemPrivateKey(message) from e
179
200
 
180
201
  def normalize_identifier(self, identifier: str) -> NormalizedIdentifier:
181
202
  return DialectUtils.normalize_identifier(
@@ -109,6 +109,7 @@ class TSQLServerDataSource(DataSource, SecretsMixin, JDBCReaderMixin):
109
109
  catalog: str | None,
110
110
  schema: str,
111
111
  table: str,
112
+ normalize: bool = True,
112
113
  ) -> list[Schema]:
113
114
  """
114
115
  Fetch the Schema from the INFORMATION_SCHEMA.COLUMNS table in SQL Server.
@@ -128,7 +129,7 @@ class TSQLServerDataSource(DataSource, SecretsMixin, JDBCReaderMixin):
128
129
  df = self.reader(schema_query).load()
129
130
  schema_metadata = df.select([col(c).alias(c.lower()) for c in df.columns]).collect()
130
131
  logger.info(f"Schema fetched successfully. Completed at: {datetime.now()}")
131
- return [self._map_meta_column(field) for field in schema_metadata]
132
+ return [self._map_meta_column(field, normalize) for field in schema_metadata]
132
133
  except (RuntimeError, PySparkException) as e:
133
134
  return self.log_and_throw_exception(e, "schema", schema_query)
134
135
 
@@ -5,10 +5,12 @@ import sqlglot.expressions as exp
5
5
  from sqlglot import Dialect, parse_one
6
6
 
7
7
  from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource
8
+ from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils
8
9
  from databricks.labs.lakebridge.reconcile.exception import InvalidInputException
9
10
  from databricks.labs.lakebridge.reconcile.query_builder.expression_generator import (
10
11
  DataType_transform_mapping,
11
12
  transform_expression,
13
+ build_column,
12
14
  )
13
15
  from databricks.labs.lakebridge.reconcile.recon_config import Schema, Table, Aggregate
14
16
  from databricks.labs.lakebridge.transpiler.sqlglot.dialect_utils import get_dialect, SQLGLOT_DIALECTS
@@ -26,7 +28,7 @@ class QueryBuilder(ABC):
26
28
 
27
29
  @property
28
30
  def engine(self) -> Dialect:
29
- return self._engine
31
+ return self._engine if self.layer == "source" else get_dialect("databricks")
30
32
 
31
33
  @property
32
34
  def layer(self) -> str:
@@ -66,7 +68,25 @@ class QueryBuilder(ABC):
66
68
 
67
69
  @property
68
70
  def user_transformations(self) -> dict[str, str]:
69
- return self._table_conf.get_transformation_dict(self._layer)
71
+ if self._table_conf.transformations:
72
+ if self.layer == "source":
73
+ return {
74
+ trans.column_name: (
75
+ trans.source
76
+ if trans.source
77
+ else self._data_source.normalize_identifier(trans.column_name).source_normalized
78
+ )
79
+ for trans in self._table_conf.transformations
80
+ }
81
+ return {
82
+ self._table_conf.get_layer_src_to_tgt_col_mapping(trans.column_name, self.layer): (
83
+ trans.target
84
+ if trans.target
85
+ else self._table_conf.get_layer_src_to_tgt_col_mapping(trans.column_name, self.layer)
86
+ )
87
+ for trans in self._table_conf.transformations
88
+ }
89
+ return {}
70
90
 
71
91
  @property
72
92
  def aggregates(self) -> list[Aggregate] | None:
@@ -89,10 +109,12 @@ class QueryBuilder(ABC):
89
109
 
90
110
  def _user_transformer(self, node: exp.Expression, user_transformations: dict[str, str]) -> exp.Expression:
91
111
  if isinstance(node, exp.Column) and user_transformations:
92
- dialect = self.engine if self.layer == "source" else get_dialect("databricks")
93
- column_name = node.name
94
- if column_name in user_transformations.keys():
95
- return parse_one(user_transformations.get(column_name, column_name), read=dialect)
112
+ normalized_column = self._data_source.normalize_identifier(node.name)
113
+ ansi_name = normalized_column.ansi_normalized
114
+ if ansi_name in user_transformations.keys():
115
+ return parse_one(
116
+ user_transformations.get(ansi_name, normalized_column.source_normalized), read=self.engine
117
+ )
96
118
  return node
97
119
 
98
120
  def _apply_default_transformation(
@@ -103,8 +125,7 @@ class QueryBuilder(ABC):
103
125
  with_transform.append(alias.transform(self._default_transformer, schema, source))
104
126
  return with_transform
105
127
 
106
- @staticmethod
107
- def _default_transformer(node: exp.Expression, schema: list[Schema], source: Dialect) -> exp.Expression:
128
+ def _default_transformer(self, node: exp.Expression, schema: list[Schema], source: Dialect) -> exp.Expression:
108
129
 
109
130
  def _get_transform(datatype: str):
110
131
  source_dialects = [source_key for source_key, dialect in SQLGLOT_DIALECTS.items() if dialect == source]
@@ -121,9 +142,10 @@ class QueryBuilder(ABC):
121
142
 
122
143
  schema_dict = {v.column_name: v.data_type for v in schema}
123
144
  if isinstance(node, exp.Column):
124
- column_name = node.name
125
- if column_name in schema_dict.keys():
126
- transform = _get_transform(schema_dict.get(column_name, column_name))
145
+ normalized_column = self._data_source.normalize_identifier(node.name)
146
+ ansi_name = normalized_column.ansi_normalized
147
+ if ansi_name in schema_dict.keys():
148
+ transform = _get_transform(schema_dict.get(ansi_name, normalized_column.source_normalized))
127
149
  return transform_expression(node, transform)
128
150
  return node
129
151
 
@@ -132,3 +154,20 @@ class QueryBuilder(ABC):
132
154
  message = f"Exception for {self.table_conf.target_name} target table in {self.layer} layer --> {message}"
133
155
  logger.error(message)
134
156
  raise InvalidInputException(message)
157
+
158
+ def _build_column_with_alias(self, column: str):
159
+ return build_column(
160
+ this=self._build_column_name_source_normalized(column),
161
+ alias=DialectUtils.unnormalize_identifier(
162
+ self.table_conf.get_layer_tgt_to_src_col_mapping(column, self.layer)
163
+ ),
164
+ quoted=True,
165
+ )
166
+
167
+ def _build_column_name_source_normalized(self, column: str):
168
+ return self._data_source.normalize_identifier(column).source_normalized
169
+
170
+ def _build_alias_source_normalized(self, column: str):
171
+ return self._data_source.normalize_identifier(
172
+ self.table_conf.get_layer_tgt_to_src_col_mapping(column, self.layer)
173
+ ).source_normalized
@@ -125,6 +125,7 @@ def anonymous(expr: exp.Column, func: str, is_expr: bool = False, dialect=None)
125
125
  return new_expr
126
126
 
127
127
 
128
+ # TODO Standardize impl and use quoted and Identifier/Column consistently
128
129
  def build_column(this: exp.ExpOrStr, table_name="", quoted=False, alias=None) -> exp.Expression:
129
130
  if alias:
130
131
  if isinstance(this, str):
@@ -135,6 +136,10 @@ def build_column(this: exp.ExpOrStr, table_name="", quoted=False, alias=None) ->
135
136
  return exp.Column(this=exp.Identifier(this=this, quoted=quoted), table=table_name)
136
137
 
137
138
 
139
+ def build_column_no_alias(this: str, table_name="") -> exp.Expression:
140
+ return exp.Column(this=this, table=table_name)
141
+
142
+
138
143
  def build_literal(this: exp.ExpOrStr, alias=None, quoted=False, is_string=True, cast=None) -> exp.Expression:
139
144
  base_literal = exp.Literal(this=this, is_string=is_string)
140
145
  if not cast and not alias:
@@ -207,10 +212,11 @@ def build_sub(
207
212
  right_column_name: str,
208
213
  left_table_name: str | None = None,
209
214
  right_table_name: str | None = None,
215
+ quoted: bool = False,
210
216
  ) -> exp.Sub:
211
217
  return exp.Sub(
212
- this=build_column(left_column_name, left_table_name),
213
- expression=build_column(right_column_name, right_table_name),
218
+ this=build_column(left_column_name, left_table_name, quoted=quoted),
219
+ expression=build_column(right_column_name, right_table_name, quoted=quoted),
214
220
  )
215
221
 
216
222
 
@@ -11,8 +11,8 @@ from databricks.labs.lakebridge.reconcile.query_builder.expression_generator imp
11
11
  get_hash_transform,
12
12
  lower,
13
13
  transform_expression,
14
+ build_column_no_alias,
14
15
  )
15
- from databricks.labs.lakebridge.transpiler.sqlglot.dialect_utils import get_dialect
16
16
 
17
17
  logger = logging.getLogger(__name__)
18
18
 
@@ -41,15 +41,12 @@ class HashQueryBuilder(QueryBuilder):
41
41
 
42
42
  key_cols = hash_cols if report_type == "row" else sorted(_join_columns | self.partition_column)
43
43
 
44
- cols_with_alias = [
45
- build_column(this=col, alias=self.table_conf.get_layer_tgt_to_src_col_mapping(col, self.layer))
46
- for col in key_cols
47
- ]
44
+ cols_with_alias = [self._build_column_with_alias(col) for col in key_cols]
48
45
 
49
46
  # in case if we have column mapping, we need to sort the target columns in the order of source columns to get
50
47
  # same hash value
51
48
  hash_cols_with_alias = [
52
- {"this": col, "alias": self.table_conf.get_layer_tgt_to_src_col_mapping(col, self.layer)}
49
+ {"this": self._build_column_name_source_normalized(col), "alias": self._build_alias_source_normalized(col)}
53
50
  for col in hash_cols
54
51
  ]
55
52
  sorted_hash_cols_with_alias = sorted(hash_cols_with_alias, key=lambda column: column["alias"])
@@ -60,12 +57,11 @@ class HashQueryBuilder(QueryBuilder):
60
57
  )
61
58
  hash_col_with_transform = [self._generate_hash_algorithm(hashcols_sorted_as_src_seq, _HASH_COLUMN_NAME)]
62
59
 
63
- dialect = self.engine if self.layer == "source" else get_dialect("databricks")
64
60
  res = (
65
61
  exp.select(*hash_col_with_transform + key_cols_with_transform)
66
62
  .from_(":tbl")
67
- .where(self.filter)
68
- .sql(dialect=dialect)
63
+ .where(self.filter, dialect=self.engine)
64
+ .sql(dialect=self.engine)
69
65
  )
70
66
 
71
67
  logger.info(f"Hash Query for {self.layer}: {res}")
@@ -76,10 +72,8 @@ class HashQueryBuilder(QueryBuilder):
76
72
  cols: list[str],
77
73
  column_alias: str,
78
74
  ) -> exp.Expression:
79
- cols_with_alias = [build_column(this=col, alias=None) for col in cols]
80
- cols_with_transform = self.add_transformations(
81
- cols_with_alias, self.engine if self.layer == "source" else get_dialect("databricks")
82
- )
75
+ cols_no_alias = [build_column_no_alias(this=col) for col in cols]
76
+ cols_with_transform = self.add_transformations(cols_no_alias, self.engine)
83
77
  col_exprs = exp.select(*cols_with_transform).iter_expressions()
84
78
  concat_expr = concat(list(col_exprs))
85
79
 
@@ -4,6 +4,7 @@ import sqlglot.expressions as exp
4
4
  from pyspark.sql import DataFrame
5
5
  from sqlglot import select
6
6
 
7
+ from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils
7
8
  from databricks.labs.lakebridge.transpiler.sqlglot.dialect_utils import get_key_from_dialect
8
9
  from databricks.labs.lakebridge.reconcile.query_builder.base import QueryBuilder
9
10
  from databricks.labs.lakebridge.reconcile.query_builder.expression_generator import (
@@ -37,12 +38,9 @@ class SamplingQueryBuilder(QueryBuilder):
37
38
 
38
39
  cols = sorted((join_columns | self.select_columns) - self.threshold_columns - self.drop_columns)
39
40
 
40
- cols_with_alias = [
41
- build_column(this=col, alias=self.table_conf.get_layer_tgt_to_src_col_mapping(col, self.layer))
42
- for col in cols
43
- ]
41
+ cols_with_alias = [self._build_column_with_alias(col) for col in cols]
44
42
 
45
- query = select(*cols_with_alias).from_(":tbl").where(self.filter).sql(dialect=self.engine)
43
+ query = select(*cols_with_alias).from_(":tbl").where(self.filter, dialect=self.engine).sql(dialect=self.engine)
46
44
 
47
45
  logger.info(f"Sampling Query with Alias for {self.layer}: {query}")
48
46
  return query
@@ -59,22 +57,22 @@ class SamplingQueryBuilder(QueryBuilder):
59
57
 
60
58
  cols = sorted((join_columns | self.select_columns) - self.threshold_columns - self.drop_columns)
61
59
 
62
- cols_with_alias = [
63
- build_column(this=col, alias=self.table_conf.get_layer_tgt_to_src_col_mapping(col, self.layer))
64
- for col in cols
65
- ]
60
+ cols_with_alias = [self._build_column_with_alias(col) for col in cols]
66
61
 
67
62
  sql_with_transforms = self.add_transformations(cols_with_alias, self.engine)
68
- query_sql = select(*sql_with_transforms).from_(":tbl").where(self.filter)
63
+ query_sql = select(*sql_with_transforms).from_(":tbl").where(self.filter, dialect=self.engine)
69
64
  if self.layer == "source":
70
- with_select = [build_column(this=col, table_name="src") for col in sorted(cols)]
65
+ with_select = [
66
+ build_column(this=DialectUtils.unnormalize_identifier(col), table_name="src", quoted=True)
67
+ for col in sorted(cols)
68
+ ]
71
69
  else:
72
70
  with_select = [
73
- build_column(this=col, table_name="src")
71
+ build_column(this=DialectUtils.unnormalize_identifier(col), table_name="src", quoted=True)
74
72
  for col in sorted(self.table_conf.get_tgt_to_src_col_mapping_list(cols))
75
73
  ]
76
74
 
77
- join_clause = SamplingQueryBuilder._get_join_clause(key_cols)
75
+ join_clause = self._get_join_clause(key_cols)
78
76
 
79
77
  query = (
80
78
  with_clause.with_(alias="src", as_=query_sql)
@@ -86,10 +84,10 @@ class SamplingQueryBuilder(QueryBuilder):
86
84
  logger.info(f"Sampling Query for {self.layer}: {query}")
87
85
  return query
88
86
 
89
- @classmethod
90
- def _get_join_clause(cls, key_cols: list):
87
+ def _get_join_clause(self, key_cols: list):
88
+ normalized = [self._build_column_name_source_normalized(col) for col in key_cols]
91
89
  return build_join_clause(
92
- "recon", key_cols, source_table_alias="src", target_table_alias="recon", kind="inner", func=exp.EQ
90
+ "recon", normalized, source_table_alias="src", target_table_alias="recon", kind="inner", func=exp.EQ
93
91
  )
94
92
 
95
93
  def _get_with_clause(self, df: DataFrame) -> exp.Select:
@@ -106,12 +104,13 @@ class SamplingQueryBuilder(QueryBuilder):
106
104
  (
107
105
  build_literal(
108
106
  this=str(value),
109
- alias=col,
107
+ alias=DialectUtils.unnormalize_identifier(col),
110
108
  is_string=_get_is_string(column_types_dict, col),
111
- cast=orig_types_dict.get(col),
109
+ cast=orig_types_dict.get(DialectUtils.ansi_normalize_identifier(col)),
110
+ quoted=True,
112
111
  )
113
112
  if value is not None
114
- else exp.Alias(this=exp.Null(), alias=col)
113
+ else exp.Alias(this=exp.Null(), alias=DialectUtils.unnormalize_identifier(col), quoted=True)
115
114
  )
116
115
  for col, value in zip(df.columns, row)
117
116
  ]
@@ -3,6 +3,7 @@ import logging
3
3
  from sqlglot import expressions as exp
4
4
  from sqlglot import select
5
5
 
6
+ from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils
6
7
  from databricks.labs.lakebridge.reconcile.query_builder.base import QueryBuilder
7
8
  from databricks.labs.lakebridge.reconcile.query_builder.expression_generator import (
8
9
  anonymous,
@@ -54,6 +55,7 @@ class ThresholdQueryBuilder(QueryBuilder):
54
55
  left_table_name="source",
55
56
  right_column_name=column,
56
57
  right_table_name="databricks",
58
+ quoted=False,
57
59
  )
58
60
  ).transform(coalesce)
59
61
 
@@ -62,7 +64,14 @@ class ThresholdQueryBuilder(QueryBuilder):
62
64
  where_clause.append(where)
63
65
  # join columns
64
66
  for column in sorted(join_columns):
65
- select_clause.append(build_column(this=column, alias=f"{column}_source", table_name="source"))
67
+ select_clause.append(
68
+ build_column(
69
+ this=column,
70
+ alias=f"{DialectUtils.unnormalize_identifier(column)}_source",
71
+ table_name="source",
72
+ quoted=True,
73
+ )
74
+ )
66
75
  where = build_where_clause(where_clause)
67
76
 
68
77
  return select_clause, where
@@ -76,10 +85,20 @@ class ThresholdQueryBuilder(QueryBuilder):
76
85
  select_clause = []
77
86
  column = threshold.column_name
78
87
  select_clause.append(
79
- build_column(this=column, alias=f"{column}_source", table_name="source").transform(coalesce)
88
+ build_column(
89
+ this=column,
90
+ alias=f"{DialectUtils.unnormalize_identifier(column)}_source",
91
+ table_name="source",
92
+ quoted=True,
93
+ ).transform(coalesce)
80
94
  )
81
95
  select_clause.append(
82
- build_column(this=column, alias=f"{column}_databricks", table_name="databricks").transform(coalesce)
96
+ build_column(
97
+ this=column,
98
+ alias=f"{DialectUtils.unnormalize_identifier(column)}_databricks",
99
+ table_name="databricks",
100
+ quoted=True,
101
+ ).transform(coalesce)
83
102
  )
84
103
  where_clause = exp.NEQ(this=base, expression=exp.Literal(this="0", is_string=False))
85
104
  return select_clause, where_clause
@@ -110,7 +129,13 @@ class ThresholdQueryBuilder(QueryBuilder):
110
129
  logger.error(error_message)
111
130
  raise ValueError(error_message)
112
131
 
113
- select_clause.append(build_column(this=func(base=base, threshold=threshold), alias=f"{column}_match"))
132
+ select_clause.append(
133
+ build_column(
134
+ this=func(base=base, threshold=threshold),
135
+ alias=f"{DialectUtils.unnormalize_identifier(column)}_match",
136
+ quoted=True,
137
+ )
138
+ )
114
139
 
115
140
  return select_clause, where_clause
116
141
 
@@ -170,8 +195,8 @@ class ThresholdQueryBuilder(QueryBuilder):
170
195
  ),
171
196
  expression=exp.Is(
172
197
  this=exp.Column(
173
- this=exp.Identifier(this=threshold.column_name, quoted=False),
174
- table=exp.Identifier(this='databricks'),
198
+ this=threshold.column_name,
199
+ table="databricks",
175
200
  ),
176
201
  expression=exp.Null(),
177
202
  ),
@@ -211,21 +236,17 @@ class ThresholdQueryBuilder(QueryBuilder):
211
236
  self._validate(self.join_columns, "Join Columns are compulsory for threshold query")
212
237
  join_columns = self.join_columns if self.join_columns else set()
213
238
  keys: list[str] = sorted(self.partition_column.union(join_columns))
214
- keys_select_alias = [
215
- build_column(this=col, alias=self.table_conf.get_layer_tgt_to_src_col_mapping(col, self.layer))
216
- for col in keys
217
- ]
239
+ keys_select_alias = [self._build_column_with_alias(col) for col in keys]
218
240
  keys_expr = self._apply_user_transformation(keys_select_alias)
219
241
 
220
242
  # threshold column expression
221
- threshold_alias = [
222
- build_column(this=col, alias=self.table_conf.get_layer_tgt_to_src_col_mapping(col, self.layer))
223
- for col in sorted(self.threshold_columns)
224
- ]
243
+ threshold_alias = [self._build_column_with_alias(col) for col in sorted(self.threshold_columns)]
225
244
  thresholds_expr = threshold_alias
226
245
  if self.user_transformations:
227
246
  thresholds_expr = self._apply_user_transformation(threshold_alias)
228
247
 
229
- query = (select(*keys_expr + thresholds_expr).from_(":tbl").where(self.filter)).sql(dialect=self.engine)
248
+ query = (select(*keys_expr + thresholds_expr).from_(":tbl").where(self.filter, dialect=self.engine)).sql(
249
+ dialect=self.engine
250
+ )
230
251
  logger.info(f"Threshold Query for {self.layer}: {query}")
231
252
  return query
@@ -257,21 +257,6 @@ class Table:
257
257
  return set()
258
258
  return {self.get_layer_src_to_tgt_col_mapping(col, layer) for col in self.drop_columns}
259
259
 
260
- def get_transformation_dict(self, layer: str) -> dict[str, str]:
261
- if self.transformations:
262
- if layer == "source":
263
- return {
264
- trans.column_name: (trans.source if trans.source else trans.column_name)
265
- for trans in self.transformations
266
- }
267
- return {
268
- self.get_layer_src_to_tgt_col_mapping(trans.column_name, layer): (
269
- trans.target if trans.target else self.get_layer_src_to_tgt_col_mapping(trans.column_name, layer)
270
- )
271
- for trans in self.transformations
272
- }
273
- return {}
274
-
275
260
  def get_partition_column(self, layer: str) -> set[str]:
276
261
  if self.jdbc_reader_options and layer == "source":
277
262
  if self.jdbc_reader_options.partition_column:
@@ -15,6 +15,7 @@ from databricks.labs.lakebridge.reconcile.compare import (
15
15
  reconcile_agg_data_per_rule,
16
16
  )
17
17
  from databricks.labs.lakebridge.reconcile.connectors.data_source import DataSource
18
+ from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils
18
19
  from databricks.labs.lakebridge.reconcile.exception import (
19
20
  DataSourceRuntimeException,
20
21
  )
@@ -455,7 +456,9 @@ class Reconciliation:
455
456
  options=table_conf.jdbc_reader_options,
456
457
  )
457
458
  threshold_columns = table_conf.get_threshold_columns("source")
458
- failed_where_cond = " OR ".join([name + "_match = 'Failed'" for name in threshold_columns])
459
+ failed_where_cond = " OR ".join(
460
+ ["`" + DialectUtils.unnormalize_identifier(name) + "_match` = 'Failed'" for name in threshold_columns]
461
+ )
459
462
  mismatched_df = threshold_result.filter(failed_where_cond)
460
463
  mismatched_count = mismatched_df.count()
461
464
  threshold_df = None