acryl-datahub 1.2.0.7rc1__py3-none-any.whl → 1.2.0.7rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of acryl-datahub might be problematic. Click here for more details.

Files changed (27) hide show
  1. {acryl_datahub-1.2.0.7rc1.dist-info → acryl_datahub-1.2.0.7rc3.dist-info}/METADATA +2485 -2463
  2. {acryl_datahub-1.2.0.7rc1.dist-info → acryl_datahub-1.2.0.7rc3.dist-info}/RECORD +26 -27
  3. datahub/_version.py +1 -1
  4. datahub/ingestion/source/redshift/config.py +9 -6
  5. datahub/ingestion/source/redshift/lineage.py +386 -687
  6. datahub/ingestion/source/redshift/redshift.py +19 -106
  7. datahub/ingestion/source/snowflake/constants.py +2 -0
  8. datahub/ingestion/source/snowflake/snowflake_connection.py +15 -4
  9. datahub/ingestion/source/snowflake/snowflake_schema_gen.py +4 -1
  10. datahub/ingestion/source/snowflake/snowflake_utils.py +18 -5
  11. datahub/ingestion/source/snowflake/snowflake_v2.py +2 -0
  12. datahub/ingestion/source/sql/mssql/job_models.py +3 -1
  13. datahub/ingestion/source/sql/mssql/source.py +62 -3
  14. datahub/ingestion/source/unity/config.py +11 -0
  15. datahub/ingestion/source/unity/proxy.py +77 -0
  16. datahub/ingestion/source/unity/proxy_types.py +24 -0
  17. datahub/ingestion/source/unity/report.py +5 -0
  18. datahub/ingestion/source/unity/source.py +99 -1
  19. datahub/metadata/_internal_schema_classes.py +5 -5
  20. datahub/metadata/schema.avsc +66 -60
  21. datahub/metadata/schemas/LogicalParent.avsc +104 -100
  22. datahub/metadata/schemas/SchemaFieldKey.avsc +3 -1
  23. datahub/ingestion/source/redshift/lineage_v2.py +0 -466
  24. {acryl_datahub-1.2.0.7rc1.dist-info → acryl_datahub-1.2.0.7rc3.dist-info}/WHEEL +0 -0
  25. {acryl_datahub-1.2.0.7rc1.dist-info → acryl_datahub-1.2.0.7rc3.dist-info}/entry_points.txt +0 -0
  26. {acryl_datahub-1.2.0.7rc1.dist-info → acryl_datahub-1.2.0.7rc3.dist-info}/licenses/LICENSE +0 -0
  27. {acryl_datahub-1.2.0.7rc1.dist-info → acryl_datahub-1.2.0.7rc3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,4 @@
1
1
  import functools
2
- import itertools
3
2
  import logging
4
3
  from collections import defaultdict
5
4
  from typing import Dict, Iterable, List, Optional, Type, Union
@@ -52,8 +51,7 @@ from datahub.ingestion.source.common.subtypes import (
52
51
  from datahub.ingestion.source.redshift.config import RedshiftConfig
53
52
  from datahub.ingestion.source.redshift.datashares import RedshiftDatasharesHelper
54
53
  from datahub.ingestion.source.redshift.exception import handle_redshift_exceptions_yield
55
- from datahub.ingestion.source.redshift.lineage import RedshiftLineageExtractor
56
- from datahub.ingestion.source.redshift.lineage_v2 import RedshiftSqlLineageV2
54
+ from datahub.ingestion.source.redshift.lineage import RedshiftSqlLineage
57
55
  from datahub.ingestion.source.redshift.profile import RedshiftProfiler
58
56
  from datahub.ingestion.source.redshift.redshift_data_reader import RedshiftDataReader
59
57
  from datahub.ingestion.source.redshift.redshift_schema import (
@@ -72,7 +70,6 @@ from datahub.ingestion.source.sql.sql_utils import (
72
70
  add_table_to_schema_container,
73
71
  gen_database_container,
74
72
  gen_database_key,
75
- gen_lineage,
76
73
  gen_schema_container,
77
74
  gen_schema_key,
78
75
  get_dataplatform_instance_aspect,
@@ -116,7 +113,6 @@ from datahub.metadata.com.linkedin.pegasus2avro.schema import (
116
113
  )
117
114
  from datahub.metadata.schema_classes import GlobalTagsClass, TagAssociationClass
118
115
  from datahub.utilities import memory_footprint
119
- from datahub.utilities.dedup_list import deduplicate_list
120
116
  from datahub.utilities.mapping import Constants
121
117
  from datahub.utilities.perf_timer import PerfTimer
122
118
  from datahub.utilities.registries.domain_registry import DomainRegistry
@@ -423,40 +419,25 @@ class RedshiftSource(StatefulIngestionSourceBase, TestableSource):
423
419
  memory_footprint.total_size(self.db_views)
424
420
  )
425
421
 
426
- if self.config.use_lineage_v2:
427
- with RedshiftSqlLineageV2(
428
- config=self.config,
429
- report=self.report,
430
- context=self.ctx,
431
- database=database,
432
- redundant_run_skip_handler=self.redundant_lineage_run_skip_handler,
433
- ) as lineage_extractor:
434
- yield from lineage_extractor.aggregator.register_schemas_from_stream(
435
- self.process_schemas(connection, database)
436
- )
437
-
438
- with self.report.new_stage(LINEAGE_EXTRACTION):
439
- yield from self.extract_lineage_v2(
440
- connection=connection,
441
- database=database,
442
- lineage_extractor=lineage_extractor,
443
- )
444
-
445
- all_tables = self.get_all_tables()
446
- else:
447
- yield from self.process_schemas(connection, database)
422
+ with RedshiftSqlLineage(
423
+ config=self.config,
424
+ report=self.report,
425
+ context=self.ctx,
426
+ database=database,
427
+ redundant_run_skip_handler=self.redundant_lineage_run_skip_handler,
428
+ ) as lineage_extractor:
429
+ yield from lineage_extractor.aggregator.register_schemas_from_stream(
430
+ self.process_schemas(connection, database)
431
+ )
448
432
 
449
- all_tables = self.get_all_tables()
433
+ with self.report.new_stage(LINEAGE_EXTRACTION):
434
+ yield from self.extract_lineage_v2(
435
+ connection=connection,
436
+ database=database,
437
+ lineage_extractor=lineage_extractor,
438
+ )
450
439
 
451
- if (
452
- self.config.include_table_lineage
453
- or self.config.include_view_lineage
454
- or self.config.include_copy_lineage
455
- ):
456
- with self.report.new_stage(LINEAGE_EXTRACTION):
457
- yield from self.extract_lineage(
458
- connection=connection, all_tables=all_tables, database=database
459
- )
440
+ all_tables = self.get_all_tables()
460
441
 
461
442
  if self.config.include_usage_statistics:
462
443
  with self.report.new_stage(USAGE_EXTRACTION_INGESTION):
@@ -968,45 +949,11 @@ class RedshiftSource(StatefulIngestionSourceBase, TestableSource):
968
949
 
969
950
  self.report.usage_extraction_sec[database] = timer.elapsed_seconds(digits=2)
970
951
 
971
- def extract_lineage(
972
- self,
973
- connection: redshift_connector.Connection,
974
- database: str,
975
- all_tables: Dict[str, Dict[str, List[Union[RedshiftView, RedshiftTable]]]],
976
- ) -> Iterable[MetadataWorkUnit]:
977
- if not self._should_ingest_lineage():
978
- return
979
-
980
- lineage_extractor = RedshiftLineageExtractor(
981
- config=self.config,
982
- report=self.report,
983
- context=self.ctx,
984
- redundant_run_skip_handler=self.redundant_lineage_run_skip_handler,
985
- )
986
-
987
- with PerfTimer() as timer:
988
- lineage_extractor.populate_lineage(
989
- database=database, connection=connection, all_tables=all_tables
990
- )
991
-
992
- self.report.lineage_extraction_sec[f"{database}"] = timer.elapsed_seconds(
993
- digits=2
994
- )
995
- yield from self.generate_lineage(
996
- database, lineage_extractor=lineage_extractor
997
- )
998
-
999
- if self.redundant_lineage_run_skip_handler:
1000
- # Update the checkpoint state for this run.
1001
- self.redundant_lineage_run_skip_handler.update_state(
1002
- self.config.start_time, self.config.end_time
1003
- )
1004
-
1005
952
  def extract_lineage_v2(
1006
953
  self,
1007
954
  connection: redshift_connector.Connection,
1008
955
  database: str,
1009
- lineage_extractor: RedshiftSqlLineageV2,
956
+ lineage_extractor: RedshiftSqlLineage,
1010
957
  ) -> Iterable[MetadataWorkUnit]:
1011
958
  if self.config.include_share_lineage:
1012
959
  outbound_shares = self.data_dictionary.get_outbound_datashares(connection)
@@ -1069,40 +1016,6 @@ class RedshiftSource(StatefulIngestionSourceBase, TestableSource):
1069
1016
 
1070
1017
  return True
1071
1018
 
1072
- def generate_lineage(
1073
- self, database: str, lineage_extractor: RedshiftLineageExtractor
1074
- ) -> Iterable[MetadataWorkUnit]:
1075
- logger.info(f"Generate lineage for {database}")
1076
- for schema in deduplicate_list(
1077
- itertools.chain(self.db_tables[database], self.db_views[database])
1078
- ):
1079
- if (
1080
- database not in self.db_schemas
1081
- or schema not in self.db_schemas[database]
1082
- ):
1083
- logger.warning(
1084
- f"Either database {database} or {schema} exists in the lineage but was not discovered earlier. Something went wrong."
1085
- )
1086
- continue
1087
-
1088
- table_or_view: Union[RedshiftTable, RedshiftView]
1089
- for table_or_view in (
1090
- []
1091
- + self.db_tables[database].get(schema, [])
1092
- + self.db_views[database].get(schema, [])
1093
- ):
1094
- datahub_dataset_name = f"{database}.{schema}.{table_or_view.name}"
1095
- dataset_urn = self.gen_dataset_urn(datahub_dataset_name)
1096
-
1097
- lineage_info = lineage_extractor.get_lineage(
1098
- table_or_view,
1099
- dataset_urn,
1100
- self.db_schemas[database][schema],
1101
- )
1102
- if lineage_info:
1103
- # incremental lineage generation is taken care by auto_incremental_lineage
1104
- yield from gen_lineage(dataset_urn, lineage_info)
1105
-
1106
1019
  def add_config_to_report(self):
1107
1020
  self.report.stateful_lineage_ingestion_enabled = (
1108
1021
  self.config.enable_stateful_lineage_ingestion
@@ -9,6 +9,8 @@ class SnowflakeCloudProvider(StrEnum):
9
9
 
10
10
  SNOWFLAKE_DEFAULT_CLOUD = SnowflakeCloudProvider.AWS
11
11
 
12
+ DEFAULT_SNOWFLAKE_DOMAIN = "snowflakecomputing.com"
13
+
12
14
 
13
15
  class SnowflakeEdition(StrEnum):
14
16
  STANDARD = "Standard"
@@ -22,6 +22,7 @@ from datahub.ingestion.api.closeable import Closeable
22
22
  from datahub.ingestion.source.snowflake.constants import (
23
23
  CLIENT_PREFETCH_THREADS,
24
24
  CLIENT_SESSION_KEEP_ALIVE,
25
+ DEFAULT_SNOWFLAKE_DOMAIN,
25
26
  )
26
27
  from datahub.ingestion.source.snowflake.oauth_config import (
27
28
  OAuthConfiguration,
@@ -47,8 +48,6 @@ _VALID_AUTH_TYPES: Dict[str, str] = {
47
48
  "OAUTH_AUTHENTICATOR_TOKEN": OAUTH_AUTHENTICATOR,
48
49
  }
49
50
 
50
- _SNOWFLAKE_HOST_SUFFIX = ".snowflakecomputing.com"
51
-
52
51
 
53
52
  class SnowflakePermissionError(MetaError):
54
53
  """A permission error has happened"""
@@ -110,6 +109,10 @@ class SnowflakeConnectionConfig(ConfigModel):
110
109
  default=None,
111
110
  description="OAuth token from external identity provider. Not recommended for most use cases because it will not be able to refresh once expired.",
112
111
  )
112
+ snowflake_domain: str = pydantic.Field(
113
+ default=DEFAULT_SNOWFLAKE_DOMAIN,
114
+ description="Snowflake domain. Use 'snowflakecomputing.com' for most regions or 'snowflakecomputing.cn' for China (cn-northwest-1) region.",
115
+ )
113
116
 
114
117
  def get_account(self) -> str:
115
118
  assert self.account_id
@@ -118,10 +121,13 @@ class SnowflakeConnectionConfig(ConfigModel):
118
121
  rename_host_port_to_account_id = pydantic_renamed_field("host_port", "account_id")
119
122
 
120
123
  @pydantic.validator("account_id")
121
- def validate_account_id(cls, account_id: str) -> str:
124
+ def validate_account_id(cls, account_id: str, values: Dict) -> str:
122
125
  account_id = remove_protocol(account_id)
123
126
  account_id = remove_trailing_slashes(account_id)
124
- account_id = remove_suffix(account_id, _SNOWFLAKE_HOST_SUFFIX)
127
+ # Get the domain from config, fallback to default
128
+ domain = values.get("snowflake_domain", DEFAULT_SNOWFLAKE_DOMAIN)
129
+ snowflake_host_suffix = f".{domain}"
130
+ account_id = remove_suffix(account_id, snowflake_host_suffix)
125
131
  return account_id
126
132
 
127
133
  @pydantic.validator("authentication_type", always=True)
@@ -311,6 +317,7 @@ class SnowflakeConnectionConfig(ConfigModel):
311
317
  warehouse=self.warehouse,
312
318
  authenticator=_VALID_AUTH_TYPES.get(self.authentication_type),
313
319
  application=_APPLICATION_NAME,
320
+ host=f"{self.account_id}.{self.snowflake_domain}",
314
321
  **connect_args,
315
322
  )
316
323
 
@@ -324,6 +331,7 @@ class SnowflakeConnectionConfig(ConfigModel):
324
331
  role=self.role,
325
332
  authenticator=_VALID_AUTH_TYPES.get(self.authentication_type),
326
333
  application=_APPLICATION_NAME,
334
+ host=f"{self.account_id}.{self.snowflake_domain}",
327
335
  **connect_args,
328
336
  )
329
337
 
@@ -337,6 +345,7 @@ class SnowflakeConnectionConfig(ConfigModel):
337
345
  warehouse=self.warehouse,
338
346
  role=self.role,
339
347
  application=_APPLICATION_NAME,
348
+ host=f"{self.account_id}.{self.snowflake_domain}",
340
349
  **connect_args,
341
350
  )
342
351
  elif self.authentication_type == "OAUTH_AUTHENTICATOR_TOKEN":
@@ -348,6 +357,7 @@ class SnowflakeConnectionConfig(ConfigModel):
348
357
  warehouse=self.warehouse,
349
358
  role=self.role,
350
359
  application=_APPLICATION_NAME,
360
+ host=f"{self.account_id}.{self.snowflake_domain}",
351
361
  **connect_args,
352
362
  )
353
363
  elif self.authentication_type == "OAUTH_AUTHENTICATOR":
@@ -363,6 +373,7 @@ class SnowflakeConnectionConfig(ConfigModel):
363
373
  role=self.role,
364
374
  authenticator=_VALID_AUTH_TYPES.get(self.authentication_type),
365
375
  application=_APPLICATION_NAME,
376
+ host=f"{self.account_id}.{self.snowflake_domain}",
366
377
  **connect_args,
367
378
  )
368
379
  else:
@@ -441,13 +441,16 @@ class SnowflakeSchemaGenerator(SnowflakeStructuredReportMixin):
441
441
  tables = self.fetch_tables_for_schema(
442
442
  snowflake_schema, db_name, schema_name
443
443
  )
444
+ if self.config.include_views:
445
+ views = self.fetch_views_for_schema(snowflake_schema, db_name, schema_name)
446
+
447
+ if self.config.include_tables:
444
448
  db_tables[schema_name] = tables
445
449
  yield from self._process_tables(
446
450
  tables, snowflake_schema, db_name, schema_name
447
451
  )
448
452
 
449
453
  if self.config.include_views:
450
- views = self.fetch_views_for_schema(snowflake_schema, db_name, schema_name)
451
454
  yield from self._process_views(
452
455
  views, snowflake_schema, db_name, schema_name
453
456
  )
@@ -9,6 +9,7 @@ from datahub.emitter.mce_builder import (
9
9
  from datahub.emitter.mcp_builder import DatabaseKey, SchemaKey
10
10
  from datahub.ingestion.api.source import SourceReport
11
11
  from datahub.ingestion.source.snowflake.constants import (
12
+ DEFAULT_SNOWFLAKE_DOMAIN,
12
13
  SNOWFLAKE_REGION_CLOUD_REGION_MAPPING,
13
14
  SnowflakeCloudProvider,
14
15
  SnowflakeObjectDomain,
@@ -34,16 +35,21 @@ class SnowsightUrlBuilder:
34
35
  "us-east-1",
35
36
  "eu-west-1",
36
37
  "eu-central-1",
37
- "ap-southeast-1",
38
38
  "ap-southeast-2",
39
39
  ]
40
40
 
41
41
  snowsight_base_url: str
42
42
 
43
- def __init__(self, account_locator: str, region: str, privatelink: bool = False):
43
+ def __init__(
44
+ self,
45
+ account_locator: str,
46
+ region: str,
47
+ privatelink: bool = False,
48
+ snowflake_domain: str = DEFAULT_SNOWFLAKE_DOMAIN,
49
+ ):
44
50
  cloud, cloud_region_id = self.get_cloud_region_from_snowflake_region_id(region)
45
51
  self.snowsight_base_url = self.create_snowsight_base_url(
46
- account_locator, cloud_region_id, cloud, privatelink
52
+ account_locator, cloud_region_id, cloud, privatelink, snowflake_domain
47
53
  )
48
54
 
49
55
  @staticmethod
@@ -52,6 +58,7 @@ class SnowsightUrlBuilder:
52
58
  cloud_region_id: str,
53
59
  cloud: str,
54
60
  privatelink: bool = False,
61
+ snowflake_domain: str = DEFAULT_SNOWFLAKE_DOMAIN,
55
62
  ) -> str:
56
63
  if cloud:
57
64
  url_cloud_provider_suffix = f".{cloud}"
@@ -67,9 +74,15 @@ class SnowsightUrlBuilder:
67
74
  else:
68
75
  url_cloud_provider_suffix = f".{cloud}"
69
76
  if privatelink:
70
- url = f"https://app.{account_locator}.{cloud_region_id}.privatelink.snowflakecomputing.com/"
77
+ url = f"https://app.{account_locator}.{cloud_region_id}.privatelink.{snowflake_domain}/"
71
78
  else:
72
- url = f"https://app.snowflake.com/{cloud_region_id}{url_cloud_provider_suffix}/{account_locator}/"
79
+ # Standard Snowsight URL format - works for most regions
80
+ # China region may use app.snowflake.cn instead of app.snowflake.com. This is not documented, just
81
+ # guessing Based on existence of snowflake.cn domain (https://domainindex.com/domains/snowflake.cn)
82
+ if snowflake_domain == "snowflakecomputing.cn":
83
+ url = f"https://app.snowflake.cn/{cloud_region_id}{url_cloud_provider_suffix}/{account_locator}/"
84
+ else:
85
+ url = f"https://app.snowflake.com/{cloud_region_id}{url_cloud_provider_suffix}/{account_locator}/"
73
86
  return url
74
87
 
75
88
  @staticmethod
@@ -199,6 +199,7 @@ class SnowflakeV2Source(
199
199
  ),
200
200
  generate_usage_statistics=False,
201
201
  generate_operations=False,
202
+ generate_queries=self.config.include_queries,
202
203
  format_queries=self.config.format_sql_queries,
203
204
  is_temp_table=self._is_temp_table,
204
205
  is_allowed_table=self._is_allowed_table,
@@ -750,6 +751,7 @@ class SnowflakeV2Source(
750
751
  # For privatelink, account identifier ends with .privatelink
751
752
  # See https://docs.snowflake.com/en/user-guide/organizations-connect.html#private-connectivity-urls
752
753
  privatelink=self.config.account_id.endswith(".privatelink"),
754
+ snowflake_domain=self.config.snowflake_domain,
753
755
  )
754
756
 
755
757
  except Exception as e:
@@ -134,7 +134,9 @@ class StoredProcedure:
134
134
 
135
135
  @property
136
136
  def escape_full_name(self) -> str:
137
- return f"[{self.db}].[{self.schema}].[{self.formatted_name}]"
137
+ return f"[{self.db}].[{self.schema}].[{self.formatted_name}]".replace(
138
+ "'", r"''"
139
+ )
138
140
 
139
141
  def to_base_procedure(self) -> BaseProcedure:
140
142
  return BaseProcedure(
@@ -10,6 +10,7 @@ from sqlalchemy import create_engine, inspect
10
10
  from sqlalchemy.engine.base import Connection
11
11
  from sqlalchemy.engine.reflection import Inspector
12
12
  from sqlalchemy.exc import ProgrammingError, ResourceClosedError
13
+ from sqlalchemy.sql import quoted_name
13
14
 
14
15
  import datahub.metadata.schema_classes as models
15
16
  from datahub.configuration.common import AllowDenyPattern
@@ -130,10 +131,14 @@ class SQLServerConfig(BasicSQLAlchemyConfig):
130
131
  "match the entire table name in database.schema.table format. Defaults are to set in such a way "
131
132
  "to ignore the temporary staging tables created by known ETL tools.",
132
133
  )
134
+ quote_schemas: bool = Field(
135
+ default=False,
136
+ description="Represent a schema identifiers combined with quoting preferences. See [sqlalchemy quoted_name docs](https://docs.sqlalchemy.org/en/20/core/sqlelement.html#sqlalchemy.sql.expression.quoted_name).",
137
+ )
133
138
 
134
139
  @pydantic.validator("uri_args")
135
140
  def passwords_match(cls, v, values, **kwargs):
136
- if values["use_odbc"] and "driver" not in v:
141
+ if values["use_odbc"] and not values["sqlalchemy_uri"] and "driver" not in v:
137
142
  raise ValueError("uri_args must contain a 'driver' option")
138
143
  elif not values["use_odbc"] and v:
139
144
  raise ValueError("uri_args is not supported when ODBC is disabled")
@@ -159,7 +164,15 @@ class SQLServerConfig(BasicSQLAlchemyConfig):
159
164
  uri_opts=uri_opts,
160
165
  )
161
166
  if self.use_odbc:
162
- uri = f"{uri}?{urllib.parse.urlencode(self.uri_args)}"
167
+ final_uri_args = self.uri_args.copy()
168
+ if final_uri_args and current_db:
169
+ final_uri_args.update({"database": current_db})
170
+
171
+ uri = (
172
+ f"{uri}?{urllib.parse.urlencode(final_uri_args)}"
173
+ if final_uri_args
174
+ else uri
175
+ )
163
176
  return uri
164
177
 
165
178
  @property
@@ -923,7 +936,11 @@ class SQLServerSource(SQLAlchemySource):
923
936
  logger.debug(f"sql_alchemy_url={url}")
924
937
  engine = create_engine(url, **self.config.options)
925
938
 
926
- if self.config.database and self.config.database != "":
939
+ if (
940
+ self.config.database
941
+ and self.config.database != ""
942
+ or (self.config.sqlalchemy_uri and self.config.sqlalchemy_uri != "")
943
+ ):
927
944
  inspector = inspect(engine)
928
945
  yield inspector
929
946
  else:
@@ -1020,3 +1037,45 @@ class SQLServerSource(SQLAlchemySource):
1020
1037
  if self.config.convert_urns_to_lowercase
1021
1038
  else table_ref_str
1022
1039
  )
1040
+
1041
+ def get_allowed_schemas(self, inspector: Inspector, db_name: str) -> Iterable[str]:
1042
+ for schema in super().get_allowed_schemas(inspector, db_name):
1043
+ if self.config.quote_schemas:
1044
+ yield quoted_name(schema, True)
1045
+ else:
1046
+ yield schema
1047
+
1048
+ def get_db_name(self, inspector: Inspector) -> str:
1049
+ engine = inspector.engine
1050
+
1051
+ try:
1052
+ if (
1053
+ engine
1054
+ and hasattr(engine, "url")
1055
+ and hasattr(engine.url, "database")
1056
+ and engine.url.database
1057
+ ):
1058
+ return str(engine.url.database).strip('"')
1059
+
1060
+ if (
1061
+ engine
1062
+ and hasattr(engine, "url")
1063
+ and hasattr(engine.url, "query")
1064
+ and "odbc_connect" in engine.url.query
1065
+ ):
1066
+ # According to the ODBC connection keywords: https://learn.microsoft.com/en-us/sql/connect/odbc/dsn-connection-string-attribute?view=sql-server-ver17#supported-dsnconnection-string-keywords-and-connection-attributes
1067
+ database = re.search(
1068
+ r"DATABASE=([^;]*);",
1069
+ urllib.parse.unquote_plus(str(engine.url.query["odbc_connect"])),
1070
+ flags=re.IGNORECASE,
1071
+ )
1072
+
1073
+ if database and database.group(1):
1074
+ return database.group(1)
1075
+
1076
+ return ""
1077
+
1078
+ except Exception as e:
1079
+ raise RuntimeError(
1080
+ "Unable to get database name from Sqlalchemy inspector"
1081
+ ) from e
@@ -312,6 +312,17 @@ class UnityCatalogSourceConfig(
312
312
 
313
313
  scheme: str = DATABRICKS
314
314
 
315
+ include_ml_model_aliases: bool = pydantic.Field(
316
+ default=False,
317
+ description="Whether to include ML model aliases in the ingestion.",
318
+ )
319
+
320
+ ml_model_max_results: int = pydantic.Field(
321
+ default=1000,
322
+ ge=0,
323
+ description="Maximum number of ML models to ingest.",
324
+ )
325
+
315
326
  def get_sql_alchemy_url(self, database: Optional[str] = None) -> str:
316
327
  uri_opts = {"http_path": f"/sql/1.0/warehouses/{self.warehouse_id}"}
317
328
  if database:
@@ -17,6 +17,8 @@ from databricks.sdk.service.catalog import (
17
17
  ColumnInfo,
18
18
  GetMetastoreSummaryResponse,
19
19
  MetastoreInfo,
20
+ ModelVersionInfo,
21
+ RegisteredModelInfo,
20
22
  SchemaInfo,
21
23
  TableInfo,
22
24
  )
@@ -49,6 +51,8 @@ from datahub.ingestion.source.unity.proxy_types import (
49
51
  CustomCatalogType,
50
52
  ExternalTableReference,
51
53
  Metastore,
54
+ Model,
55
+ ModelVersion,
52
56
  Notebook,
53
57
  NotebookReference,
54
58
  Query,
@@ -251,6 +255,40 @@ class UnityCatalogApiProxy(UnityCatalogProxyProfilingMixin):
251
255
  logger.warning(f"Error parsing table: {e}")
252
256
  self.report.report_warning("table-parse", str(e))
253
257
 
258
+ def ml_models(
259
+ self, schema: Schema, max_results: Optional[int] = None
260
+ ) -> Iterable[Model]:
261
+ response = self._workspace_client.registered_models.list(
262
+ catalog_name=schema.catalog.name,
263
+ schema_name=schema.name,
264
+ max_results=max_results,
265
+ )
266
+ for ml_model in response:
267
+ optional_ml_model = self._create_ml_model(schema, ml_model)
268
+ if optional_ml_model:
269
+ yield optional_ml_model
270
+
271
+ def ml_model_versions(
272
+ self, ml_model: Model, include_aliases: bool = False
273
+ ) -> Iterable[ModelVersion]:
274
+ response = self._workspace_client.model_versions.list(
275
+ full_name=ml_model.id,
276
+ include_browse=True,
277
+ max_results=self.databricks_api_page_size,
278
+ )
279
+ for version in response:
280
+ if version.version is not None:
281
+ if include_aliases:
282
+ # to get aliases info, use GET
283
+ version = self._workspace_client.model_versions.get(
284
+ ml_model.id, version.version, include_aliases=True
285
+ )
286
+ optional_ml_model_version = self._create_ml_model_version(
287
+ ml_model, version
288
+ )
289
+ if optional_ml_model_version:
290
+ yield optional_ml_model_version
291
+
254
292
  def service_principals(self) -> Iterable[ServicePrincipal]:
255
293
  for principal in self._workspace_client.service_principals.list():
256
294
  optional_sp = self._create_service_principal(principal)
@@ -862,6 +900,45 @@ class UnityCatalogApiProxy(UnityCatalogProxyProfilingMixin):
862
900
  if optional_column:
863
901
  yield optional_column
864
902
 
903
+ def _create_ml_model(
904
+ self, schema: Schema, obj: RegisteredModelInfo
905
+ ) -> Optional[Model]:
906
+ if not obj.name or not obj.full_name:
907
+ self.report.num_ml_models_missing_name += 1
908
+ return None
909
+ return Model(
910
+ id=obj.full_name,
911
+ name=obj.name,
912
+ description=obj.comment,
913
+ schema_name=schema.name,
914
+ catalog_name=schema.catalog.name,
915
+ created_at=parse_ts_millis(obj.created_at),
916
+ updated_at=parse_ts_millis(obj.updated_at),
917
+ )
918
+
919
+ def _create_ml_model_version(
920
+ self, model: Model, obj: ModelVersionInfo
921
+ ) -> Optional[ModelVersion]:
922
+ if obj.version is None:
923
+ return None
924
+
925
+ aliases = []
926
+ if obj.aliases:
927
+ for alias in obj.aliases:
928
+ if alias.alias_name:
929
+ aliases.append(alias.alias_name)
930
+ return ModelVersion(
931
+ id=f"{model.id}_{obj.version}",
932
+ name=f"{model.name}_{obj.version}",
933
+ model=model,
934
+ version=str(obj.version),
935
+ aliases=aliases,
936
+ description=obj.comment,
937
+ created_at=parse_ts_millis(obj.created_at),
938
+ updated_at=parse_ts_millis(obj.updated_at),
939
+ created_by=obj.created_by,
940
+ )
941
+
865
942
  def _create_service_principal(
866
943
  self, obj: DatabricksServicePrincipal
867
944
  ) -> Optional[ServicePrincipal]:
@@ -337,3 +337,27 @@ class Notebook:
337
337
  "upstreams": frozenset([*notebook.upstreams, upstream]),
338
338
  }
339
339
  )
340
+
341
+
342
+ @dataclass
343
+ class Model:
344
+ id: str
345
+ name: str
346
+ schema_name: str
347
+ catalog_name: str
348
+ description: Optional[str]
349
+ created_at: Optional[datetime]
350
+ updated_at: Optional[datetime]
351
+
352
+
353
+ @dataclass
354
+ class ModelVersion:
355
+ id: str
356
+ name: str
357
+ model: Model
358
+ version: str
359
+ aliases: Optional[List[str]]
360
+ description: Optional[str]
361
+ created_at: Optional[datetime]
362
+ updated_at: Optional[datetime]
363
+ created_by: Optional[str]
@@ -31,6 +31,10 @@ class UnityCatalogReport(IngestionStageReport, SQLSourceReport):
31
31
  tables: EntityFilterReport = EntityFilterReport.field(type="table/view")
32
32
  table_profiles: EntityFilterReport = EntityFilterReport.field(type="table profile")
33
33
  notebooks: EntityFilterReport = EntityFilterReport.field(type="notebook")
34
+ ml_models: EntityFilterReport = EntityFilterReport.field(type="ml_model")
35
+ ml_model_versions: EntityFilterReport = EntityFilterReport.field(
36
+ type="ml_model_version"
37
+ )
34
38
 
35
39
  hive_metastore_catalog_found: Optional[bool] = None
36
40
 
@@ -64,6 +68,7 @@ class UnityCatalogReport(IngestionStageReport, SQLSourceReport):
64
68
  num_catalogs_missing_name: int = 0
65
69
  num_schemas_missing_name: int = 0
66
70
  num_tables_missing_name: int = 0
71
+ num_ml_models_missing_name: int = 0
67
72
  num_columns_missing_name: int = 0
68
73
  num_queries_missing_info: int = 0
69
74