acryl-datahub 1.0.0.4rc7__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of acryl-datahub might be problematic. Click here for more details.

@@ -15,10 +15,10 @@ import pathlib
15
15
  def _load_schema(schema_name: str) -> str:
16
16
  return (pathlib.Path(__file__).parent / f"{schema_name}.avsc").read_text()
17
17
 
18
- def getMetadataChangeEventSchema() -> str:
19
- return _load_schema("MetadataChangeEvent")
20
-
21
18
  def getMetadataChangeProposalSchema() -> str:
22
19
  return _load_schema("MetadataChangeProposal")
23
20
 
21
+ def getMetadataChangeEventSchema() -> str:
22
+ return _load_schema("MetadataChangeEvent")
23
+
24
24
  # fmt: on
@@ -1366,6 +1366,13 @@ class SqlParsingAggregator(Closeable):
1366
1366
  ):
1367
1367
  upstream_columns = [x[0] for x in upstream_columns_for_query]
1368
1368
  required_queries.add(query_id)
1369
+ query = queries_map[query_id]
1370
+
1371
+ column_logic = None
1372
+ for lineage_info in query.column_lineage:
1373
+ if lineage_info.downstream.column == downstream_column:
1374
+ column_logic = lineage_info.logic
1375
+ break
1369
1376
 
1370
1377
  upstream_aspect.fineGrainedLineages.append(
1371
1378
  models.FineGrainedLineageClass(
@@ -1383,7 +1390,16 @@ class SqlParsingAggregator(Closeable):
1383
1390
  if self.can_generate_query(query_id)
1384
1391
  else None
1385
1392
  ),
1386
- confidenceScore=queries_map[query_id].confidence_score,
1393
+ confidenceScore=query.confidence_score,
1394
+ transformOperation=(
1395
+ (
1396
+ f"COPY: {column_logic.column_logic}"
1397
+ if column_logic.is_direct_copy
1398
+ else f"SQL: {column_logic.column_logic}"
1399
+ )
1400
+ if column_logic
1401
+ else None
1402
+ ),
1387
1403
  )
1388
1404
  )
1389
1405
 
@@ -5,7 +5,18 @@ import functools
5
5
  import logging
6
6
  import traceback
7
7
  from collections import defaultdict
8
- from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union
8
+ from typing import (
9
+ AbstractSet,
10
+ Any,
11
+ Dict,
12
+ Iterable,
13
+ List,
14
+ Optional,
15
+ Set,
16
+ Tuple,
17
+ TypeVar,
18
+ Union,
19
+ )
9
20
 
10
21
  import pydantic.dataclasses
11
22
  import sqlglot
@@ -54,6 +65,7 @@ from datahub.utilities.cooperative_timeout import (
54
65
  CooperativeTimeoutError,
55
66
  cooperative_timeout,
56
67
  )
68
+ from datahub.utilities.ordered_set import OrderedSet
57
69
 
58
70
  assert SQLGLOT_PATCHED
59
71
 
@@ -128,19 +140,39 @@ class DownstreamColumnRef(_ParserBaseModel):
128
140
  return SchemaFieldDataTypeClass.from_obj(v)
129
141
 
130
142
 
143
+ class ColumnTransformation(_ParserBaseModel):
144
+ is_direct_copy: bool
145
+ column_logic: str
146
+
147
+
131
148
  class _ColumnLineageInfo(_ParserBaseModel):
132
149
  downstream: _DownstreamColumnRef
133
150
  upstreams: List[_ColumnRef]
134
151
 
135
- logic: Optional[str] = None
152
+ logic: Optional[ColumnTransformation] = None
136
153
 
137
154
 
138
155
  class ColumnLineageInfo(_ParserBaseModel):
139
156
  downstream: DownstreamColumnRef
140
157
  upstreams: List[ColumnRef]
141
158
 
142
- # Logic for this column, as a SQL expression.
143
- logic: Optional[str] = pydantic.Field(default=None, exclude=True)
159
+ logic: Optional[ColumnTransformation] = pydantic.Field(default=None)
160
+
161
+
162
+ class _JoinInfo(_ParserBaseModel):
163
+ join_type: str
164
+ left_tables: List[_TableName]
165
+ right_tables: List[_TableName]
166
+ on_clause: Optional[str]
167
+ columns_involved: List[_ColumnRef]
168
+
169
+
170
+ class JoinInfo(_ParserBaseModel):
171
+ join_type: str
172
+ left_tables: List[Urn]
173
+ right_tables: List[Urn]
174
+ on_clause: Optional[str]
175
+ columns_involved: List[ColumnRef]
144
176
 
145
177
 
146
178
  class SqlParsingDebugInfo(_ParserBaseModel):
@@ -178,6 +210,7 @@ class SqlParsingResult(_ParserBaseModel):
178
210
  out_tables: List[Urn]
179
211
 
180
212
  column_lineage: Optional[List[ColumnLineageInfo]] = None
213
+ joins: Optional[List[JoinInfo]] = None
181
214
 
182
215
  # TODO include formatted original sql logic
183
216
  # TODO include list of referenced columns
@@ -197,13 +230,19 @@ class SqlParsingResult(_ParserBaseModel):
197
230
  )
198
231
 
199
232
 
233
+ def _extract_table_names(
234
+ iterable: Iterable[sqlglot.exp.Table],
235
+ ) -> OrderedSet[_TableName]:
236
+ return OrderedSet(_TableName.from_sqlglot_table(table) for table in iterable)
237
+
238
+
200
239
  def _table_level_lineage(
201
240
  statement: sqlglot.Expression, dialect: sqlglot.Dialect
202
- ) -> Tuple[Set[_TableName], Set[_TableName]]:
241
+ ) -> Tuple[AbstractSet[_TableName], AbstractSet[_TableName]]:
203
242
  # Generate table-level lineage.
204
243
  modified = (
205
- {
206
- _TableName.from_sqlglot_table(expr.this)
244
+ _extract_table_names(
245
+ expr.this
207
246
  for expr in statement.find_all(
208
247
  sqlglot.exp.Create,
209
248
  sqlglot.exp.Insert,
@@ -215,36 +254,36 @@ def _table_level_lineage(
215
254
  # In some cases like "MERGE ... then INSERT (col1, col2) VALUES (col1, col2)",
216
255
  # the `this` on the INSERT part isn't a table.
217
256
  if isinstance(expr.this, sqlglot.exp.Table)
218
- }
219
- | {
257
+ )
258
+ | _extract_table_names(
220
259
  # For statements that include a column list, like
221
260
  # CREATE DDL statements and `INSERT INTO table (col1, col2) SELECT ...`
222
261
  # the table name is nested inside a Schema object.
223
- _TableName.from_sqlglot_table(expr.this.this)
262
+ expr.this.this
224
263
  for expr in statement.find_all(
225
264
  sqlglot.exp.Create,
226
265
  sqlglot.exp.Insert,
227
266
  )
228
267
  if isinstance(expr.this, sqlglot.exp.Schema)
229
268
  and isinstance(expr.this.this, sqlglot.exp.Table)
230
- }
231
- | {
269
+ )
270
+ | _extract_table_names(
232
271
  # For drop statements, we only want it if a table/view is being dropped.
233
272
  # Other "kinds" will not have table.name populated.
234
- _TableName.from_sqlglot_table(expr.this)
273
+ expr.this
235
274
  for expr in ([statement] if isinstance(statement, sqlglot.exp.Drop) else [])
236
275
  if isinstance(expr.this, sqlglot.exp.Table)
237
276
  and expr.this.this
238
277
  and expr.this.name
239
- }
278
+ )
240
279
  )
241
280
 
242
281
  tables = (
243
- {
244
- _TableName.from_sqlglot_table(table)
282
+ _extract_table_names(
283
+ table
245
284
  for table in statement.find_all(sqlglot.exp.Table)
246
285
  if not isinstance(table.parent, sqlglot.exp.Drop)
247
- }
286
+ )
248
287
  # ignore references created in this query
249
288
  - modified
250
289
  # ignore CTEs created in this statement
@@ -520,8 +559,6 @@ def _select_statement_cll(
520
559
  # Generate SELECT lineage.
521
560
  direct_raw_col_upstreams = _get_direct_raw_col_upstreams(lineage_node)
522
561
 
523
- # column_logic = lineage_node.source
524
-
525
562
  # Fuzzy resolve the output column.
526
563
  original_col_expression = lineage_node.expression
527
564
  if output_col.startswith("_col_"):
@@ -560,7 +597,7 @@ def _select_statement_cll(
560
597
  column_type=output_col_type,
561
598
  ),
562
599
  upstreams=sorted(direct_resolved_col_upstreams),
563
- # logic=column_logic.sql(pretty=True, dialect=dialect),
600
+ logic=_get_column_transformation(lineage_node, dialect),
564
601
  )
565
602
  )
566
603
 
@@ -575,6 +612,7 @@ def _select_statement_cll(
575
612
 
576
613
  class _ColumnLineageWithDebugInfo(_ParserBaseModel):
577
614
  column_lineage: List[_ColumnLineageInfo]
615
+ joins: Optional[List[_JoinInfo]] = None
578
616
 
579
617
  select_statement: Optional[sqlglot.exp.Expression] = None
580
618
  # TODO: Add column exceptions here.
@@ -645,17 +683,27 @@ def _column_level_lineage(
645
683
  output_table=downstream_table,
646
684
  )
647
685
 
686
+ joins: Optional[List[_JoinInfo]] = None
687
+ try:
688
+ # List join clauses.
689
+ joins = _list_joins(dialect=dialect, root_scope=root_scope)
690
+ logger.debug("Joins: %s", joins)
691
+ except Exception as e:
692
+ # This is a non-fatal error, so we can continue.
693
+ logger.debug("Failed to list joins: %s", e)
694
+
648
695
  return _ColumnLineageWithDebugInfo(
649
696
  column_lineage=column_lineage,
697
+ joins=joins,
650
698
  select_statement=select_statement,
651
699
  )
652
700
 
653
701
 
654
702
  def _get_direct_raw_col_upstreams(
655
703
  lineage_node: sqlglot.lineage.Node,
656
- ) -> Set[_ColumnRef]:
657
- # Using a set here to deduplicate upstreams.
658
- direct_raw_col_upstreams: Set[_ColumnRef] = set()
704
+ ) -> OrderedSet[_ColumnRef]:
705
+ # Using an OrderedSet here to deduplicate upstreams while preserving "discovery" order.
706
+ direct_raw_col_upstreams: OrderedSet[_ColumnRef] = OrderedSet()
659
707
 
660
708
  for node in lineage_node.walk():
661
709
  if node.downstream:
@@ -690,6 +738,237 @@ def _get_direct_raw_col_upstreams(
690
738
  return direct_raw_col_upstreams
691
739
 
692
740
 
741
+ def _is_single_column_expression(
742
+ expression: sqlglot.exp.Expression,
743
+ ) -> bool:
744
+ # Check if the expression is trivial, i.e. it's just a single column.
745
+ # Things like count(*) or coalesce(col, 0) are not single columns.
746
+ if isinstance(expression, sqlglot.exp.Alias):
747
+ expression = expression.this
748
+
749
+ return isinstance(expression, sqlglot.exp.Column)
750
+
751
+
752
+ def _get_column_transformation(
753
+ lineage_node: sqlglot.lineage.Node,
754
+ dialect: sqlglot.Dialect,
755
+ parent: Optional[sqlglot.lineage.Node] = None,
756
+ ) -> ColumnTransformation:
757
+ # expression = lineage_node.expression
758
+ # is_single_column_expression = _is_single_column_expression(lineage_node.expression)
759
+ if not lineage_node.downstream:
760
+ # parent_expression = parent.expression if parent else expression
761
+ if parent:
762
+ expression = parent.expression
763
+ is_copy = _is_single_column_expression(expression)
764
+ else:
765
+ # This case should rarely happen.
766
+ is_copy = True
767
+ expression = lineage_node.expression
768
+ return ColumnTransformation(
769
+ is_direct_copy=is_copy,
770
+ column_logic=expression.sql(dialect=dialect),
771
+ )
772
+
773
+ elif len(lineage_node.downstream) > 1 or not _is_single_column_expression(
774
+ lineage_node.expression
775
+ ):
776
+ return ColumnTransformation(
777
+ is_direct_copy=False,
778
+ column_logic=lineage_node.expression.sql(dialect=dialect),
779
+ )
780
+
781
+ else:
782
+ return _get_column_transformation(
783
+ lineage_node=lineage_node.downstream[0],
784
+ dialect=dialect,
785
+ parent=lineage_node,
786
+ )
787
+
788
+
789
+ def _get_join_side_tables(
790
+ target: sqlglot.exp.Expression,
791
+ dialect: sqlglot.Dialect,
792
+ scope: sqlglot.optimizer.Scope,
793
+ ) -> OrderedSet[_TableName]:
794
+ target_alias_or_name = target.alias_or_name
795
+ if (source := scope.sources.get(target_alias_or_name)) and isinstance(
796
+ source, sqlglot.exp.Table
797
+ ):
798
+ # If the source is a Scope, we need to do some resolution work.
799
+ return OrderedSet([_TableName.from_sqlglot_table(source)])
800
+
801
+ column = sqlglot.exp.Column(
802
+ this=sqlglot.exp.Star(),
803
+ table=sqlglot.exp.Identifier(this=target.alias_or_name),
804
+ )
805
+ columns_used = _get_raw_col_upstreams_for_expression(
806
+ select=column,
807
+ dialect=dialect,
808
+ scope=scope,
809
+ )
810
+ return OrderedSet(col.table for col in columns_used)
811
+
812
+
813
+ def _get_raw_col_upstreams_for_expression(
814
+ select: sqlglot.exp.Expression,
815
+ dialect: sqlglot.Dialect,
816
+ scope: sqlglot.optimizer.Scope,
817
+ ) -> OrderedSet[_ColumnRef]:
818
+ if not isinstance(scope.expression, sqlglot.exp.Query):
819
+ # Note that Select, Subquery, SetOperation, etc. are all subclasses of Query.
820
+ # So this line should basically never happen.
821
+ return OrderedSet()
822
+
823
+ original_expression = scope.expression
824
+ updated_expression = scope.expression.select(select, append=False, copy=True)
825
+
826
+ try:
827
+ scope.expression = updated_expression
828
+ node = sqlglot.lineage.to_node(
829
+ column=0,
830
+ scope=scope,
831
+ dialect=dialect,
832
+ trim_selects=False,
833
+ )
834
+
835
+ return _get_direct_raw_col_upstreams(node)
836
+ finally:
837
+ scope.expression = original_expression
838
+
839
+
840
+ def _list_joins(
841
+ dialect: sqlglot.Dialect,
842
+ root_scope: sqlglot.optimizer.Scope,
843
+ ) -> List[_JoinInfo]:
844
+ # TODO: Add a confidence tracker here.
845
+
846
+ joins: List[_JoinInfo] = []
847
+
848
+ scope: sqlglot.optimizer.Scope
849
+ for scope in root_scope.traverse():
850
+ join: sqlglot.exp.Join
851
+ for join in scope.find_all(sqlglot.exp.Join):
852
+ left_side_tables: OrderedSet[_TableName] = OrderedSet()
853
+ from_clause: sqlglot.exp.From
854
+ for from_clause in scope.find_all(sqlglot.exp.From):
855
+ left_side_tables.update(
856
+ _get_join_side_tables(
857
+ target=from_clause.this,
858
+ dialect=dialect,
859
+ scope=scope,
860
+ )
861
+ )
862
+
863
+ right_side_tables: OrderedSet[_TableName] = OrderedSet()
864
+ if join_target := join.this:
865
+ right_side_tables = _get_join_side_tables(
866
+ target=join_target,
867
+ dialect=dialect,
868
+ scope=scope,
869
+ )
870
+
871
+ # We don't need to check for `using` here because it's normalized to `on`
872
+ # by the sqlglot optimizer.
873
+ on_clause: Optional[sqlglot.exp.Expression] = join.args.get("on")
874
+ if on_clause:
875
+ joined_columns = _get_raw_col_upstreams_for_expression(
876
+ select=on_clause, dialect=dialect, scope=scope
877
+ )
878
+
879
+ unique_tables = OrderedSet(col.table for col in joined_columns)
880
+ if not unique_tables:
881
+ logger.debug(
882
+ "Skipping join because we couldn't resolve the tables from the join condition: %s",
883
+ join.sql(dialect=dialect),
884
+ )
885
+ continue
886
+
887
+ # When we have an `on` clause, we only want to include tables whose columns are
888
+ # involved in the join condition. Without this, a statement like this:
889
+ # WITH cte_alias AS (select t1.id, t1.user_id, t2.other_col from t1 join t2 on t1.id = t2.id)
890
+ # SELECT * FROM users
891
+ # JOIN cte_alias ON users.id = cte_alias.user_id
892
+ # would incorrectly include t2 as part of the left side tables.
893
+ left_side_tables = OrderedSet(left_side_tables & unique_tables)
894
+ right_side_tables = OrderedSet(right_side_tables & unique_tables)
895
+ else:
896
+ # Some joins (cross join, lateral join, etc.) don't have an ON clause.
897
+ # In those cases, we have some best-effort logic at least extract the
898
+ # tables involved.
899
+ joined_columns = OrderedSet()
900
+
901
+ if not left_side_tables and not right_side_tables:
902
+ logger.debug(
903
+ "Skipping join because we couldn't resolve any tables from the join operands: %s",
904
+ join.sql(dialect=dialect),
905
+ )
906
+ continue
907
+ elif len(left_side_tables | right_side_tables) == 1:
908
+ # When we don't have an ON clause, we're more strict about the
909
+ # minimum number of tables we need to resolve to avoid false positives.
910
+ # On the off chance someone is doing a self-cross-join, we'll miss it.
911
+ logger.debug(
912
+ "Skipping join because we couldn't resolve enough tables from the join operands: %s",
913
+ join.sql(dialect=dialect),
914
+ )
915
+ continue
916
+
917
+ joins.append(
918
+ _JoinInfo(
919
+ join_type=_get_join_type(join),
920
+ left_tables=list(left_side_tables),
921
+ right_tables=list(right_side_tables),
922
+ on_clause=on_clause.sql(dialect=dialect) if on_clause else None,
923
+ columns_involved=list(sorted(joined_columns)),
924
+ )
925
+ )
926
+
927
+ return joins
928
+
929
+
930
+ def _get_join_type(join: sqlglot.exp.Join) -> str:
931
+ """Returns the type of join as a string.
932
+
933
+ Args:
934
+ join: A sqlglot Join expression.
935
+
936
+ Returns:
937
+ Stringified join type e.g. "LEFT JOIN", "RIGHT OUTER JOIN", "LATERAL JOIN", etc.
938
+ """
939
+ # This logic was derived from the sqlglot join_sql method.
940
+ # https://github.com/tobymao/sqlglot/blob/07bf71bae5d2a5c381104a86bb52c06809c21174/sqlglot/generator.py#L2248
941
+
942
+ # Special case for lateral joins
943
+ if isinstance(join.this, sqlglot.exp.Lateral):
944
+ if join.this.args.get("cross_apply") is not None:
945
+ return "CROSS APPLY"
946
+ return "LATERAL JOIN"
947
+
948
+ # Special case for STRAIGHT_JOIN (MySQL)
949
+ if join.args.get("kind") == "STRAIGHT":
950
+ return "STRAIGHT_JOIN"
951
+
952
+ # <method> <global> <side> <kind> JOIN
953
+ # - method = "HASH", "MERGE"
954
+ # - global = "GLOBAL"
955
+ # - side = "LEFT", "RIGHT"
956
+ # - kind = "INNER", "OUTER", "SEMI", "ANTI"
957
+ components = []
958
+ if method := join.args.get("method"):
959
+ components.append(method)
960
+ if join.args.get("global"):
961
+ components.append("GLOBAL")
962
+ if side := join.args.get("side"):
963
+ # For SEMI/ANTI joins, side is optional
964
+ components.append(side)
965
+ if kind := join.args.get("kind"):
966
+ components.append(kind)
967
+
968
+ components.append("JOIN")
969
+ return " ".join(components)
970
+
971
+
693
972
  def _extract_select_from_create(
694
973
  statement: sqlglot.exp.Create,
695
974
  ) -> sqlglot.exp.Expression:
@@ -875,6 +1154,35 @@ def _translate_internal_column_lineage(
875
1154
  )
876
1155
 
877
1156
 
1157
+ def _translate_internal_joins(
1158
+ table_name_urn_mapping: Dict[_TableName, str],
1159
+ raw_joins: List[_JoinInfo],
1160
+ dialect: sqlglot.Dialect,
1161
+ ) -> List[JoinInfo]:
1162
+ joins = []
1163
+ for raw_join in raw_joins:
1164
+ joins.append(
1165
+ JoinInfo(
1166
+ join_type=raw_join.join_type,
1167
+ left_tables=[
1168
+ table_name_urn_mapping[table] for table in raw_join.left_tables
1169
+ ],
1170
+ right_tables=[
1171
+ table_name_urn_mapping[table] for table in raw_join.right_tables
1172
+ ],
1173
+ on_clause=raw_join.on_clause,
1174
+ columns_involved=[
1175
+ ColumnRef(
1176
+ table=table_name_urn_mapping[col.table],
1177
+ column=col.column,
1178
+ )
1179
+ for col in raw_join.columns_involved
1180
+ ],
1181
+ )
1182
+ )
1183
+ return joins
1184
+
1185
+
878
1186
  _StrOrNone = TypeVar("_StrOrNone", str, Optional[str])
879
1187
 
880
1188
 
@@ -1034,6 +1342,7 @@ def _sqlglot_lineage_inner(
1034
1342
  )
1035
1343
 
1036
1344
  column_lineage: Optional[List[_ColumnLineageInfo]] = None
1345
+ joins = None
1037
1346
  try:
1038
1347
  with cooperative_timeout(
1039
1348
  timeout=(
@@ -1049,6 +1358,7 @@ def _sqlglot_lineage_inner(
1049
1358
  default_schema=default_schema,
1050
1359
  )
1051
1360
  column_lineage = column_lineage_debug_info.column_lineage
1361
+ joins = column_lineage_debug_info.joins
1052
1362
  except CooperativeTimeoutError as e:
1053
1363
  logger.debug(f"Timed out while generating column-level lineage: {e}")
1054
1364
  debug_info.column_error = e
@@ -1081,6 +1391,14 @@ def _sqlglot_lineage_inner(
1081
1391
  f"Failed to translate column lineage to urns: {e}", exc_info=True
1082
1392
  )
1083
1393
  debug_info.column_error = e
1394
+ joins_urns = None
1395
+ if joins is not None:
1396
+ try:
1397
+ joins_urns = _translate_internal_joins(
1398
+ table_name_urn_mapping, raw_joins=joins, dialect=dialect
1399
+ )
1400
+ except KeyError as e:
1401
+ logger.debug(f"Failed to translate joins to urns: {e}", exc_info=True)
1084
1402
 
1085
1403
  query_type, query_type_props = get_query_type_of_sql(
1086
1404
  original_statement, dialect=dialect
@@ -1095,6 +1413,7 @@ def _sqlglot_lineage_inner(
1095
1413
  in_tables=in_urns,
1096
1414
  out_tables=out_urns,
1097
1415
  column_lineage=column_lineage_urns,
1416
+ joins=joins_urns,
1098
1417
  debug_info=debug_info,
1099
1418
  )
1100
1419
 
@@ -377,7 +377,8 @@ def _maybe_print_upgrade_message(
377
377
  + click.style(
378
378
  f"➡️ Upgrade via \"pip install 'acryl-datahub=={version_stats.server.current.version}'\"",
379
379
  fg="cyan",
380
- )
380
+ ),
381
+ err=True,
381
382
  )
382
383
  elif client_server_compat == 0 and encourage_cli_upgrade:
383
384
  with contextlib.suppress(Exception):
@@ -387,7 +388,8 @@ def _maybe_print_upgrade_message(
387
388
  + click.style(
388
389
  f"You seem to be running an old version of datahub cli: {current_version} {get_days(current_release_date)}. Latest version is {latest_version} {get_days(latest_release_date)}.\nUpgrade via \"pip install -U 'acryl-datahub'\"",
389
390
  fg="cyan",
390
- )
391
+ ),
392
+ err=True,
391
393
  )
392
394
  elif encourage_quickstart_upgrade:
393
395
  try: