zenml-nightly 0.72.0.dev20250115__py3-none-any.whl → 0.72.0.dev20250117__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/cli/login.py +126 -50
- zenml/cli/server.py +24 -5
- zenml/config/server_config.py +142 -17
- zenml/constants.py +2 -11
- zenml/login/credentials.py +38 -14
- zenml/login/credentials_store.py +53 -18
- zenml/login/pro/client.py +3 -7
- zenml/login/pro/constants.py +0 -6
- zenml/login/pro/tenant/models.py +4 -2
- zenml/login/pro/utils.py +11 -25
- zenml/login/server_info.py +52 -0
- zenml/login/web_login.py +11 -6
- zenml/models/v2/misc/auth_models.py +1 -1
- zenml/models/v2/misc/server_models.py +44 -0
- zenml/zen_server/auth.py +97 -8
- zenml/zen_server/cloud_utils.py +79 -87
- zenml/zen_server/csrf.py +91 -0
- zenml/zen_server/deploy/helm/templates/NOTES.txt +22 -0
- zenml/zen_server/deploy/helm/templates/_environment.tpl +50 -24
- zenml/zen_server/deploy/helm/templates/server-secret.yaml +11 -0
- zenml/zen_server/deploy/helm/values.yaml +76 -7
- zenml/zen_server/feature_gate/feature_gate_interface.py +1 -1
- zenml/zen_server/jwt.py +16 -1
- zenml/zen_server/rbac/endpoint_utils.py +3 -3
- zenml/zen_server/routers/auth_endpoints.py +44 -21
- zenml/zen_server/routers/models_endpoints.py +1 -2
- zenml/zen_server/routers/pipelines_endpoints.py +2 -2
- zenml/zen_server/routers/stack_deployment_endpoints.py +5 -5
- zenml/zen_server/routers/workspaces_endpoints.py +2 -2
- zenml/zen_server/utils.py +64 -0
- zenml/zen_server/zen_server_api.py +5 -0
- zenml/zen_stores/base_zen_store.py +19 -1
- zenml/zen_stores/rest_zen_store.py +30 -20
- {zenml_nightly-0.72.0.dev20250115.dist-info → zenml_nightly-0.72.0.dev20250117.dist-info}/METADATA +3 -1
- {zenml_nightly-0.72.0.dev20250115.dist-info → zenml_nightly-0.72.0.dev20250117.dist-info}/RECORD +39 -37
- {zenml_nightly-0.72.0.dev20250115.dist-info → zenml_nightly-0.72.0.dev20250117.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.72.0.dev20250115.dist-info → zenml_nightly-0.72.0.dev20250117.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.72.0.dev20250115.dist-info → zenml_nightly-0.72.0.dev20250117.dist-info}/entry_points.txt +0 -0
zenml/constants.py
CHANGED
@@ -176,11 +176,9 @@ ENV_ZENML_PIPELINE_RUN_API_TOKEN_EXPIRATION = (
|
|
176
176
|
|
177
177
|
# ZenML Server environment variables
|
178
178
|
ENV_ZENML_SERVER_PREFIX = "ZENML_SERVER_"
|
179
|
+
ENV_ZENML_SERVER_PRO_PREFIX = "ZENML_SERVER_PRO_"
|
179
180
|
ENV_ZENML_SERVER_DEPLOYMENT_TYPE = f"{ENV_ZENML_SERVER_PREFIX}DEPLOYMENT_TYPE"
|
180
181
|
ENV_ZENML_SERVER_AUTH_SCHEME = f"{ENV_ZENML_SERVER_PREFIX}AUTH_SCHEME"
|
181
|
-
ENV_ZENML_SERVER_REPORTABLE_RESOURCES = (
|
182
|
-
f"{ENV_ZENML_SERVER_PREFIX}REPORTABLE_RESOURCES"
|
183
|
-
)
|
184
182
|
ENV_ZENML_SERVER_AUTO_ACTIVATE = f"{ENV_ZENML_SERVER_PREFIX}AUTO_ACTIVATE"
|
185
183
|
ENV_ZENML_RUN_SINGLE_STEPS_WITHOUT_STACK = (
|
186
184
|
"ZENML_RUN_SINGLE_STEPS_WITHOUT_STACK"
|
@@ -321,14 +319,7 @@ DEFAULT_ZENML_SERVER_SECURE_HEADERS_REPORT_TO = "default"
|
|
321
319
|
DEFAULT_ZENML_SERVER_REPORT_USER_ACTIVITY_TO_DB_SECONDS = 30
|
322
320
|
DEFAULT_ZENML_SERVER_MAX_REQUEST_BODY_SIZE_IN_BYTES = 256 * 1024 * 1024
|
323
321
|
|
324
|
-
|
325
|
-
# entitlement in the case of a cloud deployment. Expected Format is this:
|
326
|
-
# ENV_ZENML_REPORTABLE_RESOURCES='["Foo", "bar"]'
|
327
|
-
REPORTABLE_RESOURCES: List[str] = handle_json_env_var(
|
328
|
-
ENV_ZENML_SERVER_REPORTABLE_RESOURCES,
|
329
|
-
expected_type=list,
|
330
|
-
default=["pipeline", "pipeline_run", "model"],
|
331
|
-
)
|
322
|
+
DEFAULT_REPORTABLE_RESOURCES = ["pipeline", "pipeline_run", "model"]
|
332
323
|
REQUIRES_CUSTOM_RESOURCE_REPORTING = ["pipeline", "pipeline_run"]
|
333
324
|
|
334
325
|
# API Endpoint paths:
|
zenml/login/credentials.py
CHANGED
@@ -23,6 +23,7 @@ from pydantic import BaseModel, ConfigDict
|
|
23
23
|
from zenml.login.pro.constants import ZENML_PRO_API_URL, ZENML_PRO_URL
|
24
24
|
from zenml.login.pro.tenant.models import TenantRead, TenantStatus
|
25
25
|
from zenml.models import ServerModel
|
26
|
+
from zenml.models.v2.misc.server_models import ServerDeploymentType
|
26
27
|
from zenml.services.service_status import ServiceState
|
27
28
|
from zenml.utils.enum_utils import StrEnum
|
28
29
|
from zenml.utils.string_utils import get_human_readable_time
|
@@ -44,7 +45,6 @@ class APIToken(BaseModel):
|
|
44
45
|
expires_in: Optional[int] = None
|
45
46
|
expires_at: Optional[datetime] = None
|
46
47
|
leeway: Optional[int] = None
|
47
|
-
cookie_name: Optional[str] = None
|
48
48
|
device_id: Optional[UUID] = None
|
49
49
|
device_metadata: Optional[Dict[str, Any]] = None
|
50
50
|
|
@@ -89,13 +89,20 @@ class ServerCredentials(BaseModel):
|
|
89
89
|
password: Optional[str] = None
|
90
90
|
|
91
91
|
# Extra server attributes
|
92
|
+
deployment_type: Optional[ServerDeploymentType] = None
|
92
93
|
server_id: Optional[UUID] = None
|
93
94
|
server_name: Optional[str] = None
|
94
|
-
organization_name: Optional[str] = None
|
95
|
-
organization_id: Optional[UUID] = None
|
96
95
|
status: Optional[str] = None
|
97
96
|
version: Optional[str] = None
|
98
97
|
|
98
|
+
# Pro server attributes
|
99
|
+
organization_name: Optional[str] = None
|
100
|
+
organization_id: Optional[UUID] = None
|
101
|
+
tenant_name: Optional[str] = None
|
102
|
+
tenant_id: Optional[UUID] = None
|
103
|
+
pro_api_url: Optional[str] = None
|
104
|
+
pro_dashboard_url: Optional[str] = None
|
105
|
+
|
99
106
|
@property
|
100
107
|
def id(self) -> str:
|
101
108
|
"""Get the server identifier.
|
@@ -114,11 +121,13 @@ class ServerCredentials(BaseModel):
|
|
114
121
|
Returns:
|
115
122
|
The server type.
|
116
123
|
"""
|
117
|
-
|
118
|
-
|
124
|
+
if self.deployment_type == ServerDeploymentType.CLOUD:
|
125
|
+
return ServerType.PRO
|
119
126
|
if self.url == ZENML_PRO_API_URL:
|
120
127
|
return ServerType.PRO_API
|
121
|
-
if self.
|
128
|
+
if self.url == self.pro_api_url:
|
129
|
+
return ServerType.PRO_API
|
130
|
+
if self.organization_id or self.tenant_id:
|
122
131
|
return ServerType.PRO
|
123
132
|
if urlparse(self.url).hostname in [
|
124
133
|
"localhost",
|
@@ -139,25 +148,39 @@ class ServerCredentials(BaseModel):
|
|
139
148
|
if isinstance(server_info, ServerModel):
|
140
149
|
# The server ID doesn't change during the lifetime of the server
|
141
150
|
self.server_id = self.server_id or server_info.id
|
142
|
-
|
143
151
|
# All other attributes can change during the lifetime of the server
|
152
|
+
self.deployment_type = server_info.deployment_type
|
144
153
|
server_name = (
|
145
|
-
server_info.
|
154
|
+
server_info.pro_tenant_name
|
155
|
+
or server_info.metadata.get("tenant_name")
|
156
|
+
or server_info.name
|
146
157
|
)
|
147
158
|
if server_name:
|
148
159
|
self.server_name = server_name
|
149
|
-
|
150
|
-
|
151
|
-
|
160
|
+
if server_info.pro_organization_id:
|
161
|
+
self.organization_id = server_info.pro_organization_id
|
162
|
+
if server_info.pro_tenant_id:
|
163
|
+
self.server_id = server_info.pro_tenant_id
|
164
|
+
if server_info.pro_organization_name:
|
165
|
+
self.organization_name = server_info.pro_organization_name
|
166
|
+
if server_info.pro_tenant_name:
|
167
|
+
self.tenant_name = server_info.pro_tenant_name
|
168
|
+
if server_info.pro_api_url:
|
169
|
+
self.pro_api_url = server_info.pro_api_url
|
170
|
+
if server_info.pro_dashboard_url:
|
171
|
+
self.pro_dashboard_url = server_info.pro_dashboard_url
|
152
172
|
self.version = server_info.version or self.version
|
153
173
|
# The server information was retrieved from the server itself, so we
|
154
174
|
# can assume that the server is available
|
155
175
|
self.status = "available"
|
156
176
|
else:
|
177
|
+
self.deployment_type = ServerDeploymentType.CLOUD
|
157
178
|
self.server_id = server_info.id
|
158
179
|
self.server_name = server_info.name
|
159
180
|
self.organization_name = server_info.organization_name
|
160
181
|
self.organization_id = server_info.organization_id
|
182
|
+
self.tenant_name = server_info.name
|
183
|
+
self.tenant_id = server_info.id
|
161
184
|
self.status = server_info.status
|
162
185
|
self.version = server_info.version
|
163
186
|
|
@@ -248,9 +271,10 @@ class ServerCredentials(BaseModel):
|
|
248
271
|
"""
|
249
272
|
if self.organization_id and self.server_id:
|
250
273
|
return (
|
251
|
-
ZENML_PRO_URL
|
274
|
+
(self.pro_dashboard_url or ZENML_PRO_URL)
|
252
275
|
+ f"/organizations/{str(self.organization_id)}/tenants/{str(self.server_id)}"
|
253
276
|
)
|
277
|
+
|
254
278
|
return self.url
|
255
279
|
|
256
280
|
@property
|
@@ -262,8 +286,8 @@ class ServerCredentials(BaseModel):
|
|
262
286
|
"""
|
263
287
|
if self.organization_id:
|
264
288
|
return (
|
265
|
-
|
266
|
-
)
|
289
|
+
self.pro_dashboard_url or ZENML_PRO_URL
|
290
|
+
) + f"/organizations/{str(self.organization_id)}"
|
267
291
|
return ""
|
268
292
|
|
269
293
|
@property
|
zenml/login/credentials_store.py
CHANGED
@@ -25,7 +25,6 @@ from zenml.constants import (
|
|
25
25
|
from zenml.io import fileio
|
26
26
|
from zenml.logger import get_logger
|
27
27
|
from zenml.login.credentials import APIToken, ServerCredentials, ServerType
|
28
|
-
from zenml.login.pro.constants import ZENML_PRO_API_URL
|
29
28
|
from zenml.login.pro.tenant.models import TenantRead
|
30
29
|
from zenml.models import OAuthTokenResponse, ServerModel
|
31
30
|
from zenml.utils import yaml_utils
|
@@ -289,10 +288,13 @@ class CredentialsStore(metaclass=SingletonMetaClass):
|
|
289
288
|
self.check_and_reload_from_file()
|
290
289
|
return self.credentials.get(server_url)
|
291
290
|
|
292
|
-
def get_pro_token(
|
293
|
-
|
291
|
+
def get_pro_token(
|
292
|
+
self, pro_api_url: str, allow_expired: bool = False
|
293
|
+
) -> Optional[APIToken]:
|
294
|
+
"""Retrieve a valid token from the credentials store for a ZenML Pro API server.
|
294
295
|
|
295
296
|
Args:
|
297
|
+
pro_api_url: The URL of the ZenML Pro API server.
|
296
298
|
allow_expired: Whether to allow expired tokens to be returned. The
|
297
299
|
default behavior is to return None if a token does exist but is
|
298
300
|
expired.
|
@@ -300,14 +302,18 @@ class CredentialsStore(metaclass=SingletonMetaClass):
|
|
300
302
|
Returns:
|
301
303
|
The stored token if it exists and is not expired, None otherwise.
|
302
304
|
"""
|
303
|
-
|
305
|
+
credential = self.get_pro_credentials(pro_api_url, allow_expired)
|
306
|
+
if credential:
|
307
|
+
return credential.api_token
|
308
|
+
return None
|
304
309
|
|
305
310
|
def get_pro_credentials(
|
306
|
-
self, allow_expired: bool = False
|
311
|
+
self, pro_api_url: str, allow_expired: bool = False
|
307
312
|
) -> Optional[ServerCredentials]:
|
308
|
-
"""Retrieve a valid token from the credentials store for
|
313
|
+
"""Retrieve a valid token from the credentials store for a ZenML Pro API server.
|
309
314
|
|
310
315
|
Args:
|
316
|
+
pro_api_url: The URL of the ZenML Pro API server.
|
311
317
|
allow_expired: Whether to allow expired tokens to be returned. The
|
312
318
|
default behavior is to return None if a token does exist but is
|
313
319
|
expired.
|
@@ -315,26 +321,47 @@ class CredentialsStore(metaclass=SingletonMetaClass):
|
|
315
321
|
Returns:
|
316
322
|
The stored credentials if they exist and are not expired, None otherwise.
|
317
323
|
"""
|
318
|
-
credential = self.get_credentials(
|
324
|
+
credential = self.get_credentials(pro_api_url)
|
319
325
|
if (
|
320
326
|
credential
|
327
|
+
and credential.type == ServerType.PRO_API
|
321
328
|
and credential.api_token
|
322
329
|
and (not credential.api_token.expired or allow_expired)
|
323
330
|
):
|
324
331
|
return credential
|
325
332
|
return None
|
326
333
|
|
327
|
-
def clear_pro_credentials(self) -> None:
|
328
|
-
"""Delete the token from the store for
|
329
|
-
|
334
|
+
def clear_pro_credentials(self, pro_api_url: str) -> None:
|
335
|
+
"""Delete the token from the store for a ZenML Pro API server.
|
336
|
+
|
337
|
+
Args:
|
338
|
+
pro_api_url: The URL of the ZenML Pro API server.
|
339
|
+
"""
|
340
|
+
self.clear_token(pro_api_url)
|
341
|
+
|
342
|
+
def clear_all_pro_tokens(
|
343
|
+
self, pro_api_url: str
|
344
|
+
) -> List[ServerCredentials]:
|
345
|
+
"""Delete all tokens from the store for ZenML Pro servers connected to a given API server.
|
346
|
+
|
347
|
+
Args:
|
348
|
+
pro_api_url: The URL of the ZenML Pro API server.
|
330
349
|
|
331
|
-
|
332
|
-
|
350
|
+
Returns:
|
351
|
+
A list of the credentials that were cleared.
|
352
|
+
"""
|
353
|
+
credentials_to_clear = []
|
333
354
|
for server_url, server in self.credentials.copy().items():
|
334
|
-
if
|
355
|
+
if (
|
356
|
+
server.type == ServerType.PRO
|
357
|
+
and server.pro_api_url
|
358
|
+
and server.pro_api_url == pro_api_url
|
359
|
+
):
|
335
360
|
if server.api_key:
|
336
361
|
continue
|
337
362
|
self.clear_token(server_url)
|
363
|
+
credentials_to_clear.append(server)
|
364
|
+
return credentials_to_clear
|
338
365
|
|
339
366
|
def has_valid_authentication(self, url: str) -> bool:
|
340
367
|
"""Check if a valid authentication credential for the given server URL is stored.
|
@@ -357,13 +384,16 @@ class CredentialsStore(metaclass=SingletonMetaClass):
|
|
357
384
|
token = credential.api_token
|
358
385
|
return token is not None and not token.expired
|
359
386
|
|
360
|
-
def has_valid_pro_authentication(self) -> bool:
|
361
|
-
"""Check if a valid token for
|
387
|
+
def has_valid_pro_authentication(self, pro_api_url: str) -> bool:
|
388
|
+
"""Check if a valid token for a ZenML Pro API server is stored.
|
389
|
+
|
390
|
+
Args:
|
391
|
+
pro_api_url: The URL of the ZenML Pro API server.
|
362
392
|
|
363
393
|
Returns:
|
364
394
|
bool: True if a valid token is stored, False otherwise.
|
365
395
|
"""
|
366
|
-
return self.
|
396
|
+
return self.get_pro_token(pro_api_url) is not None
|
367
397
|
|
368
398
|
def set_api_key(
|
369
399
|
self,
|
@@ -433,12 +463,14 @@ class CredentialsStore(metaclass=SingletonMetaClass):
|
|
433
463
|
self,
|
434
464
|
server_url: str,
|
435
465
|
token_response: OAuthTokenResponse,
|
466
|
+
is_zenml_pro: bool = False,
|
436
467
|
) -> APIToken:
|
437
468
|
"""Store an API token received from an OAuth2 server.
|
438
469
|
|
439
470
|
Args:
|
440
471
|
server_url: The server URL for which the token is to be stored.
|
441
472
|
token_response: Token response received from an OAuth2 server.
|
473
|
+
is_zenml_pro: Whether the token is for a ZenML Pro server.
|
442
474
|
|
443
475
|
Returns:
|
444
476
|
APIToken: The stored token.
|
@@ -468,7 +500,6 @@ class CredentialsStore(metaclass=SingletonMetaClass):
|
|
468
500
|
expires_in=token_response.expires_in,
|
469
501
|
expires_at=expires_at,
|
470
502
|
leeway=leeway,
|
471
|
-
cookie_name=token_response.cookie_name,
|
472
503
|
device_id=token_response.device_id,
|
473
504
|
device_metadata=token_response.device_metadata,
|
474
505
|
)
|
@@ -477,10 +508,14 @@ class CredentialsStore(metaclass=SingletonMetaClass):
|
|
477
508
|
if credential:
|
478
509
|
credential.api_token = api_token
|
479
510
|
else:
|
480
|
-
self.credentials[server_url] = ServerCredentials(
|
511
|
+
credential = self.credentials[server_url] = ServerCredentials(
|
481
512
|
url=server_url, api_token=api_token
|
482
513
|
)
|
483
514
|
|
515
|
+
if is_zenml_pro:
|
516
|
+
# This is how we encode that the token is for a ZenML Pro server
|
517
|
+
credential.pro_api_url = server_url
|
518
|
+
|
484
519
|
self._save_credentials()
|
485
520
|
|
486
521
|
return api_token
|
zenml/login/pro/client.py
CHANGED
@@ -33,7 +33,6 @@ from zenml.exceptions import AuthorizationException
|
|
33
33
|
from zenml.logger import get_logger
|
34
34
|
from zenml.login.credentials import APIToken
|
35
35
|
from zenml.login.credentials_store import get_credentials_store
|
36
|
-
from zenml.login.pro.constants import ZENML_PRO_API_URL
|
37
36
|
from zenml.login.pro.models import BaseRestAPIModel
|
38
37
|
from zenml.utils.singleton import SingletonMetaClass
|
39
38
|
from zenml.zen_server.exceptions import exception_from_response
|
@@ -60,14 +59,11 @@ class ZenMLProClient(metaclass=SingletonMetaClass):
|
|
60
59
|
_tenant: Optional["TenantClient"] = None
|
61
60
|
_organization: Optional["OrganizationClient"] = None
|
62
61
|
|
63
|
-
def __init__(
|
64
|
-
self, url: Optional[str] = None, api_token: Optional[APIToken] = None
|
65
|
-
) -> None:
|
62
|
+
def __init__(self, url: str, api_token: Optional[APIToken] = None) -> None:
|
66
63
|
"""Initialize the ZenML Pro client.
|
67
64
|
|
68
65
|
Args:
|
69
|
-
url: The URL of the ZenML Pro API server.
|
70
|
-
default ZenML Pro API server URL is used.
|
66
|
+
url: The URL of the ZenML Pro API server.
|
71
67
|
api_token: The API token to use for authentication. If not provided,
|
72
68
|
the token is fetched from the credentials store.
|
73
69
|
|
@@ -75,7 +71,7 @@ class ZenMLProClient(metaclass=SingletonMetaClass):
|
|
75
71
|
AuthorizationException: If no API token is provided and no token
|
76
72
|
is found in the credentials store.
|
77
73
|
"""
|
78
|
-
self._url = url
|
74
|
+
self._url = url
|
79
75
|
if api_token is None:
|
80
76
|
logger.debug(
|
81
77
|
"No ZenML Pro API token provided. Fetching from credentials "
|
zenml/login/pro/constants.py
CHANGED
@@ -26,9 +26,3 @@ ENV_ZENML_PRO_URL = "ZENML_PRO_URL"
|
|
26
26
|
DEFAULT_ZENML_PRO_URL = "https://cloud.zenml.io"
|
27
27
|
|
28
28
|
ZENML_PRO_URL = os.getenv(ENV_ZENML_PRO_URL, default=DEFAULT_ZENML_PRO_URL)
|
29
|
-
|
30
|
-
ENV_ZENML_PRO_SERVER_SUBDOMAIN = "ZENML_PRO_SERVER_SUBDOMAIN"
|
31
|
-
DEFAULT_ZENML_PRO_SERVER_SUBDOMAIN = "cloudinfra.zenml.io"
|
32
|
-
ZENML_PRO_SERVER_SUBDOMAIN = os.getenv(
|
33
|
-
ENV_ZENML_PRO_SERVER_SUBDOMAIN, default=DEFAULT_ZENML_PRO_SERVER_SUBDOMAIN
|
34
|
-
)
|
zenml/login/pro/tenant/models.py
CHANGED
@@ -76,7 +76,7 @@ class ZenMLServiceStatus(BaseRestAPIModel):
|
|
76
76
|
class ZenMLServiceRead(BaseRestAPIModel):
|
77
77
|
"""Pydantic Model for viewing a ZenML service."""
|
78
78
|
|
79
|
-
configuration: ZenMLServiceConfiguration = Field(
|
79
|
+
configuration: Optional[ZenMLServiceConfiguration] = Field(
|
80
80
|
description="The service configuration."
|
81
81
|
)
|
82
82
|
|
@@ -133,7 +133,9 @@ class TenantRead(BaseRestAPIModel):
|
|
133
133
|
Returns:
|
134
134
|
The ZenML service version.
|
135
135
|
"""
|
136
|
-
version =
|
136
|
+
version = None
|
137
|
+
if self.zenml_service.configuration:
|
138
|
+
version = self.zenml_service.configuration.version
|
137
139
|
if self.zenml_service.status and self.zenml_service.status.version:
|
138
140
|
version = self.zenml_service.status.version
|
139
141
|
|
zenml/login/pro/utils.py
CHANGED
@@ -13,37 +13,16 @@
|
|
13
13
|
# permissions and limitations under the License.
|
14
14
|
"""ZenML Pro login utils."""
|
15
15
|
|
16
|
-
import re
|
17
|
-
|
18
16
|
from zenml.logger import get_logger
|
17
|
+
from zenml.login.credentials import ServerType
|
19
18
|
from zenml.login.credentials_store import get_credentials_store
|
20
19
|
from zenml.login.pro.client import ZenMLProClient
|
21
|
-
from zenml.login.pro.constants import
|
20
|
+
from zenml.login.pro.constants import ZENML_PRO_API_URL
|
22
21
|
from zenml.login.pro.tenant.models import TenantStatus
|
23
22
|
|
24
23
|
logger = get_logger(__name__)
|
25
24
|
|
26
25
|
|
27
|
-
def is_zenml_pro_server_url(url: str) -> bool:
|
28
|
-
"""Check if a given URL is a ZenML Pro server.
|
29
|
-
|
30
|
-
Args:
|
31
|
-
url: URL to check
|
32
|
-
|
33
|
-
Returns:
|
34
|
-
True if the URL is a ZenML Pro tenant, False otherwise
|
35
|
-
"""
|
36
|
-
domain_regex = ZENML_PRO_SERVER_SUBDOMAIN.replace(".", r"\.")
|
37
|
-
return bool(
|
38
|
-
re.match(
|
39
|
-
r"^(https://)?[a-zA-Z0-9-\.]+\.{domain}/?$".format(
|
40
|
-
domain=domain_regex
|
41
|
-
),
|
42
|
-
url,
|
43
|
-
)
|
44
|
-
)
|
45
|
-
|
46
|
-
|
47
26
|
def get_troubleshooting_instructions(url: str) -> str:
|
48
27
|
"""Get troubleshooting instructions for a given ZenML Pro server URL.
|
49
28
|
|
@@ -54,8 +33,15 @@ def get_troubleshooting_instructions(url: str) -> str:
|
|
54
33
|
Troubleshooting instructions
|
55
34
|
"""
|
56
35
|
credentials_store = get_credentials_store()
|
57
|
-
|
58
|
-
|
36
|
+
|
37
|
+
credentials = credentials_store.get_credentials(url)
|
38
|
+
if credentials and credentials.type == ServerType.PRO:
|
39
|
+
pro_api_url = credentials.pro_api_url or ZENML_PRO_API_URL
|
40
|
+
|
41
|
+
if pro_api_url and credentials_store.has_valid_pro_authentication(
|
42
|
+
pro_api_url
|
43
|
+
):
|
44
|
+
client = ZenMLProClient(pro_api_url)
|
59
45
|
|
60
46
|
try:
|
61
47
|
servers = client.tenant.list(url=url, member_only=False)
|
@@ -0,0 +1,52 @@
|
|
1
|
+
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
12
|
+
# or implied. See the License for the specific language governing
|
13
|
+
# permissions and limitations under the License.
|
14
|
+
"""ZenML server information retrieval."""
|
15
|
+
|
16
|
+
from typing import Optional
|
17
|
+
|
18
|
+
from zenml.logger import get_logger
|
19
|
+
from zenml.models import ServerModel
|
20
|
+
from zenml.zen_stores.rest_zen_store import (
|
21
|
+
RestZenStore,
|
22
|
+
RestZenStoreConfiguration,
|
23
|
+
)
|
24
|
+
|
25
|
+
logger = get_logger(__name__)
|
26
|
+
|
27
|
+
|
28
|
+
def get_server_info(url: str) -> Optional[ServerModel]:
|
29
|
+
"""Retrieve server information from a remote ZenML server.
|
30
|
+
|
31
|
+
Args:
|
32
|
+
url: The URL of the ZenML server.
|
33
|
+
|
34
|
+
Returns:
|
35
|
+
The server information or None if the server info could not be fetched.
|
36
|
+
"""
|
37
|
+
# Here we try to leverage the existing RestZenStore support to fetch the
|
38
|
+
# server info and only the server info, which doesn't actually need
|
39
|
+
# any authentication.
|
40
|
+
try:
|
41
|
+
store = RestZenStore(
|
42
|
+
config=RestZenStoreConfiguration(
|
43
|
+
url=url,
|
44
|
+
)
|
45
|
+
)
|
46
|
+
return store.server_info
|
47
|
+
except Exception as e:
|
48
|
+
logger.warning(
|
49
|
+
f"Failed to fetch server info from the server running at {url}: {e}"
|
50
|
+
)
|
51
|
+
|
52
|
+
return None
|
zenml/login/web_login.py
CHANGED
@@ -34,13 +34,14 @@ from zenml.exceptions import AuthorizationException, OAuthError
|
|
34
34
|
from zenml.logger import get_logger
|
35
35
|
from zenml.login.credentials import APIToken
|
36
36
|
from zenml.login.pro.constants import ZENML_PRO_API_URL
|
37
|
-
from zenml.login.pro.utils import is_zenml_pro_server_url
|
38
37
|
|
39
38
|
logger = get_logger(__name__)
|
40
39
|
|
41
40
|
|
42
41
|
def web_login(
|
43
|
-
url: Optional[str] = None,
|
42
|
+
url: Optional[str] = None,
|
43
|
+
verify_ssl: Optional[Union[str, bool]] = None,
|
44
|
+
pro_api_url: Optional[str] = None,
|
44
45
|
) -> APIToken:
|
45
46
|
"""Implements the OAuth2 Device Authorization Grant flow.
|
46
47
|
|
@@ -61,6 +62,8 @@ def web_login(
|
|
61
62
|
verify_ssl: Whether to verify the SSL certificate of the OAuth2 server.
|
62
63
|
If a string is passed, it is interpreted as the path to a CA bundle
|
63
64
|
file.
|
65
|
+
pro_api_url: The URL of the ZenML Pro API server. If not provided, the
|
66
|
+
default ZenML Pro API server URL is used.
|
64
67
|
|
65
68
|
Returns:
|
66
69
|
The response returned by the OAuth2 server.
|
@@ -103,16 +106,16 @@ def web_login(
|
|
103
106
|
if not url:
|
104
107
|
# If no URL is provided, we use the ZenML Pro API server by default
|
105
108
|
zenml_pro = True
|
106
|
-
url = base_url = ZENML_PRO_API_URL
|
109
|
+
url = base_url = pro_api_url or ZENML_PRO_API_URL
|
107
110
|
else:
|
108
111
|
# Get rid of any trailing slashes to prevent issues when having double
|
109
112
|
# slashes in the URL
|
110
113
|
url = url.rstrip("/")
|
111
|
-
if
|
114
|
+
if pro_api_url:
|
112
115
|
# This is a ZenML Pro server. The device authentication is done
|
113
116
|
# through the ZenML Pro API.
|
114
117
|
zenml_pro = True
|
115
|
-
base_url =
|
118
|
+
base_url = pro_api_url
|
116
119
|
else:
|
117
120
|
base_url = url
|
118
121
|
|
@@ -240,4 +243,6 @@ def web_login(
|
|
240
243
|
)
|
241
244
|
|
242
245
|
# Save the token in the credentials store
|
243
|
-
return credentials_store.set_token(
|
246
|
+
return credentials_store.set_token(
|
247
|
+
url, token_response, is_zenml_pro=zenml_pro
|
248
|
+
)
|
@@ -119,8 +119,8 @@ class OAuthTokenResponse(BaseModel):
|
|
119
119
|
token_type: str
|
120
120
|
expires_in: Optional[int] = None
|
121
121
|
refresh_token: Optional[str] = None
|
122
|
+
csrf_token: Optional[str] = None
|
122
123
|
scope: Optional[str] = None
|
123
|
-
cookie_name: Optional[str] = None
|
124
124
|
device_id: Optional[UUID] = None
|
125
125
|
device_metadata: Optional[Dict[str, Any]] = None
|
126
126
|
|
@@ -107,6 +107,42 @@ class ServerModel(BaseModel):
|
|
107
107
|
title="Timestamp of latest user activity traced on the server.",
|
108
108
|
)
|
109
109
|
|
110
|
+
pro_dashboard_url: Optional[str] = Field(
|
111
|
+
None,
|
112
|
+
title="The base URL of the ZenML Pro dashboard to which the server "
|
113
|
+
"is connected. Only set if the server is a ZenML Pro server.",
|
114
|
+
)
|
115
|
+
|
116
|
+
pro_api_url: Optional[str] = Field(
|
117
|
+
None,
|
118
|
+
title="The base URL of the ZenML Pro API to which the server is "
|
119
|
+
"connected. Only set if the server is a ZenML Pro server.",
|
120
|
+
)
|
121
|
+
|
122
|
+
pro_organization_id: Optional[UUID] = Field(
|
123
|
+
None,
|
124
|
+
title="The ID of the ZenML Pro organization to which the server is "
|
125
|
+
"connected. Only set if the server is a ZenML Pro server.",
|
126
|
+
)
|
127
|
+
|
128
|
+
pro_organization_name: Optional[str] = Field(
|
129
|
+
None,
|
130
|
+
title="The name of the ZenML Pro organization to which the server is "
|
131
|
+
"connected. Only set if the server is a ZenML Pro server.",
|
132
|
+
)
|
133
|
+
|
134
|
+
pro_tenant_id: Optional[UUID] = Field(
|
135
|
+
None,
|
136
|
+
title="The ID of the ZenML Pro tenant to which the server is connected. "
|
137
|
+
"Only set if the server is a ZenML Pro server.",
|
138
|
+
)
|
139
|
+
|
140
|
+
pro_tenant_name: Optional[str] = Field(
|
141
|
+
None,
|
142
|
+
title="The name of the ZenML Pro tenant to which the server is connected. "
|
143
|
+
"Only set if the server is a ZenML Pro server.",
|
144
|
+
)
|
145
|
+
|
110
146
|
def is_local(self) -> bool:
|
111
147
|
"""Return whether the server is running locally.
|
112
148
|
|
@@ -119,6 +155,14 @@ class ServerModel(BaseModel):
|
|
119
155
|
# server ID is the same as the local client (user) ID.
|
120
156
|
return self.id == GlobalConfiguration().user_id
|
121
157
|
|
158
|
+
def is_pro_server(self) -> bool:
|
159
|
+
"""Return whether the server is a ZenML Pro server.
|
160
|
+
|
161
|
+
Returns:
|
162
|
+
True if the server is a ZenML Pro server, False otherwise.
|
163
|
+
"""
|
164
|
+
return self.deployment_type == ServerDeploymentType.CLOUD
|
165
|
+
|
122
166
|
|
123
167
|
class ServerLoadInfo(BaseModel):
|
124
168
|
"""Domain model for ZenML server load information."""
|