zenml-nightly 0.80.2.dev20250414__py3-none-any.whl → 0.80.2.dev20250416__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/artifacts/utils.py +7 -2
- zenml/cli/utils.py +13 -11
- zenml/config/compiler.py +1 -0
- zenml/config/global_config.py +1 -1
- zenml/config/pipeline_configurations.py +1 -0
- zenml/config/pipeline_run_configuration.py +1 -0
- zenml/config/server_config.py +7 -0
- zenml/constants.py +8 -0
- zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +47 -5
- zenml/integrations/gcp/vertex_custom_job_parameters.py +15 -1
- zenml/integrations/kubernetes/flavors/kubernetes_step_operator_flavor.py +12 -0
- zenml/integrations/kubernetes/orchestrators/kube_utils.py +92 -0
- zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py +12 -3
- zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +11 -65
- zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py +11 -3
- zenml/logging/step_logging.py +41 -21
- zenml/login/credentials_store.py +31 -0
- zenml/materializers/path_materializer.py +17 -2
- zenml/models/v2/base/base.py +8 -4
- zenml/models/v2/base/filter.py +1 -1
- zenml/models/v2/core/pipeline_run.py +19 -0
- zenml/orchestrators/step_launcher.py +2 -3
- zenml/orchestrators/step_runner.py +2 -2
- zenml/orchestrators/utils.py +2 -5
- zenml/pipelines/pipeline_context.py +1 -0
- zenml/pipelines/pipeline_decorator.py +4 -0
- zenml/pipelines/pipeline_definition.py +83 -22
- zenml/pipelines/run_utils.py +4 -0
- zenml/steps/utils.py +1 -1
- zenml/utils/io_utils.py +23 -0
- zenml/zen_server/auth.py +96 -64
- zenml/zen_server/cloud_utils.py +7 -1
- zenml/zen_server/download_utils.py +123 -0
- zenml/zen_server/jwt.py +0 -14
- zenml/zen_server/rbac/rbac_interface.py +10 -3
- zenml/zen_server/rbac/utils.py +13 -3
- zenml/zen_server/rbac/zenml_cloud_rbac.py +14 -8
- zenml/zen_server/routers/artifact_version_endpoints.py +86 -3
- zenml/zen_server/routers/auth_endpoints.py +5 -36
- zenml/zen_server/routers/pipeline_deployments_endpoints.py +63 -26
- zenml/zen_server/routers/runs_endpoints.py +57 -0
- zenml/zen_server/routers/users_endpoints.py +13 -8
- zenml/zen_server/template_execution/utils.py +3 -3
- zenml/zen_stores/migrations/versions/ff538a321a92_migrate_onboarding_state.py +123 -0
- zenml/zen_stores/rest_zen_store.py +16 -13
- zenml/zen_stores/schemas/pipeline_run_schemas.py +1 -0
- zenml/zen_stores/schemas/server_settings_schemas.py +4 -1
- zenml/zen_stores/sql_zen_store.py +18 -0
- {zenml_nightly-0.80.2.dev20250414.dist-info → zenml_nightly-0.80.2.dev20250416.dist-info}/METADATA +2 -1
- {zenml_nightly-0.80.2.dev20250414.dist-info → zenml_nightly-0.80.2.dev20250416.dist-info}/RECORD +54 -52
- {zenml_nightly-0.80.2.dev20250414.dist-info → zenml_nightly-0.80.2.dev20250416.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.80.2.dev20250414.dist-info → zenml_nightly-0.80.2.dev20250416.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.80.2.dev20250414.dist-info → zenml_nightly-0.80.2.dev20250416.dist-info}/entry_points.txt +0 -0
zenml/pipelines/run_utils.py
CHANGED
@@ -15,6 +15,7 @@ from zenml.enums import ExecutionStatus
|
|
15
15
|
from zenml.logger import get_logger
|
16
16
|
from zenml.models import (
|
17
17
|
FlavorFilter,
|
18
|
+
LogsRequest,
|
18
19
|
PipelineDeploymentBase,
|
19
20
|
PipelineDeploymentResponse,
|
20
21
|
PipelineRunRequest,
|
@@ -49,6 +50,7 @@ def get_default_run_name(pipeline_name: str) -> str:
|
|
49
50
|
|
50
51
|
def create_placeholder_run(
|
51
52
|
deployment: "PipelineDeploymentResponse",
|
53
|
+
logs: Optional["LogsRequest"] = None,
|
52
54
|
) -> Optional["PipelineRunResponse"]:
|
53
55
|
"""Create a placeholder run for the deployment.
|
54
56
|
|
@@ -57,6 +59,7 @@ def create_placeholder_run(
|
|
57
59
|
|
58
60
|
Args:
|
59
61
|
deployment: The deployment for which to create the placeholder run.
|
62
|
+
logs: The logs for the run.
|
60
63
|
|
61
64
|
Returns:
|
62
65
|
The placeholder run or `None` if no run was created.
|
@@ -86,6 +89,7 @@ def create_placeholder_run(
|
|
86
89
|
pipeline=deployment.pipeline.id if deployment.pipeline else None,
|
87
90
|
status=ExecutionStatus.INITIALIZING,
|
88
91
|
tags=deployment.pipeline_configuration.tags,
|
92
|
+
logs=logs,
|
89
93
|
)
|
90
94
|
run, _ = Client().zen_store.get_or_create_run(run_request)
|
91
95
|
return run
|
zenml/steps/utils.py
CHANGED
@@ -553,7 +553,7 @@ def run_as_single_step_pipeline(
|
|
553
553
|
orchestrator = Client().active_stack.orchestrator
|
554
554
|
|
555
555
|
pipeline_settings: Any = {}
|
556
|
-
if "synchronous" in orchestrator.config.model_fields:
|
556
|
+
if "synchronous" in type(orchestrator.config).model_fields:
|
557
557
|
# Make sure the orchestrator runs sync so we stream the logs
|
558
558
|
key = settings_utils.get_stack_component_setting_key(orchestrator)
|
559
559
|
pipeline_settings[key] = BaseSettings(synchronous=True)
|
zenml/utils/io_utils.py
CHANGED
@@ -205,6 +205,29 @@ def resolve_relative_path(path: str) -> str:
|
|
205
205
|
return str(Path(path).resolve())
|
206
206
|
|
207
207
|
|
208
|
+
def is_path_within_directory(path: str, directory: str) -> bool:
|
209
|
+
"""Checks if a path is contained within a given directory.
|
210
|
+
|
211
|
+
This utility function verifies that a path (absolute or relative) resolves
|
212
|
+
to a location that is within the specified directory. This is useful for
|
213
|
+
security checks such as preventing path traversal attacks when extracting
|
214
|
+
archives (CVE-2007-4559) or whenever path containment needs to be verified.
|
215
|
+
|
216
|
+
Args:
|
217
|
+
path: The path to check (can be relative or absolute).
|
218
|
+
directory: The directory that should contain the path.
|
219
|
+
|
220
|
+
Returns:
|
221
|
+
Boolean indicating whether the path is contained within the directory (True)
|
222
|
+
or not (False).
|
223
|
+
"""
|
224
|
+
# Convert to absolute path, ensuring it's normalized
|
225
|
+
abs_path = os.path.abspath(os.path.join(directory, path))
|
226
|
+
# Check if the path is within the target directory
|
227
|
+
dir_abs = os.path.abspath(directory)
|
228
|
+
return abs_path.startswith(dir_abs + os.sep) or abs_path == dir_abs
|
229
|
+
|
230
|
+
|
208
231
|
def move(source: str, destination: str, overwrite: bool = False) -> None:
|
209
232
|
"""Moves dir or file from source to destination. Can be used to rename.
|
210
233
|
|
zenml/zen_server/auth.py
CHANGED
@@ -15,7 +15,7 @@
|
|
15
15
|
|
16
16
|
from contextvars import ContextVar
|
17
17
|
from datetime import datetime, timedelta
|
18
|
-
from typing import Callable, Optional, Union
|
18
|
+
from typing import Callable, Optional, Tuple, Union
|
19
19
|
from urllib.parse import urlencode, urlparse
|
20
20
|
from uuid import UUID, uuid4
|
21
21
|
|
@@ -33,9 +33,12 @@ from zenml.analytics.context import AnalyticsContext
|
|
33
33
|
from zenml.constants import (
|
34
34
|
API,
|
35
35
|
DEFAULT_USERNAME,
|
36
|
+
DEFAULT_ZENML_SERVER_GENERIC_API_TOKEN_LIFETIME,
|
37
|
+
ENV_ZENML_WORKLOAD_TOKEN_EXPIRATION_LEEWAY,
|
36
38
|
EXTERNAL_AUTHENTICATOR_TIMEOUT,
|
37
39
|
LOGIN,
|
38
40
|
VERSION_1,
|
41
|
+
handle_int_env_var,
|
39
42
|
)
|
40
43
|
from zenml.enums import (
|
41
44
|
AuthScheme,
|
@@ -420,28 +423,32 @@ def authenticate_credentials(
|
|
420
423
|
@cache_result(expiry=30)
|
421
424
|
def get_pipeline_run_status(
|
422
425
|
pipeline_run_id: UUID,
|
423
|
-
) -> Optional[ExecutionStatus]:
|
426
|
+
) -> Tuple[Optional[ExecutionStatus], Optional[datetime]]:
|
424
427
|
"""Get the status of a pipeline run.
|
425
428
|
|
426
429
|
Args:
|
427
430
|
pipeline_run_id: The pipeline run ID.
|
428
431
|
|
429
432
|
Returns:
|
430
|
-
The pipeline run status or None if the pipeline
|
431
|
-
exist.
|
433
|
+
The pipeline run status and end time or None if the pipeline
|
434
|
+
run does not exist.
|
432
435
|
"""
|
433
436
|
try:
|
434
437
|
pipeline_run = zen_store().get_run(
|
435
|
-
pipeline_run_id, hydrate=
|
438
|
+
pipeline_run_id, hydrate=True
|
436
439
|
)
|
437
440
|
except KeyError:
|
438
|
-
return None
|
441
|
+
return None, None
|
439
442
|
|
440
|
-
return
|
443
|
+
return (
|
444
|
+
pipeline_run.status,
|
445
|
+
pipeline_run.end_time,
|
446
|
+
)
|
441
447
|
|
442
|
-
|
443
|
-
|
444
|
-
|
448
|
+
(
|
449
|
+
pipeline_run_status,
|
450
|
+
pipeline_run_end_time,
|
451
|
+
) = get_pipeline_run_status(decoded_token.pipeline_run_id)
|
445
452
|
if pipeline_run_status is None:
|
446
453
|
error = (
|
447
454
|
f"Authentication error: error retrieving token pipeline run "
|
@@ -450,59 +457,35 @@ def authenticate_credentials(
|
|
450
457
|
logger.error(error)
|
451
458
|
raise CredentialsNotValid(error)
|
452
459
|
|
460
|
+
leeway = handle_int_env_var(
|
461
|
+
ENV_ZENML_WORKLOAD_TOKEN_EXPIRATION_LEEWAY,
|
462
|
+
DEFAULT_ZENML_SERVER_GENERIC_API_TOKEN_LIFETIME,
|
463
|
+
)
|
453
464
|
if pipeline_run_status.is_finished:
|
454
|
-
|
455
|
-
|
456
|
-
|
457
|
-
|
458
|
-
|
459
|
-
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
|
473
|
-
|
474
|
-
|
475
|
-
step_run_id: The step run ID.
|
476
|
-
|
477
|
-
Returns:
|
478
|
-
The step run status or None if the step run does not exist.
|
479
|
-
"""
|
480
|
-
try:
|
481
|
-
step_run = zen_store().get_run_step(
|
482
|
-
step_run_id, hydrate=False
|
465
|
+
if leeway < 0:
|
466
|
+
# The token should never expire, we don't need to check
|
467
|
+
# the end time.
|
468
|
+
pass
|
469
|
+
elif (
|
470
|
+
# We don't know the end time. This should never happen, but
|
471
|
+
# just in case we always expire the token.
|
472
|
+
pipeline_run_end_time is None
|
473
|
+
# Calculate whether the token has expired.
|
474
|
+
or utc_now(tz_aware=pipeline_run_end_time)
|
475
|
+
> pipeline_run_end_time + timedelta(seconds=leeway)
|
476
|
+
):
|
477
|
+
error = (
|
478
|
+
f"The pipeline run {decoded_token.pipeline_run_id} has "
|
479
|
+
"finished and API tokens scoped to it are no longer "
|
480
|
+
"valid. If you want to increase the expiration time "
|
481
|
+
"of the token to allow steps to continue for longer "
|
482
|
+
"after other steps have failed, you can do so by "
|
483
|
+
"configuring the "
|
484
|
+
f"`{ENV_ZENML_WORKLOAD_TOKEN_EXPIRATION_LEEWAY}` "
|
485
|
+
"ZenML server environment variable."
|
483
486
|
)
|
484
|
-
|
485
|
-
|
486
|
-
|
487
|
-
return step_run.status
|
488
|
-
|
489
|
-
step_run_status = get_step_run_status(decoded_token.step_run_id)
|
490
|
-
if step_run_status is None:
|
491
|
-
error = (
|
492
|
-
f"Authentication error: error retrieving token step run "
|
493
|
-
f"{decoded_token.step_run_id}"
|
494
|
-
)
|
495
|
-
logger.error(error)
|
496
|
-
raise CredentialsNotValid(error)
|
497
|
-
|
498
|
-
if step_run_status.is_finished:
|
499
|
-
error = (
|
500
|
-
f"The execution of step run "
|
501
|
-
f"{decoded_token.step_run_id} has already concluded and "
|
502
|
-
"API tokens scoped to it are no longer valid."
|
503
|
-
)
|
504
|
-
logger.error(error)
|
505
|
-
raise CredentialsNotValid(error)
|
487
|
+
logger.error(error)
|
488
|
+
raise CredentialsNotValid(error)
|
506
489
|
|
507
490
|
auth_context = AuthContext(
|
508
491
|
user=user_model,
|
@@ -861,7 +844,6 @@ def generate_access_token(
|
|
861
844
|
expires_in: Optional[int] = None,
|
862
845
|
schedule_id: Optional[UUID] = None,
|
863
846
|
pipeline_run_id: Optional[UUID] = None,
|
864
|
-
step_run_id: Optional[UUID] = None,
|
865
847
|
) -> OAuthTokenResponse:
|
866
848
|
"""Generates an access token for the given user.
|
867
849
|
|
@@ -880,7 +862,6 @@ def generate_access_token(
|
|
880
862
|
expire.
|
881
863
|
schedule_id: The ID of the schedule to scope the token to.
|
882
864
|
pipeline_run_id: The ID of the pipeline run to scope the token to.
|
883
|
-
step_run_id: The ID of the step run to scope the token to.
|
884
865
|
|
885
866
|
Returns:
|
886
867
|
An authentication response with an access token.
|
@@ -956,7 +937,6 @@ def generate_access_token(
|
|
956
937
|
api_key_id=api_key.id if api_key else None,
|
957
938
|
schedule_id=schedule_id,
|
958
939
|
pipeline_run_id=pipeline_run_id,
|
959
|
-
step_run_id=step_run_id,
|
960
940
|
# Set the session ID if this is a cross-site request
|
961
941
|
session_id=session_id,
|
962
942
|
).encode(expires=expires)
|
@@ -983,6 +963,58 @@ def generate_access_token(
|
|
983
963
|
)
|
984
964
|
|
985
965
|
|
966
|
+
def generate_artifact_download_token(artifact_version_id: UUID) -> str:
|
967
|
+
"""Generate a JWT token for artifact download.
|
968
|
+
|
969
|
+
Args:
|
970
|
+
artifact_version_id: The ID of the artifact version to download.
|
971
|
+
|
972
|
+
Returns:
|
973
|
+
The JWT token for the artifact download.
|
974
|
+
"""
|
975
|
+
import jwt
|
976
|
+
|
977
|
+
config = server_config()
|
978
|
+
|
979
|
+
return jwt.encode(
|
980
|
+
{
|
981
|
+
"exp": utc_now() + timedelta(seconds=30),
|
982
|
+
"artifact_version_id": str(artifact_version_id),
|
983
|
+
},
|
984
|
+
key=config.jwt_secret_key,
|
985
|
+
algorithm=config.jwt_token_algorithm,
|
986
|
+
)
|
987
|
+
|
988
|
+
|
989
|
+
def verify_artifact_download_token(
|
990
|
+
token: str, artifact_version_id: UUID
|
991
|
+
) -> None:
|
992
|
+
"""Verify a JWT token for artifact download.
|
993
|
+
|
994
|
+
Args:
|
995
|
+
token: The JWT token to verify.
|
996
|
+
artifact_version_id: The ID of the artifact version to download.
|
997
|
+
|
998
|
+
Raises:
|
999
|
+
CredentialsNotValid: If the token is invalid or the artifact version
|
1000
|
+
ID does not match.
|
1001
|
+
"""
|
1002
|
+
import jwt
|
1003
|
+
|
1004
|
+
config = server_config()
|
1005
|
+
try:
|
1006
|
+
claims = jwt.decode(
|
1007
|
+
token,
|
1008
|
+
config.jwt_secret_key,
|
1009
|
+
algorithms=[config.jwt_token_algorithm],
|
1010
|
+
)
|
1011
|
+
except jwt.PyJWTError as e:
|
1012
|
+
raise CredentialsNotValid(f"Invalid JWT token: {e}") from e
|
1013
|
+
|
1014
|
+
if claims["artifact_version_id"] != str(artifact_version_id):
|
1015
|
+
raise CredentialsNotValid("Invalid artifact version ID")
|
1016
|
+
|
1017
|
+
|
986
1018
|
def http_authentication(
|
987
1019
|
credentials: HTTPBasicCredentials = Depends(HTTPBasic()),
|
988
1020
|
) -> AuthContext:
|
zenml/zen_server/cloud_utils.py
CHANGED
@@ -7,7 +7,10 @@ import requests
|
|
7
7
|
from requests.adapters import HTTPAdapter, Retry
|
8
8
|
|
9
9
|
from zenml.config.server_config import ServerProConfiguration
|
10
|
-
from zenml.exceptions import
|
10
|
+
from zenml.exceptions import (
|
11
|
+
IllegalOperationError,
|
12
|
+
SubscriptionUpgradeRequiredError,
|
13
|
+
)
|
11
14
|
from zenml.utils.time_utils import utc_now
|
12
15
|
from zenml.zen_server.utils import get_zenml_headers, server_config
|
13
16
|
|
@@ -43,6 +46,7 @@ class ZenMLCloudConnection:
|
|
43
46
|
Raises:
|
44
47
|
SubscriptionUpgradeRequiredError: If the current subscription tier
|
45
48
|
is insufficient for the attempted operation.
|
49
|
+
IllegalOperationError: If the request failed with a 403 status code.
|
46
50
|
RuntimeError: If the request failed.
|
47
51
|
|
48
52
|
Returns:
|
@@ -65,6 +69,8 @@ class ZenMLCloudConnection:
|
|
65
69
|
except requests.HTTPError as e:
|
66
70
|
if response.status_code == 402:
|
67
71
|
raise SubscriptionUpgradeRequiredError(response.json())
|
72
|
+
elif response.status_code == 403:
|
73
|
+
raise IllegalOperationError(response.json())
|
68
74
|
else:
|
69
75
|
raise RuntimeError(
|
70
76
|
f"Failed while trying to contact the central zenml pro "
|
@@ -0,0 +1,123 @@
|
|
1
|
+
# Copyright (c) ZenML GmbH 2025. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
12
|
+
# or implied. See the License for the specific language governing
|
13
|
+
# permissions and limitations under the License.
|
14
|
+
"""Utility functions for downloading artifacts."""
|
15
|
+
|
16
|
+
import os
|
17
|
+
import tarfile
|
18
|
+
import tempfile
|
19
|
+
from typing import (
|
20
|
+
TYPE_CHECKING,
|
21
|
+
)
|
22
|
+
|
23
|
+
from zenml.artifacts.utils import _load_artifact_store
|
24
|
+
from zenml.exceptions import (
|
25
|
+
IllegalOperationError,
|
26
|
+
)
|
27
|
+
from zenml.models import (
|
28
|
+
ArtifactVersionResponse,
|
29
|
+
)
|
30
|
+
from zenml.zen_server.utils import server_config, zen_store
|
31
|
+
|
32
|
+
if TYPE_CHECKING:
|
33
|
+
from zenml.artifact_stores.base_artifact_store import BaseArtifactStore
|
34
|
+
|
35
|
+
|
36
|
+
def verify_artifact_is_downloadable(
|
37
|
+
artifact: "ArtifactVersionResponse",
|
38
|
+
) -> "BaseArtifactStore":
|
39
|
+
"""Verify that the given artifact is downloadable.
|
40
|
+
|
41
|
+
Args:
|
42
|
+
artifact: The artifact to verify.
|
43
|
+
|
44
|
+
Raises:
|
45
|
+
IllegalOperationError: If the artifact is too large to be archived.
|
46
|
+
KeyError: If the artifact store is not found or the artifact URI does
|
47
|
+
not exist.
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
The artifact store.
|
51
|
+
"""
|
52
|
+
if not artifact.artifact_store_id:
|
53
|
+
raise KeyError(
|
54
|
+
f"Artifact '{artifact.id}' cannot be downloaded because the "
|
55
|
+
"underlying artifact store was deleted."
|
56
|
+
)
|
57
|
+
|
58
|
+
artifact_store = _load_artifact_store(
|
59
|
+
artifact_store_id=artifact.artifact_store_id, zen_store=zen_store()
|
60
|
+
)
|
61
|
+
|
62
|
+
if not artifact_store.exists(artifact.uri):
|
63
|
+
raise KeyError(f"The artifact URI '{artifact.uri}' does not exist.")
|
64
|
+
|
65
|
+
size = artifact_store.size(artifact.uri)
|
66
|
+
max_download_size = server_config().file_download_size_limit
|
67
|
+
|
68
|
+
if size and size > max_download_size:
|
69
|
+
raise IllegalOperationError(
|
70
|
+
f"The artifact '{artifact.id}' is too large to be downloaded. "
|
71
|
+
f"The maximum download size allowed by your ZenML server is "
|
72
|
+
f"{max_download_size} bytes."
|
73
|
+
)
|
74
|
+
|
75
|
+
return artifact_store
|
76
|
+
|
77
|
+
|
78
|
+
def create_artifact_archive(
|
79
|
+
artifact: "ArtifactVersionResponse",
|
80
|
+
) -> str:
|
81
|
+
"""Create an archive of the given artifact.
|
82
|
+
|
83
|
+
Args:
|
84
|
+
artifact: The artifact to archive.
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
The path to the created archive.
|
88
|
+
"""
|
89
|
+
artifact_store = verify_artifact_is_downloadable(artifact)
|
90
|
+
|
91
|
+
def _prepare_tarinfo(path: str) -> tarfile.TarInfo:
|
92
|
+
archive_path = os.path.relpath(path, artifact.uri)
|
93
|
+
tarinfo = tarfile.TarInfo(name=archive_path)
|
94
|
+
if size := artifact_store.size(path):
|
95
|
+
tarinfo.size = size
|
96
|
+
return tarinfo
|
97
|
+
|
98
|
+
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
99
|
+
with tarfile.open(fileobj=temp_file, mode="w:gz") as tar:
|
100
|
+
if artifact_store.isdir(artifact.uri):
|
101
|
+
for dir, _, files in artifact_store.walk(artifact.uri):
|
102
|
+
dir = dir.decode() if isinstance(dir, bytes) else dir
|
103
|
+
dir_info = tarfile.TarInfo(
|
104
|
+
name=os.path.relpath(dir, artifact.uri)
|
105
|
+
)
|
106
|
+
dir_info.type = tarfile.DIRTYPE
|
107
|
+
dir_info.mode = 0o755
|
108
|
+
tar.addfile(dir_info)
|
109
|
+
|
110
|
+
for file in files:
|
111
|
+
file = (
|
112
|
+
file.decode() if isinstance(file, bytes) else file
|
113
|
+
)
|
114
|
+
path = os.path.join(dir, file)
|
115
|
+
tarinfo = _prepare_tarinfo(path)
|
116
|
+
with artifact_store.open(path, "rb") as f:
|
117
|
+
tar.addfile(tarinfo, fileobj=f)
|
118
|
+
else:
|
119
|
+
tarinfo = _prepare_tarinfo(artifact.uri)
|
120
|
+
with artifact_store.open(artifact.uri, "rb") as f:
|
121
|
+
tar.addfile(tarinfo, fileobj=f)
|
122
|
+
|
123
|
+
return temp_file.name
|
zenml/zen_server/jwt.py
CHANGED
@@ -54,7 +54,6 @@ class JWTToken(BaseModel):
|
|
54
54
|
api_key_id: Optional[UUID] = None
|
55
55
|
schedule_id: Optional[UUID] = None
|
56
56
|
pipeline_run_id: Optional[UUID] = None
|
57
|
-
step_run_id: Optional[UUID] = None
|
58
57
|
session_id: Optional[UUID] = None
|
59
58
|
claims: Dict[str, Any] = {}
|
60
59
|
|
@@ -148,16 +147,6 @@ class JWTToken(BaseModel):
|
|
148
147
|
"UUID"
|
149
148
|
)
|
150
149
|
|
151
|
-
step_run_id: Optional[UUID] = None
|
152
|
-
if "step_run_id" in claims:
|
153
|
-
try:
|
154
|
-
step_run_id = UUID(claims.pop("step_run_id"))
|
155
|
-
except ValueError:
|
156
|
-
raise CredentialsNotValid(
|
157
|
-
"Invalid JWT token: the step_run_id claim is not a valid "
|
158
|
-
"UUID"
|
159
|
-
)
|
160
|
-
|
161
150
|
session_id: Optional[UUID] = None
|
162
151
|
if "session_id" in claims:
|
163
152
|
try:
|
@@ -174,7 +163,6 @@ class JWTToken(BaseModel):
|
|
174
163
|
api_key_id=api_key_id,
|
175
164
|
schedule_id=schedule_id,
|
176
165
|
pipeline_run_id=pipeline_run_id,
|
177
|
-
step_run_id=step_run_id,
|
178
166
|
session_id=session_id,
|
179
167
|
claims=claims,
|
180
168
|
)
|
@@ -212,8 +200,6 @@ class JWTToken(BaseModel):
|
|
212
200
|
claims["schedule_id"] = str(self.schedule_id)
|
213
201
|
if self.pipeline_run_id:
|
214
202
|
claims["pipeline_run_id"] = str(self.pipeline_run_id)
|
215
|
-
if self.step_run_id:
|
216
|
-
claims["step_run_id"] = str(self.step_run_id)
|
217
203
|
if self.session_id:
|
218
204
|
claims["session_id"] = str(self.session_id)
|
219
205
|
|
@@ -14,7 +14,7 @@
|
|
14
14
|
"""RBAC interface definition."""
|
15
15
|
|
16
16
|
from abc import ABC, abstractmethod
|
17
|
-
from typing import TYPE_CHECKING, Dict, List, Set, Tuple
|
17
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
|
18
18
|
|
19
19
|
from zenml.zen_server.rbac.models import Action, Resource
|
20
20
|
|
@@ -63,15 +63,22 @@ class RBACInterface(ABC):
|
|
63
63
|
|
64
64
|
@abstractmethod
|
65
65
|
def update_resource_membership(
|
66
|
-
self,
|
66
|
+
self,
|
67
|
+
sharing_user: "UserResponse",
|
68
|
+
resource: Resource,
|
69
|
+
actions: List[Action],
|
70
|
+
user_id: Optional[str] = None,
|
71
|
+
team_id: Optional[str] = None,
|
67
72
|
) -> None:
|
68
73
|
"""Update the resource membership of a user.
|
69
74
|
|
70
75
|
Args:
|
71
|
-
|
76
|
+
sharing_user: User that is sharing the resource.
|
72
77
|
resource: The resource.
|
73
78
|
actions: The actions that the user should be able to perform on the
|
74
79
|
resource.
|
80
|
+
user_id: ID of the user for which to update the membership.
|
81
|
+
team_id: ID of the team for which to update the membership.
|
75
82
|
"""
|
76
83
|
|
77
84
|
@abstractmethod
|
zenml/zen_server/rbac/utils.py
CHANGED
@@ -688,21 +688,31 @@ def get_schema_for_resource_type(
|
|
688
688
|
|
689
689
|
|
690
690
|
def update_resource_membership(
|
691
|
-
|
691
|
+
sharing_user: "UserResponse",
|
692
|
+
resource: Resource,
|
693
|
+
actions: List[Action],
|
694
|
+
user_id: Optional[str] = None,
|
695
|
+
team_id: Optional[str] = None,
|
692
696
|
) -> None:
|
693
697
|
"""Update the resource membership of a user.
|
694
698
|
|
695
699
|
Args:
|
696
|
-
|
700
|
+
sharing_user: User that is sharing the resource.
|
697
701
|
resource: The resource.
|
698
702
|
actions: The actions that the user should be able to perform on the
|
699
703
|
resource.
|
704
|
+
user_id: ID of the user for which to update the membership.
|
705
|
+
team_id: ID of the team for which to update the membership.
|
700
706
|
"""
|
701
707
|
if not server_config().rbac_enabled:
|
702
708
|
return
|
703
709
|
|
704
710
|
rbac().update_resource_membership(
|
705
|
-
|
711
|
+
sharing_user=sharing_user,
|
712
|
+
resource=resource,
|
713
|
+
actions=actions,
|
714
|
+
user_id=user_id,
|
715
|
+
team_id=team_id,
|
706
716
|
)
|
707
717
|
|
708
718
|
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# permissions and limitations under the License.
|
14
14
|
"""Cloud RBAC implementation."""
|
15
15
|
|
16
|
-
from typing import TYPE_CHECKING, Dict, List, Set, Tuple
|
16
|
+
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
|
17
17
|
|
18
18
|
from zenml.zen_server.cloud_utils import cloud_connection
|
19
19
|
from zenml.zen_server.rbac.models import Action, Resource
|
@@ -117,22 +117,28 @@ class ZenMLCloudRBAC(RBACInterface):
|
|
117
117
|
return full_resource_access, allowed_ids
|
118
118
|
|
119
119
|
def update_resource_membership(
|
120
|
-
self,
|
120
|
+
self,
|
121
|
+
sharing_user: "UserResponse",
|
122
|
+
resource: Resource,
|
123
|
+
actions: List[Action],
|
124
|
+
user_id: Optional[str] = None,
|
125
|
+
team_id: Optional[str] = None,
|
121
126
|
) -> None:
|
122
127
|
"""Update the resource membership of a user.
|
123
128
|
|
124
129
|
Args:
|
125
|
-
|
130
|
+
sharing_user: User that is sharing the resource.
|
126
131
|
resource: The resource.
|
127
132
|
actions: The actions that the user should be able to perform on the
|
128
133
|
resource.
|
134
|
+
user_id: ID of the user for which to update the membership.
|
135
|
+
team_id: ID of the team for which to update the membership.
|
129
136
|
"""
|
130
|
-
|
131
|
-
# Service accounts have full permissions for now
|
132
|
-
return
|
133
|
-
|
137
|
+
assert sharing_user.external_user_id
|
134
138
|
data = {
|
135
|
-
"user_id":
|
139
|
+
"user_id": user_id,
|
140
|
+
"team_id": team_id,
|
141
|
+
"sharing_user_id": str(sharing_user.external_user_id),
|
136
142
|
"resource": str(resource),
|
137
143
|
"actions": [str(action) for action in actions],
|
138
144
|
}
|