acryl-datahub 1.0.0.3rc2__py3-none-any.whl → 1.0.0.3rc5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of acryl-datahub might be problematic. Click here for more details.

@@ -1,5 +1,6 @@
1
1
  import dataclasses
2
2
  import logging
3
+ from datetime import datetime, timedelta
3
4
  from typing import Dict, Iterable, List, Optional, Tuple, TypeVar, Union
4
5
 
5
6
  from google.api_core.exceptions import GoogleAPICallError
@@ -12,15 +13,22 @@ from google.cloud.aiplatform import (
12
13
  AutoMLVideoTrainingJob,
13
14
  Endpoint,
14
15
  ExperimentRun,
16
+ PipelineJob,
15
17
  )
16
18
  from google.cloud.aiplatform.base import VertexAiResourceNoun
17
19
  from google.cloud.aiplatform.metadata.execution import Execution
18
20
  from google.cloud.aiplatform.metadata.experiment_resources import Experiment
19
21
  from google.cloud.aiplatform.models import Model, VersionInfo
20
22
  from google.cloud.aiplatform.training_jobs import _TrainingJob
23
+ from google.cloud.aiplatform_v1.types import (
24
+ PipelineJob as PipelineJobType,
25
+ PipelineTaskDetail,
26
+ )
21
27
  from google.oauth2 import service_account
28
+ from google.protobuf import timestamp_pb2
22
29
 
23
30
  import datahub.emitter.mce_builder as builder
31
+ from datahub.api.entities.datajob import DataFlow, DataJob
24
32
  from datahub.emitter.mcp import MetadataChangeProposalWrapper
25
33
  from datahub.emitter.mcp_builder import (
26
34
  ExperimentKey,
@@ -43,6 +51,7 @@ from datahub.ingestion.source.vertexai.vertexai_config import VertexAIConfig
43
51
  from datahub.ingestion.source.vertexai.vertexai_result_type_utils import (
44
52
  get_execution_result_status,
45
53
  get_job_result_status,
54
+ get_pipeline_task_result_status,
46
55
  is_status_for_run_event_class,
47
56
  )
48
57
  from datahub.metadata.com.linkedin.pegasus2avro.dataprocess import (
@@ -76,7 +85,13 @@ from datahub.metadata.schema_classes import (
76
85
  VersionPropertiesClass,
77
86
  VersionTagClass,
78
87
  )
79
- from datahub.metadata.urns import DataPlatformUrn, MlModelUrn, VersionSetUrn
88
+ from datahub.metadata.urns import (
89
+ DataFlowUrn,
90
+ DataJobUrn,
91
+ DataPlatformUrn,
92
+ MlModelUrn,
93
+ VersionSetUrn,
94
+ )
80
95
  from datahub.utilities.time import datetime_to_ts_millis
81
96
 
82
97
  T = TypeVar("T")
@@ -100,6 +115,34 @@ class ModelMetadata:
100
115
  endpoints: Optional[List[Endpoint]] = None
101
116
 
102
117
 
118
+ @dataclasses.dataclass
119
+ class PipelineTaskMetadata:
120
+ name: str
121
+ urn: DataJobUrn
122
+ id: Optional[int] = None
123
+ type: Optional[str] = None
124
+ state: Optional[PipelineTaskDetail.State] = None
125
+ start_time: Optional[timestamp_pb2.Timestamp] = None
126
+ create_time: Optional[timestamp_pb2.Timestamp] = None
127
+ end_time: Optional[timestamp_pb2.Timestamp] = None
128
+ upstreams: Optional[List[DataJobUrn]] = None
129
+ duration: Optional[int] = None
130
+
131
+
132
+ @dataclasses.dataclass
133
+ class PipelineMetadata:
134
+ name: str
135
+ resource_name: str
136
+ tasks: List[PipelineTaskMetadata]
137
+ urn: DataFlowUrn
138
+ id: Optional[str] = None
139
+ labels: Optional[Dict[str, str]] = None
140
+ create_time: Optional[datetime] = None
141
+ update_time: Optional[datetime] = None
142
+ duration: Optional[timedelta] = None
143
+ region: Optional[str] = None
144
+
145
+
103
146
  @platform_name("Vertex AI", id="vertexai")
104
147
  @config_class(VertexAIConfig)
105
148
  @support_status(SupportStatus.TESTING)
@@ -150,6 +193,255 @@ class VertexAISource(Source):
150
193
  yield from self._get_experiments_workunits()
151
194
  # Fetch and Ingest Experiment Runs
152
195
  yield from auto_workunit(self._get_experiment_runs_mcps())
196
+ # Fetch Pipelines and Tasks
197
+ yield from auto_workunit(self._get_pipelines_mcps())
198
+
199
+ def _get_pipelines_mcps(self) -> Iterable[MetadataChangeProposalWrapper]:
200
+ """
201
+ Fetches pipelines from Vertex AI and generates corresponding mcps.
202
+ """
203
+
204
+ pipeline_jobs = self.client.PipelineJob.list()
205
+
206
+ for pipeline in pipeline_jobs:
207
+ logger.info(f"fetching pipeline ({pipeline.name})")
208
+ pipeline_meta = self._get_pipeline_metadata(pipeline)
209
+ yield from self._get_pipeline_mcps(pipeline_meta)
210
+ yield from self._gen_pipeline_task_mcps(pipeline_meta)
211
+
212
+ def _get_pipeline_tasks_metadata(
213
+ self, pipeline: PipelineJob, pipeline_urn: DataFlowUrn
214
+ ) -> List[PipelineTaskMetadata]:
215
+ tasks: List[PipelineTaskMetadata] = list()
216
+ task_map: Dict[str, PipelineTaskDetail] = dict()
217
+ for task in pipeline.task_details:
218
+ task_map[task.task_name] = task
219
+
220
+ resource = pipeline.gca_resource
221
+ if isinstance(resource, PipelineJobType):
222
+ for task_name in resource.pipeline_spec["root"]["dag"]["tasks"]:
223
+ logger.debug(
224
+ f"fetching pipeline task ({task_name}) in pipeline ({pipeline.name})"
225
+ )
226
+ task_urn = DataJobUrn.create_from_ids(
227
+ data_flow_urn=str(pipeline_urn),
228
+ job_id=self._make_vertexai_pipeline_task_id(task_name),
229
+ )
230
+ task_meta = PipelineTaskMetadata(name=task_name, urn=task_urn)
231
+ if (
232
+ "dependentTasks"
233
+ in resource.pipeline_spec["root"]["dag"]["tasks"][task_name]
234
+ ):
235
+ upstream_tasks = resource.pipeline_spec["root"]["dag"]["tasks"][
236
+ task_name
237
+ ]["dependentTasks"]
238
+ upstream_urls = [
239
+ DataJobUrn.create_from_ids(
240
+ data_flow_urn=str(pipeline_urn),
241
+ job_id=self._make_vertexai_pipeline_task_id(upstream_task),
242
+ )
243
+ for upstream_task in upstream_tasks
244
+ ]
245
+ task_meta.upstreams = upstream_urls
246
+
247
+ task_detail = task_map.get(task_name)
248
+ if task_detail:
249
+ task_meta.id = task_detail.task_id
250
+ task_meta.state = task_detail.state
251
+ task_meta.start_time = task_detail.start_time
252
+ task_meta.create_time = task_detail.create_time
253
+ if task_detail.end_time:
254
+ task_meta.end_time = task_detail.end_time
255
+ task_meta.duration = int(
256
+ (
257
+ task_meta.end_time.timestamp()
258
+ - task_meta.start_time.timestamp()
259
+ )
260
+ * 1000
261
+ )
262
+
263
+ tasks.append(task_meta)
264
+ return tasks
265
+
266
+ def _get_pipeline_metadata(self, pipeline: PipelineJob) -> PipelineMetadata:
267
+ dataflow_urn = DataFlowUrn.create_from_ids(
268
+ orchestrator=self.platform,
269
+ env=self.config.env,
270
+ flow_id=self._make_vertexai_pipeline_id(pipeline.name),
271
+ platform_instance=self.platform,
272
+ )
273
+ tasks = self._get_pipeline_tasks_metadata(
274
+ pipeline=pipeline, pipeline_urn=dataflow_urn
275
+ )
276
+
277
+ pipeline_meta = PipelineMetadata(
278
+ name=pipeline.name,
279
+ resource_name=pipeline.resource_name,
280
+ urn=dataflow_urn,
281
+ tasks=tasks,
282
+ )
283
+ pipeline_meta.resource_name = pipeline.resource_name
284
+ pipeline_meta.labels = pipeline.labels
285
+ pipeline_meta.create_time = pipeline.create_time
286
+ pipeline_meta.region = pipeline.location
287
+ if pipeline.update_time:
288
+ pipeline_meta.update_time = pipeline.update_time
289
+ pipeline_meta.duration = timedelta(
290
+ milliseconds=datetime_to_ts_millis(pipeline.update_time)
291
+ - datetime_to_ts_millis(pipeline.create_time)
292
+ )
293
+ return pipeline_meta
294
+
295
+ def _gen_pipeline_task_run_mcps(
296
+ self, task: PipelineTaskMetadata, datajob: DataJob, pipeline: PipelineMetadata
297
+ ) -> (Iterable)[MetadataChangeProposalWrapper]:
298
+ dpi_urn = builder.make_data_process_instance_urn(
299
+ self._make_vertexai_pipeline_task_run_id(entity_id=task.name)
300
+ )
301
+ result_status: Union[str, RunResultTypeClass] = get_pipeline_task_result_status(
302
+ task.state
303
+ )
304
+
305
+ yield from MetadataChangeProposalWrapper.construct_many(
306
+ dpi_urn,
307
+ aspects=[
308
+ DataProcessInstancePropertiesClass(
309
+ name=task.name,
310
+ created=AuditStampClass(
311
+ time=(
312
+ int(task.create_time.timestamp() * 1000)
313
+ if task.create_time
314
+ else 0
315
+ ),
316
+ actor="urn:li:corpuser:datahub",
317
+ ),
318
+ externalUrl=self._make_pipeline_external_url(pipeline.name),
319
+ customProperties={},
320
+ ),
321
+ SubTypesClass(typeNames=[MLAssetSubTypes.VERTEX_PIPELINE_TASK_RUN]),
322
+ ContainerClass(container=self._get_project_container().as_urn()),
323
+ DataPlatformInstanceClass(platform=str(DataPlatformUrn(self.platform))),
324
+ DataProcessInstanceRelationships(
325
+ upstreamInstances=[], parentTemplate=str(datajob.urn)
326
+ ),
327
+ (
328
+ DataProcessInstanceRunEventClass(
329
+ status=DataProcessRunStatusClass.COMPLETE,
330
+ timestampMillis=(
331
+ int(task.create_time.timestamp() * 1000)
332
+ if task.create_time
333
+ else 0
334
+ ),
335
+ result=DataProcessInstanceRunResultClass(
336
+ type=result_status,
337
+ nativeResultType=self.platform,
338
+ ),
339
+ durationMillis=task.duration,
340
+ )
341
+ if is_status_for_run_event_class(result_status) and task.duration
342
+ else None
343
+ ),
344
+ ],
345
+ )
346
+
347
+ def _gen_pipeline_task_mcps(
348
+ self, pipeline: PipelineMetadata
349
+ ) -> Iterable[MetadataChangeProposalWrapper]:
350
+ dataflow_urn = pipeline.urn
351
+
352
+ for task in pipeline.tasks:
353
+ datajob = DataJob(
354
+ id=self._make_vertexai_pipeline_task_id(task.name),
355
+ flow_urn=dataflow_urn,
356
+ name=task.name,
357
+ properties={},
358
+ owners={"urn:li:corpuser:datahub"},
359
+ upstream_urns=task.upstreams if task.upstreams else [],
360
+ url=self._make_pipeline_external_url(pipeline.name),
361
+ )
362
+ yield from MetadataChangeProposalWrapper.construct_many(
363
+ entityUrn=str(datajob.urn),
364
+ aspects=[
365
+ ContainerClass(container=self._get_project_container().as_urn()),
366
+ SubTypesClass(typeNames=[MLAssetSubTypes.VERTEX_PIPELINE_TASK]),
367
+ ],
368
+ )
369
+ yield from datajob.generate_mcp()
370
+ yield from self._gen_pipeline_task_run_mcps(task, datajob, pipeline)
371
+
372
+ def _format_pipeline_duration(self, td: timedelta) -> str:
373
+ days = td.days
374
+ hours, remainder = divmod(td.seconds, 3600)
375
+ minutes, seconds = divmod(remainder, 60)
376
+ milliseconds = td.microseconds // 1000
377
+
378
+ parts = []
379
+ if days:
380
+ parts.append(f"{days}d")
381
+ if hours:
382
+ parts.append(f"{hours}h")
383
+ if minutes:
384
+ parts.append(f"{minutes}m")
385
+ if seconds:
386
+ parts.append(f"{seconds}s")
387
+ if milliseconds:
388
+ parts.append(f"{milliseconds}ms")
389
+ return " ".join(parts) if parts else "0s"
390
+
391
+ def _get_pipeline_task_properties(
392
+ self, task: PipelineTaskMetadata
393
+ ) -> Dict[str, str]:
394
+ return {
395
+ "created_time": (
396
+ task.create_time.strftime("%Y-%m-%d %H:%M:%S")
397
+ if task.create_time
398
+ else ""
399
+ )
400
+ }
401
+
402
+ def _get_pipeline_properties(self, pipeline: PipelineMetadata) -> Dict[str, str]:
403
+ return {
404
+ "resource_name": pipeline.resource_name if pipeline.resource_name else "",
405
+ "create_time": (
406
+ pipeline.create_time.isoformat() if pipeline.create_time else ""
407
+ ),
408
+ "update_time": (
409
+ pipeline.update_time.isoformat() if pipeline.update_time else ""
410
+ ),
411
+ "duration": (
412
+ self._format_pipeline_duration(pipeline.duration)
413
+ if pipeline.duration
414
+ else ""
415
+ ),
416
+ "location": (pipeline.region if pipeline.region else ""),
417
+ "labels": ",".join([f"{k}:{v}" for k, v in pipeline.labels.items()])
418
+ if pipeline.labels
419
+ else "",
420
+ }
421
+
422
+ def _get_pipeline_mcps(
423
+ self, pipeline: PipelineMetadata
424
+ ) -> Iterable[MetadataChangeProposalWrapper]:
425
+ dataflow = DataFlow(
426
+ orchestrator=self.platform,
427
+ id=self._make_vertexai_pipeline_id(pipeline.name),
428
+ env=self.config.env,
429
+ name=pipeline.name,
430
+ platform_instance=self.platform,
431
+ properties=self._get_pipeline_properties(pipeline),
432
+ owners={"urn:li:corpuser:datahub"},
433
+ url=self._make_pipeline_external_url(pipeline_name=pipeline.name),
434
+ )
435
+
436
+ yield from dataflow.generate_mcp()
437
+
438
+ yield from MetadataChangeProposalWrapper.construct_many(
439
+ entityUrn=str(dataflow.urn),
440
+ aspects=[
441
+ ContainerClass(container=self._get_project_container().as_urn()),
442
+ SubTypesClass(typeNames=[MLAssetSubTypes.VERTEX_PIPELINE]),
443
+ ],
444
+ )
153
445
 
154
446
  def _get_experiments_workunits(self) -> Iterable[MetadataWorkUnit]:
155
447
  # List all experiments
@@ -175,7 +467,7 @@ class VertexAISource(Source):
175
467
  parent_container_key=self._get_project_container(),
176
468
  container_key=ExperimentKey(
177
469
  platform=self.platform,
178
- id=self._make_vertexai_experiment_name(experiment.name),
470
+ id=self._make_vertexai_experiment_id(experiment.name),
179
471
  ),
180
472
  name=experiment.name,
181
473
  sub_types=[MLAssetSubTypes.VERTEX_EXPERIMENT],
@@ -311,7 +603,7 @@ class VertexAISource(Source):
311
603
  ) -> Iterable[MetadataChangeProposalWrapper]:
312
604
  experiment_key = ExperimentKey(
313
605
  platform=self.platform,
314
- id=self._make_vertexai_experiment_name(experiment.name),
606
+ id=self._make_vertexai_experiment_id(experiment.name),
315
607
  )
316
608
  run_urn = self._make_experiment_run_urn(experiment, run)
317
609
  created_time, duration = self._get_run_timestamps(run)
@@ -968,7 +1260,7 @@ class VertexAISource(Source):
968
1260
  ) -> str:
969
1261
  return f"{self.config.project_id}.job.{entity_id}"
970
1262
 
971
- def _make_vertexai_experiment_name(self, entity_id: Optional[str]) -> str:
1263
+ def _make_vertexai_experiment_id(self, entity_id: Optional[str]) -> str:
972
1264
  return f"{self.config.project_id}.experiment.{entity_id}"
973
1265
 
974
1266
  def _make_vertexai_experiment_run_name(self, entity_id: Optional[str]) -> str:
@@ -977,6 +1269,15 @@ class VertexAISource(Source):
977
1269
  def _make_vertexai_run_execution_name(self, entity_id: Optional[str]) -> str:
978
1270
  return f"{self.config.project_id}.execution.{entity_id}"
979
1271
 
1272
+ def _make_vertexai_pipeline_id(self, entity_id: Optional[str]) -> str:
1273
+ return f"{self.config.project_id}.pipeline.{entity_id}"
1274
+
1275
+ def _make_vertexai_pipeline_task_id(self, entity_id: Optional[str]) -> str:
1276
+ return f"{self.config.project_id}.pipeline_task.{entity_id}"
1277
+
1278
+ def _make_vertexai_pipeline_task_run_id(self, entity_id: Optional[str]) -> str:
1279
+ return f"{self.config.project_id}.pipeline_task_run.{entity_id}"
1280
+
980
1281
  def _make_artifact_external_url(
981
1282
  self, experiment: Experiment, run: ExperimentRun
982
1283
  ) -> str:
@@ -1053,3 +1354,14 @@ class VertexAISource(Source):
1053
1354
  f"/runs/{experiment.name}-{run.name}/charts?project={self.config.project_id}"
1054
1355
  )
1055
1356
  return external_url
1357
+
1358
+ def _make_pipeline_external_url(self, pipeline_name: str) -> str:
1359
+ """
1360
+ Pipeline Run external URL in Vertex AI
1361
+ https://console.cloud.google.com/vertex-ai/pipelines/locations/us-west2/runs/pipeline-example-more-tasks-3-20250320210739?project=acryl-poc
1362
+ """
1363
+ external_url: str = (
1364
+ f"{self.config.vertexai_url}/pipelines/locations/{self.config.region}/runs/{pipeline_name}"
1365
+ f"?project={self.config.project_id}"
1366
+ )
1367
+ return external_url
@@ -1,9 +1,9 @@
1
- from typing import Union
1
+ from typing import Optional, Union
2
2
 
3
3
  from google.cloud.aiplatform.base import VertexAiResourceNoun
4
4
  from google.cloud.aiplatform.jobs import _RunnableJob
5
5
  from google.cloud.aiplatform.training_jobs import _TrainingJob
6
- from google.cloud.aiplatform_v1.types import JobState, PipelineState
6
+ from google.cloud.aiplatform_v1.types import JobState, PipelineState, PipelineTaskDetail
7
7
 
8
8
  from datahub.metadata.schema_classes import RunResultTypeClass
9
9
 
@@ -64,5 +64,26 @@ def get_execution_result_status(status: int) -> Union[str, RunResultTypeClass]:
64
64
  return status_mapping.get(status, "UNKNOWN")
65
65
 
66
66
 
67
+ def get_pipeline_task_result_status(
68
+ status: Optional[PipelineTaskDetail.State],
69
+ ) -> Union[str, RunResultTypeClass]:
70
+ # TODO: DataProcessInstanceRunResultClass fails with status string except for SUCCESS, FAILURE, SKIPPED,
71
+ # which will be fixed in the future
72
+ status_mapping = {
73
+ # PipelineTaskDetail.State.STATE_UNSPECIFIED: "STATE_UNSPECIFIED",
74
+ # PipelineTaskDetail.State.PENDING: "PENDING",
75
+ # PipelineTaskDetail.State.RUNNING: "RUNNING",
76
+ # PipelineTaskDetail.State.CANCEL_PENDING: "CANCEL_PENDING",
77
+ # PipelineTaskDetail.State.CANCELLING: "CANCELLING",
78
+ # PipelineTaskDetail.State.NOT_TRIGGERED: "NOT_TRIGGERED",
79
+ PipelineTaskDetail.State.SUCCEEDED: RunResultTypeClass.SUCCESS,
80
+ PipelineTaskDetail.State.FAILED: RunResultTypeClass.FAILURE,
81
+ PipelineTaskDetail.State.SKIPPED: RunResultTypeClass.SKIPPED,
82
+ }
83
+ if status is None:
84
+ return "UNKNOWN"
85
+ return status_mapping.get(status, "UNKNOWN")
86
+
87
+
67
88
  def is_status_for_run_event_class(status: Union[str, RunResultTypeClass]) -> bool:
68
89
  return status in [RunResultTypeClass.SUCCESS, RunResultTypeClass.FAILURE]
@@ -15247,7 +15247,7 @@ class DataContractKeyClass(_Aspect):
15247
15247
 
15248
15248
 
15249
15249
  ASPECT_NAME = 'dataContractKey'
15250
- ASPECT_INFO = {'keyForEntity': 'dataContract', 'entityCategory': 'core', 'entityAspects': ['dataContractProperties', 'dataContractStatus', 'status']}
15250
+ ASPECT_INFO = {'keyForEntity': 'dataContract', 'entityCategory': 'core', 'entityAspects': ['dataContractProperties', 'dataContractStatus', 'status', 'structuredProperties']}
15251
15251
  RECORD_SCHEMA = get_schema_type("com.linkedin.pegasus2avro.metadata.key.DataContractKey")
15252
15252
 
15253
15253
  def __init__(self,
@@ -16024,7 +16024,8 @@
16024
16024
  "entityAspects": [
16025
16025
  "dataContractProperties",
16026
16026
  "dataContractStatus",
16027
- "status"
16027
+ "status",
16028
+ "structuredProperties"
16028
16029
  ]
16029
16030
  },
16030
16031
  "name": "DataContractKey",
@@ -7,7 +7,8 @@
7
7
  "entityAspects": [
8
8
  "dataContractProperties",
9
9
  "dataContractStatus",
10
- "status"
10
+ "status",
11
+ "structuredProperties"
11
12
  ]
12
13
  },
13
14
  "name": "DataContractKey",
@@ -3,11 +3,15 @@ from typing import Dict, List, Type
3
3
  from datahub.sdk.container import Container
4
4
  from datahub.sdk.dataset import Dataset
5
5
  from datahub.sdk.entity import Entity
6
+ from datahub.sdk.mlmodel import MLModel
7
+ from datahub.sdk.mlmodelgroup import MLModelGroup
6
8
 
7
9
  # TODO: Is there a better way to declare this?
8
10
  ENTITY_CLASSES_LIST: List[Type[Entity]] = [
9
11
  Container,
10
12
  Dataset,
13
+ MLModel,
14
+ MLModelGroup,
11
15
  ]
12
16
 
13
17
  ENTITY_CLASSES: Dict[str, Type[Entity]] = {
datahub/sdk/_shared.py CHANGED
@@ -5,6 +5,7 @@ from datetime import datetime
5
5
  from typing import (
6
6
  TYPE_CHECKING,
7
7
  Callable,
8
+ Dict,
8
9
  List,
9
10
  Optional,
10
11
  Sequence,
@@ -14,6 +15,7 @@ from typing import (
14
15
 
15
16
  from typing_extensions import TypeAlias, assert_never
16
17
 
18
+ import datahub.emitter.mce_builder as builder
17
19
  import datahub.metadata.schema_classes as models
18
20
  from datahub.emitter.mce_builder import (
19
21
  make_ts_millis,
@@ -30,12 +32,14 @@ from datahub.metadata.urns import (
30
32
  DataJobUrn,
31
33
  DataPlatformInstanceUrn,
32
34
  DataPlatformUrn,
35
+ DataProcessInstanceUrn,
33
36
  DatasetUrn,
34
37
  DomainUrn,
35
38
  GlossaryTermUrn,
36
39
  OwnershipTypeUrn,
37
40
  TagUrn,
38
41
  Urn,
42
+ VersionSetUrn,
39
43
  )
40
44
  from datahub.sdk._utils import add_list_unique, remove_list_unique
41
45
  from datahub.sdk.entity import Entity
@@ -52,6 +56,36 @@ ActorUrn: TypeAlias = Union[CorpUserUrn, CorpGroupUrn]
52
56
 
53
57
  _DEFAULT_ACTOR_URN = CorpUserUrn("__ingestion").urn()
54
58
 
59
+ TrainingMetricsInputType: TypeAlias = Union[
60
+ List[models.MLMetricClass], Dict[str, Optional[str]]
61
+ ]
62
+ HyperParamsInputType: TypeAlias = Union[
63
+ List[models.MLHyperParamClass], Dict[str, Optional[str]]
64
+ ]
65
+ MLTrainingJobInputType: TypeAlias = Union[Sequence[Union[str, DataProcessInstanceUrn]]]
66
+
67
+
68
+ def convert_training_metrics(
69
+ metrics: TrainingMetricsInputType,
70
+ ) -> List[models.MLMetricClass]:
71
+ if isinstance(metrics, dict):
72
+ return [
73
+ models.MLMetricClass(name=name, value=str(value))
74
+ for name, value in metrics.items()
75
+ ]
76
+ return metrics
77
+
78
+
79
+ def convert_hyper_params(
80
+ params: HyperParamsInputType,
81
+ ) -> List[models.MLHyperParamClass]:
82
+ if isinstance(params, dict):
83
+ return [
84
+ models.MLHyperParamClass(name=name, value=str(value))
85
+ for name, value in params.items()
86
+ ]
87
+ return params
88
+
55
89
 
56
90
  def make_time_stamp(ts: Optional[datetime]) -> Optional[models.TimeStampClass]:
57
91
  if ts is None:
@@ -578,3 +612,109 @@ class HasInstitutionalMemory(Entity):
578
612
  self._link_key,
579
613
  self._parse_link_association_class(link),
580
614
  )
615
+
616
+
617
+ class HasVersion(Entity):
618
+ """Mixin for entities that have version properties."""
619
+
620
+ def _get_version_props(self) -> Optional[models.VersionPropertiesClass]:
621
+ return self._get_aspect(models.VersionPropertiesClass)
622
+
623
+ def _ensure_version_props(self) -> models.VersionPropertiesClass:
624
+ version_props = self._get_version_props()
625
+ if version_props is None:
626
+ guid_dict = {"urn": str(self.urn)}
627
+ version_set_urn = VersionSetUrn(
628
+ id=builder.datahub_guid(guid_dict), entity_type=self.urn.ENTITY_TYPE
629
+ )
630
+
631
+ version_props = models.VersionPropertiesClass(
632
+ versionSet=str(version_set_urn),
633
+ version=models.VersionTagClass(versionTag="0.1.0"),
634
+ sortId="0000000.1.0",
635
+ )
636
+ self._set_aspect(version_props)
637
+ return version_props
638
+
639
+ @property
640
+ def version(self) -> Optional[str]:
641
+ version_props = self._get_version_props()
642
+ if version_props and version_props.version:
643
+ return version_props.version.versionTag
644
+ return None
645
+
646
+ def set_version(self, version: str) -> None:
647
+ """Set the version of the entity."""
648
+ guid_dict = {"urn": str(self.urn)}
649
+ version_set_urn = VersionSetUrn(
650
+ id=builder.datahub_guid(guid_dict), entity_type=self.urn.ENTITY_TYPE
651
+ )
652
+
653
+ version_props = self._get_version_props()
654
+ if version_props is None:
655
+ # If no version properties exist, create a new one
656
+ version_props = models.VersionPropertiesClass(
657
+ version=models.VersionTagClass(versionTag=version),
658
+ versionSet=str(version_set_urn),
659
+ sortId=version.zfill(10), # Pad with zeros for sorting
660
+ )
661
+ else:
662
+ # Update existing version properties
663
+ version_props.version = models.VersionTagClass(versionTag=version)
664
+ version_props.versionSet = str(version_set_urn)
665
+ version_props.sortId = version.zfill(10)
666
+
667
+ self._set_aspect(version_props)
668
+
669
+ @property
670
+ def version_aliases(self) -> List[str]:
671
+ version_props = self._get_version_props()
672
+ if version_props and version_props.aliases:
673
+ return [
674
+ alias.versionTag
675
+ for alias in version_props.aliases
676
+ if alias.versionTag is not None
677
+ ]
678
+ return [] # Return empty list instead of None
679
+
680
+ def set_version_aliases(self, aliases: List[str]) -> None:
681
+ version_props = self._get_aspect(models.VersionPropertiesClass)
682
+ if version_props:
683
+ version_props.aliases = [
684
+ models.VersionTagClass(versionTag=alias) for alias in aliases
685
+ ]
686
+ else:
687
+ # If no version properties exist, we need to create one with a default version
688
+ guid_dict = {"urn": str(self.urn)}
689
+ version_set_urn = VersionSetUrn(
690
+ id=builder.datahub_guid(guid_dict), entity_type=self.urn.ENTITY_TYPE
691
+ )
692
+ self._set_aspect(
693
+ models.VersionPropertiesClass(
694
+ version=models.VersionTagClass(
695
+ versionTag="0.1.0"
696
+ ), # Default version
697
+ versionSet=str(version_set_urn),
698
+ sortId="0000000.1.0",
699
+ aliases=[
700
+ models.VersionTagClass(versionTag=alias) for alias in aliases
701
+ ],
702
+ )
703
+ )
704
+
705
+ def add_version_alias(self, alias: str) -> None:
706
+ if not alias:
707
+ raise ValueError("Alias cannot be empty")
708
+ version_props = self._ensure_version_props()
709
+ if version_props.aliases is None:
710
+ version_props.aliases = []
711
+ version_props.aliases.append(models.VersionTagClass(versionTag=alias))
712
+ self._set_aspect(version_props)
713
+
714
+ def remove_version_alias(self, alias: str) -> None:
715
+ version_props = self._get_version_props()
716
+ if version_props and version_props.aliases:
717
+ version_props.aliases = [
718
+ a for a in version_props.aliases if a.versionTag != alias
719
+ ]
720
+ self._set_aspect(version_props)
@@ -11,6 +11,8 @@ from datahub.ingestion.graph.client import DataHubGraph
11
11
  from datahub.metadata.urns import (
12
12
  ContainerUrn,
13
13
  DatasetUrn,
14
+ MlModelGroupUrn,
15
+ MlModelUrn,
14
16
  Urn,
15
17
  )
16
18
  from datahub.sdk._all_entities import ENTITY_CLASSES
@@ -18,6 +20,8 @@ from datahub.sdk._shared import UrnOrStr
18
20
  from datahub.sdk.container import Container
19
21
  from datahub.sdk.dataset import Dataset
20
22
  from datahub.sdk.entity import Entity
23
+ from datahub.sdk.mlmodel import MLModel
24
+ from datahub.sdk.mlmodelgroup import MLModelGroup
21
25
 
22
26
  if TYPE_CHECKING:
23
27
  from datahub.sdk.main_client import DataHubClient
@@ -49,6 +53,10 @@ class EntityClient:
49
53
  @overload
50
54
  def get(self, urn: DatasetUrn) -> Dataset: ...
51
55
  @overload
56
+ def get(self, urn: MlModelUrn) -> MLModel: ...
57
+ @overload
58
+ def get(self, urn: MlModelGroupUrn) -> MLModelGroup: ...
59
+ @overload
52
60
  def get(self, urn: Union[Urn, str]) -> Entity: ...
53
61
  def get(self, urn: UrnOrStr) -> Entity:
54
62
  """Retrieve an entity by its urn.