acryl-datahub 0.15.0.2rc6__py3-none-any.whl → 0.15.0.2rc7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of acryl-datahub might be problematic. Click here for more details.

@@ -174,6 +174,8 @@ from datahub.utilities.perf_timer import PerfTimer
174
174
  from datahub.utilities.stats_collections import TopKDict
175
175
  from datahub.utilities.urns.dataset_urn import DatasetUrn
176
176
 
177
+ DEFAULT_PAGE_SIZE = 10
178
+
177
179
  try:
178
180
  # On earlier versions of the tableauserverclient, the NonXMLResponseError
179
181
  # was thrown when reauthentication was necessary. We'll keep both exceptions
@@ -342,11 +344,140 @@ class PermissionIngestionConfig(ConfigModel):
342
344
  )
343
345
 
344
346
 
347
+ class TableauPageSizeConfig(ConfigModel):
348
+ """
349
+ Configuration for setting page sizes for different Tableau metadata objects.
350
+
351
+ Some considerations:
352
+ - All have default values, so no setting is mandatory.
353
+ - In general, with the `effective_` methods, if not specifically set fine-grained metrics fallback to `page_size`
354
+ or correlate with `page_size`.
355
+
356
+ Measuring the impact of changing these values can be done by looking at the
357
+ `num_(filter_|paginated_)?queries_by_connection_type` metrics in the report.
358
+ """
359
+
360
+ page_size: int = Field(
361
+ default=DEFAULT_PAGE_SIZE,
362
+ description="[advanced] Number of metadata objects (e.g. CustomSQLTable, PublishedDatasource, etc) to query at a time using the Tableau API.",
363
+ )
364
+
365
+ database_server_page_size: Optional[int] = Field(
366
+ default=None,
367
+ description="[advanced] Number of database servers to query at a time using the Tableau API; fallbacks to `page_size` if not set.",
368
+ )
369
+
370
+ @property
371
+ def effective_database_server_page_size(self) -> int:
372
+ return self.database_server_page_size or self.page_size
373
+
374
+ # We've found that even with a small workbook page size (e.g. 10), the Tableau API often
375
+ # returns warnings like this:
376
+ # {
377
+ # 'message': 'Showing partial results. The request exceeded the 20000 node limit. Use pagination, additional filtering, or both in the query to adjust results.',
378
+ # 'extensions': {
379
+ # 'severity': 'WARNING',
380
+ # 'code': 'NODE_LIMIT_EXCEEDED',
381
+ # 'properties': {
382
+ # 'nodeLimit': 20000
383
+ # }
384
+ # }
385
+ # }
386
+ # Reducing the page size for the workbook queries helps to avoid this.
387
+ workbook_page_size: Optional[int] = Field(
388
+ default=1,
389
+ description="[advanced] Number of workbooks to query at a time using the Tableau API; defaults to `1` and fallbacks to `page_size` if not set.",
390
+ )
391
+
392
+ @property
393
+ def effective_workbook_page_size(self) -> int:
394
+ return self.workbook_page_size or self.page_size
395
+
396
+ sheet_page_size: Optional[int] = Field(
397
+ default=None,
398
+ description="[advanced] Number of sheets to query at a time using the Tableau API; fallbacks to `page_size` if not set.",
399
+ )
400
+
401
+ @property
402
+ def effective_sheet_page_size(self) -> int:
403
+ return self.sheet_page_size or self.page_size
404
+
405
+ dashboard_page_size: Optional[int] = Field(
406
+ default=None,
407
+ description="[advanced] Number of dashboards to query at a time using the Tableau API; fallbacks to `page_size` if not set.",
408
+ )
409
+
410
+ @property
411
+ def effective_dashboard_page_size(self) -> int:
412
+ return self.dashboard_page_size or self.page_size
413
+
414
+ embedded_datasource_page_size: Optional[int] = Field(
415
+ default=None,
416
+ description="[advanced] Number of embedded datasources to query at a time using the Tableau API; fallbacks to `page_size` if not set.",
417
+ )
418
+
419
+ @property
420
+ def effective_embedded_datasource_page_size(self) -> int:
421
+ return self.embedded_datasource_page_size or self.page_size
422
+
423
+ # Since the field upstream query was separated from the embedded datasource queries into an independent query,
424
+ # the number of queries increased significantly and so the execution time.
425
+ # To increase the batching and so reduce the number of queries, we can increase the page size for that
426
+ # particular case.
427
+ #
428
+ # That's why unless specifically set, we will effectively use 10 times the page size as the default page size.
429
+ embedded_datasource_field_upstream_page_size: Optional[int] = Field(
430
+ default=None,
431
+ description="[advanced] Number of upstream fields to query at a time for embedded datasources using the Tableau API; fallbacks to `page_size` * 10 if not set.",
432
+ )
433
+
434
+ @property
435
+ def effective_embedded_datasource_field_upstream_page_size(self) -> int:
436
+ return self.embedded_datasource_field_upstream_page_size or self.page_size * 10
437
+
438
+ published_datasource_page_size: Optional[int] = Field(
439
+ default=None,
440
+ description="[advanced] Number of published datasources to query at a time using the Tableau API; fallbacks to `page_size` if not set.",
441
+ )
442
+
443
+ @property
444
+ def effective_published_datasource_page_size(self) -> int:
445
+ return self.published_datasource_page_size or self.page_size
446
+
447
+ published_datasource_field_upstream_page_size: Optional[int] = Field(
448
+ default=None,
449
+ description="[advanced] Number of upstream fields to query at a time for published datasources using the Tableau API; fallbacks to `page_size` * 10 if not set.",
450
+ )
451
+
452
+ @property
453
+ def effective_published_datasource_field_upstream_page_size(self) -> int:
454
+ return self.published_datasource_field_upstream_page_size or self.page_size * 10
455
+
456
+ custom_sql_table_page_size: Optional[int] = Field(
457
+ default=None,
458
+ description="[advanced] Number of custom sql datasources to query at a time using the Tableau API; fallbacks to `page_size` if not set.",
459
+ )
460
+
461
+ @property
462
+ def effective_custom_sql_table_page_size(self) -> int:
463
+ return self.custom_sql_table_page_size or self.page_size
464
+
465
+ database_table_page_size: Optional[int] = Field(
466
+ default=None,
467
+ description="[advanced] Number of database tables to query at a time using the Tableau API; fallbacks to `page_size` if not set.",
468
+ )
469
+
470
+ @property
471
+ def effective_database_table_page_size(self) -> int:
472
+ return self.database_table_page_size or self.page_size
473
+
474
+
345
475
  class TableauConfig(
346
476
  DatasetLineageProviderConfigBase,
347
477
  StatefulIngestionConfigBase,
348
478
  DatasetSourceConfigMixin,
349
479
  TableauConnectionConfig,
480
+ TableauPageSizeConfig,
350
481
  ):
351
482
  projects: Optional[List[str]] = Field(
352
483
  default=["default"],
@@ -396,29 +527,6 @@ class TableauConfig(
396
527
  description="Ingest details for tables external to (not embedded in) tableau as entities.",
397
528
  )
398
529
 
399
- page_size: int = Field(
400
- default=10,
401
- description="[advanced] Number of metadata objects (e.g. CustomSQLTable, PublishedDatasource, etc) to query at a time using the Tableau API.",
402
- )
403
-
404
- # We've found that even with a small workbook page size (e.g. 10), the Tableau API often
405
- # returns warnings like this:
406
- # {
407
- # 'message': 'Showing partial results. The request exceeded the 20000 node limit. Use pagination, additional filtering, or both in the query to adjust results.',
408
- # 'extensions': {
409
- # 'severity': 'WARNING',
410
- # 'code': 'NODE_LIMIT_EXCEEDED',
411
- # 'properties': {
412
- # 'nodeLimit': 20000
413
- # }
414
- # }
415
- # }
416
- # Reducing the page size for the workbook queries helps to avoid this.
417
- workbook_page_size: int = Field(
418
- default=1,
419
- description="[advanced] Number of workbooks to query at a time using the Tableau API.",
420
- )
421
-
422
530
  env: str = Field(
423
531
  default=builder.DEFAULT_ENV,
424
532
  description="Environment to use in namespace when constructing URNs.",
@@ -700,6 +808,23 @@ class TableauSourceReport(
700
808
  default_factory=(lambda: defaultdict(int))
701
809
  )
702
810
 
811
+ # Counters for tracking the number of queries made to get_connection_objects method
812
+ # by connection type (static and short set of keys):
813
+ # - num_queries_by_connection_type: total number of queries
814
+ # - num_filter_queries_by_connection_type: number of paginated queries due to splitting query filters
815
+ # - num_paginated_queries_by_connection_type: total number of queries due to Tableau pagination
816
+ # These counters are useful to understand the impact of changing the page size.
817
+
818
+ num_queries_by_connection_type: Dict[str, int] = dataclass_field(
819
+ default_factory=(lambda: defaultdict(int))
820
+ )
821
+ num_filter_queries_by_connection_type: Dict[str, int] = dataclass_field(
822
+ default_factory=(lambda: defaultdict(int))
823
+ )
824
+ num_paginated_queries_by_connection_type: Dict[str, int] = dataclass_field(
825
+ default_factory=(lambda: defaultdict(int))
826
+ )
827
+
703
828
 
704
829
  def report_user_role(report: TableauSourceReport, server: Server) -> None:
705
830
  title: str = "Insufficient Permissions"
@@ -994,7 +1119,9 @@ class TableauSiteSource:
994
1119
  return server_connection
995
1120
 
996
1121
  for database_server in self.get_connection_objects(
997
- database_servers_graphql_query, c.DATABASE_SERVERS_CONNECTION
1122
+ query=database_servers_graphql_query,
1123
+ connection_type=c.DATABASE_SERVERS_CONNECTION,
1124
+ page_size=self.config.effective_database_server_page_size,
998
1125
  ):
999
1126
  database_server_id = database_server.get(c.ID)
1000
1127
  server_connection = database_server.get(c.HOST_NAME)
@@ -1420,22 +1547,30 @@ class TableauSiteSource:
1420
1547
  self,
1421
1548
  query: str,
1422
1549
  connection_type: str,
1550
+ page_size: int,
1423
1551
  query_filter: dict = {},
1424
- page_size_override: Optional[int] = None,
1425
1552
  ) -> Iterable[dict]:
1426
1553
  query_filter = optimize_query_filter(query_filter)
1427
1554
 
1428
1555
  # Calls the get_connection_object_page function to get the objects,
1429
1556
  # and automatically handles pagination.
1430
- page_size = page_size_override or self.config.page_size
1431
1557
 
1432
1558
  filter_pages = get_filter_pages(query_filter, page_size)
1559
+ self.report.num_queries_by_connection_type[connection_type] += 1
1560
+ self.report.num_filter_queries_by_connection_type[connection_type] += len(
1561
+ filter_pages
1562
+ )
1563
+
1433
1564
  for filter_page in filter_pages:
1434
1565
  has_next_page = 1
1435
1566
  current_cursor: Optional[str] = None
1436
1567
  while has_next_page:
1437
1568
  filter_: str = make_filter(filter_page)
1438
1569
 
1570
+ self.report.num_paginated_queries_by_connection_type[
1571
+ connection_type
1572
+ ] += 1
1573
+
1439
1574
  self.report.num_expected_tableau_metadata_queries += 1
1440
1575
  (
1441
1576
  connection_objects,
@@ -1463,10 +1598,10 @@ class TableauSiteSource:
1463
1598
  projects = {c.PROJECT_NAME_WITH_IN: project_names}
1464
1599
 
1465
1600
  for workbook in self.get_connection_objects(
1466
- workbook_graphql_query,
1467
- c.WORKBOOKS_CONNECTION,
1468
- projects,
1469
- page_size_override=self.config.workbook_page_size,
1601
+ query=workbook_graphql_query,
1602
+ connection_type=c.WORKBOOKS_CONNECTION,
1603
+ query_filter=projects,
1604
+ page_size=self.config.effective_workbook_page_size,
1470
1605
  ):
1471
1606
  # This check is needed as we are using projectNameWithin which return project as per project name so if
1472
1607
  # user want to ingest only nested project C from A->B->C then tableau might return more than one Project
@@ -1921,9 +2056,10 @@ class TableauSiteSource:
1921
2056
 
1922
2057
  custom_sql_connection = list(
1923
2058
  self.get_connection_objects(
1924
- custom_sql_graphql_query,
1925
- c.CUSTOM_SQL_TABLE_CONNECTION,
1926
- custom_sql_filter,
2059
+ query=custom_sql_graphql_query,
2060
+ connection_type=c.CUSTOM_SQL_TABLE_CONNECTION,
2061
+ query_filter=custom_sql_filter,
2062
+ page_size=self.config.effective_custom_sql_table_page_size,
1927
2063
  )
1928
2064
  )
1929
2065
 
@@ -2632,6 +2768,7 @@ class TableauSiteSource:
2632
2768
  self,
2633
2769
  datasource: dict,
2634
2770
  field_upstream_query: str,
2771
+ page_size: int,
2635
2772
  ) -> dict:
2636
2773
  # Collect field ids to fetch field upstreams
2637
2774
  field_ids: List[str] = []
@@ -2642,9 +2779,10 @@ class TableauSiteSource:
2642
2779
  # Fetch field upstreams and arrange them in map
2643
2780
  field_vs_upstream: Dict[str, dict] = {}
2644
2781
  for field_upstream in self.get_connection_objects(
2645
- field_upstream_query,
2646
- c.FIELDS_CONNECTION,
2647
- {c.ID_WITH_IN: field_ids},
2782
+ query=field_upstream_query,
2783
+ connection_type=c.FIELDS_CONNECTION,
2784
+ query_filter={c.ID_WITH_IN: field_ids},
2785
+ page_size=page_size,
2648
2786
  ):
2649
2787
  if field_upstream.get(c.ID):
2650
2788
  field_id = field_upstream[c.ID]
@@ -2667,13 +2805,15 @@ class TableauSiteSource:
2667
2805
  datasource_filter = {c.ID_WITH_IN: self.datasource_ids_being_used}
2668
2806
 
2669
2807
  for datasource in self.get_connection_objects(
2670
- published_datasource_graphql_query,
2671
- c.PUBLISHED_DATA_SOURCES_CONNECTION,
2672
- datasource_filter,
2808
+ query=published_datasource_graphql_query,
2809
+ connection_type=c.PUBLISHED_DATA_SOURCES_CONNECTION,
2810
+ query_filter=datasource_filter,
2811
+ page_size=self.config.effective_published_datasource_page_size,
2673
2812
  ):
2674
2813
  datasource = self.update_datasource_for_field_upstream(
2675
2814
  datasource=datasource,
2676
2815
  field_upstream_query=datasource_upstream_fields_graphql_query,
2816
+ page_size=self.config.effective_published_datasource_field_upstream_page_size,
2677
2817
  )
2678
2818
 
2679
2819
  yield from self.emit_datasource(datasource)
@@ -2689,11 +2829,12 @@ class TableauSiteSource:
2689
2829
  c.ID_WITH_IN: list(tableau_database_table_id_to_urn_map.keys())
2690
2830
  }
2691
2831
 
2692
- # Emmitting tables that came from Tableau metadata
2832
+ # Emitting tables that came from Tableau metadata
2693
2833
  for tableau_table in self.get_connection_objects(
2694
- database_tables_graphql_query,
2695
- c.DATABASE_TABLES_CONNECTION,
2696
- tables_filter,
2834
+ query=database_tables_graphql_query,
2835
+ connection_type=c.DATABASE_TABLES_CONNECTION,
2836
+ query_filter=tables_filter,
2837
+ page_size=self.config.effective_database_table_page_size,
2697
2838
  ):
2698
2839
  database_table = self.database_tables[
2699
2840
  tableau_database_table_id_to_urn_map[tableau_table[c.ID]]
@@ -2882,9 +3023,10 @@ class TableauSiteSource:
2882
3023
  sheets_filter = {c.ID_WITH_IN: self.sheet_ids}
2883
3024
 
2884
3025
  for sheet in self.get_connection_objects(
2885
- sheet_graphql_query,
2886
- c.SHEETS_CONNECTION,
2887
- sheets_filter,
3026
+ query=sheet_graphql_query,
3027
+ connection_type=c.SHEETS_CONNECTION,
3028
+ query_filter=sheets_filter,
3029
+ page_size=self.config.effective_sheet_page_size,
2888
3030
  ):
2889
3031
  if self.config.ingest_hidden_assets or not self._is_hidden_view(sheet):
2890
3032
  yield from self.emit_sheets_as_charts(sheet, sheet.get(c.WORKBOOK))
@@ -3202,9 +3344,10 @@ class TableauSiteSource:
3202
3344
  dashboards_filter = {c.ID_WITH_IN: self.dashboard_ids}
3203
3345
 
3204
3346
  for dashboard in self.get_connection_objects(
3205
- dashboard_graphql_query,
3206
- c.DASHBOARDS_CONNECTION,
3207
- dashboards_filter,
3347
+ query=dashboard_graphql_query,
3348
+ connection_type=c.DASHBOARDS_CONNECTION,
3349
+ query_filter=dashboards_filter,
3350
+ page_size=self.config.effective_dashboard_page_size,
3208
3351
  ):
3209
3352
  if self.config.ingest_hidden_assets or not self._is_hidden_view(dashboard):
3210
3353
  yield from self.emit_dashboard(dashboard, dashboard.get(c.WORKBOOK))
@@ -3349,13 +3492,15 @@ class TableauSiteSource:
3349
3492
  datasource_filter = {c.ID_WITH_IN: self.embedded_datasource_ids_being_used}
3350
3493
 
3351
3494
  for datasource in self.get_connection_objects(
3352
- embedded_datasource_graphql_query,
3353
- c.EMBEDDED_DATA_SOURCES_CONNECTION,
3354
- datasource_filter,
3495
+ query=embedded_datasource_graphql_query,
3496
+ connection_type=c.EMBEDDED_DATA_SOURCES_CONNECTION,
3497
+ query_filter=datasource_filter,
3498
+ page_size=self.config.effective_embedded_datasource_page_size,
3355
3499
  ):
3356
3500
  datasource = self.update_datasource_for_field_upstream(
3357
3501
  datasource=datasource,
3358
3502
  field_upstream_query=datasource_upstream_fields_graphql_query,
3503
+ page_size=self.config.effective_embedded_datasource_field_upstream_page_size,
3359
3504
  )
3360
3505
  yield from self.emit_datasource(
3361
3506
  datasource,
@@ -642,8 +642,11 @@ class TableauUpstreamReference:
642
642
 
643
643
  @classmethod
644
644
  def create(
645
- cls, d: dict, default_schema_map: Optional[Dict[str, str]] = None
645
+ cls, d: Dict, default_schema_map: Optional[Dict[str, str]] = None
646
646
  ) -> "TableauUpstreamReference":
647
+ if d is None:
648
+ raise ValueError("TableauUpstreamReference.create: d is None")
649
+
647
650
  # Values directly from `table` object from Tableau
648
651
  database_dict = (
649
652
  d.get(c.DATABASE) or {}
@@ -717,7 +720,7 @@ class TableauUpstreamReference:
717
720
  # schema
718
721
 
719
722
  # TODO: Validate the startswith check. Currently required for our integration tests
720
- if full_name is None or not full_name.startswith("["):
723
+ if full_name is None:
721
724
  return None
722
725
 
723
726
  return full_name.replace("[", "").replace("]", "").split(".")
@@ -11,34 +11,25 @@ class DataHubSecretsClient:
11
11
  def __init__(self, graph: DataHubGraph):
12
12
  self.graph = graph
13
13
 
14
+ def _cleanup_secret_name(self, secret_names: List[str]) -> List[str]:
15
+ """Remove empty strings from the list of secret names."""
16
+ return [secret_name for secret_name in secret_names if secret_name]
17
+
14
18
  def get_secret_values(self, secret_names: List[str]) -> Dict[str, Optional[str]]:
15
19
  if len(secret_names) == 0:
16
20
  return {}
17
21
 
18
- request_json = {
19
- "query": """query getSecretValues($input: GetSecretValuesInput!) {\n
20
- getSecretValues(input: $input) {\n
21
- name\n
22
- value\n
23
- }\n
22
+ res_data = self.graph.execute_graphql(
23
+ query="""query getSecretValues($input: GetSecretValuesInput!) {
24
+ getSecretValues(input: $input) {
25
+ name
26
+ value
27
+ }
24
28
  }""",
25
- "variables": {"input": {"secrets": secret_names}},
26
- }
27
- # TODO: Use graph.execute_graphql() instead.
28
-
29
- # Fetch secrets using GraphQL API f
30
- response = self.graph._session.post(
31
- f"{self.graph.config.server}/api/graphql", json=request_json
29
+ variables={"input": {"secrets": self._cleanup_secret_name(secret_names)}},
32
30
  )
33
- response.raise_for_status()
34
-
35
- # Verify response
36
- res_data = response.json()
37
- if "errors" in res_data:
38
- raise Exception("Failed to retrieve secrets from DataHub.")
39
-
40
31
  # Convert list of name, value secret pairs into a dict and return
41
- secret_value_list = res_data["data"]["getSecretValues"]
32
+ secret_value_list = res_data["getSecretValues"]
42
33
  secret_value_dict = dict()
43
34
  for secret_value in secret_value_list:
44
35
  secret_value_dict[secret_value["name"]] = secret_value["value"]
@@ -284,6 +284,7 @@ class SqlAggregatorReport(Report):
284
284
 
285
285
  # Queries.
286
286
  num_queries_entities_generated: int = 0
287
+ num_queries_used_in_lineage: Optional[int] = None
287
288
  num_queries_skipped_due_to_filters: int = 0
288
289
 
289
290
  # Usage-related.
@@ -1200,6 +1201,7 @@ class SqlParsingAggregator(Closeable):
1200
1201
  queries_generated: Set[QueryId] = set()
1201
1202
 
1202
1203
  yield from self._gen_lineage_mcps(queries_generated)
1204
+ self.report.num_queries_used_in_lineage = len(queries_generated)
1203
1205
  yield from self._gen_usage_statistics_mcps()
1204
1206
  yield from self._gen_operation_mcps(queries_generated)
1205
1207
  yield from self._gen_remaining_queries(queries_generated)
@@ -1,10 +1,9 @@
1
- from datahub.sql_parsing._sqlglot_patch import SQLGLOT_PATCHED
2
-
3
1
  import dataclasses
4
2
  import functools
5
3
  import logging
6
4
  import traceback
7
5
  from collections import defaultdict
6
+ from datahub.sql_parsing._sqlglot_patch import SQLGLOT_PATCHED
8
7
  from typing import Any, Dict, List, Optional, Set, Tuple, TypeVar, Union
9
8
 
10
9
  import pydantic.dataclasses
@@ -1,9 +1,8 @@
1
- from datahub.sql_parsing._sqlglot_patch import SQLGLOT_PATCHED
2
-
3
1
  import functools
4
2
  import hashlib
5
3
  import logging
6
4
  import re
5
+ from datahub.sql_parsing._sqlglot_patch import SQLGLOT_PATCHED
7
6
  from typing import Dict, Iterable, Optional, Tuple, Union
8
7
 
9
8
  import sqlglot
@@ -1,7 +1,7 @@
1
1
  from collections import deque
2
2
  from itertools import chain
3
3
  from sys import getsizeof
4
- from typing import Any, Callable
4
+ from typing import Any, Iterator
5
5
 
6
6
 
7
7
  def total_size(o: Any, handlers: Any = {}) -> int:
@@ -15,7 +15,8 @@ def total_size(o: Any, handlers: Any = {}) -> int:
15
15
  Based on https://github.com/ActiveState/recipe-577504-compute-mem-footprint/blob/master/recipe.py
16
16
  """
17
17
 
18
- dict_handler: Callable[[Any], chain[Any]] = lambda d: chain.from_iterable(d.items())
18
+ def dict_handler(d: dict) -> Iterator[Any]:
19
+ return chain.from_iterable(d.items())
19
20
 
20
21
  all_handlers = {
21
22
  tuple: iter,
@@ -1,7 +1,7 @@
1
1
  import functools
2
2
  import urllib.parse
3
3
  from abc import abstractmethod
4
- from typing import ClassVar, Dict, List, Optional, Type
4
+ from typing import ClassVar, Dict, List, Optional, Type, Union
5
5
 
6
6
  from deprecated import deprecated
7
7
  from typing_extensions import Self
@@ -86,12 +86,24 @@ class Urn:
86
86
  return self._entity_ids
87
87
 
88
88
  @classmethod
89
- def from_string(cls, urn_str: str) -> Self:
90
- """
91
- Creates an Urn from its string representation.
89
+ def from_string(cls, urn_str: Union[str, "Urn"], /) -> Self:
90
+ """Create an Urn from its string representation.
91
+
92
+ When called against the base Urn class, this method will return a more specific Urn type where possible.
93
+
94
+ >>> from datahub.metadata.urns import DatasetUrn, Urn
95
+ >>> urn_str = 'urn:li:dataset:(urn:li:dataPlatform:snowflake,my_db.my_schema.my_table,PROD)'
96
+ >>> urn = Urn.from_string(urn_str)
97
+ >>> assert isinstance(urn, DatasetUrn)
98
+
99
+ When called against a specific Urn type (e.g. DatasetUrn.from_string), this method can
100
+ also be used for type narrowing.
101
+
102
+ >>> urn_str = 'urn:li:dataset:(urn:li:dataPlatform:snowflake,my_db.my_schema.my_table,PROD)'
103
+ >>> assert DatasetUrn.from_string(urn_str)
92
104
 
93
105
  Args:
94
- urn_str: The string representation of the Urn.
106
+ urn_str: The string representation of the urn. Also accepts an existing Urn instance.
95
107
 
96
108
  Returns:
97
109
  Urn of the given string representation.
@@ -100,6 +112,17 @@ class Urn:
100
112
  InvalidUrnError: If the string representation is in invalid format.
101
113
  """
102
114
 
115
+ if isinstance(urn_str, Urn):
116
+ if issubclass(cls, _SpecificUrn) and isinstance(urn_str, cls):
117
+ # Fast path - we're already the right type.
118
+
119
+ # I'm not really sure why we need a type ignore here, but mypy doesn't really
120
+ # understand the isinstance check above.
121
+ return urn_str # type: ignore
122
+
123
+ # Fall through, so that we can convert a generic Urn to a specific Urn type.
124
+ urn_str = urn_str.urn()
125
+
103
126
  # TODO: Add handling for url encoded urns e.g. urn%3A ...
104
127
 
105
128
  if not urn_str.startswith("urn:li:"):