zenml-nightly 0.83.1.dev20250626__py3-none-any.whl → 0.83.1.dev20250628__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (25) hide show
  1. zenml/VERSION +1 -1
  2. zenml/client.py +8 -2
  3. zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py +1 -1
  4. zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py +43 -8
  5. zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +88 -64
  6. zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint_configuration.py +0 -12
  7. zenml/integrations/kubernetes/orchestrators/manifest_utils.py +6 -20
  8. zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py +4 -2
  9. zenml/integrations/vllm/services/vllm_deployment.py +1 -1
  10. zenml/models/v2/core/pipeline_run.py +10 -0
  11. zenml/orchestrators/dag_runner.py +12 -3
  12. zenml/orchestrators/input_utils.py +6 -35
  13. zenml/orchestrators/step_run_utils.py +89 -15
  14. zenml/pipelines/pipeline_definition.py +6 -2
  15. zenml/pipelines/run_utils.py +5 -9
  16. zenml/stack/stack_component.py +1 -1
  17. zenml/zen_server/template_execution/utils.py +0 -1
  18. zenml/zen_stores/schemas/pipeline_run_schemas.py +38 -19
  19. zenml/zen_stores/schemas/step_run_schemas.py +44 -14
  20. zenml/zen_stores/sql_zen_store.py +75 -49
  21. {zenml_nightly-0.83.1.dev20250626.dist-info → zenml_nightly-0.83.1.dev20250628.dist-info}/METADATA +1 -1
  22. {zenml_nightly-0.83.1.dev20250626.dist-info → zenml_nightly-0.83.1.dev20250628.dist-info}/RECORD +25 -25
  23. {zenml_nightly-0.83.1.dev20250626.dist-info → zenml_nightly-0.83.1.dev20250628.dist-info}/LICENSE +0 -0
  24. {zenml_nightly-0.83.1.dev20250626.dist-info → zenml_nightly-0.83.1.dev20250628.dist-info}/WHEEL +0 -0
  25. {zenml_nightly-0.83.1.dev20250626.dist-info → zenml_nightly-0.83.1.dev20250628.dist-info}/entry_points.txt +0 -0
@@ -72,6 +72,7 @@ class ThreadedDagRunner:
72
72
  self,
73
73
  dag: Dict[str, List[str]],
74
74
  run_fn: Callable[[str], Any],
75
+ preparation_fn: Optional[Callable[[str], bool]] = None,
75
76
  finalize_fn: Optional[Callable[[Dict[str, NodeStatus]], None]] = None,
76
77
  parallel_node_startup_waiting_period: float = 0.0,
77
78
  max_parallelism: Optional[int] = None,
@@ -83,6 +84,9 @@ class ThreadedDagRunner:
83
84
  E.g.: [(1->2), (1->3), (2->4), (3->4)] should be represented as
84
85
  `dag={2: [1], 3: [1], 4: [2, 3]}`
85
86
  run_fn: A function `run_fn(node)` that runs a single node
87
+ preparation_fn: A function that is called before the node is run.
88
+ If provided, the function return value determines whether the
89
+ node should be run or can be skipped.
86
90
  finalize_fn: A function `finalize_fn(node_states)` that is called
87
91
  when all nodes have completed.
88
92
  parallel_node_startup_waiting_period: Delay in seconds to wait in
@@ -102,6 +106,7 @@ class ThreadedDagRunner:
102
106
  self.dag = dag
103
107
  self.reversed_dag = reverse_dag(dag)
104
108
  self.run_fn = run_fn
109
+ self.preparation_fn = preparation_fn
105
110
  self.finalize_fn = finalize_fn
106
111
  self.nodes = dag.keys()
107
112
  self.node_states = {
@@ -156,7 +161,7 @@ class ThreadedDagRunner:
156
161
  break
157
162
 
158
163
  logger.debug(f"Waiting for {running_nodes} nodes to finish.")
159
- time.sleep(10)
164
+ time.sleep(1)
160
165
 
161
166
  def _run_node(self, node: str) -> None:
162
167
  """Run a single node.
@@ -168,6 +173,12 @@ class ThreadedDagRunner:
168
173
  """
169
174
  self._prepare_node_run(node)
170
175
 
176
+ if self.preparation_fn:
177
+ run_required = self.preparation_fn(node)
178
+ if not run_required:
179
+ self._finish_node(node)
180
+ return
181
+
171
182
  try:
172
183
  self.run_fn(node)
173
184
  self._finish_node(node)
@@ -203,8 +214,6 @@ class ThreadedDagRunner:
203
214
  node: The node.
204
215
  failed: Whether the node failed.
205
216
  """
206
- # Update node status to completed.
207
- assert self.node_states[node] == NodeStatus.RUNNING
208
217
  with self._lock:
209
218
  if failed:
210
219
  self.node_states[node] = NodeStatus.FAILED
@@ -13,14 +13,13 @@
13
13
  # permissions and limitations under the License.
14
14
  """Utilities for inputs."""
15
15
 
16
- import json
17
16
  from typing import TYPE_CHECKING, Dict, Optional
18
17
 
19
18
  from zenml.client import Client
20
19
  from zenml.config.step_configurations import Step
21
20
  from zenml.enums import StepRunInputArtifactType
22
21
  from zenml.exceptions import InputResolutionError
23
- from zenml.utils import pagination_utils, string_utils
22
+ from zenml.utils import string_utils
24
23
 
25
24
  if TYPE_CHECKING:
26
25
  from zenml.models import PipelineRunResponse, StepRunResponse
@@ -52,6 +51,7 @@ def resolve_step_inputs(
52
51
  """
53
52
  from zenml.models import ArtifactVersionResponse
54
53
  from zenml.models.v2.core.step_run import StepRunInputResponse
54
+ from zenml.orchestrators.step_run_utils import fetch_step_runs_by_names
55
55
 
56
56
  step_runs = step_runs or {}
57
57
 
@@ -62,40 +62,11 @@ def resolve_step_inputs(
62
62
  steps_to_fetch.difference_update(step_runs.keys())
63
63
 
64
64
  if steps_to_fetch:
65
- # The list of steps might be too big to fit in the default max URL
66
- # length of 8KB supported by most servers. So we need to split it into
67
- # smaller chunks.
68
- steps_list = list(steps_to_fetch)
69
- chunks = []
70
- current_chunk = []
71
- current_length = 0
72
- # stay under 6KB for good measure.
73
- max_chunk_length = 6000
74
-
75
- for step_name in steps_list:
76
- current_chunk.append(step_name)
77
- current_length += len(step_name) + 5 # 5 is for the JSON encoding
78
-
79
- if current_length > max_chunk_length:
80
- chunks.append(current_chunk)
81
- current_chunk = []
82
- current_length = 0
83
-
84
- if current_chunk:
85
- chunks.append(current_chunk)
86
-
87
- for chunk in chunks:
88
- step_runs.update(
89
- {
90
- run_step.name: run_step
91
- for run_step in pagination_utils.depaginate(
92
- Client().list_run_steps,
93
- pipeline_run_id=pipeline_run.id,
94
- project=pipeline_run.project_id,
95
- name="oneof:" + json.dumps(chunk),
96
- )
97
- }
65
+ step_runs.update(
66
+ fetch_step_runs_by_names(
67
+ step_run_names=list(steps_to_fetch), pipeline_run=pipeline_run
98
68
  )
69
+ )
99
70
 
100
71
  input_artifacts: Dict[str, StepRunInputResponse] = {}
101
72
  for name, input_ in step.spec.inputs.items():
@@ -13,6 +13,7 @@
13
13
  # permissions and limitations under the License.
14
14
  """Utilities for creating step runs."""
15
15
 
16
+ import json
16
17
  from typing import Dict, List, Optional, Set, Tuple, Union
17
18
 
18
19
  from zenml import Tag, add_tags
@@ -32,6 +33,7 @@ from zenml.models import (
32
33
  )
33
34
  from zenml.orchestrators import cache_utils, input_utils, utils
34
35
  from zenml.stack import Stack
36
+ from zenml.utils import pagination_utils
35
37
  from zenml.utils.time_utils import utc_now
36
38
 
37
39
  logger = get_logger(__name__)
@@ -151,6 +153,15 @@ class StepRunRequestFactory:
151
153
  request.status = ExecutionStatus.CACHED
152
154
  request.end_time = request.start_time
153
155
 
156
+ # As a last resort, we try to reuse the docstring/source code
157
+ # from the cached step run. This is part of the cache key
158
+ # computation, so it must be identical to the one we would have
159
+ # computed ourselves.
160
+ if request.source_code is None:
161
+ request.source_code = cached_step_run.source_code
162
+ if request.docstring is None:
163
+ request.docstring = cached_step_run.docstring
164
+
154
165
  def _get_docstring_and_source_code(
155
166
  self, invocation_id: str
156
167
  ) -> Tuple[Optional[str], Optional[str]]:
@@ -333,27 +344,15 @@ def create_cached_step_runs(
333
344
  # -> We don't need to do anything here
334
345
  continue
335
346
 
336
- step_run = Client().zen_store.create_run_step(step_run_request)
347
+ step_run = publish_cached_step_run(
348
+ step_run_request, pipeline_run=pipeline_run
349
+ )
337
350
 
338
351
  # Include the newly created step run in the step runs dictionary to
339
352
  # avoid fetching it again later when downstream steps need it for
340
353
  # input resolution.
341
354
  step_runs[invocation_id] = step_run
342
355
 
343
- if (
344
- model_version := step_run.model_version
345
- or pipeline_run.model_version
346
- ):
347
- link_output_artifacts_to_model_version(
348
- artifacts=step_run.outputs,
349
- model_version=model_version,
350
- )
351
-
352
- cascade_tags_for_output_artifacts(
353
- artifacts=step_run.outputs,
354
- tags=pipeline_run.config.tags,
355
- )
356
-
357
356
  logger.info("Using cached version of step `%s`.", invocation_id)
358
357
  cached_invocations.add(invocation_id)
359
358
 
@@ -426,3 +425,78 @@ def cascade_tags_for_output_artifacts(
426
425
  tags=[t.name for t in cascade_tags],
427
426
  artifact_version_id=output_artifact.id,
428
427
  )
428
+
429
+
430
+ def publish_cached_step_run(
431
+ request: "StepRunRequest", pipeline_run: "PipelineRunResponse"
432
+ ) -> "StepRunResponse":
433
+ """Create a cached step run and link to model version and tags.
434
+
435
+ Args:
436
+ request: The request for the step run.
437
+ pipeline_run: The pipeline run of the step.
438
+
439
+ Returns:
440
+ The createdstep run.
441
+ """
442
+ step_run = Client().zen_store.create_run_step(request)
443
+
444
+ if model_version := step_run.model_version or pipeline_run.model_version:
445
+ link_output_artifacts_to_model_version(
446
+ artifacts=step_run.outputs,
447
+ model_version=model_version,
448
+ )
449
+
450
+ cascade_tags_for_output_artifacts(
451
+ artifacts=step_run.outputs,
452
+ tags=pipeline_run.config.tags,
453
+ )
454
+
455
+ return step_run
456
+
457
+
458
+ def fetch_step_runs_by_names(
459
+ step_run_names: List[str], pipeline_run: "PipelineRunResponse"
460
+ ) -> Dict[str, "StepRunResponse"]:
461
+ """Fetch step runs by names.
462
+
463
+ Args:
464
+ step_run_names: The names of the step runs to fetch.
465
+ pipeline_run: The pipeline run of the step runs.
466
+
467
+ Returns:
468
+ A dictionary of step runs by name.
469
+ """
470
+ step_runs = {}
471
+
472
+ chunks = []
473
+ current_chunk = []
474
+ current_length = 0
475
+ # stay under 6KB for good measure.
476
+ max_chunk_length = 6000
477
+
478
+ for step_name in step_run_names:
479
+ current_chunk.append(step_name)
480
+ current_length += len(step_name) + 5 # 5 is for the JSON encoding
481
+
482
+ if current_length > max_chunk_length:
483
+ chunks.append(current_chunk)
484
+ current_chunk = []
485
+ current_length = 0
486
+
487
+ if current_chunk:
488
+ chunks.append(current_chunk)
489
+
490
+ for chunk in chunks:
491
+ step_runs.update(
492
+ {
493
+ run_step.name: run_step
494
+ for run_step in pagination_utils.depaginate(
495
+ Client().list_run_steps,
496
+ pipeline_run_id=pipeline_run.id,
497
+ project=pipeline_run.project_id,
498
+ name="oneof:" + json.dumps(chunk),
499
+ )
500
+ }
501
+ )
502
+ return step_runs
@@ -863,8 +863,12 @@ To avoid this consider setting pipeline parameters only in one place (config or
863
863
  deployment = self._create_deployment(**self._run_args)
864
864
 
865
865
  self.log_pipeline_deployment_metadata(deployment)
866
- run = create_placeholder_run(
867
- deployment=deployment, logs=logs_model
866
+ run = (
867
+ create_placeholder_run(
868
+ deployment=deployment, logs=logs_model
869
+ )
870
+ if not deployment.schedule
871
+ else None
868
872
  )
869
873
 
870
874
  analytics_handler.metadata = (
@@ -51,23 +51,19 @@ def get_default_run_name(pipeline_name: str) -> str:
51
51
 
52
52
  def create_placeholder_run(
53
53
  deployment: "PipelineDeploymentResponse",
54
+ orchestrator_run_id: Optional[str] = None,
54
55
  logs: Optional["LogsRequest"] = None,
55
- ) -> Optional["PipelineRunResponse"]:
56
+ ) -> "PipelineRunResponse":
56
57
  """Create a placeholder run for the deployment.
57
58
 
58
- If the deployment contains a schedule, no placeholder run will be
59
- created.
60
-
61
59
  Args:
62
60
  deployment: The deployment for which to create the placeholder run.
61
+ orchestrator_run_id: The orchestrator run ID for the run.
63
62
  logs: The logs for the run.
64
63
 
65
64
  Returns:
66
- The placeholder run or `None` if no run was created.
65
+ The placeholder run.
67
66
  """
68
- if deployment.schedule:
69
- return None
70
-
71
67
  start_time = utc_now()
72
68
  run_request = PipelineRunRequest(
73
69
  name=string_utils.format_name_template(
@@ -83,7 +79,7 @@ def create_placeholder_run(
83
79
  # the start_time is only set once the first step starts
84
80
  # running.
85
81
  start_time=start_time,
86
- orchestrator_run_id=None,
82
+ orchestrator_run_id=orchestrator_run_id,
87
83
  project=deployment.project_id,
88
84
  deployment=deployment.id,
89
85
  pipeline=deployment.pipeline.id if deployment.pipeline else None,
@@ -527,7 +527,7 @@ class StackComponent:
527
527
  )
528
528
 
529
529
  # Use the current config as a base
530
- settings_dict = self.config.model_dump()
530
+ settings_dict = self.config.model_dump(exclude_unset=True)
531
531
 
532
532
  if key in all_settings:
533
533
  settings_dict.update(dict(all_settings[key]))
@@ -193,7 +193,6 @@ def run_template(
193
193
  zenml_version = build.zenml_version
194
194
 
195
195
  placeholder_run = create_placeholder_run(deployment=new_deployment)
196
- assert placeholder_run
197
196
 
198
197
  report_usage(
199
198
  feature=RUN_TEMPLATE_TRIGGERS_FEATURE_NAME,
@@ -20,7 +20,7 @@ from uuid import UUID
20
20
 
21
21
  from pydantic import ConfigDict
22
22
  from sqlalchemy import UniqueConstraint
23
- from sqlalchemy.orm import joinedload
23
+ from sqlalchemy.orm import selectinload
24
24
  from sqlalchemy.sql.base import ExecutableOption
25
25
  from sqlmodel import TEXT, Column, Field, Relationship
26
26
 
@@ -51,7 +51,9 @@ from zenml.zen_stores.schemas.pipeline_deployment_schemas import (
51
51
  from zenml.zen_stores.schemas.pipeline_schemas import PipelineSchema
52
52
  from zenml.zen_stores.schemas.project_schemas import ProjectSchema
53
53
  from zenml.zen_stores.schemas.schedule_schema import ScheduleSchema
54
- from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field
54
+ from zenml.zen_stores.schemas.schema_utils import (
55
+ build_foreign_key_field,
56
+ )
55
57
  from zenml.zen_stores.schemas.stack_schemas import StackSchema
56
58
  from zenml.zen_stores.schemas.trigger_schemas import TriggerExecutionSchema
57
59
  from zenml.zen_stores.schemas.user_schemas import UserSchema
@@ -259,19 +261,19 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True):
259
261
  from zenml.zen_stores.schemas import ModelVersionSchema
260
262
 
261
263
  options = [
262
- joinedload(jl_arg(PipelineRunSchema.deployment)).joinedload(
264
+ selectinload(jl_arg(PipelineRunSchema.deployment)).joinedload(
263
265
  jl_arg(PipelineDeploymentSchema.pipeline)
264
266
  ),
265
- joinedload(jl_arg(PipelineRunSchema.deployment)).joinedload(
267
+ selectinload(jl_arg(PipelineRunSchema.deployment)).joinedload(
266
268
  jl_arg(PipelineDeploymentSchema.stack)
267
269
  ),
268
- joinedload(jl_arg(PipelineRunSchema.deployment)).joinedload(
270
+ selectinload(jl_arg(PipelineRunSchema.deployment)).joinedload(
269
271
  jl_arg(PipelineDeploymentSchema.build)
270
272
  ),
271
- joinedload(jl_arg(PipelineRunSchema.deployment)).joinedload(
273
+ selectinload(jl_arg(PipelineRunSchema.deployment)).joinedload(
272
274
  jl_arg(PipelineDeploymentSchema.schedule)
273
275
  ),
274
- joinedload(jl_arg(PipelineRunSchema.deployment)).joinedload(
276
+ selectinload(jl_arg(PipelineRunSchema.deployment)).joinedload(
275
277
  jl_arg(PipelineDeploymentSchema.code_reference)
276
278
  ),
277
279
  ]
@@ -286,14 +288,14 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True):
286
288
  if include_resources:
287
289
  options.extend(
288
290
  [
289
- joinedload(
291
+ selectinload(
290
292
  jl_arg(PipelineRunSchema.model_version)
291
293
  ).joinedload(
292
294
  jl_arg(ModelVersionSchema.model), innerjoin=True
293
295
  ),
294
- joinedload(jl_arg(PipelineRunSchema.logs)),
295
- joinedload(jl_arg(PipelineRunSchema.user)),
296
- # joinedload(jl_arg(PipelineRunSchema.tags)),
296
+ selectinload(jl_arg(PipelineRunSchema.logs)),
297
+ selectinload(jl_arg(PipelineRunSchema.user)),
298
+ selectinload(jl_arg(PipelineRunSchema.tags)),
297
299
  ]
298
300
  )
299
301
 
@@ -550,8 +552,8 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True):
550
552
 
551
553
  Raises:
552
554
  RuntimeError: If the DB entry does not represent a placeholder run.
553
- ValueError: If the run request does not match the deployment or
554
- pipeline ID of the placeholder run.
555
+ ValueError: If the run request is not a valid request to replace the
556
+ placeholder run.
555
557
 
556
558
  Returns:
557
559
  The updated `PipelineRunSchema`.
@@ -562,13 +564,33 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True):
562
564
  "placeholder run."
563
565
  )
564
566
 
567
+ if request.is_placeholder_request:
568
+ raise ValueError(
569
+ "Cannot replace a placeholder run with another placeholder run."
570
+ )
571
+
565
572
  if (
566
573
  self.deployment_id != request.deployment
567
574
  or self.pipeline_id != request.pipeline
575
+ or self.project_id != request.project
568
576
  ):
569
577
  raise ValueError(
570
- "Deployment or orchestrator run ID of placeholder run do not "
571
- "match the IDs of the run request."
578
+ "Deployment, project or pipeline ID of placeholder run "
579
+ "do not match the IDs of the run request."
580
+ )
581
+
582
+ if not request.orchestrator_run_id:
583
+ raise ValueError(
584
+ "Orchestrator run ID is required to replace a placeholder run."
585
+ )
586
+
587
+ if (
588
+ self.orchestrator_run_id
589
+ and self.orchestrator_run_id != request.orchestrator_run_id
590
+ ):
591
+ raise ValueError(
592
+ "Orchestrator run ID of placeholder run does not match the "
593
+ "ID of the run request."
572
594
  )
573
595
 
574
596
  orchestrator_environment = json.dumps(request.orchestrator_environment)
@@ -587,7 +609,4 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True):
587
609
  Returns:
588
610
  Whether the pipeline run is a placeholder run.
589
611
  """
590
- return (
591
- self.orchestrator_run_id is None
592
- and self.status == ExecutionStatus.INITIALIZING
593
- )
612
+ return self.status == ExecutionStatus.INITIALIZING.value
@@ -21,7 +21,7 @@ from uuid import UUID
21
21
  from pydantic import ConfigDict
22
22
  from sqlalchemy import TEXT, Column, String, UniqueConstraint
23
23
  from sqlalchemy.dialects.mysql import MEDIUMTEXT
24
- from sqlalchemy.orm import joinedload
24
+ from sqlalchemy.orm import joinedload, selectinload
25
25
  from sqlalchemy.sql.base import ExecutableOption
26
26
  from sqlmodel import Field, Relationship, SQLModel
27
27
 
@@ -50,6 +50,7 @@ from zenml.zen_stores.schemas.base_schemas import NamedSchema
50
50
  from zenml.zen_stores.schemas.constants import MODEL_VERSION_TABLENAME
51
51
  from zenml.zen_stores.schemas.pipeline_deployment_schemas import (
52
52
  PipelineDeploymentSchema,
53
+ StepConfigurationSchema,
53
54
  )
54
55
  from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema
55
56
  from zenml.zen_stores.schemas.project_schemas import ProjectSchema
@@ -187,6 +188,14 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True):
187
188
  original_step_run: Optional["StepRunSchema"] = Relationship(
188
189
  sa_relationship_kwargs={"remote_side": "StepRunSchema.id"}
189
190
  )
191
+ step_configuration_schema: Optional["StepConfigurationSchema"] = (
192
+ Relationship(
193
+ sa_relationship_kwargs=dict(
194
+ viewonly=True,
195
+ primaryjoin="and_(foreign(StepConfigurationSchema.name) == StepRunSchema.name, foreign(StepConfigurationSchema.deployment_id) == StepRunSchema.deployment_id)",
196
+ ),
197
+ )
198
+ )
190
199
 
191
200
  model_config = ConfigDict(protected_namespaces=()) # type: ignore[assignment]
192
201
 
@@ -209,17 +218,25 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True):
209
218
  Returns:
210
219
  A list of query options.
211
220
  """
212
- from zenml.zen_stores.schemas import ModelVersionSchema
221
+ from zenml.zen_stores.schemas import (
222
+ ArtifactVersionSchema,
223
+ ModelVersionSchema,
224
+ )
213
225
 
214
226
  options = [
215
- joinedload(jl_arg(StepRunSchema.deployment)),
216
- joinedload(jl_arg(StepRunSchema.pipeline_run)),
227
+ selectinload(jl_arg(StepRunSchema.deployment)).load_only(
228
+ jl_arg(PipelineDeploymentSchema.pipeline_configuration)
229
+ ),
230
+ selectinload(jl_arg(StepRunSchema.pipeline_run)).load_only(
231
+ jl_arg(PipelineRunSchema.start_time)
232
+ ),
233
+ joinedload(jl_arg(StepRunSchema.step_configuration_schema)),
217
234
  ]
218
235
 
219
236
  if include_metadata:
220
237
  options.extend(
221
238
  [
222
- joinedload(jl_arg(StepRunSchema.logs)),
239
+ selectinload(jl_arg(StepRunSchema.logs)),
223
240
  # joinedload(jl_arg(StepRunSchema.parents)),
224
241
  # joinedload(jl_arg(StepRunSchema.run_metadata)),
225
242
  ]
@@ -228,12 +245,28 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True):
228
245
  if include_resources:
229
246
  options.extend(
230
247
  [
231
- joinedload(jl_arg(StepRunSchema.model_version)).joinedload(
248
+ selectinload(
249
+ jl_arg(StepRunSchema.model_version)
250
+ ).joinedload(
232
251
  jl_arg(ModelVersionSchema.model), innerjoin=True
233
252
  ),
234
- joinedload(jl_arg(StepRunSchema.user)),
235
- # joinedload(jl_arg(StepRunSchema.input_artifacts)),
236
- # joinedload(jl_arg(StepRunSchema.output_artifacts)),
253
+ selectinload(jl_arg(StepRunSchema.user)),
254
+ selectinload(jl_arg(StepRunSchema.input_artifacts))
255
+ .joinedload(
256
+ jl_arg(StepRunInputArtifactSchema.artifact_version),
257
+ innerjoin=True,
258
+ )
259
+ .joinedload(
260
+ jl_arg(ArtifactVersionSchema.artifact), innerjoin=True
261
+ ),
262
+ selectinload(jl_arg(StepRunSchema.output_artifacts))
263
+ .joinedload(
264
+ jl_arg(StepRunOutputArtifactSchema.artifact_version),
265
+ innerjoin=True,
266
+ )
267
+ .joinedload(
268
+ jl_arg(ArtifactVersionSchema.artifact), innerjoin=True
269
+ ),
237
270
  ]
238
271
  )
239
272
 
@@ -290,10 +323,7 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True):
290
323
  """
291
324
  step = None
292
325
  if self.deployment is not None:
293
- step_configurations = self.deployment.get_step_configurations(
294
- include=[self.name]
295
- )
296
- if step_configurations:
326
+ if self.step_configuration_schema:
297
327
  pipeline_configuration = (
298
328
  PipelineConfiguration.model_validate_json(
299
329
  self.deployment.pipeline_configuration
@@ -304,7 +334,7 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True):
304
334
  inplace=True,
305
335
  )
306
336
  step = Step.from_dict(
307
- json.loads(step_configurations[0].config),
337
+ json.loads(self.step_configuration_schema.config),
308
338
  pipeline_configuration=pipeline_configuration,
309
339
  )
310
340
  if not step and self.step_configuration: