zenml-nightly 0.68.1.dev20241106__py3-none-any.whl → 0.68.1.dev20241108__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/artifacts/external_artifact.py +2 -1
- zenml/artifacts/utils.py +133 -78
- zenml/cli/base.py +4 -4
- zenml/cli/model.py +1 -6
- zenml/cli/stack.py +1 -0
- zenml/client.py +21 -73
- zenml/constants.py +1 -0
- zenml/enums.py +12 -4
- zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +1 -1
- zenml/integrations/azure/orchestrators/azureml_orchestrator.py +1 -1
- zenml/integrations/evidently/__init__.py +1 -1
- zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +1 -1
- zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py +60 -54
- zenml/integrations/vllm/services/vllm_deployment.py +16 -7
- zenml/metadata/lazy_load.py +20 -7
- zenml/model/model.py +1 -2
- zenml/models/__init__.py +0 -12
- zenml/models/v2/core/artifact_version.py +19 -7
- zenml/models/v2/core/model_version.py +3 -5
- zenml/models/v2/core/pipeline_run.py +3 -5
- zenml/models/v2/core/run_metadata.py +2 -217
- zenml/models/v2/core/step_run.py +40 -24
- zenml/orchestrators/input_utils.py +44 -19
- zenml/orchestrators/step_launcher.py +2 -2
- zenml/orchestrators/step_run_utils.py +19 -15
- zenml/orchestrators/step_runner.py +21 -13
- zenml/steps/base_step.py +1 -1
- zenml/steps/entrypoint_function_utils.py +3 -5
- zenml/steps/step_context.py +3 -2
- zenml/steps/utils.py +8 -2
- zenml/zen_server/rbac/endpoint_utils.py +43 -1
- zenml/zen_server/rbac/utils.py +0 -2
- zenml/zen_server/routers/artifact_version_endpoints.py +27 -1
- zenml/zen_server/routers/workspaces_endpoints.py +3 -4
- zenml/zen_server/zen_server_api.py +0 -2
- zenml/zen_stores/migrations/versions/1cb6477f72d6_move_artifact_save_type.py +99 -0
- zenml/zen_stores/migrations/versions/b557b2871693_update_step_run_input_types.py +33 -0
- zenml/zen_stores/rest_zen_store.py +55 -54
- zenml/zen_stores/schemas/artifact_schemas.py +8 -1
- zenml/zen_stores/schemas/model_schemas.py +2 -2
- zenml/zen_stores/schemas/pipeline_run_schemas.py +1 -1
- zenml/zen_stores/schemas/run_metadata_schemas.py +1 -48
- zenml/zen_stores/schemas/step_run_schemas.py +18 -10
- zenml/zen_stores/sql_zen_store.py +68 -98
- zenml/zen_stores/zen_store_interface.py +15 -42
- {zenml_nightly-0.68.1.dev20241106.dist-info → zenml_nightly-0.68.1.dev20241108.dist-info}/METADATA +1 -1
- {zenml_nightly-0.68.1.dev20241106.dist-info → zenml_nightly-0.68.1.dev20241108.dist-info}/RECORD +51 -50
- zenml/zen_server/routers/run_metadata_endpoints.py +0 -96
- {zenml_nightly-0.68.1.dev20241106.dist-info → zenml_nightly-0.68.1.dev20241108.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.68.1.dev20241106.dist-info → zenml_nightly-0.68.1.dev20241108.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.68.1.dev20241106.dist-info → zenml_nightly-0.68.1.dev20241108.dist-info}/entry_points.txt +0 -0
zenml/models/v2/core/step_run.py
CHANGED
@@ -21,7 +21,8 @@ from pydantic import BaseModel, ConfigDict, Field
|
|
21
21
|
|
22
22
|
from zenml.config.step_configurations import StepConfiguration, StepSpec
|
23
23
|
from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH
|
24
|
-
from zenml.enums import ExecutionStatus
|
24
|
+
from zenml.enums import ExecutionStatus, StepRunInputArtifactType
|
25
|
+
from zenml.metadata.metadata_types import MetadataType
|
25
26
|
from zenml.models.v2.base.scoped import (
|
26
27
|
WorkspaceScopedFilter,
|
27
28
|
WorkspaceScopedRequest,
|
@@ -30,19 +31,35 @@ from zenml.models.v2.base.scoped import (
|
|
30
31
|
WorkspaceScopedResponseMetadata,
|
31
32
|
WorkspaceScopedResponseResources,
|
32
33
|
)
|
34
|
+
from zenml.models.v2.core.artifact_version import ArtifactVersionResponse
|
33
35
|
from zenml.models.v2.core.model_version import ModelVersionResponse
|
34
36
|
|
35
37
|
if TYPE_CHECKING:
|
36
38
|
from sqlalchemy.sql.elements import ColumnElement
|
37
39
|
|
38
|
-
from zenml.models.v2.core.artifact_version import ArtifactVersionResponse
|
39
40
|
from zenml.models.v2.core.logs import (
|
40
41
|
LogsRequest,
|
41
42
|
LogsResponse,
|
42
43
|
)
|
43
|
-
|
44
|
-
|
45
|
-
|
44
|
+
|
45
|
+
|
46
|
+
class StepRunInputResponse(ArtifactVersionResponse):
|
47
|
+
"""Response model for step run inputs."""
|
48
|
+
|
49
|
+
input_type: StepRunInputArtifactType
|
50
|
+
|
51
|
+
def get_hydrated_version(self) -> "StepRunInputResponse":
|
52
|
+
"""Get the hydrated version of this step run input.
|
53
|
+
|
54
|
+
Returns:
|
55
|
+
an instance of the same entity with the metadata field attached.
|
56
|
+
"""
|
57
|
+
from zenml.client import Client
|
58
|
+
|
59
|
+
return StepRunInputResponse(
|
60
|
+
input_type=self.input_type,
|
61
|
+
**Client().zen_store.get_artifact_version(self.id).model_dump(),
|
62
|
+
)
|
46
63
|
|
47
64
|
|
48
65
|
# ------------------ Request Model ------------------
|
@@ -97,11 +114,11 @@ class StepRunRequest(WorkspaceScopedRequest):
|
|
97
114
|
)
|
98
115
|
inputs: Dict[str, UUID] = Field(
|
99
116
|
title="The IDs of the input artifact versions of the step run.",
|
100
|
-
|
117
|
+
default_factory=dict,
|
101
118
|
)
|
102
|
-
outputs: Dict[str, UUID] = Field(
|
119
|
+
outputs: Dict[str, List[UUID]] = Field(
|
103
120
|
title="The IDs of the output artifact versions of the step run.",
|
104
|
-
|
121
|
+
default_factory=dict,
|
105
122
|
)
|
106
123
|
logs: Optional["LogsRequest"] = Field(
|
107
124
|
title="Logs associated with this step run.",
|
@@ -129,10 +146,6 @@ class StepRunUpdate(BaseModel):
|
|
129
146
|
title="The IDs of the output artifact versions of the step run.",
|
130
147
|
default={},
|
131
148
|
)
|
132
|
-
saved_artifact_versions: Dict[str, UUID] = Field(
|
133
|
-
title="The IDs of artifact versions that were saved by this step run.",
|
134
|
-
default={},
|
135
|
-
)
|
136
149
|
loaded_artifact_versions: Dict[str, UUID] = Field(
|
137
150
|
title="The IDs of artifact versions that were loaded by this step run.",
|
138
151
|
default={},
|
@@ -166,13 +179,13 @@ class StepRunResponseBody(WorkspaceScopedResponseBody):
|
|
166
179
|
title="The end time of the step run.",
|
167
180
|
default=None,
|
168
181
|
)
|
169
|
-
inputs: Dict[str,
|
182
|
+
inputs: Dict[str, StepRunInputResponse] = Field(
|
170
183
|
title="The input artifact versions of the step run.",
|
171
|
-
|
184
|
+
default_factory=dict,
|
172
185
|
)
|
173
|
-
outputs: Dict[str,
|
186
|
+
outputs: Dict[str, List[ArtifactVersionResponse]] = Field(
|
174
187
|
title="The output artifact versions of the step run.",
|
175
|
-
|
188
|
+
default_factory=dict,
|
176
189
|
)
|
177
190
|
model_version_id: Optional[UUID] = Field(
|
178
191
|
title="The ID of the model version that was "
|
@@ -230,7 +243,7 @@ class StepRunResponseMetadata(WorkspaceScopedResponseMetadata):
|
|
230
243
|
title="The IDs of the parent steps of this step run.",
|
231
244
|
default_factory=list,
|
232
245
|
)
|
233
|
-
run_metadata: Dict[str,
|
246
|
+
run_metadata: Dict[str, MetadataType] = Field(
|
234
247
|
title="Metadata associated with this step run.",
|
235
248
|
default={},
|
236
249
|
)
|
@@ -274,7 +287,7 @@ class StepRunResponse(
|
|
274
287
|
|
275
288
|
# Helper properties
|
276
289
|
@property
|
277
|
-
def input(self) ->
|
290
|
+
def input(self) -> ArtifactVersionResponse:
|
278
291
|
"""Returns the input artifact that was used to run this step.
|
279
292
|
|
280
293
|
Returns:
|
@@ -293,7 +306,7 @@ class StepRunResponse(
|
|
293
306
|
return next(iter(self.inputs.values()))
|
294
307
|
|
295
308
|
@property
|
296
|
-
def output(self) ->
|
309
|
+
def output(self) -> ArtifactVersionResponse:
|
297
310
|
"""Returns the output artifact that was written by this step.
|
298
311
|
|
299
312
|
Returns:
|
@@ -304,12 +317,15 @@ class StepRunResponse(
|
|
304
317
|
"""
|
305
318
|
if not self.outputs:
|
306
319
|
raise ValueError(f"Step {self.name} has no outputs.")
|
307
|
-
if len(self.outputs) > 1
|
320
|
+
if len(self.outputs) > 1 or (
|
321
|
+
len(self.outputs) == 1
|
322
|
+
and len(next(iter(self.outputs.values()))) > 1
|
323
|
+
):
|
308
324
|
raise ValueError(
|
309
325
|
f"Step {self.name} has multiple outputs, so `Step.output` is "
|
310
326
|
"ambiguous. Please use `Step.outputs` instead."
|
311
327
|
)
|
312
|
-
return next(iter(self.outputs.values()))
|
328
|
+
return next(iter(self.outputs.values()))[0]
|
313
329
|
|
314
330
|
# Body and metadata properties
|
315
331
|
@property
|
@@ -322,7 +338,7 @@ class StepRunResponse(
|
|
322
338
|
return self.get_body().status
|
323
339
|
|
324
340
|
@property
|
325
|
-
def inputs(self) -> Dict[str,
|
341
|
+
def inputs(self) -> Dict[str, StepRunInputResponse]:
|
326
342
|
"""The `inputs` property.
|
327
343
|
|
328
344
|
Returns:
|
@@ -331,7 +347,7 @@ class StepRunResponse(
|
|
331
347
|
return self.get_body().inputs
|
332
348
|
|
333
349
|
@property
|
334
|
-
def outputs(self) -> Dict[str,
|
350
|
+
def outputs(self) -> Dict[str, List[ArtifactVersionResponse]]:
|
335
351
|
"""The `outputs` property.
|
336
352
|
|
337
353
|
Returns:
|
@@ -466,7 +482,7 @@ class StepRunResponse(
|
|
466
482
|
return self.get_metadata().parent_step_ids
|
467
483
|
|
468
484
|
@property
|
469
|
-
def run_metadata(self) -> Dict[str,
|
485
|
+
def run_metadata(self) -> Dict[str, MetadataType]:
|
470
486
|
"""The `run_metadata` property.
|
471
487
|
|
472
488
|
Returns:
|
@@ -18,17 +18,19 @@ from uuid import UUID
|
|
18
18
|
|
19
19
|
from zenml.client import Client
|
20
20
|
from zenml.config.step_configurations import Step
|
21
|
+
from zenml.enums import ArtifactSaveType, StepRunInputArtifactType
|
21
22
|
from zenml.exceptions import InputResolutionError
|
22
23
|
from zenml.utils import pagination_utils
|
23
24
|
|
24
25
|
if TYPE_CHECKING:
|
25
|
-
from zenml.models import
|
26
|
+
from zenml.models import PipelineRunResponse
|
27
|
+
from zenml.models.v2.core.step_run import StepRunInputResponse
|
26
28
|
|
27
29
|
|
28
30
|
def resolve_step_inputs(
|
29
31
|
step: "Step",
|
30
32
|
pipeline_run: "PipelineRunResponse",
|
31
|
-
) -> Tuple[Dict[str, "
|
33
|
+
) -> Tuple[Dict[str, "StepRunInputResponse"], List[UUID]]:
|
32
34
|
"""Resolves inputs for the current step.
|
33
35
|
|
34
36
|
Args:
|
@@ -45,7 +47,8 @@ def resolve_step_inputs(
|
|
45
47
|
The IDs of the input artifact versions and the IDs of parent steps of
|
46
48
|
the current step.
|
47
49
|
"""
|
48
|
-
from zenml.models import ArtifactVersionResponse
|
50
|
+
from zenml.models import ArtifactVersionResponse
|
51
|
+
from zenml.models.v2.core.step_run import StepRunInputResponse
|
49
52
|
|
50
53
|
current_run_steps = {
|
51
54
|
run_step.name: run_step
|
@@ -54,7 +57,7 @@ def resolve_step_inputs(
|
|
54
57
|
)
|
55
58
|
}
|
56
59
|
|
57
|
-
input_artifacts: Dict[str,
|
60
|
+
input_artifacts: Dict[str, StepRunInputResponse] = {}
|
58
61
|
for name, input_ in step.spec.inputs.items():
|
59
62
|
try:
|
60
63
|
step_run = current_run_steps[input_.step_name]
|
@@ -64,22 +67,44 @@ def resolve_step_inputs(
|
|
64
67
|
)
|
65
68
|
|
66
69
|
try:
|
67
|
-
|
70
|
+
outputs = step_run.outputs[input_.output_name]
|
68
71
|
except KeyError:
|
69
72
|
raise InputResolutionError(
|
70
|
-
f"No output `{input_.output_name}` found for step "
|
73
|
+
f"No step output `{input_.output_name}` found for step "
|
71
74
|
f"`{input_.step_name}`."
|
72
75
|
)
|
73
76
|
|
74
|
-
|
77
|
+
step_outputs = [
|
78
|
+
output
|
79
|
+
for output in outputs
|
80
|
+
if output.save_type == ArtifactSaveType.STEP_OUTPUT
|
81
|
+
]
|
82
|
+
if len(step_outputs) > 2:
|
83
|
+
# This should never happen, there can only be a single regular step
|
84
|
+
# output for a name
|
85
|
+
raise InputResolutionError(
|
86
|
+
f"Too many step outputs for output `{input_.output_name}` of "
|
87
|
+
f"step `{input_.step_name}`."
|
88
|
+
)
|
89
|
+
elif len(step_outputs) == 0:
|
90
|
+
raise InputResolutionError(
|
91
|
+
f"No step output `{input_.output_name}` found for step "
|
92
|
+
f"`{input_.step_name}`."
|
93
|
+
)
|
94
|
+
|
95
|
+
input_artifacts[name] = StepRunInputResponse(
|
96
|
+
input_type=StepRunInputArtifactType.STEP_OUTPUT,
|
97
|
+
**step_outputs[0].model_dump(),
|
98
|
+
)
|
75
99
|
|
76
100
|
for (
|
77
101
|
name,
|
78
102
|
external_artifact,
|
79
103
|
) in step.config.external_input_artifacts.items():
|
80
104
|
artifact_version_id = external_artifact.get_artifact_version_id()
|
81
|
-
input_artifacts[name] =
|
82
|
-
|
105
|
+
input_artifacts[name] = StepRunInputResponse(
|
106
|
+
input_type=StepRunInputArtifactType.EXTERNAL,
|
107
|
+
**Client().get_artifact_version(artifact_version_id).model_dump(),
|
83
108
|
)
|
84
109
|
|
85
110
|
for name, config_ in step.config.model_artifacts_or_metadata.items():
|
@@ -98,9 +123,7 @@ def resolve_step_inputs(
|
|
98
123
|
):
|
99
124
|
# metadata values should go directly in parameters, as primitive types
|
100
125
|
step.config.parameters[name] = (
|
101
|
-
context_model_version.run_metadata[
|
102
|
-
config_.metadata_name
|
103
|
-
].value
|
126
|
+
context_model_version.run_metadata[config_.metadata_name]
|
104
127
|
)
|
105
128
|
elif config_.artifact_name is None:
|
106
129
|
err_msg = (
|
@@ -112,14 +135,15 @@ def resolve_step_inputs(
|
|
112
135
|
config_.artifact_name, config_.artifact_version
|
113
136
|
):
|
114
137
|
if config_.metadata_name is None:
|
115
|
-
input_artifacts[name] =
|
138
|
+
input_artifacts[name] = StepRunInputResponse(
|
139
|
+
input_type=StepRunInputArtifactType.LAZY_LOADED,
|
140
|
+
**artifact_.model_dump(),
|
141
|
+
)
|
116
142
|
elif config_.metadata_name:
|
117
143
|
# metadata values should go directly in parameters, as primitive types
|
118
144
|
try:
|
119
145
|
step.config.parameters[name] = (
|
120
|
-
artifact_.run_metadata[
|
121
|
-
config_.metadata_name
|
122
|
-
].value
|
146
|
+
artifact_.run_metadata[config_.metadata_name]
|
123
147
|
)
|
124
148
|
except KeyError:
|
125
149
|
err_msg = (
|
@@ -141,9 +165,10 @@ def resolve_step_inputs(
|
|
141
165
|
for name, cll_ in step.config.client_lazy_loaders.items():
|
142
166
|
value_ = cll_.evaluate()
|
143
167
|
if isinstance(value_, ArtifactVersionResponse):
|
144
|
-
input_artifacts[name] =
|
145
|
-
|
146
|
-
|
168
|
+
input_artifacts[name] = StepRunInputResponse(
|
169
|
+
input_type=StepRunInputArtifactType.LAZY_LOADED,
|
170
|
+
**value_.model_dump(),
|
171
|
+
)
|
147
172
|
else:
|
148
173
|
step.config.parameters[name] = value_
|
149
174
|
|
@@ -33,13 +33,13 @@ from zenml.environment import get_run_environment_dict
|
|
33
33
|
from zenml.logger import get_logger
|
34
34
|
from zenml.logging import step_logging
|
35
35
|
from zenml.models import (
|
36
|
-
ArtifactVersionResponse,
|
37
36
|
LogsRequest,
|
38
37
|
PipelineDeploymentResponse,
|
39
38
|
PipelineRunRequest,
|
40
39
|
PipelineRunResponse,
|
41
40
|
StepRunResponse,
|
42
41
|
)
|
42
|
+
from zenml.models.v2.core.step_run import StepRunInputResponse
|
43
43
|
from zenml.orchestrators import output_utils, publish_utils, step_run_utils
|
44
44
|
from zenml.orchestrators import utils as orchestrator_utils
|
45
45
|
from zenml.orchestrators.step_runner import StepRunner
|
@@ -442,7 +442,7 @@ class StepLauncher:
|
|
442
442
|
pipeline_run: PipelineRunResponse,
|
443
443
|
step_run: StepRunResponse,
|
444
444
|
step_run_info: StepRunInfo,
|
445
|
-
input_artifacts: Dict[str,
|
445
|
+
input_artifacts: Dict[str, StepRunInputResponse],
|
446
446
|
output_artifact_uris: Dict[str, str],
|
447
447
|
last_retry: bool,
|
448
448
|
) -> None:
|
@@ -14,12 +14,12 @@
|
|
14
14
|
"""Utilities for creating step runs."""
|
15
15
|
|
16
16
|
from datetime import datetime
|
17
|
-
from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple
|
17
|
+
from typing import TYPE_CHECKING, Dict, List, Mapping, Optional, Set, Tuple
|
18
18
|
|
19
19
|
from zenml.client import Client
|
20
20
|
from zenml.config.step_configurations import ArtifactConfiguration, Step
|
21
21
|
from zenml.constants import CODE_HASH_PARAMETER_NAME, TEXT_FIELD_MAX_LENGTH
|
22
|
-
from zenml.enums import ExecutionStatus
|
22
|
+
from zenml.enums import ArtifactSaveType, ExecutionStatus
|
23
23
|
from zenml.logger import get_logger
|
24
24
|
from zenml.model.utils import link_artifact_version_to_model_version
|
25
25
|
from zenml.models import (
|
@@ -104,6 +104,7 @@ class StepRunRequestFactory:
|
|
104
104
|
input_name: artifact.id
|
105
105
|
for input_name, artifact in input_artifacts.items()
|
106
106
|
}
|
107
|
+
|
107
108
|
request.inputs = input_artifact_ids
|
108
109
|
request.parent_step_ids = parent_step_ids
|
109
110
|
|
@@ -142,8 +143,8 @@ class StepRunRequestFactory:
|
|
142
143
|
|
143
144
|
request.original_step_run_id = cached_step_run.id
|
144
145
|
request.outputs = {
|
145
|
-
output_name: artifact.id
|
146
|
-
for output_name,
|
146
|
+
output_name: [artifact.id for artifact in artifacts]
|
147
|
+
for output_name, artifacts in cached_step_run.outputs.items()
|
147
148
|
}
|
148
149
|
|
149
150
|
request.status = ExecutionStatus.CACHED
|
@@ -551,7 +552,7 @@ def link_pipeline_run_to_model_version(
|
|
551
552
|
|
552
553
|
|
553
554
|
def link_output_artifacts_to_model_version(
|
554
|
-
artifacts: Dict[str, ArtifactVersionResponse],
|
555
|
+
artifacts: Dict[str, List[ArtifactVersionResponse]],
|
555
556
|
output_configurations: Mapping[str, ArtifactConfiguration],
|
556
557
|
model_version: ModelVersionResponse,
|
557
558
|
) -> None:
|
@@ -562,13 +563,16 @@ def link_output_artifacts_to_model_version(
|
|
562
563
|
output_configurations: The output configurations for the step.
|
563
564
|
model_version: The model version to link.
|
564
565
|
"""
|
565
|
-
for output_name,
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
|
570
|
-
|
571
|
-
|
572
|
-
|
573
|
-
|
574
|
-
|
566
|
+
for output_name, output_artifacts in artifacts.items():
|
567
|
+
for output_artifact in output_artifacts:
|
568
|
+
artifact_config = None
|
569
|
+
if output_artifact.save_type == ArtifactSaveType.STEP_OUTPUT and (
|
570
|
+
output_config := output_configurations.get(output_name, None)
|
571
|
+
):
|
572
|
+
artifact_config = output_config.artifact_config
|
573
|
+
|
574
|
+
link_artifact_version_to_model_version(
|
575
|
+
artifact_version=output_artifact,
|
576
|
+
model_version=model_version,
|
577
|
+
artifact_config=artifact_config,
|
578
|
+
)
|
@@ -28,7 +28,8 @@ from typing import (
|
|
28
28
|
)
|
29
29
|
|
30
30
|
from zenml.artifacts.unmaterialized_artifact import UnmaterializedArtifact
|
31
|
-
from zenml.artifacts.utils import
|
31
|
+
from zenml.artifacts.utils import _store_artifact_data_and_prepare_request
|
32
|
+
from zenml.client import Client
|
32
33
|
from zenml.config.step_configurations import StepConfiguration
|
33
34
|
from zenml.config.step_run_info import StepRunInfo
|
34
35
|
from zenml.constants import (
|
@@ -36,10 +37,12 @@ from zenml.constants import (
|
|
36
37
|
ENV_ZENML_IGNORE_FAILURE_HOOK,
|
37
38
|
handle_bool_env_var,
|
38
39
|
)
|
40
|
+
from zenml.enums import ArtifactSaveType
|
39
41
|
from zenml.exceptions import StepInterfaceError
|
40
42
|
from zenml.logger import get_logger
|
41
43
|
from zenml.logging.step_logging import StepLogsStorageContext, redirected
|
42
44
|
from zenml.materializers.base_materializer import BaseMaterializer
|
45
|
+
from zenml.models.v2.core.step_run import StepRunInputResponse
|
43
46
|
from zenml.orchestrators.publish_utils import (
|
44
47
|
publish_step_run_metadata,
|
45
48
|
publish_successful_step_run,
|
@@ -98,7 +101,7 @@ class StepRunner:
|
|
98
101
|
self,
|
99
102
|
pipeline_run: "PipelineRunResponse",
|
100
103
|
step_run: "StepRunResponse",
|
101
|
-
input_artifacts: Dict[str,
|
104
|
+
input_artifacts: Dict[str, StepRunInputResponse],
|
102
105
|
output_artifact_uris: Dict[str, str],
|
103
106
|
step_run_info: StepRunInfo,
|
104
107
|
) -> None:
|
@@ -241,7 +244,9 @@ class StepRunner:
|
|
241
244
|
from zenml.orchestrators import step_run_utils
|
242
245
|
|
243
246
|
step_run_utils.link_output_artifacts_to_model_version(
|
244
|
-
artifacts=
|
247
|
+
artifacts={
|
248
|
+
k: [v] for k, v in output_artifacts.items()
|
249
|
+
},
|
245
250
|
output_configurations=step_run.config.outputs,
|
246
251
|
model_version=model_version,
|
247
252
|
)
|
@@ -302,7 +307,7 @@ class StepRunner:
|
|
302
307
|
self,
|
303
308
|
args: List[str],
|
304
309
|
annotations: Dict[str, Any],
|
305
|
-
input_artifacts: Dict[str,
|
310
|
+
input_artifacts: Dict[str, StepRunInputResponse],
|
306
311
|
) -> Dict[str, Any]:
|
307
312
|
"""Parses the inputs for a step entrypoint function.
|
308
313
|
|
@@ -534,7 +539,7 @@ class StepRunner:
|
|
534
539
|
The IDs of the published output artifacts.
|
535
540
|
"""
|
536
541
|
step_context = get_step_context()
|
537
|
-
|
542
|
+
artifact_requests = []
|
538
543
|
|
539
544
|
for output_name, return_value in output_data.items():
|
540
545
|
data_type = type(return_value)
|
@@ -595,22 +600,25 @@ class StepRunner:
|
|
595
600
|
# Get full set of tags
|
596
601
|
tags = step_context.get_output_tags(output_name)
|
597
602
|
|
598
|
-
|
603
|
+
artifact_request = _store_artifact_data_and_prepare_request(
|
599
604
|
name=artifact_name,
|
600
605
|
data=return_value,
|
601
|
-
|
606
|
+
materializer_class=materializer_class,
|
602
607
|
uri=uri,
|
603
|
-
|
604
|
-
|
608
|
+
store_metadata=artifact_metadata_enabled,
|
609
|
+
store_visualizations=artifact_visualization_enabled,
|
605
610
|
has_custom_name=has_custom_name,
|
606
611
|
version=version,
|
607
612
|
tags=tags,
|
608
|
-
|
609
|
-
|
613
|
+
save_type=ArtifactSaveType.STEP_OUTPUT,
|
614
|
+
metadata=user_metadata,
|
610
615
|
)
|
611
|
-
|
616
|
+
artifact_requests.append(artifact_request)
|
612
617
|
|
613
|
-
|
618
|
+
responses = Client().zen_store.batch_create_artifact_versions(
|
619
|
+
artifact_requests
|
620
|
+
)
|
621
|
+
return dict(zip(output_data.keys(), responses))
|
614
622
|
|
615
623
|
def load_and_run_hook(
|
616
624
|
self,
|
zenml/steps/base_step.py
CHANGED
@@ -327,12 +327,12 @@ class BaseStep:
|
|
327
327
|
The artifacts, external artifacts, model version artifacts/metadata and parameters for the step.
|
328
328
|
"""
|
329
329
|
from zenml.artifacts.external_artifact import ExternalArtifact
|
330
|
+
from zenml.metadata.lazy_load import LazyRunMetadataResponse
|
330
331
|
from zenml.model.lazy_load import ModelVersionDataLazyLoader
|
331
332
|
from zenml.models.v2.core.artifact_version import (
|
332
333
|
ArtifactVersionResponse,
|
333
334
|
LazyArtifactVersionResponse,
|
334
335
|
)
|
335
|
-
from zenml.models.v2.core.run_metadata import LazyRunMetadataResponse
|
336
336
|
|
337
337
|
signature = inspect.signature(self.entrypoint, follow_wrapped=True)
|
338
338
|
|
@@ -32,6 +32,7 @@ from zenml.constants import ENFORCE_TYPE_ANNOTATIONS
|
|
32
32
|
from zenml.exceptions import StepInterfaceError
|
33
33
|
from zenml.logger import get_logger
|
34
34
|
from zenml.materializers.base_materializer import BaseMaterializer
|
35
|
+
from zenml.metadata.lazy_load import LazyRunMetadataResponse
|
35
36
|
from zenml.steps.utils import (
|
36
37
|
OutputSignature,
|
37
38
|
parse_return_type_annotations,
|
@@ -136,10 +137,7 @@ class EntrypointFunctionDefinition(NamedTuple):
|
|
136
137
|
UnmaterializedArtifact,
|
137
138
|
)
|
138
139
|
from zenml.client_lazy_loader import ClientLazyLoader
|
139
|
-
from zenml.models import
|
140
|
-
ArtifactVersionResponse,
|
141
|
-
RunMetadataResponse,
|
142
|
-
)
|
140
|
+
from zenml.models import ArtifactVersionResponse
|
143
141
|
|
144
142
|
if key not in self.inputs:
|
145
143
|
raise KeyError(
|
@@ -154,8 +152,8 @@ class EntrypointFunctionDefinition(NamedTuple):
|
|
154
152
|
StepArtifact,
|
155
153
|
ExternalArtifact,
|
156
154
|
ArtifactVersionResponse,
|
157
|
-
RunMetadataResponse,
|
158
155
|
ClientLazyLoader,
|
156
|
+
LazyRunMetadataResponse,
|
159
157
|
),
|
160
158
|
):
|
161
159
|
# If we were to do any type validation for artifacts here, we
|
zenml/steps/step_context.py
CHANGED
@@ -35,11 +35,12 @@ if TYPE_CHECKING:
|
|
35
35
|
from zenml.metadata.metadata_types import MetadataType
|
36
36
|
from zenml.model.model import Model
|
37
37
|
from zenml.models import (
|
38
|
-
ArtifactVersionResponse,
|
39
38
|
PipelineResponse,
|
40
39
|
PipelineRunResponse,
|
41
40
|
StepRunResponse,
|
42
41
|
)
|
42
|
+
from zenml.models.v2.core.step_run import StepRunInputResponse
|
43
|
+
|
43
44
|
|
44
45
|
logger = get_logger(__name__)
|
45
46
|
|
@@ -191,7 +192,7 @@ class StepContext(metaclass=SingletonMetaClass):
|
|
191
192
|
return self.model_version.to_model_class()
|
192
193
|
|
193
194
|
@property
|
194
|
-
def inputs(self) -> Dict[str, "
|
195
|
+
def inputs(self) -> Dict[str, "StepRunInputResponse"]:
|
195
196
|
"""Returns the input artifacts of the current step.
|
196
197
|
|
197
198
|
Returns:
|
zenml/steps/utils.py
CHANGED
@@ -26,7 +26,11 @@ from typing_extensions import Annotated
|
|
26
26
|
|
27
27
|
from zenml.artifacts.artifact_config import ArtifactConfig
|
28
28
|
from zenml.client import Client
|
29
|
-
from zenml.enums import
|
29
|
+
from zenml.enums import (
|
30
|
+
ArtifactSaveType,
|
31
|
+
ExecutionStatus,
|
32
|
+
MetadataResourceTypes,
|
33
|
+
)
|
30
34
|
from zenml.exceptions import StepInterfaceError
|
31
35
|
from zenml.logger import get_logger
|
32
36
|
from zenml.metadata.metadata_types import MetadataType
|
@@ -547,8 +551,10 @@ def run_as_single_step_pipeline(
|
|
547
551
|
# 4. Load output artifacts
|
548
552
|
step_run = next(iter(run.steps.values()))
|
549
553
|
outputs = [
|
550
|
-
|
554
|
+
artifact_version.load()
|
551
555
|
for output_name in step_run.config.outputs.keys()
|
556
|
+
for artifact_version in step_run.outputs[output_name]
|
557
|
+
if artifact_version.save_type == ArtifactSaveType.STEP_OUTPUT
|
552
558
|
]
|
553
559
|
|
554
560
|
if len(outputs) == 0:
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# permissions and limitations under the License.
|
14
14
|
"""High-level helper functions to write endpoints with RBAC."""
|
15
15
|
|
16
|
-
from typing import Any, Callable, TypeVar, Union
|
16
|
+
from typing import Any, Callable, List, TypeVar, Union
|
17
17
|
from uuid import UUID
|
18
18
|
|
19
19
|
from pydantic import BaseModel
|
@@ -96,6 +96,48 @@ def verify_permissions_and_create_entity(
|
|
96
96
|
return created
|
97
97
|
|
98
98
|
|
99
|
+
def verify_permissions_and_batch_create_entity(
|
100
|
+
batch: List[AnyRequest],
|
101
|
+
resource_type: ResourceType,
|
102
|
+
create_method: Callable[[List[AnyRequest]], List[AnyResponse]],
|
103
|
+
) -> List[AnyResponse]:
|
104
|
+
"""Verify permissions and create a batch of entities if authorized.
|
105
|
+
|
106
|
+
Args:
|
107
|
+
batch: The batch to create.
|
108
|
+
resource_type: The resource type of the entities to create.
|
109
|
+
create_method: The method to create the entities.
|
110
|
+
|
111
|
+
Raises:
|
112
|
+
IllegalOperationError: If the request model has a different owner then
|
113
|
+
the currently authenticated user.
|
114
|
+
RuntimeError: If the resource type is usage-tracked.
|
115
|
+
|
116
|
+
Returns:
|
117
|
+
The created entities.
|
118
|
+
"""
|
119
|
+
auth_context = get_auth_context()
|
120
|
+
assert auth_context
|
121
|
+
|
122
|
+
for request_model in batch:
|
123
|
+
if isinstance(request_model, UserScopedRequest):
|
124
|
+
if request_model.user != auth_context.user.id:
|
125
|
+
raise IllegalOperationError(
|
126
|
+
f"Not allowed to create resource '{resource_type}' for a "
|
127
|
+
"different user."
|
128
|
+
)
|
129
|
+
|
130
|
+
verify_permission(resource_type=resource_type, action=Action.CREATE)
|
131
|
+
|
132
|
+
if resource_type in REPORTABLE_RESOURCES:
|
133
|
+
raise RuntimeError(
|
134
|
+
"Batch requests are currently not possible with usage-tracked features."
|
135
|
+
)
|
136
|
+
|
137
|
+
created = create_method(batch)
|
138
|
+
return created
|
139
|
+
|
140
|
+
|
99
141
|
def verify_permissions_and_get_entity(
|
100
142
|
id: UUIDOrStr,
|
101
143
|
get_method: Callable[[UUIDOrStr], AnyResponse],
|
zenml/zen_server/rbac/utils.py
CHANGED
@@ -404,7 +404,6 @@ def get_resource_type_for_model(
|
|
404
404
|
PipelineDeploymentResponse,
|
405
405
|
PipelineResponse,
|
406
406
|
PipelineRunResponse,
|
407
|
-
RunMetadataResponse,
|
408
407
|
RunTemplateResponse,
|
409
408
|
SecretResponse,
|
410
409
|
ServiceAccountResponse,
|
@@ -437,7 +436,6 @@ def get_resource_type_for_model(
|
|
437
436
|
ArtifactVersionResponse: ResourceType.ARTIFACT_VERSION,
|
438
437
|
WorkspaceResponse: ResourceType.WORKSPACE,
|
439
438
|
UserResponse: ResourceType.USER,
|
440
|
-
RunMetadataResponse: ResourceType.RUN_METADATA,
|
441
439
|
PipelineDeploymentResponse: ResourceType.PIPELINE_DEPLOYMENT,
|
442
440
|
PipelineBuildResponse: ResourceType.PIPELINE_BUILD,
|
443
441
|
PipelineRunResponse: ResourceType.PIPELINE_RUN,
|