acryl-datahub 1.0.0.4rc7__py3-none-any.whl → 1.1.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of acryl-datahub might be problematic. Click here for more details.

@@ -15,10 +15,10 @@ import pathlib
15
15
  def _load_schema(schema_name: str) -> str:
16
16
  return (pathlib.Path(__file__).parent / f"{schema_name}.avsc").read_text()
17
17
 
18
- def getMetadataChangeEventSchema() -> str:
19
- return _load_schema("MetadataChangeEvent")
20
-
21
18
  def getMetadataChangeProposalSchema() -> str:
22
19
  return _load_schema("MetadataChangeProposal")
23
20
 
21
+ def getMetadataChangeEventSchema() -> str:
22
+ return _load_schema("MetadataChangeEvent")
23
+
24
24
  # fmt: on
@@ -1366,6 +1366,13 @@ class SqlParsingAggregator(Closeable):
1366
1366
  ):
1367
1367
  upstream_columns = [x[0] for x in upstream_columns_for_query]
1368
1368
  required_queries.add(query_id)
1369
+ query = queries_map[query_id]
1370
+
1371
+ column_logic = None
1372
+ for lineage_info in query.column_lineage:
1373
+ if lineage_info.downstream.column == downstream_column:
1374
+ column_logic = lineage_info.logic
1375
+ break
1369
1376
 
1370
1377
  upstream_aspect.fineGrainedLineages.append(
1371
1378
  models.FineGrainedLineageClass(
@@ -1383,7 +1390,16 @@ class SqlParsingAggregator(Closeable):
1383
1390
  if self.can_generate_query(query_id)
1384
1391
  else None
1385
1392
  ),
1386
- confidenceScore=queries_map[query_id].confidence_score,
1393
+ confidenceScore=query.confidence_score,
1394
+ transformOperation=(
1395
+ (
1396
+ f"COPY: {column_logic.column_logic}"
1397
+ if column_logic.is_direct_copy
1398
+ else f"SQL: {column_logic.column_logic}"
1399
+ )
1400
+ if column_logic
1401
+ else None
1402
+ ),
1387
1403
  )
1388
1404
  )
1389
1405
 
@@ -54,6 +54,8 @@ from datahub.utilities.cooperative_timeout import (
54
54
  CooperativeTimeoutError,
55
55
  cooperative_timeout,
56
56
  )
57
+ from datahub.utilities.dedup_list import deduplicate_list
58
+ from datahub.utilities.ordered_set import OrderedSet
57
59
 
58
60
  assert SQLGLOT_PATCHED
59
61
 
@@ -128,19 +130,37 @@ class DownstreamColumnRef(_ParserBaseModel):
128
130
  return SchemaFieldDataTypeClass.from_obj(v)
129
131
 
130
132
 
133
+ class ColumnTransformation(_ParserBaseModel):
134
+ is_direct_copy: bool
135
+ column_logic: str
136
+
137
+
131
138
  class _ColumnLineageInfo(_ParserBaseModel):
132
139
  downstream: _DownstreamColumnRef
133
140
  upstreams: List[_ColumnRef]
134
141
 
135
- logic: Optional[str] = None
142
+ logic: Optional[ColumnTransformation] = None
136
143
 
137
144
 
138
145
  class ColumnLineageInfo(_ParserBaseModel):
139
146
  downstream: DownstreamColumnRef
140
147
  upstreams: List[ColumnRef]
141
148
 
142
- # Logic for this column, as a SQL expression.
143
- logic: Optional[str] = pydantic.Field(default=None, exclude=True)
149
+ logic: Optional[ColumnTransformation] = pydantic.Field(default=None)
150
+
151
+
152
+ class _JoinInfo(_ParserBaseModel):
153
+ join_type: str
154
+ tables: List[_TableName]
155
+ on_clause: Optional[str]
156
+ columns_involved: List[_ColumnRef]
157
+
158
+
159
+ class JoinInfo(_ParserBaseModel):
160
+ join_type: str
161
+ tables: List[Urn]
162
+ on_clause: Optional[str]
163
+ columns_involved: List[ColumnRef]
144
164
 
145
165
 
146
166
  class SqlParsingDebugInfo(_ParserBaseModel):
@@ -178,6 +198,7 @@ class SqlParsingResult(_ParserBaseModel):
178
198
  out_tables: List[Urn]
179
199
 
180
200
  column_lineage: Optional[List[ColumnLineageInfo]] = None
201
+ joins: Optional[List[JoinInfo]] = None
181
202
 
182
203
  # TODO include formatted original sql logic
183
204
  # TODO include list of referenced columns
@@ -520,8 +541,6 @@ def _select_statement_cll(
520
541
  # Generate SELECT lineage.
521
542
  direct_raw_col_upstreams = _get_direct_raw_col_upstreams(lineage_node)
522
543
 
523
- # column_logic = lineage_node.source
524
-
525
544
  # Fuzzy resolve the output column.
526
545
  original_col_expression = lineage_node.expression
527
546
  if output_col.startswith("_col_"):
@@ -560,7 +579,7 @@ def _select_statement_cll(
560
579
  column_type=output_col_type,
561
580
  ),
562
581
  upstreams=sorted(direct_resolved_col_upstreams),
563
- # logic=column_logic.sql(pretty=True, dialect=dialect),
582
+ logic=_get_column_transformation(lineage_node, dialect),
564
583
  )
565
584
  )
566
585
 
@@ -575,6 +594,7 @@ def _select_statement_cll(
575
594
 
576
595
  class _ColumnLineageWithDebugInfo(_ParserBaseModel):
577
596
  column_lineage: List[_ColumnLineageInfo]
597
+ joins: Optional[List[_JoinInfo]] = None
578
598
 
579
599
  select_statement: Optional[sqlglot.exp.Expression] = None
580
600
  # TODO: Add column exceptions here.
@@ -645,17 +665,27 @@ def _column_level_lineage(
645
665
  output_table=downstream_table,
646
666
  )
647
667
 
668
+ joins: Optional[List[_JoinInfo]] = None
669
+ try:
670
+ # List join clauses.
671
+ joins = _list_joins(dialect=dialect, root_scope=root_scope)
672
+ logger.debug("Joins: %s", joins)
673
+ except Exception as e:
674
+ # This is a non-fatal error, so we can continue.
675
+ logger.debug("Failed to list joins: %s", e)
676
+
648
677
  return _ColumnLineageWithDebugInfo(
649
678
  column_lineage=column_lineage,
679
+ joins=joins,
650
680
  select_statement=select_statement,
651
681
  )
652
682
 
653
683
 
654
684
  def _get_direct_raw_col_upstreams(
655
685
  lineage_node: sqlglot.lineage.Node,
656
- ) -> Set[_ColumnRef]:
657
- # Using a set here to deduplicate upstreams.
658
- direct_raw_col_upstreams: Set[_ColumnRef] = set()
686
+ ) -> OrderedSet[_ColumnRef]:
687
+ # Using an OrderedSet here to deduplicate upstreams while preserving "discovery" order.
688
+ direct_raw_col_upstreams: OrderedSet[_ColumnRef] = OrderedSet()
659
689
 
660
690
  for node in lineage_node.walk():
661
691
  if node.downstream:
@@ -690,6 +720,152 @@ def _get_direct_raw_col_upstreams(
690
720
  return direct_raw_col_upstreams
691
721
 
692
722
 
723
+ def _is_single_column_expression(
724
+ expression: sqlglot.exp.Expression,
725
+ ) -> bool:
726
+ # Check if the expression is trivial, i.e. it's just a single column.
727
+ # Things like count(*) or coalesce(col, 0) are not single columns.
728
+ if isinstance(expression, sqlglot.exp.Alias):
729
+ expression = expression.this
730
+
731
+ return isinstance(expression, sqlglot.exp.Column)
732
+
733
+
734
+ def _get_column_transformation(
735
+ lineage_node: sqlglot.lineage.Node,
736
+ dialect: sqlglot.Dialect,
737
+ parent: Optional[sqlglot.lineage.Node] = None,
738
+ ) -> ColumnTransformation:
739
+ # expression = lineage_node.expression
740
+ # is_single_column_expression = _is_single_column_expression(lineage_node.expression)
741
+ if not lineage_node.downstream:
742
+ # parent_expression = parent.expression if parent else expression
743
+ if parent:
744
+ expression = parent.expression
745
+ is_copy = _is_single_column_expression(expression)
746
+ else:
747
+ # This case should rarely happen.
748
+ is_copy = True
749
+ expression = lineage_node.expression
750
+ return ColumnTransformation(
751
+ is_direct_copy=is_copy,
752
+ column_logic=expression.sql(dialect=dialect),
753
+ )
754
+
755
+ elif len(lineage_node.downstream) > 1 or not _is_single_column_expression(
756
+ lineage_node.expression
757
+ ):
758
+ return ColumnTransformation(
759
+ is_direct_copy=False,
760
+ column_logic=lineage_node.expression.sql(dialect=dialect),
761
+ )
762
+
763
+ else:
764
+ return _get_column_transformation(
765
+ lineage_node=lineage_node.downstream[0],
766
+ dialect=dialect,
767
+ parent=lineage_node,
768
+ )
769
+
770
+
771
+ def _get_raw_col_upstreams_for_expression(
772
+ select: sqlglot.exp.Expression,
773
+ dialect: sqlglot.Dialect,
774
+ scope: sqlglot.optimizer.Scope,
775
+ ) -> OrderedSet[_ColumnRef]:
776
+ if not isinstance(scope.expression, sqlglot.exp.Query):
777
+ # Note that Select, Subquery, SetOperation, etc. are all subclasses of Query.
778
+ # So this line should basically never happen.
779
+ return OrderedSet()
780
+
781
+ original_expression = scope.expression
782
+ updated_expression = scope.expression.select(select, append=False, copy=True)
783
+
784
+ try:
785
+ scope.expression = updated_expression
786
+ node = sqlglot.lineage.to_node(
787
+ column=0,
788
+ scope=scope,
789
+ dialect=dialect,
790
+ trim_selects=False,
791
+ )
792
+
793
+ return _get_direct_raw_col_upstreams(node)
794
+ finally:
795
+ scope.expression = original_expression
796
+
797
+
798
+ def _list_joins(
799
+ dialect: sqlglot.Dialect,
800
+ root_scope: sqlglot.optimizer.Scope,
801
+ ) -> List[_JoinInfo]:
802
+ # TODO: Add a confidence tracker here.
803
+
804
+ joins: List[_JoinInfo] = []
805
+
806
+ for scope in root_scope.traverse():
807
+ join: sqlglot.exp.Join
808
+ for join in scope.find_all(sqlglot.exp.Join):
809
+ on_clause: Optional[sqlglot.exp.Expression] = join.args.get("on")
810
+ if not on_clause:
811
+ # We don't need to check for `using` here because it's normalized to `on`
812
+ # by the sqlglot optimizer.
813
+ logger.debug(
814
+ "Skipping join without ON clause: %s",
815
+ join.sql(dialect=dialect),
816
+ )
817
+ # TODO: This skips joins that don't have ON clauses, like cross joins, lateral joins, etc.
818
+ continue
819
+
820
+ joined_columns = _get_raw_col_upstreams_for_expression(
821
+ select=on_clause, dialect=dialect, scope=scope
822
+ )
823
+
824
+ unique_tables = deduplicate_list(col.table for col in joined_columns)
825
+ if not unique_tables:
826
+ logger.debug(
827
+ "Skipping join because we couldn't resolve the tables: %s",
828
+ join.sql(dialect=dialect),
829
+ )
830
+ continue
831
+
832
+ joins.append(
833
+ _JoinInfo(
834
+ join_type=_get_join_type(join),
835
+ tables=list(unique_tables),
836
+ on_clause=on_clause.sql(dialect=dialect) if on_clause else None,
837
+ columns_involved=list(sorted(joined_columns)),
838
+ )
839
+ )
840
+
841
+ return joins
842
+
843
+
844
+ def _get_join_type(join: sqlglot.exp.Join) -> str:
845
+ # Will return "LEFT JOIN", "RIGHT OUTER JOIN", etc.
846
+ # This is not really comprehensive - there's a couple other edge
847
+ # cases (e.g. STRAIGHT_JOIN, anti-join) that we don't handle.
848
+
849
+ components = []
850
+
851
+ # Add method if present (e.g. "HASH", "MERGE")
852
+ if method := join.args.get("method"):
853
+ components.append(method)
854
+
855
+ # Add side if present (e.g. "LEFT", "RIGHT")
856
+ if side := join.args.get("side"):
857
+ components.append(side)
858
+
859
+ # Add kind if present (e.g. "INNER", "OUTER", "SEMI", "ANTI")
860
+ if kind := join.args.get("kind"):
861
+ components.append(kind)
862
+
863
+ # Join the components and append "JOIN"
864
+ if not components:
865
+ return "JOIN"
866
+ return f"{' '.join(components)} JOIN"
867
+
868
+
693
869
  def _extract_select_from_create(
694
870
  statement: sqlglot.exp.Create,
695
871
  ) -> sqlglot.exp.Expression:
@@ -875,6 +1051,30 @@ def _translate_internal_column_lineage(
875
1051
  )
876
1052
 
877
1053
 
1054
+ def _translate_internal_joins(
1055
+ table_name_urn_mapping: Dict[_TableName, str],
1056
+ raw_joins: List[_JoinInfo],
1057
+ dialect: sqlglot.Dialect,
1058
+ ) -> List[JoinInfo]:
1059
+ joins = []
1060
+ for raw_join in raw_joins:
1061
+ joins.append(
1062
+ JoinInfo(
1063
+ join_type=raw_join.join_type,
1064
+ tables=[table_name_urn_mapping[table] for table in raw_join.tables],
1065
+ on_clause=raw_join.on_clause,
1066
+ columns_involved=[
1067
+ ColumnRef(
1068
+ table=table_name_urn_mapping[col.table],
1069
+ column=col.column,
1070
+ )
1071
+ for col in raw_join.columns_involved
1072
+ ],
1073
+ )
1074
+ )
1075
+ return joins
1076
+
1077
+
878
1078
  _StrOrNone = TypeVar("_StrOrNone", str, Optional[str])
879
1079
 
880
1080
 
@@ -1034,6 +1234,7 @@ def _sqlglot_lineage_inner(
1034
1234
  )
1035
1235
 
1036
1236
  column_lineage: Optional[List[_ColumnLineageInfo]] = None
1237
+ joins = None
1037
1238
  try:
1038
1239
  with cooperative_timeout(
1039
1240
  timeout=(
@@ -1049,6 +1250,7 @@ def _sqlglot_lineage_inner(
1049
1250
  default_schema=default_schema,
1050
1251
  )
1051
1252
  column_lineage = column_lineage_debug_info.column_lineage
1253
+ joins = column_lineage_debug_info.joins
1052
1254
  except CooperativeTimeoutError as e:
1053
1255
  logger.debug(f"Timed out while generating column-level lineage: {e}")
1054
1256
  debug_info.column_error = e
@@ -1081,6 +1283,14 @@ def _sqlglot_lineage_inner(
1081
1283
  f"Failed to translate column lineage to urns: {e}", exc_info=True
1082
1284
  )
1083
1285
  debug_info.column_error = e
1286
+ joins_urns = None
1287
+ if joins is not None:
1288
+ try:
1289
+ joins_urns = _translate_internal_joins(
1290
+ table_name_urn_mapping, raw_joins=joins, dialect=dialect
1291
+ )
1292
+ except KeyError as e:
1293
+ logger.debug(f"Failed to translate joins to urns: {e}", exc_info=True)
1084
1294
 
1085
1295
  query_type, query_type_props = get_query_type_of_sql(
1086
1296
  original_statement, dialect=dialect
@@ -1095,6 +1305,7 @@ def _sqlglot_lineage_inner(
1095
1305
  in_tables=in_urns,
1096
1306
  out_tables=out_urns,
1097
1307
  column_lineage=column_lineage_urns,
1308
+ joins=joins_urns,
1098
1309
  debug_info=debug_info,
1099
1310
  )
1100
1311