acryl-datahub 1.2.0.7rc1__py3-none-any.whl → 1.2.0.7rc3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of acryl-datahub might be problematic. Click here for more details.
- {acryl_datahub-1.2.0.7rc1.dist-info → acryl_datahub-1.2.0.7rc3.dist-info}/METADATA +2485 -2463
- {acryl_datahub-1.2.0.7rc1.dist-info → acryl_datahub-1.2.0.7rc3.dist-info}/RECORD +26 -27
- datahub/_version.py +1 -1
- datahub/ingestion/source/redshift/config.py +9 -6
- datahub/ingestion/source/redshift/lineage.py +386 -687
- datahub/ingestion/source/redshift/redshift.py +19 -106
- datahub/ingestion/source/snowflake/constants.py +2 -0
- datahub/ingestion/source/snowflake/snowflake_connection.py +15 -4
- datahub/ingestion/source/snowflake/snowflake_schema_gen.py +4 -1
- datahub/ingestion/source/snowflake/snowflake_utils.py +18 -5
- datahub/ingestion/source/snowflake/snowflake_v2.py +2 -0
- datahub/ingestion/source/sql/mssql/job_models.py +3 -1
- datahub/ingestion/source/sql/mssql/source.py +62 -3
- datahub/ingestion/source/unity/config.py +11 -0
- datahub/ingestion/source/unity/proxy.py +77 -0
- datahub/ingestion/source/unity/proxy_types.py +24 -0
- datahub/ingestion/source/unity/report.py +5 -0
- datahub/ingestion/source/unity/source.py +99 -1
- datahub/metadata/_internal_schema_classes.py +5 -5
- datahub/metadata/schema.avsc +66 -60
- datahub/metadata/schemas/LogicalParent.avsc +104 -100
- datahub/metadata/schemas/SchemaFieldKey.avsc +3 -1
- datahub/ingestion/source/redshift/lineage_v2.py +0 -466
- {acryl_datahub-1.2.0.7rc1.dist-info → acryl_datahub-1.2.0.7rc3.dist-info}/WHEEL +0 -0
- {acryl_datahub-1.2.0.7rc1.dist-info → acryl_datahub-1.2.0.7rc3.dist-info}/entry_points.txt +0 -0
- {acryl_datahub-1.2.0.7rc1.dist-info → acryl_datahub-1.2.0.7rc3.dist-info}/licenses/LICENSE +0 -0
- {acryl_datahub-1.2.0.7rc1.dist-info → acryl_datahub-1.2.0.7rc3.dist-info}/top_level.txt +0 -0
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import functools
|
|
2
|
-
import itertools
|
|
3
2
|
import logging
|
|
4
3
|
from collections import defaultdict
|
|
5
4
|
from typing import Dict, Iterable, List, Optional, Type, Union
|
|
@@ -52,8 +51,7 @@ from datahub.ingestion.source.common.subtypes import (
|
|
|
52
51
|
from datahub.ingestion.source.redshift.config import RedshiftConfig
|
|
53
52
|
from datahub.ingestion.source.redshift.datashares import RedshiftDatasharesHelper
|
|
54
53
|
from datahub.ingestion.source.redshift.exception import handle_redshift_exceptions_yield
|
|
55
|
-
from datahub.ingestion.source.redshift.lineage import
|
|
56
|
-
from datahub.ingestion.source.redshift.lineage_v2 import RedshiftSqlLineageV2
|
|
54
|
+
from datahub.ingestion.source.redshift.lineage import RedshiftSqlLineage
|
|
57
55
|
from datahub.ingestion.source.redshift.profile import RedshiftProfiler
|
|
58
56
|
from datahub.ingestion.source.redshift.redshift_data_reader import RedshiftDataReader
|
|
59
57
|
from datahub.ingestion.source.redshift.redshift_schema import (
|
|
@@ -72,7 +70,6 @@ from datahub.ingestion.source.sql.sql_utils import (
|
|
|
72
70
|
add_table_to_schema_container,
|
|
73
71
|
gen_database_container,
|
|
74
72
|
gen_database_key,
|
|
75
|
-
gen_lineage,
|
|
76
73
|
gen_schema_container,
|
|
77
74
|
gen_schema_key,
|
|
78
75
|
get_dataplatform_instance_aspect,
|
|
@@ -116,7 +113,6 @@ from datahub.metadata.com.linkedin.pegasus2avro.schema import (
|
|
|
116
113
|
)
|
|
117
114
|
from datahub.metadata.schema_classes import GlobalTagsClass, TagAssociationClass
|
|
118
115
|
from datahub.utilities import memory_footprint
|
|
119
|
-
from datahub.utilities.dedup_list import deduplicate_list
|
|
120
116
|
from datahub.utilities.mapping import Constants
|
|
121
117
|
from datahub.utilities.perf_timer import PerfTimer
|
|
122
118
|
from datahub.utilities.registries.domain_registry import DomainRegistry
|
|
@@ -423,40 +419,25 @@ class RedshiftSource(StatefulIngestionSourceBase, TestableSource):
|
|
|
423
419
|
memory_footprint.total_size(self.db_views)
|
|
424
420
|
)
|
|
425
421
|
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
)
|
|
437
|
-
|
|
438
|
-
with self.report.new_stage(LINEAGE_EXTRACTION):
|
|
439
|
-
yield from self.extract_lineage_v2(
|
|
440
|
-
connection=connection,
|
|
441
|
-
database=database,
|
|
442
|
-
lineage_extractor=lineage_extractor,
|
|
443
|
-
)
|
|
444
|
-
|
|
445
|
-
all_tables = self.get_all_tables()
|
|
446
|
-
else:
|
|
447
|
-
yield from self.process_schemas(connection, database)
|
|
422
|
+
with RedshiftSqlLineage(
|
|
423
|
+
config=self.config,
|
|
424
|
+
report=self.report,
|
|
425
|
+
context=self.ctx,
|
|
426
|
+
database=database,
|
|
427
|
+
redundant_run_skip_handler=self.redundant_lineage_run_skip_handler,
|
|
428
|
+
) as lineage_extractor:
|
|
429
|
+
yield from lineage_extractor.aggregator.register_schemas_from_stream(
|
|
430
|
+
self.process_schemas(connection, database)
|
|
431
|
+
)
|
|
448
432
|
|
|
449
|
-
|
|
433
|
+
with self.report.new_stage(LINEAGE_EXTRACTION):
|
|
434
|
+
yield from self.extract_lineage_v2(
|
|
435
|
+
connection=connection,
|
|
436
|
+
database=database,
|
|
437
|
+
lineage_extractor=lineage_extractor,
|
|
438
|
+
)
|
|
450
439
|
|
|
451
|
-
|
|
452
|
-
self.config.include_table_lineage
|
|
453
|
-
or self.config.include_view_lineage
|
|
454
|
-
or self.config.include_copy_lineage
|
|
455
|
-
):
|
|
456
|
-
with self.report.new_stage(LINEAGE_EXTRACTION):
|
|
457
|
-
yield from self.extract_lineage(
|
|
458
|
-
connection=connection, all_tables=all_tables, database=database
|
|
459
|
-
)
|
|
440
|
+
all_tables = self.get_all_tables()
|
|
460
441
|
|
|
461
442
|
if self.config.include_usage_statistics:
|
|
462
443
|
with self.report.new_stage(USAGE_EXTRACTION_INGESTION):
|
|
@@ -968,45 +949,11 @@ class RedshiftSource(StatefulIngestionSourceBase, TestableSource):
|
|
|
968
949
|
|
|
969
950
|
self.report.usage_extraction_sec[database] = timer.elapsed_seconds(digits=2)
|
|
970
951
|
|
|
971
|
-
def extract_lineage(
|
|
972
|
-
self,
|
|
973
|
-
connection: redshift_connector.Connection,
|
|
974
|
-
database: str,
|
|
975
|
-
all_tables: Dict[str, Dict[str, List[Union[RedshiftView, RedshiftTable]]]],
|
|
976
|
-
) -> Iterable[MetadataWorkUnit]:
|
|
977
|
-
if not self._should_ingest_lineage():
|
|
978
|
-
return
|
|
979
|
-
|
|
980
|
-
lineage_extractor = RedshiftLineageExtractor(
|
|
981
|
-
config=self.config,
|
|
982
|
-
report=self.report,
|
|
983
|
-
context=self.ctx,
|
|
984
|
-
redundant_run_skip_handler=self.redundant_lineage_run_skip_handler,
|
|
985
|
-
)
|
|
986
|
-
|
|
987
|
-
with PerfTimer() as timer:
|
|
988
|
-
lineage_extractor.populate_lineage(
|
|
989
|
-
database=database, connection=connection, all_tables=all_tables
|
|
990
|
-
)
|
|
991
|
-
|
|
992
|
-
self.report.lineage_extraction_sec[f"{database}"] = timer.elapsed_seconds(
|
|
993
|
-
digits=2
|
|
994
|
-
)
|
|
995
|
-
yield from self.generate_lineage(
|
|
996
|
-
database, lineage_extractor=lineage_extractor
|
|
997
|
-
)
|
|
998
|
-
|
|
999
|
-
if self.redundant_lineage_run_skip_handler:
|
|
1000
|
-
# Update the checkpoint state for this run.
|
|
1001
|
-
self.redundant_lineage_run_skip_handler.update_state(
|
|
1002
|
-
self.config.start_time, self.config.end_time
|
|
1003
|
-
)
|
|
1004
|
-
|
|
1005
952
|
def extract_lineage_v2(
|
|
1006
953
|
self,
|
|
1007
954
|
connection: redshift_connector.Connection,
|
|
1008
955
|
database: str,
|
|
1009
|
-
lineage_extractor:
|
|
956
|
+
lineage_extractor: RedshiftSqlLineage,
|
|
1010
957
|
) -> Iterable[MetadataWorkUnit]:
|
|
1011
958
|
if self.config.include_share_lineage:
|
|
1012
959
|
outbound_shares = self.data_dictionary.get_outbound_datashares(connection)
|
|
@@ -1069,40 +1016,6 @@ class RedshiftSource(StatefulIngestionSourceBase, TestableSource):
|
|
|
1069
1016
|
|
|
1070
1017
|
return True
|
|
1071
1018
|
|
|
1072
|
-
def generate_lineage(
|
|
1073
|
-
self, database: str, lineage_extractor: RedshiftLineageExtractor
|
|
1074
|
-
) -> Iterable[MetadataWorkUnit]:
|
|
1075
|
-
logger.info(f"Generate lineage for {database}")
|
|
1076
|
-
for schema in deduplicate_list(
|
|
1077
|
-
itertools.chain(self.db_tables[database], self.db_views[database])
|
|
1078
|
-
):
|
|
1079
|
-
if (
|
|
1080
|
-
database not in self.db_schemas
|
|
1081
|
-
or schema not in self.db_schemas[database]
|
|
1082
|
-
):
|
|
1083
|
-
logger.warning(
|
|
1084
|
-
f"Either database {database} or {schema} exists in the lineage but was not discovered earlier. Something went wrong."
|
|
1085
|
-
)
|
|
1086
|
-
continue
|
|
1087
|
-
|
|
1088
|
-
table_or_view: Union[RedshiftTable, RedshiftView]
|
|
1089
|
-
for table_or_view in (
|
|
1090
|
-
[]
|
|
1091
|
-
+ self.db_tables[database].get(schema, [])
|
|
1092
|
-
+ self.db_views[database].get(schema, [])
|
|
1093
|
-
):
|
|
1094
|
-
datahub_dataset_name = f"{database}.{schema}.{table_or_view.name}"
|
|
1095
|
-
dataset_urn = self.gen_dataset_urn(datahub_dataset_name)
|
|
1096
|
-
|
|
1097
|
-
lineage_info = lineage_extractor.get_lineage(
|
|
1098
|
-
table_or_view,
|
|
1099
|
-
dataset_urn,
|
|
1100
|
-
self.db_schemas[database][schema],
|
|
1101
|
-
)
|
|
1102
|
-
if lineage_info:
|
|
1103
|
-
# incremental lineage generation is taken care by auto_incremental_lineage
|
|
1104
|
-
yield from gen_lineage(dataset_urn, lineage_info)
|
|
1105
|
-
|
|
1106
1019
|
def add_config_to_report(self):
|
|
1107
1020
|
self.report.stateful_lineage_ingestion_enabled = (
|
|
1108
1021
|
self.config.enable_stateful_lineage_ingestion
|
|
@@ -22,6 +22,7 @@ from datahub.ingestion.api.closeable import Closeable
|
|
|
22
22
|
from datahub.ingestion.source.snowflake.constants import (
|
|
23
23
|
CLIENT_PREFETCH_THREADS,
|
|
24
24
|
CLIENT_SESSION_KEEP_ALIVE,
|
|
25
|
+
DEFAULT_SNOWFLAKE_DOMAIN,
|
|
25
26
|
)
|
|
26
27
|
from datahub.ingestion.source.snowflake.oauth_config import (
|
|
27
28
|
OAuthConfiguration,
|
|
@@ -47,8 +48,6 @@ _VALID_AUTH_TYPES: Dict[str, str] = {
|
|
|
47
48
|
"OAUTH_AUTHENTICATOR_TOKEN": OAUTH_AUTHENTICATOR,
|
|
48
49
|
}
|
|
49
50
|
|
|
50
|
-
_SNOWFLAKE_HOST_SUFFIX = ".snowflakecomputing.com"
|
|
51
|
-
|
|
52
51
|
|
|
53
52
|
class SnowflakePermissionError(MetaError):
|
|
54
53
|
"""A permission error has happened"""
|
|
@@ -110,6 +109,10 @@ class SnowflakeConnectionConfig(ConfigModel):
|
|
|
110
109
|
default=None,
|
|
111
110
|
description="OAuth token from external identity provider. Not recommended for most use cases because it will not be able to refresh once expired.",
|
|
112
111
|
)
|
|
112
|
+
snowflake_domain: str = pydantic.Field(
|
|
113
|
+
default=DEFAULT_SNOWFLAKE_DOMAIN,
|
|
114
|
+
description="Snowflake domain. Use 'snowflakecomputing.com' for most regions or 'snowflakecomputing.cn' for China (cn-northwest-1) region.",
|
|
115
|
+
)
|
|
113
116
|
|
|
114
117
|
def get_account(self) -> str:
|
|
115
118
|
assert self.account_id
|
|
@@ -118,10 +121,13 @@ class SnowflakeConnectionConfig(ConfigModel):
|
|
|
118
121
|
rename_host_port_to_account_id = pydantic_renamed_field("host_port", "account_id")
|
|
119
122
|
|
|
120
123
|
@pydantic.validator("account_id")
|
|
121
|
-
def validate_account_id(cls, account_id: str) -> str:
|
|
124
|
+
def validate_account_id(cls, account_id: str, values: Dict) -> str:
|
|
122
125
|
account_id = remove_protocol(account_id)
|
|
123
126
|
account_id = remove_trailing_slashes(account_id)
|
|
124
|
-
|
|
127
|
+
# Get the domain from config, fallback to default
|
|
128
|
+
domain = values.get("snowflake_domain", DEFAULT_SNOWFLAKE_DOMAIN)
|
|
129
|
+
snowflake_host_suffix = f".{domain}"
|
|
130
|
+
account_id = remove_suffix(account_id, snowflake_host_suffix)
|
|
125
131
|
return account_id
|
|
126
132
|
|
|
127
133
|
@pydantic.validator("authentication_type", always=True)
|
|
@@ -311,6 +317,7 @@ class SnowflakeConnectionConfig(ConfigModel):
|
|
|
311
317
|
warehouse=self.warehouse,
|
|
312
318
|
authenticator=_VALID_AUTH_TYPES.get(self.authentication_type),
|
|
313
319
|
application=_APPLICATION_NAME,
|
|
320
|
+
host=f"{self.account_id}.{self.snowflake_domain}",
|
|
314
321
|
**connect_args,
|
|
315
322
|
)
|
|
316
323
|
|
|
@@ -324,6 +331,7 @@ class SnowflakeConnectionConfig(ConfigModel):
|
|
|
324
331
|
role=self.role,
|
|
325
332
|
authenticator=_VALID_AUTH_TYPES.get(self.authentication_type),
|
|
326
333
|
application=_APPLICATION_NAME,
|
|
334
|
+
host=f"{self.account_id}.{self.snowflake_domain}",
|
|
327
335
|
**connect_args,
|
|
328
336
|
)
|
|
329
337
|
|
|
@@ -337,6 +345,7 @@ class SnowflakeConnectionConfig(ConfigModel):
|
|
|
337
345
|
warehouse=self.warehouse,
|
|
338
346
|
role=self.role,
|
|
339
347
|
application=_APPLICATION_NAME,
|
|
348
|
+
host=f"{self.account_id}.{self.snowflake_domain}",
|
|
340
349
|
**connect_args,
|
|
341
350
|
)
|
|
342
351
|
elif self.authentication_type == "OAUTH_AUTHENTICATOR_TOKEN":
|
|
@@ -348,6 +357,7 @@ class SnowflakeConnectionConfig(ConfigModel):
|
|
|
348
357
|
warehouse=self.warehouse,
|
|
349
358
|
role=self.role,
|
|
350
359
|
application=_APPLICATION_NAME,
|
|
360
|
+
host=f"{self.account_id}.{self.snowflake_domain}",
|
|
351
361
|
**connect_args,
|
|
352
362
|
)
|
|
353
363
|
elif self.authentication_type == "OAUTH_AUTHENTICATOR":
|
|
@@ -363,6 +373,7 @@ class SnowflakeConnectionConfig(ConfigModel):
|
|
|
363
373
|
role=self.role,
|
|
364
374
|
authenticator=_VALID_AUTH_TYPES.get(self.authentication_type),
|
|
365
375
|
application=_APPLICATION_NAME,
|
|
376
|
+
host=f"{self.account_id}.{self.snowflake_domain}",
|
|
366
377
|
**connect_args,
|
|
367
378
|
)
|
|
368
379
|
else:
|
|
@@ -441,13 +441,16 @@ class SnowflakeSchemaGenerator(SnowflakeStructuredReportMixin):
|
|
|
441
441
|
tables = self.fetch_tables_for_schema(
|
|
442
442
|
snowflake_schema, db_name, schema_name
|
|
443
443
|
)
|
|
444
|
+
if self.config.include_views:
|
|
445
|
+
views = self.fetch_views_for_schema(snowflake_schema, db_name, schema_name)
|
|
446
|
+
|
|
447
|
+
if self.config.include_tables:
|
|
444
448
|
db_tables[schema_name] = tables
|
|
445
449
|
yield from self._process_tables(
|
|
446
450
|
tables, snowflake_schema, db_name, schema_name
|
|
447
451
|
)
|
|
448
452
|
|
|
449
453
|
if self.config.include_views:
|
|
450
|
-
views = self.fetch_views_for_schema(snowflake_schema, db_name, schema_name)
|
|
451
454
|
yield from self._process_views(
|
|
452
455
|
views, snowflake_schema, db_name, schema_name
|
|
453
456
|
)
|
|
@@ -9,6 +9,7 @@ from datahub.emitter.mce_builder import (
|
|
|
9
9
|
from datahub.emitter.mcp_builder import DatabaseKey, SchemaKey
|
|
10
10
|
from datahub.ingestion.api.source import SourceReport
|
|
11
11
|
from datahub.ingestion.source.snowflake.constants import (
|
|
12
|
+
DEFAULT_SNOWFLAKE_DOMAIN,
|
|
12
13
|
SNOWFLAKE_REGION_CLOUD_REGION_MAPPING,
|
|
13
14
|
SnowflakeCloudProvider,
|
|
14
15
|
SnowflakeObjectDomain,
|
|
@@ -34,16 +35,21 @@ class SnowsightUrlBuilder:
|
|
|
34
35
|
"us-east-1",
|
|
35
36
|
"eu-west-1",
|
|
36
37
|
"eu-central-1",
|
|
37
|
-
"ap-southeast-1",
|
|
38
38
|
"ap-southeast-2",
|
|
39
39
|
]
|
|
40
40
|
|
|
41
41
|
snowsight_base_url: str
|
|
42
42
|
|
|
43
|
-
def __init__(
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
account_locator: str,
|
|
46
|
+
region: str,
|
|
47
|
+
privatelink: bool = False,
|
|
48
|
+
snowflake_domain: str = DEFAULT_SNOWFLAKE_DOMAIN,
|
|
49
|
+
):
|
|
44
50
|
cloud, cloud_region_id = self.get_cloud_region_from_snowflake_region_id(region)
|
|
45
51
|
self.snowsight_base_url = self.create_snowsight_base_url(
|
|
46
|
-
account_locator, cloud_region_id, cloud, privatelink
|
|
52
|
+
account_locator, cloud_region_id, cloud, privatelink, snowflake_domain
|
|
47
53
|
)
|
|
48
54
|
|
|
49
55
|
@staticmethod
|
|
@@ -52,6 +58,7 @@ class SnowsightUrlBuilder:
|
|
|
52
58
|
cloud_region_id: str,
|
|
53
59
|
cloud: str,
|
|
54
60
|
privatelink: bool = False,
|
|
61
|
+
snowflake_domain: str = DEFAULT_SNOWFLAKE_DOMAIN,
|
|
55
62
|
) -> str:
|
|
56
63
|
if cloud:
|
|
57
64
|
url_cloud_provider_suffix = f".{cloud}"
|
|
@@ -67,9 +74,15 @@ class SnowsightUrlBuilder:
|
|
|
67
74
|
else:
|
|
68
75
|
url_cloud_provider_suffix = f".{cloud}"
|
|
69
76
|
if privatelink:
|
|
70
|
-
url = f"https://app.{account_locator}.{cloud_region_id}.privatelink.
|
|
77
|
+
url = f"https://app.{account_locator}.{cloud_region_id}.privatelink.{snowflake_domain}/"
|
|
71
78
|
else:
|
|
72
|
-
|
|
79
|
+
# Standard Snowsight URL format - works for most regions
|
|
80
|
+
# China region may use app.snowflake.cn instead of app.snowflake.com. This is not documented, just
|
|
81
|
+
# guessing Based on existence of snowflake.cn domain (https://domainindex.com/domains/snowflake.cn)
|
|
82
|
+
if snowflake_domain == "snowflakecomputing.cn":
|
|
83
|
+
url = f"https://app.snowflake.cn/{cloud_region_id}{url_cloud_provider_suffix}/{account_locator}/"
|
|
84
|
+
else:
|
|
85
|
+
url = f"https://app.snowflake.com/{cloud_region_id}{url_cloud_provider_suffix}/{account_locator}/"
|
|
73
86
|
return url
|
|
74
87
|
|
|
75
88
|
@staticmethod
|
|
@@ -199,6 +199,7 @@ class SnowflakeV2Source(
|
|
|
199
199
|
),
|
|
200
200
|
generate_usage_statistics=False,
|
|
201
201
|
generate_operations=False,
|
|
202
|
+
generate_queries=self.config.include_queries,
|
|
202
203
|
format_queries=self.config.format_sql_queries,
|
|
203
204
|
is_temp_table=self._is_temp_table,
|
|
204
205
|
is_allowed_table=self._is_allowed_table,
|
|
@@ -750,6 +751,7 @@ class SnowflakeV2Source(
|
|
|
750
751
|
# For privatelink, account identifier ends with .privatelink
|
|
751
752
|
# See https://docs.snowflake.com/en/user-guide/organizations-connect.html#private-connectivity-urls
|
|
752
753
|
privatelink=self.config.account_id.endswith(".privatelink"),
|
|
754
|
+
snowflake_domain=self.config.snowflake_domain,
|
|
753
755
|
)
|
|
754
756
|
|
|
755
757
|
except Exception as e:
|
|
@@ -134,7 +134,9 @@ class StoredProcedure:
|
|
|
134
134
|
|
|
135
135
|
@property
|
|
136
136
|
def escape_full_name(self) -> str:
|
|
137
|
-
return f"[{self.db}].[{self.schema}].[{self.formatted_name}]"
|
|
137
|
+
return f"[{self.db}].[{self.schema}].[{self.formatted_name}]".replace(
|
|
138
|
+
"'", r"''"
|
|
139
|
+
)
|
|
138
140
|
|
|
139
141
|
def to_base_procedure(self) -> BaseProcedure:
|
|
140
142
|
return BaseProcedure(
|
|
@@ -10,6 +10,7 @@ from sqlalchemy import create_engine, inspect
|
|
|
10
10
|
from sqlalchemy.engine.base import Connection
|
|
11
11
|
from sqlalchemy.engine.reflection import Inspector
|
|
12
12
|
from sqlalchemy.exc import ProgrammingError, ResourceClosedError
|
|
13
|
+
from sqlalchemy.sql import quoted_name
|
|
13
14
|
|
|
14
15
|
import datahub.metadata.schema_classes as models
|
|
15
16
|
from datahub.configuration.common import AllowDenyPattern
|
|
@@ -130,10 +131,14 @@ class SQLServerConfig(BasicSQLAlchemyConfig):
|
|
|
130
131
|
"match the entire table name in database.schema.table format. Defaults are to set in such a way "
|
|
131
132
|
"to ignore the temporary staging tables created by known ETL tools.",
|
|
132
133
|
)
|
|
134
|
+
quote_schemas: bool = Field(
|
|
135
|
+
default=False,
|
|
136
|
+
description="Represent a schema identifiers combined with quoting preferences. See [sqlalchemy quoted_name docs](https://docs.sqlalchemy.org/en/20/core/sqlelement.html#sqlalchemy.sql.expression.quoted_name).",
|
|
137
|
+
)
|
|
133
138
|
|
|
134
139
|
@pydantic.validator("uri_args")
|
|
135
140
|
def passwords_match(cls, v, values, **kwargs):
|
|
136
|
-
if values["use_odbc"] and "driver" not in v:
|
|
141
|
+
if values["use_odbc"] and not values["sqlalchemy_uri"] and "driver" not in v:
|
|
137
142
|
raise ValueError("uri_args must contain a 'driver' option")
|
|
138
143
|
elif not values["use_odbc"] and v:
|
|
139
144
|
raise ValueError("uri_args is not supported when ODBC is disabled")
|
|
@@ -159,7 +164,15 @@ class SQLServerConfig(BasicSQLAlchemyConfig):
|
|
|
159
164
|
uri_opts=uri_opts,
|
|
160
165
|
)
|
|
161
166
|
if self.use_odbc:
|
|
162
|
-
|
|
167
|
+
final_uri_args = self.uri_args.copy()
|
|
168
|
+
if final_uri_args and current_db:
|
|
169
|
+
final_uri_args.update({"database": current_db})
|
|
170
|
+
|
|
171
|
+
uri = (
|
|
172
|
+
f"{uri}?{urllib.parse.urlencode(final_uri_args)}"
|
|
173
|
+
if final_uri_args
|
|
174
|
+
else uri
|
|
175
|
+
)
|
|
163
176
|
return uri
|
|
164
177
|
|
|
165
178
|
@property
|
|
@@ -923,7 +936,11 @@ class SQLServerSource(SQLAlchemySource):
|
|
|
923
936
|
logger.debug(f"sql_alchemy_url={url}")
|
|
924
937
|
engine = create_engine(url, **self.config.options)
|
|
925
938
|
|
|
926
|
-
if
|
|
939
|
+
if (
|
|
940
|
+
self.config.database
|
|
941
|
+
and self.config.database != ""
|
|
942
|
+
or (self.config.sqlalchemy_uri and self.config.sqlalchemy_uri != "")
|
|
943
|
+
):
|
|
927
944
|
inspector = inspect(engine)
|
|
928
945
|
yield inspector
|
|
929
946
|
else:
|
|
@@ -1020,3 +1037,45 @@ class SQLServerSource(SQLAlchemySource):
|
|
|
1020
1037
|
if self.config.convert_urns_to_lowercase
|
|
1021
1038
|
else table_ref_str
|
|
1022
1039
|
)
|
|
1040
|
+
|
|
1041
|
+
def get_allowed_schemas(self, inspector: Inspector, db_name: str) -> Iterable[str]:
|
|
1042
|
+
for schema in super().get_allowed_schemas(inspector, db_name):
|
|
1043
|
+
if self.config.quote_schemas:
|
|
1044
|
+
yield quoted_name(schema, True)
|
|
1045
|
+
else:
|
|
1046
|
+
yield schema
|
|
1047
|
+
|
|
1048
|
+
def get_db_name(self, inspector: Inspector) -> str:
|
|
1049
|
+
engine = inspector.engine
|
|
1050
|
+
|
|
1051
|
+
try:
|
|
1052
|
+
if (
|
|
1053
|
+
engine
|
|
1054
|
+
and hasattr(engine, "url")
|
|
1055
|
+
and hasattr(engine.url, "database")
|
|
1056
|
+
and engine.url.database
|
|
1057
|
+
):
|
|
1058
|
+
return str(engine.url.database).strip('"')
|
|
1059
|
+
|
|
1060
|
+
if (
|
|
1061
|
+
engine
|
|
1062
|
+
and hasattr(engine, "url")
|
|
1063
|
+
and hasattr(engine.url, "query")
|
|
1064
|
+
and "odbc_connect" in engine.url.query
|
|
1065
|
+
):
|
|
1066
|
+
# According to the ODBC connection keywords: https://learn.microsoft.com/en-us/sql/connect/odbc/dsn-connection-string-attribute?view=sql-server-ver17#supported-dsnconnection-string-keywords-and-connection-attributes
|
|
1067
|
+
database = re.search(
|
|
1068
|
+
r"DATABASE=([^;]*);",
|
|
1069
|
+
urllib.parse.unquote_plus(str(engine.url.query["odbc_connect"])),
|
|
1070
|
+
flags=re.IGNORECASE,
|
|
1071
|
+
)
|
|
1072
|
+
|
|
1073
|
+
if database and database.group(1):
|
|
1074
|
+
return database.group(1)
|
|
1075
|
+
|
|
1076
|
+
return ""
|
|
1077
|
+
|
|
1078
|
+
except Exception as e:
|
|
1079
|
+
raise RuntimeError(
|
|
1080
|
+
"Unable to get database name from Sqlalchemy inspector"
|
|
1081
|
+
) from e
|
|
@@ -312,6 +312,17 @@ class UnityCatalogSourceConfig(
|
|
|
312
312
|
|
|
313
313
|
scheme: str = DATABRICKS
|
|
314
314
|
|
|
315
|
+
include_ml_model_aliases: bool = pydantic.Field(
|
|
316
|
+
default=False,
|
|
317
|
+
description="Whether to include ML model aliases in the ingestion.",
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
ml_model_max_results: int = pydantic.Field(
|
|
321
|
+
default=1000,
|
|
322
|
+
ge=0,
|
|
323
|
+
description="Maximum number of ML models to ingest.",
|
|
324
|
+
)
|
|
325
|
+
|
|
315
326
|
def get_sql_alchemy_url(self, database: Optional[str] = None) -> str:
|
|
316
327
|
uri_opts = {"http_path": f"/sql/1.0/warehouses/{self.warehouse_id}"}
|
|
317
328
|
if database:
|
|
@@ -17,6 +17,8 @@ from databricks.sdk.service.catalog import (
|
|
|
17
17
|
ColumnInfo,
|
|
18
18
|
GetMetastoreSummaryResponse,
|
|
19
19
|
MetastoreInfo,
|
|
20
|
+
ModelVersionInfo,
|
|
21
|
+
RegisteredModelInfo,
|
|
20
22
|
SchemaInfo,
|
|
21
23
|
TableInfo,
|
|
22
24
|
)
|
|
@@ -49,6 +51,8 @@ from datahub.ingestion.source.unity.proxy_types import (
|
|
|
49
51
|
CustomCatalogType,
|
|
50
52
|
ExternalTableReference,
|
|
51
53
|
Metastore,
|
|
54
|
+
Model,
|
|
55
|
+
ModelVersion,
|
|
52
56
|
Notebook,
|
|
53
57
|
NotebookReference,
|
|
54
58
|
Query,
|
|
@@ -251,6 +255,40 @@ class UnityCatalogApiProxy(UnityCatalogProxyProfilingMixin):
|
|
|
251
255
|
logger.warning(f"Error parsing table: {e}")
|
|
252
256
|
self.report.report_warning("table-parse", str(e))
|
|
253
257
|
|
|
258
|
+
def ml_models(
|
|
259
|
+
self, schema: Schema, max_results: Optional[int] = None
|
|
260
|
+
) -> Iterable[Model]:
|
|
261
|
+
response = self._workspace_client.registered_models.list(
|
|
262
|
+
catalog_name=schema.catalog.name,
|
|
263
|
+
schema_name=schema.name,
|
|
264
|
+
max_results=max_results,
|
|
265
|
+
)
|
|
266
|
+
for ml_model in response:
|
|
267
|
+
optional_ml_model = self._create_ml_model(schema, ml_model)
|
|
268
|
+
if optional_ml_model:
|
|
269
|
+
yield optional_ml_model
|
|
270
|
+
|
|
271
|
+
def ml_model_versions(
|
|
272
|
+
self, ml_model: Model, include_aliases: bool = False
|
|
273
|
+
) -> Iterable[ModelVersion]:
|
|
274
|
+
response = self._workspace_client.model_versions.list(
|
|
275
|
+
full_name=ml_model.id,
|
|
276
|
+
include_browse=True,
|
|
277
|
+
max_results=self.databricks_api_page_size,
|
|
278
|
+
)
|
|
279
|
+
for version in response:
|
|
280
|
+
if version.version is not None:
|
|
281
|
+
if include_aliases:
|
|
282
|
+
# to get aliases info, use GET
|
|
283
|
+
version = self._workspace_client.model_versions.get(
|
|
284
|
+
ml_model.id, version.version, include_aliases=True
|
|
285
|
+
)
|
|
286
|
+
optional_ml_model_version = self._create_ml_model_version(
|
|
287
|
+
ml_model, version
|
|
288
|
+
)
|
|
289
|
+
if optional_ml_model_version:
|
|
290
|
+
yield optional_ml_model_version
|
|
291
|
+
|
|
254
292
|
def service_principals(self) -> Iterable[ServicePrincipal]:
|
|
255
293
|
for principal in self._workspace_client.service_principals.list():
|
|
256
294
|
optional_sp = self._create_service_principal(principal)
|
|
@@ -862,6 +900,45 @@ class UnityCatalogApiProxy(UnityCatalogProxyProfilingMixin):
|
|
|
862
900
|
if optional_column:
|
|
863
901
|
yield optional_column
|
|
864
902
|
|
|
903
|
+
def _create_ml_model(
|
|
904
|
+
self, schema: Schema, obj: RegisteredModelInfo
|
|
905
|
+
) -> Optional[Model]:
|
|
906
|
+
if not obj.name or not obj.full_name:
|
|
907
|
+
self.report.num_ml_models_missing_name += 1
|
|
908
|
+
return None
|
|
909
|
+
return Model(
|
|
910
|
+
id=obj.full_name,
|
|
911
|
+
name=obj.name,
|
|
912
|
+
description=obj.comment,
|
|
913
|
+
schema_name=schema.name,
|
|
914
|
+
catalog_name=schema.catalog.name,
|
|
915
|
+
created_at=parse_ts_millis(obj.created_at),
|
|
916
|
+
updated_at=parse_ts_millis(obj.updated_at),
|
|
917
|
+
)
|
|
918
|
+
|
|
919
|
+
def _create_ml_model_version(
|
|
920
|
+
self, model: Model, obj: ModelVersionInfo
|
|
921
|
+
) -> Optional[ModelVersion]:
|
|
922
|
+
if obj.version is None:
|
|
923
|
+
return None
|
|
924
|
+
|
|
925
|
+
aliases = []
|
|
926
|
+
if obj.aliases:
|
|
927
|
+
for alias in obj.aliases:
|
|
928
|
+
if alias.alias_name:
|
|
929
|
+
aliases.append(alias.alias_name)
|
|
930
|
+
return ModelVersion(
|
|
931
|
+
id=f"{model.id}_{obj.version}",
|
|
932
|
+
name=f"{model.name}_{obj.version}",
|
|
933
|
+
model=model,
|
|
934
|
+
version=str(obj.version),
|
|
935
|
+
aliases=aliases,
|
|
936
|
+
description=obj.comment,
|
|
937
|
+
created_at=parse_ts_millis(obj.created_at),
|
|
938
|
+
updated_at=parse_ts_millis(obj.updated_at),
|
|
939
|
+
created_by=obj.created_by,
|
|
940
|
+
)
|
|
941
|
+
|
|
865
942
|
def _create_service_principal(
|
|
866
943
|
self, obj: DatabricksServicePrincipal
|
|
867
944
|
) -> Optional[ServicePrincipal]:
|
|
@@ -337,3 +337,27 @@ class Notebook:
|
|
|
337
337
|
"upstreams": frozenset([*notebook.upstreams, upstream]),
|
|
338
338
|
}
|
|
339
339
|
)
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
@dataclass
|
|
343
|
+
class Model:
|
|
344
|
+
id: str
|
|
345
|
+
name: str
|
|
346
|
+
schema_name: str
|
|
347
|
+
catalog_name: str
|
|
348
|
+
description: Optional[str]
|
|
349
|
+
created_at: Optional[datetime]
|
|
350
|
+
updated_at: Optional[datetime]
|
|
351
|
+
|
|
352
|
+
|
|
353
|
+
@dataclass
|
|
354
|
+
class ModelVersion:
|
|
355
|
+
id: str
|
|
356
|
+
name: str
|
|
357
|
+
model: Model
|
|
358
|
+
version: str
|
|
359
|
+
aliases: Optional[List[str]]
|
|
360
|
+
description: Optional[str]
|
|
361
|
+
created_at: Optional[datetime]
|
|
362
|
+
updated_at: Optional[datetime]
|
|
363
|
+
created_by: Optional[str]
|
|
@@ -31,6 +31,10 @@ class UnityCatalogReport(IngestionStageReport, SQLSourceReport):
|
|
|
31
31
|
tables: EntityFilterReport = EntityFilterReport.field(type="table/view")
|
|
32
32
|
table_profiles: EntityFilterReport = EntityFilterReport.field(type="table profile")
|
|
33
33
|
notebooks: EntityFilterReport = EntityFilterReport.field(type="notebook")
|
|
34
|
+
ml_models: EntityFilterReport = EntityFilterReport.field(type="ml_model")
|
|
35
|
+
ml_model_versions: EntityFilterReport = EntityFilterReport.field(
|
|
36
|
+
type="ml_model_version"
|
|
37
|
+
)
|
|
34
38
|
|
|
35
39
|
hive_metastore_catalog_found: Optional[bool] = None
|
|
36
40
|
|
|
@@ -64,6 +68,7 @@ class UnityCatalogReport(IngestionStageReport, SQLSourceReport):
|
|
|
64
68
|
num_catalogs_missing_name: int = 0
|
|
65
69
|
num_schemas_missing_name: int = 0
|
|
66
70
|
num_tables_missing_name: int = 0
|
|
71
|
+
num_ml_models_missing_name: int = 0
|
|
67
72
|
num_columns_missing_name: int = 0
|
|
68
73
|
num_queries_missing_info: int = 0
|
|
69
74
|
|