zenml-nightly 0.80.2.dev20250414__py3-none-any.whl → 0.80.2.dev20250416__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. zenml/VERSION +1 -1
  2. zenml/artifacts/utils.py +7 -2
  3. zenml/cli/utils.py +13 -11
  4. zenml/config/compiler.py +1 -0
  5. zenml/config/global_config.py +1 -1
  6. zenml/config/pipeline_configurations.py +1 -0
  7. zenml/config/pipeline_run_configuration.py +1 -0
  8. zenml/config/server_config.py +7 -0
  9. zenml/constants.py +8 -0
  10. zenml/integrations/gcp/orchestrators/vertex_orchestrator.py +47 -5
  11. zenml/integrations/gcp/vertex_custom_job_parameters.py +15 -1
  12. zenml/integrations/kubernetes/flavors/kubernetes_step_operator_flavor.py +12 -0
  13. zenml/integrations/kubernetes/orchestrators/kube_utils.py +92 -0
  14. zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator.py +12 -3
  15. zenml/integrations/kubernetes/orchestrators/kubernetes_orchestrator_entrypoint.py +11 -65
  16. zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py +11 -3
  17. zenml/logging/step_logging.py +41 -21
  18. zenml/login/credentials_store.py +31 -0
  19. zenml/materializers/path_materializer.py +17 -2
  20. zenml/models/v2/base/base.py +8 -4
  21. zenml/models/v2/base/filter.py +1 -1
  22. zenml/models/v2/core/pipeline_run.py +19 -0
  23. zenml/orchestrators/step_launcher.py +2 -3
  24. zenml/orchestrators/step_runner.py +2 -2
  25. zenml/orchestrators/utils.py +2 -5
  26. zenml/pipelines/pipeline_context.py +1 -0
  27. zenml/pipelines/pipeline_decorator.py +4 -0
  28. zenml/pipelines/pipeline_definition.py +83 -22
  29. zenml/pipelines/run_utils.py +4 -0
  30. zenml/steps/utils.py +1 -1
  31. zenml/utils/io_utils.py +23 -0
  32. zenml/zen_server/auth.py +96 -64
  33. zenml/zen_server/cloud_utils.py +7 -1
  34. zenml/zen_server/download_utils.py +123 -0
  35. zenml/zen_server/jwt.py +0 -14
  36. zenml/zen_server/rbac/rbac_interface.py +10 -3
  37. zenml/zen_server/rbac/utils.py +13 -3
  38. zenml/zen_server/rbac/zenml_cloud_rbac.py +14 -8
  39. zenml/zen_server/routers/artifact_version_endpoints.py +86 -3
  40. zenml/zen_server/routers/auth_endpoints.py +5 -36
  41. zenml/zen_server/routers/pipeline_deployments_endpoints.py +63 -26
  42. zenml/zen_server/routers/runs_endpoints.py +57 -0
  43. zenml/zen_server/routers/users_endpoints.py +13 -8
  44. zenml/zen_server/template_execution/utils.py +3 -3
  45. zenml/zen_stores/migrations/versions/ff538a321a92_migrate_onboarding_state.py +123 -0
  46. zenml/zen_stores/rest_zen_store.py +16 -13
  47. zenml/zen_stores/schemas/pipeline_run_schemas.py +1 -0
  48. zenml/zen_stores/schemas/server_settings_schemas.py +4 -1
  49. zenml/zen_stores/sql_zen_store.py +18 -0
  50. {zenml_nightly-0.80.2.dev20250414.dist-info → zenml_nightly-0.80.2.dev20250416.dist-info}/METADATA +2 -1
  51. {zenml_nightly-0.80.2.dev20250414.dist-info → zenml_nightly-0.80.2.dev20250416.dist-info}/RECORD +54 -52
  52. {zenml_nightly-0.80.2.dev20250414.dist-info → zenml_nightly-0.80.2.dev20250416.dist-info}/LICENSE +0 -0
  53. {zenml_nightly-0.80.2.dev20250414.dist-info → zenml_nightly-0.80.2.dev20250416.dist-info}/WHEEL +0 -0
  54. {zenml_nightly-0.80.2.dev20250414.dist-info → zenml_nightly-0.80.2.dev20250416.dist-info}/entry_points.txt +0 -0
@@ -15,6 +15,7 @@ from zenml.enums import ExecutionStatus
15
15
  from zenml.logger import get_logger
16
16
  from zenml.models import (
17
17
  FlavorFilter,
18
+ LogsRequest,
18
19
  PipelineDeploymentBase,
19
20
  PipelineDeploymentResponse,
20
21
  PipelineRunRequest,
@@ -49,6 +50,7 @@ def get_default_run_name(pipeline_name: str) -> str:
49
50
 
50
51
  def create_placeholder_run(
51
52
  deployment: "PipelineDeploymentResponse",
53
+ logs: Optional["LogsRequest"] = None,
52
54
  ) -> Optional["PipelineRunResponse"]:
53
55
  """Create a placeholder run for the deployment.
54
56
 
@@ -57,6 +59,7 @@ def create_placeholder_run(
57
59
 
58
60
  Args:
59
61
  deployment: The deployment for which to create the placeholder run.
62
+ logs: The logs for the run.
60
63
 
61
64
  Returns:
62
65
  The placeholder run or `None` if no run was created.
@@ -86,6 +89,7 @@ def create_placeholder_run(
86
89
  pipeline=deployment.pipeline.id if deployment.pipeline else None,
87
90
  status=ExecutionStatus.INITIALIZING,
88
91
  tags=deployment.pipeline_configuration.tags,
92
+ logs=logs,
89
93
  )
90
94
  run, _ = Client().zen_store.get_or_create_run(run_request)
91
95
  return run
zenml/steps/utils.py CHANGED
@@ -553,7 +553,7 @@ def run_as_single_step_pipeline(
553
553
  orchestrator = Client().active_stack.orchestrator
554
554
 
555
555
  pipeline_settings: Any = {}
556
- if "synchronous" in orchestrator.config.model_fields:
556
+ if "synchronous" in type(orchestrator.config).model_fields:
557
557
  # Make sure the orchestrator runs sync so we stream the logs
558
558
  key = settings_utils.get_stack_component_setting_key(orchestrator)
559
559
  pipeline_settings[key] = BaseSettings(synchronous=True)
zenml/utils/io_utils.py CHANGED
@@ -205,6 +205,29 @@ def resolve_relative_path(path: str) -> str:
205
205
  return str(Path(path).resolve())
206
206
 
207
207
 
208
+ def is_path_within_directory(path: str, directory: str) -> bool:
209
+ """Checks if a path is contained within a given directory.
210
+
211
+ This utility function verifies that a path (absolute or relative) resolves
212
+ to a location that is within the specified directory. This is useful for
213
+ security checks such as preventing path traversal attacks when extracting
214
+ archives (CVE-2007-4559) or whenever path containment needs to be verified.
215
+
216
+ Args:
217
+ path: The path to check (can be relative or absolute).
218
+ directory: The directory that should contain the path.
219
+
220
+ Returns:
221
+ Boolean indicating whether the path is contained within the directory (True)
222
+ or not (False).
223
+ """
224
+ # Convert to absolute path, ensuring it's normalized
225
+ abs_path = os.path.abspath(os.path.join(directory, path))
226
+ # Check if the path is within the target directory
227
+ dir_abs = os.path.abspath(directory)
228
+ return abs_path.startswith(dir_abs + os.sep) or abs_path == dir_abs
229
+
230
+
208
231
  def move(source: str, destination: str, overwrite: bool = False) -> None:
209
232
  """Moves dir or file from source to destination. Can be used to rename.
210
233
 
zenml/zen_server/auth.py CHANGED
@@ -15,7 +15,7 @@
15
15
 
16
16
  from contextvars import ContextVar
17
17
  from datetime import datetime, timedelta
18
- from typing import Callable, Optional, Union
18
+ from typing import Callable, Optional, Tuple, Union
19
19
  from urllib.parse import urlencode, urlparse
20
20
  from uuid import UUID, uuid4
21
21
 
@@ -33,9 +33,12 @@ from zenml.analytics.context import AnalyticsContext
33
33
  from zenml.constants import (
34
34
  API,
35
35
  DEFAULT_USERNAME,
36
+ DEFAULT_ZENML_SERVER_GENERIC_API_TOKEN_LIFETIME,
37
+ ENV_ZENML_WORKLOAD_TOKEN_EXPIRATION_LEEWAY,
36
38
  EXTERNAL_AUTHENTICATOR_TIMEOUT,
37
39
  LOGIN,
38
40
  VERSION_1,
41
+ handle_int_env_var,
39
42
  )
40
43
  from zenml.enums import (
41
44
  AuthScheme,
@@ -420,28 +423,32 @@ def authenticate_credentials(
420
423
  @cache_result(expiry=30)
421
424
  def get_pipeline_run_status(
422
425
  pipeline_run_id: UUID,
423
- ) -> Optional[ExecutionStatus]:
426
+ ) -> Tuple[Optional[ExecutionStatus], Optional[datetime]]:
424
427
  """Get the status of a pipeline run.
425
428
 
426
429
  Args:
427
430
  pipeline_run_id: The pipeline run ID.
428
431
 
429
432
  Returns:
430
- The pipeline run status or None if the pipeline run does not
431
- exist.
433
+ The pipeline run status and end time or None if the pipeline
434
+ run does not exist.
432
435
  """
433
436
  try:
434
437
  pipeline_run = zen_store().get_run(
435
- pipeline_run_id, hydrate=False
438
+ pipeline_run_id, hydrate=True
436
439
  )
437
440
  except KeyError:
438
- return None
441
+ return None, None
439
442
 
440
- return pipeline_run.status
443
+ return (
444
+ pipeline_run.status,
445
+ pipeline_run.end_time,
446
+ )
441
447
 
442
- pipeline_run_status = get_pipeline_run_status(
443
- decoded_token.pipeline_run_id
444
- )
448
+ (
449
+ pipeline_run_status,
450
+ pipeline_run_end_time,
451
+ ) = get_pipeline_run_status(decoded_token.pipeline_run_id)
445
452
  if pipeline_run_status is None:
446
453
  error = (
447
454
  f"Authentication error: error retrieving token pipeline run "
@@ -450,59 +457,35 @@ def authenticate_credentials(
450
457
  logger.error(error)
451
458
  raise CredentialsNotValid(error)
452
459
 
460
+ leeway = handle_int_env_var(
461
+ ENV_ZENML_WORKLOAD_TOKEN_EXPIRATION_LEEWAY,
462
+ DEFAULT_ZENML_SERVER_GENERIC_API_TOKEN_LIFETIME,
463
+ )
453
464
  if pipeline_run_status.is_finished:
454
- error = (
455
- f"The execution of pipeline run "
456
- f"{decoded_token.pipeline_run_id} has already concluded and "
457
- "API tokens scoped to it are no longer valid."
458
- )
459
- logger.error(error)
460
- raise CredentialsNotValid(error)
461
-
462
- if decoded_token.step_run_id:
463
- # If the token contains a step run ID, we need to check if the
464
- # step run exists in the database and the step run has not concluded.
465
- # We use a cached version of the step run status to avoid unnecessary
466
- # database queries.
467
-
468
- @cache_result(expiry=30)
469
- def get_step_run_status(
470
- step_run_id: UUID,
471
- ) -> Optional[ExecutionStatus]:
472
- """Get the status of a step run.
473
-
474
- Args:
475
- step_run_id: The step run ID.
476
-
477
- Returns:
478
- The step run status or None if the step run does not exist.
479
- """
480
- try:
481
- step_run = zen_store().get_run_step(
482
- step_run_id, hydrate=False
465
+ if leeway < 0:
466
+ # The token should never expire, we don't need to check
467
+ # the end time.
468
+ pass
469
+ elif (
470
+ # We don't know the end time. This should never happen, but
471
+ # just in case we always expire the token.
472
+ pipeline_run_end_time is None
473
+ # Calculate whether the token has expired.
474
+ or utc_now(tz_aware=pipeline_run_end_time)
475
+ > pipeline_run_end_time + timedelta(seconds=leeway)
476
+ ):
477
+ error = (
478
+ f"The pipeline run {decoded_token.pipeline_run_id} has "
479
+ "finished and API tokens scoped to it are no longer "
480
+ "valid. If you want to increase the expiration time "
481
+ "of the token to allow steps to continue for longer "
482
+ "after other steps have failed, you can do so by "
483
+ "configuring the "
484
+ f"`{ENV_ZENML_WORKLOAD_TOKEN_EXPIRATION_LEEWAY}` "
485
+ "ZenML server environment variable."
483
486
  )
484
- except KeyError:
485
- return None
486
-
487
- return step_run.status
488
-
489
- step_run_status = get_step_run_status(decoded_token.step_run_id)
490
- if step_run_status is None:
491
- error = (
492
- f"Authentication error: error retrieving token step run "
493
- f"{decoded_token.step_run_id}"
494
- )
495
- logger.error(error)
496
- raise CredentialsNotValid(error)
497
-
498
- if step_run_status.is_finished:
499
- error = (
500
- f"The execution of step run "
501
- f"{decoded_token.step_run_id} has already concluded and "
502
- "API tokens scoped to it are no longer valid."
503
- )
504
- logger.error(error)
505
- raise CredentialsNotValid(error)
487
+ logger.error(error)
488
+ raise CredentialsNotValid(error)
506
489
 
507
490
  auth_context = AuthContext(
508
491
  user=user_model,
@@ -861,7 +844,6 @@ def generate_access_token(
861
844
  expires_in: Optional[int] = None,
862
845
  schedule_id: Optional[UUID] = None,
863
846
  pipeline_run_id: Optional[UUID] = None,
864
- step_run_id: Optional[UUID] = None,
865
847
  ) -> OAuthTokenResponse:
866
848
  """Generates an access token for the given user.
867
849
 
@@ -880,7 +862,6 @@ def generate_access_token(
880
862
  expire.
881
863
  schedule_id: The ID of the schedule to scope the token to.
882
864
  pipeline_run_id: The ID of the pipeline run to scope the token to.
883
- step_run_id: The ID of the step run to scope the token to.
884
865
 
885
866
  Returns:
886
867
  An authentication response with an access token.
@@ -956,7 +937,6 @@ def generate_access_token(
956
937
  api_key_id=api_key.id if api_key else None,
957
938
  schedule_id=schedule_id,
958
939
  pipeline_run_id=pipeline_run_id,
959
- step_run_id=step_run_id,
960
940
  # Set the session ID if this is a cross-site request
961
941
  session_id=session_id,
962
942
  ).encode(expires=expires)
@@ -983,6 +963,58 @@ def generate_access_token(
983
963
  )
984
964
 
985
965
 
966
+ def generate_artifact_download_token(artifact_version_id: UUID) -> str:
967
+ """Generate a JWT token for artifact download.
968
+
969
+ Args:
970
+ artifact_version_id: The ID of the artifact version to download.
971
+
972
+ Returns:
973
+ The JWT token for the artifact download.
974
+ """
975
+ import jwt
976
+
977
+ config = server_config()
978
+
979
+ return jwt.encode(
980
+ {
981
+ "exp": utc_now() + timedelta(seconds=30),
982
+ "artifact_version_id": str(artifact_version_id),
983
+ },
984
+ key=config.jwt_secret_key,
985
+ algorithm=config.jwt_token_algorithm,
986
+ )
987
+
988
+
989
+ def verify_artifact_download_token(
990
+ token: str, artifact_version_id: UUID
991
+ ) -> None:
992
+ """Verify a JWT token for artifact download.
993
+
994
+ Args:
995
+ token: The JWT token to verify.
996
+ artifact_version_id: The ID of the artifact version to download.
997
+
998
+ Raises:
999
+ CredentialsNotValid: If the token is invalid or the artifact version
1000
+ ID does not match.
1001
+ """
1002
+ import jwt
1003
+
1004
+ config = server_config()
1005
+ try:
1006
+ claims = jwt.decode(
1007
+ token,
1008
+ config.jwt_secret_key,
1009
+ algorithms=[config.jwt_token_algorithm],
1010
+ )
1011
+ except jwt.PyJWTError as e:
1012
+ raise CredentialsNotValid(f"Invalid JWT token: {e}") from e
1013
+
1014
+ if claims["artifact_version_id"] != str(artifact_version_id):
1015
+ raise CredentialsNotValid("Invalid artifact version ID")
1016
+
1017
+
986
1018
  def http_authentication(
987
1019
  credentials: HTTPBasicCredentials = Depends(HTTPBasic()),
988
1020
  ) -> AuthContext:
@@ -7,7 +7,10 @@ import requests
7
7
  from requests.adapters import HTTPAdapter, Retry
8
8
 
9
9
  from zenml.config.server_config import ServerProConfiguration
10
- from zenml.exceptions import SubscriptionUpgradeRequiredError
10
+ from zenml.exceptions import (
11
+ IllegalOperationError,
12
+ SubscriptionUpgradeRequiredError,
13
+ )
11
14
  from zenml.utils.time_utils import utc_now
12
15
  from zenml.zen_server.utils import get_zenml_headers, server_config
13
16
 
@@ -43,6 +46,7 @@ class ZenMLCloudConnection:
43
46
  Raises:
44
47
  SubscriptionUpgradeRequiredError: If the current subscription tier
45
48
  is insufficient for the attempted operation.
49
+ IllegalOperationError: If the request failed with a 403 status code.
46
50
  RuntimeError: If the request failed.
47
51
 
48
52
  Returns:
@@ -65,6 +69,8 @@ class ZenMLCloudConnection:
65
69
  except requests.HTTPError as e:
66
70
  if response.status_code == 402:
67
71
  raise SubscriptionUpgradeRequiredError(response.json())
72
+ elif response.status_code == 403:
73
+ raise IllegalOperationError(response.json())
68
74
  else:
69
75
  raise RuntimeError(
70
76
  f"Failed while trying to contact the central zenml pro "
@@ -0,0 +1,123 @@
1
+ # Copyright (c) ZenML GmbH 2025. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12
+ # or implied. See the License for the specific language governing
13
+ # permissions and limitations under the License.
14
+ """Utility functions for downloading artifacts."""
15
+
16
+ import os
17
+ import tarfile
18
+ import tempfile
19
+ from typing import (
20
+ TYPE_CHECKING,
21
+ )
22
+
23
+ from zenml.artifacts.utils import _load_artifact_store
24
+ from zenml.exceptions import (
25
+ IllegalOperationError,
26
+ )
27
+ from zenml.models import (
28
+ ArtifactVersionResponse,
29
+ )
30
+ from zenml.zen_server.utils import server_config, zen_store
31
+
32
+ if TYPE_CHECKING:
33
+ from zenml.artifact_stores.base_artifact_store import BaseArtifactStore
34
+
35
+
36
+ def verify_artifact_is_downloadable(
37
+ artifact: "ArtifactVersionResponse",
38
+ ) -> "BaseArtifactStore":
39
+ """Verify that the given artifact is downloadable.
40
+
41
+ Args:
42
+ artifact: The artifact to verify.
43
+
44
+ Raises:
45
+ IllegalOperationError: If the artifact is too large to be archived.
46
+ KeyError: If the artifact store is not found or the artifact URI does
47
+ not exist.
48
+
49
+ Returns:
50
+ The artifact store.
51
+ """
52
+ if not artifact.artifact_store_id:
53
+ raise KeyError(
54
+ f"Artifact '{artifact.id}' cannot be downloaded because the "
55
+ "underlying artifact store was deleted."
56
+ )
57
+
58
+ artifact_store = _load_artifact_store(
59
+ artifact_store_id=artifact.artifact_store_id, zen_store=zen_store()
60
+ )
61
+
62
+ if not artifact_store.exists(artifact.uri):
63
+ raise KeyError(f"The artifact URI '{artifact.uri}' does not exist.")
64
+
65
+ size = artifact_store.size(artifact.uri)
66
+ max_download_size = server_config().file_download_size_limit
67
+
68
+ if size and size > max_download_size:
69
+ raise IllegalOperationError(
70
+ f"The artifact '{artifact.id}' is too large to be downloaded. "
71
+ f"The maximum download size allowed by your ZenML server is "
72
+ f"{max_download_size} bytes."
73
+ )
74
+
75
+ return artifact_store
76
+
77
+
78
+ def create_artifact_archive(
79
+ artifact: "ArtifactVersionResponse",
80
+ ) -> str:
81
+ """Create an archive of the given artifact.
82
+
83
+ Args:
84
+ artifact: The artifact to archive.
85
+
86
+ Returns:
87
+ The path to the created archive.
88
+ """
89
+ artifact_store = verify_artifact_is_downloadable(artifact)
90
+
91
+ def _prepare_tarinfo(path: str) -> tarfile.TarInfo:
92
+ archive_path = os.path.relpath(path, artifact.uri)
93
+ tarinfo = tarfile.TarInfo(name=archive_path)
94
+ if size := artifact_store.size(path):
95
+ tarinfo.size = size
96
+ return tarinfo
97
+
98
+ with tempfile.NamedTemporaryFile(delete=False) as temp_file:
99
+ with tarfile.open(fileobj=temp_file, mode="w:gz") as tar:
100
+ if artifact_store.isdir(artifact.uri):
101
+ for dir, _, files in artifact_store.walk(artifact.uri):
102
+ dir = dir.decode() if isinstance(dir, bytes) else dir
103
+ dir_info = tarfile.TarInfo(
104
+ name=os.path.relpath(dir, artifact.uri)
105
+ )
106
+ dir_info.type = tarfile.DIRTYPE
107
+ dir_info.mode = 0o755
108
+ tar.addfile(dir_info)
109
+
110
+ for file in files:
111
+ file = (
112
+ file.decode() if isinstance(file, bytes) else file
113
+ )
114
+ path = os.path.join(dir, file)
115
+ tarinfo = _prepare_tarinfo(path)
116
+ with artifact_store.open(path, "rb") as f:
117
+ tar.addfile(tarinfo, fileobj=f)
118
+ else:
119
+ tarinfo = _prepare_tarinfo(artifact.uri)
120
+ with artifact_store.open(artifact.uri, "rb") as f:
121
+ tar.addfile(tarinfo, fileobj=f)
122
+
123
+ return temp_file.name
zenml/zen_server/jwt.py CHANGED
@@ -54,7 +54,6 @@ class JWTToken(BaseModel):
54
54
  api_key_id: Optional[UUID] = None
55
55
  schedule_id: Optional[UUID] = None
56
56
  pipeline_run_id: Optional[UUID] = None
57
- step_run_id: Optional[UUID] = None
58
57
  session_id: Optional[UUID] = None
59
58
  claims: Dict[str, Any] = {}
60
59
 
@@ -148,16 +147,6 @@ class JWTToken(BaseModel):
148
147
  "UUID"
149
148
  )
150
149
 
151
- step_run_id: Optional[UUID] = None
152
- if "step_run_id" in claims:
153
- try:
154
- step_run_id = UUID(claims.pop("step_run_id"))
155
- except ValueError:
156
- raise CredentialsNotValid(
157
- "Invalid JWT token: the step_run_id claim is not a valid "
158
- "UUID"
159
- )
160
-
161
150
  session_id: Optional[UUID] = None
162
151
  if "session_id" in claims:
163
152
  try:
@@ -174,7 +163,6 @@ class JWTToken(BaseModel):
174
163
  api_key_id=api_key_id,
175
164
  schedule_id=schedule_id,
176
165
  pipeline_run_id=pipeline_run_id,
177
- step_run_id=step_run_id,
178
166
  session_id=session_id,
179
167
  claims=claims,
180
168
  )
@@ -212,8 +200,6 @@ class JWTToken(BaseModel):
212
200
  claims["schedule_id"] = str(self.schedule_id)
213
201
  if self.pipeline_run_id:
214
202
  claims["pipeline_run_id"] = str(self.pipeline_run_id)
215
- if self.step_run_id:
216
- claims["step_run_id"] = str(self.step_run_id)
217
203
  if self.session_id:
218
204
  claims["session_id"] = str(self.session_id)
219
205
 
@@ -14,7 +14,7 @@
14
14
  """RBAC interface definition."""
15
15
 
16
16
  from abc import ABC, abstractmethod
17
- from typing import TYPE_CHECKING, Dict, List, Set, Tuple
17
+ from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
18
18
 
19
19
  from zenml.zen_server.rbac.models import Action, Resource
20
20
 
@@ -63,15 +63,22 @@ class RBACInterface(ABC):
63
63
 
64
64
  @abstractmethod
65
65
  def update_resource_membership(
66
- self, user: "UserResponse", resource: Resource, actions: List[Action]
66
+ self,
67
+ sharing_user: "UserResponse",
68
+ resource: Resource,
69
+ actions: List[Action],
70
+ user_id: Optional[str] = None,
71
+ team_id: Optional[str] = None,
67
72
  ) -> None:
68
73
  """Update the resource membership of a user.
69
74
 
70
75
  Args:
71
- user: User for which the resource membership should be updated.
76
+ sharing_user: User that is sharing the resource.
72
77
  resource: The resource.
73
78
  actions: The actions that the user should be able to perform on the
74
79
  resource.
80
+ user_id: ID of the user for which to update the membership.
81
+ team_id: ID of the team for which to update the membership.
75
82
  """
76
83
 
77
84
  @abstractmethod
@@ -688,21 +688,31 @@ def get_schema_for_resource_type(
688
688
 
689
689
 
690
690
  def update_resource_membership(
691
- user: UserResponse, resource: Resource, actions: List[Action]
691
+ sharing_user: "UserResponse",
692
+ resource: Resource,
693
+ actions: List[Action],
694
+ user_id: Optional[str] = None,
695
+ team_id: Optional[str] = None,
692
696
  ) -> None:
693
697
  """Update the resource membership of a user.
694
698
 
695
699
  Args:
696
- user: User for which the resource membership should be updated.
700
+ sharing_user: User that is sharing the resource.
697
701
  resource: The resource.
698
702
  actions: The actions that the user should be able to perform on the
699
703
  resource.
704
+ user_id: ID of the user for which to update the membership.
705
+ team_id: ID of the team for which to update the membership.
700
706
  """
701
707
  if not server_config().rbac_enabled:
702
708
  return
703
709
 
704
710
  rbac().update_resource_membership(
705
- user=user, resource=resource, actions=actions
711
+ sharing_user=sharing_user,
712
+ resource=resource,
713
+ actions=actions,
714
+ user_id=user_id,
715
+ team_id=team_id,
706
716
  )
707
717
 
708
718
 
@@ -13,7 +13,7 @@
13
13
  # permissions and limitations under the License.
14
14
  """Cloud RBAC implementation."""
15
15
 
16
- from typing import TYPE_CHECKING, Dict, List, Set, Tuple
16
+ from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
17
17
 
18
18
  from zenml.zen_server.cloud_utils import cloud_connection
19
19
  from zenml.zen_server.rbac.models import Action, Resource
@@ -117,22 +117,28 @@ class ZenMLCloudRBAC(RBACInterface):
117
117
  return full_resource_access, allowed_ids
118
118
 
119
119
  def update_resource_membership(
120
- self, user: "UserResponse", resource: Resource, actions: List[Action]
120
+ self,
121
+ sharing_user: "UserResponse",
122
+ resource: Resource,
123
+ actions: List[Action],
124
+ user_id: Optional[str] = None,
125
+ team_id: Optional[str] = None,
121
126
  ) -> None:
122
127
  """Update the resource membership of a user.
123
128
 
124
129
  Args:
125
- user: User for which the resource membership should be updated.
130
+ sharing_user: User that is sharing the resource.
126
131
  resource: The resource.
127
132
  actions: The actions that the user should be able to perform on the
128
133
  resource.
134
+ user_id: ID of the user for which to update the membership.
135
+ team_id: ID of the team for which to update the membership.
129
136
  """
130
- if user.is_service_account:
131
- # Service accounts have full permissions for now
132
- return
133
-
137
+ assert sharing_user.external_user_id
134
138
  data = {
135
- "user_id": str(user.external_user_id),
139
+ "user_id": user_id,
140
+ "team_id": team_id,
141
+ "sharing_user_id": str(sharing_user.external_user_id),
136
142
  "resource": str(resource),
137
143
  "actions": [str(action) for action in actions],
138
144
  }