zenml-nightly 0.72.0.dev20250115__py3-none-any.whl → 0.72.0.dev20250117__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. zenml/VERSION +1 -1
  2. zenml/cli/login.py +126 -50
  3. zenml/cli/server.py +24 -5
  4. zenml/config/server_config.py +142 -17
  5. zenml/constants.py +2 -11
  6. zenml/login/credentials.py +38 -14
  7. zenml/login/credentials_store.py +53 -18
  8. zenml/login/pro/client.py +3 -7
  9. zenml/login/pro/constants.py +0 -6
  10. zenml/login/pro/tenant/models.py +4 -2
  11. zenml/login/pro/utils.py +11 -25
  12. zenml/login/server_info.py +52 -0
  13. zenml/login/web_login.py +11 -6
  14. zenml/models/v2/misc/auth_models.py +1 -1
  15. zenml/models/v2/misc/server_models.py +44 -0
  16. zenml/zen_server/auth.py +97 -8
  17. zenml/zen_server/cloud_utils.py +79 -87
  18. zenml/zen_server/csrf.py +91 -0
  19. zenml/zen_server/deploy/helm/templates/NOTES.txt +22 -0
  20. zenml/zen_server/deploy/helm/templates/_environment.tpl +50 -24
  21. zenml/zen_server/deploy/helm/templates/server-secret.yaml +11 -0
  22. zenml/zen_server/deploy/helm/values.yaml +76 -7
  23. zenml/zen_server/feature_gate/feature_gate_interface.py +1 -1
  24. zenml/zen_server/jwt.py +16 -1
  25. zenml/zen_server/rbac/endpoint_utils.py +3 -3
  26. zenml/zen_server/routers/auth_endpoints.py +44 -21
  27. zenml/zen_server/routers/models_endpoints.py +1 -2
  28. zenml/zen_server/routers/pipelines_endpoints.py +2 -2
  29. zenml/zen_server/routers/stack_deployment_endpoints.py +5 -5
  30. zenml/zen_server/routers/workspaces_endpoints.py +2 -2
  31. zenml/zen_server/utils.py +64 -0
  32. zenml/zen_server/zen_server_api.py +5 -0
  33. zenml/zen_stores/base_zen_store.py +19 -1
  34. zenml/zen_stores/rest_zen_store.py +30 -20
  35. {zenml_nightly-0.72.0.dev20250115.dist-info → zenml_nightly-0.72.0.dev20250117.dist-info}/METADATA +3 -1
  36. {zenml_nightly-0.72.0.dev20250115.dist-info → zenml_nightly-0.72.0.dev20250117.dist-info}/RECORD +39 -37
  37. {zenml_nightly-0.72.0.dev20250115.dist-info → zenml_nightly-0.72.0.dev20250117.dist-info}/LICENSE +0 -0
  38. {zenml_nightly-0.72.0.dev20250115.dist-info → zenml_nightly-0.72.0.dev20250117.dist-info}/WHEEL +0 -0
  39. {zenml_nightly-0.72.0.dev20250115.dist-info → zenml_nightly-0.72.0.dev20250117.dist-info}/entry_points.txt +0 -0
zenml/constants.py CHANGED
@@ -176,11 +176,9 @@ ENV_ZENML_PIPELINE_RUN_API_TOKEN_EXPIRATION = (
176
176
 
177
177
  # ZenML Server environment variables
178
178
  ENV_ZENML_SERVER_PREFIX = "ZENML_SERVER_"
179
+ ENV_ZENML_SERVER_PRO_PREFIX = "ZENML_SERVER_PRO_"
179
180
  ENV_ZENML_SERVER_DEPLOYMENT_TYPE = f"{ENV_ZENML_SERVER_PREFIX}DEPLOYMENT_TYPE"
180
181
  ENV_ZENML_SERVER_AUTH_SCHEME = f"{ENV_ZENML_SERVER_PREFIX}AUTH_SCHEME"
181
- ENV_ZENML_SERVER_REPORTABLE_RESOURCES = (
182
- f"{ENV_ZENML_SERVER_PREFIX}REPORTABLE_RESOURCES"
183
- )
184
182
  ENV_ZENML_SERVER_AUTO_ACTIVATE = f"{ENV_ZENML_SERVER_PREFIX}AUTO_ACTIVATE"
185
183
  ENV_ZENML_RUN_SINGLE_STEPS_WITHOUT_STACK = (
186
184
  "ZENML_RUN_SINGLE_STEPS_WITHOUT_STACK"
@@ -321,14 +319,7 @@ DEFAULT_ZENML_SERVER_SECURE_HEADERS_REPORT_TO = "default"
321
319
  DEFAULT_ZENML_SERVER_REPORT_USER_ACTIVITY_TO_DB_SECONDS = 30
322
320
  DEFAULT_ZENML_SERVER_MAX_REQUEST_BODY_SIZE_IN_BYTES = 256 * 1024 * 1024
323
321
 
324
- # Configurations to decide which resources report their usage and check for
325
- # entitlement in the case of a cloud deployment. Expected Format is this:
326
- # ENV_ZENML_REPORTABLE_RESOURCES='["Foo", "bar"]'
327
- REPORTABLE_RESOURCES: List[str] = handle_json_env_var(
328
- ENV_ZENML_SERVER_REPORTABLE_RESOURCES,
329
- expected_type=list,
330
- default=["pipeline", "pipeline_run", "model"],
331
- )
322
+ DEFAULT_REPORTABLE_RESOURCES = ["pipeline", "pipeline_run", "model"]
332
323
  REQUIRES_CUSTOM_RESOURCE_REPORTING = ["pipeline", "pipeline_run"]
333
324
 
334
325
  # API Endpoint paths:
@@ -23,6 +23,7 @@ from pydantic import BaseModel, ConfigDict
23
23
  from zenml.login.pro.constants import ZENML_PRO_API_URL, ZENML_PRO_URL
24
24
  from zenml.login.pro.tenant.models import TenantRead, TenantStatus
25
25
  from zenml.models import ServerModel
26
+ from zenml.models.v2.misc.server_models import ServerDeploymentType
26
27
  from zenml.services.service_status import ServiceState
27
28
  from zenml.utils.enum_utils import StrEnum
28
29
  from zenml.utils.string_utils import get_human_readable_time
@@ -44,7 +45,6 @@ class APIToken(BaseModel):
44
45
  expires_in: Optional[int] = None
45
46
  expires_at: Optional[datetime] = None
46
47
  leeway: Optional[int] = None
47
- cookie_name: Optional[str] = None
48
48
  device_id: Optional[UUID] = None
49
49
  device_metadata: Optional[Dict[str, Any]] = None
50
50
 
@@ -89,13 +89,20 @@ class ServerCredentials(BaseModel):
89
89
  password: Optional[str] = None
90
90
 
91
91
  # Extra server attributes
92
+ deployment_type: Optional[ServerDeploymentType] = None
92
93
  server_id: Optional[UUID] = None
93
94
  server_name: Optional[str] = None
94
- organization_name: Optional[str] = None
95
- organization_id: Optional[UUID] = None
96
95
  status: Optional[str] = None
97
96
  version: Optional[str] = None
98
97
 
98
+ # Pro server attributes
99
+ organization_name: Optional[str] = None
100
+ organization_id: Optional[UUID] = None
101
+ tenant_name: Optional[str] = None
102
+ tenant_id: Optional[UUID] = None
103
+ pro_api_url: Optional[str] = None
104
+ pro_dashboard_url: Optional[str] = None
105
+
99
106
  @property
100
107
  def id(self) -> str:
101
108
  """Get the server identifier.
@@ -114,11 +121,13 @@ class ServerCredentials(BaseModel):
114
121
  Returns:
115
122
  The server type.
116
123
  """
117
- from zenml.login.pro.utils import is_zenml_pro_server_url
118
-
124
+ if self.deployment_type == ServerDeploymentType.CLOUD:
125
+ return ServerType.PRO
119
126
  if self.url == ZENML_PRO_API_URL:
120
127
  return ServerType.PRO_API
121
- if self.organization_id or is_zenml_pro_server_url(self.url):
128
+ if self.url == self.pro_api_url:
129
+ return ServerType.PRO_API
130
+ if self.organization_id or self.tenant_id:
122
131
  return ServerType.PRO
123
132
  if urlparse(self.url).hostname in [
124
133
  "localhost",
@@ -139,25 +148,39 @@ class ServerCredentials(BaseModel):
139
148
  if isinstance(server_info, ServerModel):
140
149
  # The server ID doesn't change during the lifetime of the server
141
150
  self.server_id = self.server_id or server_info.id
142
-
143
151
  # All other attributes can change during the lifetime of the server
152
+ self.deployment_type = server_info.deployment_type
144
153
  server_name = (
145
- server_info.metadata.get("tenant_name") or server_info.name
154
+ server_info.pro_tenant_name
155
+ or server_info.metadata.get("tenant_name")
156
+ or server_info.name
146
157
  )
147
158
  if server_name:
148
159
  self.server_name = server_name
149
- organization_id = server_info.metadata.get("organization_id")
150
- if organization_id:
151
- self.organization_id = UUID(organization_id)
160
+ if server_info.pro_organization_id:
161
+ self.organization_id = server_info.pro_organization_id
162
+ if server_info.pro_tenant_id:
163
+ self.server_id = server_info.pro_tenant_id
164
+ if server_info.pro_organization_name:
165
+ self.organization_name = server_info.pro_organization_name
166
+ if server_info.pro_tenant_name:
167
+ self.tenant_name = server_info.pro_tenant_name
168
+ if server_info.pro_api_url:
169
+ self.pro_api_url = server_info.pro_api_url
170
+ if server_info.pro_dashboard_url:
171
+ self.pro_dashboard_url = server_info.pro_dashboard_url
152
172
  self.version = server_info.version or self.version
153
173
  # The server information was retrieved from the server itself, so we
154
174
  # can assume that the server is available
155
175
  self.status = "available"
156
176
  else:
177
+ self.deployment_type = ServerDeploymentType.CLOUD
157
178
  self.server_id = server_info.id
158
179
  self.server_name = server_info.name
159
180
  self.organization_name = server_info.organization_name
160
181
  self.organization_id = server_info.organization_id
182
+ self.tenant_name = server_info.name
183
+ self.tenant_id = server_info.id
161
184
  self.status = server_info.status
162
185
  self.version = server_info.version
163
186
 
@@ -248,9 +271,10 @@ class ServerCredentials(BaseModel):
248
271
  """
249
272
  if self.organization_id and self.server_id:
250
273
  return (
251
- ZENML_PRO_URL
274
+ (self.pro_dashboard_url or ZENML_PRO_URL)
252
275
  + f"/organizations/{str(self.organization_id)}/tenants/{str(self.server_id)}"
253
276
  )
277
+
254
278
  return self.url
255
279
 
256
280
  @property
@@ -262,8 +286,8 @@ class ServerCredentials(BaseModel):
262
286
  """
263
287
  if self.organization_id:
264
288
  return (
265
- ZENML_PRO_URL + f"/organizations/{str(self.organization_id)}"
266
- )
289
+ self.pro_dashboard_url or ZENML_PRO_URL
290
+ ) + f"/organizations/{str(self.organization_id)}"
267
291
  return ""
268
292
 
269
293
  @property
@@ -25,7 +25,6 @@ from zenml.constants import (
25
25
  from zenml.io import fileio
26
26
  from zenml.logger import get_logger
27
27
  from zenml.login.credentials import APIToken, ServerCredentials, ServerType
28
- from zenml.login.pro.constants import ZENML_PRO_API_URL
29
28
  from zenml.login.pro.tenant.models import TenantRead
30
29
  from zenml.models import OAuthTokenResponse, ServerModel
31
30
  from zenml.utils import yaml_utils
@@ -289,10 +288,13 @@ class CredentialsStore(metaclass=SingletonMetaClass):
289
288
  self.check_and_reload_from_file()
290
289
  return self.credentials.get(server_url)
291
290
 
292
- def get_pro_token(self, allow_expired: bool = False) -> Optional[APIToken]:
293
- """Retrieve a valid token from the credentials store for the ZenML Pro API server.
291
+ def get_pro_token(
292
+ self, pro_api_url: str, allow_expired: bool = False
293
+ ) -> Optional[APIToken]:
294
+ """Retrieve a valid token from the credentials store for a ZenML Pro API server.
294
295
 
295
296
  Args:
297
+ pro_api_url: The URL of the ZenML Pro API server.
296
298
  allow_expired: Whether to allow expired tokens to be returned. The
297
299
  default behavior is to return None if a token does exist but is
298
300
  expired.
@@ -300,14 +302,18 @@ class CredentialsStore(metaclass=SingletonMetaClass):
300
302
  Returns:
301
303
  The stored token if it exists and is not expired, None otherwise.
302
304
  """
303
- return self.get_token(ZENML_PRO_API_URL, allow_expired)
305
+ credential = self.get_pro_credentials(pro_api_url, allow_expired)
306
+ if credential:
307
+ return credential.api_token
308
+ return None
304
309
 
305
310
  def get_pro_credentials(
306
- self, allow_expired: bool = False
311
+ self, pro_api_url: str, allow_expired: bool = False
307
312
  ) -> Optional[ServerCredentials]:
308
- """Retrieve a valid token from the credentials store for the ZenML Pro API server.
313
+ """Retrieve a valid token from the credentials store for a ZenML Pro API server.
309
314
 
310
315
  Args:
316
+ pro_api_url: The URL of the ZenML Pro API server.
311
317
  allow_expired: Whether to allow expired tokens to be returned. The
312
318
  default behavior is to return None if a token does exist but is
313
319
  expired.
@@ -315,26 +321,47 @@ class CredentialsStore(metaclass=SingletonMetaClass):
315
321
  Returns:
316
322
  The stored credentials if they exist and are not expired, None otherwise.
317
323
  """
318
- credential = self.get_credentials(ZENML_PRO_API_URL)
324
+ credential = self.get_credentials(pro_api_url)
319
325
  if (
320
326
  credential
327
+ and credential.type == ServerType.PRO_API
321
328
  and credential.api_token
322
329
  and (not credential.api_token.expired or allow_expired)
323
330
  ):
324
331
  return credential
325
332
  return None
326
333
 
327
- def clear_pro_credentials(self) -> None:
328
- """Delete the token from the store for the ZenML Pro API server."""
329
- self.clear_token(ZENML_PRO_API_URL)
334
+ def clear_pro_credentials(self, pro_api_url: str) -> None:
335
+ """Delete the token from the store for a ZenML Pro API server.
336
+
337
+ Args:
338
+ pro_api_url: The URL of the ZenML Pro API server.
339
+ """
340
+ self.clear_token(pro_api_url)
341
+
342
+ def clear_all_pro_tokens(
343
+ self, pro_api_url: str
344
+ ) -> List[ServerCredentials]:
345
+ """Delete all tokens from the store for ZenML Pro servers connected to a given API server.
346
+
347
+ Args:
348
+ pro_api_url: The URL of the ZenML Pro API server.
330
349
 
331
- def clear_all_pro_tokens(self) -> None:
332
- """Delete all tokens from the store for ZenML Pro API servers."""
350
+ Returns:
351
+ A list of the credentials that were cleared.
352
+ """
353
+ credentials_to_clear = []
333
354
  for server_url, server in self.credentials.copy().items():
334
- if server.type == ServerType.PRO:
355
+ if (
356
+ server.type == ServerType.PRO
357
+ and server.pro_api_url
358
+ and server.pro_api_url == pro_api_url
359
+ ):
335
360
  if server.api_key:
336
361
  continue
337
362
  self.clear_token(server_url)
363
+ credentials_to_clear.append(server)
364
+ return credentials_to_clear
338
365
 
339
366
  def has_valid_authentication(self, url: str) -> bool:
340
367
  """Check if a valid authentication credential for the given server URL is stored.
@@ -357,13 +384,16 @@ class CredentialsStore(metaclass=SingletonMetaClass):
357
384
  token = credential.api_token
358
385
  return token is not None and not token.expired
359
386
 
360
- def has_valid_pro_authentication(self) -> bool:
361
- """Check if a valid token for the ZenML Pro API server is stored.
387
+ def has_valid_pro_authentication(self, pro_api_url: str) -> bool:
388
+ """Check if a valid token for a ZenML Pro API server is stored.
389
+
390
+ Args:
391
+ pro_api_url: The URL of the ZenML Pro API server.
362
392
 
363
393
  Returns:
364
394
  bool: True if a valid token is stored, False otherwise.
365
395
  """
366
- return self.get_token(ZENML_PRO_API_URL) is not None
396
+ return self.get_pro_token(pro_api_url) is not None
367
397
 
368
398
  def set_api_key(
369
399
  self,
@@ -433,12 +463,14 @@ class CredentialsStore(metaclass=SingletonMetaClass):
433
463
  self,
434
464
  server_url: str,
435
465
  token_response: OAuthTokenResponse,
466
+ is_zenml_pro: bool = False,
436
467
  ) -> APIToken:
437
468
  """Store an API token received from an OAuth2 server.
438
469
 
439
470
  Args:
440
471
  server_url: The server URL for which the token is to be stored.
441
472
  token_response: Token response received from an OAuth2 server.
473
+ is_zenml_pro: Whether the token is for a ZenML Pro server.
442
474
 
443
475
  Returns:
444
476
  APIToken: The stored token.
@@ -468,7 +500,6 @@ class CredentialsStore(metaclass=SingletonMetaClass):
468
500
  expires_in=token_response.expires_in,
469
501
  expires_at=expires_at,
470
502
  leeway=leeway,
471
- cookie_name=token_response.cookie_name,
472
503
  device_id=token_response.device_id,
473
504
  device_metadata=token_response.device_metadata,
474
505
  )
@@ -477,10 +508,14 @@ class CredentialsStore(metaclass=SingletonMetaClass):
477
508
  if credential:
478
509
  credential.api_token = api_token
479
510
  else:
480
- self.credentials[server_url] = ServerCredentials(
511
+ credential = self.credentials[server_url] = ServerCredentials(
481
512
  url=server_url, api_token=api_token
482
513
  )
483
514
 
515
+ if is_zenml_pro:
516
+ # This is how we encode that the token is for a ZenML Pro server
517
+ credential.pro_api_url = server_url
518
+
484
519
  self._save_credentials()
485
520
 
486
521
  return api_token
zenml/login/pro/client.py CHANGED
@@ -33,7 +33,6 @@ from zenml.exceptions import AuthorizationException
33
33
  from zenml.logger import get_logger
34
34
  from zenml.login.credentials import APIToken
35
35
  from zenml.login.credentials_store import get_credentials_store
36
- from zenml.login.pro.constants import ZENML_PRO_API_URL
37
36
  from zenml.login.pro.models import BaseRestAPIModel
38
37
  from zenml.utils.singleton import SingletonMetaClass
39
38
  from zenml.zen_server.exceptions import exception_from_response
@@ -60,14 +59,11 @@ class ZenMLProClient(metaclass=SingletonMetaClass):
60
59
  _tenant: Optional["TenantClient"] = None
61
60
  _organization: Optional["OrganizationClient"] = None
62
61
 
63
- def __init__(
64
- self, url: Optional[str] = None, api_token: Optional[APIToken] = None
65
- ) -> None:
62
+ def __init__(self, url: str, api_token: Optional[APIToken] = None) -> None:
66
63
  """Initialize the ZenML Pro client.
67
64
 
68
65
  Args:
69
- url: The URL of the ZenML Pro API server. If not provided, the
70
- default ZenML Pro API server URL is used.
66
+ url: The URL of the ZenML Pro API server.
71
67
  api_token: The API token to use for authentication. If not provided,
72
68
  the token is fetched from the credentials store.
73
69
 
@@ -75,7 +71,7 @@ class ZenMLProClient(metaclass=SingletonMetaClass):
75
71
  AuthorizationException: If no API token is provided and no token
76
72
  is found in the credentials store.
77
73
  """
78
- self._url = url or ZENML_PRO_API_URL
74
+ self._url = url
79
75
  if api_token is None:
80
76
  logger.debug(
81
77
  "No ZenML Pro API token provided. Fetching from credentials "
@@ -26,9 +26,3 @@ ENV_ZENML_PRO_URL = "ZENML_PRO_URL"
26
26
  DEFAULT_ZENML_PRO_URL = "https://cloud.zenml.io"
27
27
 
28
28
  ZENML_PRO_URL = os.getenv(ENV_ZENML_PRO_URL, default=DEFAULT_ZENML_PRO_URL)
29
-
30
- ENV_ZENML_PRO_SERVER_SUBDOMAIN = "ZENML_PRO_SERVER_SUBDOMAIN"
31
- DEFAULT_ZENML_PRO_SERVER_SUBDOMAIN = "cloudinfra.zenml.io"
32
- ZENML_PRO_SERVER_SUBDOMAIN = os.getenv(
33
- ENV_ZENML_PRO_SERVER_SUBDOMAIN, default=DEFAULT_ZENML_PRO_SERVER_SUBDOMAIN
34
- )
@@ -76,7 +76,7 @@ class ZenMLServiceStatus(BaseRestAPIModel):
76
76
  class ZenMLServiceRead(BaseRestAPIModel):
77
77
  """Pydantic Model for viewing a ZenML service."""
78
78
 
79
- configuration: ZenMLServiceConfiguration = Field(
79
+ configuration: Optional[ZenMLServiceConfiguration] = Field(
80
80
  description="The service configuration."
81
81
  )
82
82
 
@@ -133,7 +133,9 @@ class TenantRead(BaseRestAPIModel):
133
133
  Returns:
134
134
  The ZenML service version.
135
135
  """
136
- version = self.zenml_service.configuration.version
136
+ version = None
137
+ if self.zenml_service.configuration:
138
+ version = self.zenml_service.configuration.version
137
139
  if self.zenml_service.status and self.zenml_service.status.version:
138
140
  version = self.zenml_service.status.version
139
141
 
zenml/login/pro/utils.py CHANGED
@@ -13,37 +13,16 @@
13
13
  # permissions and limitations under the License.
14
14
  """ZenML Pro login utils."""
15
15
 
16
- import re
17
-
18
16
  from zenml.logger import get_logger
17
+ from zenml.login.credentials import ServerType
19
18
  from zenml.login.credentials_store import get_credentials_store
20
19
  from zenml.login.pro.client import ZenMLProClient
21
- from zenml.login.pro.constants import ZENML_PRO_SERVER_SUBDOMAIN
20
+ from zenml.login.pro.constants import ZENML_PRO_API_URL
22
21
  from zenml.login.pro.tenant.models import TenantStatus
23
22
 
24
23
  logger = get_logger(__name__)
25
24
 
26
25
 
27
- def is_zenml_pro_server_url(url: str) -> bool:
28
- """Check if a given URL is a ZenML Pro server.
29
-
30
- Args:
31
- url: URL to check
32
-
33
- Returns:
34
- True if the URL is a ZenML Pro tenant, False otherwise
35
- """
36
- domain_regex = ZENML_PRO_SERVER_SUBDOMAIN.replace(".", r"\.")
37
- return bool(
38
- re.match(
39
- r"^(https://)?[a-zA-Z0-9-\.]+\.{domain}/?$".format(
40
- domain=domain_regex
41
- ),
42
- url,
43
- )
44
- )
45
-
46
-
47
26
  def get_troubleshooting_instructions(url: str) -> str:
48
27
  """Get troubleshooting instructions for a given ZenML Pro server URL.
49
28
 
@@ -54,8 +33,15 @@ def get_troubleshooting_instructions(url: str) -> str:
54
33
  Troubleshooting instructions
55
34
  """
56
35
  credentials_store = get_credentials_store()
57
- if credentials_store.has_valid_pro_authentication():
58
- client = ZenMLProClient()
36
+
37
+ credentials = credentials_store.get_credentials(url)
38
+ if credentials and credentials.type == ServerType.PRO:
39
+ pro_api_url = credentials.pro_api_url or ZENML_PRO_API_URL
40
+
41
+ if pro_api_url and credentials_store.has_valid_pro_authentication(
42
+ pro_api_url
43
+ ):
44
+ client = ZenMLProClient(pro_api_url)
59
45
 
60
46
  try:
61
47
  servers = client.tenant.list(url=url, member_only=False)
@@ -0,0 +1,52 @@
1
+ # Copyright (c) ZenML GmbH 2024. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12
+ # or implied. See the License for the specific language governing
13
+ # permissions and limitations under the License.
14
+ """ZenML server information retrieval."""
15
+
16
+ from typing import Optional
17
+
18
+ from zenml.logger import get_logger
19
+ from zenml.models import ServerModel
20
+ from zenml.zen_stores.rest_zen_store import (
21
+ RestZenStore,
22
+ RestZenStoreConfiguration,
23
+ )
24
+
25
+ logger = get_logger(__name__)
26
+
27
+
28
+ def get_server_info(url: str) -> Optional[ServerModel]:
29
+ """Retrieve server information from a remote ZenML server.
30
+
31
+ Args:
32
+ url: The URL of the ZenML server.
33
+
34
+ Returns:
35
+ The server information or None if the server info could not be fetched.
36
+ """
37
+ # Here we try to leverage the existing RestZenStore support to fetch the
38
+ # server info and only the server info, which doesn't actually need
39
+ # any authentication.
40
+ try:
41
+ store = RestZenStore(
42
+ config=RestZenStoreConfiguration(
43
+ url=url,
44
+ )
45
+ )
46
+ return store.server_info
47
+ except Exception as e:
48
+ logger.warning(
49
+ f"Failed to fetch server info from the server running at {url}: {e}"
50
+ )
51
+
52
+ return None
zenml/login/web_login.py CHANGED
@@ -34,13 +34,14 @@ from zenml.exceptions import AuthorizationException, OAuthError
34
34
  from zenml.logger import get_logger
35
35
  from zenml.login.credentials import APIToken
36
36
  from zenml.login.pro.constants import ZENML_PRO_API_URL
37
- from zenml.login.pro.utils import is_zenml_pro_server_url
38
37
 
39
38
  logger = get_logger(__name__)
40
39
 
41
40
 
42
41
  def web_login(
43
- url: Optional[str] = None, verify_ssl: Optional[Union[str, bool]] = None
42
+ url: Optional[str] = None,
43
+ verify_ssl: Optional[Union[str, bool]] = None,
44
+ pro_api_url: Optional[str] = None,
44
45
  ) -> APIToken:
45
46
  """Implements the OAuth2 Device Authorization Grant flow.
46
47
 
@@ -61,6 +62,8 @@ def web_login(
61
62
  verify_ssl: Whether to verify the SSL certificate of the OAuth2 server.
62
63
  If a string is passed, it is interpreted as the path to a CA bundle
63
64
  file.
65
+ pro_api_url: The URL of the ZenML Pro API server. If not provided, the
66
+ default ZenML Pro API server URL is used.
64
67
 
65
68
  Returns:
66
69
  The response returned by the OAuth2 server.
@@ -103,16 +106,16 @@ def web_login(
103
106
  if not url:
104
107
  # If no URL is provided, we use the ZenML Pro API server by default
105
108
  zenml_pro = True
106
- url = base_url = ZENML_PRO_API_URL
109
+ url = base_url = pro_api_url or ZENML_PRO_API_URL
107
110
  else:
108
111
  # Get rid of any trailing slashes to prevent issues when having double
109
112
  # slashes in the URL
110
113
  url = url.rstrip("/")
111
- if is_zenml_pro_server_url(url):
114
+ if pro_api_url:
112
115
  # This is a ZenML Pro server. The device authentication is done
113
116
  # through the ZenML Pro API.
114
117
  zenml_pro = True
115
- base_url = ZENML_PRO_API_URL
118
+ base_url = pro_api_url
116
119
  else:
117
120
  base_url = url
118
121
 
@@ -240,4 +243,6 @@ def web_login(
240
243
  )
241
244
 
242
245
  # Save the token in the credentials store
243
- return credentials_store.set_token(url, token_response)
246
+ return credentials_store.set_token(
247
+ url, token_response, is_zenml_pro=zenml_pro
248
+ )
@@ -119,8 +119,8 @@ class OAuthTokenResponse(BaseModel):
119
119
  token_type: str
120
120
  expires_in: Optional[int] = None
121
121
  refresh_token: Optional[str] = None
122
+ csrf_token: Optional[str] = None
122
123
  scope: Optional[str] = None
123
- cookie_name: Optional[str] = None
124
124
  device_id: Optional[UUID] = None
125
125
  device_metadata: Optional[Dict[str, Any]] = None
126
126
 
@@ -107,6 +107,42 @@ class ServerModel(BaseModel):
107
107
  title="Timestamp of latest user activity traced on the server.",
108
108
  )
109
109
 
110
+ pro_dashboard_url: Optional[str] = Field(
111
+ None,
112
+ title="The base URL of the ZenML Pro dashboard to which the server "
113
+ "is connected. Only set if the server is a ZenML Pro server.",
114
+ )
115
+
116
+ pro_api_url: Optional[str] = Field(
117
+ None,
118
+ title="The base URL of the ZenML Pro API to which the server is "
119
+ "connected. Only set if the server is a ZenML Pro server.",
120
+ )
121
+
122
+ pro_organization_id: Optional[UUID] = Field(
123
+ None,
124
+ title="The ID of the ZenML Pro organization to which the server is "
125
+ "connected. Only set if the server is a ZenML Pro server.",
126
+ )
127
+
128
+ pro_organization_name: Optional[str] = Field(
129
+ None,
130
+ title="The name of the ZenML Pro organization to which the server is "
131
+ "connected. Only set if the server is a ZenML Pro server.",
132
+ )
133
+
134
+ pro_tenant_id: Optional[UUID] = Field(
135
+ None,
136
+ title="The ID of the ZenML Pro tenant to which the server is connected. "
137
+ "Only set if the server is a ZenML Pro server.",
138
+ )
139
+
140
+ pro_tenant_name: Optional[str] = Field(
141
+ None,
142
+ title="The name of the ZenML Pro tenant to which the server is connected. "
143
+ "Only set if the server is a ZenML Pro server.",
144
+ )
145
+
110
146
  def is_local(self) -> bool:
111
147
  """Return whether the server is running locally.
112
148
 
@@ -119,6 +155,14 @@ class ServerModel(BaseModel):
119
155
  # server ID is the same as the local client (user) ID.
120
156
  return self.id == GlobalConfiguration().user_id
121
157
 
158
+ def is_pro_server(self) -> bool:
159
+ """Return whether the server is a ZenML Pro server.
160
+
161
+ Returns:
162
+ True if the server is a ZenML Pro server, False otherwise.
163
+ """
164
+ return self.deployment_type == ServerDeploymentType.CLOUD
165
+
122
166
 
123
167
  class ServerLoadInfo(BaseModel):
124
168
  """Domain model for ZenML server load information."""