zenml-nightly 0.72.0.dev20250115__py3-none-any.whl → 0.72.0.dev20250116__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/cli/login.py +126 -50
- zenml/cli/server.py +24 -5
- zenml/config/server_config.py +142 -17
- zenml/constants.py +2 -11
- zenml/login/credentials.py +38 -14
- zenml/login/credentials_store.py +53 -18
- zenml/login/pro/client.py +3 -7
- zenml/login/pro/constants.py +0 -6
- zenml/login/pro/tenant/models.py +4 -2
- zenml/login/pro/utils.py +11 -25
- zenml/login/server_info.py +52 -0
- zenml/login/web_login.py +11 -6
- zenml/models/v2/misc/auth_models.py +1 -1
- zenml/models/v2/misc/server_models.py +44 -0
- zenml/zen_server/auth.py +97 -8
- zenml/zen_server/cloud_utils.py +79 -87
- zenml/zen_server/csrf.py +91 -0
- zenml/zen_server/deploy/helm/templates/NOTES.txt +22 -0
- zenml/zen_server/deploy/helm/templates/_environment.tpl +50 -24
- zenml/zen_server/deploy/helm/templates/server-secret.yaml +11 -0
- zenml/zen_server/deploy/helm/values.yaml +76 -7
- zenml/zen_server/feature_gate/feature_gate_interface.py +1 -1
- zenml/zen_server/jwt.py +16 -1
- zenml/zen_server/rbac/endpoint_utils.py +3 -3
- zenml/zen_server/routers/auth_endpoints.py +44 -21
- zenml/zen_server/routers/models_endpoints.py +1 -2
- zenml/zen_server/routers/pipelines_endpoints.py +2 -2
- zenml/zen_server/routers/stack_deployment_endpoints.py +5 -5
- zenml/zen_server/routers/workspaces_endpoints.py +2 -2
- zenml/zen_server/utils.py +64 -0
- zenml/zen_server/zen_server_api.py +5 -0
- zenml/zen_stores/base_zen_store.py +19 -1
- zenml/zen_stores/rest_zen_store.py +30 -20
- {zenml_nightly-0.72.0.dev20250115.dist-info → zenml_nightly-0.72.0.dev20250116.dist-info}/METADATA +3 -1
- {zenml_nightly-0.72.0.dev20250115.dist-info → zenml_nightly-0.72.0.dev20250116.dist-info}/RECORD +39 -37
- {zenml_nightly-0.72.0.dev20250115.dist-info → zenml_nightly-0.72.0.dev20250116.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.72.0.dev20250115.dist-info → zenml_nightly-0.72.0.dev20250116.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.72.0.dev20250115.dist-info → zenml_nightly-0.72.0.dev20250116.dist-info}/entry_points.txt +0 -0
zenml/zen_server/auth.py
CHANGED
@@ -16,8 +16,8 @@
|
|
16
16
|
from contextvars import ContextVar
|
17
17
|
from datetime import datetime, timedelta
|
18
18
|
from typing import Callable, Optional, Union
|
19
|
-
from urllib.parse import urlencode
|
20
|
-
from uuid import UUID
|
19
|
+
from urllib.parse import urlencode, urlparse
|
20
|
+
from uuid import UUID, uuid4
|
21
21
|
|
22
22
|
import requests
|
23
23
|
from fastapi import Depends, Response
|
@@ -63,9 +63,15 @@ from zenml.models import (
|
|
63
63
|
UserUpdate,
|
64
64
|
)
|
65
65
|
from zenml.zen_server.cache import cache_result
|
66
|
+
from zenml.zen_server.csrf import CSRFToken
|
66
67
|
from zenml.zen_server.exceptions import http_exception_from_error
|
67
68
|
from zenml.zen_server.jwt import JWTToken
|
68
|
-
from zenml.zen_server.utils import
|
69
|
+
from zenml.zen_server.utils import (
|
70
|
+
get_zenml_headers,
|
71
|
+
is_same_or_subdomain,
|
72
|
+
server_config,
|
73
|
+
zen_store,
|
74
|
+
)
|
69
75
|
|
70
76
|
logger = get_logger(__name__)
|
71
77
|
|
@@ -174,6 +180,7 @@ def authenticate_credentials(
|
|
174
180
|
user_name_or_id: Optional[Union[str, UUID]] = None,
|
175
181
|
password: Optional[str] = None,
|
176
182
|
access_token: Optional[str] = None,
|
183
|
+
csrf_token: Optional[str] = None,
|
177
184
|
activation_token: Optional[str] = None,
|
178
185
|
) -> AuthContext:
|
179
186
|
"""Verify if user authentication credentials are valid.
|
@@ -192,6 +199,7 @@ def authenticate_credentials(
|
|
192
199
|
user_name_or_id: The username or user ID.
|
193
200
|
password: The password.
|
194
201
|
access_token: The access token.
|
202
|
+
csrf_token: The CSRF token.
|
195
203
|
activation_token: The activation token.
|
196
204
|
|
197
205
|
Returns:
|
@@ -253,6 +261,22 @@ def authenticate_credentials(
|
|
253
261
|
logger.exception(error)
|
254
262
|
raise CredentialsNotValid(error)
|
255
263
|
|
264
|
+
if decoded_token.session_id:
|
265
|
+
if not csrf_token:
|
266
|
+
error = "Authentication error: missing CSRF token"
|
267
|
+
logger.error(error)
|
268
|
+
raise CredentialsNotValid(error)
|
269
|
+
|
270
|
+
decoded_csrf_token = CSRFToken.decode_token(csrf_token)
|
271
|
+
|
272
|
+
if decoded_csrf_token.session_id != decoded_token.session_id:
|
273
|
+
error = (
|
274
|
+
"Authentication error: CSRF token does not match the "
|
275
|
+
"access token"
|
276
|
+
)
|
277
|
+
logger.error(error)
|
278
|
+
raise CredentialsNotValid(error)
|
279
|
+
|
256
280
|
try:
|
257
281
|
user_model = zen_store().get_user(
|
258
282
|
user_name_or_id=decoded_token.user_id, include_private=True
|
@@ -282,6 +306,14 @@ def authenticate_credentials(
|
|
282
306
|
|
283
307
|
device_model: Optional[OAuthDeviceInternalResponse] = None
|
284
308
|
if decoded_token.device_id:
|
309
|
+
if server_config().auth_scheme in [
|
310
|
+
AuthScheme.NO_AUTH,
|
311
|
+
AuthScheme.EXTERNAL,
|
312
|
+
]:
|
313
|
+
error = "Authentication error: device authorization is not supported."
|
314
|
+
logger.error(error)
|
315
|
+
raise CredentialsNotValid(error)
|
316
|
+
|
285
317
|
# Access tokens that have been issued for a device are only valid
|
286
318
|
# for that device, so we need to check if the device ID matches any
|
287
319
|
# of the valid devices in the database.
|
@@ -660,6 +692,7 @@ def authenticate_external_user(
|
|
660
692
|
# Get the user information from the external authenticator
|
661
693
|
user_info_url = config.external_user_info_url
|
662
694
|
headers = {"Authorization": "Bearer " + external_access_token}
|
695
|
+
headers.update(get_zenml_headers())
|
663
696
|
query_params = dict(server_id=str(config.get_external_server_id()))
|
664
697
|
|
665
698
|
try:
|
@@ -817,6 +850,7 @@ def authenticate_api_key(
|
|
817
850
|
def generate_access_token(
|
818
851
|
user_id: UUID,
|
819
852
|
response: Optional[Response] = None,
|
853
|
+
request: Optional[Request] = None,
|
820
854
|
device: Optional[OAuthDeviceInternalResponse] = None,
|
821
855
|
api_key: Optional[APIKeyInternalResponse] = None,
|
822
856
|
expires_in: Optional[int] = None,
|
@@ -828,7 +862,11 @@ def generate_access_token(
|
|
828
862
|
|
829
863
|
Args:
|
830
864
|
user_id: The ID of the user.
|
831
|
-
response: The FastAPI response object.
|
865
|
+
response: The FastAPI response object. If passed, the access
|
866
|
+
token will also be set as an HTTP only cookie in the response.
|
867
|
+
request: The FastAPI request object. Used to determine the request
|
868
|
+
origin and to decide whether to use cross-site security measures for
|
869
|
+
the access token cookie.
|
832
870
|
device: The device used for authentication.
|
833
871
|
api_key: The service account API key used for authentication.
|
834
872
|
expires_in: The number of seconds until the token expires. If not set,
|
@@ -866,6 +904,46 @@ def generate_access_token(
|
|
866
904
|
)
|
867
905
|
expires_in = config.jwt_token_expire_minutes * 60
|
868
906
|
|
907
|
+
# Figure out if this is a same-site request or a cross-site request
|
908
|
+
same_site = True
|
909
|
+
if response and request:
|
910
|
+
# Extract the origin domain from the request; use the referer as a
|
911
|
+
# fallback
|
912
|
+
origin_domain: Optional[str] = None
|
913
|
+
origin = request.headers.get("origin", request.headers.get("referer"))
|
914
|
+
if origin:
|
915
|
+
# If the request origin is known, we use it to determine whether
|
916
|
+
# this is a cross-site request and enable additional security
|
917
|
+
# measures.
|
918
|
+
origin_domain = urlparse(origin).netloc
|
919
|
+
|
920
|
+
server_domain: Optional[str] = config.auth_cookie_domain
|
921
|
+
# If the server's cookie domain is not explicitly set in the
|
922
|
+
# server's configuration, we use other sources to determine it:
|
923
|
+
#
|
924
|
+
# 1. the server's root URL, if set in the server's configuration
|
925
|
+
# 2. the X-Forwarded-Host header, if set by the reverse proxy
|
926
|
+
# 3. the request URL, if all else fails
|
927
|
+
if not server_domain and config.server_url:
|
928
|
+
server_domain = urlparse(config.server_url).netloc
|
929
|
+
if not server_domain:
|
930
|
+
server_domain = request.headers.get(
|
931
|
+
"x-forwarded-host", request.url.netloc
|
932
|
+
)
|
933
|
+
|
934
|
+
# Same-site requests can come from the same domain or from a
|
935
|
+
# subdomain of the domain used to issue cookies.
|
936
|
+
if origin_domain and server_domain:
|
937
|
+
same_site = is_same_or_subdomain(origin_domain, server_domain)
|
938
|
+
|
939
|
+
csrf_token: Optional[str] = None
|
940
|
+
session_id: Optional[UUID] = None
|
941
|
+
if not same_site:
|
942
|
+
# If responding to a cross-site login request, we need to generate and
|
943
|
+
# sign a CSRF token associated with the authentication session.
|
944
|
+
session_id = uuid4()
|
945
|
+
csrf_token = CSRFToken(session_id=session_id).encode()
|
946
|
+
|
869
947
|
access_token = JWTToken(
|
870
948
|
user_id=user_id,
|
871
949
|
device_id=device.id if device else None,
|
@@ -873,15 +951,18 @@ def generate_access_token(
|
|
873
951
|
schedule_id=schedule_id,
|
874
952
|
pipeline_run_id=pipeline_run_id,
|
875
953
|
step_run_id=step_run_id,
|
954
|
+
# Set the session ID if this is a cross-site request
|
955
|
+
session_id=session_id,
|
876
956
|
).encode(expires=expires)
|
877
957
|
|
878
|
-
if
|
958
|
+
if response:
|
879
959
|
# Also set the access token as an HTTP only cookie in the response
|
880
960
|
response.set_cookie(
|
881
961
|
key=config.get_auth_cookie_name(),
|
882
962
|
value=access_token,
|
883
963
|
httponly=True,
|
884
|
-
|
964
|
+
secure=not same_site,
|
965
|
+
samesite="lax" if same_site else "none",
|
885
966
|
max_age=config.jwt_token_expire_minutes * 60
|
886
967
|
if config.jwt_token_expire_minutes
|
887
968
|
else None,
|
@@ -889,7 +970,10 @@ def generate_access_token(
|
|
889
970
|
)
|
890
971
|
|
891
972
|
return OAuthTokenResponse(
|
892
|
-
access_token=access_token,
|
973
|
+
access_token=access_token,
|
974
|
+
expires_in=expires_in,
|
975
|
+
token_type="bearer",
|
976
|
+
csrf_token=csrf_token,
|
893
977
|
)
|
894
978
|
|
895
979
|
|
@@ -945,6 +1029,7 @@ class CookieOAuth2TokenBearer(OAuth2PasswordBearer):
|
|
945
1029
|
|
946
1030
|
|
947
1031
|
def oauth2_authentication(
|
1032
|
+
request: Request,
|
948
1033
|
token: str = Depends(
|
949
1034
|
CookieOAuth2TokenBearer(
|
950
1035
|
tokenUrl=server_config().root_url_path + API + VERSION_1 + LOGIN,
|
@@ -955,14 +1040,18 @@ def oauth2_authentication(
|
|
955
1040
|
|
956
1041
|
Args:
|
957
1042
|
token: The JWT bearer token to be authenticated.
|
1043
|
+
request: The FastAPI request object.
|
958
1044
|
|
959
1045
|
Returns:
|
960
1046
|
The authentication context reflecting the authenticated user.
|
961
1047
|
|
962
1048
|
# noqa: DAR401
|
963
1049
|
"""
|
1050
|
+
csrf_token = request.headers.get("X-CSRF-Token")
|
964
1051
|
try:
|
965
|
-
auth_context = authenticate_credentials(
|
1052
|
+
auth_context = authenticate_credentials(
|
1053
|
+
access_token=token, csrf_token=csrf_token
|
1054
|
+
)
|
966
1055
|
except CredentialsNotValid as e:
|
967
1056
|
# We want to be very explicit here and return a CredentialsNotValid
|
968
1057
|
# exception encoded as a 401 Unauthorized error encoded, so that the
|
zenml/zen_server/cloud_utils.py
CHANGED
@@ -1,114 +1,92 @@
|
|
1
1
|
"""Utils concerning anything concerning the cloud control plane backend."""
|
2
2
|
|
3
|
-
import os
|
4
3
|
from datetime import datetime, timedelta, timezone
|
5
4
|
from typing import Any, Dict, Optional
|
6
5
|
|
7
6
|
import requests
|
8
|
-
from pydantic import BaseModel, ConfigDict, field_validator
|
9
7
|
from requests.adapters import HTTPAdapter, Retry
|
10
8
|
|
9
|
+
from zenml.config.server_config import ServerProConfiguration
|
11
10
|
from zenml.exceptions import SubscriptionUpgradeRequiredError
|
12
|
-
from zenml.zen_server.utils import server_config
|
13
|
-
|
14
|
-
ZENML_CLOUD_RBAC_ENV_PREFIX = "ZENML_CLOUD_"
|
11
|
+
from zenml.zen_server.utils import get_zenml_headers, server_config
|
15
12
|
|
16
13
|
_cloud_connection: Optional["ZenMLCloudConnection"] = None
|
17
14
|
|
18
15
|
|
19
|
-
class ZenMLCloudConfiguration(BaseModel):
|
20
|
-
"""ZenML Pro RBAC configuration."""
|
21
|
-
|
22
|
-
api_url: str
|
23
|
-
oauth2_client_id: str
|
24
|
-
oauth2_client_secret: str
|
25
|
-
oauth2_audience: str
|
26
|
-
|
27
|
-
@field_validator("api_url")
|
28
|
-
@classmethod
|
29
|
-
def _strip_trailing_slashes_url(cls, url: str) -> str:
|
30
|
-
"""Strip any trailing slashes on the API URL.
|
31
|
-
|
32
|
-
Args:
|
33
|
-
url: The API URL.
|
34
|
-
|
35
|
-
Returns:
|
36
|
-
The API URL with potential trailing slashes removed.
|
37
|
-
"""
|
38
|
-
return url.rstrip("/")
|
39
|
-
|
40
|
-
@classmethod
|
41
|
-
def from_environment(cls) -> "ZenMLCloudConfiguration":
|
42
|
-
"""Get the RBAC configuration from environment variables.
|
43
|
-
|
44
|
-
Returns:
|
45
|
-
The RBAC configuration.
|
46
|
-
"""
|
47
|
-
env_config: Dict[str, Any] = {}
|
48
|
-
for k, v in os.environ.items():
|
49
|
-
if v == "":
|
50
|
-
continue
|
51
|
-
if k.startswith(ZENML_CLOUD_RBAC_ENV_PREFIX):
|
52
|
-
env_config[k[len(ZENML_CLOUD_RBAC_ENV_PREFIX) :].lower()] = v
|
53
|
-
|
54
|
-
return ZenMLCloudConfiguration(**env_config)
|
55
|
-
|
56
|
-
model_config = ConfigDict(
|
57
|
-
# Allow extra attributes from configs of previous ZenML versions to
|
58
|
-
# permit downgrading
|
59
|
-
extra="allow"
|
60
|
-
)
|
61
|
-
|
62
|
-
|
63
16
|
class ZenMLCloudConnection:
|
64
17
|
"""Class to use for communication between server and control plane."""
|
65
18
|
|
66
19
|
def __init__(self) -> None:
|
67
20
|
"""Initialize the RBAC component."""
|
68
|
-
self._config =
|
21
|
+
self._config = ServerProConfiguration.get_server_config()
|
69
22
|
self._session: Optional[requests.Session] = None
|
70
23
|
self._token: Optional[str] = None
|
71
24
|
self._token_expires_at: Optional[datetime] = None
|
72
25
|
|
73
|
-
def
|
74
|
-
self,
|
26
|
+
def request(
|
27
|
+
self,
|
28
|
+
method: str,
|
29
|
+
endpoint: str,
|
30
|
+
params: Optional[Dict[str, Any]] = None,
|
31
|
+
data: Optional[Dict[str, Any]] = None,
|
75
32
|
) -> requests.Response:
|
76
|
-
"""Send a
|
33
|
+
"""Send a request using the active session.
|
77
34
|
|
78
35
|
Args:
|
36
|
+
method: The HTTP method to use.
|
79
37
|
endpoint: The endpoint to send the request to. This will be appended
|
80
38
|
to the base URL.
|
81
39
|
params: Parameters to include in the request.
|
40
|
+
data: Data to include in the request.
|
82
41
|
|
83
42
|
Raises:
|
43
|
+
SubscriptionUpgradeRequiredError: If the current subscription tier
|
44
|
+
is insufficient for the attempted operation.
|
84
45
|
RuntimeError: If the request failed.
|
85
|
-
SubscriptionUpgradeRequiredError: In case the current subscription
|
86
|
-
tier is insufficient for the attempted operation.
|
87
46
|
|
88
47
|
Returns:
|
89
48
|
The response.
|
90
49
|
"""
|
91
50
|
url = self._config.api_url + endpoint
|
92
51
|
|
93
|
-
response = self.session.
|
52
|
+
response = self.session.request(
|
53
|
+
method=method, url=url, params=params, json=data, timeout=7
|
54
|
+
)
|
94
55
|
if response.status_code == 401:
|
95
|
-
#
|
96
|
-
# auth token and try again
|
56
|
+
# Refresh the auth token and try again
|
97
57
|
self._clear_session()
|
98
|
-
response = self.session.
|
58
|
+
response = self.session.request(
|
59
|
+
method=method, url=url, params=params, json=data, timeout=7
|
60
|
+
)
|
99
61
|
|
100
62
|
try:
|
101
63
|
response.raise_for_status()
|
102
|
-
except requests.HTTPError:
|
64
|
+
except requests.HTTPError as e:
|
103
65
|
if response.status_code == 402:
|
104
66
|
raise SubscriptionUpgradeRequiredError(response.json())
|
105
67
|
else:
|
106
68
|
raise RuntimeError(
|
107
|
-
f"Failed
|
69
|
+
f"Failed while trying to contact the central zenml pro "
|
70
|
+
f"service: {e}"
|
108
71
|
)
|
109
72
|
|
110
73
|
return response
|
111
74
|
|
75
|
+
def get(
|
76
|
+
self, endpoint: str, params: Optional[Dict[str, Any]]
|
77
|
+
) -> requests.Response:
|
78
|
+
"""Send a GET request using the active session.
|
79
|
+
|
80
|
+
Args:
|
81
|
+
endpoint: The endpoint to send the request to. This will be appended
|
82
|
+
to the base URL.
|
83
|
+
params: Parameters to include in the request.
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
The response.
|
87
|
+
"""
|
88
|
+
return self.request(method="GET", endpoint=endpoint, params=params)
|
89
|
+
|
112
90
|
def post(
|
113
91
|
self,
|
114
92
|
endpoint: str,
|
@@ -123,33 +101,33 @@ class ZenMLCloudConnection:
|
|
123
101
|
params: Parameters to include in the request.
|
124
102
|
data: Data to include in the request.
|
125
103
|
|
126
|
-
Raises:
|
127
|
-
RuntimeError: If the request failed.
|
128
|
-
|
129
104
|
Returns:
|
130
105
|
The response.
|
131
106
|
"""
|
132
|
-
|
133
|
-
|
134
|
-
response = self.session.post(
|
135
|
-
url=url, params=params, json=data, timeout=7
|
107
|
+
return self.request(
|
108
|
+
method="POST", endpoint=endpoint, params=params, data=data
|
136
109
|
)
|
137
|
-
if response.status_code == 401:
|
138
|
-
# Refresh the auth token and try again
|
139
|
-
self._clear_session()
|
140
|
-
response = self.session.post(
|
141
|
-
url=url, params=params, json=data, timeout=7
|
142
|
-
)
|
143
110
|
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
111
|
+
def patch(
|
112
|
+
self,
|
113
|
+
endpoint: str,
|
114
|
+
params: Optional[Dict[str, Any]] = None,
|
115
|
+
data: Optional[Dict[str, Any]] = None,
|
116
|
+
) -> requests.Response:
|
117
|
+
"""Send a PATCH request using the active session.
|
151
118
|
|
152
|
-
|
119
|
+
Args:
|
120
|
+
endpoint: The endpoint to send the request to. This will be appended
|
121
|
+
to the base URL.
|
122
|
+
params: Parameters to include in the request.
|
123
|
+
data: Data to include in the request.
|
124
|
+
|
125
|
+
Returns:
|
126
|
+
The response.
|
127
|
+
"""
|
128
|
+
return self.request(
|
129
|
+
method="PATCH", endpoint=endpoint, params=params, data=data
|
130
|
+
)
|
153
131
|
|
154
132
|
@property
|
155
133
|
def session(self) -> requests.Session:
|
@@ -169,6 +147,8 @@ class ZenMLCloudConnection:
|
|
169
147
|
self._session = requests.Session()
|
170
148
|
token = self._fetch_auth_token()
|
171
149
|
self._session.headers.update({"Authorization": "Bearer " + token})
|
150
|
+
# Add the ZenML specific headers
|
151
|
+
self._session.headers.update(get_zenml_headers())
|
172
152
|
|
173
153
|
retries = Retry(
|
174
154
|
total=5, backoff_factor=0.1, status_forcelist=[502, 504]
|
@@ -194,7 +174,7 @@ class ZenMLCloudConnection:
|
|
194
174
|
self._token_expires_at = None
|
195
175
|
|
196
176
|
def _fetch_auth_token(self) -> str:
|
197
|
-
"""Fetch an auth token
|
177
|
+
"""Fetch an auth token from the Cloud API.
|
198
178
|
|
199
179
|
Raises:
|
200
180
|
RuntimeError: If the auth token can't be fetched.
|
@@ -210,11 +190,14 @@ class ZenMLCloudConnection:
|
|
210
190
|
):
|
211
191
|
return self._token
|
212
192
|
|
213
|
-
# Get an auth token from
|
193
|
+
# Get an auth token from the Cloud API
|
214
194
|
login_url = f"{self._config.api_url}/auth/login"
|
215
195
|
headers = {"content-type": "application/x-www-form-urlencoded"}
|
196
|
+
# Add zenml specific headers to the request
|
197
|
+
headers.update(get_zenml_headers())
|
216
198
|
payload = {
|
217
|
-
|
199
|
+
# The client ID is the external server ID
|
200
|
+
"client_id": str(server_config().get_external_server_id()),
|
218
201
|
"client_secret": self._config.oauth2_client_secret,
|
219
202
|
"audience": self._config.oauth2_audience,
|
220
203
|
"grant_type": "client_credentials",
|
@@ -225,7 +208,9 @@ class ZenMLCloudConnection:
|
|
225
208
|
)
|
226
209
|
response.raise_for_status()
|
227
210
|
except Exception as e:
|
228
|
-
raise RuntimeError(
|
211
|
+
raise RuntimeError(
|
212
|
+
f"Error fetching auth token from the Cloud API: {e}"
|
213
|
+
)
|
229
214
|
|
230
215
|
json_response = response.json()
|
231
216
|
access_token = json_response.get("access_token", "")
|
@@ -237,7 +222,9 @@ class ZenMLCloudConnection:
|
|
237
222
|
or not expires_in
|
238
223
|
or not isinstance(expires_in, int)
|
239
224
|
):
|
240
|
-
raise RuntimeError(
|
225
|
+
raise RuntimeError(
|
226
|
+
"Could not fetch auth token from the Cloud API."
|
227
|
+
)
|
241
228
|
|
242
229
|
self._token = access_token
|
243
230
|
self._token_expires_at = datetime.now(timezone.utc) + timedelta(
|
@@ -259,3 +246,8 @@ def cloud_connection() -> ZenMLCloudConnection:
|
|
259
246
|
_cloud_connection = ZenMLCloudConnection()
|
260
247
|
|
261
248
|
return _cloud_connection
|
249
|
+
|
250
|
+
|
251
|
+
def send_pro_tenant_status_update() -> None:
|
252
|
+
"""Send a tenant status update to the Cloud API."""
|
253
|
+
cloud_connection().patch("/tenant_status")
|
zenml/zen_server/csrf.py
ADDED
@@ -0,0 +1,91 @@
|
|
1
|
+
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
12
|
+
# or implied. See the License for the specific language governing
|
13
|
+
# permissions and limitations under the License.
|
14
|
+
"""CSRF token utilities module for ZenML server."""
|
15
|
+
|
16
|
+
from uuid import UUID
|
17
|
+
|
18
|
+
from pydantic import BaseModel
|
19
|
+
|
20
|
+
from zenml.exceptions import CredentialsNotValid
|
21
|
+
from zenml.logger import get_logger
|
22
|
+
from zenml.zen_server.utils import server_config
|
23
|
+
|
24
|
+
logger = get_logger(__name__)
|
25
|
+
|
26
|
+
|
27
|
+
class CSRFToken(BaseModel):
|
28
|
+
"""Pydantic object representing a CSRF token.
|
29
|
+
|
30
|
+
Attributes:
|
31
|
+
session_id: The id of the authenticated session.
|
32
|
+
"""
|
33
|
+
|
34
|
+
session_id: UUID
|
35
|
+
|
36
|
+
@classmethod
|
37
|
+
def decode_token(
|
38
|
+
cls,
|
39
|
+
token: str,
|
40
|
+
) -> "CSRFToken":
|
41
|
+
"""Decodes a CSRF token.
|
42
|
+
|
43
|
+
Decodes a CSRF access token and returns a `CSRFToken` object with the
|
44
|
+
information retrieved from its contents.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
token: The encoded CSRF token.
|
48
|
+
|
49
|
+
Returns:
|
50
|
+
The decoded CSRF token.
|
51
|
+
|
52
|
+
Raises:
|
53
|
+
CredentialsNotValid: If the token is invalid.
|
54
|
+
"""
|
55
|
+
from itsdangerous import BadData, BadSignature, URLSafeSerializer
|
56
|
+
|
57
|
+
config = server_config()
|
58
|
+
|
59
|
+
serializer = URLSafeSerializer(config.jwt_secret_key)
|
60
|
+
try:
|
61
|
+
# Decode and verify the token
|
62
|
+
data = serializer.loads(token)
|
63
|
+
except BadSignature as e:
|
64
|
+
raise CredentialsNotValid(
|
65
|
+
"Invalid CSRF token: signature mismatch"
|
66
|
+
) from e
|
67
|
+
except BadData as e:
|
68
|
+
raise CredentialsNotValid("Invalid CSRF token") from e
|
69
|
+
|
70
|
+
try:
|
71
|
+
return CSRFToken(session_id=UUID(data))
|
72
|
+
except ValueError as e:
|
73
|
+
raise CredentialsNotValid(
|
74
|
+
"Invalid CSRF token: the session ID is not a valid UUID"
|
75
|
+
) from e
|
76
|
+
|
77
|
+
def encode(self) -> str:
|
78
|
+
"""Creates a CSRF token.
|
79
|
+
|
80
|
+
Encodes, signs and returns a CSRF access token.
|
81
|
+
|
82
|
+
Returns:
|
83
|
+
The generated CSRF token.
|
84
|
+
"""
|
85
|
+
from itsdangerous import URLSafeSerializer
|
86
|
+
|
87
|
+
config = server_config()
|
88
|
+
|
89
|
+
serializer = URLSafeSerializer(config.jwt_secret_key)
|
90
|
+
token = serializer.dumps(str(self.session_id))
|
91
|
+
return token
|
@@ -1,3 +1,24 @@
|
|
1
|
+
{{- if .Values.zenml.pro.enabled }}
|
2
|
+
|
3
|
+
The ZenML Pro server API is now active and ready to use at the following URL:
|
4
|
+
|
5
|
+
{{ .Values.zenml.serverURL }}
|
6
|
+
|
7
|
+
{{- if .Values.zenml.pro.enrollmentKey }}
|
8
|
+
|
9
|
+
The following enrollment key has been used to enroll your server in the ZenML Pro control plane:
|
10
|
+
|
11
|
+
{{ .Values.zenml.pro.enrollmentKey }}
|
12
|
+
|
13
|
+
{{- else }}
|
14
|
+
|
15
|
+
An enrollment key has been auto-generated for your server. Please use the following command to fetch the enrollment key:
|
16
|
+
|
17
|
+
kubectl get secret {{ include "zenml.fullname" . }} -o jsonpath="{.data.ZENML_SERVER_PRO_OAUTH2_CLIENT_SECRET}" | base64 --decode
|
18
|
+
|
19
|
+
{{- end }}
|
20
|
+
|
21
|
+
{{- else }}
|
1
22
|
{{- if .Values.zenml.ingress.enabled }}
|
2
23
|
{{- if .Values.zenml.ingress.host }}
|
3
24
|
|
@@ -28,3 +49,4 @@ You can get the ZenML server URL by running these commands:
|
|
28
49
|
{{- end }}
|
29
50
|
|
30
51
|
{{- end }}
|
52
|
+
{{- end }}
|