zenml-nightly 0.68.1.dev20241107__py3-none-any.whl → 0.68.1.dev20241109__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. zenml/VERSION +1 -1
  2. zenml/artifacts/external_artifact.py +2 -1
  3. zenml/artifacts/utils.py +13 -20
  4. zenml/cli/base.py +4 -4
  5. zenml/cli/model.py +1 -6
  6. zenml/cli/stack.py +1 -0
  7. zenml/client.py +21 -73
  8. zenml/data_validators/base_data_validator.py +2 -2
  9. zenml/enums.py +12 -4
  10. zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +1 -1
  11. zenml/integrations/azure/orchestrators/azureml_orchestrator.py +1 -1
  12. zenml/integrations/deepchecks/data_validators/deepchecks_data_validator.py +1 -1
  13. zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +1 -1
  14. zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py +60 -54
  15. zenml/metadata/lazy_load.py +20 -7
  16. zenml/model/model.py +1 -2
  17. zenml/models/__init__.py +0 -12
  18. zenml/models/v2/core/artifact_version.py +19 -7
  19. zenml/models/v2/core/model_version.py +3 -5
  20. zenml/models/v2/core/pipeline_run.py +3 -5
  21. zenml/models/v2/core/run_metadata.py +2 -217
  22. zenml/models/v2/core/step_run.py +40 -24
  23. zenml/orchestrators/input_utils.py +44 -19
  24. zenml/orchestrators/step_launcher.py +2 -2
  25. zenml/orchestrators/step_run_utils.py +19 -15
  26. zenml/orchestrators/step_runner.py +8 -3
  27. zenml/steps/base_step.py +1 -1
  28. zenml/steps/entrypoint_function_utils.py +3 -5
  29. zenml/steps/step_context.py +3 -2
  30. zenml/steps/utils.py +8 -2
  31. zenml/zen_server/rbac/utils.py +0 -2
  32. zenml/zen_server/routers/workspaces_endpoints.py +3 -4
  33. zenml/zen_server/zen_server_api.py +0 -2
  34. zenml/zen_stores/migrations/versions/1cb6477f72d6_move_artifact_save_type.py +99 -0
  35. zenml/zen_stores/migrations/versions/b557b2871693_update_step_run_input_types.py +33 -0
  36. zenml/zen_stores/rest_zen_store.py +3 -54
  37. zenml/zen_stores/schemas/artifact_schemas.py +8 -1
  38. zenml/zen_stores/schemas/model_schemas.py +2 -2
  39. zenml/zen_stores/schemas/pipeline_run_schemas.py +1 -1
  40. zenml/zen_stores/schemas/run_metadata_schemas.py +1 -48
  41. zenml/zen_stores/schemas/step_run_schemas.py +18 -10
  42. zenml/zen_stores/sql_zen_store.py +52 -98
  43. zenml/zen_stores/zen_store_interface.py +2 -42
  44. {zenml_nightly-0.68.1.dev20241107.dist-info → zenml_nightly-0.68.1.dev20241109.dist-info}/METADATA +1 -1
  45. {zenml_nightly-0.68.1.dev20241107.dist-info → zenml_nightly-0.68.1.dev20241109.dist-info}/RECORD +48 -47
  46. zenml/zen_server/routers/run_metadata_endpoints.py +0 -96
  47. {zenml_nightly-0.68.1.dev20241107.dist-info → zenml_nightly-0.68.1.dev20241109.dist-info}/LICENSE +0 -0
  48. {zenml_nightly-0.68.1.dev20241107.dist-info → zenml_nightly-0.68.1.dev20241109.dist-info}/WHEEL +0 -0
  49. {zenml_nightly-0.68.1.dev20241107.dist-info → zenml_nightly-0.68.1.dev20241109.dist-info}/entry_points.txt +0 -0
zenml/models/__init__.py CHANGED
@@ -239,12 +239,7 @@ from zenml.models.v2.core.run_template import (
239
239
  )
240
240
  from zenml.models.v2.base.base_plugin_flavor import BasePluginFlavorResponse
241
241
  from zenml.models.v2.core.run_metadata import (
242
- LazyRunMetadataResponse,
243
242
  RunMetadataRequest,
244
- RunMetadataFilter,
245
- RunMetadataResponse,
246
- RunMetadataResponseBody,
247
- RunMetadataResponseMetadata,
248
243
  )
249
244
  from zenml.models.v2.core.schedule import (
250
245
  ScheduleRequest,
@@ -418,7 +413,6 @@ EventSourceResponseResources.model_rebuild()
418
413
  FlavorResponseBody.model_rebuild()
419
414
  FlavorResponseMetadata.model_rebuild()
420
415
  LazyArtifactVersionResponse.model_rebuild()
421
- LazyRunMetadataResponse.model_rebuild()
422
416
  ModelResponseBody.model_rebuild()
423
417
  ModelResponseMetadata.model_rebuild()
424
418
  ModelVersionResponseBody.model_rebuild()
@@ -444,8 +438,6 @@ RunTemplateResponseBody.model_rebuild()
444
438
  RunTemplateResponseMetadata.model_rebuild()
445
439
  RunTemplateResponseResources.model_rebuild()
446
440
  RunTemplateResponseBody.model_rebuild()
447
- RunMetadataResponseBody.model_rebuild()
448
- RunMetadataResponseMetadata.model_rebuild()
449
441
  ScheduleResponseBody.model_rebuild()
450
442
  ScheduleResponseMetadata.model_rebuild()
451
443
  SecretResponseBody.model_rebuild()
@@ -637,10 +629,6 @@ __all__ = [
637
629
  "RunTemplateResponseResources",
638
630
  "RunTemplateFilter",
639
631
  "RunMetadataRequest",
640
- "RunMetadataFilter",
641
- "RunMetadataResponse",
642
- "RunMetadataResponseBody",
643
- "RunMetadataResponseMetadata",
644
632
  "ScheduleRequest",
645
633
  "ScheduleUpdate",
646
634
  "ScheduleFilter",
@@ -34,7 +34,7 @@ from pydantic import (
34
34
 
35
35
  from zenml.config.source import Source, SourceWithValidator
36
36
  from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH
37
- from zenml.enums import ArtifactType, GenericFilterOps
37
+ from zenml.enums import ArtifactSaveType, ArtifactType, GenericFilterOps
38
38
  from zenml.logger import get_logger
39
39
  from zenml.metadata.metadata_types import MetadataType
40
40
  from zenml.models.v2.base.filter import StrFilter
@@ -57,9 +57,6 @@ if TYPE_CHECKING:
57
57
  ArtifactVisualizationResponse,
58
58
  )
59
59
  from zenml.models.v2.core.pipeline_run import PipelineRunResponse
60
- from zenml.models.v2.core.run_metadata import (
61
- RunMetadataResponse,
62
- )
63
60
  from zenml.models.v2.core.step_run import StepRunResponse
64
61
 
65
62
  logger = get_logger(__name__)
@@ -107,6 +104,9 @@ class ArtifactVersionRequest(WorkspaceScopedRequest):
107
104
  visualizations: Optional[List["ArtifactVisualizationRequest"]] = Field(
108
105
  default=None, title="Visualizations of the artifact."
109
106
  )
107
+ save_type: ArtifactSaveType = Field(
108
+ title="The save type of the artifact version.",
109
+ )
110
110
  metadata: Optional[Dict[str, MetadataType]] = Field(
111
111
  default=None, title="Metadata of the artifact version."
112
112
  )
@@ -193,6 +193,9 @@ class ArtifactVersionResponseBody(WorkspaceScopedResponseBody):
193
193
  title="The ID of the pipeline run that generated this artifact version.",
194
194
  default=None,
195
195
  )
196
+ save_type: ArtifactSaveType = Field(
197
+ title="The save type of the artifact version.",
198
+ )
196
199
  artifact_store_id: Optional[UUID] = Field(
197
200
  title="ID of the artifact store in which this artifact is stored.",
198
201
  default=None,
@@ -230,7 +233,7 @@ class ArtifactVersionResponseMetadata(WorkspaceScopedResponseMetadata):
230
233
  visualizations: Optional[List["ArtifactVisualizationResponse"]] = Field(
231
234
  default=None, title="Visualizations of the artifact."
232
235
  )
233
- run_metadata: Dict[str, "RunMetadataResponse"] = Field(
236
+ run_metadata: Dict[str, MetadataType] = Field(
234
237
  default={}, title="Metadata of the artifact."
235
238
  )
236
239
 
@@ -313,6 +316,15 @@ class ArtifactVersionResponse(
313
316
  """
314
317
  return self.get_body().producer_pipeline_run_id
315
318
 
319
+ @property
320
+ def save_type(self) -> ArtifactSaveType:
321
+ """The `save_type` property.
322
+
323
+ Returns:
324
+ the value of the property.
325
+ """
326
+ return self.get_body().save_type
327
+
316
328
  @property
317
329
  def artifact_store_id(self) -> Optional[UUID]:
318
330
  """The `artifact_store_id` property.
@@ -343,7 +355,7 @@ class ArtifactVersionResponse(
343
355
  return self.get_metadata().visualizations
344
356
 
345
357
  @property
346
- def run_metadata(self) -> Dict[str, "RunMetadataResponse"]:
358
+ def run_metadata(self) -> Dict[str, MetadataType]:
347
359
  """The `metadata` property.
348
360
 
349
361
  Returns:
@@ -671,7 +683,7 @@ class LazyArtifactVersionResponse(ArtifactVersionResponse):
671
683
  )
672
684
 
673
685
  @property
674
- def run_metadata(self) -> Dict[str, "RunMetadataResponse"]:
686
+ def run_metadata(self) -> Dict[str, MetadataType]:
675
687
  """The `metadata` property in lazy loading mode.
676
688
 
677
689
  Returns:
@@ -29,6 +29,7 @@ from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, field_validator
29
29
 
30
30
  from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH
31
31
  from zenml.enums import ModelStages
32
+ from zenml.metadata.metadata_types import MetadataType
32
33
  from zenml.models.v2.base.filter import AnyQuery
33
34
  from zenml.models.v2.base.page import Page
34
35
  from zenml.models.v2.base.scoped import (
@@ -49,9 +50,6 @@ if TYPE_CHECKING:
49
50
  from zenml.models.v2.core.artifact_version import ArtifactVersionResponse
50
51
  from zenml.models.v2.core.model import ModelResponse
51
52
  from zenml.models.v2.core.pipeline_run import PipelineRunResponse
52
- from zenml.models.v2.core.run_metadata import (
53
- RunMetadataResponse,
54
- )
55
53
  from zenml.zen_stores.schemas import BaseSchema
56
54
 
57
55
  AnySchema = TypeVar("AnySchema", bound=BaseSchema)
@@ -193,7 +191,7 @@ class ModelVersionResponseMetadata(WorkspaceScopedResponseMetadata):
193
191
  max_length=TEXT_FIELD_MAX_LENGTH,
194
192
  default=None,
195
193
  )
196
- run_metadata: Dict[str, "RunMetadataResponse"] = Field(
194
+ run_metadata: Dict[str, MetadataType] = Field(
197
195
  description="Metadata linked to the model version",
198
196
  default={},
199
197
  )
@@ -304,7 +302,7 @@ class ModelVersionResponse(
304
302
  return self.get_metadata().description
305
303
 
306
304
  @property
307
- def run_metadata(self) -> Optional[Dict[str, "RunMetadataResponse"]]:
305
+ def run_metadata(self) -> Dict[str, MetadataType]:
308
306
  """The `run_metadata` property.
309
307
 
310
308
  Returns:
@@ -30,6 +30,7 @@ from pydantic import BaseModel, ConfigDict, Field
30
30
  from zenml.config.pipeline_configurations import PipelineConfiguration
31
31
  from zenml.constants import STR_FIELD_MAX_LENGTH
32
32
  from zenml.enums import ExecutionStatus
33
+ from zenml.metadata.metadata_types import MetadataType
33
34
  from zenml.models.v2.base.scoped import (
34
35
  WorkspaceScopedRequest,
35
36
  WorkspaceScopedResponse,
@@ -51,9 +52,6 @@ if TYPE_CHECKING:
51
52
  from zenml.models.v2.core.pipeline_build import (
52
53
  PipelineBuildResponse,
53
54
  )
54
- from zenml.models.v2.core.run_metadata import (
55
- RunMetadataResponse,
56
- )
57
55
  from zenml.models.v2.core.schedule import ScheduleResponse
58
56
  from zenml.models.v2.core.stack import StackResponse
59
57
  from zenml.models.v2.core.step_run import StepRunResponse
@@ -190,7 +188,7 @@ class PipelineRunResponseBody(WorkspaceScopedResponseBody):
190
188
  class PipelineRunResponseMetadata(WorkspaceScopedResponseMetadata):
191
189
  """Response metadata for pipeline runs."""
192
190
 
193
- run_metadata: Dict[str, "RunMetadataResponse"] = Field(
191
+ run_metadata: Dict[str, MetadataType] = Field(
194
192
  default={},
195
193
  title="Metadata associated with this pipeline run.",
196
194
  )
@@ -450,7 +448,7 @@ class PipelineRunResponse(
450
448
  return self.get_body().model_version_id
451
449
 
452
450
  @property
453
- def run_metadata(self) -> Dict[str, "RunMetadataResponse"]:
451
+ def run_metadata(self) -> Dict[str, MetadataType]:
454
452
  """The `run_metadata` property.
455
453
 
456
454
  Returns:
@@ -13,21 +13,15 @@
13
13
  # permissions and limitations under the License.
14
14
  """Models representing run metadata."""
15
15
 
16
- from typing import Any, Dict, Optional, Union
16
+ from typing import Dict, Optional
17
17
  from uuid import UUID
18
18
 
19
- from pydantic import Field, field_validator
19
+ from pydantic import Field
20
20
 
21
- from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH
22
21
  from zenml.enums import MetadataResourceTypes
23
22
  from zenml.metadata.metadata_types import MetadataType, MetadataTypeEnum
24
23
  from zenml.models.v2.base.scoped import (
25
- WorkspaceScopedFilter,
26
24
  WorkspaceScopedRequest,
27
- WorkspaceScopedResponse,
28
- WorkspaceScopedResponseBody,
29
- WorkspaceScopedResponseMetadata,
30
- WorkspaceScopedResponseResources,
31
25
  )
32
26
 
33
27
  # ------------------ Request Model ------------------
@@ -51,212 +45,3 @@ class RunMetadataRequest(WorkspaceScopedRequest):
51
45
  types: Dict[str, "MetadataTypeEnum"] = Field(
52
46
  title="The types of the metadata to be created.",
53
47
  )
54
-
55
-
56
- # ------------------ Update Model ------------------
57
-
58
- # There is no update model for run metadata.
59
-
60
- # ------------------ Response Model ------------------
61
-
62
-
63
- class RunMetadataResponseBody(WorkspaceScopedResponseBody):
64
- """Response body for run metadata."""
65
-
66
- key: str = Field(title="The key of the metadata.")
67
- value: MetadataType = Field(
68
- title="The value of the metadata.", union_mode="smart"
69
- )
70
- type: MetadataTypeEnum = Field(title="The type of the metadata.")
71
-
72
- @field_validator("key", "type")
73
- @classmethod
74
- def str_field_max_length_check(cls, value: Any) -> Any:
75
- """Checks if the length of the value exceeds the maximum str length.
76
-
77
- Args:
78
- value: the value set in the field
79
-
80
- Returns:
81
- the value itself.
82
-
83
- Raises:
84
- AssertionError: if the length of the field is longer than the
85
- maximum threshold.
86
- """
87
- assert len(str(value)) < STR_FIELD_MAX_LENGTH, (
88
- "The length of the value for this field can not "
89
- f"exceed {STR_FIELD_MAX_LENGTH}"
90
- )
91
- return value
92
-
93
- @field_validator("value")
94
- @classmethod
95
- def text_field_max_length_check(cls, value: Any) -> Any:
96
- """Checks if the length of the value exceeds the maximum text length.
97
-
98
- Args:
99
- value: the value set in the field
100
-
101
- Returns:
102
- the value itself.
103
-
104
- Raises:
105
- AssertionError: if the length of the field is longer than the
106
- maximum threshold.
107
- """
108
- assert len(str(value)) < TEXT_FIELD_MAX_LENGTH, (
109
- "The length of the value for this field can not "
110
- f"exceed {TEXT_FIELD_MAX_LENGTH}"
111
- )
112
- return value
113
-
114
-
115
- class RunMetadataResponseMetadata(WorkspaceScopedResponseMetadata):
116
- """Response metadata for run metadata."""
117
-
118
- resource_id: UUID = Field(
119
- title="The ID of the resource that this metadata belongs to.",
120
- )
121
- resource_type: MetadataResourceTypes = Field(
122
- title="The type of the resource that this metadata belongs to.",
123
- )
124
- stack_component_id: Optional[UUID] = Field(
125
- title="The ID of the stack component that this metadata belongs to."
126
- )
127
-
128
-
129
- class RunMetadataResponseResources(WorkspaceScopedResponseResources):
130
- """Class for all resource models associated with the run metadata entity."""
131
-
132
-
133
- class RunMetadataResponse(
134
- WorkspaceScopedResponse[
135
- RunMetadataResponseBody,
136
- RunMetadataResponseMetadata,
137
- RunMetadataResponseResources,
138
- ]
139
- ):
140
- """Response model for run metadata."""
141
-
142
- def get_hydrated_version(self) -> "RunMetadataResponse":
143
- """Get the hydrated version of this run metadata.
144
-
145
- Returns:
146
- an instance of the same entity with the metadata field attached.
147
- """
148
- from zenml.client import Client
149
-
150
- return Client().zen_store.get_run_metadata(self.id)
151
-
152
- # Body and metadata properties
153
- @property
154
- def key(self) -> str:
155
- """The `key` property.
156
-
157
- Returns:
158
- the value of the property.
159
- """
160
- return self.get_body().key
161
-
162
- @property
163
- def value(self) -> MetadataType:
164
- """The `value` property.
165
-
166
- Returns:
167
- the value of the property.
168
- """
169
- return self.get_body().value
170
-
171
- @property
172
- def type(self) -> MetadataTypeEnum:
173
- """The `type` property.
174
-
175
- Returns:
176
- the value of the property.
177
- """
178
- return self.get_body().type
179
-
180
- @property
181
- def resource_id(self) -> UUID:
182
- """The `resource_id` property.
183
-
184
- Returns:
185
- the value of the property.
186
- """
187
- return self.get_metadata().resource_id
188
-
189
- @property
190
- def resource_type(self) -> MetadataResourceTypes:
191
- """The `resource_type` property.
192
-
193
- Returns:
194
- the value of the property.
195
- """
196
- return MetadataResourceTypes(self.get_metadata().resource_type)
197
-
198
- @property
199
- def stack_component_id(self) -> Optional[UUID]:
200
- """The `stack_component_id` property.
201
-
202
- Returns:
203
- the value of the property.
204
- """
205
- return self.get_metadata().stack_component_id
206
-
207
-
208
- # ------------------ Filter Model ------------------
209
-
210
-
211
- class RunMetadataFilter(WorkspaceScopedFilter):
212
- """Model to enable advanced filtering of run metadata."""
213
-
214
- resource_id: Optional[Union[str, UUID]] = Field(
215
- default=None, union_mode="left_to_right"
216
- )
217
- resource_type: Optional[MetadataResourceTypes] = None
218
- stack_component_id: Optional[Union[str, UUID]] = Field(
219
- default=None, union_mode="left_to_right"
220
- )
221
- key: Optional[str] = None
222
- type: Optional[Union[str, MetadataTypeEnum]] = Field(
223
- default=None, union_mode="left_to_right"
224
- )
225
-
226
-
227
- # -------------------- Lazy Loader --------------------
228
-
229
-
230
- class LazyRunMetadataResponse(RunMetadataResponse):
231
- """Lazy run metadata response.
232
-
233
- Used if the run metadata is accessed from the model in
234
- a pipeline context available only during pipeline compilation.
235
- """
236
-
237
- id: Optional[UUID] = None # type: ignore[assignment]
238
- lazy_load_artifact_name: Optional[str] = None
239
- lazy_load_artifact_version: Optional[str] = None
240
- lazy_load_metadata_name: Optional[str] = None
241
- lazy_load_model_name: str
242
- lazy_load_model_version: Optional[str] = None
243
-
244
- def get_body(self) -> None: # type: ignore[override]
245
- """Protects from misuse of the lazy loader.
246
-
247
- Raises:
248
- RuntimeError: always
249
- """
250
- raise RuntimeError(
251
- "Cannot access run metadata body before pipeline runs."
252
- )
253
-
254
- def get_metadata(self) -> None: # type: ignore[override]
255
- """Protects from misuse of the lazy loader.
256
-
257
- Raises:
258
- RuntimeError: always
259
- """
260
- raise RuntimeError(
261
- "Cannot access run metadata metadata before pipeline runs."
262
- )
@@ -21,7 +21,8 @@ from pydantic import BaseModel, ConfigDict, Field
21
21
 
22
22
  from zenml.config.step_configurations import StepConfiguration, StepSpec
23
23
  from zenml.constants import STR_FIELD_MAX_LENGTH, TEXT_FIELD_MAX_LENGTH
24
- from zenml.enums import ExecutionStatus
24
+ from zenml.enums import ExecutionStatus, StepRunInputArtifactType
25
+ from zenml.metadata.metadata_types import MetadataType
25
26
  from zenml.models.v2.base.scoped import (
26
27
  WorkspaceScopedFilter,
27
28
  WorkspaceScopedRequest,
@@ -30,19 +31,35 @@ from zenml.models.v2.base.scoped import (
30
31
  WorkspaceScopedResponseMetadata,
31
32
  WorkspaceScopedResponseResources,
32
33
  )
34
+ from zenml.models.v2.core.artifact_version import ArtifactVersionResponse
33
35
  from zenml.models.v2.core.model_version import ModelVersionResponse
34
36
 
35
37
  if TYPE_CHECKING:
36
38
  from sqlalchemy.sql.elements import ColumnElement
37
39
 
38
- from zenml.models.v2.core.artifact_version import ArtifactVersionResponse
39
40
  from zenml.models.v2.core.logs import (
40
41
  LogsRequest,
41
42
  LogsResponse,
42
43
  )
43
- from zenml.models.v2.core.run_metadata import (
44
- RunMetadataResponse,
45
- )
44
+
45
+
46
+ class StepRunInputResponse(ArtifactVersionResponse):
47
+ """Response model for step run inputs."""
48
+
49
+ input_type: StepRunInputArtifactType
50
+
51
+ def get_hydrated_version(self) -> "StepRunInputResponse":
52
+ """Get the hydrated version of this step run input.
53
+
54
+ Returns:
55
+ an instance of the same entity with the metadata field attached.
56
+ """
57
+ from zenml.client import Client
58
+
59
+ return StepRunInputResponse(
60
+ input_type=self.input_type,
61
+ **Client().zen_store.get_artifact_version(self.id).model_dump(),
62
+ )
46
63
 
47
64
 
48
65
  # ------------------ Request Model ------------------
@@ -97,11 +114,11 @@ class StepRunRequest(WorkspaceScopedRequest):
97
114
  )
98
115
  inputs: Dict[str, UUID] = Field(
99
116
  title="The IDs of the input artifact versions of the step run.",
100
- default={},
117
+ default_factory=dict,
101
118
  )
102
- outputs: Dict[str, UUID] = Field(
119
+ outputs: Dict[str, List[UUID]] = Field(
103
120
  title="The IDs of the output artifact versions of the step run.",
104
- default={},
121
+ default_factory=dict,
105
122
  )
106
123
  logs: Optional["LogsRequest"] = Field(
107
124
  title="Logs associated with this step run.",
@@ -129,10 +146,6 @@ class StepRunUpdate(BaseModel):
129
146
  title="The IDs of the output artifact versions of the step run.",
130
147
  default={},
131
148
  )
132
- saved_artifact_versions: Dict[str, UUID] = Field(
133
- title="The IDs of artifact versions that were saved by this step run.",
134
- default={},
135
- )
136
149
  loaded_artifact_versions: Dict[str, UUID] = Field(
137
150
  title="The IDs of artifact versions that were loaded by this step run.",
138
151
  default={},
@@ -166,13 +179,13 @@ class StepRunResponseBody(WorkspaceScopedResponseBody):
166
179
  title="The end time of the step run.",
167
180
  default=None,
168
181
  )
169
- inputs: Dict[str, "ArtifactVersionResponse"] = Field(
182
+ inputs: Dict[str, StepRunInputResponse] = Field(
170
183
  title="The input artifact versions of the step run.",
171
- default={},
184
+ default_factory=dict,
172
185
  )
173
- outputs: Dict[str, "ArtifactVersionResponse"] = Field(
186
+ outputs: Dict[str, List[ArtifactVersionResponse]] = Field(
174
187
  title="The output artifact versions of the step run.",
175
- default={},
188
+ default_factory=dict,
176
189
  )
177
190
  model_version_id: Optional[UUID] = Field(
178
191
  title="The ID of the model version that was "
@@ -230,7 +243,7 @@ class StepRunResponseMetadata(WorkspaceScopedResponseMetadata):
230
243
  title="The IDs of the parent steps of this step run.",
231
244
  default_factory=list,
232
245
  )
233
- run_metadata: Dict[str, "RunMetadataResponse"] = Field(
246
+ run_metadata: Dict[str, MetadataType] = Field(
234
247
  title="Metadata associated with this step run.",
235
248
  default={},
236
249
  )
@@ -274,7 +287,7 @@ class StepRunResponse(
274
287
 
275
288
  # Helper properties
276
289
  @property
277
- def input(self) -> "ArtifactVersionResponse":
290
+ def input(self) -> ArtifactVersionResponse:
278
291
  """Returns the input artifact that was used to run this step.
279
292
 
280
293
  Returns:
@@ -293,7 +306,7 @@ class StepRunResponse(
293
306
  return next(iter(self.inputs.values()))
294
307
 
295
308
  @property
296
- def output(self) -> "ArtifactVersionResponse":
309
+ def output(self) -> ArtifactVersionResponse:
297
310
  """Returns the output artifact that was written by this step.
298
311
 
299
312
  Returns:
@@ -304,12 +317,15 @@ class StepRunResponse(
304
317
  """
305
318
  if not self.outputs:
306
319
  raise ValueError(f"Step {self.name} has no outputs.")
307
- if len(self.outputs) > 1:
320
+ if len(self.outputs) > 1 or (
321
+ len(self.outputs) == 1
322
+ and len(next(iter(self.outputs.values()))) > 1
323
+ ):
308
324
  raise ValueError(
309
325
  f"Step {self.name} has multiple outputs, so `Step.output` is "
310
326
  "ambiguous. Please use `Step.outputs` instead."
311
327
  )
312
- return next(iter(self.outputs.values()))
328
+ return next(iter(self.outputs.values()))[0]
313
329
 
314
330
  # Body and metadata properties
315
331
  @property
@@ -322,7 +338,7 @@ class StepRunResponse(
322
338
  return self.get_body().status
323
339
 
324
340
  @property
325
- def inputs(self) -> Dict[str, "ArtifactVersionResponse"]:
341
+ def inputs(self) -> Dict[str, StepRunInputResponse]:
326
342
  """The `inputs` property.
327
343
 
328
344
  Returns:
@@ -331,7 +347,7 @@ class StepRunResponse(
331
347
  return self.get_body().inputs
332
348
 
333
349
  @property
334
- def outputs(self) -> Dict[str, "ArtifactVersionResponse"]:
350
+ def outputs(self) -> Dict[str, List[ArtifactVersionResponse]]:
335
351
  """The `outputs` property.
336
352
 
337
353
  Returns:
@@ -466,7 +482,7 @@ class StepRunResponse(
466
482
  return self.get_metadata().parent_step_ids
467
483
 
468
484
  @property
469
- def run_metadata(self) -> Dict[str, "RunMetadataResponse"]:
485
+ def run_metadata(self) -> Dict[str, MetadataType]:
470
486
  """The `run_metadata` property.
471
487
 
472
488
  Returns: