zenml-nightly 0.71.0.dev20241213__py3-none-any.whl → 0.71.0.dev20241214__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/client.py +44 -2
- zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py +0 -1
- zenml/model/model.py +12 -16
- zenml/models/v2/base/filter.py +26 -30
- zenml/models/v2/base/scoped.py +258 -5
- zenml/models/v2/core/artifact_version.py +15 -26
- zenml/models/v2/core/code_repository.py +1 -12
- zenml/models/v2/core/component.py +5 -46
- zenml/models/v2/core/flavor.py +1 -11
- zenml/models/v2/core/model.py +1 -57
- zenml/models/v2/core/model_version.py +5 -33
- zenml/models/v2/core/model_version_artifact.py +11 -3
- zenml/models/v2/core/model_version_pipeline_run.py +14 -3
- zenml/models/v2/core/pipeline.py +47 -55
- zenml/models/v2/core/pipeline_build.py +19 -12
- zenml/models/v2/core/pipeline_deployment.py +0 -10
- zenml/models/v2/core/pipeline_run.py +91 -29
- zenml/models/v2/core/run_template.py +21 -29
- zenml/models/v2/core/schedule.py +0 -10
- zenml/models/v2/core/secret.py +0 -14
- zenml/models/v2/core/service.py +9 -16
- zenml/models/v2/core/service_connector.py +0 -11
- zenml/models/v2/core/stack.py +21 -30
- zenml/models/v2/core/step_run.py +18 -14
- zenml/models/v2/core/trigger.py +19 -3
- zenml/orchestrators/step_launcher.py +9 -13
- zenml/orchestrators/step_run_utils.py +8 -204
- zenml/zen_server/rbac/rbac_sql_zen_store.py +173 -0
- zenml/zen_server/utils.py +4 -3
- zenml/zen_stores/base_zen_store.py +10 -2
- zenml/zen_stores/migrations/versions/a1237ba94fd8_add_model_version_producer_run_unique_.py +68 -0
- zenml/zen_stores/schemas/model_schemas.py +42 -6
- zenml/zen_stores/schemas/pipeline_deployment_schemas.py +7 -7
- zenml/zen_stores/schemas/pipeline_schemas.py +5 -0
- zenml/zen_stores/sql_zen_store.py +322 -86
- {zenml_nightly-0.71.0.dev20241213.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/METADATA +1 -1
- {zenml_nightly-0.71.0.dev20241213.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/RECORD +41 -39
- {zenml_nightly-0.71.0.dev20241213.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.71.0.dev20241213.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.71.0.dev20241213.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/entry_points.txt +0 -0
zenml/models/v2/core/schedule.py
CHANGED
@@ -279,16 +279,6 @@ class ScheduleResponse(
|
|
279
279
|
class ScheduleFilter(WorkspaceScopedFilter):
|
280
280
|
"""Model to enable advanced filtering of all Users."""
|
281
281
|
|
282
|
-
workspace_id: Optional[Union[UUID, str]] = Field(
|
283
|
-
default=None,
|
284
|
-
description="Workspace scope of the schedule.",
|
285
|
-
union_mode="left_to_right",
|
286
|
-
)
|
287
|
-
user_id: Optional[Union[UUID, str]] = Field(
|
288
|
-
default=None,
|
289
|
-
description="User that created the schedule",
|
290
|
-
union_mode="left_to_right",
|
291
|
-
)
|
292
282
|
pipeline_id: Optional[Union[UUID, str]] = Field(
|
293
283
|
default=None,
|
294
284
|
description="Pipeline that the schedule is attached to.",
|
zenml/models/v2/core/secret.py
CHANGED
@@ -15,7 +15,6 @@
|
|
15
15
|
|
16
16
|
from datetime import datetime
|
17
17
|
from typing import Any, ClassVar, Dict, List, Optional, Union
|
18
|
-
from uuid import UUID
|
19
18
|
|
20
19
|
from pydantic import Field, SecretStr
|
21
20
|
|
@@ -253,25 +252,12 @@ class SecretFilter(WorkspaceScopedFilter):
|
|
253
252
|
default=None,
|
254
253
|
description="Name of the secret",
|
255
254
|
)
|
256
|
-
|
257
255
|
scope: Optional[Union[SecretScope, str]] = Field(
|
258
256
|
default=None,
|
259
257
|
description="Scope in which to filter secrets",
|
260
258
|
union_mode="left_to_right",
|
261
259
|
)
|
262
260
|
|
263
|
-
workspace_id: Optional[Union[UUID, str]] = Field(
|
264
|
-
default=None,
|
265
|
-
description="Workspace of the Secret",
|
266
|
-
union_mode="left_to_right",
|
267
|
-
)
|
268
|
-
|
269
|
-
user_id: Optional[Union[UUID, str]] = Field(
|
270
|
-
default=None,
|
271
|
-
description="User that created the Secret",
|
272
|
-
union_mode="left_to_right",
|
273
|
-
)
|
274
|
-
|
275
261
|
@staticmethod
|
276
262
|
def _get_filtering_value(value: Optional[Any]) -> str:
|
277
263
|
"""Convert the value to a string that can be used for lexicographical filtering and sorting.
|
zenml/models/v2/core/service.py
CHANGED
@@ -15,19 +15,20 @@
|
|
15
15
|
|
16
16
|
from datetime import datetime
|
17
17
|
from typing import (
|
18
|
+
TYPE_CHECKING,
|
18
19
|
Any,
|
19
20
|
ClassVar,
|
20
21
|
Dict,
|
21
22
|
List,
|
22
23
|
Optional,
|
23
24
|
Type,
|
25
|
+
TypeVar,
|
24
26
|
Union,
|
25
27
|
)
|
26
28
|
from uuid import UUID
|
27
29
|
|
28
30
|
from pydantic import BaseModel, ConfigDict, Field
|
29
31
|
from sqlalchemy.sql.elements import ColumnElement
|
30
|
-
from sqlmodel import SQLModel
|
31
32
|
|
32
33
|
from zenml.constants import STR_FIELD_MAX_LENGTH
|
33
34
|
from zenml.models.v2.base.scoped import (
|
@@ -37,11 +38,15 @@ from zenml.models.v2.base.scoped import (
|
|
37
38
|
WorkspaceScopedResponseBody,
|
38
39
|
WorkspaceScopedResponseMetadata,
|
39
40
|
WorkspaceScopedResponseResources,
|
40
|
-
WorkspaceScopedTaggableFilter,
|
41
41
|
)
|
42
42
|
from zenml.services.service_status import ServiceState
|
43
43
|
from zenml.services.service_type import ServiceType
|
44
44
|
|
45
|
+
if TYPE_CHECKING:
|
46
|
+
from zenml.zen_stores.schemas import BaseSchema
|
47
|
+
|
48
|
+
AnySchema = TypeVar("AnySchema", bound=BaseSchema)
|
49
|
+
|
45
50
|
# ------------------ Request Model ------------------
|
46
51
|
|
47
52
|
|
@@ -376,16 +381,6 @@ class ServiceFilter(WorkspaceScopedFilter):
|
|
376
381
|
description="Name of the service. Use this to filter services by "
|
377
382
|
"their name.",
|
378
383
|
)
|
379
|
-
workspace_id: Optional[Union[UUID, str]] = Field(
|
380
|
-
default=None,
|
381
|
-
description="Workspace of the service",
|
382
|
-
union_mode="left_to_right",
|
383
|
-
)
|
384
|
-
user_id: Optional[Union[UUID, str]] = Field(
|
385
|
-
default=None,
|
386
|
-
description="User of the service",
|
387
|
-
union_mode="left_to_right",
|
388
|
-
)
|
389
384
|
type: Optional[str] = Field(
|
390
385
|
default=None,
|
391
386
|
description="Type of the service. Filter services by their type.",
|
@@ -457,9 +452,7 @@ class ServiceFilter(WorkspaceScopedFilter):
|
|
457
452
|
"config",
|
458
453
|
]
|
459
454
|
CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [
|
460
|
-
*
|
461
|
-
"workspace_id",
|
462
|
-
"user_id",
|
455
|
+
*WorkspaceScopedFilter.CLI_EXCLUDE_FIELDS,
|
463
456
|
"flavor",
|
464
457
|
"type",
|
465
458
|
"pipeline_step_name",
|
@@ -468,7 +461,7 @@ class ServiceFilter(WorkspaceScopedFilter):
|
|
468
461
|
]
|
469
462
|
|
470
463
|
def generate_filter(
|
471
|
-
self, table: Type["
|
464
|
+
self, table: Type["AnySchema"]
|
472
465
|
) -> Union["ColumnElement[bool]"]:
|
473
466
|
"""Generate the filter for the query.
|
474
467
|
|
@@ -801,7 +801,6 @@ class ServiceConnectorFilter(WorkspaceScopedFilter):
|
|
801
801
|
default=None,
|
802
802
|
description="The type to scope this query to.",
|
803
803
|
)
|
804
|
-
|
805
804
|
name: Optional[str] = Field(
|
806
805
|
default=None,
|
807
806
|
description="The name to filter by",
|
@@ -810,16 +809,6 @@ class ServiceConnectorFilter(WorkspaceScopedFilter):
|
|
810
809
|
default=None,
|
811
810
|
description="The type of service connector to filter by",
|
812
811
|
)
|
813
|
-
workspace_id: Optional[Union[UUID, str]] = Field(
|
814
|
-
default=None,
|
815
|
-
description="Workspace to filter by",
|
816
|
-
union_mode="left_to_right",
|
817
|
-
)
|
818
|
-
user_id: Optional[Union[UUID, str]] = Field(
|
819
|
-
default=None,
|
820
|
-
description="User to filter by",
|
821
|
-
union_mode="left_to_right",
|
822
|
-
)
|
823
812
|
auth_method: Optional[str] = Field(
|
824
813
|
default=None,
|
825
814
|
title="Filter by the authentication method configured for the "
|
zenml/models/v2/core/stack.py
CHANGED
@@ -14,7 +14,17 @@
|
|
14
14
|
"""Models representing stacks."""
|
15
15
|
|
16
16
|
import json
|
17
|
-
from typing import
|
17
|
+
from typing import (
|
18
|
+
TYPE_CHECKING,
|
19
|
+
Any,
|
20
|
+
ClassVar,
|
21
|
+
Dict,
|
22
|
+
List,
|
23
|
+
Optional,
|
24
|
+
Type,
|
25
|
+
TypeVar,
|
26
|
+
Union,
|
27
|
+
)
|
18
28
|
from uuid import UUID
|
19
29
|
|
20
30
|
from pydantic import Field, model_validator
|
@@ -39,6 +49,9 @@ if TYPE_CHECKING:
|
|
39
49
|
from sqlalchemy.sql.elements import ColumnElement
|
40
50
|
|
41
51
|
from zenml.models.v2.core.component import ComponentResponse
|
52
|
+
from zenml.zen_stores.schemas import BaseSchema
|
53
|
+
|
54
|
+
AnySchema = TypeVar("AnySchema", bound=BaseSchema)
|
42
55
|
|
43
56
|
|
44
57
|
# ------------------ Request Model ------------------
|
@@ -323,7 +336,6 @@ class StackFilter(WorkspaceScopedFilter):
|
|
323
336
|
FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [
|
324
337
|
*WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS,
|
325
338
|
"component_id",
|
326
|
-
"user",
|
327
339
|
"component",
|
328
340
|
]
|
329
341
|
|
@@ -334,42 +346,32 @@ class StackFilter(WorkspaceScopedFilter):
|
|
334
346
|
description: Optional[str] = Field(
|
335
347
|
default=None, description="Description of the stack"
|
336
348
|
)
|
337
|
-
workspace_id: Optional[Union[UUID, str]] = Field(
|
338
|
-
default=None,
|
339
|
-
description="Workspace of the stack",
|
340
|
-
union_mode="left_to_right",
|
341
|
-
)
|
342
|
-
user_id: Optional[Union[UUID, str]] = Field(
|
343
|
-
default=None,
|
344
|
-
description="User of the stack",
|
345
|
-
union_mode="left_to_right",
|
346
|
-
)
|
347
349
|
component_id: Optional[Union[UUID, str]] = Field(
|
348
350
|
default=None,
|
349
351
|
description="Component in the stack",
|
350
352
|
union_mode="left_to_right",
|
351
353
|
)
|
352
|
-
user: Optional[Union[UUID, str]] = Field(
|
353
|
-
default=None,
|
354
|
-
description="Name/ID of the user that created the stack.",
|
355
|
-
)
|
356
354
|
component: Optional[Union[UUID, str]] = Field(
|
357
355
|
default=None, description="Name/ID of a component in the stack."
|
358
356
|
)
|
359
357
|
|
360
|
-
def get_custom_filters(
|
358
|
+
def get_custom_filters(
|
359
|
+
self, table: Type["AnySchema"]
|
360
|
+
) -> List["ColumnElement[bool]"]:
|
361
361
|
"""Get custom filters.
|
362
362
|
|
363
|
+
Args:
|
364
|
+
table: The query table.
|
365
|
+
|
363
366
|
Returns:
|
364
367
|
A list of custom filters.
|
365
368
|
"""
|
366
|
-
custom_filters = super().get_custom_filters()
|
369
|
+
custom_filters = super().get_custom_filters(table)
|
367
370
|
|
368
371
|
from zenml.zen_stores.schemas import (
|
369
372
|
StackComponentSchema,
|
370
373
|
StackCompositionSchema,
|
371
374
|
StackSchema,
|
372
|
-
UserSchema,
|
373
375
|
)
|
374
376
|
|
375
377
|
if self.component_id:
|
@@ -379,17 +381,6 @@ class StackFilter(WorkspaceScopedFilter):
|
|
379
381
|
)
|
380
382
|
custom_filters.append(component_id_filter)
|
381
383
|
|
382
|
-
if self.user:
|
383
|
-
user_filter = and_(
|
384
|
-
StackSchema.user_id == UserSchema.id,
|
385
|
-
self.generate_name_or_id_query_conditions(
|
386
|
-
value=self.user,
|
387
|
-
table=UserSchema,
|
388
|
-
additional_columns=["full_name"],
|
389
|
-
),
|
390
|
-
)
|
391
|
-
custom_filters.append(user_filter)
|
392
|
-
|
393
384
|
if self.component:
|
394
385
|
component_filter = and_(
|
395
386
|
StackCompositionSchema.stack_id == StackSchema.id,
|
zenml/models/v2/core/step_run.py
CHANGED
@@ -14,7 +14,16 @@
|
|
14
14
|
"""Models representing steps runs."""
|
15
15
|
|
16
16
|
from datetime import datetime
|
17
|
-
from typing import
|
17
|
+
from typing import (
|
18
|
+
TYPE_CHECKING,
|
19
|
+
ClassVar,
|
20
|
+
Dict,
|
21
|
+
List,
|
22
|
+
Optional,
|
23
|
+
Type,
|
24
|
+
TypeVar,
|
25
|
+
Union,
|
26
|
+
)
|
18
27
|
from uuid import UUID
|
19
28
|
|
20
29
|
from pydantic import BaseModel, ConfigDict, Field
|
@@ -41,6 +50,9 @@ if TYPE_CHECKING:
|
|
41
50
|
LogsRequest,
|
42
51
|
LogsResponse,
|
43
52
|
)
|
53
|
+
from zenml.zen_stores.schemas import BaseSchema
|
54
|
+
|
55
|
+
AnySchema = TypeVar("AnySchema", bound=BaseSchema)
|
44
56
|
|
45
57
|
|
46
58
|
class StepRunInputResponse(ArtifactVersionResponse):
|
@@ -553,16 +565,6 @@ class StepRunFilter(WorkspaceScopedFilter):
|
|
553
565
|
description="Original id for this step run",
|
554
566
|
union_mode="left_to_right",
|
555
567
|
)
|
556
|
-
user_id: Optional[Union[UUID, str]] = Field(
|
557
|
-
default=None,
|
558
|
-
description="User that produced this step run",
|
559
|
-
union_mode="left_to_right",
|
560
|
-
)
|
561
|
-
workspace_id: Optional[Union[UUID, str]] = Field(
|
562
|
-
default=None,
|
563
|
-
description="Workspace of this step run",
|
564
|
-
union_mode="left_to_right",
|
565
|
-
)
|
566
568
|
model_version_id: Optional[Union[UUID, str]] = Field(
|
567
569
|
default=None,
|
568
570
|
description="Model version associated with the step run.",
|
@@ -576,18 +578,20 @@ class StepRunFilter(WorkspaceScopedFilter):
|
|
576
578
|
default=None,
|
577
579
|
description="The run_metadata to filter the step runs by.",
|
578
580
|
)
|
579
|
-
|
580
581
|
model_config = ConfigDict(protected_namespaces=())
|
581
582
|
|
582
583
|
def get_custom_filters(
|
583
|
-
self,
|
584
|
+
self, table: Type["AnySchema"]
|
584
585
|
) -> List["ColumnElement[bool]"]:
|
585
586
|
"""Get custom filters.
|
586
587
|
|
588
|
+
Args:
|
589
|
+
table: The query table.
|
590
|
+
|
587
591
|
Returns:
|
588
592
|
A list of custom filters.
|
589
593
|
"""
|
590
|
-
custom_filters = super().get_custom_filters()
|
594
|
+
custom_filters = super().get_custom_filters(table)
|
591
595
|
|
592
596
|
from sqlmodel import and_
|
593
597
|
|
zenml/models/v2/core/trigger.py
CHANGED
@@ -13,7 +13,17 @@
|
|
13
13
|
# permissions and limitations under the License.
|
14
14
|
"""Collection of all models concerning triggers."""
|
15
15
|
|
16
|
-
from typing import
|
16
|
+
from typing import (
|
17
|
+
TYPE_CHECKING,
|
18
|
+
Any,
|
19
|
+
ClassVar,
|
20
|
+
Dict,
|
21
|
+
List,
|
22
|
+
Optional,
|
23
|
+
Type,
|
24
|
+
TypeVar,
|
25
|
+
Union,
|
26
|
+
)
|
17
27
|
from uuid import UUID
|
18
28
|
|
19
29
|
from pydantic import Field, model_validator
|
@@ -39,6 +49,9 @@ if TYPE_CHECKING:
|
|
39
49
|
ActionResponse,
|
40
50
|
)
|
41
51
|
from zenml.models.v2.core.event_source import EventSourceResponse
|
52
|
+
from zenml.zen_stores.schemas import BaseSchema
|
53
|
+
|
54
|
+
AnySchema = TypeVar("AnySchema", bound=BaseSchema)
|
42
55
|
|
43
56
|
|
44
57
|
# ------------------ Request Model ------------------
|
@@ -358,10 +371,13 @@ class TriggerFilter(WorkspaceScopedFilter):
|
|
358
371
|
)
|
359
372
|
|
360
373
|
def get_custom_filters(
|
361
|
-
self,
|
374
|
+
self, table: Type["AnySchema"]
|
362
375
|
) -> List["ColumnElement[bool]"]:
|
363
376
|
"""Get custom filters.
|
364
377
|
|
378
|
+
Args:
|
379
|
+
table: The query table.
|
380
|
+
|
365
381
|
Returns:
|
366
382
|
A list of custom filters.
|
367
383
|
"""
|
@@ -373,7 +389,7 @@ class TriggerFilter(WorkspaceScopedFilter):
|
|
373
389
|
TriggerSchema,
|
374
390
|
)
|
375
391
|
|
376
|
-
custom_filters = super().get_custom_filters()
|
392
|
+
custom_filters = super().get_custom_filters(table)
|
377
393
|
|
378
394
|
if self.event_source_flavor:
|
379
395
|
event_source_flavor_filter = and_(
|
@@ -179,12 +179,10 @@ class StepLauncher:
|
|
179
179
|
pipeline_run_id=pipeline_run.id,
|
180
180
|
pipeline_run_metadata=pipeline_run_metadata,
|
181
181
|
)
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
)
|
187
|
-
)
|
182
|
+
if model_version := pipeline_run.model_version:
|
183
|
+
step_run_utils.log_model_version_dashboard_url(
|
184
|
+
model_version=model_version
|
185
|
+
)
|
188
186
|
|
189
187
|
request_factory = step_run_utils.StepRunRequestFactory(
|
190
188
|
deployment=self._deployment,
|
@@ -209,12 +207,10 @@ class StepLauncher:
|
|
209
207
|
step_run = Client().zen_store.create_run_step(
|
210
208
|
step_run_request
|
211
209
|
)
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
step_run=step_run, pipeline_run=pipeline_run
|
210
|
+
if model_version := step_run.model_version:
|
211
|
+
step_run_utils.log_model_version_dashboard_url(
|
212
|
+
model_version=model_version
|
216
213
|
)
|
217
|
-
)
|
218
214
|
|
219
215
|
if not step_run.status.is_finished:
|
220
216
|
logger.info(f"Step `{self._step_name}` has started.")
|
@@ -289,8 +285,8 @@ class StepLauncher:
|
|
289
285
|
f"Using cached version of step `{self._step_name}`."
|
290
286
|
)
|
291
287
|
if (
|
292
|
-
model_version :=
|
293
|
-
or
|
288
|
+
model_version := step_run.model_version
|
289
|
+
or pipeline_run.model_version
|
294
290
|
):
|
295
291
|
step_run_utils.link_output_artifacts_to_model_version(
|
296
292
|
artifacts=step_run.outputs,
|
@@ -14,7 +14,7 @@
|
|
14
14
|
"""Utilities for creating step runs."""
|
15
15
|
|
16
16
|
from datetime import datetime
|
17
|
-
from typing import
|
17
|
+
from typing import Dict, List, Optional, Set, Tuple
|
18
18
|
|
19
19
|
from zenml.client import Client
|
20
20
|
from zenml.config.step_configurations import Step
|
@@ -24,21 +24,13 @@ from zenml.logger import get_logger
|
|
24
24
|
from zenml.model.utils import link_artifact_version_to_model_version
|
25
25
|
from zenml.models import (
|
26
26
|
ArtifactVersionResponse,
|
27
|
-
ModelVersionPipelineRunRequest,
|
28
27
|
ModelVersionResponse,
|
29
28
|
PipelineDeploymentResponse,
|
30
29
|
PipelineRunResponse,
|
31
|
-
PipelineRunUpdate,
|
32
30
|
StepRunRequest,
|
33
|
-
StepRunResponse,
|
34
|
-
StepRunUpdate,
|
35
31
|
)
|
36
32
|
from zenml.orchestrators import cache_utils, input_utils, utils
|
37
33
|
from zenml.stack import Stack
|
38
|
-
from zenml.utils import pagination_utils, string_utils
|
39
|
-
|
40
|
-
if TYPE_CHECKING:
|
41
|
-
from zenml.model.model import Model
|
42
34
|
|
43
35
|
logger = get_logger(__name__)
|
44
36
|
|
@@ -293,10 +285,6 @@ def create_cached_step_runs(
|
|
293
285
|
deployment=deployment, pipeline_run=pipeline_run, stack=stack
|
294
286
|
)
|
295
287
|
|
296
|
-
pipeline_model_version, pipeline_run = prepare_pipeline_run_model_version(
|
297
|
-
pipeline_run=pipeline_run
|
298
|
-
)
|
299
|
-
|
300
288
|
while (
|
301
289
|
cache_candidates := find_cacheable_invocation_candidates(
|
302
290
|
deployment=deployment,
|
@@ -311,7 +299,9 @@ def create_cached_step_runs(
|
|
311
299
|
|
312
300
|
# Make sure the request factory has the most up to date pipeline
|
313
301
|
# run to avoid hydration calls
|
314
|
-
request_factory.pipeline_run =
|
302
|
+
request_factory.pipeline_run = Client().get_pipeline_run(
|
303
|
+
pipeline_run.id
|
304
|
+
)
|
315
305
|
try:
|
316
306
|
step_run_request = request_factory.create_request(
|
317
307
|
invocation_id
|
@@ -336,15 +326,10 @@ def create_cached_step_runs(
|
|
336
326
|
|
337
327
|
step_run = Client().zen_store.create_run_step(step_run_request)
|
338
328
|
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
step_model_version, step_run = prepare_step_run_model_version(
|
344
|
-
step_run=step_run, pipeline_run=pipeline_run
|
345
|
-
)
|
346
|
-
|
347
|
-
if model_version := step_model_version or pipeline_model_version:
|
329
|
+
if (
|
330
|
+
model_version := step_run.model_version
|
331
|
+
or pipeline_run.model_version
|
332
|
+
):
|
348
333
|
link_output_artifacts_to_model_version(
|
349
334
|
artifacts=step_run.outputs,
|
350
335
|
model_version=model_version,
|
@@ -356,169 +341,6 @@ def create_cached_step_runs(
|
|
356
341
|
return cached_invocations
|
357
342
|
|
358
343
|
|
359
|
-
def get_or_create_model_version_for_pipeline_run(
|
360
|
-
model: "Model",
|
361
|
-
pipeline_run: PipelineRunResponse,
|
362
|
-
substitutions: Dict[str, str],
|
363
|
-
) -> Tuple[ModelVersionResponse, bool]:
|
364
|
-
"""Get or create a model version as part of a pipeline run.
|
365
|
-
|
366
|
-
Args:
|
367
|
-
model: The model to get or create.
|
368
|
-
pipeline_run: The pipeline run for which the model should be created.
|
369
|
-
substitutions: Substitutions to apply to the model version name.
|
370
|
-
|
371
|
-
Returns:
|
372
|
-
The model version and a boolean indicating whether it was newly created
|
373
|
-
or not.
|
374
|
-
"""
|
375
|
-
# Copy the model before modifying it so we don't accidently modify
|
376
|
-
# configurations in which the model object is potentially referenced
|
377
|
-
model = model.model_copy()
|
378
|
-
|
379
|
-
if model.model_version_id:
|
380
|
-
return model._get_model_version(), False
|
381
|
-
elif model.version:
|
382
|
-
if isinstance(model.version, str):
|
383
|
-
model.version = string_utils.format_name_template(
|
384
|
-
model.version,
|
385
|
-
substitutions=substitutions,
|
386
|
-
)
|
387
|
-
model.name = string_utils.format_name_template(
|
388
|
-
model.name,
|
389
|
-
substitutions=substitutions,
|
390
|
-
)
|
391
|
-
|
392
|
-
return (
|
393
|
-
model._get_or_create_model_version(),
|
394
|
-
model._created_model_version,
|
395
|
-
)
|
396
|
-
|
397
|
-
# The model version should be created as part of this run
|
398
|
-
# -> We first check if it was already created as part of this run, and if
|
399
|
-
# not we do create it. If this is running in two parallel steps, we might
|
400
|
-
# run into issues that this will create two versions. Ideally, all model
|
401
|
-
# versions required for a pipeline run and its steps could be created
|
402
|
-
# server-side at run creation time before the first step starts.
|
403
|
-
if model_version := get_model_version_created_by_pipeline_run(
|
404
|
-
model_name=model.name, pipeline_run=pipeline_run
|
405
|
-
):
|
406
|
-
return model_version, False
|
407
|
-
else:
|
408
|
-
return model._get_or_create_model_version(), True
|
409
|
-
|
410
|
-
|
411
|
-
def get_model_version_created_by_pipeline_run(
|
412
|
-
model_name: str, pipeline_run: PipelineRunResponse
|
413
|
-
) -> Optional[ModelVersionResponse]:
|
414
|
-
"""Get a model version that was created by a specific pipeline run.
|
415
|
-
|
416
|
-
This function does not refresh the pipeline run, so it will only try to
|
417
|
-
fetch the model version from existing steps if they're already part of the
|
418
|
-
response.
|
419
|
-
|
420
|
-
Args:
|
421
|
-
model_name: The model name for which to get the version.
|
422
|
-
pipeline_run: The pipeline run for which to get the version.
|
423
|
-
|
424
|
-
Returns:
|
425
|
-
A model version with the given name created by the run, or None if such
|
426
|
-
a model version does not exist.
|
427
|
-
"""
|
428
|
-
if pipeline_run.config.model and pipeline_run.model_version:
|
429
|
-
if (
|
430
|
-
pipeline_run.config.model.name == model_name
|
431
|
-
and pipeline_run.config.model.version is None
|
432
|
-
):
|
433
|
-
return pipeline_run.model_version
|
434
|
-
|
435
|
-
# We fetch a list of hydrated step runs here in order to avoid hydration
|
436
|
-
# calls for each step separately.
|
437
|
-
candidate_step_runs = pagination_utils.depaginate(
|
438
|
-
Client().list_run_steps,
|
439
|
-
pipeline_run_id=pipeline_run.id,
|
440
|
-
model=model_name,
|
441
|
-
hydrate=True,
|
442
|
-
)
|
443
|
-
for step_run in candidate_step_runs:
|
444
|
-
if step_run.config.model and step_run.model_version:
|
445
|
-
if (
|
446
|
-
step_run.config.model.name == model_name
|
447
|
-
and step_run.config.model.version is None
|
448
|
-
):
|
449
|
-
return step_run.model_version
|
450
|
-
|
451
|
-
return None
|
452
|
-
|
453
|
-
|
454
|
-
def prepare_pipeline_run_model_version(
|
455
|
-
pipeline_run: PipelineRunResponse,
|
456
|
-
) -> Tuple[Optional[ModelVersionResponse], PipelineRunResponse]:
|
457
|
-
"""Prepare the model version for a pipeline run.
|
458
|
-
|
459
|
-
Args:
|
460
|
-
pipeline_run: The pipeline run for which to prepare the model version.
|
461
|
-
|
462
|
-
Returns:
|
463
|
-
The prepared model version and the updated pipeline run.
|
464
|
-
"""
|
465
|
-
model_version = None
|
466
|
-
|
467
|
-
if pipeline_run.model_version:
|
468
|
-
model_version = pipeline_run.model_version
|
469
|
-
elif config_model := pipeline_run.config.model:
|
470
|
-
model_version, _ = get_or_create_model_version_for_pipeline_run(
|
471
|
-
model=config_model,
|
472
|
-
pipeline_run=pipeline_run,
|
473
|
-
substitutions=pipeline_run.config.substitutions,
|
474
|
-
)
|
475
|
-
pipeline_run = Client().zen_store.update_run(
|
476
|
-
run_id=pipeline_run.id,
|
477
|
-
run_update=PipelineRunUpdate(model_version_id=model_version.id),
|
478
|
-
)
|
479
|
-
link_pipeline_run_to_model_version(
|
480
|
-
pipeline_run=pipeline_run, model_version=model_version
|
481
|
-
)
|
482
|
-
log_model_version_dashboard_url(model_version)
|
483
|
-
|
484
|
-
return model_version, pipeline_run
|
485
|
-
|
486
|
-
|
487
|
-
def prepare_step_run_model_version(
|
488
|
-
step_run: StepRunResponse, pipeline_run: PipelineRunResponse
|
489
|
-
) -> Tuple[Optional[ModelVersionResponse], StepRunResponse]:
|
490
|
-
"""Prepare the model version for a step run.
|
491
|
-
|
492
|
-
Args:
|
493
|
-
step_run: The step run for which to prepare the model version.
|
494
|
-
pipeline_run: The pipeline run of the step.
|
495
|
-
|
496
|
-
Returns:
|
497
|
-
The prepared model version and the updated step run.
|
498
|
-
"""
|
499
|
-
model_version = None
|
500
|
-
|
501
|
-
if step_run.model_version:
|
502
|
-
model_version = step_run.model_version
|
503
|
-
elif config_model := step_run.config.model:
|
504
|
-
model_version, created = get_or_create_model_version_for_pipeline_run(
|
505
|
-
model=config_model,
|
506
|
-
pipeline_run=pipeline_run,
|
507
|
-
substitutions=step_run.config.substitutions,
|
508
|
-
)
|
509
|
-
step_run = Client().zen_store.update_run_step(
|
510
|
-
step_run_id=step_run.id,
|
511
|
-
step_run_update=StepRunUpdate(model_version_id=model_version.id),
|
512
|
-
)
|
513
|
-
link_pipeline_run_to_model_version(
|
514
|
-
pipeline_run=pipeline_run, model_version=model_version
|
515
|
-
)
|
516
|
-
if created:
|
517
|
-
log_model_version_dashboard_url(model_version)
|
518
|
-
|
519
|
-
return model_version, step_run
|
520
|
-
|
521
|
-
|
522
344
|
def log_model_version_dashboard_url(
|
523
345
|
model_version: ModelVersionResponse,
|
524
346
|
) -> None:
|
@@ -546,24 +368,6 @@ def log_model_version_dashboard_url(
|
|
546
368
|
)
|
547
369
|
|
548
370
|
|
549
|
-
def link_pipeline_run_to_model_version(
|
550
|
-
pipeline_run: PipelineRunResponse, model_version: ModelVersionResponse
|
551
|
-
) -> None:
|
552
|
-
"""Link a pipeline run to a model version.
|
553
|
-
|
554
|
-
Args:
|
555
|
-
pipeline_run: The pipeline run to link.
|
556
|
-
model_version: The model version to link.
|
557
|
-
"""
|
558
|
-
client = Client()
|
559
|
-
client.zen_store.create_model_version_pipeline_run_link(
|
560
|
-
ModelVersionPipelineRunRequest(
|
561
|
-
pipeline_run=pipeline_run.id,
|
562
|
-
model_version=model_version.id,
|
563
|
-
)
|
564
|
-
)
|
565
|
-
|
566
|
-
|
567
371
|
def link_output_artifacts_to_model_version(
|
568
372
|
artifacts: Dict[str, List[ArtifactVersionResponse]],
|
569
373
|
model_version: ModelVersionResponse,
|