zenml-nightly 0.73.0.dev20250129__py3-none-any.whl → 0.73.0.dev20250131__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/cli/code_repository.py +26 -0
- zenml/cli/utils.py +14 -9
- zenml/client.py +2 -7
- zenml/code_repositories/base_code_repository.py +30 -2
- zenml/code_repositories/git/local_git_repository_context.py +26 -10
- zenml/code_repositories/local_repository_context.py +11 -8
- zenml/constants.py +3 -0
- zenml/integrations/gcp/constants.py +1 -1
- zenml/integrations/gcp/flavors/vertex_step_operator_flavor.py +3 -1
- zenml/integrations/gcp/step_operators/vertex_step_operator.py +1 -0
- zenml/integrations/github/code_repositories/github_code_repository.py +17 -2
- zenml/integrations/gitlab/code_repositories/gitlab_code_repository.py +17 -2
- zenml/integrations/huggingface/services/huggingface_deployment.py +72 -29
- zenml/integrations/pytorch/materializers/base_pytorch_materializer.py +1 -1
- zenml/integrations/vllm/services/vllm_deployment.py +6 -1
- zenml/pipelines/build_utils.py +42 -35
- zenml/pipelines/pipeline_definition.py +5 -2
- zenml/utils/code_repository_utils.py +11 -2
- zenml/utils/downloaded_repository_context.py +3 -5
- zenml/utils/source_utils.py +3 -3
- zenml/zen_stores/migrations/utils.py +48 -1
- zenml/zen_stores/migrations/versions/4d5524b92a30_add_run_metadata_tag_index.py +67 -0
- zenml/zen_stores/rest_zen_store.py +3 -13
- zenml/zen_stores/schemas/run_metadata_schemas.py +15 -2
- zenml/zen_stores/schemas/schema_utils.py +34 -2
- zenml/zen_stores/schemas/tag_schemas.py +14 -1
- zenml/zen_stores/secrets_stores/sql_secrets_store.py +5 -2
- zenml/zen_stores/sql_zen_store.py +24 -17
- {zenml_nightly-0.73.0.dev20250129.dist-info → zenml_nightly-0.73.0.dev20250131.dist-info}/METADATA +1 -1
- {zenml_nightly-0.73.0.dev20250129.dist-info → zenml_nightly-0.73.0.dev20250131.dist-info}/RECORD +34 -33
- {zenml_nightly-0.73.0.dev20250129.dist-info → zenml_nightly-0.73.0.dev20250131.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.73.0.dev20250129.dist-info → zenml_nightly-0.73.0.dev20250131.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.73.0.dev20250129.dist-info → zenml_nightly-0.73.0.dev20250131.dist-info}/entry_points.txt +0 -0
@@ -152,9 +152,14 @@ class VLLMDeploymentService(LocalDaemonService, BaseDeploymentService):
|
|
152
152
|
parser: argparse.ArgumentParser = make_arg_parser(
|
153
153
|
FlexibleArgumentParser()
|
154
154
|
)
|
155
|
-
|
155
|
+
# pass in empty list to get default args
|
156
|
+
# otherwise it will try to get the args from sys.argv
|
157
|
+
# and if there's a --config in there, it will want to use
|
158
|
+
# that file for vLLM configuration, which is not what we want
|
159
|
+
args: argparse.Namespace = parser.parse_args(args=[])
|
156
160
|
# Override port with the available port
|
157
161
|
self.config.port = self.endpoint.status.port or self.config.port
|
162
|
+
|
158
163
|
# Update the arguments in place
|
159
164
|
args.__dict__.update(self.config.model_dump())
|
160
165
|
uvloop.run(run_server(args=args))
|
zenml/pipelines/build_utils.py
CHANGED
@@ -517,32 +517,9 @@ def verify_local_repository_context(
|
|
517
517
|
"changes."
|
518
518
|
)
|
519
519
|
|
520
|
-
if local_repo_context:
|
521
|
-
if local_repo_context.is_dirty:
|
522
|
-
logger.warning(
|
523
|
-
"Unable to use code repository to download code for this "
|
524
|
-
"run as there are uncommitted changes."
|
525
|
-
)
|
526
|
-
elif local_repo_context.has_local_changes:
|
527
|
-
logger.warning(
|
528
|
-
"Unable to use code repository to download code for this "
|
529
|
-
"run as there are unpushed changes."
|
530
|
-
)
|
531
|
-
|
532
520
|
code_repository = None
|
533
521
|
if local_repo_context and not local_repo_context.has_local_changes:
|
534
|
-
|
535
|
-
local_repo_context.code_repository_id
|
536
|
-
)
|
537
|
-
code_repository = BaseCodeRepository.from_model(model)
|
538
|
-
|
539
|
-
if will_download_from_code_repository(
|
540
|
-
deployment=deployment, local_repo_context=local_repo_context
|
541
|
-
):
|
542
|
-
logger.info(
|
543
|
-
"Using code repository `%s` to download code for this run.",
|
544
|
-
model.name,
|
545
|
-
)
|
522
|
+
code_repository = local_repo_context.code_repository
|
546
523
|
|
547
524
|
return code_repository
|
548
525
|
|
@@ -738,25 +715,17 @@ def should_upload_code(
|
|
738
715
|
return False
|
739
716
|
|
740
717
|
|
741
|
-
def
|
718
|
+
def allows_download_from_code_repository(
|
742
719
|
deployment: PipelineDeploymentBase,
|
743
|
-
local_repo_context: "LocalRepositoryContext",
|
744
720
|
) -> bool:
|
745
|
-
"""Checks whether a code repository
|
721
|
+
"""Checks whether a code repository can be used to download code.
|
746
722
|
|
747
723
|
Args:
|
748
724
|
deployment: The deployment.
|
749
|
-
local_repo_context: The local repository context.
|
750
725
|
|
751
726
|
Returns:
|
752
|
-
Whether a code repository
|
727
|
+
Whether a code repository can be used to download code.
|
753
728
|
"""
|
754
|
-
if not build_required(deployment=deployment):
|
755
|
-
return False
|
756
|
-
|
757
|
-
if local_repo_context.has_local_changes:
|
758
|
-
return False
|
759
|
-
|
760
729
|
for step in deployment.step_configurations.values():
|
761
730
|
docker_settings = step.config.docker_settings
|
762
731
|
|
@@ -764,3 +733,41 @@ def will_download_from_code_repository(
|
|
764
733
|
return True
|
765
734
|
|
766
735
|
return False
|
736
|
+
|
737
|
+
|
738
|
+
def log_code_repository_usage(
|
739
|
+
deployment: PipelineDeploymentBase,
|
740
|
+
local_repo_context: "LocalRepositoryContext",
|
741
|
+
) -> None:
|
742
|
+
"""Log what the code repository can (not) be used for given a deployment.
|
743
|
+
|
744
|
+
Args:
|
745
|
+
deployment: The deployment.
|
746
|
+
local_repo_context: The local repository context.
|
747
|
+
"""
|
748
|
+
if build_required(deployment) and allows_download_from_code_repository(
|
749
|
+
deployment
|
750
|
+
):
|
751
|
+
if local_repo_context.is_dirty:
|
752
|
+
logger.warning(
|
753
|
+
"Unable to use code repository `%s` to download code or track "
|
754
|
+
"the commit hash as there are uncommitted or untracked files.",
|
755
|
+
local_repo_context.code_repository.name,
|
756
|
+
)
|
757
|
+
elif local_repo_context.has_local_changes:
|
758
|
+
logger.warning(
|
759
|
+
"Unable to use code repository `%s` to download code as there "
|
760
|
+
"are unpushed commits.",
|
761
|
+
local_repo_context.code_repository.name,
|
762
|
+
)
|
763
|
+
else:
|
764
|
+
logger.info(
|
765
|
+
"Using code repository `%s` to download code for this run.",
|
766
|
+
local_repo_context.code_repository.name,
|
767
|
+
)
|
768
|
+
elif local_repo_context.is_dirty:
|
769
|
+
logger.warning(
|
770
|
+
"Unable to use code repository `%s` to track the commit hash as "
|
771
|
+
"there are uncommitted or untracked files.",
|
772
|
+
local_repo_context.code_repository.name,
|
773
|
+
)
|
@@ -643,7 +643,6 @@ To avoid this consider setting pipeline parameters only in one place (config or
|
|
643
643
|
pipeline_id = None
|
644
644
|
if register_pipeline:
|
645
645
|
pipeline_id = self._register().id
|
646
|
-
|
647
646
|
else:
|
648
647
|
logger.debug(f"Pipeline {self.name} is unlisted.")
|
649
648
|
|
@@ -702,6 +701,10 @@ To avoid this consider setting pipeline parameters only in one place (config or
|
|
702
701
|
deployment=deployment, local_repo_context=local_repo_context
|
703
702
|
)
|
704
703
|
can_download_from_code_repository = code_repository is not None
|
704
|
+
if local_repo_context:
|
705
|
+
build_utils.log_code_repository_usage(
|
706
|
+
deployment=deployment, local_repo_context=local_repo_context
|
707
|
+
)
|
705
708
|
|
706
709
|
if prevent_build_reuse:
|
707
710
|
logger.warning(
|
@@ -731,7 +734,7 @@ To avoid this consider setting pipeline parameters only in one place (config or
|
|
731
734
|
code_reference = CodeReferenceRequest(
|
732
735
|
commit=local_repo_context.current_commit,
|
733
736
|
subdirectory=subdirectory.as_posix(),
|
734
|
-
code_repository=local_repo_context.
|
737
|
+
code_repository=local_repo_context.code_repository.id,
|
735
738
|
)
|
736
739
|
|
737
740
|
code_path = None
|
@@ -79,7 +79,7 @@ def set_custom_local_repository(
|
|
79
79
|
|
80
80
|
path = os.path.abspath(source_utils.get_source_root())
|
81
81
|
_CODE_REPOSITORY_CACHE[path] = _DownloadedRepositoryContext(
|
82
|
-
|
82
|
+
code_repository=repo, root=root, commit=commit
|
83
83
|
)
|
84
84
|
|
85
85
|
|
@@ -106,7 +106,8 @@ def find_active_code_repository(
|
|
106
106
|
return _CODE_REPOSITORY_CACHE[path]
|
107
107
|
|
108
108
|
local_context: Optional["LocalRepositoryContext"] = None
|
109
|
-
|
109
|
+
code_repositories = depaginate(list_method=Client().list_code_repositories)
|
110
|
+
for model in code_repositories:
|
110
111
|
try:
|
111
112
|
repo = BaseCodeRepository.from_model(model)
|
112
113
|
except ImportError:
|
@@ -125,6 +126,14 @@ def find_active_code_repository(
|
|
125
126
|
local_context = repo.get_local_context(path)
|
126
127
|
if local_context:
|
127
128
|
break
|
129
|
+
else:
|
130
|
+
if code_repositories:
|
131
|
+
# There are registered code repositories, but none was matching the
|
132
|
+
# current path -> We log the path to help in debugging issues
|
133
|
+
# related to the source root.
|
134
|
+
logger.info(
|
135
|
+
"No matching code repository found for path `%s`.", path
|
136
|
+
)
|
128
137
|
|
129
138
|
_CODE_REPOSITORY_CACHE[path] = local_context
|
130
139
|
return local_context
|
@@ -13,9 +13,7 @@
|
|
13
13
|
# permissions and limitations under the License.
|
14
14
|
"""Downloaded code repository."""
|
15
15
|
|
16
|
-
from
|
17
|
-
|
18
|
-
from zenml.code_repositories import LocalRepositoryContext
|
16
|
+
from zenml.code_repositories import BaseCodeRepository, LocalRepositoryContext
|
19
17
|
|
20
18
|
|
21
19
|
class _DownloadedRepositoryContext(LocalRepositoryContext):
|
@@ -27,11 +25,11 @@ class _DownloadedRepositoryContext(LocalRepositoryContext):
|
|
27
25
|
|
28
26
|
def __init__(
|
29
27
|
self,
|
30
|
-
|
28
|
+
code_repository: BaseCodeRepository,
|
31
29
|
root: str,
|
32
30
|
commit: str,
|
33
31
|
):
|
34
|
-
super().__init__(
|
32
|
+
super().__init__(code_repository=code_repository)
|
35
33
|
self._root = root
|
36
34
|
self._commit = commit
|
37
35
|
|
zenml/utils/source_utils.py
CHANGED
@@ -226,7 +226,7 @@ def resolve(
|
|
226
226
|
subdir = PurePath(source_root).relative_to(local_repo_context.root)
|
227
227
|
|
228
228
|
return CodeRepositorySource(
|
229
|
-
repository_id=local_repo_context.
|
229
|
+
repository_id=local_repo_context.code_repository.id,
|
230
230
|
commit=local_repo_context.current_commit,
|
231
231
|
subdirectory=subdir.as_posix(),
|
232
232
|
module=module_name,
|
@@ -482,7 +482,7 @@ def _warn_about_potential_source_loading_issues(
|
|
482
482
|
source.repository_id,
|
483
483
|
get_source_root(),
|
484
484
|
)
|
485
|
-
elif local_repo.
|
485
|
+
elif local_repo.code_repository.id != source.repository_id:
|
486
486
|
logger.warning(
|
487
487
|
"Potential issue when loading the source `%s`: The source "
|
488
488
|
"references the code repository `%s` but there is a different "
|
@@ -492,7 +492,7 @@ def _warn_about_potential_source_loading_issues(
|
|
492
492
|
"source was originally stored.",
|
493
493
|
source.import_path,
|
494
494
|
source.repository_id,
|
495
|
-
local_repo.
|
495
|
+
local_repo.code_repository.id,
|
496
496
|
get_source_root(),
|
497
497
|
)
|
498
498
|
elif local_repo.current_commit != source.commit:
|
@@ -34,7 +34,7 @@ from sqlalchemy.engine import URL, Engine
|
|
34
34
|
from sqlalchemy.exc import (
|
35
35
|
OperationalError,
|
36
36
|
)
|
37
|
-
from sqlalchemy.schema import CreateTable
|
37
|
+
from sqlalchemy.schema import CreateIndex, CreateTable
|
38
38
|
from sqlmodel import (
|
39
39
|
create_engine,
|
40
40
|
select,
|
@@ -249,6 +249,7 @@ class MigrationUtils(BaseModel):
|
|
249
249
|
# them to the create table statement.
|
250
250
|
|
251
251
|
# Extract the unique constraints from the table schema
|
252
|
+
index_create_statements = []
|
252
253
|
unique_constraints = []
|
253
254
|
for index in table.indexes:
|
254
255
|
if index.unique:
|
@@ -258,6 +259,38 @@ class MigrationUtils(BaseModel):
|
|
258
259
|
unique_constraints.append(
|
259
260
|
f"UNIQUE KEY `{index.name}` ({', '.join(unique_columns)})"
|
260
261
|
)
|
262
|
+
else:
|
263
|
+
if index.name in {
|
264
|
+
fk.name for fk in table.foreign_key_constraints
|
265
|
+
}:
|
266
|
+
# Foreign key indices are already handled by the
|
267
|
+
# table creation statement.
|
268
|
+
continue
|
269
|
+
|
270
|
+
index_create = str(CreateIndex(index)).strip() # type: ignore[no-untyped-call]
|
271
|
+
index_create = index_create.replace(
|
272
|
+
f"CREATE INDEX {index.name}",
|
273
|
+
f"CREATE INDEX `{index.name}`",
|
274
|
+
)
|
275
|
+
index_create = index_create.replace(
|
276
|
+
f"ON {table.name}", f"ON `{table.name}`"
|
277
|
+
)
|
278
|
+
|
279
|
+
for column_name in index.columns.keys():
|
280
|
+
# We need this logic here to avoid the column names
|
281
|
+
# inside the index name
|
282
|
+
index_create = index_create.replace(
|
283
|
+
f"({column_name}", f"(`{column_name}`"
|
284
|
+
)
|
285
|
+
index_create = index_create.replace(
|
286
|
+
f"{column_name},", f"`{column_name}`,"
|
287
|
+
)
|
288
|
+
index_create = index_create.replace(
|
289
|
+
f"{column_name})", f"`{column_name}`)"
|
290
|
+
)
|
291
|
+
|
292
|
+
index_create = index_create.replace('"', "") + ";"
|
293
|
+
index_create_statements.append(index_create)
|
261
294
|
|
262
295
|
# Add the unique constraints to the create table statement
|
263
296
|
if unique_constraints:
|
@@ -290,6 +323,14 @@ class MigrationUtils(BaseModel):
|
|
290
323
|
)
|
291
324
|
)
|
292
325
|
|
326
|
+
for stmt in index_create_statements:
|
327
|
+
store_db_info(
|
328
|
+
dict(
|
329
|
+
table=table.name,
|
330
|
+
index_create_stmt=stmt,
|
331
|
+
)
|
332
|
+
)
|
333
|
+
|
293
334
|
# 2. extract the table data in batches
|
294
335
|
order_by = [col for col in table.primary_key]
|
295
336
|
|
@@ -356,6 +397,12 @@ class MigrationUtils(BaseModel):
|
|
356
397
|
"self_references", False
|
357
398
|
)
|
358
399
|
|
400
|
+
if "index_create_stmt" in table_dump:
|
401
|
+
# execute the index creation statement
|
402
|
+
connection.execute(text(table_dump["index_create_stmt"]))
|
403
|
+
# Reload the database metadata after creating the index
|
404
|
+
metadata.reflect(bind=self.engine)
|
405
|
+
|
359
406
|
if "data" in table_dump:
|
360
407
|
# insert the data into the database
|
361
408
|
table = metadata.tables[table_name]
|
@@ -0,0 +1,67 @@
|
|
1
|
+
"""Add run metadata and tag index [4d5524b92a30].
|
2
|
+
|
3
|
+
Revision ID: 4d5524b92a30
|
4
|
+
Revises: 0.73.0
|
5
|
+
Create Date: 2025-01-30 11:30:36.736452
|
6
|
+
|
7
|
+
"""
|
8
|
+
|
9
|
+
from alembic import op
|
10
|
+
from sqlalchemy import inspect
|
11
|
+
|
12
|
+
# revision identifiers, used by Alembic.
|
13
|
+
revision = "4d5524b92a30"
|
14
|
+
down_revision = "0.73.0"
|
15
|
+
branch_labels = None
|
16
|
+
depends_on = None
|
17
|
+
|
18
|
+
|
19
|
+
def upgrade() -> None:
|
20
|
+
"""Upgrade database schema and/or data, creating a new revision."""
|
21
|
+
connection = op.get_bind()
|
22
|
+
|
23
|
+
inspector = inspect(connection)
|
24
|
+
for index in inspector.get_indexes("run_metadata_resource"):
|
25
|
+
# This index was manually added to some databases to improve the
|
26
|
+
# speed and cache utilisation. In this case we simply return here and
|
27
|
+
# don't continue with the migration.
|
28
|
+
if (
|
29
|
+
index["name"]
|
30
|
+
== "ix_run_metadata_resource_resource_id_resource_type_run_metadata_"
|
31
|
+
):
|
32
|
+
return
|
33
|
+
|
34
|
+
# ### commands auto generated by Alembic - please adjust! ###
|
35
|
+
with op.batch_alter_table(
|
36
|
+
"run_metadata_resource", schema=None
|
37
|
+
) as batch_op:
|
38
|
+
batch_op.create_index(
|
39
|
+
"ix_run_metadata_resource_resource_id_resource_type_run_metadata_",
|
40
|
+
["resource_id", "resource_type", "run_metadata_id"],
|
41
|
+
unique=False,
|
42
|
+
)
|
43
|
+
|
44
|
+
with op.batch_alter_table("tag_resource", schema=None) as batch_op:
|
45
|
+
batch_op.create_index(
|
46
|
+
"ix_tag_resource_resource_id_resource_type_tag_id",
|
47
|
+
["resource_id", "resource_type", "tag_id"],
|
48
|
+
unique=False,
|
49
|
+
)
|
50
|
+
|
51
|
+
# ### end Alembic commands ###
|
52
|
+
|
53
|
+
|
54
|
+
def downgrade() -> None:
|
55
|
+
"""Downgrade database schema and/or data back to the previous revision."""
|
56
|
+
# ### commands auto generated by Alembic - please adjust! ###
|
57
|
+
with op.batch_alter_table("tag_resource", schema=None) as batch_op:
|
58
|
+
batch_op.drop_index("ix_tag_resource_resource_id_resource_type_tag_id")
|
59
|
+
|
60
|
+
with op.batch_alter_table(
|
61
|
+
"run_metadata_resource", schema=None
|
62
|
+
) as batch_op:
|
63
|
+
batch_op.drop_index(
|
64
|
+
"ix_run_metadata_resource_resource_id_resource_type_run_metadata_"
|
65
|
+
)
|
66
|
+
|
67
|
+
# ### end Alembic commands ###
|
@@ -365,17 +365,16 @@ class RestZenStoreConfiguration(StoreConfiguration):
|
|
365
365
|
|
366
366
|
if os.path.isfile(verify_ssl):
|
367
367
|
with open(verify_ssl, "r") as f:
|
368
|
-
|
368
|
+
cert_content = f.read()
|
369
369
|
|
370
370
|
fileio.makedirs(str(secret_folder))
|
371
371
|
file_path = Path(secret_folder, "ca_bundle.pem")
|
372
372
|
with os.fdopen(
|
373
373
|
os.open(file_path, flags=os.O_RDWR | os.O_CREAT, mode=0o600), "w"
|
374
374
|
) as f:
|
375
|
-
f.write(
|
376
|
-
verify_ssl = str(file_path)
|
375
|
+
f.write(cert_content)
|
377
376
|
|
378
|
-
return
|
377
|
+
return str(file_path)
|
379
378
|
|
380
379
|
@classmethod
|
381
380
|
def supports_url_scheme(cls, url: str) -> bool:
|
@@ -389,15 +388,6 @@ class RestZenStoreConfiguration(StoreConfiguration):
|
|
389
388
|
"""
|
390
389
|
return urlparse(url).scheme in ("http", "https")
|
391
390
|
|
392
|
-
def expand_certificates(self) -> None:
|
393
|
-
"""Expands the certificates in the verify_ssl field."""
|
394
|
-
# Load the certificate values back into the configuration
|
395
|
-
if isinstance(self.verify_ssl, str) and os.path.isfile(
|
396
|
-
self.verify_ssl
|
397
|
-
):
|
398
|
-
with open(self.verify_ssl, "r") as f:
|
399
|
-
self.verify_ssl = f.read()
|
400
|
-
|
401
391
|
@model_validator(mode="before")
|
402
392
|
@classmethod
|
403
393
|
@before_validator_handler
|
@@ -11,7 +11,7 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
12
12
|
# or implied. See the License for the specific language governing
|
13
13
|
# permissions and limitations under the License.
|
14
|
-
"""SQLModel implementation of
|
14
|
+
"""SQLModel implementation of run metadata tables."""
|
15
15
|
|
16
16
|
from typing import Optional
|
17
17
|
from uuid import UUID, uuid4
|
@@ -21,7 +21,10 @@ from sqlmodel import Field, Relationship, SQLModel
|
|
21
21
|
|
22
22
|
from zenml.zen_stores.schemas.base_schemas import BaseSchema
|
23
23
|
from zenml.zen_stores.schemas.component_schemas import StackComponentSchema
|
24
|
-
from zenml.zen_stores.schemas.schema_utils import
|
24
|
+
from zenml.zen_stores.schemas.schema_utils import (
|
25
|
+
build_foreign_key_field,
|
26
|
+
build_index,
|
27
|
+
)
|
25
28
|
from zenml.zen_stores.schemas.step_run_schemas import StepRunSchema
|
26
29
|
from zenml.zen_stores.schemas.user_schemas import UserSchema
|
27
30
|
from zenml.zen_stores.schemas.workspace_schemas import WorkspaceSchema
|
@@ -82,6 +85,16 @@ class RunMetadataResourceSchema(SQLModel, table=True):
|
|
82
85
|
"""Table for linking resources to run metadata entries."""
|
83
86
|
|
84
87
|
__tablename__ = "run_metadata_resource"
|
88
|
+
__table_args__ = (
|
89
|
+
build_index(
|
90
|
+
table_name=__tablename__,
|
91
|
+
column_names=[
|
92
|
+
"resource_id",
|
93
|
+
"resource_type",
|
94
|
+
"run_metadata_id",
|
95
|
+
],
|
96
|
+
),
|
97
|
+
)
|
85
98
|
|
86
99
|
id: UUID = Field(default_factory=uuid4, primary_key=True)
|
87
100
|
resource_id: UUID
|
@@ -13,9 +13,9 @@
|
|
13
13
|
# permissions and limitations under the License.
|
14
14
|
"""Utility functions for SQLModel schemas."""
|
15
15
|
|
16
|
-
from typing import Any
|
16
|
+
from typing import Any, List
|
17
17
|
|
18
|
-
from sqlalchemy import Column, ForeignKey
|
18
|
+
from sqlalchemy import Column, ForeignKey, Index
|
19
19
|
from sqlmodel import Field
|
20
20
|
|
21
21
|
|
@@ -84,3 +84,35 @@ def build_foreign_key_field(
|
|
84
84
|
**sa_column_kwargs,
|
85
85
|
),
|
86
86
|
)
|
87
|
+
|
88
|
+
|
89
|
+
def get_index_name(table_name: str, column_names: List[str]) -> str:
|
90
|
+
"""Get the name for an index.
|
91
|
+
|
92
|
+
Args:
|
93
|
+
table_name: The name of the table for which the index will be created.
|
94
|
+
column_names: Names of the columns on which the index will be created.
|
95
|
+
|
96
|
+
Returns:
|
97
|
+
The index name.
|
98
|
+
"""
|
99
|
+
columns = "_".join(column_names)
|
100
|
+
# MySQL allows a maximum of 64 characters in identifiers
|
101
|
+
return f"ix_{table_name}_{columns}"[:64]
|
102
|
+
|
103
|
+
|
104
|
+
def build_index(
|
105
|
+
table_name: str, column_names: List[str], **kwargs: Any
|
106
|
+
) -> Index:
|
107
|
+
"""Build an index object.
|
108
|
+
|
109
|
+
Args:
|
110
|
+
table_name: The name of the table for which the index will be created.
|
111
|
+
column_names: Names of the columns on which the index will be created.
|
112
|
+
**kwargs: Additional keyword arguments to pass to the Index.
|
113
|
+
|
114
|
+
Returns:
|
115
|
+
The index.
|
116
|
+
"""
|
117
|
+
name = get_index_name(table_name=table_name, column_names=column_names)
|
118
|
+
return Index(name, *column_names, **kwargs)
|
@@ -31,7 +31,10 @@ from zenml.models import (
|
|
31
31
|
)
|
32
32
|
from zenml.utils.time_utils import utc_now
|
33
33
|
from zenml.zen_stores.schemas.base_schemas import BaseSchema, NamedSchema
|
34
|
-
from zenml.zen_stores.schemas.schema_utils import
|
34
|
+
from zenml.zen_stores.schemas.schema_utils import (
|
35
|
+
build_foreign_key_field,
|
36
|
+
build_index,
|
37
|
+
)
|
35
38
|
|
36
39
|
|
37
40
|
class TagSchema(NamedSchema, table=True):
|
@@ -111,6 +114,16 @@ class TagResourceSchema(BaseSchema, table=True):
|
|
111
114
|
"""SQL Model for tag resource relationship."""
|
112
115
|
|
113
116
|
__tablename__ = "tag_resource"
|
117
|
+
__table_args__ = (
|
118
|
+
build_index(
|
119
|
+
table_name=__tablename__,
|
120
|
+
column_names=[
|
121
|
+
"resource_id",
|
122
|
+
"resource_type",
|
123
|
+
"tag_id",
|
124
|
+
],
|
125
|
+
),
|
126
|
+
)
|
114
127
|
|
115
128
|
tag_id: UUID = build_foreign_key_field(
|
116
129
|
source=__tablename__,
|
@@ -37,6 +37,7 @@ from zenml.exceptions import (
|
|
37
37
|
IllegalOperationError,
|
38
38
|
)
|
39
39
|
from zenml.logger import get_logger
|
40
|
+
from zenml.utils.secret_utils import PlainSerializedSecretStr
|
40
41
|
from zenml.zen_stores.schemas import (
|
41
42
|
SecretSchema,
|
42
43
|
)
|
@@ -62,7 +63,7 @@ class SqlSecretsStoreConfiguration(SecretsStoreConfiguration):
|
|
62
63
|
"""
|
63
64
|
|
64
65
|
type: SecretsStoreType = SecretsStoreType.SQL
|
65
|
-
encryption_key: Optional[
|
66
|
+
encryption_key: Optional[PlainSerializedSecretStr] = None
|
66
67
|
model_config = ConfigDict(
|
67
68
|
# Don't validate attributes when assigning them. This is necessary
|
68
69
|
# because the certificate attributes can be expanded to the contents
|
@@ -159,7 +160,9 @@ class SqlSecretsStore(BaseSecretsStore):
|
|
159
160
|
# Initialize the encryption engine
|
160
161
|
if self.config.encryption_key:
|
161
162
|
self._encryption_engine = AesGcmEngine()
|
162
|
-
self._encryption_engine._update_key(
|
163
|
+
self._encryption_engine._update_key(
|
164
|
+
self.config.encryption_key.get_secret_value()
|
165
|
+
)
|
163
166
|
|
164
167
|
# Nothing else to do here, the SQL ZenML store back-end is already
|
165
168
|
# initialized
|
@@ -304,6 +304,7 @@ from zenml.utils.networking_utils import (
|
|
304
304
|
replace_localhost_with_internal_hostname,
|
305
305
|
)
|
306
306
|
from zenml.utils.pydantic_utils import before_validator_handler
|
307
|
+
from zenml.utils.secret_utils import PlainSerializedSecretStr
|
307
308
|
from zenml.utils.string_utils import (
|
308
309
|
format_name_template,
|
309
310
|
random_str,
|
@@ -460,11 +461,11 @@ class SqlZenStoreConfiguration(StoreConfiguration):
|
|
460
461
|
|
461
462
|
driver: Optional[SQLDatabaseDriver] = None
|
462
463
|
database: Optional[str] = None
|
463
|
-
username: Optional[
|
464
|
-
password: Optional[
|
465
|
-
ssl_ca: Optional[
|
466
|
-
ssl_cert: Optional[
|
467
|
-
ssl_key: Optional[
|
464
|
+
username: Optional[PlainSerializedSecretStr] = None
|
465
|
+
password: Optional[PlainSerializedSecretStr] = None
|
466
|
+
ssl_ca: Optional[PlainSerializedSecretStr] = None
|
467
|
+
ssl_cert: Optional[PlainSerializedSecretStr] = None
|
468
|
+
ssl_key: Optional[PlainSerializedSecretStr] = None
|
468
469
|
ssl_verify_server_cert: bool = False
|
469
470
|
pool_size: int = 20
|
470
471
|
max_overflow: int = 20
|
@@ -611,10 +612,10 @@ class SqlZenStoreConfiguration(StoreConfiguration):
|
|
611
612
|
self.database = sql_url.database
|
612
613
|
elif sql_url.drivername == SQLDatabaseDriver.MYSQL:
|
613
614
|
if sql_url.username:
|
614
|
-
self.username = sql_url.username
|
615
|
+
self.username = PlainSerializedSecretStr(sql_url.username)
|
615
616
|
sql_url = sql_url._replace(username=None)
|
616
617
|
if sql_url.password:
|
617
|
-
self.password = sql_url.password
|
618
|
+
self.password = PlainSerializedSecretStr(sql_url.password)
|
618
619
|
sql_url = sql_url._replace(password=None)
|
619
620
|
if sql_url.database:
|
620
621
|
self.database = sql_url.database
|
@@ -642,13 +643,13 @@ class SqlZenStoreConfiguration(StoreConfiguration):
|
|
642
643
|
for k, v in sql_url.query.items():
|
643
644
|
if k == "ssl_ca":
|
644
645
|
if r := _get_query_result(v):
|
645
|
-
self.ssl_ca = r
|
646
|
+
self.ssl_ca = PlainSerializedSecretStr(r)
|
646
647
|
elif k == "ssl_cert":
|
647
648
|
if r := _get_query_result(v):
|
648
|
-
self.ssl_cert = r
|
649
|
+
self.ssl_cert = PlainSerializedSecretStr(r)
|
649
650
|
elif k == "ssl_key":
|
650
651
|
if r := _get_query_result(v):
|
651
|
-
self.ssl_key = r
|
652
|
+
self.ssl_key = PlainSerializedSecretStr(r)
|
652
653
|
elif k == "ssl_verify_server_cert":
|
653
654
|
if r := _get_query_result(v):
|
654
655
|
if is_true_string_value(r):
|
@@ -688,7 +689,7 @@ class SqlZenStoreConfiguration(StoreConfiguration):
|
|
688
689
|
)
|
689
690
|
for key in ["ssl_key", "ssl_ca", "ssl_cert"]:
|
690
691
|
content = getattr(self, key)
|
691
|
-
if content and not os.path.isfile(content):
|
692
|
+
if content and not os.path.isfile(content.get_secret_value()):
|
692
693
|
fileio.makedirs(str(secret_folder))
|
693
694
|
file_path = Path(secret_folder, f"{key}.pem")
|
694
695
|
with os.fdopen(
|
@@ -697,7 +698,7 @@ class SqlZenStoreConfiguration(StoreConfiguration):
|
|
697
698
|
),
|
698
699
|
"w",
|
699
700
|
) as f:
|
700
|
-
f.write(content)
|
701
|
+
f.write(content.get_secret_value())
|
701
702
|
setattr(self, key, str(file_path))
|
702
703
|
|
703
704
|
self.url = str(sql_url)
|
@@ -732,7 +733,7 @@ class SqlZenStoreConfiguration(StoreConfiguration):
|
|
732
733
|
# Load the certificate values back into the configuration
|
733
734
|
for key in ["ssl_key", "ssl_ca", "ssl_cert"]:
|
734
735
|
file_path = getattr(self, key, None)
|
735
|
-
if file_path and os.path.isfile(file_path):
|
736
|
+
if file_path and os.path.isfile(file_path.get_secret_value()):
|
736
737
|
with open(file_path, "r") as f:
|
737
738
|
setattr(self, key, f.read())
|
738
739
|
|
@@ -780,8 +781,8 @@ class SqlZenStoreConfiguration(StoreConfiguration):
|
|
780
781
|
|
781
782
|
sql_url = sql_url._replace(
|
782
783
|
drivername="mysql+pymysql",
|
783
|
-
username=self.username,
|
784
|
-
password=self.password,
|
784
|
+
username=self.username.get_secret_value(),
|
785
|
+
password=self.password.get_secret_value(),
|
785
786
|
database=database,
|
786
787
|
)
|
787
788
|
|
@@ -792,11 +793,17 @@ class SqlZenStoreConfiguration(StoreConfiguration):
|
|
792
793
|
ssl_setting = getattr(self, key)
|
793
794
|
if not ssl_setting:
|
794
795
|
continue
|
795
|
-
if not os.path.isfile(ssl_setting):
|
796
|
+
if not os.path.isfile(ssl_setting.get_secret_value()):
|
796
797
|
logger.warning(
|
797
798
|
f"Database SSL setting `{key}` is not a file. "
|
798
799
|
)
|
799
|
-
sqlalchemy_ssl_args[key.
|
800
|
+
sqlalchemy_ssl_args[key.lstrip("ssl_")] = (
|
801
|
+
ssl_setting.get_secret_value()
|
802
|
+
)
|
803
|
+
sqlalchemy_ssl_args[key.removeprefix("ssl_")] = (
|
804
|
+
ssl_setting.get_secret_value()
|
805
|
+
)
|
806
|
+
|
800
807
|
if len(sqlalchemy_ssl_args) > 0:
|
801
808
|
sqlalchemy_ssl_args["check_hostname"] = (
|
802
809
|
self.ssl_verify_server_cert
|