zenml-nightly 0.68.1.dev20241107__py3-none-any.whl → 0.68.1.dev20241108__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/artifacts/external_artifact.py +2 -1
- zenml/artifacts/utils.py +13 -20
- zenml/cli/base.py +4 -4
- zenml/cli/model.py +1 -6
- zenml/cli/stack.py +1 -0
- zenml/client.py +21 -73
- zenml/enums.py +12 -4
- zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +1 -1
- zenml/integrations/azure/orchestrators/azureml_orchestrator.py +1 -1
- zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +1 -1
- zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py +60 -54
- zenml/metadata/lazy_load.py +20 -7
- zenml/model/model.py +1 -2
- zenml/models/__init__.py +0 -12
- zenml/models/v2/core/artifact_version.py +19 -7
- zenml/models/v2/core/model_version.py +3 -5
- zenml/models/v2/core/pipeline_run.py +3 -5
- zenml/models/v2/core/run_metadata.py +2 -217
- zenml/models/v2/core/step_run.py +40 -24
- zenml/orchestrators/input_utils.py +44 -19
- zenml/orchestrators/step_launcher.py +2 -2
- zenml/orchestrators/step_run_utils.py +19 -15
- zenml/orchestrators/step_runner.py +8 -3
- zenml/steps/base_step.py +1 -1
- zenml/steps/entrypoint_function_utils.py +3 -5
- zenml/steps/step_context.py +3 -2
- zenml/steps/utils.py +8 -2
- zenml/zen_server/rbac/utils.py +0 -2
- zenml/zen_server/routers/workspaces_endpoints.py +3 -4
- zenml/zen_server/zen_server_api.py +0 -2
- zenml/zen_stores/migrations/versions/1cb6477f72d6_move_artifact_save_type.py +99 -0
- zenml/zen_stores/migrations/versions/b557b2871693_update_step_run_input_types.py +33 -0
- zenml/zen_stores/rest_zen_store.py +3 -54
- zenml/zen_stores/schemas/artifact_schemas.py +8 -1
- zenml/zen_stores/schemas/model_schemas.py +2 -2
- zenml/zen_stores/schemas/pipeline_run_schemas.py +1 -1
- zenml/zen_stores/schemas/run_metadata_schemas.py +1 -48
- zenml/zen_stores/schemas/step_run_schemas.py +18 -10
- zenml/zen_stores/sql_zen_store.py +52 -98
- zenml/zen_stores/zen_store_interface.py +2 -42
- {zenml_nightly-0.68.1.dev20241107.dist-info → zenml_nightly-0.68.1.dev20241108.dist-info}/METADATA +1 -1
- {zenml_nightly-0.68.1.dev20241107.dist-info → zenml_nightly-0.68.1.dev20241108.dist-info}/RECORD +46 -45
- zenml/zen_server/routers/run_metadata_endpoints.py +0 -96
- {zenml_nightly-0.68.1.dev20241107.dist-info → zenml_nightly-0.68.1.dev20241108.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.68.1.dev20241107.dist-info → zenml_nightly-0.68.1.dev20241108.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.68.1.dev20241107.dist-info → zenml_nightly-0.68.1.dev20241108.dist-info}/entry_points.txt +0 -0
@@ -34,7 +34,7 @@ from pydantic import (
|
|
34
34
|
|
35
35
|
from zenml.config.source import Source, SourceWithValidator
|
36
36
|
from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH
|
37
|
-
from zenml.enums import ArtifactType, GenericFilterOps
|
37
|
+
from zenml.enums import ArtifactSaveType, ArtifactType, GenericFilterOps
|
38
38
|
from zenml.logger import get_logger
|
39
39
|
from zenml.metadata.metadata_types import MetadataType
|
40
40
|
from zenml.models.v2.base.filter import StrFilter
|
@@ -57,9 +57,6 @@ if TYPE_CHECKING:
|
|
57
57
|
ArtifactVisualizationResponse,
|
58
58
|
)
|
59
59
|
from zenml.models.v2.core.pipeline_run import PipelineRunResponse
|
60
|
-
from zenml.models.v2.core.run_metadata import (
|
61
|
-
RunMetadataResponse,
|
62
|
-
)
|
63
60
|
from zenml.models.v2.core.step_run import StepRunResponse
|
64
61
|
|
65
62
|
logger = get_logger(__name__)
|
@@ -107,6 +104,9 @@ class ArtifactVersionRequest(WorkspaceScopedRequest):
|
|
107
104
|
visualizations: Optional[List["ArtifactVisualizationRequest"]] = Field(
|
108
105
|
default=None, title="Visualizations of the artifact."
|
109
106
|
)
|
107
|
+
save_type: ArtifactSaveType = Field(
|
108
|
+
title="The save type of the artifact version.",
|
109
|
+
)
|
110
110
|
metadata: Optional[Dict[str, MetadataType]] = Field(
|
111
111
|
default=None, title="Metadata of the artifact version."
|
112
112
|
)
|
@@ -193,6 +193,9 @@ class ArtifactVersionResponseBody(WorkspaceScopedResponseBody):
|
|
193
193
|
title="The ID of the pipeline run that generated this artifact version.",
|
194
194
|
default=None,
|
195
195
|
)
|
196
|
+
save_type: ArtifactSaveType = Field(
|
197
|
+
title="The save type of the artifact version.",
|
198
|
+
)
|
196
199
|
artifact_store_id: Optional[UUID] = Field(
|
197
200
|
title="ID of the artifact store in which this artifact is stored.",
|
198
201
|
default=None,
|
@@ -230,7 +233,7 @@ class ArtifactVersionResponseMetadata(WorkspaceScopedResponseMetadata):
|
|
230
233
|
visualizations: Optional[List["ArtifactVisualizationResponse"]] = Field(
|
231
234
|
default=None, title="Visualizations of the artifact."
|
232
235
|
)
|
233
|
-
run_metadata: Dict[str,
|
236
|
+
run_metadata: Dict[str, MetadataType] = Field(
|
234
237
|
default={}, title="Metadata of the artifact."
|
235
238
|
)
|
236
239
|
|
@@ -313,6 +316,15 @@ class ArtifactVersionResponse(
|
|
313
316
|
"""
|
314
317
|
return self.get_body().producer_pipeline_run_id
|
315
318
|
|
319
|
+
@property
|
320
|
+
def save_type(self) -> ArtifactSaveType:
|
321
|
+
"""The `save_type` property.
|
322
|
+
|
323
|
+
Returns:
|
324
|
+
the value of the property.
|
325
|
+
"""
|
326
|
+
return self.get_body().save_type
|
327
|
+
|
316
328
|
@property
|
317
329
|
def artifact_store_id(self) -> Optional[UUID]:
|
318
330
|
"""The `artifact_store_id` property.
|
@@ -343,7 +355,7 @@ class ArtifactVersionResponse(
|
|
343
355
|
return self.get_metadata().visualizations
|
344
356
|
|
345
357
|
@property
|
346
|
-
def run_metadata(self) -> Dict[str,
|
358
|
+
def run_metadata(self) -> Dict[str, MetadataType]:
|
347
359
|
"""The `metadata` property.
|
348
360
|
|
349
361
|
Returns:
|
@@ -671,7 +683,7 @@ class LazyArtifactVersionResponse(ArtifactVersionResponse):
|
|
671
683
|
)
|
672
684
|
|
673
685
|
@property
|
674
|
-
def run_metadata(self) -> Dict[str,
|
686
|
+
def run_metadata(self) -> Dict[str, MetadataType]:
|
675
687
|
"""The `metadata` property in lazy loading mode.
|
676
688
|
|
677
689
|
Returns:
|
@@ -29,6 +29,7 @@ from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, field_validator
|
|
29
29
|
|
30
30
|
from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH
|
31
31
|
from zenml.enums import ModelStages
|
32
|
+
from zenml.metadata.metadata_types import MetadataType
|
32
33
|
from zenml.models.v2.base.filter import AnyQuery
|
33
34
|
from zenml.models.v2.base.page import Page
|
34
35
|
from zenml.models.v2.base.scoped import (
|
@@ -49,9 +50,6 @@ if TYPE_CHECKING:
|
|
49
50
|
from zenml.models.v2.core.artifact_version import ArtifactVersionResponse
|
50
51
|
from zenml.models.v2.core.model import ModelResponse
|
51
52
|
from zenml.models.v2.core.pipeline_run import PipelineRunResponse
|
52
|
-
from zenml.models.v2.core.run_metadata import (
|
53
|
-
RunMetadataResponse,
|
54
|
-
)
|
55
53
|
from zenml.zen_stores.schemas import BaseSchema
|
56
54
|
|
57
55
|
AnySchema = TypeVar("AnySchema", bound=BaseSchema)
|
@@ -193,7 +191,7 @@ class ModelVersionResponseMetadata(WorkspaceScopedResponseMetadata):
|
|
193
191
|
max_length=TEXT_FIELD_MAX_LENGTH,
|
194
192
|
default=None,
|
195
193
|
)
|
196
|
-
run_metadata: Dict[str,
|
194
|
+
run_metadata: Dict[str, MetadataType] = Field(
|
197
195
|
description="Metadata linked to the model version",
|
198
196
|
default={},
|
199
197
|
)
|
@@ -304,7 +302,7 @@ class ModelVersionResponse(
|
|
304
302
|
return self.get_metadata().description
|
305
303
|
|
306
304
|
@property
|
307
|
-
def run_metadata(self) ->
|
305
|
+
def run_metadata(self) -> Dict[str, MetadataType]:
|
308
306
|
"""The `run_metadata` property.
|
309
307
|
|
310
308
|
Returns:
|
@@ -30,6 +30,7 @@ from pydantic import BaseModel, ConfigDict, Field
|
|
30
30
|
from zenml.config.pipeline_configurations import PipelineConfiguration
|
31
31
|
from zenml.constants import STR_FIELD_MAX_LENGTH
|
32
32
|
from zenml.enums import ExecutionStatus
|
33
|
+
from zenml.metadata.metadata_types import MetadataType
|
33
34
|
from zenml.models.v2.base.scoped import (
|
34
35
|
WorkspaceScopedRequest,
|
35
36
|
WorkspaceScopedResponse,
|
@@ -51,9 +52,6 @@ if TYPE_CHECKING:
|
|
51
52
|
from zenml.models.v2.core.pipeline_build import (
|
52
53
|
PipelineBuildResponse,
|
53
54
|
)
|
54
|
-
from zenml.models.v2.core.run_metadata import (
|
55
|
-
RunMetadataResponse,
|
56
|
-
)
|
57
55
|
from zenml.models.v2.core.schedule import ScheduleResponse
|
58
56
|
from zenml.models.v2.core.stack import StackResponse
|
59
57
|
from zenml.models.v2.core.step_run import StepRunResponse
|
@@ -190,7 +188,7 @@ class PipelineRunResponseBody(WorkspaceScopedResponseBody):
|
|
190
188
|
class PipelineRunResponseMetadata(WorkspaceScopedResponseMetadata):
|
191
189
|
"""Response metadata for pipeline runs."""
|
192
190
|
|
193
|
-
run_metadata: Dict[str,
|
191
|
+
run_metadata: Dict[str, MetadataType] = Field(
|
194
192
|
default={},
|
195
193
|
title="Metadata associated with this pipeline run.",
|
196
194
|
)
|
@@ -450,7 +448,7 @@ class PipelineRunResponse(
|
|
450
448
|
return self.get_body().model_version_id
|
451
449
|
|
452
450
|
@property
|
453
|
-
def run_metadata(self) -> Dict[str,
|
451
|
+
def run_metadata(self) -> Dict[str, MetadataType]:
|
454
452
|
"""The `run_metadata` property.
|
455
453
|
|
456
454
|
Returns:
|
@@ -13,21 +13,15 @@
|
|
13
13
|
# permissions and limitations under the License.
|
14
14
|
"""Models representing run metadata."""
|
15
15
|
|
16
|
-
from typing import
|
16
|
+
from typing import Dict, Optional
|
17
17
|
from uuid import UUID
|
18
18
|
|
19
|
-
from pydantic import Field
|
19
|
+
from pydantic import Field
|
20
20
|
|
21
|
-
from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH
|
22
21
|
from zenml.enums import MetadataResourceTypes
|
23
22
|
from zenml.metadata.metadata_types import MetadataType, MetadataTypeEnum
|
24
23
|
from zenml.models.v2.base.scoped import (
|
25
|
-
WorkspaceScopedFilter,
|
26
24
|
WorkspaceScopedRequest,
|
27
|
-
WorkspaceScopedResponse,
|
28
|
-
WorkspaceScopedResponseBody,
|
29
|
-
WorkspaceScopedResponseMetadata,
|
30
|
-
WorkspaceScopedResponseResources,
|
31
25
|
)
|
32
26
|
|
33
27
|
# ------------------ Request Model ------------------
|
@@ -51,212 +45,3 @@ class RunMetadataRequest(WorkspaceScopedRequest):
|
|
51
45
|
types: Dict[str, "MetadataTypeEnum"] = Field(
|
52
46
|
title="The types of the metadata to be created.",
|
53
47
|
)
|
54
|
-
|
55
|
-
|
56
|
-
# ------------------ Update Model ------------------
|
57
|
-
|
58
|
-
# There is no update model for run metadata.
|
59
|
-
|
60
|
-
# ------------------ Response Model ------------------
|
61
|
-
|
62
|
-
|
63
|
-
class RunMetadataResponseBody(WorkspaceScopedResponseBody):
|
64
|
-
"""Response body for run metadata."""
|
65
|
-
|
66
|
-
key: str = Field(title="The key of the metadata.")
|
67
|
-
value: MetadataType = Field(
|
68
|
-
title="The value of the metadata.", union_mode="smart"
|
69
|
-
)
|
70
|
-
type: MetadataTypeEnum = Field(title="The type of the metadata.")
|
71
|
-
|
72
|
-
@field_validator("key", "type")
|
73
|
-
@classmethod
|
74
|
-
def str_field_max_length_check(cls, value: Any) -> Any:
|
75
|
-
"""Checks if the length of the value exceeds the maximum str length.
|
76
|
-
|
77
|
-
Args:
|
78
|
-
value: the value set in the field
|
79
|
-
|
80
|
-
Returns:
|
81
|
-
the value itself.
|
82
|
-
|
83
|
-
Raises:
|
84
|
-
AssertionError: if the length of the field is longer than the
|
85
|
-
maximum threshold.
|
86
|
-
"""
|
87
|
-
assert len(str(value)) < STR_FIELD_MAX_LENGTH, (
|
88
|
-
"The length of the value for this field can not "
|
89
|
-
f"exceed {STR_FIELD_MAX_LENGTH}"
|
90
|
-
)
|
91
|
-
return value
|
92
|
-
|
93
|
-
@field_validator("value")
|
94
|
-
@classmethod
|
95
|
-
def text_field_max_length_check(cls, value: Any) -> Any:
|
96
|
-
"""Checks if the length of the value exceeds the maximum text length.
|
97
|
-
|
98
|
-
Args:
|
99
|
-
value: the value set in the field
|
100
|
-
|
101
|
-
Returns:
|
102
|
-
the value itself.
|
103
|
-
|
104
|
-
Raises:
|
105
|
-
AssertionError: if the length of the field is longer than the
|
106
|
-
maximum threshold.
|
107
|
-
"""
|
108
|
-
assert len(str(value)) < TEXT_FIELD_MAX_LENGTH, (
|
109
|
-
"The length of the value for this field can not "
|
110
|
-
f"exceed {TEXT_FIELD_MAX_LENGTH}"
|
111
|
-
)
|
112
|
-
return value
|
113
|
-
|
114
|
-
|
115
|
-
class RunMetadataResponseMetadata(WorkspaceScopedResponseMetadata):
|
116
|
-
"""Response metadata for run metadata."""
|
117
|
-
|
118
|
-
resource_id: UUID = Field(
|
119
|
-
title="The ID of the resource that this metadata belongs to.",
|
120
|
-
)
|
121
|
-
resource_type: MetadataResourceTypes = Field(
|
122
|
-
title="The type of the resource that this metadata belongs to.",
|
123
|
-
)
|
124
|
-
stack_component_id: Optional[UUID] = Field(
|
125
|
-
title="The ID of the stack component that this metadata belongs to."
|
126
|
-
)
|
127
|
-
|
128
|
-
|
129
|
-
class RunMetadataResponseResources(WorkspaceScopedResponseResources):
|
130
|
-
"""Class for all resource models associated with the run metadata entity."""
|
131
|
-
|
132
|
-
|
133
|
-
class RunMetadataResponse(
|
134
|
-
WorkspaceScopedResponse[
|
135
|
-
RunMetadataResponseBody,
|
136
|
-
RunMetadataResponseMetadata,
|
137
|
-
RunMetadataResponseResources,
|
138
|
-
]
|
139
|
-
):
|
140
|
-
"""Response model for run metadata."""
|
141
|
-
|
142
|
-
def get_hydrated_version(self) -> "RunMetadataResponse":
|
143
|
-
"""Get the hydrated version of this run metadata.
|
144
|
-
|
145
|
-
Returns:
|
146
|
-
an instance of the same entity with the metadata field attached.
|
147
|
-
"""
|
148
|
-
from zenml.client import Client
|
149
|
-
|
150
|
-
return Client().zen_store.get_run_metadata(self.id)
|
151
|
-
|
152
|
-
# Body and metadata properties
|
153
|
-
@property
|
154
|
-
def key(self) -> str:
|
155
|
-
"""The `key` property.
|
156
|
-
|
157
|
-
Returns:
|
158
|
-
the value of the property.
|
159
|
-
"""
|
160
|
-
return self.get_body().key
|
161
|
-
|
162
|
-
@property
|
163
|
-
def value(self) -> MetadataType:
|
164
|
-
"""The `value` property.
|
165
|
-
|
166
|
-
Returns:
|
167
|
-
the value of the property.
|
168
|
-
"""
|
169
|
-
return self.get_body().value
|
170
|
-
|
171
|
-
@property
|
172
|
-
def type(self) -> MetadataTypeEnum:
|
173
|
-
"""The `type` property.
|
174
|
-
|
175
|
-
Returns:
|
176
|
-
the value of the property.
|
177
|
-
"""
|
178
|
-
return self.get_body().type
|
179
|
-
|
180
|
-
@property
|
181
|
-
def resource_id(self) -> UUID:
|
182
|
-
"""The `resource_id` property.
|
183
|
-
|
184
|
-
Returns:
|
185
|
-
the value of the property.
|
186
|
-
"""
|
187
|
-
return self.get_metadata().resource_id
|
188
|
-
|
189
|
-
@property
|
190
|
-
def resource_type(self) -> MetadataResourceTypes:
|
191
|
-
"""The `resource_type` property.
|
192
|
-
|
193
|
-
Returns:
|
194
|
-
the value of the property.
|
195
|
-
"""
|
196
|
-
return MetadataResourceTypes(self.get_metadata().resource_type)
|
197
|
-
|
198
|
-
@property
|
199
|
-
def stack_component_id(self) -> Optional[UUID]:
|
200
|
-
"""The `stack_component_id` property.
|
201
|
-
|
202
|
-
Returns:
|
203
|
-
the value of the property.
|
204
|
-
"""
|
205
|
-
return self.get_metadata().stack_component_id
|
206
|
-
|
207
|
-
|
208
|
-
# ------------------ Filter Model ------------------
|
209
|
-
|
210
|
-
|
211
|
-
class RunMetadataFilter(WorkspaceScopedFilter):
|
212
|
-
"""Model to enable advanced filtering of run metadata."""
|
213
|
-
|
214
|
-
resource_id: Optional[Union[str, UUID]] = Field(
|
215
|
-
default=None, union_mode="left_to_right"
|
216
|
-
)
|
217
|
-
resource_type: Optional[MetadataResourceTypes] = None
|
218
|
-
stack_component_id: Optional[Union[str, UUID]] = Field(
|
219
|
-
default=None, union_mode="left_to_right"
|
220
|
-
)
|
221
|
-
key: Optional[str] = None
|
222
|
-
type: Optional[Union[str, MetadataTypeEnum]] = Field(
|
223
|
-
default=None, union_mode="left_to_right"
|
224
|
-
)
|
225
|
-
|
226
|
-
|
227
|
-
# -------------------- Lazy Loader --------------------
|
228
|
-
|
229
|
-
|
230
|
-
class LazyRunMetadataResponse(RunMetadataResponse):
|
231
|
-
"""Lazy run metadata response.
|
232
|
-
|
233
|
-
Used if the run metadata is accessed from the model in
|
234
|
-
a pipeline context available only during pipeline compilation.
|
235
|
-
"""
|
236
|
-
|
237
|
-
id: Optional[UUID] = None # type: ignore[assignment]
|
238
|
-
lazy_load_artifact_name: Optional[str] = None
|
239
|
-
lazy_load_artifact_version: Optional[str] = None
|
240
|
-
lazy_load_metadata_name: Optional[str] = None
|
241
|
-
lazy_load_model_name: str
|
242
|
-
lazy_load_model_version: Optional[str] = None
|
243
|
-
|
244
|
-
def get_body(self) -> None: # type: ignore[override]
|
245
|
-
"""Protects from misuse of the lazy loader.
|
246
|
-
|
247
|
-
Raises:
|
248
|
-
RuntimeError: always
|
249
|
-
"""
|
250
|
-
raise RuntimeError(
|
251
|
-
"Cannot access run metadata body before pipeline runs."
|
252
|
-
)
|
253
|
-
|
254
|
-
def get_metadata(self) -> None: # type: ignore[override]
|
255
|
-
"""Protects from misuse of the lazy loader.
|
256
|
-
|
257
|
-
Raises:
|
258
|
-
RuntimeError: always
|
259
|
-
"""
|
260
|
-
raise RuntimeError(
|
261
|
-
"Cannot access run metadata metadata before pipeline runs."
|
262
|
-
)
|
zenml/models/v2/core/step_run.py
CHANGED
@@ -21,7 +21,8 @@ from pydantic import BaseModel, ConfigDict, Field
|
|
21
21
|
|
22
22
|
from zenml.config.step_configurations import StepConfiguration, StepSpec
|
23
23
|
from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH
|
24
|
-
from zenml.enums import ExecutionStatus
|
24
|
+
from zenml.enums import ExecutionStatus, StepRunInputArtifactType
|
25
|
+
from zenml.metadata.metadata_types import MetadataType
|
25
26
|
from zenml.models.v2.base.scoped import (
|
26
27
|
WorkspaceScopedFilter,
|
27
28
|
WorkspaceScopedRequest,
|
@@ -30,19 +31,35 @@ from zenml.models.v2.base.scoped import (
|
|
30
31
|
WorkspaceScopedResponseMetadata,
|
31
32
|
WorkspaceScopedResponseResources,
|
32
33
|
)
|
34
|
+
from zenml.models.v2.core.artifact_version import ArtifactVersionResponse
|
33
35
|
from zenml.models.v2.core.model_version import ModelVersionResponse
|
34
36
|
|
35
37
|
if TYPE_CHECKING:
|
36
38
|
from sqlalchemy.sql.elements import ColumnElement
|
37
39
|
|
38
|
-
from zenml.models.v2.core.artifact_version import ArtifactVersionResponse
|
39
40
|
from zenml.models.v2.core.logs import (
|
40
41
|
LogsRequest,
|
41
42
|
LogsResponse,
|
42
43
|
)
|
43
|
-
|
44
|
-
|
45
|
-
|
44
|
+
|
45
|
+
|
46
|
+
class StepRunInputResponse(ArtifactVersionResponse):
|
47
|
+
"""Response model for step run inputs."""
|
48
|
+
|
49
|
+
input_type: StepRunInputArtifactType
|
50
|
+
|
51
|
+
def get_hydrated_version(self) -> "StepRunInputResponse":
|
52
|
+
"""Get the hydrated version of this step run input.
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
an instance of the same entity with the metadata field attached.
|
56
|
+
"""
|
57
|
+
from zenml.client import Client
|
58
|
+
|
59
|
+
return StepRunInputResponse(
|
60
|
+
input_type=self.input_type,
|
61
|
+
**Client().zen_store.get_artifact_version(self.id).model_dump(),
|
62
|
+
)
|
46
63
|
|
47
64
|
|
48
65
|
# ------------------ Request Model ------------------
|
@@ -97,11 +114,11 @@ class StepRunRequest(WorkspaceScopedRequest):
|
|
97
114
|
)
|
98
115
|
inputs: Dict[str, UUID] = Field(
|
99
116
|
title="The IDs of the input artifact versions of the step run.",
|
100
|
-
|
117
|
+
default_factory=dict,
|
101
118
|
)
|
102
|
-
outputs: Dict[str, UUID] = Field(
|
119
|
+
outputs: Dict[str, List[UUID]] = Field(
|
103
120
|
title="The IDs of the output artifact versions of the step run.",
|
104
|
-
|
121
|
+
default_factory=dict,
|
105
122
|
)
|
106
123
|
logs: Optional["LogsRequest"] = Field(
|
107
124
|
title="Logs associated with this step run.",
|
@@ -129,10 +146,6 @@ class StepRunUpdate(BaseModel):
|
|
129
146
|
title="The IDs of the output artifact versions of the step run.",
|
130
147
|
default={},
|
131
148
|
)
|
132
|
-
saved_artifact_versions: Dict[str, UUID] = Field(
|
133
|
-
title="The IDs of artifact versions that were saved by this step run.",
|
134
|
-
default={},
|
135
|
-
)
|
136
149
|
loaded_artifact_versions: Dict[str, UUID] = Field(
|
137
150
|
title="The IDs of artifact versions that were loaded by this step run.",
|
138
151
|
default={},
|
@@ -166,13 +179,13 @@ class StepRunResponseBody(WorkspaceScopedResponseBody):
|
|
166
179
|
title="The end time of the step run.",
|
167
180
|
default=None,
|
168
181
|
)
|
169
|
-
inputs: Dict[str,
|
182
|
+
inputs: Dict[str, StepRunInputResponse] = Field(
|
170
183
|
title="The input artifact versions of the step run.",
|
171
|
-
|
184
|
+
default_factory=dict,
|
172
185
|
)
|
173
|
-
outputs: Dict[str,
|
186
|
+
outputs: Dict[str, List[ArtifactVersionResponse]] = Field(
|
174
187
|
title="The output artifact versions of the step run.",
|
175
|
-
|
188
|
+
default_factory=dict,
|
176
189
|
)
|
177
190
|
model_version_id: Optional[UUID] = Field(
|
178
191
|
title="The ID of the model version that was "
|
@@ -230,7 +243,7 @@ class StepRunResponseMetadata(WorkspaceScopedResponseMetadata):
|
|
230
243
|
title="The IDs of the parent steps of this step run.",
|
231
244
|
default_factory=list,
|
232
245
|
)
|
233
|
-
run_metadata: Dict[str,
|
246
|
+
run_metadata: Dict[str, MetadataType] = Field(
|
234
247
|
title="Metadata associated with this step run.",
|
235
248
|
default={},
|
236
249
|
)
|
@@ -274,7 +287,7 @@ class StepRunResponse(
|
|
274
287
|
|
275
288
|
# Helper properties
|
276
289
|
@property
|
277
|
-
def input(self) ->
|
290
|
+
def input(self) -> ArtifactVersionResponse:
|
278
291
|
"""Returns the input artifact that was used to run this step.
|
279
292
|
|
280
293
|
Returns:
|
@@ -293,7 +306,7 @@ class StepRunResponse(
|
|
293
306
|
return next(iter(self.inputs.values()))
|
294
307
|
|
295
308
|
@property
|
296
|
-
def output(self) ->
|
309
|
+
def output(self) -> ArtifactVersionResponse:
|
297
310
|
"""Returns the output artifact that was written by this step.
|
298
311
|
|
299
312
|
Returns:
|
@@ -304,12 +317,15 @@ class StepRunResponse(
|
|
304
317
|
"""
|
305
318
|
if not self.outputs:
|
306
319
|
raise ValueError(f"Step {self.name} has no outputs.")
|
307
|
-
if len(self.outputs) > 1
|
320
|
+
if len(self.outputs) > 1 or (
|
321
|
+
len(self.outputs) == 1
|
322
|
+
and len(next(iter(self.outputs.values()))) > 1
|
323
|
+
):
|
308
324
|
raise ValueError(
|
309
325
|
f"Step {self.name} has multiple outputs, so `Step.output` is "
|
310
326
|
"ambiguous. Please use `Step.outputs` instead."
|
311
327
|
)
|
312
|
-
return next(iter(self.outputs.values()))
|
328
|
+
return next(iter(self.outputs.values()))[0]
|
313
329
|
|
314
330
|
# Body and metadata properties
|
315
331
|
@property
|
@@ -322,7 +338,7 @@ class StepRunResponse(
|
|
322
338
|
return self.get_body().status
|
323
339
|
|
324
340
|
@property
|
325
|
-
def inputs(self) -> Dict[str,
|
341
|
+
def inputs(self) -> Dict[str, StepRunInputResponse]:
|
326
342
|
"""The `inputs` property.
|
327
343
|
|
328
344
|
Returns:
|
@@ -331,7 +347,7 @@ class StepRunResponse(
|
|
331
347
|
return self.get_body().inputs
|
332
348
|
|
333
349
|
@property
|
334
|
-
def outputs(self) -> Dict[str,
|
350
|
+
def outputs(self) -> Dict[str, List[ArtifactVersionResponse]]:
|
335
351
|
"""The `outputs` property.
|
336
352
|
|
337
353
|
Returns:
|
@@ -466,7 +482,7 @@ class StepRunResponse(
|
|
466
482
|
return self.get_metadata().parent_step_ids
|
467
483
|
|
468
484
|
@property
|
469
|
-
def run_metadata(self) -> Dict[str,
|
485
|
+
def run_metadata(self) -> Dict[str, MetadataType]:
|
470
486
|
"""The `run_metadata` property.
|
471
487
|
|
472
488
|
Returns:
|
@@ -18,17 +18,19 @@ from uuid import UUID
|
|
18
18
|
|
19
19
|
from zenml.client import Client
|
20
20
|
from zenml.config.step_configurations import Step
|
21
|
+
from zenml.enums import ArtifactSaveType, StepRunInputArtifactType
|
21
22
|
from zenml.exceptions import InputResolutionError
|
22
23
|
from zenml.utils import pagination_utils
|
23
24
|
|
24
25
|
if TYPE_CHECKING:
|
25
|
-
from zenml.models import
|
26
|
+
from zenml.models import PipelineRunResponse
|
27
|
+
from zenml.models.v2.core.step_run import StepRunInputResponse
|
26
28
|
|
27
29
|
|
28
30
|
def resolve_step_inputs(
|
29
31
|
step: "Step",
|
30
32
|
pipeline_run: "PipelineRunResponse",
|
31
|
-
) -> Tuple[Dict[str, "
|
33
|
+
) -> Tuple[Dict[str, "StepRunInputResponse"], List[UUID]]:
|
32
34
|
"""Resolves inputs for the current step.
|
33
35
|
|
34
36
|
Args:
|
@@ -45,7 +47,8 @@ def resolve_step_inputs(
|
|
45
47
|
The IDs of the input artifact versions and the IDs of parent steps of
|
46
48
|
the current step.
|
47
49
|
"""
|
48
|
-
from zenml.models import ArtifactVersionResponse
|
50
|
+
from zenml.models import ArtifactVersionResponse
|
51
|
+
from zenml.models.v2.core.step_run import StepRunInputResponse
|
49
52
|
|
50
53
|
current_run_steps = {
|
51
54
|
run_step.name: run_step
|
@@ -54,7 +57,7 @@ def resolve_step_inputs(
|
|
54
57
|
)
|
55
58
|
}
|
56
59
|
|
57
|
-
input_artifacts: Dict[str,
|
60
|
+
input_artifacts: Dict[str, StepRunInputResponse] = {}
|
58
61
|
for name, input_ in step.spec.inputs.items():
|
59
62
|
try:
|
60
63
|
step_run = current_run_steps[input_.step_name]
|
@@ -64,22 +67,44 @@ def resolve_step_inputs(
|
|
64
67
|
)
|
65
68
|
|
66
69
|
try:
|
67
|
-
|
70
|
+
outputs = step_run.outputs[input_.output_name]
|
68
71
|
except KeyError:
|
69
72
|
raise InputResolutionError(
|
70
|
-
f"No output `{input_.output_name}` found for step "
|
73
|
+
f"No step output `{input_.output_name}` found for step "
|
71
74
|
f"`{input_.step_name}`."
|
72
75
|
)
|
73
76
|
|
74
|
-
|
77
|
+
step_outputs = [
|
78
|
+
output
|
79
|
+
for output in outputs
|
80
|
+
if output.save_type == ArtifactSaveType.STEP_OUTPUT
|
81
|
+
]
|
82
|
+
if len(step_outputs) > 2:
|
83
|
+
# This should never happen, there can only be a single regular step
|
84
|
+
# output for a name
|
85
|
+
raise InputResolutionError(
|
86
|
+
f"Too many step outputs for output `{input_.output_name}` of "
|
87
|
+
f"step `{input_.step_name}`."
|
88
|
+
)
|
89
|
+
elif len(step_outputs) == 0:
|
90
|
+
raise InputResolutionError(
|
91
|
+
f"No step output `{input_.output_name}` found for step "
|
92
|
+
f"`{input_.step_name}`."
|
93
|
+
)
|
94
|
+
|
95
|
+
input_artifacts[name] = StepRunInputResponse(
|
96
|
+
input_type=StepRunInputArtifactType.STEP_OUTPUT,
|
97
|
+
**step_outputs[0].model_dump(),
|
98
|
+
)
|
75
99
|
|
76
100
|
for (
|
77
101
|
name,
|
78
102
|
external_artifact,
|
79
103
|
) in step.config.external_input_artifacts.items():
|
80
104
|
artifact_version_id = external_artifact.get_artifact_version_id()
|
81
|
-
input_artifacts[name] =
|
82
|
-
|
105
|
+
input_artifacts[name] = StepRunInputResponse(
|
106
|
+
input_type=StepRunInputArtifactType.EXTERNAL,
|
107
|
+
**Client().get_artifact_version(artifact_version_id).model_dump(),
|
83
108
|
)
|
84
109
|
|
85
110
|
for name, config_ in step.config.model_artifacts_or_metadata.items():
|
@@ -98,9 +123,7 @@ def resolve_step_inputs(
|
|
98
123
|
):
|
99
124
|
# metadata values should go directly in parameters, as primitive types
|
100
125
|
step.config.parameters[name] = (
|
101
|
-
context_model_version.run_metadata[
|
102
|
-
config_.metadata_name
|
103
|
-
].value
|
126
|
+
context_model_version.run_metadata[config_.metadata_name]
|
104
127
|
)
|
105
128
|
elif config_.artifact_name is None:
|
106
129
|
err_msg = (
|
@@ -112,14 +135,15 @@ def resolve_step_inputs(
|
|
112
135
|
config_.artifact_name, config_.artifact_version
|
113
136
|
):
|
114
137
|
if config_.metadata_name is None:
|
115
|
-
input_artifacts[name] =
|
138
|
+
input_artifacts[name] = StepRunInputResponse(
|
139
|
+
input_type=StepRunInputArtifactType.LAZY_LOADED,
|
140
|
+
**artifact_.model_dump(),
|
141
|
+
)
|
116
142
|
elif config_.metadata_name:
|
117
143
|
# metadata values should go directly in parameters, as primitive types
|
118
144
|
try:
|
119
145
|
step.config.parameters[name] = (
|
120
|
-
artifact_.run_metadata[
|
121
|
-
config_.metadata_name
|
122
|
-
].value
|
146
|
+
artifact_.run_metadata[config_.metadata_name]
|
123
147
|
)
|
124
148
|
except KeyError:
|
125
149
|
err_msg = (
|
@@ -141,9 +165,10 @@ def resolve_step_inputs(
|
|
141
165
|
for name, cll_ in step.config.client_lazy_loaders.items():
|
142
166
|
value_ = cll_.evaluate()
|
143
167
|
if isinstance(value_, ArtifactVersionResponse):
|
144
|
-
input_artifacts[name] =
|
145
|
-
|
146
|
-
|
168
|
+
input_artifacts[name] = StepRunInputResponse(
|
169
|
+
input_type=StepRunInputArtifactType.LAZY_LOADED,
|
170
|
+
**value_.model_dump(),
|
171
|
+
)
|
147
172
|
else:
|
148
173
|
step.config.parameters[name] = value_
|
149
174
|
|