acryl-datahub 1.2.0.7rc2__py3-none-any.whl → 1.2.0.7rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of acryl-datahub might be problematic. Click here for more details.

Files changed (24) hide show
  1. {acryl_datahub-1.2.0.7rc2.dist-info → acryl_datahub-1.2.0.7rc3.dist-info}/METADATA +2525 -2521
  2. {acryl_datahub-1.2.0.7rc2.dist-info → acryl_datahub-1.2.0.7rc3.dist-info}/RECORD +23 -24
  3. datahub/_version.py +1 -1
  4. datahub/ingestion/source/redshift/config.py +9 -6
  5. datahub/ingestion/source/redshift/lineage.py +386 -687
  6. datahub/ingestion/source/redshift/redshift.py +19 -106
  7. datahub/ingestion/source/snowflake/snowflake_schema_gen.py +4 -1
  8. datahub/ingestion/source/snowflake/snowflake_v2.py +1 -0
  9. datahub/ingestion/source/sql/mssql/job_models.py +3 -1
  10. datahub/ingestion/source/sql/mssql/source.py +62 -3
  11. datahub/ingestion/source/unity/config.py +11 -0
  12. datahub/ingestion/source/unity/proxy.py +77 -0
  13. datahub/ingestion/source/unity/proxy_types.py +24 -0
  14. datahub/ingestion/source/unity/report.py +5 -0
  15. datahub/ingestion/source/unity/source.py +99 -1
  16. datahub/metadata/_internal_schema_classes.py +5 -5
  17. datahub/metadata/schema.avsc +66 -60
  18. datahub/metadata/schemas/LogicalParent.avsc +104 -100
  19. datahub/metadata/schemas/SchemaFieldKey.avsc +3 -1
  20. datahub/ingestion/source/redshift/lineage_v2.py +0 -466
  21. {acryl_datahub-1.2.0.7rc2.dist-info → acryl_datahub-1.2.0.7rc3.dist-info}/WHEEL +0 -0
  22. {acryl_datahub-1.2.0.7rc2.dist-info → acryl_datahub-1.2.0.7rc3.dist-info}/entry_points.txt +0 -0
  23. {acryl_datahub-1.2.0.7rc2.dist-info → acryl_datahub-1.2.0.7rc3.dist-info}/licenses/LICENSE +0 -0
  24. {acryl_datahub-1.2.0.7rc2.dist-info → acryl_datahub-1.2.0.7rc3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,4 @@
1
1
  import functools
2
- import itertools
3
2
  import logging
4
3
  from collections import defaultdict
5
4
  from typing import Dict, Iterable, List, Optional, Type, Union
@@ -52,8 +51,7 @@ from datahub.ingestion.source.common.subtypes import (
52
51
  from datahub.ingestion.source.redshift.config import RedshiftConfig
53
52
  from datahub.ingestion.source.redshift.datashares import RedshiftDatasharesHelper
54
53
  from datahub.ingestion.source.redshift.exception import handle_redshift_exceptions_yield
55
- from datahub.ingestion.source.redshift.lineage import RedshiftLineageExtractor
56
- from datahub.ingestion.source.redshift.lineage_v2 import RedshiftSqlLineageV2
54
+ from datahub.ingestion.source.redshift.lineage import RedshiftSqlLineage
57
55
  from datahub.ingestion.source.redshift.profile import RedshiftProfiler
58
56
  from datahub.ingestion.source.redshift.redshift_data_reader import RedshiftDataReader
59
57
  from datahub.ingestion.source.redshift.redshift_schema import (
@@ -72,7 +70,6 @@ from datahub.ingestion.source.sql.sql_utils import (
72
70
  add_table_to_schema_container,
73
71
  gen_database_container,
74
72
  gen_database_key,
75
- gen_lineage,
76
73
  gen_schema_container,
77
74
  gen_schema_key,
78
75
  get_dataplatform_instance_aspect,
@@ -116,7 +113,6 @@ from datahub.metadata.com.linkedin.pegasus2avro.schema import (
116
113
  )
117
114
  from datahub.metadata.schema_classes import GlobalTagsClass, TagAssociationClass
118
115
  from datahub.utilities import memory_footprint
119
- from datahub.utilities.dedup_list import deduplicate_list
120
116
  from datahub.utilities.mapping import Constants
121
117
  from datahub.utilities.perf_timer import PerfTimer
122
118
  from datahub.utilities.registries.domain_registry import DomainRegistry
@@ -423,40 +419,25 @@ class RedshiftSource(StatefulIngestionSourceBase, TestableSource):
423
419
  memory_footprint.total_size(self.db_views)
424
420
  )
425
421
 
426
- if self.config.use_lineage_v2:
427
- with RedshiftSqlLineageV2(
428
- config=self.config,
429
- report=self.report,
430
- context=self.ctx,
431
- database=database,
432
- redundant_run_skip_handler=self.redundant_lineage_run_skip_handler,
433
- ) as lineage_extractor:
434
- yield from lineage_extractor.aggregator.register_schemas_from_stream(
435
- self.process_schemas(connection, database)
436
- )
437
-
438
- with self.report.new_stage(LINEAGE_EXTRACTION):
439
- yield from self.extract_lineage_v2(
440
- connection=connection,
441
- database=database,
442
- lineage_extractor=lineage_extractor,
443
- )
444
-
445
- all_tables = self.get_all_tables()
446
- else:
447
- yield from self.process_schemas(connection, database)
422
+ with RedshiftSqlLineage(
423
+ config=self.config,
424
+ report=self.report,
425
+ context=self.ctx,
426
+ database=database,
427
+ redundant_run_skip_handler=self.redundant_lineage_run_skip_handler,
428
+ ) as lineage_extractor:
429
+ yield from lineage_extractor.aggregator.register_schemas_from_stream(
430
+ self.process_schemas(connection, database)
431
+ )
448
432
 
449
- all_tables = self.get_all_tables()
433
+ with self.report.new_stage(LINEAGE_EXTRACTION):
434
+ yield from self.extract_lineage_v2(
435
+ connection=connection,
436
+ database=database,
437
+ lineage_extractor=lineage_extractor,
438
+ )
450
439
 
451
- if (
452
- self.config.include_table_lineage
453
- or self.config.include_view_lineage
454
- or self.config.include_copy_lineage
455
- ):
456
- with self.report.new_stage(LINEAGE_EXTRACTION):
457
- yield from self.extract_lineage(
458
- connection=connection, all_tables=all_tables, database=database
459
- )
440
+ all_tables = self.get_all_tables()
460
441
 
461
442
  if self.config.include_usage_statistics:
462
443
  with self.report.new_stage(USAGE_EXTRACTION_INGESTION):
@@ -968,45 +949,11 @@ class RedshiftSource(StatefulIngestionSourceBase, TestableSource):
968
949
 
969
950
  self.report.usage_extraction_sec[database] = timer.elapsed_seconds(digits=2)
970
951
 
971
- def extract_lineage(
972
- self,
973
- connection: redshift_connector.Connection,
974
- database: str,
975
- all_tables: Dict[str, Dict[str, List[Union[RedshiftView, RedshiftTable]]]],
976
- ) -> Iterable[MetadataWorkUnit]:
977
- if not self._should_ingest_lineage():
978
- return
979
-
980
- lineage_extractor = RedshiftLineageExtractor(
981
- config=self.config,
982
- report=self.report,
983
- context=self.ctx,
984
- redundant_run_skip_handler=self.redundant_lineage_run_skip_handler,
985
- )
986
-
987
- with PerfTimer() as timer:
988
- lineage_extractor.populate_lineage(
989
- database=database, connection=connection, all_tables=all_tables
990
- )
991
-
992
- self.report.lineage_extraction_sec[f"{database}"] = timer.elapsed_seconds(
993
- digits=2
994
- )
995
- yield from self.generate_lineage(
996
- database, lineage_extractor=lineage_extractor
997
- )
998
-
999
- if self.redundant_lineage_run_skip_handler:
1000
- # Update the checkpoint state for this run.
1001
- self.redundant_lineage_run_skip_handler.update_state(
1002
- self.config.start_time, self.config.end_time
1003
- )
1004
-
1005
952
  def extract_lineage_v2(
1006
953
  self,
1007
954
  connection: redshift_connector.Connection,
1008
955
  database: str,
1009
- lineage_extractor: RedshiftSqlLineageV2,
956
+ lineage_extractor: RedshiftSqlLineage,
1010
957
  ) -> Iterable[MetadataWorkUnit]:
1011
958
  if self.config.include_share_lineage:
1012
959
  outbound_shares = self.data_dictionary.get_outbound_datashares(connection)
@@ -1069,40 +1016,6 @@ class RedshiftSource(StatefulIngestionSourceBase, TestableSource):
1069
1016
 
1070
1017
  return True
1071
1018
 
1072
- def generate_lineage(
1073
- self, database: str, lineage_extractor: RedshiftLineageExtractor
1074
- ) -> Iterable[MetadataWorkUnit]:
1075
- logger.info(f"Generate lineage for {database}")
1076
- for schema in deduplicate_list(
1077
- itertools.chain(self.db_tables[database], self.db_views[database])
1078
- ):
1079
- if (
1080
- database not in self.db_schemas
1081
- or schema not in self.db_schemas[database]
1082
- ):
1083
- logger.warning(
1084
- f"Either database {database} or {schema} exists in the lineage but was not discovered earlier. Something went wrong."
1085
- )
1086
- continue
1087
-
1088
- table_or_view: Union[RedshiftTable, RedshiftView]
1089
- for table_or_view in (
1090
- []
1091
- + self.db_tables[database].get(schema, [])
1092
- + self.db_views[database].get(schema, [])
1093
- ):
1094
- datahub_dataset_name = f"{database}.{schema}.{table_or_view.name}"
1095
- dataset_urn = self.gen_dataset_urn(datahub_dataset_name)
1096
-
1097
- lineage_info = lineage_extractor.get_lineage(
1098
- table_or_view,
1099
- dataset_urn,
1100
- self.db_schemas[database][schema],
1101
- )
1102
- if lineage_info:
1103
- # incremental lineage generation is taken care by auto_incremental_lineage
1104
- yield from gen_lineage(dataset_urn, lineage_info)
1105
-
1106
1019
  def add_config_to_report(self):
1107
1020
  self.report.stateful_lineage_ingestion_enabled = (
1108
1021
  self.config.enable_stateful_lineage_ingestion
@@ -441,13 +441,16 @@ class SnowflakeSchemaGenerator(SnowflakeStructuredReportMixin):
441
441
  tables = self.fetch_tables_for_schema(
442
442
  snowflake_schema, db_name, schema_name
443
443
  )
444
+ if self.config.include_views:
445
+ views = self.fetch_views_for_schema(snowflake_schema, db_name, schema_name)
446
+
447
+ if self.config.include_tables:
444
448
  db_tables[schema_name] = tables
445
449
  yield from self._process_tables(
446
450
  tables, snowflake_schema, db_name, schema_name
447
451
  )
448
452
 
449
453
  if self.config.include_views:
450
- views = self.fetch_views_for_schema(snowflake_schema, db_name, schema_name)
451
454
  yield from self._process_views(
452
455
  views, snowflake_schema, db_name, schema_name
453
456
  )
@@ -199,6 +199,7 @@ class SnowflakeV2Source(
199
199
  ),
200
200
  generate_usage_statistics=False,
201
201
  generate_operations=False,
202
+ generate_queries=self.config.include_queries,
202
203
  format_queries=self.config.format_sql_queries,
203
204
  is_temp_table=self._is_temp_table,
204
205
  is_allowed_table=self._is_allowed_table,
@@ -134,7 +134,9 @@ class StoredProcedure:
134
134
 
135
135
  @property
136
136
  def escape_full_name(self) -> str:
137
- return f"[{self.db}].[{self.schema}].[{self.formatted_name}]"
137
+ return f"[{self.db}].[{self.schema}].[{self.formatted_name}]".replace(
138
+ "'", r"''"
139
+ )
138
140
 
139
141
  def to_base_procedure(self) -> BaseProcedure:
140
142
  return BaseProcedure(
@@ -10,6 +10,7 @@ from sqlalchemy import create_engine, inspect
10
10
  from sqlalchemy.engine.base import Connection
11
11
  from sqlalchemy.engine.reflection import Inspector
12
12
  from sqlalchemy.exc import ProgrammingError, ResourceClosedError
13
+ from sqlalchemy.sql import quoted_name
13
14
 
14
15
  import datahub.metadata.schema_classes as models
15
16
  from datahub.configuration.common import AllowDenyPattern
@@ -130,10 +131,14 @@ class SQLServerConfig(BasicSQLAlchemyConfig):
130
131
  "match the entire table name in database.schema.table format. Defaults are to set in such a way "
131
132
  "to ignore the temporary staging tables created by known ETL tools.",
132
133
  )
134
+ quote_schemas: bool = Field(
135
+ default=False,
136
+ description="Represent a schema identifiers combined with quoting preferences. See [sqlalchemy quoted_name docs](https://docs.sqlalchemy.org/en/20/core/sqlelement.html#sqlalchemy.sql.expression.quoted_name).",
137
+ )
133
138
 
134
139
  @pydantic.validator("uri_args")
135
140
  def passwords_match(cls, v, values, **kwargs):
136
- if values["use_odbc"] and "driver" not in v:
141
+ if values["use_odbc"] and not values["sqlalchemy_uri"] and "driver" not in v:
137
142
  raise ValueError("uri_args must contain a 'driver' option")
138
143
  elif not values["use_odbc"] and v:
139
144
  raise ValueError("uri_args is not supported when ODBC is disabled")
@@ -159,7 +164,15 @@ class SQLServerConfig(BasicSQLAlchemyConfig):
159
164
  uri_opts=uri_opts,
160
165
  )
161
166
  if self.use_odbc:
162
- uri = f"{uri}?{urllib.parse.urlencode(self.uri_args)}"
167
+ final_uri_args = self.uri_args.copy()
168
+ if final_uri_args and current_db:
169
+ final_uri_args.update({"database": current_db})
170
+
171
+ uri = (
172
+ f"{uri}?{urllib.parse.urlencode(final_uri_args)}"
173
+ if final_uri_args
174
+ else uri
175
+ )
163
176
  return uri
164
177
 
165
178
  @property
@@ -923,7 +936,11 @@ class SQLServerSource(SQLAlchemySource):
923
936
  logger.debug(f"sql_alchemy_url={url}")
924
937
  engine = create_engine(url, **self.config.options)
925
938
 
926
- if self.config.database and self.config.database != "":
939
+ if (
940
+ self.config.database
941
+ and self.config.database != ""
942
+ or (self.config.sqlalchemy_uri and self.config.sqlalchemy_uri != "")
943
+ ):
927
944
  inspector = inspect(engine)
928
945
  yield inspector
929
946
  else:
@@ -1020,3 +1037,45 @@ class SQLServerSource(SQLAlchemySource):
1020
1037
  if self.config.convert_urns_to_lowercase
1021
1038
  else table_ref_str
1022
1039
  )
1040
+
1041
+ def get_allowed_schemas(self, inspector: Inspector, db_name: str) -> Iterable[str]:
1042
+ for schema in super().get_allowed_schemas(inspector, db_name):
1043
+ if self.config.quote_schemas:
1044
+ yield quoted_name(schema, True)
1045
+ else:
1046
+ yield schema
1047
+
1048
+ def get_db_name(self, inspector: Inspector) -> str:
1049
+ engine = inspector.engine
1050
+
1051
+ try:
1052
+ if (
1053
+ engine
1054
+ and hasattr(engine, "url")
1055
+ and hasattr(engine.url, "database")
1056
+ and engine.url.database
1057
+ ):
1058
+ return str(engine.url.database).strip('"')
1059
+
1060
+ if (
1061
+ engine
1062
+ and hasattr(engine, "url")
1063
+ and hasattr(engine.url, "query")
1064
+ and "odbc_connect" in engine.url.query
1065
+ ):
1066
+ # According to the ODBC connection keywords: https://learn.microsoft.com/en-us/sql/connect/odbc/dsn-connection-string-attribute?view=sql-server-ver17#supported-dsnconnection-string-keywords-and-connection-attributes
1067
+ database = re.search(
1068
+ r"DATABASE=([^;]*);",
1069
+ urllib.parse.unquote_plus(str(engine.url.query["odbc_connect"])),
1070
+ flags=re.IGNORECASE,
1071
+ )
1072
+
1073
+ if database and database.group(1):
1074
+ return database.group(1)
1075
+
1076
+ return ""
1077
+
1078
+ except Exception as e:
1079
+ raise RuntimeError(
1080
+ "Unable to get database name from Sqlalchemy inspector"
1081
+ ) from e
@@ -312,6 +312,17 @@ class UnityCatalogSourceConfig(
312
312
 
313
313
  scheme: str = DATABRICKS
314
314
 
315
+ include_ml_model_aliases: bool = pydantic.Field(
316
+ default=False,
317
+ description="Whether to include ML model aliases in the ingestion.",
318
+ )
319
+
320
+ ml_model_max_results: int = pydantic.Field(
321
+ default=1000,
322
+ ge=0,
323
+ description="Maximum number of ML models to ingest.",
324
+ )
325
+
315
326
  def get_sql_alchemy_url(self, database: Optional[str] = None) -> str:
316
327
  uri_opts = {"http_path": f"/sql/1.0/warehouses/{self.warehouse_id}"}
317
328
  if database:
@@ -17,6 +17,8 @@ from databricks.sdk.service.catalog import (
17
17
  ColumnInfo,
18
18
  GetMetastoreSummaryResponse,
19
19
  MetastoreInfo,
20
+ ModelVersionInfo,
21
+ RegisteredModelInfo,
20
22
  SchemaInfo,
21
23
  TableInfo,
22
24
  )
@@ -49,6 +51,8 @@ from datahub.ingestion.source.unity.proxy_types import (
49
51
  CustomCatalogType,
50
52
  ExternalTableReference,
51
53
  Metastore,
54
+ Model,
55
+ ModelVersion,
52
56
  Notebook,
53
57
  NotebookReference,
54
58
  Query,
@@ -251,6 +255,40 @@ class UnityCatalogApiProxy(UnityCatalogProxyProfilingMixin):
251
255
  logger.warning(f"Error parsing table: {e}")
252
256
  self.report.report_warning("table-parse", str(e))
253
257
 
258
+ def ml_models(
259
+ self, schema: Schema, max_results: Optional[int] = None
260
+ ) -> Iterable[Model]:
261
+ response = self._workspace_client.registered_models.list(
262
+ catalog_name=schema.catalog.name,
263
+ schema_name=schema.name,
264
+ max_results=max_results,
265
+ )
266
+ for ml_model in response:
267
+ optional_ml_model = self._create_ml_model(schema, ml_model)
268
+ if optional_ml_model:
269
+ yield optional_ml_model
270
+
271
+ def ml_model_versions(
272
+ self, ml_model: Model, include_aliases: bool = False
273
+ ) -> Iterable[ModelVersion]:
274
+ response = self._workspace_client.model_versions.list(
275
+ full_name=ml_model.id,
276
+ include_browse=True,
277
+ max_results=self.databricks_api_page_size,
278
+ )
279
+ for version in response:
280
+ if version.version is not None:
281
+ if include_aliases:
282
+ # to get aliases info, use GET
283
+ version = self._workspace_client.model_versions.get(
284
+ ml_model.id, version.version, include_aliases=True
285
+ )
286
+ optional_ml_model_version = self._create_ml_model_version(
287
+ ml_model, version
288
+ )
289
+ if optional_ml_model_version:
290
+ yield optional_ml_model_version
291
+
254
292
  def service_principals(self) -> Iterable[ServicePrincipal]:
255
293
  for principal in self._workspace_client.service_principals.list():
256
294
  optional_sp = self._create_service_principal(principal)
@@ -862,6 +900,45 @@ class UnityCatalogApiProxy(UnityCatalogProxyProfilingMixin):
862
900
  if optional_column:
863
901
  yield optional_column
864
902
 
903
+ def _create_ml_model(
904
+ self, schema: Schema, obj: RegisteredModelInfo
905
+ ) -> Optional[Model]:
906
+ if not obj.name or not obj.full_name:
907
+ self.report.num_ml_models_missing_name += 1
908
+ return None
909
+ return Model(
910
+ id=obj.full_name,
911
+ name=obj.name,
912
+ description=obj.comment,
913
+ schema_name=schema.name,
914
+ catalog_name=schema.catalog.name,
915
+ created_at=parse_ts_millis(obj.created_at),
916
+ updated_at=parse_ts_millis(obj.updated_at),
917
+ )
918
+
919
+ def _create_ml_model_version(
920
+ self, model: Model, obj: ModelVersionInfo
921
+ ) -> Optional[ModelVersion]:
922
+ if obj.version is None:
923
+ return None
924
+
925
+ aliases = []
926
+ if obj.aliases:
927
+ for alias in obj.aliases:
928
+ if alias.alias_name:
929
+ aliases.append(alias.alias_name)
930
+ return ModelVersion(
931
+ id=f"{model.id}_{obj.version}",
932
+ name=f"{model.name}_{obj.version}",
933
+ model=model,
934
+ version=str(obj.version),
935
+ aliases=aliases,
936
+ description=obj.comment,
937
+ created_at=parse_ts_millis(obj.created_at),
938
+ updated_at=parse_ts_millis(obj.updated_at),
939
+ created_by=obj.created_by,
940
+ )
941
+
865
942
  def _create_service_principal(
866
943
  self, obj: DatabricksServicePrincipal
867
944
  ) -> Optional[ServicePrincipal]:
@@ -337,3 +337,27 @@ class Notebook:
337
337
  "upstreams": frozenset([*notebook.upstreams, upstream]),
338
338
  }
339
339
  )
340
+
341
+
342
+ @dataclass
343
+ class Model:
344
+ id: str
345
+ name: str
346
+ schema_name: str
347
+ catalog_name: str
348
+ description: Optional[str]
349
+ created_at: Optional[datetime]
350
+ updated_at: Optional[datetime]
351
+
352
+
353
+ @dataclass
354
+ class ModelVersion:
355
+ id: str
356
+ name: str
357
+ model: Model
358
+ version: str
359
+ aliases: Optional[List[str]]
360
+ description: Optional[str]
361
+ created_at: Optional[datetime]
362
+ updated_at: Optional[datetime]
363
+ created_by: Optional[str]
@@ -31,6 +31,10 @@ class UnityCatalogReport(IngestionStageReport, SQLSourceReport):
31
31
  tables: EntityFilterReport = EntityFilterReport.field(type="table/view")
32
32
  table_profiles: EntityFilterReport = EntityFilterReport.field(type="table profile")
33
33
  notebooks: EntityFilterReport = EntityFilterReport.field(type="notebook")
34
+ ml_models: EntityFilterReport = EntityFilterReport.field(type="ml_model")
35
+ ml_model_versions: EntityFilterReport = EntityFilterReport.field(
36
+ type="ml_model_version"
37
+ )
34
38
 
35
39
  hive_metastore_catalog_found: Optional[bool] = None
36
40
 
@@ -64,6 +68,7 @@ class UnityCatalogReport(IngestionStageReport, SQLSourceReport):
64
68
  num_catalogs_missing_name: int = 0
65
69
  num_schemas_missing_name: int = 0
66
70
  num_tables_missing_name: int = 0
71
+ num_ml_models_missing_name: int = 0
67
72
  num_columns_missing_name: int = 0
68
73
  num_queries_missing_info: int = 0
69
74
 
@@ -12,6 +12,7 @@ from datahub.emitter.mce_builder import (
12
12
  make_dataset_urn_with_platform_instance,
13
13
  make_domain_urn,
14
14
  make_group_urn,
15
+ make_ml_model_group_urn,
15
16
  make_schema_field_urn,
16
17
  make_ts_millis,
17
18
  make_user_urn,
@@ -26,6 +27,7 @@ from datahub.emitter.mcp_builder import (
26
27
  UnitySchemaKey,
27
28
  UnitySchemaKeyWithMetastore,
28
29
  add_dataset_to_container,
30
+ add_entity_to_container,
29
31
  gen_containers,
30
32
  )
31
33
  from datahub.emitter.sql_parsing_builder import SqlParsingBuilder
@@ -87,6 +89,8 @@ from datahub.ingestion.source.unity.proxy_types import (
87
89
  CustomCatalogType,
88
90
  HiveTableType,
89
91
  Metastore,
92
+ Model,
93
+ ModelVersion,
90
94
  Notebook,
91
95
  NotebookId,
92
96
  Schema,
@@ -121,6 +125,7 @@ from datahub.metadata.schema_classes import (
121
125
  DatasetLineageTypeClass,
122
126
  DatasetPropertiesClass,
123
127
  DomainsClass,
128
+ MLModelPropertiesClass,
124
129
  MySqlDDLClass,
125
130
  NullTypeClass,
126
131
  OwnerClass,
@@ -134,7 +139,8 @@ from datahub.metadata.schema_classes import (
134
139
  UpstreamClass,
135
140
  UpstreamLineageClass,
136
141
  )
137
- from datahub.metadata.urns import TagUrn
142
+ from datahub.metadata.urns import MlModelGroupUrn, MlModelUrn, TagUrn
143
+ from datahub.sdk import MLModel, MLModelGroup
138
144
  from datahub.sql_parsing.schema_resolver import SchemaResolver
139
145
  from datahub.sql_parsing.sqlglot_lineage import (
140
146
  SqlParsingResult,
@@ -182,6 +188,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
182
188
  - metastores
183
189
  - schemas
184
190
  - tables and column lineage
191
+ - model and model versions
185
192
  """
186
193
 
187
194
  config: UnityCatalogSourceConfig
@@ -512,6 +519,7 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
512
519
  yield from self.gen_schema_containers(schema)
513
520
  try:
514
521
  yield from self.process_tables(schema)
522
+ yield from self.process_ml_models(schema)
515
523
  except Exception as e:
516
524
  logger.exception(f"Error parsing schema {schema}")
517
525
  self.report.report_warning(
@@ -665,6 +673,69 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
665
673
  )
666
674
  ]
667
675
 
676
+ def process_ml_models(self, schema: Schema) -> Iterable[MetadataWorkUnit]:
677
+ for ml_model in self.unity_catalog_api_proxy.ml_models(
678
+ schema=schema, max_results=self.config.ml_model_max_results
679
+ ):
680
+ yield from self.process_ml_model(ml_model, schema)
681
+ ml_model_urn = self.gen_ml_model_urn(ml_model.id)
682
+ for ml_model_version in self.unity_catalog_api_proxy.ml_model_versions(
683
+ ml_model, include_aliases=self.config.include_ml_model_aliases
684
+ ):
685
+ yield from self.process_ml_model_version(
686
+ ml_model_urn, ml_model_version, schema
687
+ )
688
+
689
+ def process_ml_model(
690
+ self, ml_model: Model, schema: Schema
691
+ ) -> Iterable[MetadataWorkUnit]:
692
+ ml_model_group = MLModelGroup(
693
+ id=ml_model.id,
694
+ name=ml_model.name,
695
+ platform=self.platform,
696
+ platform_instance=schema.name,
697
+ env=self.config.env,
698
+ description=ml_model.description,
699
+ created=ml_model.created_at,
700
+ last_modified=ml_model.updated_at,
701
+ )
702
+ yield from ml_model_group.as_workunits()
703
+ yield from self.add_model_to_schema_container(str(ml_model_group.urn), schema)
704
+ self.report.ml_models.processed(ml_model.id)
705
+
706
+ def process_ml_model_version(
707
+ self, ml_model_urn: str, ml_model_version: ModelVersion, schema: Schema
708
+ ) -> Iterable[MetadataWorkUnit]:
709
+ extra_aspects = []
710
+ if ml_model_version.created_at is not None:
711
+ created_time = int(ml_model_version.created_at.timestamp() * 1000)
712
+ created_actor = (
713
+ f"urn:li:platformResource:{ml_model_version.created_by}"
714
+ if ml_model_version.created_by
715
+ else None
716
+ )
717
+ extra_aspects.append(
718
+ MLModelPropertiesClass(
719
+ created=TimeStampClass(time=created_time, actor=created_actor),
720
+ )
721
+ )
722
+
723
+ ml_model = MLModel(
724
+ id=ml_model_version.id,
725
+ name=ml_model_version.name,
726
+ version=str(ml_model_version.version),
727
+ aliases=ml_model_version.aliases,
728
+ description=ml_model_version.description,
729
+ model_group=ml_model_urn,
730
+ platform=self.platform,
731
+ last_modified=ml_model_version.updated_at,
732
+ extra_aspects=extra_aspects,
733
+ )
734
+
735
+ yield from ml_model.as_workunits()
736
+ yield from self.add_model_version_to_schema_container(str(ml_model.urn), schema)
737
+ self.report.ml_model_versions.processed(ml_model_version.id)
738
+
668
739
  def ingest_lineage(self, table: Table) -> Optional[UpstreamLineageClass]:
669
740
  # Calculate datetime filters for lineage
670
741
  lineage_start_time = None
@@ -802,6 +873,13 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
802
873
  env=self.config.env,
803
874
  )
804
875
 
876
+ def gen_ml_model_urn(self, name: str) -> str:
877
+ return make_ml_model_group_urn(
878
+ platform=self.platform,
879
+ group_name=name,
880
+ env=self.config.env,
881
+ )
882
+
805
883
  def gen_notebook_urn(self, notebook: Union[Notebook, NotebookId]) -> str:
806
884
  notebook_id = notebook.id if isinstance(notebook, Notebook) else notebook
807
885
  return NotebookKey(
@@ -973,6 +1051,26 @@ class UnityCatalogSource(StatefulIngestionSourceBase, TestableSource):
973
1051
  dataset_urn=dataset_urn,
974
1052
  )
975
1053
 
1054
+ def add_model_to_schema_container(
1055
+ self, model_urn: str, schema: Schema
1056
+ ) -> Iterable[MetadataWorkUnit]:
1057
+ schema_container_key = self.gen_schema_key(schema)
1058
+ yield from add_entity_to_container(
1059
+ container_key=schema_container_key,
1060
+ entity_type=MlModelGroupUrn.ENTITY_TYPE,
1061
+ entity_urn=model_urn,
1062
+ )
1063
+
1064
+ def add_model_version_to_schema_container(
1065
+ self, model_version_urn: str, schema: Schema
1066
+ ) -> Iterable[MetadataWorkUnit]:
1067
+ schema_container_key = self.gen_schema_key(schema)
1068
+ yield from add_entity_to_container(
1069
+ container_key=schema_container_key,
1070
+ entity_type=MlModelUrn.ENTITY_TYPE,
1071
+ entity_urn=model_version_urn,
1072
+ )
1073
+
976
1074
  def _get_catalog_tags(
977
1075
  self, catalog: str, schema: str, table: str
978
1076
  ) -> List[UnityCatalogTag]: