acryl-datahub 1.0.0rc13__py3-none-any.whl → 1.0.0rc14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of acryl-datahub might be problematic. Click here for more details.

Files changed (43) hide show
  1. {acryl_datahub-1.0.0rc13.dist-info → acryl_datahub-1.0.0rc14.dist-info}/METADATA +2524 -2524
  2. {acryl_datahub-1.0.0rc13.dist-info → acryl_datahub-1.0.0rc14.dist-info}/RECORD +43 -43
  3. datahub/_version.py +1 -1
  4. datahub/configuration/common.py +1 -1
  5. datahub/emitter/rest_emitter.py +165 -10
  6. datahub/ingestion/glossary/classification_mixin.py +1 -5
  7. datahub/ingestion/graph/client.py +6 -3
  8. datahub/ingestion/reporting/datahub_ingestion_run_summary_provider.py +1 -1
  9. datahub/ingestion/run/pipeline.py +2 -4
  10. datahub/ingestion/sink/datahub_rest.py +4 -0
  11. datahub/ingestion/source/common/subtypes.py +5 -0
  12. datahub/ingestion/source/data_lake_common/path_spec.py +1 -3
  13. datahub/ingestion/source/dbt/dbt_common.py +2 -4
  14. datahub/ingestion/source/dbt/dbt_tests.py +4 -8
  15. datahub/ingestion/source/dremio/dremio_api.py +1 -5
  16. datahub/ingestion/source/dremio/dremio_aspects.py +1 -4
  17. datahub/ingestion/source/dynamodb/dynamodb.py +1 -0
  18. datahub/ingestion/source/kafka_connect/common.py +1 -6
  19. datahub/ingestion/source/mlflow.py +338 -31
  20. datahub/ingestion/source/redshift/lineage.py +2 -2
  21. datahub/ingestion/source/redshift/lineage_v2.py +19 -7
  22. datahub/ingestion/source/redshift/profile.py +1 -1
  23. datahub/ingestion/source/redshift/query.py +14 -6
  24. datahub/ingestion/source/redshift/redshift.py +9 -5
  25. datahub/ingestion/source/redshift/redshift_schema.py +27 -7
  26. datahub/ingestion/source/sql/athena.py +6 -12
  27. datahub/ingestion/source/sql/hive.py +2 -6
  28. datahub/ingestion/source/sql/hive_metastore.py +2 -1
  29. datahub/ingestion/source/sql/sql_common.py +3 -9
  30. datahub/ingestion/source/state/stale_entity_removal_handler.py +4 -8
  31. datahub/ingestion/source/superset.py +1 -3
  32. datahub/ingestion/source/tableau/tableau_common.py +1 -1
  33. datahub/lite/duckdb_lite.py +1 -3
  34. datahub/metadata/_schema_classes.py +31 -1
  35. datahub/metadata/schema.avsc +56 -4
  36. datahub/metadata/schemas/DataProcessInstanceInput.avsc +129 -1
  37. datahub/metadata/schemas/DataProcessInstanceOutput.avsc +131 -3
  38. datahub/sdk/dataset.py +2 -2
  39. datahub/sql_parsing/sqlglot_utils.py +1 -4
  40. {acryl_datahub-1.0.0rc13.dist-info → acryl_datahub-1.0.0rc14.dist-info}/LICENSE +0 -0
  41. {acryl_datahub-1.0.0rc13.dist-info → acryl_datahub-1.0.0rc14.dist-info}/WHEEL +0 -0
  42. {acryl_datahub-1.0.0rc13.dist-info → acryl_datahub-1.0.0rc14.dist-info}/entry_points.txt +0 -0
  43. {acryl_datahub-1.0.0rc13.dist-info → acryl_datahub-1.0.0rc14.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,9 @@ from datahub.emitter.mcp import MetadataChangeProposalWrapper
20
20
  from datahub.emitter.mcp_builder import mcps_from_mce
21
21
  from datahub.emitter.rest_emitter import (
22
22
  BATCH_INGEST_MAX_PAYLOAD_LENGTH,
23
+ DEFAULT_REST_SINK_ENDPOINT,
23
24
  DataHubRestEmitter,
25
+ RestSinkEndpoint,
24
26
  )
25
27
  from datahub.ingestion.api.common import RecordEnvelope, WorkUnit
26
28
  from datahub.ingestion.api.sink import (
@@ -66,6 +68,7 @@ _DEFAULT_REST_SINK_MODE = pydantic.parse_obj_as(
66
68
 
67
69
  class DatahubRestSinkConfig(DatahubClientConfig):
68
70
  mode: RestSinkMode = _DEFAULT_REST_SINK_MODE
71
+ endpoint: RestSinkEndpoint = DEFAULT_REST_SINK_ENDPOINT
69
72
 
70
73
  # These only apply in async modes.
71
74
  max_threads: pydantic.PositiveInt = _DEFAULT_REST_SINK_MAX_THREADS
@@ -172,6 +175,7 @@ class DatahubRestSink(Sink[DatahubRestSinkConfig, DataHubRestSinkReport]):
172
175
  ca_certificate_path=config.ca_certificate_path,
173
176
  client_certificate_path=config.client_certificate_path,
174
177
  disable_ssl_verification=config.disable_ssl_verification,
178
+ openapi_ingestion=config.endpoint == RestSinkEndpoint.OPENAPI,
175
179
  )
176
180
 
177
181
  @property
@@ -92,3 +92,8 @@ class BIAssetSubTypes(StrEnum):
92
92
  # SAP Analytics Cloud
93
93
  SAC_STORY = "Story"
94
94
  SAC_APPLICATION = "Application"
95
+
96
+
97
+ class MLAssetSubTypes(StrEnum):
98
+ MLFLOW_TRAINING_RUN = "ML Training Run"
99
+ MLFLOW_EXPERIMENT = "ML Experiment"
@@ -454,10 +454,8 @@ class PathSpec(ConfigModel):
454
454
  return None
455
455
  partition = partition_split[0]
456
456
  # If partition is in the form of /value1/value2/value3 we infer it from the path and assign partition_0, partition_1, partition_2 etc
457
- num = 0
458
- for partition_value in partition.split("/"):
457
+ for num, partition_value in enumerate(partition.split("/")):
459
458
  partition_keys.append((f"partition_{num}", partition_value))
460
- num += 1
461
459
  return partition_keys
462
460
 
463
461
  return None
@@ -1774,10 +1774,8 @@ class DBTSourceBase(StatefulIngestionSourceBase):
1774
1774
  logger.debug(
1775
1775
  f"Owner after applying owner extraction pattern:'{self.config.owner_extraction_pattern}' is '{owner}'."
1776
1776
  )
1777
- if isinstance(owner, list):
1778
- owners = owner
1779
- else:
1780
- owners = [owner]
1777
+ owners = owner if isinstance(owner, list) else [owner]
1778
+
1781
1779
  for owner in owners:
1782
1780
  if self.config.strip_user_ids_from_email:
1783
1781
  owner = owner.split("@")[0]
@@ -57,15 +57,11 @@ def _get_name_for_relationship_test(kw_args: Dict[str, str]) -> Optional[str]:
57
57
  # base assertions are violated, bail early
58
58
  return None
59
59
  m = re.match(r"^ref\(\'(.*)\'\)$", destination_ref)
60
- if m:
61
- destination_table = m.group(1)
62
- else:
63
- destination_table = destination_ref
60
+ destination_table = m.group(1) if m else destination_ref
61
+
64
62
  m = re.search(r"ref\(\'(.*)\'\)", source_ref)
65
- if m:
66
- source_table = m.group(1)
67
- else:
68
- source_table = source_ref
63
+ source_table = m.group(1) if m else source_ref
64
+
69
65
  return f"{source_table}.{column_name} referential integrity to {destination_table}.{dest_field_name}"
70
66
 
71
67
 
@@ -683,11 +683,7 @@ class DremioAPIOperations:
683
683
  # Add end anchor for exact matching
684
684
  regex_pattern = regex_pattern + "$"
685
685
 
686
- for path in paths:
687
- if re.match(regex_pattern, path, re.IGNORECASE):
688
- return True
689
-
690
- return False
686
+ return any(re.match(regex_pattern, path, re.IGNORECASE) for path in paths)
691
687
 
692
688
  def should_include_container(self, path: List[str], name: str) -> bool:
693
689
  """
@@ -116,10 +116,7 @@ class SchemaFieldTypeMapper:
116
116
  data_type = data_type.lower()
117
117
  type_class = cls.FIELD_TYPE_MAPPING.get(data_type, NullTypeClass)
118
118
 
119
- if data_size:
120
- native_data_type = f"{data_type}({data_size})"
121
- else:
122
- native_data_type = data_type
119
+ native_data_type = f"{data_type}({data_size})" if data_size else data_type
123
120
 
124
121
  try:
125
122
  schema_field_type = SchemaFieldDataTypeClass(type=type_class())
@@ -246,6 +246,7 @@ class DynamoDBSource(StatefulIngestionSourceBase):
246
246
  platform=self.platform,
247
247
  platform_instance=platform_instance,
248
248
  name=dataset_name,
249
+ env=self.config.env,
249
250
  )
250
251
  dataset_properties = DatasetPropertiesClass(
251
252
  name=table_name,
@@ -141,12 +141,7 @@ def get_dataset_name(
141
141
  database_name: Optional[str],
142
142
  source_table: str,
143
143
  ) -> str:
144
- if database_name:
145
- dataset_name = database_name + "." + source_table
146
- else:
147
- dataset_name = source_table
148
-
149
- return dataset_name
144
+ return database_name + "." + source_table if database_name else source_table
150
145
 
151
146
 
152
147
  def get_platform_instance(
@@ -1,17 +1,20 @@
1
+ import time
1
2
  from dataclasses import dataclass
2
3
  from typing import Any, Callable, Iterable, List, Optional, TypeVar, Union
3
4
 
4
5
  from mlflow import MlflowClient
5
- from mlflow.entities import Run
6
+ from mlflow.entities import Experiment, Run
6
7
  from mlflow.entities.model_registry import ModelVersion, RegisteredModel
7
8
  from mlflow.store.entities import PagedList
8
9
  from pydantic.fields import Field
9
10
 
10
11
  import datahub.emitter.mce_builder as builder
11
- from datahub.configuration.source_common import (
12
- EnvConfigMixin,
12
+ from datahub.api.entities.dataprocess.dataprocess_instance import (
13
+ DataProcessInstance,
13
14
  )
15
+ from datahub.configuration.source_common import EnvConfigMixin
14
16
  from datahub.emitter.mcp import MetadataChangeProposalWrapper
17
+ from datahub.emitter.mcp_builder import ContainerKey
15
18
  from datahub.ingestion.api.common import PipelineContext
16
19
  from datahub.ingestion.api.decorators import (
17
20
  SupportStatus,
@@ -26,6 +29,7 @@ from datahub.ingestion.api.source import (
26
29
  SourceReport,
27
30
  )
28
31
  from datahub.ingestion.api.workunit import MetadataWorkUnit
32
+ from datahub.ingestion.source.common.subtypes import MLAssetSubTypes
29
33
  from datahub.ingestion.source.state.stale_entity_removal_handler import (
30
34
  StaleEntityRemovalHandler,
31
35
  StaleEntityRemovalSourceReport,
@@ -35,20 +39,45 @@ from datahub.ingestion.source.state.stateful_ingestion_base import (
35
39
  StatefulIngestionSourceBase,
36
40
  )
37
41
  from datahub.metadata.schema_classes import (
42
+ AuditStampClass,
43
+ ContainerClass,
44
+ DataPlatformInstanceClass,
45
+ DataProcessInstanceOutputClass,
46
+ DataProcessInstancePropertiesClass,
47
+ DataProcessInstanceRunEventClass,
48
+ DataProcessInstanceRunResultClass,
49
+ DataProcessRunStatusClass,
50
+ EdgeClass,
38
51
  GlobalTagsClass,
52
+ MetadataAttributionClass,
39
53
  MLHyperParamClass,
40
54
  MLMetricClass,
41
55
  MLModelGroupPropertiesClass,
42
56
  MLModelPropertiesClass,
57
+ MLTrainingRunPropertiesClass,
58
+ PlatformResourceInfoClass,
59
+ SubTypesClass,
43
60
  TagAssociationClass,
44
61
  TagPropertiesClass,
62
+ TimeStampClass,
63
+ VersionPropertiesClass,
45
64
  VersionTagClass,
46
65
  _Aspect,
47
66
  )
67
+ from datahub.metadata.urns import (
68
+ DataPlatformUrn,
69
+ MlModelUrn,
70
+ VersionSetUrn,
71
+ )
72
+ from datahub.sdk.container import Container
48
73
 
49
74
  T = TypeVar("T")
50
75
 
51
76
 
77
+ class ContainerKeyWithId(ContainerKey):
78
+ id: str
79
+
80
+
52
81
  class MLflowConfig(StatefulIngestionConfigBase, EnvConfigMixin):
53
82
  tracking_uri: Optional[str] = Field(
54
83
  default=None,
@@ -141,6 +170,7 @@ class MLflowSource(StatefulIngestionSourceBase):
141
170
 
142
171
  def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
143
172
  yield from self._get_tags_workunits()
173
+ yield from self._get_experiment_workunits()
144
174
  yield from self._get_ml_model_workunits()
145
175
 
146
176
  def _get_tags_workunits(self) -> Iterable[MetadataWorkUnit]:
@@ -174,22 +204,162 @@ class MLflowSource(StatefulIngestionSourceBase):
174
204
  aspect=aspect,
175
205
  ).as_workunit()
176
206
 
177
- def _get_ml_model_workunits(self) -> Iterable[MetadataWorkUnit]:
178
- """
179
- Traverse each Registered Model in Model Registry and generate a corresponding workunit.
180
- """
181
- registered_models = self._get_mlflow_registered_models()
182
- for registered_model in registered_models:
183
- yield self._get_ml_group_workunit(registered_model)
184
- model_versions = self._get_mlflow_model_versions(registered_model)
185
- for model_version in model_versions:
186
- run = self._get_mlflow_run(model_version)
187
- yield self._get_ml_model_properties_workunit(
188
- registered_model=registered_model,
189
- model_version=model_version,
190
- run=run,
191
- )
192
- yield self._get_global_tags_workunit(model_version=model_version)
207
+ def _get_experiment_workunits(self) -> Iterable[MetadataWorkUnit]:
208
+ experiments = self._get_mlflow_experiments()
209
+ for experiment in experiments:
210
+ yield from self._get_experiment_container_workunit(experiment)
211
+
212
+ runs = self._get_mlflow_runs_from_experiment(experiment)
213
+ if runs:
214
+ for run in runs:
215
+ yield from self._get_run_workunits(experiment, run)
216
+
217
+ def _get_experiment_custom_properties(self, experiment):
218
+ experiment_custom_props = getattr(experiment, "tags", {}) or {}
219
+ experiment_custom_props.pop("mlflow.note.content", None)
220
+ experiment_custom_props["artifacts_location"] = experiment.artifact_location
221
+ return experiment_custom_props
222
+
223
+ def _get_experiment_container_workunit(
224
+ self, experiment: Experiment
225
+ ) -> Iterable[MetadataWorkUnit]:
226
+ experiment_container = Container(
227
+ container_key=ContainerKeyWithId(
228
+ platform=str(DataPlatformUrn(platform_name=self.platform)),
229
+ id=experiment.name,
230
+ ),
231
+ subtype=MLAssetSubTypes.MLFLOW_EXPERIMENT,
232
+ display_name=experiment.name,
233
+ description=experiment.tags.get("mlflow.note.content"),
234
+ extra_properties=self._get_experiment_custom_properties(experiment),
235
+ )
236
+
237
+ yield from experiment_container.as_workunits()
238
+
239
+ def _get_run_metrics(self, run: Run) -> List[MLMetricClass]:
240
+ return [
241
+ MLMetricClass(name=k, value=str(v)) for k, v in run.data.metrics.items()
242
+ ]
243
+
244
+ def _get_run_params(self, run: Run) -> List[MLHyperParamClass]:
245
+ return [
246
+ MLHyperParamClass(name=k, value=str(v)) for k, v in run.data.params.items()
247
+ ]
248
+
249
+ def _convert_run_result_type(
250
+ self, status: str
251
+ ) -> DataProcessInstanceRunResultClass:
252
+ if status == "FINISHED":
253
+ return DataProcessInstanceRunResultClass(
254
+ type="SUCCESS", nativeResultType=self.platform
255
+ )
256
+ elif status == "FAILED":
257
+ return DataProcessInstanceRunResultClass(
258
+ type="FAILURE", nativeResultType=self.platform
259
+ )
260
+ else:
261
+ return DataProcessInstanceRunResultClass(
262
+ type="SKIPPED", nativeResultType=self.platform
263
+ )
264
+
265
+ def _get_run_workunits(
266
+ self, experiment: Experiment, run: Run
267
+ ) -> Iterable[MetadataWorkUnit]:
268
+ experiment_key = ContainerKeyWithId(
269
+ platform=str(DataPlatformUrn(self.platform)), id=experiment.name
270
+ )
271
+
272
+ data_process_instance = DataProcessInstance(
273
+ id=run.info.run_id,
274
+ orchestrator=self.platform,
275
+ template_urn=None,
276
+ )
277
+
278
+ created_time = run.info.start_time or int(time.time() * 1000)
279
+ user_id = run.info.user_id if run.info.user_id else "mlflow"
280
+ guid_dict_user = {"platform": self.platform, "user": user_id}
281
+ platform_user_urn = (
282
+ f"urn:li:platformResource:{builder.datahub_guid(guid_dict_user)}"
283
+ )
284
+
285
+ yield MetadataChangeProposalWrapper(
286
+ entityUrn=platform_user_urn,
287
+ aspect=PlatformResourceInfoClass(
288
+ resourceType="user",
289
+ primaryKey=user_id,
290
+ ),
291
+ ).as_workunit()
292
+
293
+ yield MetadataChangeProposalWrapper(
294
+ entityUrn=str(data_process_instance.urn),
295
+ aspect=DataProcessInstancePropertiesClass(
296
+ name=run.info.run_name or run.info.run_id,
297
+ created=AuditStampClass(
298
+ time=created_time,
299
+ actor=platform_user_urn,
300
+ ),
301
+ externalUrl=self._make_external_url_from_run(experiment, run),
302
+ customProperties=getattr(run, "tags", {}) or {},
303
+ ),
304
+ ).as_workunit()
305
+
306
+ yield MetadataChangeProposalWrapper(
307
+ entityUrn=str(data_process_instance.urn),
308
+ aspect=ContainerClass(container=experiment_key.as_urn()),
309
+ ).as_workunit()
310
+
311
+ model_versions = self.get_mlflow_model_versions_from_run(run.info.run_id)
312
+ if model_versions:
313
+ model_version_urn = self._make_ml_model_urn(model_versions[0])
314
+ yield MetadataChangeProposalWrapper(
315
+ entityUrn=str(data_process_instance.urn),
316
+ aspect=DataProcessInstanceOutputClass(
317
+ outputs=[],
318
+ outputEdges=[
319
+ EdgeClass(destinationUrn=model_version_urn),
320
+ ],
321
+ ),
322
+ ).as_workunit()
323
+
324
+ metrics = self._get_run_metrics(run)
325
+ hyperparams = self._get_run_params(run)
326
+ yield MetadataChangeProposalWrapper(
327
+ entityUrn=str(data_process_instance.urn),
328
+ aspect=MLTrainingRunPropertiesClass(
329
+ hyperParams=hyperparams,
330
+ trainingMetrics=metrics,
331
+ outputUrls=[run.info.artifact_uri],
332
+ id=run.info.run_id,
333
+ ),
334
+ ).as_workunit()
335
+
336
+ if run.info.end_time:
337
+ duration_millis = run.info.end_time - run.info.start_time
338
+
339
+ yield MetadataChangeProposalWrapper(
340
+ entityUrn=str(data_process_instance.urn),
341
+ aspect=DataProcessInstanceRunEventClass(
342
+ status=DataProcessRunStatusClass.COMPLETE,
343
+ timestampMillis=run.info.end_time,
344
+ result=DataProcessInstanceRunResultClass(
345
+ type=self._convert_run_result_type(run.info.status).type,
346
+ nativeResultType=self.platform,
347
+ ),
348
+ durationMillis=duration_millis,
349
+ ),
350
+ ).as_workunit()
351
+
352
+ yield MetadataChangeProposalWrapper(
353
+ entityUrn=str(data_process_instance.urn),
354
+ aspect=DataPlatformInstanceClass(
355
+ platform=str(DataPlatformUrn(self.platform))
356
+ ),
357
+ ).as_workunit()
358
+
359
+ yield MetadataChangeProposalWrapper(
360
+ entityUrn=str(data_process_instance.urn),
361
+ aspect=SubTypesClass(typeNames=[MLAssetSubTypes.MLFLOW_TRAINING_RUN]),
362
+ ).as_workunit()
193
363
 
194
364
  def _get_mlflow_registered_models(self) -> Iterable[RegisteredModel]:
195
365
  """
@@ -202,6 +372,19 @@ class MLflowSource(StatefulIngestionSourceBase):
202
372
  )
203
373
  return registered_models
204
374
 
375
+ def _get_mlflow_experiments(self) -> Iterable[Experiment]:
376
+ experiments: Iterable[Experiment] = self._traverse_mlflow_search_func(
377
+ search_func=self.client.search_experiments,
378
+ )
379
+ return experiments
380
+
381
+ def _get_mlflow_runs_from_experiment(self, experiment: Experiment) -> Iterable[Run]:
382
+ runs: Iterable[Run] = self._traverse_mlflow_search_func(
383
+ search_func=self.client.search_runs,
384
+ experiment_ids=[experiment.experiment_id],
385
+ )
386
+ return runs
387
+
205
388
  @staticmethod
206
389
  def _traverse_mlflow_search_func(
207
390
  search_func: Callable[..., PagedList[T]],
@@ -218,6 +401,13 @@ class MLflowSource(StatefulIngestionSourceBase):
218
401
  if not next_page_token:
219
402
  return
220
403
 
404
+ def _get_latest_version(self, registered_model: RegisteredModel) -> Optional[str]:
405
+ return (
406
+ str(registered_model.latest_versions[0].version)
407
+ if registered_model.latest_versions
408
+ else None
409
+ )
410
+
221
411
  def _get_ml_group_workunit(
222
412
  self,
223
413
  registered_model: RegisteredModel,
@@ -229,7 +419,20 @@ class MLflowSource(StatefulIngestionSourceBase):
229
419
  ml_model_group_properties = MLModelGroupPropertiesClass(
230
420
  customProperties=registered_model.tags,
231
421
  description=registered_model.description,
232
- createdAt=registered_model.creation_timestamp,
422
+ created=TimeStampClass(
423
+ time=registered_model.creation_timestamp, actor=None
424
+ ),
425
+ lastModified=TimeStampClass(
426
+ time=registered_model.last_updated_timestamp,
427
+ actor=None,
428
+ ),
429
+ version=VersionTagClass(
430
+ versionTag=self._get_latest_version(registered_model),
431
+ metadataAttribution=MetadataAttributionClass(
432
+ time=registered_model.last_updated_timestamp,
433
+ actor="urn:li:corpuser:datahub",
434
+ ),
435
+ ),
233
436
  )
234
437
  wu = self._create_workunit(
235
438
  urn=ml_model_group_urn,
@@ -259,6 +462,16 @@ class MLflowSource(StatefulIngestionSourceBase):
259
462
  )
260
463
  return model_versions
261
464
 
465
+ def get_mlflow_model_versions_from_run(self, run_id):
466
+ filter_string = f"run_id = '{run_id}'"
467
+
468
+ model_versions: Iterable[ModelVersion] = self._traverse_mlflow_search_func(
469
+ search_func=self.client.search_model_versions,
470
+ filter_string=filter_string,
471
+ )
472
+
473
+ return list(model_versions)
474
+
262
475
  def _get_mlflow_run(self, model_version: ModelVersion) -> Union[None, Run]:
263
476
  """
264
477
  Get a Run associated with a Model Version. Some MVs may exist without Run.
@@ -269,6 +482,67 @@ class MLflowSource(StatefulIngestionSourceBase):
269
482
  else:
270
483
  return None
271
484
 
485
+ def _get_ml_model_workunits(self) -> Iterable[MetadataWorkUnit]:
486
+ """
487
+ Traverse each Registered Model in Model Registry and generate a corresponding workunit.
488
+ """
489
+ registered_models = self._get_mlflow_registered_models()
490
+ for registered_model in registered_models:
491
+ version_set_urn = self._get_version_set_urn(registered_model)
492
+ yield self._get_ml_group_workunit(registered_model)
493
+ model_versions = self._get_mlflow_model_versions(registered_model)
494
+ for model_version in model_versions:
495
+ run = self._get_mlflow_run(model_version)
496
+ yield self._get_ml_model_properties_workunit(
497
+ registered_model=registered_model,
498
+ model_version=model_version,
499
+ run=run,
500
+ )
501
+ yield self._get_ml_model_version_properties_workunit(
502
+ model_version=model_version,
503
+ version_set_urn=version_set_urn,
504
+ )
505
+ yield self._get_global_tags_workunit(model_version=model_version)
506
+
507
+ def _get_version_set_urn(self, registered_model: RegisteredModel) -> VersionSetUrn:
508
+ guid_dict = {"platform": self.platform, "name": registered_model.name}
509
+ version_set_urn = VersionSetUrn(
510
+ id=builder.datahub_guid(guid_dict),
511
+ entity_type=MlModelUrn.ENTITY_TYPE,
512
+ )
513
+
514
+ return version_set_urn
515
+
516
+ def _get_ml_model_version_properties_workunit(
517
+ self,
518
+ model_version: ModelVersion,
519
+ version_set_urn: VersionSetUrn,
520
+ ) -> MetadataWorkUnit:
521
+ ml_model_urn = self._make_ml_model_urn(model_version)
522
+
523
+ # get mlmodel name from ml model urn
524
+ ml_model_version_properties = VersionPropertiesClass(
525
+ version=VersionTagClass(
526
+ versionTag=str(model_version.version),
527
+ metadataAttribution=MetadataAttributionClass(
528
+ time=model_version.creation_timestamp,
529
+ actor="urn:li:corpuser:datahub",
530
+ ),
531
+ ),
532
+ versionSet=str(version_set_urn),
533
+ sortId=str(model_version.version).zfill(10),
534
+ aliases=[
535
+ VersionTagClass(versionTag=alias) for alias in model_version.aliases
536
+ ],
537
+ )
538
+
539
+ wu = MetadataChangeProposalWrapper(
540
+ entityUrn=str(ml_model_urn),
541
+ aspect=ml_model_version_properties,
542
+ ).as_workunit()
543
+
544
+ return wu
545
+
272
546
  def _get_ml_model_properties_workunit(
273
547
  self,
274
548
  registered_model: RegisteredModel,
@@ -282,28 +556,47 @@ class MLflowSource(StatefulIngestionSourceBase):
282
556
  """
283
557
  ml_model_group_urn = self._make_ml_model_group_urn(registered_model)
284
558
  ml_model_urn = self._make_ml_model_urn(model_version)
559
+
285
560
  if run:
286
- hyperparams = [
287
- MLHyperParamClass(name=k, value=str(v))
288
- for k, v in run.data.params.items()
289
- ]
290
- training_metrics = [
291
- MLMetricClass(name=k, value=str(v)) for k, v in run.data.metrics.items()
292
- ]
561
+ # Use the same metrics and hyperparams from the run
562
+ hyperparams = self._get_run_params(run)
563
+ training_metrics = self._get_run_metrics(run)
564
+ run_urn = DataProcessInstance(
565
+ id=run.info.run_id,
566
+ orchestrator=self.platform,
567
+ ).urn
568
+
569
+ training_jobs = [str(run_urn)] if run_urn else []
293
570
  else:
294
571
  hyperparams = None
295
572
  training_metrics = None
573
+ training_jobs = []
574
+
575
+ created_time = model_version.creation_timestamp
576
+ created_actor = (
577
+ f"urn:li:platformResource:{model_version.user_id}"
578
+ if model_version.user_id
579
+ else None
580
+ )
581
+ model_version_tags = [f"{k}:{v}" for k, v in model_version.tags.items()]
582
+
296
583
  ml_model_properties = MLModelPropertiesClass(
297
584
  customProperties=model_version.tags,
298
585
  externalUrl=self._make_external_url(model_version),
586
+ lastModified=TimeStampClass(
587
+ time=model_version.last_updated_timestamp,
588
+ actor=None,
589
+ ),
299
590
  description=model_version.description,
300
- date=model_version.creation_timestamp,
301
- version=VersionTagClass(versionTag=str(model_version.version)),
591
+ created=TimeStampClass(
592
+ time=created_time,
593
+ actor=created_actor,
594
+ ),
302
595
  hyperParams=hyperparams,
303
596
  trainingMetrics=training_metrics,
304
- # mlflow tags are dicts, but datahub tags are lists. currently use only keys from mlflow tags
305
- tags=list(model_version.tags.keys()),
597
+ tags=model_version_tags,
306
598
  groups=[ml_model_group_urn],
599
+ trainingJobs=training_jobs,
307
600
  )
308
601
  wu = self._create_workunit(urn=ml_model_urn, aspect=ml_model_properties)
309
602
  return wu
@@ -337,6 +630,15 @@ class MLflowSource(StatefulIngestionSourceBase):
337
630
  else:
338
631
  return None
339
632
 
633
+ def _make_external_url_from_run(
634
+ self, experiment: Experiment, run: Run
635
+ ) -> Union[None, str]:
636
+ base_uri = self.client.tracking_uri
637
+ if base_uri.startswith("http"):
638
+ return f"{base_uri.rstrip('/')}/#/experiments/{experiment.experiment_id}/runs/{run.info.run_id}"
639
+ else:
640
+ return None
641
+
340
642
  def _get_global_tags_workunit(
341
643
  self,
342
644
  model_version: ModelVersion,
@@ -356,3 +658,8 @@ class MLflowSource(StatefulIngestionSourceBase):
356
658
  aspect=global_tags,
357
659
  )
358
660
  return wu
661
+
662
+ @classmethod
663
+ def create(cls, config_dict: dict, ctx: PipelineContext) -> "MLflowSource":
664
+ config = MLflowConfig.parse_obj(config_dict)
665
+ return cls(ctx, config)
@@ -814,8 +814,8 @@ class RedshiftLineageExtractor:
814
814
 
815
815
  tablename = table.name
816
816
  if (
817
- table.is_external_table
818
- and schema.is_external_schema
817
+ table.is_external_table()
818
+ and schema.is_external_schema()
819
819
  and schema.external_platform
820
820
  ):
821
821
  # external_db_params = schema.option
@@ -403,8 +403,8 @@ class RedshiftSqlLineageV2(Closeable):
403
403
  for table in tables:
404
404
  schema = db_schemas[self.database][schema_name]
405
405
  if (
406
- table.is_external_table
407
- and schema.is_external_schema
406
+ table.is_external_table()
407
+ and schema.is_external_schema()
408
408
  and schema.external_platform
409
409
  ):
410
410
  # external_db_params = schema.option
@@ -416,14 +416,26 @@ class RedshiftSqlLineageV2(Closeable):
416
416
  platform_instance=self.config.platform_instance,
417
417
  env=self.config.env,
418
418
  )
419
- upstream_urn = mce_builder.make_dataset_urn_with_platform_instance(
420
- upstream_platform,
421
- f"{schema.external_database}.{table.name}",
422
- platform_instance=(
419
+ if upstream_platform == self.platform:
420
+ upstream_schema = schema.get_upstream_schema_name() or "public"
421
+ upstream_dataset_name = (
422
+ f"{schema.external_database}.{upstream_schema}.{table.name}"
423
+ )
424
+ upstream_platform_instance = self.config.platform_instance
425
+ else:
426
+ upstream_dataset_name = (
427
+ f"{schema.external_database}.{table.name}"
428
+ )
429
+ upstream_platform_instance = (
423
430
  self.config.platform_instance_map.get(upstream_platform)
424
431
  if self.config.platform_instance_map
425
432
  else None
426
- ),
433
+ )
434
+
435
+ upstream_urn = mce_builder.make_dataset_urn_with_platform_instance(
436
+ upstream_platform,
437
+ upstream_dataset_name,
438
+ platform_instance=upstream_platform_instance,
427
439
  env=self.config.env,
428
440
  )
429
441