zenml-nightly 0.71.0.dev20241213__py3-none-any.whl → 0.71.0.dev20241214__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. zenml/VERSION +1 -1
  2. zenml/client.py +44 -2
  3. zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py +0 -1
  4. zenml/model/model.py +12 -16
  5. zenml/models/v2/base/filter.py +26 -30
  6. zenml/models/v2/base/scoped.py +258 -5
  7. zenml/models/v2/core/artifact_version.py +15 -26
  8. zenml/models/v2/core/code_repository.py +1 -12
  9. zenml/models/v2/core/component.py +5 -46
  10. zenml/models/v2/core/flavor.py +1 -11
  11. zenml/models/v2/core/model.py +1 -57
  12. zenml/models/v2/core/model_version.py +5 -33
  13. zenml/models/v2/core/model_version_artifact.py +11 -3
  14. zenml/models/v2/core/model_version_pipeline_run.py +14 -3
  15. zenml/models/v2/core/pipeline.py +47 -55
  16. zenml/models/v2/core/pipeline_build.py +19 -12
  17. zenml/models/v2/core/pipeline_deployment.py +0 -10
  18. zenml/models/v2/core/pipeline_run.py +91 -29
  19. zenml/models/v2/core/run_template.py +21 -29
  20. zenml/models/v2/core/schedule.py +0 -10
  21. zenml/models/v2/core/secret.py +0 -14
  22. zenml/models/v2/core/service.py +9 -16
  23. zenml/models/v2/core/service_connector.py +0 -11
  24. zenml/models/v2/core/stack.py +21 -30
  25. zenml/models/v2/core/step_run.py +18 -14
  26. zenml/models/v2/core/trigger.py +19 -3
  27. zenml/orchestrators/step_launcher.py +9 -13
  28. zenml/orchestrators/step_run_utils.py +8 -204
  29. zenml/zen_server/rbac/rbac_sql_zen_store.py +173 -0
  30. zenml/zen_server/utils.py +4 -3
  31. zenml/zen_stores/base_zen_store.py +10 -2
  32. zenml/zen_stores/migrations/versions/a1237ba94fd8_add_model_version_producer_run_unique_.py +68 -0
  33. zenml/zen_stores/schemas/model_schemas.py +42 -6
  34. zenml/zen_stores/schemas/pipeline_deployment_schemas.py +7 -7
  35. zenml/zen_stores/schemas/pipeline_schemas.py +5 -0
  36. zenml/zen_stores/sql_zen_store.py +322 -86
  37. {zenml_nightly-0.71.0.dev20241213.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/METADATA +1 -1
  38. {zenml_nightly-0.71.0.dev20241213.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/RECORD +41 -39
  39. {zenml_nightly-0.71.0.dev20241213.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/LICENSE +0 -0
  40. {zenml_nightly-0.71.0.dev20241213.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/WHEEL +0 -0
  41. {zenml_nightly-0.71.0.dev20241213.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/entry_points.txt +0 -0
@@ -279,16 +279,6 @@ class ScheduleResponse(
279
279
  class ScheduleFilter(WorkspaceScopedFilter):
280
280
  """Model to enable advanced filtering of all Users."""
281
281
 
282
- workspace_id: Optional[Union[UUID, str]] = Field(
283
- default=None,
284
- description="Workspace scope of the schedule.",
285
- union_mode="left_to_right",
286
- )
287
- user_id: Optional[Union[UUID, str]] = Field(
288
- default=None,
289
- description="User that created the schedule",
290
- union_mode="left_to_right",
291
- )
292
282
  pipeline_id: Optional[Union[UUID, str]] = Field(
293
283
  default=None,
294
284
  description="Pipeline that the schedule is attached to.",
@@ -15,7 +15,6 @@
15
15
 
16
16
  from datetime import datetime
17
17
  from typing import Any, ClassVar, Dict, List, Optional, Union
18
- from uuid import UUID
19
18
 
20
19
  from pydantic import Field, SecretStr
21
20
 
@@ -253,25 +252,12 @@ class SecretFilter(WorkspaceScopedFilter):
253
252
  default=None,
254
253
  description="Name of the secret",
255
254
  )
256
-
257
255
  scope: Optional[Union[SecretScope, str]] = Field(
258
256
  default=None,
259
257
  description="Scope in which to filter secrets",
260
258
  union_mode="left_to_right",
261
259
  )
262
260
 
263
- workspace_id: Optional[Union[UUID, str]] = Field(
264
- default=None,
265
- description="Workspace of the Secret",
266
- union_mode="left_to_right",
267
- )
268
-
269
- user_id: Optional[Union[UUID, str]] = Field(
270
- default=None,
271
- description="User that created the Secret",
272
- union_mode="left_to_right",
273
- )
274
-
275
261
  @staticmethod
276
262
  def _get_filtering_value(value: Optional[Any]) -> str:
277
263
  """Convert the value to a string that can be used for lexicographical filtering and sorting.
@@ -15,19 +15,20 @@
15
15
 
16
16
  from datetime import datetime
17
17
  from typing import (
18
+ TYPE_CHECKING,
18
19
  Any,
19
20
  ClassVar,
20
21
  Dict,
21
22
  List,
22
23
  Optional,
23
24
  Type,
25
+ TypeVar,
24
26
  Union,
25
27
  )
26
28
  from uuid import UUID
27
29
 
28
30
  from pydantic import BaseModel, ConfigDict, Field
29
31
  from sqlalchemy.sql.elements import ColumnElement
30
- from sqlmodel import SQLModel
31
32
 
32
33
  from zenml.constants import STR_FIELD_MAX_LENGTH
33
34
  from zenml.models.v2.base.scoped import (
@@ -37,11 +38,15 @@ from zenml.models.v2.base.scoped import (
37
38
  WorkspaceScopedResponseBody,
38
39
  WorkspaceScopedResponseMetadata,
39
40
  WorkspaceScopedResponseResources,
40
- WorkspaceScopedTaggableFilter,
41
41
  )
42
42
  from zenml.services.service_status import ServiceState
43
43
  from zenml.services.service_type import ServiceType
44
44
 
45
+ if TYPE_CHECKING:
46
+ from zenml.zen_stores.schemas import BaseSchema
47
+
48
+ AnySchema = TypeVar("AnySchema", bound=BaseSchema)
49
+
45
50
  # ------------------ Request Model ------------------
46
51
 
47
52
 
@@ -376,16 +381,6 @@ class ServiceFilter(WorkspaceScopedFilter):
376
381
  description="Name of the service. Use this to filter services by "
377
382
  "their name.",
378
383
  )
379
- workspace_id: Optional[Union[UUID, str]] = Field(
380
- default=None,
381
- description="Workspace of the service",
382
- union_mode="left_to_right",
383
- )
384
- user_id: Optional[Union[UUID, str]] = Field(
385
- default=None,
386
- description="User of the service",
387
- union_mode="left_to_right",
388
- )
389
384
  type: Optional[str] = Field(
390
385
  default=None,
391
386
  description="Type of the service. Filter services by their type.",
@@ -457,9 +452,7 @@ class ServiceFilter(WorkspaceScopedFilter):
457
452
  "config",
458
453
  ]
459
454
  CLI_EXCLUDE_FIELDS: ClassVar[List[str]] = [
460
- *WorkspaceScopedTaggableFilter.CLI_EXCLUDE_FIELDS,
461
- "workspace_id",
462
- "user_id",
455
+ *WorkspaceScopedFilter.CLI_EXCLUDE_FIELDS,
463
456
  "flavor",
464
457
  "type",
465
458
  "pipeline_step_name",
@@ -468,7 +461,7 @@ class ServiceFilter(WorkspaceScopedFilter):
468
461
  ]
469
462
 
470
463
  def generate_filter(
471
- self, table: Type["SQLModel"]
464
+ self, table: Type["AnySchema"]
472
465
  ) -> Union["ColumnElement[bool]"]:
473
466
  """Generate the filter for the query.
474
467
 
@@ -801,7 +801,6 @@ class ServiceConnectorFilter(WorkspaceScopedFilter):
801
801
  default=None,
802
802
  description="The type to scope this query to.",
803
803
  )
804
-
805
804
  name: Optional[str] = Field(
806
805
  default=None,
807
806
  description="The name to filter by",
@@ -810,16 +809,6 @@ class ServiceConnectorFilter(WorkspaceScopedFilter):
810
809
  default=None,
811
810
  description="The type of service connector to filter by",
812
811
  )
813
- workspace_id: Optional[Union[UUID, str]] = Field(
814
- default=None,
815
- description="Workspace to filter by",
816
- union_mode="left_to_right",
817
- )
818
- user_id: Optional[Union[UUID, str]] = Field(
819
- default=None,
820
- description="User to filter by",
821
- union_mode="left_to_right",
822
- )
823
812
  auth_method: Optional[str] = Field(
824
813
  default=None,
825
814
  title="Filter by the authentication method configured for the "
@@ -14,7 +14,17 @@
14
14
  """Models representing stacks."""
15
15
 
16
16
  import json
17
- from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Union
17
+ from typing import (
18
+ TYPE_CHECKING,
19
+ Any,
20
+ ClassVar,
21
+ Dict,
22
+ List,
23
+ Optional,
24
+ Type,
25
+ TypeVar,
26
+ Union,
27
+ )
18
28
  from uuid import UUID
19
29
 
20
30
  from pydantic import Field, model_validator
@@ -39,6 +49,9 @@ if TYPE_CHECKING:
39
49
  from sqlalchemy.sql.elements import ColumnElement
40
50
 
41
51
  from zenml.models.v2.core.component import ComponentResponse
52
+ from zenml.zen_stores.schemas import BaseSchema
53
+
54
+ AnySchema = TypeVar("AnySchema", bound=BaseSchema)
42
55
 
43
56
 
44
57
  # ------------------ Request Model ------------------
@@ -323,7 +336,6 @@ class StackFilter(WorkspaceScopedFilter):
323
336
  FILTER_EXCLUDE_FIELDS: ClassVar[List[str]] = [
324
337
  *WorkspaceScopedFilter.FILTER_EXCLUDE_FIELDS,
325
338
  "component_id",
326
- "user",
327
339
  "component",
328
340
  ]
329
341
 
@@ -334,42 +346,32 @@ class StackFilter(WorkspaceScopedFilter):
334
346
  description: Optional[str] = Field(
335
347
  default=None, description="Description of the stack"
336
348
  )
337
- workspace_id: Optional[Union[UUID, str]] = Field(
338
- default=None,
339
- description="Workspace of the stack",
340
- union_mode="left_to_right",
341
- )
342
- user_id: Optional[Union[UUID, str]] = Field(
343
- default=None,
344
- description="User of the stack",
345
- union_mode="left_to_right",
346
- )
347
349
  component_id: Optional[Union[UUID, str]] = Field(
348
350
  default=None,
349
351
  description="Component in the stack",
350
352
  union_mode="left_to_right",
351
353
  )
352
- user: Optional[Union[UUID, str]] = Field(
353
- default=None,
354
- description="Name/ID of the user that created the stack.",
355
- )
356
354
  component: Optional[Union[UUID, str]] = Field(
357
355
  default=None, description="Name/ID of a component in the stack."
358
356
  )
359
357
 
360
- def get_custom_filters(self) -> List["ColumnElement[bool]"]:
358
+ def get_custom_filters(
359
+ self, table: Type["AnySchema"]
360
+ ) -> List["ColumnElement[bool]"]:
361
361
  """Get custom filters.
362
362
 
363
+ Args:
364
+ table: The query table.
365
+
363
366
  Returns:
364
367
  A list of custom filters.
365
368
  """
366
- custom_filters = super().get_custom_filters()
369
+ custom_filters = super().get_custom_filters(table)
367
370
 
368
371
  from zenml.zen_stores.schemas import (
369
372
  StackComponentSchema,
370
373
  StackCompositionSchema,
371
374
  StackSchema,
372
- UserSchema,
373
375
  )
374
376
 
375
377
  if self.component_id:
@@ -379,17 +381,6 @@ class StackFilter(WorkspaceScopedFilter):
379
381
  )
380
382
  custom_filters.append(component_id_filter)
381
383
 
382
- if self.user:
383
- user_filter = and_(
384
- StackSchema.user_id == UserSchema.id,
385
- self.generate_name_or_id_query_conditions(
386
- value=self.user,
387
- table=UserSchema,
388
- additional_columns=["full_name"],
389
- ),
390
- )
391
- custom_filters.append(user_filter)
392
-
393
384
  if self.component:
394
385
  component_filter = and_(
395
386
  StackCompositionSchema.stack_id == StackSchema.id,
@@ -14,7 +14,16 @@
14
14
  """Models representing steps runs."""
15
15
 
16
16
  from datetime import datetime
17
- from typing import TYPE_CHECKING, ClassVar, Dict, List, Optional, Union
17
+ from typing import (
18
+ TYPE_CHECKING,
19
+ ClassVar,
20
+ Dict,
21
+ List,
22
+ Optional,
23
+ Type,
24
+ TypeVar,
25
+ Union,
26
+ )
18
27
  from uuid import UUID
19
28
 
20
29
  from pydantic import BaseModel, ConfigDict, Field
@@ -41,6 +50,9 @@ if TYPE_CHECKING:
41
50
  LogsRequest,
42
51
  LogsResponse,
43
52
  )
53
+ from zenml.zen_stores.schemas import BaseSchema
54
+
55
+ AnySchema = TypeVar("AnySchema", bound=BaseSchema)
44
56
 
45
57
 
46
58
  class StepRunInputResponse(ArtifactVersionResponse):
@@ -553,16 +565,6 @@ class StepRunFilter(WorkspaceScopedFilter):
553
565
  description="Original id for this step run",
554
566
  union_mode="left_to_right",
555
567
  )
556
- user_id: Optional[Union[UUID, str]] = Field(
557
- default=None,
558
- description="User that produced this step run",
559
- union_mode="left_to_right",
560
- )
561
- workspace_id: Optional[Union[UUID, str]] = Field(
562
- default=None,
563
- description="Workspace of this step run",
564
- union_mode="left_to_right",
565
- )
566
568
  model_version_id: Optional[Union[UUID, str]] = Field(
567
569
  default=None,
568
570
  description="Model version associated with the step run.",
@@ -576,18 +578,20 @@ class StepRunFilter(WorkspaceScopedFilter):
576
578
  default=None,
577
579
  description="The run_metadata to filter the step runs by.",
578
580
  )
579
-
580
581
  model_config = ConfigDict(protected_namespaces=())
581
582
 
582
583
  def get_custom_filters(
583
- self,
584
+ self, table: Type["AnySchema"]
584
585
  ) -> List["ColumnElement[bool]"]:
585
586
  """Get custom filters.
586
587
 
588
+ Args:
589
+ table: The query table.
590
+
587
591
  Returns:
588
592
  A list of custom filters.
589
593
  """
590
- custom_filters = super().get_custom_filters()
594
+ custom_filters = super().get_custom_filters(table)
591
595
 
592
596
  from sqlmodel import and_
593
597
 
@@ -13,7 +13,17 @@
13
13
  # permissions and limitations under the License.
14
14
  """Collection of all models concerning triggers."""
15
15
 
16
- from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Union
16
+ from typing import (
17
+ TYPE_CHECKING,
18
+ Any,
19
+ ClassVar,
20
+ Dict,
21
+ List,
22
+ Optional,
23
+ Type,
24
+ TypeVar,
25
+ Union,
26
+ )
17
27
  from uuid import UUID
18
28
 
19
29
  from pydantic import Field, model_validator
@@ -39,6 +49,9 @@ if TYPE_CHECKING:
39
49
  ActionResponse,
40
50
  )
41
51
  from zenml.models.v2.core.event_source import EventSourceResponse
52
+ from zenml.zen_stores.schemas import BaseSchema
53
+
54
+ AnySchema = TypeVar("AnySchema", bound=BaseSchema)
42
55
 
43
56
 
44
57
  # ------------------ Request Model ------------------
@@ -358,10 +371,13 @@ class TriggerFilter(WorkspaceScopedFilter):
358
371
  )
359
372
 
360
373
  def get_custom_filters(
361
- self,
374
+ self, table: Type["AnySchema"]
362
375
  ) -> List["ColumnElement[bool]"]:
363
376
  """Get custom filters.
364
377
 
378
+ Args:
379
+ table: The query table.
380
+
365
381
  Returns:
366
382
  A list of custom filters.
367
383
  """
@@ -373,7 +389,7 @@ class TriggerFilter(WorkspaceScopedFilter):
373
389
  TriggerSchema,
374
390
  )
375
391
 
376
- custom_filters = super().get_custom_filters()
392
+ custom_filters = super().get_custom_filters(table)
377
393
 
378
394
  if self.event_source_flavor:
379
395
  event_source_flavor_filter = and_(
@@ -179,12 +179,10 @@ class StepLauncher:
179
179
  pipeline_run_id=pipeline_run.id,
180
180
  pipeline_run_metadata=pipeline_run_metadata,
181
181
  )
182
-
183
- pipeline_model_version, pipeline_run = (
184
- step_run_utils.prepare_pipeline_run_model_version(
185
- pipeline_run
186
- )
187
- )
182
+ if model_version := pipeline_run.model_version:
183
+ step_run_utils.log_model_version_dashboard_url(
184
+ model_version=model_version
185
+ )
188
186
 
189
187
  request_factory = step_run_utils.StepRunRequestFactory(
190
188
  deployment=self._deployment,
@@ -209,12 +207,10 @@ class StepLauncher:
209
207
  step_run = Client().zen_store.create_run_step(
210
208
  step_run_request
211
209
  )
212
-
213
- step_model_version, step_run = (
214
- step_run_utils.prepare_step_run_model_version(
215
- step_run=step_run, pipeline_run=pipeline_run
210
+ if model_version := step_run.model_version:
211
+ step_run_utils.log_model_version_dashboard_url(
212
+ model_version=model_version
216
213
  )
217
- )
218
214
 
219
215
  if not step_run.status.is_finished:
220
216
  logger.info(f"Step `{self._step_name}` has started.")
@@ -289,8 +285,8 @@ class StepLauncher:
289
285
  f"Using cached version of step `{self._step_name}`."
290
286
  )
291
287
  if (
292
- model_version := step_model_version
293
- or pipeline_model_version
288
+ model_version := step_run.model_version
289
+ or pipeline_run.model_version
294
290
  ):
295
291
  step_run_utils.link_output_artifacts_to_model_version(
296
292
  artifacts=step_run.outputs,
@@ -14,7 +14,7 @@
14
14
  """Utilities for creating step runs."""
15
15
 
16
16
  from datetime import datetime
17
- from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
17
+ from typing import Dict, List, Optional, Set, Tuple
18
18
 
19
19
  from zenml.client import Client
20
20
  from zenml.config.step_configurations import Step
@@ -24,21 +24,13 @@ from zenml.logger import get_logger
24
24
  from zenml.model.utils import link_artifact_version_to_model_version
25
25
  from zenml.models import (
26
26
  ArtifactVersionResponse,
27
- ModelVersionPipelineRunRequest,
28
27
  ModelVersionResponse,
29
28
  PipelineDeploymentResponse,
30
29
  PipelineRunResponse,
31
- PipelineRunUpdate,
32
30
  StepRunRequest,
33
- StepRunResponse,
34
- StepRunUpdate,
35
31
  )
36
32
  from zenml.orchestrators import cache_utils, input_utils, utils
37
33
  from zenml.stack import Stack
38
- from zenml.utils import pagination_utils, string_utils
39
-
40
- if TYPE_CHECKING:
41
- from zenml.model.model import Model
42
34
 
43
35
  logger = get_logger(__name__)
44
36
 
@@ -293,10 +285,6 @@ def create_cached_step_runs(
293
285
  deployment=deployment, pipeline_run=pipeline_run, stack=stack
294
286
  )
295
287
 
296
- pipeline_model_version, pipeline_run = prepare_pipeline_run_model_version(
297
- pipeline_run=pipeline_run
298
- )
299
-
300
288
  while (
301
289
  cache_candidates := find_cacheable_invocation_candidates(
302
290
  deployment=deployment,
@@ -311,7 +299,9 @@ def create_cached_step_runs(
311
299
 
312
300
  # Make sure the request factory has the most up to date pipeline
313
301
  # run to avoid hydration calls
314
- request_factory.pipeline_run = pipeline_run
302
+ request_factory.pipeline_run = Client().get_pipeline_run(
303
+ pipeline_run.id
304
+ )
315
305
  try:
316
306
  step_run_request = request_factory.create_request(
317
307
  invocation_id
@@ -336,15 +326,10 @@ def create_cached_step_runs(
336
326
 
337
327
  step_run = Client().zen_store.create_run_step(step_run_request)
338
328
 
339
- # Refresh the pipeline run here to make sure we have the latest
340
- # state
341
- pipeline_run = Client().get_pipeline_run(pipeline_run.id)
342
-
343
- step_model_version, step_run = prepare_step_run_model_version(
344
- step_run=step_run, pipeline_run=pipeline_run
345
- )
346
-
347
- if model_version := step_model_version or pipeline_model_version:
329
+ if (
330
+ model_version := step_run.model_version
331
+ or pipeline_run.model_version
332
+ ):
348
333
  link_output_artifacts_to_model_version(
349
334
  artifacts=step_run.outputs,
350
335
  model_version=model_version,
@@ -356,169 +341,6 @@ def create_cached_step_runs(
356
341
  return cached_invocations
357
342
 
358
343
 
359
- def get_or_create_model_version_for_pipeline_run(
360
- model: "Model",
361
- pipeline_run: PipelineRunResponse,
362
- substitutions: Dict[str, str],
363
- ) -> Tuple[ModelVersionResponse, bool]:
364
- """Get or create a model version as part of a pipeline run.
365
-
366
- Args:
367
- model: The model to get or create.
368
- pipeline_run: The pipeline run for which the model should be created.
369
- substitutions: Substitutions to apply to the model version name.
370
-
371
- Returns:
372
- The model version and a boolean indicating whether it was newly created
373
- or not.
374
- """
375
- # Copy the model before modifying it so we don't accidently modify
376
- # configurations in which the model object is potentially referenced
377
- model = model.model_copy()
378
-
379
- if model.model_version_id:
380
- return model._get_model_version(), False
381
- elif model.version:
382
- if isinstance(model.version, str):
383
- model.version = string_utils.format_name_template(
384
- model.version,
385
- substitutions=substitutions,
386
- )
387
- model.name = string_utils.format_name_template(
388
- model.name,
389
- substitutions=substitutions,
390
- )
391
-
392
- return (
393
- model._get_or_create_model_version(),
394
- model._created_model_version,
395
- )
396
-
397
- # The model version should be created as part of this run
398
- # -> We first check if it was already created as part of this run, and if
399
- # not we do create it. If this is running in two parallel steps, we might
400
- # run into issues that this will create two versions. Ideally, all model
401
- # versions required for a pipeline run and its steps could be created
402
- # server-side at run creation time before the first step starts.
403
- if model_version := get_model_version_created_by_pipeline_run(
404
- model_name=model.name, pipeline_run=pipeline_run
405
- ):
406
- return model_version, False
407
- else:
408
- return model._get_or_create_model_version(), True
409
-
410
-
411
- def get_model_version_created_by_pipeline_run(
412
- model_name: str, pipeline_run: PipelineRunResponse
413
- ) -> Optional[ModelVersionResponse]:
414
- """Get a model version that was created by a specific pipeline run.
415
-
416
- This function does not refresh the pipeline run, so it will only try to
417
- fetch the model version from existing steps if they're already part of the
418
- response.
419
-
420
- Args:
421
- model_name: The model name for which to get the version.
422
- pipeline_run: The pipeline run for which to get the version.
423
-
424
- Returns:
425
- A model version with the given name created by the run, or None if such
426
- a model version does not exist.
427
- """
428
- if pipeline_run.config.model and pipeline_run.model_version:
429
- if (
430
- pipeline_run.config.model.name == model_name
431
- and pipeline_run.config.model.version is None
432
- ):
433
- return pipeline_run.model_version
434
-
435
- # We fetch a list of hydrated step runs here in order to avoid hydration
436
- # calls for each step separately.
437
- candidate_step_runs = pagination_utils.depaginate(
438
- Client().list_run_steps,
439
- pipeline_run_id=pipeline_run.id,
440
- model=model_name,
441
- hydrate=True,
442
- )
443
- for step_run in candidate_step_runs:
444
- if step_run.config.model and step_run.model_version:
445
- if (
446
- step_run.config.model.name == model_name
447
- and step_run.config.model.version is None
448
- ):
449
- return step_run.model_version
450
-
451
- return None
452
-
453
-
454
- def prepare_pipeline_run_model_version(
455
- pipeline_run: PipelineRunResponse,
456
- ) -> Tuple[Optional[ModelVersionResponse], PipelineRunResponse]:
457
- """Prepare the model version for a pipeline run.
458
-
459
- Args:
460
- pipeline_run: The pipeline run for which to prepare the model version.
461
-
462
- Returns:
463
- The prepared model version and the updated pipeline run.
464
- """
465
- model_version = None
466
-
467
- if pipeline_run.model_version:
468
- model_version = pipeline_run.model_version
469
- elif config_model := pipeline_run.config.model:
470
- model_version, _ = get_or_create_model_version_for_pipeline_run(
471
- model=config_model,
472
- pipeline_run=pipeline_run,
473
- substitutions=pipeline_run.config.substitutions,
474
- )
475
- pipeline_run = Client().zen_store.update_run(
476
- run_id=pipeline_run.id,
477
- run_update=PipelineRunUpdate(model_version_id=model_version.id),
478
- )
479
- link_pipeline_run_to_model_version(
480
- pipeline_run=pipeline_run, model_version=model_version
481
- )
482
- log_model_version_dashboard_url(model_version)
483
-
484
- return model_version, pipeline_run
485
-
486
-
487
- def prepare_step_run_model_version(
488
- step_run: StepRunResponse, pipeline_run: PipelineRunResponse
489
- ) -> Tuple[Optional[ModelVersionResponse], StepRunResponse]:
490
- """Prepare the model version for a step run.
491
-
492
- Args:
493
- step_run: The step run for which to prepare the model version.
494
- pipeline_run: The pipeline run of the step.
495
-
496
- Returns:
497
- The prepared model version and the updated step run.
498
- """
499
- model_version = None
500
-
501
- if step_run.model_version:
502
- model_version = step_run.model_version
503
- elif config_model := step_run.config.model:
504
- model_version, created = get_or_create_model_version_for_pipeline_run(
505
- model=config_model,
506
- pipeline_run=pipeline_run,
507
- substitutions=step_run.config.substitutions,
508
- )
509
- step_run = Client().zen_store.update_run_step(
510
- step_run_id=step_run.id,
511
- step_run_update=StepRunUpdate(model_version_id=model_version.id),
512
- )
513
- link_pipeline_run_to_model_version(
514
- pipeline_run=pipeline_run, model_version=model_version
515
- )
516
- if created:
517
- log_model_version_dashboard_url(model_version)
518
-
519
- return model_version, step_run
520
-
521
-
522
344
  def log_model_version_dashboard_url(
523
345
  model_version: ModelVersionResponse,
524
346
  ) -> None:
@@ -546,24 +368,6 @@ def log_model_version_dashboard_url(
546
368
  )
547
369
 
548
370
 
549
- def link_pipeline_run_to_model_version(
550
- pipeline_run: PipelineRunResponse, model_version: ModelVersionResponse
551
- ) -> None:
552
- """Link a pipeline run to a model version.
553
-
554
- Args:
555
- pipeline_run: The pipeline run to link.
556
- model_version: The model version to link.
557
- """
558
- client = Client()
559
- client.zen_store.create_model_version_pipeline_run_link(
560
- ModelVersionPipelineRunRequest(
561
- pipeline_run=pipeline_run.id,
562
- model_version=model_version.id,
563
- )
564
- )
565
-
566
-
567
371
  def link_output_artifacts_to_model_version(
568
372
  artifacts: Dict[str, List[ArtifactVersionResponse]],
569
373
  model_version: ModelVersionResponse,