zenml-nightly 0.83.1.dev20250626__py3-none-any.whl → 0.83.1.dev20250628__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/client.py +8 -2
- zenml/integrations/aws/flavors/sagemaker_orchestrator_flavor.py +1 -1
- zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py +43 -8
- zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +88 -64
- zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint_configuration.py +0 -12
- zenml/integrations/kubernetes/orchestrators/manifest_utils.py +6 -20
- zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py +4 -2
- zenml/integrations/vllm/services/vllm_deployment.py +1 -1
- zenml/models/v2/core/pipeline_run.py +10 -0
- zenml/orchestrators/dag_runner.py +12 -3
- zenml/orchestrators/input_utils.py +6 -35
- zenml/orchestrators/step_run_utils.py +89 -15
- zenml/pipelines/pipeline_definition.py +6 -2
- zenml/pipelines/run_utils.py +5 -9
- zenml/stack/stack_component.py +1 -1
- zenml/zen_server/template_execution/utils.py +0 -1
- zenml/zen_stores/schemas/pipeline_run_schemas.py +38 -19
- zenml/zen_stores/schemas/step_run_schemas.py +44 -14
- zenml/zen_stores/sql_zen_store.py +75 -49
- {zenml_nightly-0.83.1.dev20250626.dist-info → zenml_nightly-0.83.1.dev20250628.dist-info}/METADATA +1 -1
- {zenml_nightly-0.83.1.dev20250626.dist-info → zenml_nightly-0.83.1.dev20250628.dist-info}/RECORD +25 -25
- {zenml_nightly-0.83.1.dev20250626.dist-info → zenml_nightly-0.83.1.dev20250628.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.83.1.dev20250626.dist-info → zenml_nightly-0.83.1.dev20250628.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.83.1.dev20250626.dist-info → zenml_nightly-0.83.1.dev20250628.dist-info}/entry_points.txt +0 -0
@@ -72,6 +72,7 @@ class ThreadedDagRunner:
|
|
72
72
|
self,
|
73
73
|
dag: Dict[str, List[str]],
|
74
74
|
run_fn: Callable[[str], Any],
|
75
|
+
preparation_fn: Optional[Callable[[str], bool]] = None,
|
75
76
|
finalize_fn: Optional[Callable[[Dict[str, NodeStatus]], None]] = None,
|
76
77
|
parallel_node_startup_waiting_period: float = 0.0,
|
77
78
|
max_parallelism: Optional[int] = None,
|
@@ -83,6 +84,9 @@ class ThreadedDagRunner:
|
|
83
84
|
E.g.: [(1->2), (1->3), (2->4), (3->4)] should be represented as
|
84
85
|
`dag={2: [1], 3: [1], 4: [2, 3]}`
|
85
86
|
run_fn: A function `run_fn(node)` that runs a single node
|
87
|
+
preparation_fn: A function that is called before the node is run.
|
88
|
+
If provided, the function return value determines whether the
|
89
|
+
node should be run or can be skipped.
|
86
90
|
finalize_fn: A function `finalize_fn(node_states)` that is called
|
87
91
|
when all nodes have completed.
|
88
92
|
parallel_node_startup_waiting_period: Delay in seconds to wait in
|
@@ -102,6 +106,7 @@ class ThreadedDagRunner:
|
|
102
106
|
self.dag = dag
|
103
107
|
self.reversed_dag = reverse_dag(dag)
|
104
108
|
self.run_fn = run_fn
|
109
|
+
self.preparation_fn = preparation_fn
|
105
110
|
self.finalize_fn = finalize_fn
|
106
111
|
self.nodes = dag.keys()
|
107
112
|
self.node_states = {
|
@@ -156,7 +161,7 @@ class ThreadedDagRunner:
|
|
156
161
|
break
|
157
162
|
|
158
163
|
logger.debug(f"Waiting for {running_nodes} nodes to finish.")
|
159
|
-
time.sleep(
|
164
|
+
time.sleep(1)
|
160
165
|
|
161
166
|
def _run_node(self, node: str) -> None:
|
162
167
|
"""Run a single node.
|
@@ -168,6 +173,12 @@ class ThreadedDagRunner:
|
|
168
173
|
"""
|
169
174
|
self._prepare_node_run(node)
|
170
175
|
|
176
|
+
if self.preparation_fn:
|
177
|
+
run_required = self.preparation_fn(node)
|
178
|
+
if not run_required:
|
179
|
+
self._finish_node(node)
|
180
|
+
return
|
181
|
+
|
171
182
|
try:
|
172
183
|
self.run_fn(node)
|
173
184
|
self._finish_node(node)
|
@@ -203,8 +214,6 @@ class ThreadedDagRunner:
|
|
203
214
|
node: The node.
|
204
215
|
failed: Whether the node failed.
|
205
216
|
"""
|
206
|
-
# Update node status to completed.
|
207
|
-
assert self.node_states[node] == NodeStatus.RUNNING
|
208
217
|
with self._lock:
|
209
218
|
if failed:
|
210
219
|
self.node_states[node] = NodeStatus.FAILED
|
@@ -13,14 +13,13 @@
|
|
13
13
|
# permissions and limitations under the License.
|
14
14
|
"""Utilities for inputs."""
|
15
15
|
|
16
|
-
import json
|
17
16
|
from typing import TYPE_CHECKING, Dict, Optional
|
18
17
|
|
19
18
|
from zenml.client import Client
|
20
19
|
from zenml.config.step_configurations import Step
|
21
20
|
from zenml.enums import StepRunInputArtifactType
|
22
21
|
from zenml.exceptions import InputResolutionError
|
23
|
-
from zenml.utils import
|
22
|
+
from zenml.utils import string_utils
|
24
23
|
|
25
24
|
if TYPE_CHECKING:
|
26
25
|
from zenml.models import PipelineRunResponse, StepRunResponse
|
@@ -52,6 +51,7 @@ def resolve_step_inputs(
|
|
52
51
|
"""
|
53
52
|
from zenml.models import ArtifactVersionResponse
|
54
53
|
from zenml.models.v2.core.step_run import StepRunInputResponse
|
54
|
+
from zenml.orchestrators.step_run_utils import fetch_step_runs_by_names
|
55
55
|
|
56
56
|
step_runs = step_runs or {}
|
57
57
|
|
@@ -62,40 +62,11 @@ def resolve_step_inputs(
|
|
62
62
|
steps_to_fetch.difference_update(step_runs.keys())
|
63
63
|
|
64
64
|
if steps_to_fetch:
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
steps_list = list(steps_to_fetch)
|
69
|
-
chunks = []
|
70
|
-
current_chunk = []
|
71
|
-
current_length = 0
|
72
|
-
# stay under 6KB for good measure.
|
73
|
-
max_chunk_length = 6000
|
74
|
-
|
75
|
-
for step_name in steps_list:
|
76
|
-
current_chunk.append(step_name)
|
77
|
-
current_length += len(step_name) + 5 # 5 is for the JSON encoding
|
78
|
-
|
79
|
-
if current_length > max_chunk_length:
|
80
|
-
chunks.append(current_chunk)
|
81
|
-
current_chunk = []
|
82
|
-
current_length = 0
|
83
|
-
|
84
|
-
if current_chunk:
|
85
|
-
chunks.append(current_chunk)
|
86
|
-
|
87
|
-
for chunk in chunks:
|
88
|
-
step_runs.update(
|
89
|
-
{
|
90
|
-
run_step.name: run_step
|
91
|
-
for run_step in pagination_utils.depaginate(
|
92
|
-
Client().list_run_steps,
|
93
|
-
pipeline_run_id=pipeline_run.id,
|
94
|
-
project=pipeline_run.project_id,
|
95
|
-
name="oneof:" + json.dumps(chunk),
|
96
|
-
)
|
97
|
-
}
|
65
|
+
step_runs.update(
|
66
|
+
fetch_step_runs_by_names(
|
67
|
+
step_run_names=list(steps_to_fetch), pipeline_run=pipeline_run
|
98
68
|
)
|
69
|
+
)
|
99
70
|
|
100
71
|
input_artifacts: Dict[str, StepRunInputResponse] = {}
|
101
72
|
for name, input_ in step.spec.inputs.items():
|
@@ -13,6 +13,7 @@
|
|
13
13
|
# permissions and limitations under the License.
|
14
14
|
"""Utilities for creating step runs."""
|
15
15
|
|
16
|
+
import json
|
16
17
|
from typing import Dict, List, Optional, Set, Tuple, Union
|
17
18
|
|
18
19
|
from zenml import Tag, add_tags
|
@@ -32,6 +33,7 @@ from zenml.models import (
|
|
32
33
|
)
|
33
34
|
from zenml.orchestrators import cache_utils, input_utils, utils
|
34
35
|
from zenml.stack import Stack
|
36
|
+
from zenml.utils import pagination_utils
|
35
37
|
from zenml.utils.time_utils import utc_now
|
36
38
|
|
37
39
|
logger = get_logger(__name__)
|
@@ -151,6 +153,15 @@ class StepRunRequestFactory:
|
|
151
153
|
request.status = ExecutionStatus.CACHED
|
152
154
|
request.end_time = request.start_time
|
153
155
|
|
156
|
+
# As a last resort, we try to reuse the docstring/source code
|
157
|
+
# from the cached step run. This is part of the cache key
|
158
|
+
# computation, so it must be identical to the one we would have
|
159
|
+
# computed ourselves.
|
160
|
+
if request.source_code is None:
|
161
|
+
request.source_code = cached_step_run.source_code
|
162
|
+
if request.docstring is None:
|
163
|
+
request.docstring = cached_step_run.docstring
|
164
|
+
|
154
165
|
def _get_docstring_and_source_code(
|
155
166
|
self, invocation_id: str
|
156
167
|
) -> Tuple[Optional[str], Optional[str]]:
|
@@ -333,27 +344,15 @@ def create_cached_step_runs(
|
|
333
344
|
# -> We don't need to do anything here
|
334
345
|
continue
|
335
346
|
|
336
|
-
step_run =
|
347
|
+
step_run = publish_cached_step_run(
|
348
|
+
step_run_request, pipeline_run=pipeline_run
|
349
|
+
)
|
337
350
|
|
338
351
|
# Include the newly created step run in the step runs dictionary to
|
339
352
|
# avoid fetching it again later when downstream steps need it for
|
340
353
|
# input resolution.
|
341
354
|
step_runs[invocation_id] = step_run
|
342
355
|
|
343
|
-
if (
|
344
|
-
model_version := step_run.model_version
|
345
|
-
or pipeline_run.model_version
|
346
|
-
):
|
347
|
-
link_output_artifacts_to_model_version(
|
348
|
-
artifacts=step_run.outputs,
|
349
|
-
model_version=model_version,
|
350
|
-
)
|
351
|
-
|
352
|
-
cascade_tags_for_output_artifacts(
|
353
|
-
artifacts=step_run.outputs,
|
354
|
-
tags=pipeline_run.config.tags,
|
355
|
-
)
|
356
|
-
|
357
356
|
logger.info("Using cached version of step `%s`.", invocation_id)
|
358
357
|
cached_invocations.add(invocation_id)
|
359
358
|
|
@@ -426,3 +425,78 @@ def cascade_tags_for_output_artifacts(
|
|
426
425
|
tags=[t.name for t in cascade_tags],
|
427
426
|
artifact_version_id=output_artifact.id,
|
428
427
|
)
|
428
|
+
|
429
|
+
|
430
|
+
def publish_cached_step_run(
|
431
|
+
request: "StepRunRequest", pipeline_run: "PipelineRunResponse"
|
432
|
+
) -> "StepRunResponse":
|
433
|
+
"""Create a cached step run and link to model version and tags.
|
434
|
+
|
435
|
+
Args:
|
436
|
+
request: The request for the step run.
|
437
|
+
pipeline_run: The pipeline run of the step.
|
438
|
+
|
439
|
+
Returns:
|
440
|
+
The createdstep run.
|
441
|
+
"""
|
442
|
+
step_run = Client().zen_store.create_run_step(request)
|
443
|
+
|
444
|
+
if model_version := step_run.model_version or pipeline_run.model_version:
|
445
|
+
link_output_artifacts_to_model_version(
|
446
|
+
artifacts=step_run.outputs,
|
447
|
+
model_version=model_version,
|
448
|
+
)
|
449
|
+
|
450
|
+
cascade_tags_for_output_artifacts(
|
451
|
+
artifacts=step_run.outputs,
|
452
|
+
tags=pipeline_run.config.tags,
|
453
|
+
)
|
454
|
+
|
455
|
+
return step_run
|
456
|
+
|
457
|
+
|
458
|
+
def fetch_step_runs_by_names(
|
459
|
+
step_run_names: List[str], pipeline_run: "PipelineRunResponse"
|
460
|
+
) -> Dict[str, "StepRunResponse"]:
|
461
|
+
"""Fetch step runs by names.
|
462
|
+
|
463
|
+
Args:
|
464
|
+
step_run_names: The names of the step runs to fetch.
|
465
|
+
pipeline_run: The pipeline run of the step runs.
|
466
|
+
|
467
|
+
Returns:
|
468
|
+
A dictionary of step runs by name.
|
469
|
+
"""
|
470
|
+
step_runs = {}
|
471
|
+
|
472
|
+
chunks = []
|
473
|
+
current_chunk = []
|
474
|
+
current_length = 0
|
475
|
+
# stay under 6KB for good measure.
|
476
|
+
max_chunk_length = 6000
|
477
|
+
|
478
|
+
for step_name in step_run_names:
|
479
|
+
current_chunk.append(step_name)
|
480
|
+
current_length += len(step_name) + 5 # 5 is for the JSON encoding
|
481
|
+
|
482
|
+
if current_length > max_chunk_length:
|
483
|
+
chunks.append(current_chunk)
|
484
|
+
current_chunk = []
|
485
|
+
current_length = 0
|
486
|
+
|
487
|
+
if current_chunk:
|
488
|
+
chunks.append(current_chunk)
|
489
|
+
|
490
|
+
for chunk in chunks:
|
491
|
+
step_runs.update(
|
492
|
+
{
|
493
|
+
run_step.name: run_step
|
494
|
+
for run_step in pagination_utils.depaginate(
|
495
|
+
Client().list_run_steps,
|
496
|
+
pipeline_run_id=pipeline_run.id,
|
497
|
+
project=pipeline_run.project_id,
|
498
|
+
name="oneof:" + json.dumps(chunk),
|
499
|
+
)
|
500
|
+
}
|
501
|
+
)
|
502
|
+
return step_runs
|
@@ -863,8 +863,12 @@ To avoid this consider setting pipeline parameters only in one place (config or
|
|
863
863
|
deployment = self._create_deployment(**self._run_args)
|
864
864
|
|
865
865
|
self.log_pipeline_deployment_metadata(deployment)
|
866
|
-
run =
|
867
|
-
|
866
|
+
run = (
|
867
|
+
create_placeholder_run(
|
868
|
+
deployment=deployment, logs=logs_model
|
869
|
+
)
|
870
|
+
if not deployment.schedule
|
871
|
+
else None
|
868
872
|
)
|
869
873
|
|
870
874
|
analytics_handler.metadata = (
|
zenml/pipelines/run_utils.py
CHANGED
@@ -51,23 +51,19 @@ def get_default_run_name(pipeline_name: str) -> str:
|
|
51
51
|
|
52
52
|
def create_placeholder_run(
|
53
53
|
deployment: "PipelineDeploymentResponse",
|
54
|
+
orchestrator_run_id: Optional[str] = None,
|
54
55
|
logs: Optional["LogsRequest"] = None,
|
55
|
-
) ->
|
56
|
+
) -> "PipelineRunResponse":
|
56
57
|
"""Create a placeholder run for the deployment.
|
57
58
|
|
58
|
-
If the deployment contains a schedule, no placeholder run will be
|
59
|
-
created.
|
60
|
-
|
61
59
|
Args:
|
62
60
|
deployment: The deployment for which to create the placeholder run.
|
61
|
+
orchestrator_run_id: The orchestrator run ID for the run.
|
63
62
|
logs: The logs for the run.
|
64
63
|
|
65
64
|
Returns:
|
66
|
-
The placeholder run
|
65
|
+
The placeholder run.
|
67
66
|
"""
|
68
|
-
if deployment.schedule:
|
69
|
-
return None
|
70
|
-
|
71
67
|
start_time = utc_now()
|
72
68
|
run_request = PipelineRunRequest(
|
73
69
|
name=string_utils.format_name_template(
|
@@ -83,7 +79,7 @@ def create_placeholder_run(
|
|
83
79
|
# the start_time is only set once the first step starts
|
84
80
|
# running.
|
85
81
|
start_time=start_time,
|
86
|
-
orchestrator_run_id=
|
82
|
+
orchestrator_run_id=orchestrator_run_id,
|
87
83
|
project=deployment.project_id,
|
88
84
|
deployment=deployment.id,
|
89
85
|
pipeline=deployment.pipeline.id if deployment.pipeline else None,
|
zenml/stack/stack_component.py
CHANGED
@@ -527,7 +527,7 @@ class StackComponent:
|
|
527
527
|
)
|
528
528
|
|
529
529
|
# Use the current config as a base
|
530
|
-
settings_dict = self.config.model_dump()
|
530
|
+
settings_dict = self.config.model_dump(exclude_unset=True)
|
531
531
|
|
532
532
|
if key in all_settings:
|
533
533
|
settings_dict.update(dict(all_settings[key]))
|
@@ -20,7 +20,7 @@ from uuid import UUID
|
|
20
20
|
|
21
21
|
from pydantic import ConfigDict
|
22
22
|
from sqlalchemy import UniqueConstraint
|
23
|
-
from sqlalchemy.orm import
|
23
|
+
from sqlalchemy.orm import selectinload
|
24
24
|
from sqlalchemy.sql.base import ExecutableOption
|
25
25
|
from sqlmodel import TEXT, Column, Field, Relationship
|
26
26
|
|
@@ -51,7 +51,9 @@ from zenml.zen_stores.schemas.pipeline_deployment_schemas import (
|
|
51
51
|
from zenml.zen_stores.schemas.pipeline_schemas import PipelineSchema
|
52
52
|
from zenml.zen_stores.schemas.project_schemas import ProjectSchema
|
53
53
|
from zenml.zen_stores.schemas.schedule_schema import ScheduleSchema
|
54
|
-
from zenml.zen_stores.schemas.schema_utils import
|
54
|
+
from zenml.zen_stores.schemas.schema_utils import (
|
55
|
+
build_foreign_key_field,
|
56
|
+
)
|
55
57
|
from zenml.zen_stores.schemas.stack_schemas import StackSchema
|
56
58
|
from zenml.zen_stores.schemas.trigger_schemas import TriggerExecutionSchema
|
57
59
|
from zenml.zen_stores.schemas.user_schemas import UserSchema
|
@@ -259,19 +261,19 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True):
|
|
259
261
|
from zenml.zen_stores.schemas import ModelVersionSchema
|
260
262
|
|
261
263
|
options = [
|
262
|
-
|
264
|
+
selectinload(jl_arg(PipelineRunSchema.deployment)).joinedload(
|
263
265
|
jl_arg(PipelineDeploymentSchema.pipeline)
|
264
266
|
),
|
265
|
-
|
267
|
+
selectinload(jl_arg(PipelineRunSchema.deployment)).joinedload(
|
266
268
|
jl_arg(PipelineDeploymentSchema.stack)
|
267
269
|
),
|
268
|
-
|
270
|
+
selectinload(jl_arg(PipelineRunSchema.deployment)).joinedload(
|
269
271
|
jl_arg(PipelineDeploymentSchema.build)
|
270
272
|
),
|
271
|
-
|
273
|
+
selectinload(jl_arg(PipelineRunSchema.deployment)).joinedload(
|
272
274
|
jl_arg(PipelineDeploymentSchema.schedule)
|
273
275
|
),
|
274
|
-
|
276
|
+
selectinload(jl_arg(PipelineRunSchema.deployment)).joinedload(
|
275
277
|
jl_arg(PipelineDeploymentSchema.code_reference)
|
276
278
|
),
|
277
279
|
]
|
@@ -286,14 +288,14 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True):
|
|
286
288
|
if include_resources:
|
287
289
|
options.extend(
|
288
290
|
[
|
289
|
-
|
291
|
+
selectinload(
|
290
292
|
jl_arg(PipelineRunSchema.model_version)
|
291
293
|
).joinedload(
|
292
294
|
jl_arg(ModelVersionSchema.model), innerjoin=True
|
293
295
|
),
|
294
|
-
|
295
|
-
|
296
|
-
|
296
|
+
selectinload(jl_arg(PipelineRunSchema.logs)),
|
297
|
+
selectinload(jl_arg(PipelineRunSchema.user)),
|
298
|
+
selectinload(jl_arg(PipelineRunSchema.tags)),
|
297
299
|
]
|
298
300
|
)
|
299
301
|
|
@@ -550,8 +552,8 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True):
|
|
550
552
|
|
551
553
|
Raises:
|
552
554
|
RuntimeError: If the DB entry does not represent a placeholder run.
|
553
|
-
ValueError: If the run request
|
554
|
-
|
555
|
+
ValueError: If the run request is not a valid request to replace the
|
556
|
+
placeholder run.
|
555
557
|
|
556
558
|
Returns:
|
557
559
|
The updated `PipelineRunSchema`.
|
@@ -562,13 +564,33 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True):
|
|
562
564
|
"placeholder run."
|
563
565
|
)
|
564
566
|
|
567
|
+
if request.is_placeholder_request:
|
568
|
+
raise ValueError(
|
569
|
+
"Cannot replace a placeholder run with another placeholder run."
|
570
|
+
)
|
571
|
+
|
565
572
|
if (
|
566
573
|
self.deployment_id != request.deployment
|
567
574
|
or self.pipeline_id != request.pipeline
|
575
|
+
or self.project_id != request.project
|
568
576
|
):
|
569
577
|
raise ValueError(
|
570
|
-
"Deployment or
|
571
|
-
"match the IDs of the run request."
|
578
|
+
"Deployment, project or pipeline ID of placeholder run "
|
579
|
+
"do not match the IDs of the run request."
|
580
|
+
)
|
581
|
+
|
582
|
+
if not request.orchestrator_run_id:
|
583
|
+
raise ValueError(
|
584
|
+
"Orchestrator run ID is required to replace a placeholder run."
|
585
|
+
)
|
586
|
+
|
587
|
+
if (
|
588
|
+
self.orchestrator_run_id
|
589
|
+
and self.orchestrator_run_id != request.orchestrator_run_id
|
590
|
+
):
|
591
|
+
raise ValueError(
|
592
|
+
"Orchestrator run ID of placeholder run does not match the "
|
593
|
+
"ID of the run request."
|
572
594
|
)
|
573
595
|
|
574
596
|
orchestrator_environment = json.dumps(request.orchestrator_environment)
|
@@ -587,7 +609,4 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True):
|
|
587
609
|
Returns:
|
588
610
|
Whether the pipeline run is a placeholder run.
|
589
611
|
"""
|
590
|
-
return
|
591
|
-
self.orchestrator_run_id is None
|
592
|
-
and self.status == ExecutionStatus.INITIALIZING
|
593
|
-
)
|
612
|
+
return self.status == ExecutionStatus.INITIALIZING.value
|
@@ -21,7 +21,7 @@ from uuid import UUID
|
|
21
21
|
from pydantic import ConfigDict
|
22
22
|
from sqlalchemy import TEXT, Column, String, UniqueConstraint
|
23
23
|
from sqlalchemy.dialects.mysql import MEDIUMTEXT
|
24
|
-
from sqlalchemy.orm import joinedload
|
24
|
+
from sqlalchemy.orm import joinedload, selectinload
|
25
25
|
from sqlalchemy.sql.base import ExecutableOption
|
26
26
|
from sqlmodel import Field, Relationship, SQLModel
|
27
27
|
|
@@ -50,6 +50,7 @@ from zenml.zen_stores.schemas.base_schemas import NamedSchema
|
|
50
50
|
from zenml.zen_stores.schemas.constants import MODEL_VERSION_TABLENAME
|
51
51
|
from zenml.zen_stores.schemas.pipeline_deployment_schemas import (
|
52
52
|
PipelineDeploymentSchema,
|
53
|
+
StepConfigurationSchema,
|
53
54
|
)
|
54
55
|
from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema
|
55
56
|
from zenml.zen_stores.schemas.project_schemas import ProjectSchema
|
@@ -187,6 +188,14 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True):
|
|
187
188
|
original_step_run: Optional["StepRunSchema"] = Relationship(
|
188
189
|
sa_relationship_kwargs={"remote_side": "StepRunSchema.id"}
|
189
190
|
)
|
191
|
+
step_configuration_schema: Optional["StepConfigurationSchema"] = (
|
192
|
+
Relationship(
|
193
|
+
sa_relationship_kwargs=dict(
|
194
|
+
viewonly=True,
|
195
|
+
primaryjoin="and_(foreign(StepConfigurationSchema.name) == StepRunSchema.name, foreign(StepConfigurationSchema.deployment_id) == StepRunSchema.deployment_id)",
|
196
|
+
),
|
197
|
+
)
|
198
|
+
)
|
190
199
|
|
191
200
|
model_config = ConfigDict(protected_namespaces=()) # type: ignore[assignment]
|
192
201
|
|
@@ -209,17 +218,25 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True):
|
|
209
218
|
Returns:
|
210
219
|
A list of query options.
|
211
220
|
"""
|
212
|
-
from zenml.zen_stores.schemas import
|
221
|
+
from zenml.zen_stores.schemas import (
|
222
|
+
ArtifactVersionSchema,
|
223
|
+
ModelVersionSchema,
|
224
|
+
)
|
213
225
|
|
214
226
|
options = [
|
215
|
-
|
216
|
-
|
227
|
+
selectinload(jl_arg(StepRunSchema.deployment)).load_only(
|
228
|
+
jl_arg(PipelineDeploymentSchema.pipeline_configuration)
|
229
|
+
),
|
230
|
+
selectinload(jl_arg(StepRunSchema.pipeline_run)).load_only(
|
231
|
+
jl_arg(PipelineRunSchema.start_time)
|
232
|
+
),
|
233
|
+
joinedload(jl_arg(StepRunSchema.step_configuration_schema)),
|
217
234
|
]
|
218
235
|
|
219
236
|
if include_metadata:
|
220
237
|
options.extend(
|
221
238
|
[
|
222
|
-
|
239
|
+
selectinload(jl_arg(StepRunSchema.logs)),
|
223
240
|
# joinedload(jl_arg(StepRunSchema.parents)),
|
224
241
|
# joinedload(jl_arg(StepRunSchema.run_metadata)),
|
225
242
|
]
|
@@ -228,12 +245,28 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True):
|
|
228
245
|
if include_resources:
|
229
246
|
options.extend(
|
230
247
|
[
|
231
|
-
|
248
|
+
selectinload(
|
249
|
+
jl_arg(StepRunSchema.model_version)
|
250
|
+
).joinedload(
|
232
251
|
jl_arg(ModelVersionSchema.model), innerjoin=True
|
233
252
|
),
|
234
|
-
|
235
|
-
|
236
|
-
|
253
|
+
selectinload(jl_arg(StepRunSchema.user)),
|
254
|
+
selectinload(jl_arg(StepRunSchema.input_artifacts))
|
255
|
+
.joinedload(
|
256
|
+
jl_arg(StepRunInputArtifactSchema.artifact_version),
|
257
|
+
innerjoin=True,
|
258
|
+
)
|
259
|
+
.joinedload(
|
260
|
+
jl_arg(ArtifactVersionSchema.artifact), innerjoin=True
|
261
|
+
),
|
262
|
+
selectinload(jl_arg(StepRunSchema.output_artifacts))
|
263
|
+
.joinedload(
|
264
|
+
jl_arg(StepRunOutputArtifactSchema.artifact_version),
|
265
|
+
innerjoin=True,
|
266
|
+
)
|
267
|
+
.joinedload(
|
268
|
+
jl_arg(ArtifactVersionSchema.artifact), innerjoin=True
|
269
|
+
),
|
237
270
|
]
|
238
271
|
)
|
239
272
|
|
@@ -290,10 +323,7 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True):
|
|
290
323
|
"""
|
291
324
|
step = None
|
292
325
|
if self.deployment is not None:
|
293
|
-
|
294
|
-
include=[self.name]
|
295
|
-
)
|
296
|
-
if step_configurations:
|
326
|
+
if self.step_configuration_schema:
|
297
327
|
pipeline_configuration = (
|
298
328
|
PipelineConfiguration.model_validate_json(
|
299
329
|
self.deployment.pipeline_configuration
|
@@ -304,7 +334,7 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True):
|
|
304
334
|
inplace=True,
|
305
335
|
)
|
306
336
|
step = Step.from_dict(
|
307
|
-
json.loads(
|
337
|
+
json.loads(self.step_configuration_schema.config),
|
308
338
|
pipeline_configuration=pipeline_configuration,
|
309
339
|
)
|
310
340
|
if not step and self.step_configuration:
|