zenml-nightly 0.73.0.dev20250129__py3-none-any.whl → 0.73.0.dev20250131__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. zenml/VERSION +1 -1
  2. zenml/cli/code_repository.py +26 -0
  3. zenml/cli/utils.py +14 -9
  4. zenml/client.py +2 -7
  5. zenml/code_repositories/base_code_repository.py +30 -2
  6. zenml/code_repositories/git/local_git_repository_context.py +26 -10
  7. zenml/code_repositories/local_repository_context.py +11 -8
  8. zenml/constants.py +3 -0
  9. zenml/integrations/gcp/constants.py +1 -1
  10. zenml/integrations/gcp/flavors/vertex_step_operator_flavor.py +3 -1
  11. zenml/integrations/gcp/step_operators/vertex_step_operator.py +1 -0
  12. zenml/integrations/github/code_repositories/github_code_repository.py +17 -2
  13. zenml/integrations/gitlab/code_repositories/gitlab_code_repository.py +17 -2
  14. zenml/integrations/huggingface/services/huggingface_deployment.py +72 -29
  15. zenml/integrations/pytorch/materializers/base_pytorch_materializer.py +1 -1
  16. zenml/integrations/vllm/services/vllm_deployment.py +6 -1
  17. zenml/pipelines/build_utils.py +42 -35
  18. zenml/pipelines/pipeline_definition.py +5 -2
  19. zenml/utils/code_repository_utils.py +11 -2
  20. zenml/utils/downloaded_repository_context.py +3 -5
  21. zenml/utils/source_utils.py +3 -3
  22. zenml/zen_stores/migrations/utils.py +48 -1
  23. zenml/zen_stores/migrations/versions/4d5524b92a30_add_run_metadata_tag_index.py +67 -0
  24. zenml/zen_stores/rest_zen_store.py +3 -13
  25. zenml/zen_stores/schemas/run_metadata_schemas.py +15 -2
  26. zenml/zen_stores/schemas/schema_utils.py +34 -2
  27. zenml/zen_stores/schemas/tag_schemas.py +14 -1
  28. zenml/zen_stores/secrets_stores/sql_secrets_store.py +5 -2
  29. zenml/zen_stores/sql_zen_store.py +24 -17
  30. {zenml_nightly-0.73.0.dev20250129.dist-info → zenml_nightly-0.73.0.dev20250131.dist-info}/METADATA +1 -1
  31. {zenml_nightly-0.73.0.dev20250129.dist-info → zenml_nightly-0.73.0.dev20250131.dist-info}/RECORD +34 -33
  32. {zenml_nightly-0.73.0.dev20250129.dist-info → zenml_nightly-0.73.0.dev20250131.dist-info}/LICENSE +0 -0
  33. {zenml_nightly-0.73.0.dev20250129.dist-info → zenml_nightly-0.73.0.dev20250131.dist-info}/WHEEL +0 -0
  34. {zenml_nightly-0.73.0.dev20250129.dist-info → zenml_nightly-0.73.0.dev20250131.dist-info}/entry_points.txt +0 -0
@@ -152,9 +152,14 @@ class VLLMDeploymentService(LocalDaemonService, BaseDeploymentService):
152
152
  parser: argparse.ArgumentParser = make_arg_parser(
153
153
  FlexibleArgumentParser()
154
154
  )
155
- args: argparse.Namespace = parser.parse_args()
155
+ # pass in empty list to get default args
156
+ # otherwise it will try to get the args from sys.argv
157
+ # and if there's a --config in there, it will want to use
158
+ # that file for vLLM configuration, which is not what we want
159
+ args: argparse.Namespace = parser.parse_args(args=[])
156
160
  # Override port with the available port
157
161
  self.config.port = self.endpoint.status.port or self.config.port
162
+
158
163
  # Update the arguments in place
159
164
  args.__dict__.update(self.config.model_dump())
160
165
  uvloop.run(run_server(args=args))
@@ -517,32 +517,9 @@ def verify_local_repository_context(
517
517
  "changes."
518
518
  )
519
519
 
520
- if local_repo_context:
521
- if local_repo_context.is_dirty:
522
- logger.warning(
523
- "Unable to use code repository to download code for this "
524
- "run as there are uncommitted changes."
525
- )
526
- elif local_repo_context.has_local_changes:
527
- logger.warning(
528
- "Unable to use code repository to download code for this "
529
- "run as there are unpushed changes."
530
- )
531
-
532
520
  code_repository = None
533
521
  if local_repo_context and not local_repo_context.has_local_changes:
534
- model = Client().get_code_repository(
535
- local_repo_context.code_repository_id
536
- )
537
- code_repository = BaseCodeRepository.from_model(model)
538
-
539
- if will_download_from_code_repository(
540
- deployment=deployment, local_repo_context=local_repo_context
541
- ):
542
- logger.info(
543
- "Using code repository `%s` to download code for this run.",
544
- model.name,
545
- )
522
+ code_repository = local_repo_context.code_repository
546
523
 
547
524
  return code_repository
548
525
 
@@ -738,25 +715,17 @@ def should_upload_code(
738
715
  return False
739
716
 
740
717
 
741
- def will_download_from_code_repository(
718
+ def allows_download_from_code_repository(
742
719
  deployment: PipelineDeploymentBase,
743
- local_repo_context: "LocalRepositoryContext",
744
720
  ) -> bool:
745
- """Checks whether a code repository will be used to download code.
721
+ """Checks whether a code repository can be used to download code.
746
722
 
747
723
  Args:
748
724
  deployment: The deployment.
749
- local_repo_context: The local repository context.
750
725
 
751
726
  Returns:
752
- Whether a code repository will be used to download code.
727
+ Whether a code repository can be used to download code.
753
728
  """
754
- if not build_required(deployment=deployment):
755
- return False
756
-
757
- if local_repo_context.has_local_changes:
758
- return False
759
-
760
729
  for step in deployment.step_configurations.values():
761
730
  docker_settings = step.config.docker_settings
762
731
 
@@ -764,3 +733,41 @@ def will_download_from_code_repository(
764
733
  return True
765
734
 
766
735
  return False
736
+
737
+
738
+ def log_code_repository_usage(
739
+ deployment: PipelineDeploymentBase,
740
+ local_repo_context: "LocalRepositoryContext",
741
+ ) -> None:
742
+ """Log what the code repository can (not) be used for given a deployment.
743
+
744
+ Args:
745
+ deployment: The deployment.
746
+ local_repo_context: The local repository context.
747
+ """
748
+ if build_required(deployment) and allows_download_from_code_repository(
749
+ deployment
750
+ ):
751
+ if local_repo_context.is_dirty:
752
+ logger.warning(
753
+ "Unable to use code repository `%s` to download code or track "
754
+ "the commit hash as there are uncommitted or untracked files.",
755
+ local_repo_context.code_repository.name,
756
+ )
757
+ elif local_repo_context.has_local_changes:
758
+ logger.warning(
759
+ "Unable to use code repository `%s` to download code as there "
760
+ "are unpushed commits.",
761
+ local_repo_context.code_repository.name,
762
+ )
763
+ else:
764
+ logger.info(
765
+ "Using code repository `%s` to download code for this run.",
766
+ local_repo_context.code_repository.name,
767
+ )
768
+ elif local_repo_context.is_dirty:
769
+ logger.warning(
770
+ "Unable to use code repository `%s` to track the commit hash as "
771
+ "there are uncommitted or untracked files.",
772
+ local_repo_context.code_repository.name,
773
+ )
@@ -643,7 +643,6 @@ To avoid this consider setting pipeline parameters only in one place (config or
643
643
  pipeline_id = None
644
644
  if register_pipeline:
645
645
  pipeline_id = self._register().id
646
-
647
646
  else:
648
647
  logger.debug(f"Pipeline {self.name} is unlisted.")
649
648
 
@@ -702,6 +701,10 @@ To avoid this consider setting pipeline parameters only in one place (config or
702
701
  deployment=deployment, local_repo_context=local_repo_context
703
702
  )
704
703
  can_download_from_code_repository = code_repository is not None
704
+ if local_repo_context:
705
+ build_utils.log_code_repository_usage(
706
+ deployment=deployment, local_repo_context=local_repo_context
707
+ )
705
708
 
706
709
  if prevent_build_reuse:
707
710
  logger.warning(
@@ -731,7 +734,7 @@ To avoid this consider setting pipeline parameters only in one place (config or
731
734
  code_reference = CodeReferenceRequest(
732
735
  commit=local_repo_context.current_commit,
733
736
  subdirectory=subdirectory.as_posix(),
734
- code_repository=local_repo_context.code_repository_id,
737
+ code_repository=local_repo_context.code_repository.id,
735
738
  )
736
739
 
737
740
  code_path = None
@@ -79,7 +79,7 @@ def set_custom_local_repository(
79
79
 
80
80
  path = os.path.abspath(source_utils.get_source_root())
81
81
  _CODE_REPOSITORY_CACHE[path] = _DownloadedRepositoryContext(
82
- code_repository_id=repo.id, root=root, commit=commit
82
+ code_repository=repo, root=root, commit=commit
83
83
  )
84
84
 
85
85
 
@@ -106,7 +106,8 @@ def find_active_code_repository(
106
106
  return _CODE_REPOSITORY_CACHE[path]
107
107
 
108
108
  local_context: Optional["LocalRepositoryContext"] = None
109
- for model in depaginate(list_method=Client().list_code_repositories):
109
+ code_repositories = depaginate(list_method=Client().list_code_repositories)
110
+ for model in code_repositories:
110
111
  try:
111
112
  repo = BaseCodeRepository.from_model(model)
112
113
  except ImportError:
@@ -125,6 +126,14 @@ def find_active_code_repository(
125
126
  local_context = repo.get_local_context(path)
126
127
  if local_context:
127
128
  break
129
+ else:
130
+ if code_repositories:
131
+ # There are registered code repositories, but none was matching the
132
+ # current path -> We log the path to help in debugging issues
133
+ # related to the source root.
134
+ logger.info(
135
+ "No matching code repository found for path `%s`.", path
136
+ )
128
137
 
129
138
  _CODE_REPOSITORY_CACHE[path] = local_context
130
139
  return local_context
@@ -13,9 +13,7 @@
13
13
  # permissions and limitations under the License.
14
14
  """Downloaded code repository."""
15
15
 
16
- from uuid import UUID
17
-
18
- from zenml.code_repositories import LocalRepositoryContext
16
+ from zenml.code_repositories import BaseCodeRepository, LocalRepositoryContext
19
17
 
20
18
 
21
19
  class _DownloadedRepositoryContext(LocalRepositoryContext):
@@ -27,11 +25,11 @@ class _DownloadedRepositoryContext(LocalRepositoryContext):
27
25
 
28
26
  def __init__(
29
27
  self,
30
- code_repository_id: UUID,
28
+ code_repository: BaseCodeRepository,
31
29
  root: str,
32
30
  commit: str,
33
31
  ):
34
- super().__init__(code_repository_id=code_repository_id)
32
+ super().__init__(code_repository=code_repository)
35
33
  self._root = root
36
34
  self._commit = commit
37
35
 
@@ -226,7 +226,7 @@ def resolve(
226
226
  subdir = PurePath(source_root).relative_to(local_repo_context.root)
227
227
 
228
228
  return CodeRepositorySource(
229
- repository_id=local_repo_context.code_repository_id,
229
+ repository_id=local_repo_context.code_repository.id,
230
230
  commit=local_repo_context.current_commit,
231
231
  subdirectory=subdir.as_posix(),
232
232
  module=module_name,
@@ -482,7 +482,7 @@ def _warn_about_potential_source_loading_issues(
482
482
  source.repository_id,
483
483
  get_source_root(),
484
484
  )
485
- elif local_repo.code_repository_id != source.repository_id:
485
+ elif local_repo.code_repository.id != source.repository_id:
486
486
  logger.warning(
487
487
  "Potential issue when loading the source `%s`: The source "
488
488
  "references the code repository `%s` but there is a different "
@@ -492,7 +492,7 @@ def _warn_about_potential_source_loading_issues(
492
492
  "source was originally stored.",
493
493
  source.import_path,
494
494
  source.repository_id,
495
- local_repo.code_repository_id,
495
+ local_repo.code_repository.id,
496
496
  get_source_root(),
497
497
  )
498
498
  elif local_repo.current_commit != source.commit:
@@ -34,7 +34,7 @@ from sqlalchemy.engine import URL, Engine
34
34
  from sqlalchemy.exc import (
35
35
  OperationalError,
36
36
  )
37
- from sqlalchemy.schema import CreateTable
37
+ from sqlalchemy.schema import CreateIndex, CreateTable
38
38
  from sqlmodel import (
39
39
  create_engine,
40
40
  select,
@@ -249,6 +249,7 @@ class MigrationUtils(BaseModel):
249
249
  # them to the create table statement.
250
250
 
251
251
  # Extract the unique constraints from the table schema
252
+ index_create_statements = []
252
253
  unique_constraints = []
253
254
  for index in table.indexes:
254
255
  if index.unique:
@@ -258,6 +259,38 @@ class MigrationUtils(BaseModel):
258
259
  unique_constraints.append(
259
260
  f"UNIQUE KEY `{index.name}` ({', '.join(unique_columns)})"
260
261
  )
262
+ else:
263
+ if index.name in {
264
+ fk.name for fk in table.foreign_key_constraints
265
+ }:
266
+ # Foreign key indices are already handled by the
267
+ # table creation statement.
268
+ continue
269
+
270
+ index_create = str(CreateIndex(index)).strip() # type: ignore[no-untyped-call]
271
+ index_create = index_create.replace(
272
+ f"CREATE INDEX {index.name}",
273
+ f"CREATE INDEX `{index.name}`",
274
+ )
275
+ index_create = index_create.replace(
276
+ f"ON {table.name}", f"ON `{table.name}`"
277
+ )
278
+
279
+ for column_name in index.columns.keys():
280
+ # We need this logic here to avoid the column names
281
+ # inside the index name
282
+ index_create = index_create.replace(
283
+ f"({column_name}", f"(`{column_name}`"
284
+ )
285
+ index_create = index_create.replace(
286
+ f"{column_name},", f"`{column_name}`,"
287
+ )
288
+ index_create = index_create.replace(
289
+ f"{column_name})", f"`{column_name}`)"
290
+ )
291
+
292
+ index_create = index_create.replace('"', "") + ";"
293
+ index_create_statements.append(index_create)
261
294
 
262
295
  # Add the unique constraints to the create table statement
263
296
  if unique_constraints:
@@ -290,6 +323,14 @@ class MigrationUtils(BaseModel):
290
323
  )
291
324
  )
292
325
 
326
+ for stmt in index_create_statements:
327
+ store_db_info(
328
+ dict(
329
+ table=table.name,
330
+ index_create_stmt=stmt,
331
+ )
332
+ )
333
+
293
334
  # 2. extract the table data in batches
294
335
  order_by = [col for col in table.primary_key]
295
336
 
@@ -356,6 +397,12 @@ class MigrationUtils(BaseModel):
356
397
  "self_references", False
357
398
  )
358
399
 
400
+ if "index_create_stmt" in table_dump:
401
+ # execute the index creation statement
402
+ connection.execute(text(table_dump["index_create_stmt"]))
403
+ # Reload the database metadata after creating the index
404
+ metadata.reflect(bind=self.engine)
405
+
359
406
  if "data" in table_dump:
360
407
  # insert the data into the database
361
408
  table = metadata.tables[table_name]
@@ -0,0 +1,67 @@
1
+ """Add run metadata and tag index [4d5524b92a30].
2
+
3
+ Revision ID: 4d5524b92a30
4
+ Revises: 0.73.0
5
+ Create Date: 2025-01-30 11:30:36.736452
6
+
7
+ """
8
+
9
+ from alembic import op
10
+ from sqlalchemy import inspect
11
+
12
+ # revision identifiers, used by Alembic.
13
+ revision = "4d5524b92a30"
14
+ down_revision = "0.73.0"
15
+ branch_labels = None
16
+ depends_on = None
17
+
18
+
19
+ def upgrade() -> None:
20
+ """Upgrade database schema and/or data, creating a new revision."""
21
+ connection = op.get_bind()
22
+
23
+ inspector = inspect(connection)
24
+ for index in inspector.get_indexes("run_metadata_resource"):
25
+ # This index was manually added to some databases to improve the
26
+ # speed and cache utilisation. In this case we simply return here and
27
+ # don't continue with the migration.
28
+ if (
29
+ index["name"]
30
+ == "ix_run_metadata_resource_resource_id_resource_type_run_metadata_"
31
+ ):
32
+ return
33
+
34
+ # ### commands auto generated by Alembic - please adjust! ###
35
+ with op.batch_alter_table(
36
+ "run_metadata_resource", schema=None
37
+ ) as batch_op:
38
+ batch_op.create_index(
39
+ "ix_run_metadata_resource_resource_id_resource_type_run_metadata_",
40
+ ["resource_id", "resource_type", "run_metadata_id"],
41
+ unique=False,
42
+ )
43
+
44
+ with op.batch_alter_table("tag_resource", schema=None) as batch_op:
45
+ batch_op.create_index(
46
+ "ix_tag_resource_resource_id_resource_type_tag_id",
47
+ ["resource_id", "resource_type", "tag_id"],
48
+ unique=False,
49
+ )
50
+
51
+ # ### end Alembic commands ###
52
+
53
+
54
+ def downgrade() -> None:
55
+ """Downgrade database schema and/or data back to the previous revision."""
56
+ # ### commands auto generated by Alembic - please adjust! ###
57
+ with op.batch_alter_table("tag_resource", schema=None) as batch_op:
58
+ batch_op.drop_index("ix_tag_resource_resource_id_resource_type_tag_id")
59
+
60
+ with op.batch_alter_table(
61
+ "run_metadata_resource", schema=None
62
+ ) as batch_op:
63
+ batch_op.drop_index(
64
+ "ix_run_metadata_resource_resource_id_resource_type_run_metadata_"
65
+ )
66
+
67
+ # ### end Alembic commands ###
@@ -365,17 +365,16 @@ class RestZenStoreConfiguration(StoreConfiguration):
365
365
 
366
366
  if os.path.isfile(verify_ssl):
367
367
  with open(verify_ssl, "r") as f:
368
- verify_ssl = f.read()
368
+ cert_content = f.read()
369
369
 
370
370
  fileio.makedirs(str(secret_folder))
371
371
  file_path = Path(secret_folder, "ca_bundle.pem")
372
372
  with os.fdopen(
373
373
  os.open(file_path, flags=os.O_RDWR | os.O_CREAT, mode=0o600), "w"
374
374
  ) as f:
375
- f.write(verify_ssl)
376
- verify_ssl = str(file_path)
375
+ f.write(cert_content)
377
376
 
378
- return verify_ssl
377
+ return str(file_path)
379
378
 
380
379
  @classmethod
381
380
  def supports_url_scheme(cls, url: str) -> bool:
@@ -389,15 +388,6 @@ class RestZenStoreConfiguration(StoreConfiguration):
389
388
  """
390
389
  return urlparse(url).scheme in ("http", "https")
391
390
 
392
- def expand_certificates(self) -> None:
393
- """Expands the certificates in the verify_ssl field."""
394
- # Load the certificate values back into the configuration
395
- if isinstance(self.verify_ssl, str) and os.path.isfile(
396
- self.verify_ssl
397
- ):
398
- with open(self.verify_ssl, "r") as f:
399
- self.verify_ssl = f.read()
400
-
401
391
  @model_validator(mode="before")
402
392
  @classmethod
403
393
  @before_validator_handler
@@ -11,7 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12
12
  # or implied. See the License for the specific language governing
13
13
  # permissions and limitations under the License.
14
- """SQLModel implementation of pipeline run metadata tables."""
14
+ """SQLModel implementation of run metadata tables."""
15
15
 
16
16
  from typing import Optional
17
17
  from uuid import UUID, uuid4
@@ -21,7 +21,10 @@ from sqlmodel import Field, Relationship, SQLModel
21
21
 
22
22
  from zenml.zen_stores.schemas.base_schemas import BaseSchema
23
23
  from zenml.zen_stores.schemas.component_schemas import StackComponentSchema
24
- from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field
24
+ from zenml.zen_stores.schemas.schema_utils import (
25
+ build_foreign_key_field,
26
+ build_index,
27
+ )
25
28
  from zenml.zen_stores.schemas.step_run_schemas import StepRunSchema
26
29
  from zenml.zen_stores.schemas.user_schemas import UserSchema
27
30
  from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema
@@ -82,6 +85,16 @@ class RunMetadataResourceSchema(SQLModel, table=True):
82
85
  """Table for linking resources to run metadata entries."""
83
86
 
84
87
  __tablename__ = "run_metadata_resource"
88
+ __table_args__ = (
89
+ build_index(
90
+ table_name=__tablename__,
91
+ column_names=[
92
+ "resource_id",
93
+ "resource_type",
94
+ "run_metadata_id",
95
+ ],
96
+ ),
97
+ )
85
98
 
86
99
  id: UUID = Field(default_factory=uuid4, primary_key=True)
87
100
  resource_id: UUID
@@ -13,9 +13,9 @@
13
13
  # permissions and limitations under the License.
14
14
  """Utility functions for SQLModel schemas."""
15
15
 
16
- from typing import Any
16
+ from typing import Any, List
17
17
 
18
- from sqlalchemy import Column, ForeignKey
18
+ from sqlalchemy import Column, ForeignKey, Index
19
19
  from sqlmodel import Field
20
20
 
21
21
 
@@ -84,3 +84,35 @@ def build_foreign_key_field(
84
84
  **sa_column_kwargs,
85
85
  ),
86
86
  )
87
+
88
+
89
+ def get_index_name(table_name: str, column_names: List[str]) -> str:
90
+ """Get the name for an index.
91
+
92
+ Args:
93
+ table_name: The name of the table for which the index will be created.
94
+ column_names: Names of the columns on which the index will be created.
95
+
96
+ Returns:
97
+ The index name.
98
+ """
99
+ columns = "_".join(column_names)
100
+ # MySQL allows a maximum of 64 characters in identifiers
101
+ return f"ix_{table_name}_{columns}"[:64]
102
+
103
+
104
+ def build_index(
105
+ table_name: str, column_names: List[str], **kwargs: Any
106
+ ) -> Index:
107
+ """Build an index object.
108
+
109
+ Args:
110
+ table_name: The name of the table for which the index will be created.
111
+ column_names: Names of the columns on which the index will be created.
112
+ **kwargs: Additional keyword arguments to pass to the Index.
113
+
114
+ Returns:
115
+ The index.
116
+ """
117
+ name = get_index_name(table_name=table_name, column_names=column_names)
118
+ return Index(name, *column_names, **kwargs)
@@ -31,7 +31,10 @@ from zenml.models import (
31
31
  )
32
32
  from zenml.utils.time_utils import utc_now
33
33
  from zenml.zen_stores.schemas.base_schemas import BaseSchema, NamedSchema
34
- from zenml.zen_stores.schemas.schema_utils import build_foreign_key_field
34
+ from zenml.zen_stores.schemas.schema_utils import (
35
+ build_foreign_key_field,
36
+ build_index,
37
+ )
35
38
 
36
39
 
37
40
  class TagSchema(NamedSchema, table=True):
@@ -111,6 +114,16 @@ class TagResourceSchema(BaseSchema, table=True):
111
114
  """SQL Model for tag resource relationship."""
112
115
 
113
116
  __tablename__ = "tag_resource"
117
+ __table_args__ = (
118
+ build_index(
119
+ table_name=__tablename__,
120
+ column_names=[
121
+ "resource_id",
122
+ "resource_type",
123
+ "tag_id",
124
+ ],
125
+ ),
126
+ )
114
127
 
115
128
  tag_id: UUID = build_foreign_key_field(
116
129
  source=__tablename__,
@@ -37,6 +37,7 @@ from zenml.exceptions import (
37
37
  IllegalOperationError,
38
38
  )
39
39
  from zenml.logger import get_logger
40
+ from zenml.utils.secret_utils import PlainSerializedSecretStr
40
41
  from zenml.zen_stores.schemas import (
41
42
  SecretSchema,
42
43
  )
@@ -62,7 +63,7 @@ class SqlSecretsStoreConfiguration(SecretsStoreConfiguration):
62
63
  """
63
64
 
64
65
  type: SecretsStoreType = SecretsStoreType.SQL
65
- encryption_key: Optional[str] = None
66
+ encryption_key: Optional[PlainSerializedSecretStr] = None
66
67
  model_config = ConfigDict(
67
68
  # Don't validate attributes when assigning them. This is necessary
68
69
  # because the certificate attributes can be expanded to the contents
@@ -159,7 +160,9 @@ class SqlSecretsStore(BaseSecretsStore):
159
160
  # Initialize the encryption engine
160
161
  if self.config.encryption_key:
161
162
  self._encryption_engine = AesGcmEngine()
162
- self._encryption_engine._update_key(self.config.encryption_key)
163
+ self._encryption_engine._update_key(
164
+ self.config.encryption_key.get_secret_value()
165
+ )
163
166
 
164
167
  # Nothing else to do here, the SQL ZenML store back-end is already
165
168
  # initialized
@@ -304,6 +304,7 @@ from zenml.utils.networking_utils import (
304
304
  replace_localhost_with_internal_hostname,
305
305
  )
306
306
  from zenml.utils.pydantic_utils import before_validator_handler
307
+ from zenml.utils.secret_utils import PlainSerializedSecretStr
307
308
  from zenml.utils.string_utils import (
308
309
  format_name_template,
309
310
  random_str,
@@ -460,11 +461,11 @@ class SqlZenStoreConfiguration(StoreConfiguration):
460
461
 
461
462
  driver: Optional[SQLDatabaseDriver] = None
462
463
  database: Optional[str] = None
463
- username: Optional[str] = None
464
- password: Optional[str] = None
465
- ssl_ca: Optional[str] = None
466
- ssl_cert: Optional[str] = None
467
- ssl_key: Optional[str] = None
464
+ username: Optional[PlainSerializedSecretStr] = None
465
+ password: Optional[PlainSerializedSecretStr] = None
466
+ ssl_ca: Optional[PlainSerializedSecretStr] = None
467
+ ssl_cert: Optional[PlainSerializedSecretStr] = None
468
+ ssl_key: Optional[PlainSerializedSecretStr] = None
468
469
  ssl_verify_server_cert: bool = False
469
470
  pool_size: int = 20
470
471
  max_overflow: int = 20
@@ -611,10 +612,10 @@ class SqlZenStoreConfiguration(StoreConfiguration):
611
612
  self.database = sql_url.database
612
613
  elif sql_url.drivername == SQLDatabaseDriver.MYSQL:
613
614
  if sql_url.username:
614
- self.username = sql_url.username
615
+ self.username = PlainSerializedSecretStr(sql_url.username)
615
616
  sql_url = sql_url._replace(username=None)
616
617
  if sql_url.password:
617
- self.password = sql_url.password
618
+ self.password = PlainSerializedSecretStr(sql_url.password)
618
619
  sql_url = sql_url._replace(password=None)
619
620
  if sql_url.database:
620
621
  self.database = sql_url.database
@@ -642,13 +643,13 @@ class SqlZenStoreConfiguration(StoreConfiguration):
642
643
  for k, v in sql_url.query.items():
643
644
  if k == "ssl_ca":
644
645
  if r := _get_query_result(v):
645
- self.ssl_ca = r
646
+ self.ssl_ca = PlainSerializedSecretStr(r)
646
647
  elif k == "ssl_cert":
647
648
  if r := _get_query_result(v):
648
- self.ssl_cert = r
649
+ self.ssl_cert = PlainSerializedSecretStr(r)
649
650
  elif k == "ssl_key":
650
651
  if r := _get_query_result(v):
651
- self.ssl_key = r
652
+ self.ssl_key = PlainSerializedSecretStr(r)
652
653
  elif k == "ssl_verify_server_cert":
653
654
  if r := _get_query_result(v):
654
655
  if is_true_string_value(r):
@@ -688,7 +689,7 @@ class SqlZenStoreConfiguration(StoreConfiguration):
688
689
  )
689
690
  for key in ["ssl_key", "ssl_ca", "ssl_cert"]:
690
691
  content = getattr(self, key)
691
- if content and not os.path.isfile(content):
692
+ if content and not os.path.isfile(content.get_secret_value()):
692
693
  fileio.makedirs(str(secret_folder))
693
694
  file_path = Path(secret_folder, f"{key}.pem")
694
695
  with os.fdopen(
@@ -697,7 +698,7 @@ class SqlZenStoreConfiguration(StoreConfiguration):
697
698
  ),
698
699
  "w",
699
700
  ) as f:
700
- f.write(content)
701
+ f.write(content.get_secret_value())
701
702
  setattr(self, key, str(file_path))
702
703
 
703
704
  self.url = str(sql_url)
@@ -732,7 +733,7 @@ class SqlZenStoreConfiguration(StoreConfiguration):
732
733
  # Load the certificate values back into the configuration
733
734
  for key in ["ssl_key", "ssl_ca", "ssl_cert"]:
734
735
  file_path = getattr(self, key, None)
735
- if file_path and os.path.isfile(file_path):
736
+ if file_path and os.path.isfile(file_path.get_secret_value()):
736
737
  with open(file_path, "r") as f:
737
738
  setattr(self, key, f.read())
738
739
 
@@ -780,8 +781,8 @@ class SqlZenStoreConfiguration(StoreConfiguration):
780
781
 
781
782
  sql_url = sql_url._replace(
782
783
  drivername="mysql+pymysql",
783
- username=self.username,
784
- password=self.password,
784
+ username=self.username.get_secret_value(),
785
+ password=self.password.get_secret_value(),
785
786
  database=database,
786
787
  )
787
788
 
@@ -792,11 +793,17 @@ class SqlZenStoreConfiguration(StoreConfiguration):
792
793
  ssl_setting = getattr(self, key)
793
794
  if not ssl_setting:
794
795
  continue
795
- if not os.path.isfile(ssl_setting):
796
+ if not os.path.isfile(ssl_setting.get_secret_value()):
796
797
  logger.warning(
797
798
  f"Database SSL setting `{key}` is not a file. "
798
799
  )
799
- sqlalchemy_ssl_args[key.removeprefix("ssl_")] = ssl_setting
800
+ sqlalchemy_ssl_args[key.lstrip("ssl_")] = (
801
+ ssl_setting.get_secret_value()
802
+ )
803
+ sqlalchemy_ssl_args[key.removeprefix("ssl_")] = (
804
+ ssl_setting.get_secret_value()
805
+ )
806
+
800
807
  if len(sqlalchemy_ssl_args) > 0:
801
808
  sqlalchemy_ssl_args["check_hostname"] = (
802
809
  self.ssl_verify_server_cert