acryl-datahub 1.0.0.3rc4__py3-none-any.whl → 1.0.0.3rc5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of acryl-datahub might be problematic. Click here for more details.

@@ -1,7 +1,7 @@
1
- acryl_datahub-1.0.0.3rc4.dist-info/licenses/LICENSE,sha256=9xNHpsD0uYF5ONzXsKDCuHHB-xbiCrSbueWXqrTNsxk,11365
1
+ acryl_datahub-1.0.0.3rc5.dist-info/licenses/LICENSE,sha256=9xNHpsD0uYF5ONzXsKDCuHHB-xbiCrSbueWXqrTNsxk,11365
2
2
  datahub/__init__.py,sha256=aq_i5lVREmoLfYIqcx_pEQicO855YlhD19tWc1eZZNI,59
3
3
  datahub/__main__.py,sha256=pegIvQ9hzK7IhqVeUi1MeADSZ2QlP-D3K0OQdEg55RU,106
4
- datahub/_version.py,sha256=911KmAC1s4iZRnTeHRY8SXqgCBef_UX6nB2GAa1FqkE,323
4
+ datahub/_version.py,sha256=QrGnvYk4Mo48c9uZMV5srbywzBX4Fcxp27GXbDul8F4,323
5
5
  datahub/entrypoints.py,sha256=2TYgHhs3sCxJlojIHjqfxzt3_ImPwPzq4vBtsUuMqu4,8885
6
6
  datahub/errors.py,sha256=BzKdcmYseHOt36zfjJXc17WNutFhp9Y23cU_L6cIkxc,612
7
7
  datahub/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -899,15 +899,17 @@ datahub/metadata/schemas/__init__.py,sha256=uvLNC3VyCkWA_v8e9FdA1leFf46NFKDD0Aaj
899
899
  datahub/pydantic/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
900
900
  datahub/pydantic/compat.py,sha256=TUEo4kSEeOWVAhV6LQtst1phrpVgGtK4uif4OI5vQ2M,1937
901
901
  datahub/sdk/__init__.py,sha256=QeutS6Th8K4E4ZxXuoGrmvahN6zA9Oh9asKk5mw9AIk,1670
902
- datahub/sdk/_all_entities.py,sha256=lCa2oeudmrKN7suOBIT5tidGJB8pgL3xD80DQlsYUKI,399
902
+ datahub/sdk/_all_entities.py,sha256=inbLFv2T7dhZpGfBY5FPhCWbyE0P0G8umOt0Bc7V4XA,520
903
903
  datahub/sdk/_attribution.py,sha256=0Trh8steVd27GOr9MKCZeawbuDD2_q3GIsZlCtHqEUg,1321
904
- datahub/sdk/_shared.py,sha256=pHVKEJ50BoLw0fLLAm9zYsynNDN_bPI26qlj8nk2iyY,19582
904
+ datahub/sdk/_shared.py,sha256=5L1IkihLc7Pd2x0ypDs96kZ8ecm6o0-UZEn0J1Sffqw,24808
905
905
  datahub/sdk/_utils.py,sha256=aGE665Su8SGtj2CRDiTaXNYrJ8ADBsS0m4DmaXw79b8,1027
906
906
  datahub/sdk/container.py,sha256=yw_vw9Jl1wOYNwMHxQHLz5ZvVQVDWWHi9CWBR3hOCd8,7547
907
907
  datahub/sdk/dataset.py,sha256=5LG4c_8bHeSPYrW88KNXRgiPD8frBjR0OBVrrwdquU4,29152
908
908
  datahub/sdk/entity.py,sha256=Q29AbpS58L4gD8ETwoNIwG-ouytz4c0MSSFi6-jLl_4,6742
909
- datahub/sdk/entity_client.py,sha256=Sxe6H6Vr_tqLJu5KW7MJfLWJ6mgh4mbsx7u7MOBpM64,5052
909
+ datahub/sdk/entity_client.py,sha256=1AC9J7-jv3rD-MFEPz2PnFrT8nFkj_WO0M-4nyVOtQk,5319
910
910
  datahub/sdk/main_client.py,sha256=h2MKRhR-BO0zGCMhF7z2bTncX4hagKrAYwR3wTNTtzA,3666
911
+ datahub/sdk/mlmodel.py,sha256=amS-hHg5tT7zAqEHG17kSA60Q7td2DFtO-W2rEfb2rY,10206
912
+ datahub/sdk/mlmodelgroup.py,sha256=_7IkqkLVeyqYVEUHTVePSDLQyESsnwht5ca1lcMODAg,7842
911
913
  datahub/sdk/resolver_client.py,sha256=nKMAZJt2tRSGfKSzoREIh43PXqjM3umLiYkYHJjo1io,3243
912
914
  datahub/sdk/search_client.py,sha256=BJR5t7Ff2oDNOGLcSCp9YHzrGKbgOQr7T8XQKGEpucw,3437
913
915
  datahub/sdk/search_filters.py,sha256=BcMhvG5hGYAATtLPLz4WLRjKApX2oLYrrcGn-CG__ek,12901
@@ -1046,8 +1048,8 @@ datahub_provider/operators/datahub_assertion_operator.py,sha256=uvTQ-jk2F0sbqqxp
1046
1048
  datahub_provider/operators/datahub_assertion_sensor.py,sha256=lCBj_3x1cf5GMNpHdfkpHuyHfVxsm6ff5x2Z5iizcAo,140
1047
1049
  datahub_provider/operators/datahub_operation_operator.py,sha256=aevDp2FzX7FxGlXrR0khoHNbxbhKR2qPEX5e8O2Jyzw,174
1048
1050
  datahub_provider/operators/datahub_operation_sensor.py,sha256=8fcdVBCEPgqy1etTXgLoiHoJrRt_nzFZQMdSzHqSG7M,168
1049
- acryl_datahub-1.0.0.3rc4.dist-info/METADATA,sha256=66uBY8gH_YmKeBmwbECGBpJhT7JiyrJXfFJv5wcFgQQ,176965
1050
- acryl_datahub-1.0.0.3rc4.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
1051
- acryl_datahub-1.0.0.3rc4.dist-info/entry_points.txt,sha256=o3mDeJXSKhsy7XLkuogihraiabBdLn9HaizYXPrxmk0,9710
1052
- acryl_datahub-1.0.0.3rc4.dist-info/top_level.txt,sha256=iLjSrLK5ox1YVYcglRUkcvfZPvKlobBWx7CTUXx8_GI,25
1053
- acryl_datahub-1.0.0.3rc4.dist-info/RECORD,,
1051
+ acryl_datahub-1.0.0.3rc5.dist-info/METADATA,sha256=4yRDLa5bvQ1tAjJ2uzc4mKVmegeKOASmLpseMx9A1K8,176989
1052
+ acryl_datahub-1.0.0.3rc5.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
1053
+ acryl_datahub-1.0.0.3rc5.dist-info/entry_points.txt,sha256=o3mDeJXSKhsy7XLkuogihraiabBdLn9HaizYXPrxmk0,9710
1054
+ acryl_datahub-1.0.0.3rc5.dist-info/top_level.txt,sha256=iLjSrLK5ox1YVYcglRUkcvfZPvKlobBWx7CTUXx8_GI,25
1055
+ acryl_datahub-1.0.0.3rc5.dist-info/RECORD,,
datahub/_version.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # Published at https://pypi.org/project/acryl-datahub/.
2
2
  __package_name__ = "acryl-datahub"
3
- __version__ = "1.0.0.3rc4"
3
+ __version__ = "1.0.0.3rc5"
4
4
 
5
5
 
6
6
  def is_dev_mode() -> bool:
@@ -3,11 +3,15 @@ from typing import Dict, List, Type
3
3
  from datahub.sdk.container import Container
4
4
  from datahub.sdk.dataset import Dataset
5
5
  from datahub.sdk.entity import Entity
6
+ from datahub.sdk.mlmodel import MLModel
7
+ from datahub.sdk.mlmodelgroup import MLModelGroup
6
8
 
7
9
  # TODO: Is there a better way to declare this?
8
10
  ENTITY_CLASSES_LIST: List[Type[Entity]] = [
9
11
  Container,
10
12
  Dataset,
13
+ MLModel,
14
+ MLModelGroup,
11
15
  ]
12
16
 
13
17
  ENTITY_CLASSES: Dict[str, Type[Entity]] = {
datahub/sdk/_shared.py CHANGED
@@ -5,6 +5,7 @@ from datetime import datetime
5
5
  from typing import (
6
6
  TYPE_CHECKING,
7
7
  Callable,
8
+ Dict,
8
9
  List,
9
10
  Optional,
10
11
  Sequence,
@@ -14,6 +15,7 @@ from typing import (
14
15
 
15
16
  from typing_extensions import TypeAlias, assert_never
16
17
 
18
+ import datahub.emitter.mce_builder as builder
17
19
  import datahub.metadata.schema_classes as models
18
20
  from datahub.emitter.mce_builder import (
19
21
  make_ts_millis,
@@ -30,12 +32,14 @@ from datahub.metadata.urns import (
30
32
  DataJobUrn,
31
33
  DataPlatformInstanceUrn,
32
34
  DataPlatformUrn,
35
+ DataProcessInstanceUrn,
33
36
  DatasetUrn,
34
37
  DomainUrn,
35
38
  GlossaryTermUrn,
36
39
  OwnershipTypeUrn,
37
40
  TagUrn,
38
41
  Urn,
42
+ VersionSetUrn,
39
43
  )
40
44
  from datahub.sdk._utils import add_list_unique, remove_list_unique
41
45
  from datahub.sdk.entity import Entity
@@ -52,6 +56,36 @@ ActorUrn: TypeAlias = Union[CorpUserUrn, CorpGroupUrn]
52
56
 
53
57
  _DEFAULT_ACTOR_URN = CorpUserUrn("__ingestion").urn()
54
58
 
59
+ TrainingMetricsInputType: TypeAlias = Union[
60
+ List[models.MLMetricClass], Dict[str, Optional[str]]
61
+ ]
62
+ HyperParamsInputType: TypeAlias = Union[
63
+ List[models.MLHyperParamClass], Dict[str, Optional[str]]
64
+ ]
65
+ MLTrainingJobInputType: TypeAlias = Union[Sequence[Union[str, DataProcessInstanceUrn]]]
66
+
67
+
68
+ def convert_training_metrics(
69
+ metrics: TrainingMetricsInputType,
70
+ ) -> List[models.MLMetricClass]:
71
+ if isinstance(metrics, dict):
72
+ return [
73
+ models.MLMetricClass(name=name, value=str(value))
74
+ for name, value in metrics.items()
75
+ ]
76
+ return metrics
77
+
78
+
79
+ def convert_hyper_params(
80
+ params: HyperParamsInputType,
81
+ ) -> List[models.MLHyperParamClass]:
82
+ if isinstance(params, dict):
83
+ return [
84
+ models.MLHyperParamClass(name=name, value=str(value))
85
+ for name, value in params.items()
86
+ ]
87
+ return params
88
+
55
89
 
56
90
  def make_time_stamp(ts: Optional[datetime]) -> Optional[models.TimeStampClass]:
57
91
  if ts is None:
@@ -578,3 +612,109 @@ class HasInstitutionalMemory(Entity):
578
612
  self._link_key,
579
613
  self._parse_link_association_class(link),
580
614
  )
615
+
616
+
617
+ class HasVersion(Entity):
618
+ """Mixin for entities that have version properties."""
619
+
620
+ def _get_version_props(self) -> Optional[models.VersionPropertiesClass]:
621
+ return self._get_aspect(models.VersionPropertiesClass)
622
+
623
+ def _ensure_version_props(self) -> models.VersionPropertiesClass:
624
+ version_props = self._get_version_props()
625
+ if version_props is None:
626
+ guid_dict = {"urn": str(self.urn)}
627
+ version_set_urn = VersionSetUrn(
628
+ id=builder.datahub_guid(guid_dict), entity_type=self.urn.ENTITY_TYPE
629
+ )
630
+
631
+ version_props = models.VersionPropertiesClass(
632
+ versionSet=str(version_set_urn),
633
+ version=models.VersionTagClass(versionTag="0.1.0"),
634
+ sortId="0000000.1.0",
635
+ )
636
+ self._set_aspect(version_props)
637
+ return version_props
638
+
639
+ @property
640
+ def version(self) -> Optional[str]:
641
+ version_props = self._get_version_props()
642
+ if version_props and version_props.version:
643
+ return version_props.version.versionTag
644
+ return None
645
+
646
+ def set_version(self, version: str) -> None:
647
+ """Set the version of the entity."""
648
+ guid_dict = {"urn": str(self.urn)}
649
+ version_set_urn = VersionSetUrn(
650
+ id=builder.datahub_guid(guid_dict), entity_type=self.urn.ENTITY_TYPE
651
+ )
652
+
653
+ version_props = self._get_version_props()
654
+ if version_props is None:
655
+ # If no version properties exist, create a new one
656
+ version_props = models.VersionPropertiesClass(
657
+ version=models.VersionTagClass(versionTag=version),
658
+ versionSet=str(version_set_urn),
659
+ sortId=version.zfill(10), # Pad with zeros for sorting
660
+ )
661
+ else:
662
+ # Update existing version properties
663
+ version_props.version = models.VersionTagClass(versionTag=version)
664
+ version_props.versionSet = str(version_set_urn)
665
+ version_props.sortId = version.zfill(10)
666
+
667
+ self._set_aspect(version_props)
668
+
669
+ @property
670
+ def version_aliases(self) -> List[str]:
671
+ version_props = self._get_version_props()
672
+ if version_props and version_props.aliases:
673
+ return [
674
+ alias.versionTag
675
+ for alias in version_props.aliases
676
+ if alias.versionTag is not None
677
+ ]
678
+ return [] # Return empty list instead of None
679
+
680
+ def set_version_aliases(self, aliases: List[str]) -> None:
681
+ version_props = self._get_aspect(models.VersionPropertiesClass)
682
+ if version_props:
683
+ version_props.aliases = [
684
+ models.VersionTagClass(versionTag=alias) for alias in aliases
685
+ ]
686
+ else:
687
+ # If no version properties exist, we need to create one with a default version
688
+ guid_dict = {"urn": str(self.urn)}
689
+ version_set_urn = VersionSetUrn(
690
+ id=builder.datahub_guid(guid_dict), entity_type=self.urn.ENTITY_TYPE
691
+ )
692
+ self._set_aspect(
693
+ models.VersionPropertiesClass(
694
+ version=models.VersionTagClass(
695
+ versionTag="0.1.0"
696
+ ), # Default version
697
+ versionSet=str(version_set_urn),
698
+ sortId="0000000.1.0",
699
+ aliases=[
700
+ models.VersionTagClass(versionTag=alias) for alias in aliases
701
+ ],
702
+ )
703
+ )
704
+
705
+ def add_version_alias(self, alias: str) -> None:
706
+ if not alias:
707
+ raise ValueError("Alias cannot be empty")
708
+ version_props = self._ensure_version_props()
709
+ if version_props.aliases is None:
710
+ version_props.aliases = []
711
+ version_props.aliases.append(models.VersionTagClass(versionTag=alias))
712
+ self._set_aspect(version_props)
713
+
714
+ def remove_version_alias(self, alias: str) -> None:
715
+ version_props = self._get_version_props()
716
+ if version_props and version_props.aliases:
717
+ version_props.aliases = [
718
+ a for a in version_props.aliases if a.versionTag != alias
719
+ ]
720
+ self._set_aspect(version_props)
@@ -11,6 +11,8 @@ from datahub.ingestion.graph.client import DataHubGraph
11
11
  from datahub.metadata.urns import (
12
12
  ContainerUrn,
13
13
  DatasetUrn,
14
+ MlModelGroupUrn,
15
+ MlModelUrn,
14
16
  Urn,
15
17
  )
16
18
  from datahub.sdk._all_entities import ENTITY_CLASSES
@@ -18,6 +20,8 @@ from datahub.sdk._shared import UrnOrStr
18
20
  from datahub.sdk.container import Container
19
21
  from datahub.sdk.dataset import Dataset
20
22
  from datahub.sdk.entity import Entity
23
+ from datahub.sdk.mlmodel import MLModel
24
+ from datahub.sdk.mlmodelgroup import MLModelGroup
21
25
 
22
26
  if TYPE_CHECKING:
23
27
  from datahub.sdk.main_client import DataHubClient
@@ -49,6 +53,10 @@ class EntityClient:
49
53
  @overload
50
54
  def get(self, urn: DatasetUrn) -> Dataset: ...
51
55
  @overload
56
+ def get(self, urn: MlModelUrn) -> MLModel: ...
57
+ @overload
58
+ def get(self, urn: MlModelGroupUrn) -> MLModelGroup: ...
59
+ @overload
52
60
  def get(self, urn: Union[Urn, str]) -> Entity: ...
53
61
  def get(self, urn: UrnOrStr) -> Entity:
54
62
  """Retrieve an entity by its urn.
datahub/sdk/mlmodel.py ADDED
@@ -0,0 +1,301 @@
1
+ from __future__ import annotations
2
+
3
+ from datetime import datetime
4
+ from typing import Dict, List, Optional, Sequence, Type, Union
5
+
6
+ from typing_extensions import Self
7
+
8
+ from datahub.emitter.mce_builder import DEFAULT_ENV
9
+ from datahub.metadata.schema_classes import (
10
+ AspectBag,
11
+ MLHyperParamClass,
12
+ MLMetricClass,
13
+ MLModelPropertiesClass,
14
+ )
15
+ from datahub.metadata.urns import (
16
+ DataProcessInstanceUrn,
17
+ MlModelGroupUrn,
18
+ MlModelUrn,
19
+ Urn,
20
+ )
21
+ from datahub.sdk._shared import (
22
+ DomainInputType,
23
+ HasDomain,
24
+ HasInstitutionalMemory,
25
+ HasOwnership,
26
+ HasPlatformInstance,
27
+ HasTags,
28
+ HasTerms,
29
+ HasVersion,
30
+ HyperParamsInputType,
31
+ LinksInputType,
32
+ MLTrainingJobInputType,
33
+ OwnersInputType,
34
+ TagsInputType,
35
+ TermsInputType,
36
+ TrainingMetricsInputType,
37
+ convert_hyper_params,
38
+ convert_training_metrics,
39
+ make_time_stamp,
40
+ parse_time_stamp,
41
+ )
42
+ from datahub.sdk.entity import Entity, ExtraAspectsType
43
+
44
+
45
+ class MLModel(
46
+ HasPlatformInstance,
47
+ HasOwnership,
48
+ HasInstitutionalMemory,
49
+ HasTags,
50
+ HasTerms,
51
+ HasDomain,
52
+ HasVersion,
53
+ Entity,
54
+ ):
55
+ __slots__ = ()
56
+
57
+ @classmethod
58
+ def get_urn_type(cls) -> Type[MlModelUrn]:
59
+ return MlModelUrn
60
+
61
+ def __init__(
62
+ self,
63
+ id: str,
64
+ platform: str,
65
+ version: Optional[str] = None,
66
+ aliases: Optional[List[str]] = None,
67
+ platform_instance: Optional[str] = None,
68
+ env: str = DEFAULT_ENV,
69
+ name: Optional[str] = None,
70
+ description: Optional[str] = None,
71
+ training_metrics: Optional[TrainingMetricsInputType] = None,
72
+ hyper_params: Optional[HyperParamsInputType] = None,
73
+ external_url: Optional[str] = None,
74
+ custom_properties: Optional[Dict[str, str]] = None,
75
+ created: Optional[datetime] = None,
76
+ last_modified: Optional[datetime] = None,
77
+ owners: Optional[OwnersInputType] = None,
78
+ links: Optional[LinksInputType] = None,
79
+ tags: Optional[TagsInputType] = None,
80
+ terms: Optional[TermsInputType] = None,
81
+ domain: Optional[DomainInputType] = None,
82
+ model_group: Optional[Union[str, MlModelGroupUrn]] = None,
83
+ training_jobs: Optional[MLTrainingJobInputType] = None,
84
+ downstream_jobs: Optional[MLTrainingJobInputType] = None,
85
+ extra_aspects: ExtraAspectsType = None,
86
+ ):
87
+ urn = MlModelUrn(platform=platform, name=id, env=env)
88
+ super().__init__(urn)
89
+ self._set_extra_aspects(extra_aspects)
90
+
91
+ self._set_platform_instance(urn.platform, platform_instance)
92
+
93
+ self._ensure_model_props()
94
+
95
+ if version is not None:
96
+ self.set_version(version)
97
+ if name is not None:
98
+ self.set_name(name)
99
+ if aliases is not None:
100
+ self.set_version_aliases(aliases)
101
+ if description is not None:
102
+ self.set_description(description)
103
+ if training_metrics is not None:
104
+ self.set_training_metrics(training_metrics)
105
+ if hyper_params is not None:
106
+ self.set_hyper_params(hyper_params)
107
+ if external_url is not None:
108
+ self.set_external_url(external_url)
109
+ if custom_properties is not None:
110
+ self.set_custom_properties(custom_properties)
111
+ if created is not None:
112
+ self.set_created(created)
113
+ if last_modified is not None:
114
+ self.set_last_modified(last_modified)
115
+
116
+ if owners is not None:
117
+ self.set_owners(owners)
118
+ if links is not None:
119
+ self.set_links(links)
120
+ if tags is not None:
121
+ self.set_tags(tags)
122
+ if terms is not None:
123
+ self.set_terms(terms)
124
+ if domain is not None:
125
+ self.set_domain(domain)
126
+ if model_group is not None:
127
+ self.set_model_group(model_group)
128
+ if training_jobs is not None:
129
+ self.set_training_jobs(training_jobs)
130
+ if downstream_jobs is not None:
131
+ self.set_downstream_jobs(downstream_jobs)
132
+
133
+ @classmethod
134
+ def _new_from_graph(cls, urn: Urn, current_aspects: AspectBag) -> Self:
135
+ assert isinstance(urn, MlModelUrn)
136
+ entity = cls(
137
+ id=urn.name,
138
+ platform=urn.platform,
139
+ env=urn.env,
140
+ )
141
+ return entity._init_from_graph(current_aspects)
142
+
143
+ @property
144
+ def urn(self) -> MlModelUrn:
145
+ return self._urn # type: ignore
146
+
147
+ def _ensure_model_props(
148
+ self,
149
+ ) -> MLModelPropertiesClass:
150
+ return self._setdefault_aspect(MLModelPropertiesClass())
151
+
152
+ @property
153
+ def name(self) -> Optional[str]:
154
+ return self._ensure_model_props().name
155
+
156
+ def set_name(self, name: str) -> None:
157
+ self._ensure_model_props().name = name
158
+
159
+ @property
160
+ def description(self) -> Optional[str]:
161
+ return self._ensure_model_props().description
162
+
163
+ def set_description(self, description: str) -> None:
164
+ self._ensure_model_props().description = description
165
+
166
+ @property
167
+ def external_url(self) -> Optional[str]:
168
+ return self._ensure_model_props().externalUrl
169
+
170
+ def set_external_url(self, external_url: str) -> None:
171
+ self._ensure_model_props().externalUrl = external_url
172
+
173
+ @property
174
+ def custom_properties(self) -> Optional[Dict[str, str]]:
175
+ return self._ensure_model_props().customProperties
176
+
177
+ def set_custom_properties(self, custom_properties: Dict[str, str]) -> None:
178
+ self._ensure_model_props().customProperties = custom_properties
179
+
180
+ @property
181
+ def created(self) -> Optional[datetime]:
182
+ return parse_time_stamp(self._ensure_model_props().created)
183
+
184
+ def set_created(self, created: datetime) -> None:
185
+ self._ensure_model_props().created = make_time_stamp(created)
186
+
187
+ @property
188
+ def last_modified(self) -> Optional[datetime]:
189
+ return parse_time_stamp(self._ensure_model_props().lastModified)
190
+
191
+ def set_last_modified(self, last_modified: datetime) -> None:
192
+ self._ensure_model_props().lastModified = make_time_stamp(last_modified)
193
+
194
+ @property
195
+ def training_metrics(self) -> Optional[List[MLMetricClass]]:
196
+ return self._ensure_model_props().trainingMetrics
197
+
198
+ def set_training_metrics(self, metrics: TrainingMetricsInputType) -> None:
199
+ self._ensure_model_props().trainingMetrics = convert_training_metrics(metrics)
200
+
201
+ def add_training_metrics(self, metrics: TrainingMetricsInputType) -> None:
202
+ props = self._ensure_model_props()
203
+ if props.trainingMetrics is None:
204
+ props.trainingMetrics = []
205
+ if isinstance(metrics, list):
206
+ props.trainingMetrics.extend(
207
+ [
208
+ MLMetricClass(name=metric.name, value=metric.value)
209
+ for metric in metrics
210
+ ]
211
+ )
212
+ else:
213
+ # For dictionary case, use the key as name and value as value
214
+ for name, value in metrics.items():
215
+ props.trainingMetrics.append(MLMetricClass(name=name, value=value))
216
+
217
+ @property
218
+ def hyper_params(self) -> Optional[List[MLHyperParamClass]]:
219
+ return self._ensure_model_props().hyperParams
220
+
221
+ def set_hyper_params(self, params: HyperParamsInputType) -> None:
222
+ self._ensure_model_props().hyperParams = convert_hyper_params(params)
223
+
224
+ def add_hyper_params(self, params: HyperParamsInputType) -> None:
225
+ props = self._ensure_model_props()
226
+ if props.hyperParams is None:
227
+ props.hyperParams = []
228
+ if isinstance(params, list):
229
+ props.hyperParams.extend(
230
+ [
231
+ MLHyperParamClass(name=param.name, value=param.value)
232
+ for param in params
233
+ ]
234
+ )
235
+ else:
236
+ # For dictionary case, iterate through key-value pairs
237
+ for name, value in params.items():
238
+ props.hyperParams.append(MLHyperParamClass(name=name, value=value))
239
+
240
+ @property
241
+ def model_group(self) -> Optional[str]:
242
+ props = self._ensure_model_props()
243
+ groups = props.groups
244
+ if groups is None or len(groups) == 0:
245
+ return None
246
+ return groups[0]
247
+
248
+ def set_model_group(self, group: Union[str, MlModelGroupUrn]) -> None:
249
+ self._ensure_model_props().groups = [str(group)]
250
+
251
+ @property
252
+ def training_jobs(self) -> Optional[List[str]]:
253
+ return self._ensure_model_props().trainingJobs
254
+
255
+ def set_training_jobs(self, training_jobs: MLTrainingJobInputType) -> None:
256
+ self._ensure_model_props().trainingJobs = [str(job) for job in training_jobs]
257
+
258
+ def add_training_job(
259
+ self, training_job: Union[str, DataProcessInstanceUrn]
260
+ ) -> None:
261
+ props = self._ensure_model_props()
262
+ if props.trainingJobs is None:
263
+ props.trainingJobs = []
264
+ props.trainingJobs.append(str(training_job))
265
+
266
+ def remove_training_job(
267
+ self, training_job: Union[str, DataProcessInstanceUrn]
268
+ ) -> None:
269
+ props = self._ensure_model_props()
270
+ if props.trainingJobs is not None:
271
+ job_str = str(training_job)
272
+ props.trainingJobs = [job for job in props.trainingJobs if job != job_str]
273
+
274
+ @property
275
+ def downstream_jobs(self) -> Optional[List[str]]:
276
+ return self._ensure_model_props().downstreamJobs
277
+
278
+ def set_downstream_jobs(
279
+ self, downstream_jobs: Sequence[Union[str, DataProcessInstanceUrn]]
280
+ ) -> None:
281
+ self._ensure_model_props().downstreamJobs = [
282
+ str(job) for job in downstream_jobs
283
+ ]
284
+
285
+ def add_downstream_job(
286
+ self, downstream_job: Union[str, DataProcessInstanceUrn]
287
+ ) -> None:
288
+ props = self._ensure_model_props()
289
+ if props.downstreamJobs is None:
290
+ props.downstreamJobs = []
291
+ props.downstreamJobs.append(str(downstream_job))
292
+
293
+ def remove_downstream_job(
294
+ self, downstream_job: Union[str, DataProcessInstanceUrn]
295
+ ) -> None:
296
+ props = self._ensure_model_props()
297
+ if props.downstreamJobs is not None:
298
+ job_str = str(downstream_job)
299
+ props.downstreamJobs = [
300
+ job for job in props.downstreamJobs if job != job_str
301
+ ]