zenml-nightly 0.68.1.dev20241107__py3-none-any.whl → 0.68.1.dev20241108__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. zenml/VERSION +1 -1
  2. zenml/artifacts/external_artifact.py +2 -1
  3. zenml/artifacts/utils.py +13 -20
  4. zenml/cli/base.py +4 -4
  5. zenml/cli/model.py +1 -6
  6. zenml/cli/stack.py +1 -0
  7. zenml/client.py +21 -73
  8. zenml/enums.py +12 -4
  9. zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py +1 -1
  10. zenml/integrations/azure/orchestrators/azureml_orchestrator.py +1 -1
  11. zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +1 -1
  12. zenml/integrations/tensorboard/visualizers/tensorboard_visualizer.py +60 -54
  13. zenml/metadata/lazy_load.py +20 -7
  14. zenml/model/model.py +1 -2
  15. zenml/models/__init__.py +0 -12
  16. zenml/models/v2/core/artifact_version.py +19 -7
  17. zenml/models/v2/core/model_version.py +3 -5
  18. zenml/models/v2/core/pipeline_run.py +3 -5
  19. zenml/models/v2/core/run_metadata.py +2 -217
  20. zenml/models/v2/core/step_run.py +40 -24
  21. zenml/orchestrators/input_utils.py +44 -19
  22. zenml/orchestrators/step_launcher.py +2 -2
  23. zenml/orchestrators/step_run_utils.py +19 -15
  24. zenml/orchestrators/step_runner.py +8 -3
  25. zenml/steps/base_step.py +1 -1
  26. zenml/steps/entrypoint_function_utils.py +3 -5
  27. zenml/steps/step_context.py +3 -2
  28. zenml/steps/utils.py +8 -2
  29. zenml/zen_server/rbac/utils.py +0 -2
  30. zenml/zen_server/routers/workspaces_endpoints.py +3 -4
  31. zenml/zen_server/zen_server_api.py +0 -2
  32. zenml/zen_stores/migrations/versions/1cb6477f72d6_move_artifact_save_type.py +99 -0
  33. zenml/zen_stores/migrations/versions/b557b2871693_update_step_run_input_types.py +33 -0
  34. zenml/zen_stores/rest_zen_store.py +3 -54
  35. zenml/zen_stores/schemas/artifact_schemas.py +8 -1
  36. zenml/zen_stores/schemas/model_schemas.py +2 -2
  37. zenml/zen_stores/schemas/pipeline_run_schemas.py +1 -1
  38. zenml/zen_stores/schemas/run_metadata_schemas.py +1 -48
  39. zenml/zen_stores/schemas/step_run_schemas.py +18 -10
  40. zenml/zen_stores/sql_zen_store.py +52 -98
  41. zenml/zen_stores/zen_store_interface.py +2 -42
  42. {zenml_nightly-0.68.1.dev20241107.dist-info → zenml_nightly-0.68.1.dev20241108.dist-info}/METADATA +1 -1
  43. {zenml_nightly-0.68.1.dev20241107.dist-info → zenml_nightly-0.68.1.dev20241108.dist-info}/RECORD +46 -45
  44. zenml/zen_server/routers/run_metadata_endpoints.py +0 -96
  45. {zenml_nightly-0.68.1.dev20241107.dist-info → zenml_nightly-0.68.1.dev20241108.dist-info}/LICENSE +0 -0
  46. {zenml_nightly-0.68.1.dev20241107.dist-info → zenml_nightly-0.68.1.dev20241108.dist-info}/WHEEL +0 -0
  47. {zenml_nightly-0.68.1.dev20241107.dist-info → zenml_nightly-0.68.1.dev20241108.dist-info}/entry_points.txt +0 -0
zenml/VERSION CHANGED
@@ -1 +1 @@
1
- 0.68.1.dev20241107
1
+ 0.68.1.dev20241108
@@ -23,6 +23,7 @@ from zenml.artifacts.external_artifact_config import (
23
23
  ExternalArtifactConfiguration,
24
24
  )
25
25
  from zenml.config.source import Source
26
+ from zenml.enums import ArtifactSaveType
26
27
  from zenml.logger import get_logger
27
28
  from zenml.materializers.base_materializer import BaseMaterializer
28
29
 
@@ -114,7 +115,7 @@ class ExternalArtifact(ExternalArtifactConfiguration):
114
115
  materializer=self.materializer,
115
116
  uri=uri,
116
117
  has_custom_name=False,
117
- manual_save=False,
118
+ save_type=ArtifactSaveType.EXTERNAL,
118
119
  )
119
120
 
120
121
  # To avoid duplicate uploads, switch to referencing the uploaded
zenml/artifacts/utils.py CHANGED
@@ -39,6 +39,7 @@ from zenml.constants import (
39
39
  MODEL_METADATA_YAML_FILE_NAME,
40
40
  )
41
41
  from zenml.enums import (
42
+ ArtifactSaveType,
42
43
  ArtifactType,
43
44
  ExecutionStatus,
44
45
  MetadataResourceTypes,
@@ -115,6 +116,7 @@ def _store_artifact_data_and_prepare_request(
115
116
  name: str,
116
117
  uri: str,
117
118
  materializer_class: Type["BaseMaterializer"],
119
+ save_type: ArtifactSaveType,
118
120
  version: Optional[Union[int, str]] = None,
119
121
  tags: Optional[List[str]] = None,
120
122
  store_metadata: bool = True,
@@ -130,6 +132,7 @@ def _store_artifact_data_and_prepare_request(
130
132
  uri: The artifact URI.
131
133
  materializer_class: The materializer class to use for storing the
132
134
  artifact data.
135
+ save_type: Save type of the artifact version.
133
136
  version: The artifact version.
134
137
  tags: Tags for the artifact version.
135
138
  store_metadata: Whether to store metadata for the artifact version.
@@ -182,6 +185,7 @@ def _store_artifact_data_and_prepare_request(
182
185
  artifact_store_id=artifact_store.id,
183
186
  visualizations=visualizations,
184
187
  has_custom_name=has_custom_name,
188
+ save_type=save_type,
185
189
  metadata=validate_metadata(combined_metadata)
186
190
  if combined_metadata
187
191
  else None,
@@ -203,7 +207,7 @@ def save_artifact(
203
207
  is_model_artifact: bool = False,
204
208
  is_deployment_artifact: bool = False,
205
209
  # TODO: remove these once external artifact does not use this function anymore
206
- manual_save: bool = True,
210
+ save_type: ArtifactSaveType = ArtifactSaveType.MANUAL,
207
211
  has_custom_name: bool = True,
208
212
  ) -> "ArtifactVersionResponse":
209
213
  """Upload and publish an artifact.
@@ -224,8 +228,7 @@ def save_artifact(
224
228
  `custom_artifacts/{name}/{version}`.
225
229
  is_model_artifact: If the artifact is a model artifact.
226
230
  is_deployment_artifact: If the artifact is a deployment artifact.
227
- manual_save: If this function is called manually and should therefore
228
- link the artifact to the current step run.
231
+ save_type: The type of save operation that created the artifact version.
229
232
  has_custom_name: If the artifact name is custom and should be listed in
230
233
  the dashboard "Artifacts" tab.
231
234
 
@@ -245,7 +248,7 @@ def save_artifact(
245
248
  if not uri.startswith(artifact_store.path):
246
249
  uri = os.path.join(artifact_store.path, uri)
247
250
 
248
- if manual_save:
251
+ if save_type == ArtifactSaveType.MANUAL:
249
252
  # This check is only necessary for manual saves as we already check
250
253
  # it when creating the directory for step output artifacts
251
254
  _check_if_artifact_with_given_uri_already_registered(
@@ -268,6 +271,7 @@ def save_artifact(
268
271
  name=name,
269
272
  uri=uri,
270
273
  materializer_class=materializer_class,
274
+ save_type=save_type,
271
275
  version=version,
272
276
  tags=tags,
273
277
  store_metadata=extract_metadata,
@@ -279,7 +283,7 @@ def save_artifact(
279
283
  artifact_version=artifact_version_request
280
284
  )
281
285
 
282
- if manual_save:
286
+ if save_type == ArtifactSaveType.MANUAL:
283
287
  _link_artifact_version_to_the_step_and_model(
284
288
  artifact_version=artifact_version,
285
289
  is_model_artifact=is_model_artifact,
@@ -343,6 +347,7 @@ def register_artifact(
343
347
  version=version,
344
348
  tags=tags,
345
349
  type=ArtifactType.DATA,
350
+ save_type=ArtifactSaveType.PREEXISTING,
346
351
  uri=folder_or_file_uri,
347
352
  materializer=source_utils.resolve(PreexistingDataMaterializer),
348
353
  data_type=source_utils.resolve(Path),
@@ -382,17 +387,6 @@ def load_artifact(
382
387
  The loaded artifact.
383
388
  """
384
389
  artifact = Client().get_artifact_version(name_or_id, version)
385
- try:
386
- step_run = get_step_context().step_run
387
- client = Client()
388
- client.zen_store.update_run_step(
389
- step_run_id=step_run.id,
390
- step_run_update=StepRunUpdate(
391
- loaded_artifact_versions={artifact.name: artifact.id}
392
- ),
393
- )
394
- except RuntimeError:
395
- pass # Cannot link to step run if called outside of a step
396
390
  return load_artifact_from_response(artifact)
397
391
 
398
392
 
@@ -636,7 +630,8 @@ def get_artifacts_versions_of_pipeline_run(
636
630
  artifact_versions: List["ArtifactVersionResponse"] = []
637
631
  for step in pipeline_run.steps.values():
638
632
  if not only_produced or step.status == ExecutionStatus.COMPLETED:
639
- artifact_versions.extend(step.outputs.values())
633
+ for output in step.outputs.values():
634
+ artifact_versions.extend(output)
640
635
  return artifact_versions
641
636
 
642
637
 
@@ -697,9 +692,7 @@ def _link_artifact_version_to_the_step_and_model(
697
692
  client.zen_store.update_run_step(
698
693
  step_run_id=step_run.id,
699
694
  step_run_update=StepRunUpdate(
700
- saved_artifact_versions={
701
- artifact_version.artifact.name: artifact_version.id
702
- }
695
+ outputs={artifact_version.artifact.name: artifact_version.id}
703
696
  ),
704
697
  )
705
698
  error_message = "model"
zenml/cli/base.py CHANGED
@@ -79,19 +79,19 @@ class ZenMLProjectTemplateLocation(BaseModel):
79
79
  ZENML_PROJECT_TEMPLATES = dict(
80
80
  e2e_batch=ZenMLProjectTemplateLocation(
81
81
  github_url="zenml-io/template-e2e-batch",
82
- github_tag="2024.10.10", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml
82
+ github_tag="2024.10.30", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml
83
83
  ),
84
84
  starter=ZenMLProjectTemplateLocation(
85
85
  github_url="zenml-io/template-starter",
86
- github_tag="2024.09.24", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml
86
+ github_tag="2024.10.30", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml
87
87
  ),
88
88
  nlp=ZenMLProjectTemplateLocation(
89
89
  github_url="zenml-io/template-nlp",
90
- github_tag="2024.09.23", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml
90
+ github_tag="2024.10.30", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml
91
91
  ),
92
92
  llm_finetuning=ZenMLProjectTemplateLocation(
93
93
  github_url="zenml-io/template-llm-finetuning",
94
- github_tag="2024.09.24", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml
94
+ github_tag="2024.10.30", # Make sure it is aligned with .github/workflows/update-templates-to-examples.yml
95
95
  ),
96
96
  )
97
97
 
zenml/cli/model.py CHANGED
@@ -59,11 +59,6 @@ def _model_to_print(model: ModelResponse) -> Dict[str, Any]:
59
59
  def _model_version_to_print(
60
60
  model_version: ModelVersionResponse,
61
61
  ) -> Dict[str, Any]:
62
- run_metadata = None
63
- if model_version.run_metadata:
64
- run_metadata = {
65
- k: v.value for k, v in model_version.run_metadata.items()
66
- }
67
62
  return {
68
63
  "id": model_version.id,
69
64
  "model": model_version.model.name,
@@ -71,7 +66,7 @@ def _model_version_to_print(
71
66
  "number": model_version.number,
72
67
  "description": model_version.description,
73
68
  "stage": model_version.stage,
74
- "run_metadata": run_metadata,
69
+ "run_metadata": model_version.run_metadata,
75
70
  "tags": [t.name for t in model_version.tags],
76
71
  "data_artifacts_count": len(model_version.data_artifact_ids),
77
72
  "model_artifacts_count": len(model_version.model_artifact_ids),
zenml/cli/stack.py CHANGED
@@ -408,6 +408,7 @@ def register_stack(
408
408
  component_type, preset_name
409
409
  )
410
410
  component_info = component_response.id
411
+ component_name = component_response.name
411
412
  else:
412
413
  if isinstance(service_connector, UUID):
413
414
  # find existing components under same connector
zenml/client.py CHANGED
@@ -136,9 +136,7 @@ from zenml.models import (
136
136
  PipelineResponse,
137
137
  PipelineRunFilter,
138
138
  PipelineRunResponse,
139
- RunMetadataFilter,
140
139
  RunMetadataRequest,
141
- RunMetadataResponse,
142
140
  RunTemplateFilter,
143
141
  RunTemplateRequest,
144
142
  RunTemplateResponse,
@@ -190,6 +188,7 @@ from zenml.models import (
190
188
  WorkspaceResponse,
191
189
  WorkspaceUpdate,
192
190
  )
191
+ from zenml.models.v2.core.step_run import StepRunUpdate
193
192
  from zenml.services.service import ServiceConfig
194
193
  from zenml.services.service_status import ServiceState
195
194
  from zenml.services.service_type import ServiceType
@@ -4166,6 +4165,8 @@ class Client(metaclass=ClientMetaClass):
4166
4165
  Returns:
4167
4166
  The artifact version.
4168
4167
  """
4168
+ from zenml import get_step_context
4169
+
4169
4170
  if cll := client_lazy_loader(
4170
4171
  method_name="get_artifact_version",
4171
4172
  name_id_or_prefix=name_id_or_prefix,
@@ -4173,13 +4174,26 @@ class Client(metaclass=ClientMetaClass):
4173
4174
  hydrate=hydrate,
4174
4175
  ):
4175
4176
  return cll # type: ignore[return-value]
4176
- return self._get_entity_version_by_id_or_name_or_prefix(
4177
+
4178
+ artifact = self._get_entity_version_by_id_or_name_or_prefix(
4177
4179
  get_method=self.zen_store.get_artifact_version,
4178
4180
  list_method=self.list_artifact_versions,
4179
4181
  name_id_or_prefix=name_id_or_prefix,
4180
4182
  version=version,
4181
4183
  hydrate=hydrate,
4182
4184
  )
4185
+ try:
4186
+ step_run = get_step_context().step_run
4187
+ client = Client()
4188
+ client.zen_store.update_run_step(
4189
+ step_run_id=step_run.id,
4190
+ step_run_update=StepRunUpdate(
4191
+ loaded_artifact_versions={artifact.name: artifact.id}
4192
+ ),
4193
+ )
4194
+ except RuntimeError:
4195
+ pass # Cannot link to step run if called outside of a step
4196
+ return artifact
4183
4197
 
4184
4198
  def list_artifact_versions(
4185
4199
  self,
@@ -4417,7 +4431,7 @@ class Client(metaclass=ClientMetaClass):
4417
4431
  resource_id: UUID,
4418
4432
  resource_type: MetadataResourceTypes,
4419
4433
  stack_component_id: Optional[UUID] = None,
4420
- ) -> List[RunMetadataResponse]:
4434
+ ) -> None:
4421
4435
  """Create run metadata.
4422
4436
 
4423
4437
  Args:
@@ -4430,7 +4444,7 @@ class Client(metaclass=ClientMetaClass):
4430
4444
  the metadata.
4431
4445
 
4432
4446
  Returns:
4433
- The created metadata, as string to model dictionary.
4447
+ None
4434
4448
  """
4435
4449
  from zenml.metadata.metadata_types import get_metadata_type
4436
4450
 
@@ -4465,74 +4479,8 @@ class Client(metaclass=ClientMetaClass):
4465
4479
  values=values,
4466
4480
  types=types,
4467
4481
  )
4468
- return self.zen_store.create_run_metadata(run_metadata)
4469
-
4470
- def list_run_metadata(
4471
- self,
4472
- sort_by: str = "created",
4473
- page: int = PAGINATION_STARTING_PAGE,
4474
- size: int = PAGE_SIZE_DEFAULT,
4475
- logical_operator: LogicalOperators = LogicalOperators.AND,
4476
- id: Optional[Union[UUID, str]] = None,
4477
- created: Optional[Union[datetime, str]] = None,
4478
- updated: Optional[Union[datetime, str]] = None,
4479
- workspace_id: Optional[UUID] = None,
4480
- user_id: Optional[UUID] = None,
4481
- resource_id: Optional[UUID] = None,
4482
- resource_type: Optional[MetadataResourceTypes] = None,
4483
- stack_component_id: Optional[UUID] = None,
4484
- key: Optional[str] = None,
4485
- value: Optional["MetadataType"] = None,
4486
- type: Optional[str] = None,
4487
- hydrate: bool = False,
4488
- ) -> Page[RunMetadataResponse]:
4489
- """List run metadata.
4490
-
4491
- Args:
4492
- sort_by: The field to sort the results by.
4493
- page: The page number to return.
4494
- size: The number of results to return per page.
4495
- logical_operator: The logical operator to use for filtering.
4496
- id: The ID of the metadata.
4497
- created: The creation time of the metadata.
4498
- updated: The last update time of the metadata.
4499
- workspace_id: The ID of the workspace the metadata belongs to.
4500
- user_id: The ID of the user that created the metadata.
4501
- resource_id: The ID of the resource the metadata belongs to.
4502
- resource_type: The type of the resource the metadata belongs to.
4503
- stack_component_id: The ID of the stack component that produced
4504
- the metadata.
4505
- key: The key of the metadata.
4506
- value: The value of the metadata.
4507
- type: The type of the metadata.
4508
- hydrate: Flag deciding whether to hydrate the output model(s)
4509
- by including metadata fields in the response.
4510
-
4511
- Returns:
4512
- The run metadata.
4513
- """
4514
- metadata_filter_model = RunMetadataFilter(
4515
- sort_by=sort_by,
4516
- page=page,
4517
- size=size,
4518
- logical_operator=logical_operator,
4519
- id=id,
4520
- created=created,
4521
- updated=updated,
4522
- workspace_id=workspace_id,
4523
- user_id=user_id,
4524
- resource_id=resource_id,
4525
- resource_type=resource_type,
4526
- stack_component_id=stack_component_id,
4527
- key=key,
4528
- value=value,
4529
- type=type,
4530
- )
4531
- metadata_filter_model.set_scope_workspace(self.active_workspace.id)
4532
- return self.zen_store.list_run_metadata(
4533
- metadata_filter_model,
4534
- hydrate=hydrate,
4535
- )
4482
+ self.zen_store.create_run_metadata(run_metadata)
4483
+ return None
4536
4484
 
4537
4485
  # -------------------------------- Secrets ---------------------------------
4538
4486
 
zenml/enums.py CHANGED
@@ -34,15 +34,23 @@ class ArtifactType(StrEnum):
34
34
  class StepRunInputArtifactType(StrEnum):
35
35
  """All possible types of a step run input artifact."""
36
36
 
37
- DEFAULT = "default" # input argument that is the output of a previous step
37
+ STEP_OUTPUT = (
38
+ "step_output" # input argument that is the output of a previous step
39
+ )
38
40
  MANUAL = "manual" # manually loaded via `zenml.load_artifact()`
41
+ EXTERNAL = "external" # loaded via `ExternalArtifact(value=...)`
42
+ LAZY_LOADED = "lazy" # loaded via various lazy methods
39
43
 
40
44
 
41
- class StepRunOutputArtifactType(StrEnum):
42
- """All possible types of a step run output artifact."""
45
+ class ArtifactSaveType(StrEnum):
46
+ """All possible method types of how artifact versions can be saved."""
43
47
 
44
- DEFAULT = "default" # output of the current step
48
+ STEP_OUTPUT = "step_output" # output of the current step
45
49
  MANUAL = "manual" # manually saved via `zenml.save_artifact()`
50
+ PREEXISTING = "preexisting" # register via `zenml.register_artifact()`
51
+ EXTERNAL = (
52
+ "external" # saved via `zenml.ExternalArtifact.upload_by_value()`
53
+ )
46
54
 
47
55
 
48
56
  class VisualizationType(StrEnum):
@@ -566,7 +566,7 @@ class SagemakerOrchestrator(ContainerizedOrchestrator):
566
566
 
567
567
  # Fetch the status of the _PipelineExecution
568
568
  if METADATA_ORCHESTRATOR_RUN_ID in run.run_metadata:
569
- run_id = run.run_metadata[METADATA_ORCHESTRATOR_RUN_ID].value
569
+ run_id = run.run_metadata[METADATA_ORCHESTRATOR_RUN_ID]
570
570
  elif run.orchestrator_run_id is not None:
571
571
  run_id = run.orchestrator_run_id
572
572
  else:
@@ -482,7 +482,7 @@ class AzureMLOrchestrator(ContainerizedOrchestrator):
482
482
 
483
483
  # Fetch the status of the PipelineJob
484
484
  if METADATA_ORCHESTRATOR_RUN_ID in run.run_metadata:
485
- run_id = run.run_metadata[METADATA_ORCHESTRATOR_RUN_ID].value
485
+ run_id = run.run_metadata[METADATA_ORCHESTRATOR_RUN_ID]
486
486
  elif run.orchestrator_run_id is not None:
487
487
  run_id = run.orchestrator_run_id
488
488
  else:
@@ -835,7 +835,7 @@ class VertexOrchestrator(ContainerizedOrchestrator, GoogleCredentialsMixin):
835
835
 
836
836
  # Fetch the status of the PipelineJob
837
837
  if METADATA_ORCHESTRATOR_RUN_ID in run.run_metadata:
838
- run_id = run.run_metadata[METADATA_ORCHESTRATOR_RUN_ID].value
838
+ run_id = run.run_metadata[METADATA_ORCHESTRATOR_RUN_ID]
839
839
  elif run.orchestrator_run_id is not None:
840
840
  run_id = run.orchestrator_run_id
841
841
  else:
@@ -87,41 +87,44 @@ class TensorboardVisualizer:
87
87
  *args: Additional arguments.
88
88
  **kwargs: Additional keyword arguments.
89
89
  """
90
- for _, artifact_view in object.outputs.items():
91
- # filter out anything but model artifacts
92
- if artifact_view.type == ArtifactType.MODEL:
93
- logdir = os.path.dirname(artifact_view.uri)
94
-
95
- # first check if a TensorBoard server is already running for
96
- # the same logdir location and use that one
97
- running_server = self.find_running_tensorboard_server(logdir)
98
- if running_server:
99
- self.visualize_tensorboard(running_server.port, height)
100
- return
101
-
102
- if sys.platform == "win32":
103
- # Daemon service functionality is currently not supported
104
- # on Windows
105
- print(
106
- "You can run:\n"
107
- f"[italic green] tensorboard --logdir {logdir}"
108
- "[/italic green]\n"
109
- "...to visualize the TensorBoard logs for your trained model."
90
+ for output in object.outputs.values():
91
+ for artifact_view in output:
92
+ # filter out anything but model artifacts
93
+ if artifact_view.type == ArtifactType.MODEL:
94
+ logdir = os.path.dirname(artifact_view.uri)
95
+
96
+ # first check if a TensorBoard server is already running for
97
+ # the same logdir location and use that one
98
+ running_server = self.find_running_tensorboard_server(
99
+ logdir
110
100
  )
111
- else:
112
- # start a new TensorBoard server
113
- service = TensorboardService(
114
- TensorboardServiceConfig(
115
- logdir=logdir,
116
- name=f"zenml-tensorboard-{logdir}",
101
+ if running_server:
102
+ self.visualize_tensorboard(running_server.port, height)
103
+ return
104
+
105
+ if sys.platform == "win32":
106
+ # Daemon service functionality is currently not supported
107
+ # on Windows
108
+ print(
109
+ "You can run:\n"
110
+ f"[italic green] tensorboard --logdir {logdir}"
111
+ "[/italic green]\n"
112
+ "...to visualize the TensorBoard logs for your trained model."
117
113
  )
118
- )
119
- service.start(timeout=60)
120
- if service.endpoint.status.port:
121
- self.visualize_tensorboard(
122
- service.endpoint.status.port, height
114
+ else:
115
+ # start a new TensorBoard server
116
+ service = TensorboardService(
117
+ TensorboardServiceConfig(
118
+ logdir=logdir,
119
+ name=f"zenml-tensorboard-{logdir}",
120
+ )
123
121
  )
124
- return
122
+ service.start(timeout=60)
123
+ if service.endpoint.status.port:
124
+ self.visualize_tensorboard(
125
+ service.endpoint.status.port, height
126
+ )
127
+ return
125
128
 
126
129
  def visualize_tensorboard(
127
130
  self,
@@ -154,31 +157,34 @@ class TensorboardVisualizer:
154
157
  Args:
155
158
  object: StepRunResponseModel fetched from get_step().
156
159
  """
157
- for _, artifact_view in object.outputs.items():
158
- # filter out anything but model artifacts
159
- if artifact_view.type == ArtifactType.MODEL:
160
- logdir = os.path.dirname(artifact_view.uri)
161
-
162
- # first check if a TensorBoard server is already running for
163
- # the same logdir location and use that one
164
- running_server = self.find_running_tensorboard_server(logdir)
165
- if not running_server:
166
- return
160
+ for output in object.outputs.values():
161
+ for artifact_view in output:
162
+ # filter out anything but model artifacts
163
+ if artifact_view.type == ArtifactType.MODEL:
164
+ logdir = os.path.dirname(artifact_view.uri)
165
+
166
+ # first check if a TensorBoard server is already running for
167
+ # the same logdir location and use that one
168
+ running_server = self.find_running_tensorboard_server(
169
+ logdir
170
+ )
171
+ if not running_server:
172
+ return
167
173
 
168
- logger.debug(
169
- "Stopping tensorboard server with PID '%d' ...",
170
- running_server.pid,
171
- )
172
- try:
173
- p = psutil.Process(running_server.pid)
174
- except psutil.Error:
175
- logger.error(
176
- "Could not find process for PID '%d' ...",
174
+ logger.debug(
175
+ "Stopping tensorboard server with PID '%d' ...",
177
176
  running_server.pid,
178
177
  )
179
- continue
180
- p.kill()
181
- return
178
+ try:
179
+ p = psutil.Process(running_server.pid)
180
+ except psutil.Error:
181
+ logger.error(
182
+ "Could not find process for PID '%d' ...",
183
+ running_server.pid,
184
+ )
185
+ continue
186
+ p.kill()
187
+ return
182
188
 
183
189
 
184
190
  def get_step(pipeline_name: str, step_name: str) -> "StepRunResponse":
@@ -13,10 +13,25 @@
13
13
  # permissions and limitations under the License.
14
14
  """Run Metadata Lazy Loader definition."""
15
15
 
16
- from typing import TYPE_CHECKING, Optional
16
+ from typing import Optional
17
17
 
18
- if TYPE_CHECKING:
19
- from zenml.models import RunMetadataResponse
18
+ from pydantic import BaseModel
19
+
20
+ from zenml.metadata.metadata_types import MetadataType
21
+
22
+
23
+ class LazyRunMetadataResponse(BaseModel):
24
+ """Lazy run metadata response.
25
+
26
+ Used if the run metadata is accessed from the model in
27
+ a pipeline context available only during pipeline compilation.
28
+ """
29
+
30
+ lazy_load_artifact_name: Optional[str] = None
31
+ lazy_load_artifact_version: Optional[str] = None
32
+ lazy_load_metadata_name: Optional[str] = None
33
+ lazy_load_model_name: str
34
+ lazy_load_model_version: Optional[str] = None
20
35
 
21
36
 
22
37
  class RunMetadataLazyGetter:
@@ -47,7 +62,7 @@ class RunMetadataLazyGetter:
47
62
  self._lazy_load_artifact_name = _lazy_load_artifact_name
48
63
  self._lazy_load_artifact_version = _lazy_load_artifact_version
49
64
 
50
- def __getitem__(self, key: str) -> "RunMetadataResponse":
65
+ def __getitem__(self, key: str) -> MetadataType:
51
66
  """Get the metadata for the given key.
52
67
 
53
68
  Args:
@@ -56,9 +71,7 @@ class RunMetadataLazyGetter:
56
71
  Returns:
57
72
  The metadata lazy loader wrapper for the given key.
58
73
  """
59
- from zenml.models.v2.core.run_metadata import LazyRunMetadataResponse
60
-
61
- return LazyRunMetadataResponse(
74
+ return LazyRunMetadataResponse( # type: ignore[return-value]
62
75
  lazy_load_model_name=self._lazy_load_model_name,
63
76
  lazy_load_model_version=self._lazy_load_model_version,
64
77
  lazy_load_artifact_name=self._lazy_load_artifact_name,
zenml/model/model.py CHANGED
@@ -43,7 +43,6 @@ if TYPE_CHECKING:
43
43
  ModelResponse,
44
44
  ModelVersionResponse,
45
45
  PipelineRunResponse,
46
- RunMetadataResponse,
47
46
  StepRunResponse,
48
47
  )
49
48
 
@@ -349,7 +348,7 @@ class Model(BaseModel):
349
348
  )
350
349
 
351
350
  @property
352
- def run_metadata(self) -> Dict[str, "RunMetadataResponse"]:
351
+ def run_metadata(self) -> Dict[str, "MetadataType"]:
353
352
  """Get model version run metadata.
354
353
 
355
354
  Returns:
zenml/models/__init__.py CHANGED
@@ -239,12 +239,7 @@ from zenml.models.v2.core.run_template import (
239
239
  )
240
240
  from zenml.models.v2.base.base_plugin_flavor import BasePluginFlavorResponse
241
241
  from zenml.models.v2.core.run_metadata import (
242
- LazyRunMetadataResponse,
243
242
  RunMetadataRequest,
244
- RunMetadataFilter,
245
- RunMetadataResponse,
246
- RunMetadataResponseBody,
247
- RunMetadataResponseMetadata,
248
243
  )
249
244
  from zenml.models.v2.core.schedule import (
250
245
  ScheduleRequest,
@@ -418,7 +413,6 @@ EventSourceResponseResources.model_rebuild()
418
413
  FlavorResponseBody.model_rebuild()
419
414
  FlavorResponseMetadata.model_rebuild()
420
415
  LazyArtifactVersionResponse.model_rebuild()
421
- LazyRunMetadataResponse.model_rebuild()
422
416
  ModelResponseBody.model_rebuild()
423
417
  ModelResponseMetadata.model_rebuild()
424
418
  ModelVersionResponseBody.model_rebuild()
@@ -444,8 +438,6 @@ RunTemplateResponseBody.model_rebuild()
444
438
  RunTemplateResponseMetadata.model_rebuild()
445
439
  RunTemplateResponseResources.model_rebuild()
446
440
  RunTemplateResponseBody.model_rebuild()
447
- RunMetadataResponseBody.model_rebuild()
448
- RunMetadataResponseMetadata.model_rebuild()
449
441
  ScheduleResponseBody.model_rebuild()
450
442
  ScheduleResponseMetadata.model_rebuild()
451
443
  SecretResponseBody.model_rebuild()
@@ -637,10 +629,6 @@ __all__ = [
637
629
  "RunTemplateResponseResources",
638
630
  "RunTemplateFilter",
639
631
  "RunMetadataRequest",
640
- "RunMetadataFilter",
641
- "RunMetadataResponse",
642
- "RunMetadataResponseBody",
643
- "RunMetadataResponseMetadata",
644
632
  "ScheduleRequest",
645
633
  "ScheduleUpdate",
646
634
  "ScheduleFilter",