acryl-datahub 1.1.0rc1__py3-none-any.whl → 1.1.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of acryl-datahub might be problematic. Click here for more details.

@@ -5,7 +5,18 @@ import functools
5
5
  import logging
6
6
  import traceback
7
7
  from collections import defaultdict
8
- from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union
8
+ from typing import (
9
+ AbstractSet,
10
+ Any,
11
+ Dict,
12
+ Iterable,
13
+ List,
14
+ Optional,
15
+ Set,
16
+ Tuple,
17
+ TypeVar,
18
+ Union,
19
+ )
9
20
 
10
21
  import pydantic.dataclasses
11
22
  import sqlglot
@@ -54,7 +65,6 @@ from datahub.utilities.cooperative_timeout import (
54
65
  CooperativeTimeoutError,
55
66
  cooperative_timeout,
56
67
  )
57
- from datahub.utilities.dedup_list import deduplicate_list
58
68
  from datahub.utilities.ordered_set import OrderedSet
59
69
 
60
70
  assert SQLGLOT_PATCHED
@@ -151,14 +161,16 @@ class ColumnLineageInfo(_ParserBaseModel):
151
161
 
152
162
  class _JoinInfo(_ParserBaseModel):
153
163
  join_type: str
154
- tables: List[_TableName]
164
+ left_tables: List[_TableName]
165
+ right_tables: List[_TableName]
155
166
  on_clause: Optional[str]
156
167
  columns_involved: List[_ColumnRef]
157
168
 
158
169
 
159
170
  class JoinInfo(_ParserBaseModel):
160
171
  join_type: str
161
- tables: List[Urn]
172
+ left_tables: List[Urn]
173
+ right_tables: List[Urn]
162
174
  on_clause: Optional[str]
163
175
  columns_involved: List[ColumnRef]
164
176
 
@@ -218,13 +230,19 @@ class SqlParsingResult(_ParserBaseModel):
218
230
  )
219
231
 
220
232
 
233
+ def _extract_table_names(
234
+ iterable: Iterable[sqlglot.exp.Table],
235
+ ) -> OrderedSet[_TableName]:
236
+ return OrderedSet(_TableName.from_sqlglot_table(table) for table in iterable)
237
+
238
+
221
239
  def _table_level_lineage(
222
240
  statement: sqlglot.Expression, dialect: sqlglot.Dialect
223
- ) -> Tuple[Set[_TableName], Set[_TableName]]:
241
+ ) -> Tuple[AbstractSet[_TableName], AbstractSet[_TableName]]:
224
242
  # Generate table-level lineage.
225
243
  modified = (
226
- {
227
- _TableName.from_sqlglot_table(expr.this)
244
+ _extract_table_names(
245
+ expr.this
228
246
  for expr in statement.find_all(
229
247
  sqlglot.exp.Create,
230
248
  sqlglot.exp.Insert,
@@ -236,36 +254,36 @@ def _table_level_lineage(
236
254
  # In some cases like "MERGE ... then INSERT (col1, col2) VALUES (col1, col2)",
237
255
  # the `this` on the INSERT part isn't a table.
238
256
  if isinstance(expr.this, sqlglot.exp.Table)
239
- }
240
- | {
257
+ )
258
+ | _extract_table_names(
241
259
  # For statements that include a column list, like
242
260
  # CREATE DDL statements and `INSERT INTO table (col1, col2) SELECT ...`
243
261
  # the table name is nested inside a Schema object.
244
- _TableName.from_sqlglot_table(expr.this.this)
262
+ expr.this.this
245
263
  for expr in statement.find_all(
246
264
  sqlglot.exp.Create,
247
265
  sqlglot.exp.Insert,
248
266
  )
249
267
  if isinstance(expr.this, sqlglot.exp.Schema)
250
268
  and isinstance(expr.this.this, sqlglot.exp.Table)
251
- }
252
- | {
269
+ )
270
+ | _extract_table_names(
253
271
  # For drop statements, we only want it if a table/view is being dropped.
254
272
  # Other "kinds" will not have table.name populated.
255
- _TableName.from_sqlglot_table(expr.this)
273
+ expr.this
256
274
  for expr in ([statement] if isinstance(statement, sqlglot.exp.Drop) else [])
257
275
  if isinstance(expr.this, sqlglot.exp.Table)
258
276
  and expr.this.this
259
277
  and expr.this.name
260
- }
278
+ )
261
279
  )
262
280
 
263
281
  tables = (
264
- {
265
- _TableName.from_sqlglot_table(table)
282
+ _extract_table_names(
283
+ table
266
284
  for table in statement.find_all(sqlglot.exp.Table)
267
285
  if not isinstance(table.parent, sqlglot.exp.Drop)
268
- }
286
+ )
269
287
  # ignore references created in this query
270
288
  - modified
271
289
  # ignore CTEs created in this statement
@@ -768,6 +786,30 @@ def _get_column_transformation(
768
786
  )
769
787
 
770
788
 
789
+ def _get_join_side_tables(
790
+ target: sqlglot.exp.Expression,
791
+ dialect: sqlglot.Dialect,
792
+ scope: sqlglot.optimizer.Scope,
793
+ ) -> OrderedSet[_TableName]:
794
+ target_alias_or_name = target.alias_or_name
795
+ if (source := scope.sources.get(target_alias_or_name)) and isinstance(
796
+ source, sqlglot.exp.Table
797
+ ):
798
+ # If the source is a Scope, we need to do some resolution work.
799
+ return OrderedSet([_TableName.from_sqlglot_table(source)])
800
+
801
+ column = sqlglot.exp.Column(
802
+ this=sqlglot.exp.Star(),
803
+ table=sqlglot.exp.Identifier(this=target.alias_or_name),
804
+ )
805
+ columns_used = _get_raw_col_upstreams_for_expression(
806
+ select=column,
807
+ dialect=dialect,
808
+ scope=scope,
809
+ )
810
+ return OrderedSet(col.table for col in columns_used)
811
+
812
+
771
813
  def _get_raw_col_upstreams_for_expression(
772
814
  select: sqlglot.exp.Expression,
773
815
  dialect: sqlglot.Dialect,
@@ -803,36 +845,80 @@ def _list_joins(
803
845
 
804
846
  joins: List[_JoinInfo] = []
805
847
 
848
+ scope: sqlglot.optimizer.Scope
806
849
  for scope in root_scope.traverse():
807
850
  join: sqlglot.exp.Join
808
851
  for join in scope.find_all(sqlglot.exp.Join):
809
- on_clause: Optional[sqlglot.exp.Expression] = join.args.get("on")
810
- if not on_clause:
811
- # We don't need to check for `using` here because it's normalized to `on`
812
- # by the sqlglot optimizer.
813
- logger.debug(
814
- "Skipping join without ON clause: %s",
815
- join.sql(dialect=dialect),
852
+ left_side_tables: OrderedSet[_TableName] = OrderedSet()
853
+ from_clause: sqlglot.exp.From
854
+ for from_clause in scope.find_all(sqlglot.exp.From):
855
+ left_side_tables.update(
856
+ _get_join_side_tables(
857
+ target=from_clause.this,
858
+ dialect=dialect,
859
+ scope=scope,
860
+ )
816
861
  )
817
- # TODO: This skips joins that don't have ON clauses, like cross joins, lateral joins, etc.
818
- continue
819
862
 
820
- joined_columns = _get_raw_col_upstreams_for_expression(
821
- select=on_clause, dialect=dialect, scope=scope
822
- )
863
+ right_side_tables: OrderedSet[_TableName] = OrderedSet()
864
+ if join_target := join.this:
865
+ right_side_tables = _get_join_side_tables(
866
+ target=join_target,
867
+ dialect=dialect,
868
+ scope=scope,
869
+ )
823
870
 
824
- unique_tables = deduplicate_list(col.table for col in joined_columns)
825
- if not unique_tables:
826
- logger.debug(
827
- "Skipping join because we couldn't resolve the tables: %s",
828
- join.sql(dialect=dialect),
871
+ # We don't need to check for `using` here because it's normalized to `on`
872
+ # by the sqlglot optimizer.
873
+ on_clause: Optional[sqlglot.exp.Expression] = join.args.get("on")
874
+ if on_clause:
875
+ joined_columns = _get_raw_col_upstreams_for_expression(
876
+ select=on_clause, dialect=dialect, scope=scope
829
877
  )
830
- continue
878
+
879
+ unique_tables = OrderedSet(col.table for col in joined_columns)
880
+ if not unique_tables:
881
+ logger.debug(
882
+ "Skipping join because we couldn't resolve the tables from the join condition: %s",
883
+ join.sql(dialect=dialect),
884
+ )
885
+ continue
886
+
887
+ # When we have an `on` clause, we only want to include tables whose columns are
888
+ # involved in the join condition. Without this, a statement like this:
889
+ # WITH cte_alias AS (select t1.id, t1.user_id, t2.other_col from t1 join t2 on t1.id = t2.id)
890
+ # SELECT * FROM users
891
+ # JOIN cte_alias ON users.id = cte_alias.user_id
892
+ # would incorrectly include t2 as part of the left side tables.
893
+ left_side_tables = OrderedSet(left_side_tables & unique_tables)
894
+ right_side_tables = OrderedSet(right_side_tables & unique_tables)
895
+ else:
896
+ # Some joins (cross join, lateral join, etc.) don't have an ON clause.
897
+ # In those cases, we have some best-effort logic at least extract the
898
+ # tables involved.
899
+ joined_columns = OrderedSet()
900
+
901
+ if not left_side_tables and not right_side_tables:
902
+ logger.debug(
903
+ "Skipping join because we couldn't resolve any tables from the join operands: %s",
904
+ join.sql(dialect=dialect),
905
+ )
906
+ continue
907
+ elif len(left_side_tables | right_side_tables) == 1:
908
+ # When we don't have an ON clause, we're more strict about the
909
+ # minimum number of tables we need to resolve to avoid false positives.
910
+ # On the off chance someone is doing a self-cross-join, we'll miss it.
911
+ logger.debug(
912
+ "Skipping join because we couldn't resolve enough tables from the join operands: %s",
913
+ join.sql(dialect=dialect),
914
+ )
915
+ continue
831
916
 
832
917
  joins.append(
833
918
  _JoinInfo(
834
919
  join_type=_get_join_type(join),
835
- tables=list(unique_tables),
920
+ left_tables=list(left_side_tables),
921
+ right_tables=list(right_side_tables),
836
922
  on_clause=on_clause.sql(dialect=dialect) if on_clause else None,
837
923
  columns_involved=list(sorted(joined_columns)),
838
924
  )
@@ -842,28 +928,45 @@ def _list_joins(
842
928
 
843
929
 
844
930
  def _get_join_type(join: sqlglot.exp.Join) -> str:
845
- # Will return "LEFT JOIN", "RIGHT OUTER JOIN", etc.
846
- # This is not really comprehensive - there's a couple other edge
847
- # cases (e.g. STRAIGHT_JOIN, anti-join) that we don't handle.
931
+ """Returns the type of join as a string.
848
932
 
849
- components = []
933
+ Args:
934
+ join: A sqlglot Join expression.
850
935
 
851
- # Add method if present (e.g. "HASH", "MERGE")
936
+ Returns:
937
+ Stringified join type e.g. "LEFT JOIN", "RIGHT OUTER JOIN", "LATERAL JOIN", etc.
938
+ """
939
+ # This logic was derived from the sqlglot join_sql method.
940
+ # https://github.com/tobymao/sqlglot/blob/07bf71bae5d2a5c381104a86bb52c06809c21174/sqlglot/generator.py#L2248
941
+
942
+ # Special case for lateral joins
943
+ if isinstance(join.this, sqlglot.exp.Lateral):
944
+ if join.this.args.get("cross_apply") is not None:
945
+ return "CROSS APPLY"
946
+ return "LATERAL JOIN"
947
+
948
+ # Special case for STRAIGHT_JOIN (MySQL)
949
+ if join.args.get("kind") == "STRAIGHT":
950
+ return "STRAIGHT_JOIN"
951
+
952
+ # <method> <global> <side> <kind> JOIN
953
+ # - method = "HASH", "MERGE"
954
+ # - global = "GLOBAL"
955
+ # - side = "LEFT", "RIGHT"
956
+ # - kind = "INNER", "OUTER", "SEMI", "ANTI"
957
+ components = []
852
958
  if method := join.args.get("method"):
853
959
  components.append(method)
854
-
855
- # Add side if present (e.g. "LEFT", "RIGHT")
960
+ if join.args.get("global"):
961
+ components.append("GLOBAL")
856
962
  if side := join.args.get("side"):
963
+ # For SEMI/ANTI joins, side is optional
857
964
  components.append(side)
858
-
859
- # Add kind if present (e.g. "INNER", "OUTER", "SEMI", "ANTI")
860
965
  if kind := join.args.get("kind"):
861
966
  components.append(kind)
862
967
 
863
- # Join the components and append "JOIN"
864
- if not components:
865
- return "JOIN"
866
- return f"{' '.join(components)} JOIN"
968
+ components.append("JOIN")
969
+ return " ".join(components)
867
970
 
868
971
 
869
972
  def _extract_select_from_create(
@@ -1061,7 +1164,12 @@ def _translate_internal_joins(
1061
1164
  joins.append(
1062
1165
  JoinInfo(
1063
1166
  join_type=raw_join.join_type,
1064
- tables=[table_name_urn_mapping[table] for table in raw_join.tables],
1167
+ left_tables=[
1168
+ table_name_urn_mapping[table] for table in raw_join.left_tables
1169
+ ],
1170
+ right_tables=[
1171
+ table_name_urn_mapping[table] for table in raw_join.right_tables
1172
+ ],
1065
1173
  on_clause=raw_join.on_clause,
1066
1174
  columns_involved=[
1067
1175
  ColumnRef(
@@ -377,7 +377,8 @@ def _maybe_print_upgrade_message(
377
377
  + click.style(
378
378
  f"➡️ Upgrade via \"pip install 'acryl-datahub=={version_stats.server.current.version}'\"",
379
379
  fg="cyan",
380
- )
380
+ ),
381
+ err=True,
381
382
  )
382
383
  elif client_server_compat == 0 and encourage_cli_upgrade:
383
384
  with contextlib.suppress(Exception):
@@ -387,7 +388,8 @@ def _maybe_print_upgrade_message(
387
388
  + click.style(
388
389
  f"You seem to be running an old version of datahub cli: {current_version} {get_days(current_release_date)}. Latest version is {latest_version} {get_days(latest_release_date)}.\nUpgrade via \"pip install -U 'acryl-datahub'\"",
389
390
  fg="cyan",
390
- )
391
+ ),
392
+ err=True,
391
393
  )
392
394
  elif encourage_quickstart_upgrade:
393
395
  try: