zenml-nightly 0.83.0.dev20250619__py3-none-any.whl → 0.83.0.dev20250622__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. zenml/VERSION +1 -1
  2. zenml/__init__.py +12 -2
  3. zenml/analytics/context.py +4 -2
  4. zenml/config/server_config.py +6 -1
  5. zenml/constants.py +3 -0
  6. zenml/entrypoints/step_entrypoint_configuration.py +14 -0
  7. zenml/models/__init__.py +15 -0
  8. zenml/models/v2/core/api_transaction.py +193 -0
  9. zenml/models/v2/core/pipeline_build.py +4 -0
  10. zenml/models/v2/core/pipeline_deployment.py +8 -1
  11. zenml/models/v2/core/pipeline_run.py +7 -0
  12. zenml/models/v2/core/step_run.py +6 -0
  13. zenml/orchestrators/input_utils.py +34 -11
  14. zenml/utils/json_utils.py +1 -1
  15. zenml/zen_server/auth.py +53 -31
  16. zenml/zen_server/cloud_utils.py +19 -7
  17. zenml/zen_server/middleware.py +424 -0
  18. zenml/zen_server/rbac/endpoint_utils.py +5 -2
  19. zenml/zen_server/rbac/utils.py +12 -7
  20. zenml/zen_server/request_management.py +556 -0
  21. zenml/zen_server/routers/auth_endpoints.py +1 -0
  22. zenml/zen_server/routers/model_versions_endpoints.py +3 -3
  23. zenml/zen_server/routers/models_endpoints.py +3 -3
  24. zenml/zen_server/routers/pipeline_builds_endpoints.py +2 -2
  25. zenml/zen_server/routers/pipeline_deployments_endpoints.py +9 -4
  26. zenml/zen_server/routers/pipelines_endpoints.py +4 -4
  27. zenml/zen_server/routers/run_templates_endpoints.py +3 -3
  28. zenml/zen_server/routers/runs_endpoints.py +4 -4
  29. zenml/zen_server/routers/service_connectors_endpoints.py +6 -6
  30. zenml/zen_server/routers/steps_endpoints.py +3 -3
  31. zenml/zen_server/utils.py +230 -63
  32. zenml/zen_server/zen_server_api.py +34 -399
  33. zenml/zen_stores/migrations/versions/3d7e39f3ac92_split_up_step_configurations.py +138 -0
  34. zenml/zen_stores/migrations/versions/857843db1bcf_add_api_transaction_table.py +69 -0
  35. zenml/zen_stores/rest_zen_store.py +52 -42
  36. zenml/zen_stores/schemas/__init__.py +4 -0
  37. zenml/zen_stores/schemas/api_transaction_schemas.py +141 -0
  38. zenml/zen_stores/schemas/pipeline_deployment_schemas.py +88 -27
  39. zenml/zen_stores/schemas/pipeline_run_schemas.py +28 -11
  40. zenml/zen_stores/schemas/step_run_schemas.py +4 -4
  41. zenml/zen_stores/sql_zen_store.py +277 -42
  42. zenml/zen_stores/zen_store_interface.py +7 -1
  43. {zenml_nightly-0.83.0.dev20250619.dist-info → zenml_nightly-0.83.0.dev20250622.dist-info}/METADATA +1 -1
  44. {zenml_nightly-0.83.0.dev20250619.dist-info → zenml_nightly-0.83.0.dev20250622.dist-info}/RECORD +47 -41
  45. {zenml_nightly-0.83.0.dev20250619.dist-info → zenml_nightly-0.83.0.dev20250622.dist-info}/LICENSE +0 -0
  46. {zenml_nightly-0.83.0.dev20250619.dist-info → zenml_nightly-0.83.0.dev20250622.dist-info}/WHEEL +0 -0
  47. {zenml_nightly-0.83.0.dev20250619.dist-info → zenml_nightly-0.83.0.dev20250622.dist-info}/entry_points.txt +0 -0
zenml/VERSION CHANGED
@@ -1 +1 @@
1
- 0.83.0.dev20250619
1
+ 0.83.0.dev20250622
zenml/__init__.py CHANGED
@@ -14,6 +14,7 @@
14
14
  """Initialization for ZenML."""
15
15
 
16
16
  import os
17
+ from typing import Any
17
18
 
18
19
  ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
19
20
 
@@ -26,6 +27,16 @@ from zenml.logger import init_logging # noqa
26
27
 
27
28
  init_logging()
28
29
 
30
+ def __getattr__(name: str) -> Any:
31
+ # We allow directly accessing the entrypoint module as `zenml.entrypoint`
32
+ # as this is needed for some orchestrators. Instead of directly importing
33
+ # the entrypoint module here, we import it dynamically. This avoids a
34
+ # warning when running the `zenml.entrypoints.entrypoint` module directly.
35
+ if name == "entrypoint":
36
+ from zenml.entrypoints import entrypoint
37
+ return entrypoint
38
+
39
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
29
40
 
30
41
  # Need to import zenml.models before zenml.config to avoid circular imports
31
42
  from zenml.models import * # noqa: F401
@@ -50,7 +61,7 @@ from zenml.steps import step, get_step_context
50
61
  from zenml.steps.utils import log_step_metadata
51
62
  from zenml.utils.metadata_utils import log_metadata
52
63
  from zenml.utils.tag_utils import Tag, add_tags, remove_tags
53
- from zenml.entrypoints import entrypoint
64
+
54
65
 
55
66
  __all__ = [
56
67
  "add_tags",
@@ -72,5 +83,4 @@ __all__ = [
72
83
  "register_artifact",
73
84
  "show",
74
85
  "step",
75
- "entrypoint",
76
86
  ]
@@ -105,8 +105,10 @@ class AnalyticsContext:
105
105
 
106
106
  # Fetch the `user_id`
107
107
  if self.in_server:
108
- from zenml.zen_server.auth import get_auth_context
109
- from zenml.zen_server.utils import server_config
108
+ from zenml.zen_server.utils import (
109
+ get_auth_context,
110
+ server_config,
111
+ )
110
112
 
111
113
  # If the code is running on the server, use the auth context.
112
114
  auth_context = get_auth_context()
@@ -33,6 +33,7 @@ from zenml.constants import (
33
33
  DEFAULT_REPORTABLE_RESOURCES,
34
34
  DEFAULT_ZENML_JWT_TOKEN_ALGORITHM,
35
35
  DEFAULT_ZENML_JWT_TOKEN_LEEWAY,
36
+ DEFAULT_ZENML_SERVER_AUTH_THREAD_POOL_SIZE,
36
37
  DEFAULT_ZENML_SERVER_DEVICE_AUTH_POLLING,
37
38
  DEFAULT_ZENML_SERVER_DEVICE_AUTH_TIMEOUT,
38
39
  DEFAULT_ZENML_SERVER_FILE_DOWNLOAD_SIZE_LIMIT,
@@ -45,6 +46,7 @@ from zenml.constants import (
45
46
  DEFAULT_ZENML_SERVER_MAX_REQUEST_BODY_SIZE_IN_BYTES,
46
47
  DEFAULT_ZENML_SERVER_NAME,
47
48
  DEFAULT_ZENML_SERVER_PIPELINE_RUN_AUTH_WINDOW,
49
+ DEFAULT_ZENML_SERVER_REQUEST_CACHE_TIMEOUT,
48
50
  DEFAULT_ZENML_SERVER_REQUEST_TIMEOUT,
49
51
  DEFAULT_ZENML_SERVER_SECURE_HEADERS_CACHE,
50
52
  DEFAULT_ZENML_SERVER_SECURE_HEADERS_CONTENT,
@@ -354,7 +356,10 @@ class ServerConfiguration(BaseModel):
354
356
  auto_activate: bool = False
355
357
 
356
358
  thread_pool_size: int = DEFAULT_ZENML_SERVER_THREAD_POOL_SIZE
357
- server_request_timeout: int = DEFAULT_ZENML_SERVER_REQUEST_TIMEOUT
359
+ auth_thread_pool_size: int = DEFAULT_ZENML_SERVER_AUTH_THREAD_POOL_SIZE
360
+ request_timeout: int = DEFAULT_ZENML_SERVER_REQUEST_TIMEOUT
361
+ request_deduplication: bool = True
362
+ request_cache_timeout: int = DEFAULT_ZENML_SERVER_REQUEST_CACHE_TIMEOUT
358
363
 
359
364
  max_request_body_size_in_bytes: int = (
360
365
  DEFAULT_ZENML_SERVER_MAX_REQUEST_BODY_SIZE_IN_BYTES
zenml/constants.py CHANGED
@@ -274,6 +274,7 @@ CODE_HASH_PARAMETER_NAME = "step_source"
274
274
  # Server settings
275
275
  DEFAULT_ZENML_SERVER_NAME = "default"
276
276
  DEFAULT_ZENML_SERVER_THREAD_POOL_SIZE = 40
277
+ DEFAULT_ZENML_SERVER_AUTH_THREAD_POOL_SIZE = 5
277
278
  DEFAULT_ZENML_JWT_TOKEN_LEEWAY = 10
278
279
  DEFAULT_ZENML_JWT_TOKEN_ALGORITHM = "HS256"
279
280
  DEFAULT_ZENML_AUTH_SCHEME = AuthScheme.OAUTH2_PASSWORD_BEARER
@@ -283,6 +284,7 @@ DEFAULT_ZENML_SERVER_DEVICE_AUTH_TIMEOUT = 60 * 5 # 5 minutes
283
284
  DEFAULT_ZENML_SERVER_DEVICE_AUTH_POLLING = 5 # seconds
284
285
  DEFAULT_HTTP_TIMEOUT = 30
285
286
  DEFAULT_ZENML_SERVER_REQUEST_TIMEOUT = 20 # seconds
287
+ DEFAULT_ZENML_SERVER_REQUEST_CACHE_TIMEOUT = 300 # seconds
286
288
  SERVICE_CONNECTOR_VERIFY_REQUEST_TIMEOUT = 120 # seconds
287
289
  ZENML_API_KEY_PREFIX = "ZENKEY_"
288
290
  DEFAULT_ZENML_SERVER_PIPELINE_RUN_AUTH_WINDOW = 60 * 48 # 48 hours
@@ -372,6 +374,7 @@ EVENT_FLAVORS = "/event-flavors"
372
374
  EVENT_SOURCES = "/event-sources"
373
375
  FLAVORS = "/flavors"
374
376
  HEALTH = "/health"
377
+ READY = "/ready"
375
378
  INFO = "/info"
376
379
  LOAD_INFO = "/load-info"
377
380
  LOGIN = "/login"
@@ -16,9 +16,11 @@
16
16
  import os
17
17
  import sys
18
18
  from typing import TYPE_CHECKING, Any, List, Set
19
+ from uuid import UUID
19
20
 
20
21
  from zenml.client import Client
21
22
  from zenml.entrypoints.base_entrypoint_configuration import (
23
+ DEPLOYMENT_ID_OPTION,
22
24
  BaseEntrypointConfiguration,
23
25
  )
24
26
  from zenml.integrations.registry import integration_registry
@@ -147,6 +149,18 @@ class StepEntrypointConfiguration(BaseEntrypointConfiguration):
147
149
  kwargs[STEP_NAME_OPTION],
148
150
  ]
149
151
 
152
+ def load_deployment(self) -> "PipelineDeploymentResponse":
153
+ """Loads the deployment.
154
+
155
+ Returns:
156
+ The deployment.
157
+ """
158
+ deployment_id = UUID(self.entrypoint_args[DEPLOYMENT_ID_OPTION])
159
+ step_name = self.entrypoint_args[STEP_NAME_OPTION]
160
+ return Client().zen_store.get_deployment(
161
+ deployment_id=deployment_id, step_configuration_filter=[step_name]
162
+ )
163
+
150
164
  def run(self) -> None:
151
165
  """Prepares the environment and runs the configured step."""
152
166
  deployment = self.load_deployment()
zenml/models/__init__.py CHANGED
@@ -80,6 +80,14 @@ from zenml.models.v2.core.api_key import (
80
80
  APIKeyInternalUpdate,
81
81
  APIKeyRotateRequest,
82
82
  )
83
+ from zenml.models.v2.core.api_transaction import (
84
+ ApiTransactionRequest,
85
+ ApiTransactionUpdate,
86
+ ApiTransactionResponse,
87
+ ApiTransactionResponseBody,
88
+ ApiTransactionResponseMetadata,
89
+ ApiTransactionResponseResources,
90
+ )
83
91
  from zenml.models.v2.core.artifact import (
84
92
  ArtifactFilter,
85
93
  ArtifactRequest,
@@ -433,6 +441,7 @@ from zenml.models.v2.misc.info_models import (
433
441
  ActionResponseResources.model_rebuild()
434
442
  ActionResponseMetadata.model_rebuild()
435
443
  APIKeyResponseBody.model_rebuild()
444
+ ApiTransactionResponse.model_rebuild()
436
445
  ArtifactResponse.model_rebuild()
437
446
  ArtifactResponseBody.model_rebuild()
438
447
  ArtifactResponseMetadata.model_rebuild()
@@ -570,6 +579,12 @@ __all__ = [
570
579
  "APIKeyInternalResponse",
571
580
  "APIKeyInternalUpdate",
572
581
  "APIKeyRotateRequest",
582
+ "ApiTransactionRequest",
583
+ "ApiTransactionUpdate",
584
+ "ApiTransactionResponse",
585
+ "ApiTransactionResponseBody",
586
+ "ApiTransactionResponseMetadata",
587
+ "ApiTransactionResponseResources",
573
588
  "ArtifactFilter",
574
589
  "ArtifactRequest",
575
590
  "ArtifactResponse",
@@ -0,0 +1,193 @@
1
+ # Copyright (c) ZenML GmbH 2024. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12
+ # or implied. See the License for the specific language governing
13
+ # permissions and limitations under the License.
14
+ """Models representing API transactions."""
15
+
16
+ from typing import (
17
+ TYPE_CHECKING,
18
+ Optional,
19
+ TypeVar,
20
+ )
21
+ from uuid import UUID
22
+
23
+ from pydantic import Field, SecretStr
24
+
25
+ from zenml.models.v2.base.base import BaseUpdate
26
+ from zenml.models.v2.base.scoped import (
27
+ UserScopedRequest,
28
+ UserScopedResponse,
29
+ UserScopedResponseBody,
30
+ UserScopedResponseMetadata,
31
+ UserScopedResponseResources,
32
+ )
33
+ from zenml.utils.secret_utils import PlainSerializedSecretStr
34
+
35
+ if TYPE_CHECKING:
36
+ from zenml.zen_stores.schemas import BaseSchema
37
+
38
+ AnySchema = TypeVar("AnySchema", bound=BaseSchema)
39
+
40
+ # ------------------ Request Model ------------------
41
+
42
+
43
+ class ApiTransactionRequest(UserScopedRequest):
44
+ """Request model for API transactions."""
45
+
46
+ transaction_id: UUID = Field(
47
+ title="The ID of the transaction.",
48
+ )
49
+ method: str = Field(
50
+ title="The HTTP method of the transaction.",
51
+ )
52
+ url: str = Field(
53
+ title="The URL of the transaction.",
54
+ )
55
+
56
+
57
+ # ------------------ Update Model ------------------
58
+
59
+
60
+ class ApiTransactionUpdate(BaseUpdate):
61
+ """Update model for stack components."""
62
+
63
+ result: Optional[PlainSerializedSecretStr] = Field(
64
+ default=None,
65
+ title="The response payload.",
66
+ )
67
+ cache_time: int = Field(
68
+ title="The time in seconds that the transaction is kept around after "
69
+ "completion."
70
+ )
71
+
72
+ def get_result(self) -> Optional[str]:
73
+ """Get the result of the API transaction.
74
+
75
+ Returns:
76
+ the result of the API transaction.
77
+ """
78
+ result = self.result
79
+ if result is None:
80
+ return None
81
+ return result.get_secret_value()
82
+
83
+ def set_result(self, result: str) -> None:
84
+ """Set the result of the API transaction.
85
+
86
+ Args:
87
+ result: the result of the API transaction.
88
+ """
89
+ self.result = SecretStr(result)
90
+
91
+
92
+ # ------------------ Response Model ------------------
93
+
94
+
95
+ class ApiTransactionResponseBody(UserScopedResponseBody):
96
+ """Response body for API transactions."""
97
+
98
+ method: str = Field(
99
+ title="The HTTP method of the transaction.",
100
+ )
101
+ url: str = Field(
102
+ title="The URL of the transaction.",
103
+ )
104
+ completed: bool = Field(
105
+ title="Whether the transaction is completed.",
106
+ )
107
+ result: Optional[PlainSerializedSecretStr] = Field(
108
+ default=None,
109
+ title="The response payload.",
110
+ )
111
+
112
+
113
+ class ApiTransactionResponseMetadata(UserScopedResponseMetadata):
114
+ """Response metadata for API transactions."""
115
+
116
+
117
+ class ApiTransactionResponseResources(UserScopedResponseResources):
118
+ """Response resources for API transactions."""
119
+
120
+
121
+ class ApiTransactionResponse(
122
+ UserScopedResponse[
123
+ ApiTransactionResponseBody,
124
+ ApiTransactionResponseMetadata,
125
+ ApiTransactionResponseResources,
126
+ ]
127
+ ):
128
+ """Response model for API transactions."""
129
+
130
+ def get_hydrated_version(self) -> "ApiTransactionResponse":
131
+ """Get the hydrated version of this API transaction.
132
+
133
+ Returns:
134
+ an instance of the same entity with the metadata field attached.
135
+ """
136
+ return self
137
+
138
+ # Body and metadata properties
139
+
140
+ @property
141
+ def method(self) -> str:
142
+ """The `method` property.
143
+
144
+ Returns:
145
+ the value of the property.
146
+ """
147
+ return self.get_body().method
148
+
149
+ @property
150
+ def url(self) -> str:
151
+ """The `url` property.
152
+
153
+ Returns:
154
+ the value of the property.
155
+ """
156
+ return self.get_body().url
157
+
158
+ @property
159
+ def completed(self) -> bool:
160
+ """The `completed` property.
161
+
162
+ Returns:
163
+ the value of the property.
164
+ """
165
+ return self.get_body().completed
166
+
167
+ @property
168
+ def result(self) -> Optional[PlainSerializedSecretStr]:
169
+ """The `result` property.
170
+
171
+ Returns:
172
+ the value of the property.
173
+ """
174
+ return self.get_body().result
175
+
176
+ def get_result(self) -> Optional[str]:
177
+ """Get the result of the API transaction.
178
+
179
+ Returns:
180
+ the result of the API transaction.
181
+ """
182
+ result = self.result
183
+ if result is None:
184
+ return None
185
+ return result.get_secret_value()
186
+
187
+ def set_result(self, result: str) -> None:
188
+ """Set the result of the API transaction.
189
+
190
+ Args:
191
+ result: the result of the API transaction.
192
+ """
193
+ self.get_body().result = SecretStr(result)
@@ -200,6 +200,10 @@ class PipelineBuildResponseBody(ProjectScopedResponseBody):
200
200
  class PipelineBuildResponseMetadata(ProjectScopedResponseMetadata):
201
201
  """Response metadata for pipeline builds."""
202
202
 
203
+ __zenml_skip_dehydration__: ClassVar[List[str]] = [
204
+ "images",
205
+ ]
206
+
203
207
  pipeline: Optional["PipelineResponse"] = Field(
204
208
  default=None, title="The pipeline that was used for this build."
205
209
  )
@@ -13,7 +13,7 @@
13
13
  # permissions and limitations under the License.
14
14
  """Models representing pipeline deployments."""
15
15
 
16
- from typing import Any, Dict, Optional, Union
16
+ from typing import Any, ClassVar, Dict, List, Optional, Union
17
17
  from uuid import UUID
18
18
 
19
19
  from pydantic import Field
@@ -130,6 +130,13 @@ class PipelineDeploymentResponseBody(ProjectScopedResponseBody):
130
130
  class PipelineDeploymentResponseMetadata(ProjectScopedResponseMetadata):
131
131
  """Response metadata for pipeline deployments."""
132
132
 
133
+ __zenml_skip_dehydration__: ClassVar[List[str]] = [
134
+ "pipeline_configuration",
135
+ "step_configurations",
136
+ "client_environment",
137
+ "pipeline_spec",
138
+ ]
139
+
133
140
  run_name_template: str = Field(
134
141
  title="The run name template for runs created using this deployment.",
135
142
  )
@@ -190,6 +190,13 @@ class PipelineRunResponseBody(ProjectScopedResponseBody):
190
190
  class PipelineRunResponseMetadata(ProjectScopedResponseMetadata):
191
191
  """Response metadata for pipeline runs."""
192
192
 
193
+ __zenml_skip_dehydration__: ClassVar[List[str]] = [
194
+ "run_metadata",
195
+ "config",
196
+ "client_environment",
197
+ "orchestrator_environment",
198
+ ]
199
+
193
200
  run_metadata: Dict[str, MetadataType] = Field(
194
201
  default={},
195
202
  title="Metadata associated with this pipeline run.",
@@ -199,6 +199,12 @@ class StepRunResponseBody(ProjectScopedResponseBody):
199
199
  class StepRunResponseMetadata(ProjectScopedResponseMetadata):
200
200
  """Response metadata for step runs."""
201
201
 
202
+ __zenml_skip_dehydration__: ClassVar[List[str]] = [
203
+ "config",
204
+ "spec",
205
+ "metadata",
206
+ ]
207
+
202
208
  # Configuration
203
209
  config: "StepConfiguration" = Field(title="The configuration of the step.")
204
210
  spec: "StepSpec" = Field(title="The spec of the step.")
@@ -65,17 +65,40 @@ def resolve_step_inputs(
65
65
  steps_to_fetch.difference_update(step_runs.keys())
66
66
 
67
67
  if steps_to_fetch:
68
- step_runs.update(
69
- {
70
- run_step.name: run_step
71
- for run_step in pagination_utils.depaginate(
72
- Client().list_run_steps,
73
- pipeline_run_id=pipeline_run.id,
74
- project=pipeline_run.project_id,
75
- name="oneof:" + json.dumps(list(steps_to_fetch)),
76
- )
77
- }
78
- )
68
+ # The list of steps might be too big to fit in the default max URL
69
+ # length of 8KB supported by most servers. So we need to split it into
70
+ # smaller chunks.
71
+ steps_list = list(steps_to_fetch)
72
+ chunks = []
73
+ current_chunk = []
74
+ current_length = 0
75
+ # stay under 6KB for good measure.
76
+ max_chunk_length = 6000
77
+
78
+ for step_name in steps_list:
79
+ current_chunk.append(step_name)
80
+ current_length += len(step_name) + 5 # 5 is for the JSON encoding
81
+
82
+ if current_length > max_chunk_length:
83
+ chunks.append(current_chunk)
84
+ current_chunk = []
85
+ current_length = 0
86
+
87
+ if current_chunk:
88
+ chunks.append(current_chunk)
89
+
90
+ for chunk in chunks:
91
+ step_runs.update(
92
+ {
93
+ run_step.name: run_step
94
+ for run_step in pagination_utils.depaginate(
95
+ Client().list_run_steps,
96
+ pipeline_run_id=pipeline_run.id,
97
+ project=pipeline_run.project_id,
98
+ name="oneof:" + json.dumps(chunk),
99
+ )
100
+ }
101
+ )
79
102
 
80
103
  input_artifacts: Dict[str, StepRunInputResponse] = {}
81
104
  for name, input_ in step.spec.inputs.items():
zenml/utils/json_utils.py CHANGED
@@ -113,7 +113,7 @@ def pydantic_encoder(obj: Any) -> Any:
113
113
  from pydantic import BaseModel
114
114
 
115
115
  if isinstance(obj, BaseModel):
116
- return obj.model_dump()
116
+ return obj.model_dump(mode="json")
117
117
  elif is_dataclass(obj):
118
118
  return asdict(obj)
119
119
 
zenml/zen_server/auth.py CHANGED
@@ -13,13 +13,16 @@
13
13
  # permissions and limitations under the License.
14
14
  """Authentication module for ZenML server."""
15
15
 
16
- from contextvars import ContextVar
16
+ import functools
17
17
  from datetime import datetime, timedelta
18
- from typing import Callable, Optional, Tuple, Union
18
+ from functools import wraps
19
+ from typing import Any, Awaitable, Callable, Optional, Tuple, Union
19
20
  from urllib.parse import urlencode, urlparse
20
21
  from uuid import UUID, uuid4
21
22
 
23
+ import anyio.to_thread
22
24
  import requests
25
+ from anyio import CapacityLimiter
23
26
  from fastapi import Depends, Response
24
27
  from fastapi.security import (
25
28
  HTTPBasic,
@@ -73,40 +76,13 @@ from zenml.zen_server.jwt import JWTToken
73
76
  from zenml.zen_server.utils import (
74
77
  get_zenml_headers,
75
78
  is_same_or_subdomain,
79
+ request_manager,
76
80
  server_config,
77
81
  zen_store,
78
82
  )
79
83
 
80
84
  logger = get_logger(__name__)
81
85
 
82
- # create a context variable to store the authentication context
83
- _auth_context: ContextVar[Optional["AuthContext"]] = ContextVar(
84
- "auth_context", default=None
85
- )
86
-
87
-
88
- def get_auth_context() -> Optional["AuthContext"]:
89
- """Returns the current authentication context.
90
-
91
- Returns:
92
- The authentication context.
93
- """
94
- auth_context = _auth_context.get()
95
- return auth_context
96
-
97
-
98
- def set_auth_context(auth_context: "AuthContext") -> "AuthContext":
99
- """Sets the current authentication context.
100
-
101
- Args:
102
- auth_context: The authentication context.
103
-
104
- Returns:
105
- The authentication context.
106
- """
107
- _auth_context.set(auth_context)
108
- return auth_context
109
-
110
86
 
111
87
  class AuthContext(BaseModel):
112
88
  """The authentication context."""
@@ -1128,4 +1104,50 @@ def authentication_provider() -> Callable[..., AuthContext]:
1128
1104
  raise ValueError(f"Unknown authentication scheme: {auth_scheme}")
1129
1105
 
1130
1106
 
1131
- authorize = authentication_provider()
1107
+ def get_authorization_provider() -> Callable[..., Awaitable[AuthContext]]:
1108
+ """Fetches the async authorization provider.
1109
+
1110
+ Returns:
1111
+ The async authorization provider.
1112
+ """
1113
+ provider = authentication_provider()
1114
+ # Create a custom thread pool limiter with a limit of 1 thread for all
1115
+ # auth calls
1116
+ thread_limiter = CapacityLimiter(server_config().auth_thread_pool_size)
1117
+
1118
+ @wraps(provider)
1119
+ async def async_authorize_fn(*args: Any, **kwargs: Any) -> AuthContext:
1120
+ from zenml.zen_server.utils import get_system_metrics_log_str
1121
+
1122
+ request_context = request_manager().current_request
1123
+
1124
+ @wraps(provider)
1125
+ def sync_authorize_fn(*args: Any, **kwargs: Any) -> AuthContext:
1126
+ assert request_context is not None
1127
+
1128
+ logger.debug(
1129
+ f"[{request_context.log_request_id}] API STATS - "
1130
+ f"{request_context.log_request} "
1131
+ f"AUTHORIZING "
1132
+ f"{get_system_metrics_log_str(request_context.request)}"
1133
+ )
1134
+
1135
+ try:
1136
+ auth_context = provider(*args, **kwargs)
1137
+ request_context.auth_context = auth_context
1138
+ return auth_context
1139
+ finally:
1140
+ logger.debug(
1141
+ f"[{request_context.log_request_id}] API STATS - "
1142
+ f"{request_context.log_request} "
1143
+ f"AUTHORIZED "
1144
+ f"{get_system_metrics_log_str(request_context.request)}"
1145
+ )
1146
+
1147
+ func = functools.partial(sync_authorize_fn, *args, **kwargs)
1148
+ return await anyio.to_thread.run_sync(func, limiter=thread_limiter)
1149
+
1150
+ return async_authorize_fn
1151
+
1152
+
1153
+ authorize = get_authorization_provider()
@@ -1,7 +1,6 @@
1
1
  """Utils concerning anything concerning the cloud control plane backend."""
2
2
 
3
3
  import logging
4
- import threading
5
4
  import time
6
5
  from datetime import datetime, timedelta
7
6
  from threading import RLock
@@ -17,7 +16,11 @@ from zenml.exceptions import (
17
16
  )
18
17
  from zenml.logger import get_logger
19
18
  from zenml.utils.time_utils import utc_now
20
- from zenml.zen_server.utils import get_zenml_headers, server_config
19
+ from zenml.zen_server.utils import (
20
+ get_system_metrics_log_str,
21
+ get_zenml_headers,
22
+ server_config,
23
+ )
21
24
 
22
25
  logger = get_logger(__name__)
23
26
 
@@ -60,13 +63,21 @@ class ZenMLCloudConnection:
60
63
  Returns:
61
64
  The response.
62
65
  """
66
+ from zenml.zen_server.utils import get_current_request_context
67
+
63
68
  url = self._config.api_url + endpoint
64
69
 
70
+ log_request_id = "N/A"
71
+ try:
72
+ request_context = get_current_request_context()
73
+ log_request_id = request_context.log_request_id
74
+ except RuntimeError:
75
+ pass
76
+
65
77
  if logger.isEnabledFor(logging.DEBUG):
66
- # Get the request ID from the current thread object
67
- request_id = threading.current_thread().name
68
78
  logger.debug(
69
- f"[{request_id}] RBAC STATS - {method} {endpoint} started"
79
+ f"[{log_request_id}] RBAC STATS - {method} "
80
+ f"{endpoint} started {get_system_metrics_log_str()}"
70
81
  )
71
82
  start_time = time.time()
72
83
 
@@ -107,8 +118,9 @@ class ZenMLCloudConnection:
107
118
  if logger.isEnabledFor(logging.DEBUG):
108
119
  duration = (time.time() - start_time) * 1000
109
120
  logger.debug(
110
- f"[{request_id}] RBAC STATS - {status_code} {method} "
111
- f"{endpoint} completed in {duration:.2f}ms"
121
+ f"[{log_request_id}] RBAC STATS - "
122
+ f"{status_code} {method} {endpoint} completed in "
123
+ f"{duration:.2f}ms {get_system_metrics_log_str()}"
112
124
  )
113
125
 
114
126
  return response