zenml-nightly 0.83.0.dev20250619__py3-none-any.whl → 0.83.0.dev20250621__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/__init__.py +12 -2
- zenml/analytics/context.py +4 -2
- zenml/config/server_config.py +6 -1
- zenml/constants.py +3 -0
- zenml/entrypoints/step_entrypoint_configuration.py +14 -0
- zenml/models/__init__.py +15 -0
- zenml/models/v2/core/api_transaction.py +193 -0
- zenml/models/v2/core/pipeline_build.py +4 -0
- zenml/models/v2/core/pipeline_deployment.py +8 -1
- zenml/models/v2/core/pipeline_run.py +7 -0
- zenml/models/v2/core/step_run.py +6 -0
- zenml/orchestrators/input_utils.py +34 -11
- zenml/utils/json_utils.py +1 -1
- zenml/zen_server/auth.py +53 -31
- zenml/zen_server/cloud_utils.py +19 -7
- zenml/zen_server/middleware.py +424 -0
- zenml/zen_server/rbac/endpoint_utils.py +5 -2
- zenml/zen_server/rbac/utils.py +12 -7
- zenml/zen_server/request_management.py +556 -0
- zenml/zen_server/routers/auth_endpoints.py +1 -0
- zenml/zen_server/routers/model_versions_endpoints.py +3 -3
- zenml/zen_server/routers/models_endpoints.py +3 -3
- zenml/zen_server/routers/pipeline_builds_endpoints.py +2 -2
- zenml/zen_server/routers/pipeline_deployments_endpoints.py +9 -4
- zenml/zen_server/routers/pipelines_endpoints.py +4 -4
- zenml/zen_server/routers/run_templates_endpoints.py +3 -3
- zenml/zen_server/routers/runs_endpoints.py +4 -4
- zenml/zen_server/routers/service_connectors_endpoints.py +6 -6
- zenml/zen_server/routers/steps_endpoints.py +3 -3
- zenml/zen_server/utils.py +230 -63
- zenml/zen_server/zen_server_api.py +34 -399
- zenml/zen_stores/migrations/versions/3d7e39f3ac92_split_up_step_configurations.py +138 -0
- zenml/zen_stores/migrations/versions/857843db1bcf_add_api_transaction_table.py +69 -0
- zenml/zen_stores/rest_zen_store.py +52 -42
- zenml/zen_stores/schemas/__init__.py +4 -0
- zenml/zen_stores/schemas/api_transaction_schemas.py +141 -0
- zenml/zen_stores/schemas/pipeline_deployment_schemas.py +88 -27
- zenml/zen_stores/schemas/pipeline_run_schemas.py +28 -11
- zenml/zen_stores/schemas/step_run_schemas.py +4 -4
- zenml/zen_stores/sql_zen_store.py +277 -42
- zenml/zen_stores/zen_store_interface.py +7 -1
- {zenml_nightly-0.83.0.dev20250619.dist-info → zenml_nightly-0.83.0.dev20250621.dist-info}/METADATA +1 -1
- {zenml_nightly-0.83.0.dev20250619.dist-info → zenml_nightly-0.83.0.dev20250621.dist-info}/RECORD +47 -41
- {zenml_nightly-0.83.0.dev20250619.dist-info → zenml_nightly-0.83.0.dev20250621.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.83.0.dev20250619.dist-info → zenml_nightly-0.83.0.dev20250621.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.83.0.dev20250619.dist-info → zenml_nightly-0.83.0.dev20250621.dist-info}/entry_points.txt +0 -0
zenml/VERSION
CHANGED
@@ -1 +1 @@
|
|
1
|
-
0.83.0.
|
1
|
+
0.83.0.dev20250621
|
zenml/__init__.py
CHANGED
@@ -14,6 +14,7 @@
|
|
14
14
|
"""Initialization for ZenML."""
|
15
15
|
|
16
16
|
import os
|
17
|
+
from typing import Any
|
17
18
|
|
18
19
|
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
19
20
|
|
@@ -26,6 +27,16 @@ from zenml.logger import init_logging # noqa
|
|
26
27
|
|
27
28
|
init_logging()
|
28
29
|
|
30
|
+
def __getattr__(name: str) -> Any:
|
31
|
+
# We allow directly accessing the entrypoint module as `zenml.entrypoint`
|
32
|
+
# as this is needed for some orchestrators. Instead of directly importing
|
33
|
+
# the entrypoint module here, we import it dynamically. This avoids a
|
34
|
+
# warning when running the `zenml.entrypoints.entrypoint` module directly.
|
35
|
+
if name == "entrypoint":
|
36
|
+
from zenml.entrypoints import entrypoint
|
37
|
+
return entrypoint
|
38
|
+
|
39
|
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
29
40
|
|
30
41
|
# Need to import zenml.models before zenml.config to avoid circular imports
|
31
42
|
from zenml.models import * # noqa: F401
|
@@ -50,7 +61,7 @@ from zenml.steps import step, get_step_context
|
|
50
61
|
from zenml.steps.utils import log_step_metadata
|
51
62
|
from zenml.utils.metadata_utils import log_metadata
|
52
63
|
from zenml.utils.tag_utils import Tag, add_tags, remove_tags
|
53
|
-
|
64
|
+
|
54
65
|
|
55
66
|
__all__ = [
|
56
67
|
"add_tags",
|
@@ -72,5 +83,4 @@ __all__ = [
|
|
72
83
|
"register_artifact",
|
73
84
|
"show",
|
74
85
|
"step",
|
75
|
-
"entrypoint",
|
76
86
|
]
|
zenml/analytics/context.py
CHANGED
@@ -105,8 +105,10 @@ class AnalyticsContext:
|
|
105
105
|
|
106
106
|
# Fetch the `user_id`
|
107
107
|
if self.in_server:
|
108
|
-
from zenml.zen_server.
|
109
|
-
|
108
|
+
from zenml.zen_server.utils import (
|
109
|
+
get_auth_context,
|
110
|
+
server_config,
|
111
|
+
)
|
110
112
|
|
111
113
|
# If the code is running on the server, use the auth context.
|
112
114
|
auth_context = get_auth_context()
|
zenml/config/server_config.py
CHANGED
@@ -33,6 +33,7 @@ from zenml.constants import (
|
|
33
33
|
DEFAULT_REPORTABLE_RESOURCES,
|
34
34
|
DEFAULT_ZENML_JWT_TOKEN_ALGORITHM,
|
35
35
|
DEFAULT_ZENML_JWT_TOKEN_LEEWAY,
|
36
|
+
DEFAULT_ZENML_SERVER_AUTH_THREAD_POOL_SIZE,
|
36
37
|
DEFAULT_ZENML_SERVER_DEVICE_AUTH_POLLING,
|
37
38
|
DEFAULT_ZENML_SERVER_DEVICE_AUTH_TIMEOUT,
|
38
39
|
DEFAULT_ZENML_SERVER_FILE_DOWNLOAD_SIZE_LIMIT,
|
@@ -45,6 +46,7 @@ from zenml.constants import (
|
|
45
46
|
DEFAULT_ZENML_SERVER_MAX_REQUEST_BODY_SIZE_IN_BYTES,
|
46
47
|
DEFAULT_ZENML_SERVER_NAME,
|
47
48
|
DEFAULT_ZENML_SERVER_PIPELINE_RUN_AUTH_WINDOW,
|
49
|
+
DEFAULT_ZENML_SERVER_REQUEST_CACHE_TIMEOUT,
|
48
50
|
DEFAULT_ZENML_SERVER_REQUEST_TIMEOUT,
|
49
51
|
DEFAULT_ZENML_SERVER_SECURE_HEADERS_CACHE,
|
50
52
|
DEFAULT_ZENML_SERVER_SECURE_HEADERS_CONTENT,
|
@@ -354,7 +356,10 @@ class ServerConfiguration(BaseModel):
|
|
354
356
|
auto_activate: bool = False
|
355
357
|
|
356
358
|
thread_pool_size: int = DEFAULT_ZENML_SERVER_THREAD_POOL_SIZE
|
357
|
-
|
359
|
+
auth_thread_pool_size: int = DEFAULT_ZENML_SERVER_AUTH_THREAD_POOL_SIZE
|
360
|
+
request_timeout: int = DEFAULT_ZENML_SERVER_REQUEST_TIMEOUT
|
361
|
+
request_deduplication: bool = True
|
362
|
+
request_cache_timeout: int = DEFAULT_ZENML_SERVER_REQUEST_CACHE_TIMEOUT
|
358
363
|
|
359
364
|
max_request_body_size_in_bytes: int = (
|
360
365
|
DEFAULT_ZENML_SERVER_MAX_REQUEST_BODY_SIZE_IN_BYTES
|
zenml/constants.py
CHANGED
@@ -274,6 +274,7 @@ CODE_HASH_PARAMETER_NAME = "step_source"
|
|
274
274
|
# Server settings
|
275
275
|
DEFAULT_ZENML_SERVER_NAME = "default"
|
276
276
|
DEFAULT_ZENML_SERVER_THREAD_POOL_SIZE = 40
|
277
|
+
DEFAULT_ZENML_SERVER_AUTH_THREAD_POOL_SIZE = 5
|
277
278
|
DEFAULT_ZENML_JWT_TOKEN_LEEWAY = 10
|
278
279
|
DEFAULT_ZENML_JWT_TOKEN_ALGORITHM = "HS256"
|
279
280
|
DEFAULT_ZENML_AUTH_SCHEME = AuthScheme.OAUTH2_PASSWORD_BEARER
|
@@ -283,6 +284,7 @@ DEFAULT_ZENML_SERVER_DEVICE_AUTH_TIMEOUT = 60 * 5 # 5 minutes
|
|
283
284
|
DEFAULT_ZENML_SERVER_DEVICE_AUTH_POLLING = 5 # seconds
|
284
285
|
DEFAULT_HTTP_TIMEOUT = 30
|
285
286
|
DEFAULT_ZENML_SERVER_REQUEST_TIMEOUT = 20 # seconds
|
287
|
+
DEFAULT_ZENML_SERVER_REQUEST_CACHE_TIMEOUT = 300 # seconds
|
286
288
|
SERVICE_CONNECTOR_VERIFY_REQUEST_TIMEOUT = 120 # seconds
|
287
289
|
ZENML_API_KEY_PREFIX = "ZENKEY_"
|
288
290
|
DEFAULT_ZENML_SERVER_PIPELINE_RUN_AUTH_WINDOW = 60 * 48 # 48 hours
|
@@ -372,6 +374,7 @@ EVENT_FLAVORS = "/event-flavors"
|
|
372
374
|
EVENT_SOURCES = "/event-sources"
|
373
375
|
FLAVORS = "/flavors"
|
374
376
|
HEALTH = "/health"
|
377
|
+
READY = "/ready"
|
375
378
|
INFO = "/info"
|
376
379
|
LOAD_INFO = "/load-info"
|
377
380
|
LOGIN = "/login"
|
@@ -16,9 +16,11 @@
|
|
16
16
|
import os
|
17
17
|
import sys
|
18
18
|
from typing import TYPE_CHECKING, Any, List, Set
|
19
|
+
from uuid import UUID
|
19
20
|
|
20
21
|
from zenml.client import Client
|
21
22
|
from zenml.entrypoints.base_entrypoint_configuration import (
|
23
|
+
DEPLOYMENT_ID_OPTION,
|
22
24
|
BaseEntrypointConfiguration,
|
23
25
|
)
|
24
26
|
from zenml.integrations.registry import integration_registry
|
@@ -147,6 +149,18 @@ class StepEntrypointConfiguration(BaseEntrypointConfiguration):
|
|
147
149
|
kwargs[STEP_NAME_OPTION],
|
148
150
|
]
|
149
151
|
|
152
|
+
def load_deployment(self) -> "PipelineDeploymentResponse":
|
153
|
+
"""Loads the deployment.
|
154
|
+
|
155
|
+
Returns:
|
156
|
+
The deployment.
|
157
|
+
"""
|
158
|
+
deployment_id = UUID(self.entrypoint_args[DEPLOYMENT_ID_OPTION])
|
159
|
+
step_name = self.entrypoint_args[STEP_NAME_OPTION]
|
160
|
+
return Client().zen_store.get_deployment(
|
161
|
+
deployment_id=deployment_id, step_configuration_filter=[step_name]
|
162
|
+
)
|
163
|
+
|
150
164
|
def run(self) -> None:
|
151
165
|
"""Prepares the environment and runs the configured step."""
|
152
166
|
deployment = self.load_deployment()
|
zenml/models/__init__.py
CHANGED
@@ -80,6 +80,14 @@ from zenml.models.v2.core.api_key import (
|
|
80
80
|
APIKeyInternalUpdate,
|
81
81
|
APIKeyRotateRequest,
|
82
82
|
)
|
83
|
+
from zenml.models.v2.core.api_transaction import (
|
84
|
+
ApiTransactionRequest,
|
85
|
+
ApiTransactionUpdate,
|
86
|
+
ApiTransactionResponse,
|
87
|
+
ApiTransactionResponseBody,
|
88
|
+
ApiTransactionResponseMetadata,
|
89
|
+
ApiTransactionResponseResources,
|
90
|
+
)
|
83
91
|
from zenml.models.v2.core.artifact import (
|
84
92
|
ArtifactFilter,
|
85
93
|
ArtifactRequest,
|
@@ -433,6 +441,7 @@ from zenml.models.v2.misc.info_models import (
|
|
433
441
|
ActionResponseResources.model_rebuild()
|
434
442
|
ActionResponseMetadata.model_rebuild()
|
435
443
|
APIKeyResponseBody.model_rebuild()
|
444
|
+
ApiTransactionResponse.model_rebuild()
|
436
445
|
ArtifactResponse.model_rebuild()
|
437
446
|
ArtifactResponseBody.model_rebuild()
|
438
447
|
ArtifactResponseMetadata.model_rebuild()
|
@@ -570,6 +579,12 @@ __all__ = [
|
|
570
579
|
"APIKeyInternalResponse",
|
571
580
|
"APIKeyInternalUpdate",
|
572
581
|
"APIKeyRotateRequest",
|
582
|
+
"ApiTransactionRequest",
|
583
|
+
"ApiTransactionUpdate",
|
584
|
+
"ApiTransactionResponse",
|
585
|
+
"ApiTransactionResponseBody",
|
586
|
+
"ApiTransactionResponseMetadata",
|
587
|
+
"ApiTransactionResponseResources",
|
573
588
|
"ArtifactFilter",
|
574
589
|
"ArtifactRequest",
|
575
590
|
"ArtifactResponse",
|
@@ -0,0 +1,193 @@
|
|
1
|
+
# Copyright (c) ZenML GmbH 2024. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
12
|
+
# or implied. See the License for the specific language governing
|
13
|
+
# permissions and limitations under the License.
|
14
|
+
"""Models representing API transactions."""
|
15
|
+
|
16
|
+
from typing import (
|
17
|
+
TYPE_CHECKING,
|
18
|
+
Optional,
|
19
|
+
TypeVar,
|
20
|
+
)
|
21
|
+
from uuid import UUID
|
22
|
+
|
23
|
+
from pydantic import Field, SecretStr
|
24
|
+
|
25
|
+
from zenml.models.v2.base.base import BaseUpdate
|
26
|
+
from zenml.models.v2.base.scoped import (
|
27
|
+
UserScopedRequest,
|
28
|
+
UserScopedResponse,
|
29
|
+
UserScopedResponseBody,
|
30
|
+
UserScopedResponseMetadata,
|
31
|
+
UserScopedResponseResources,
|
32
|
+
)
|
33
|
+
from zenml.utils.secret_utils import PlainSerializedSecretStr
|
34
|
+
|
35
|
+
if TYPE_CHECKING:
|
36
|
+
from zenml.zen_stores.schemas import BaseSchema
|
37
|
+
|
38
|
+
AnySchema = TypeVar("AnySchema", bound=BaseSchema)
|
39
|
+
|
40
|
+
# ------------------ Request Model ------------------
|
41
|
+
|
42
|
+
|
43
|
+
class ApiTransactionRequest(UserScopedRequest):
|
44
|
+
"""Request model for API transactions."""
|
45
|
+
|
46
|
+
transaction_id: UUID = Field(
|
47
|
+
title="The ID of the transaction.",
|
48
|
+
)
|
49
|
+
method: str = Field(
|
50
|
+
title="The HTTP method of the transaction.",
|
51
|
+
)
|
52
|
+
url: str = Field(
|
53
|
+
title="The URL of the transaction.",
|
54
|
+
)
|
55
|
+
|
56
|
+
|
57
|
+
# ------------------ Update Model ------------------
|
58
|
+
|
59
|
+
|
60
|
+
class ApiTransactionUpdate(BaseUpdate):
|
61
|
+
"""Update model for stack components."""
|
62
|
+
|
63
|
+
result: Optional[PlainSerializedSecretStr] = Field(
|
64
|
+
default=None,
|
65
|
+
title="The response payload.",
|
66
|
+
)
|
67
|
+
cache_time: int = Field(
|
68
|
+
title="The time in seconds that the transaction is kept around after "
|
69
|
+
"completion."
|
70
|
+
)
|
71
|
+
|
72
|
+
def get_result(self) -> Optional[str]:
|
73
|
+
"""Get the result of the API transaction.
|
74
|
+
|
75
|
+
Returns:
|
76
|
+
the result of the API transaction.
|
77
|
+
"""
|
78
|
+
result = self.result
|
79
|
+
if result is None:
|
80
|
+
return None
|
81
|
+
return result.get_secret_value()
|
82
|
+
|
83
|
+
def set_result(self, result: str) -> None:
|
84
|
+
"""Set the result of the API transaction.
|
85
|
+
|
86
|
+
Args:
|
87
|
+
result: the result of the API transaction.
|
88
|
+
"""
|
89
|
+
self.result = SecretStr(result)
|
90
|
+
|
91
|
+
|
92
|
+
# ------------------ Response Model ------------------
|
93
|
+
|
94
|
+
|
95
|
+
class ApiTransactionResponseBody(UserScopedResponseBody):
|
96
|
+
"""Response body for API transactions."""
|
97
|
+
|
98
|
+
method: str = Field(
|
99
|
+
title="The HTTP method of the transaction.",
|
100
|
+
)
|
101
|
+
url: str = Field(
|
102
|
+
title="The URL of the transaction.",
|
103
|
+
)
|
104
|
+
completed: bool = Field(
|
105
|
+
title="Whether the transaction is completed.",
|
106
|
+
)
|
107
|
+
result: Optional[PlainSerializedSecretStr] = Field(
|
108
|
+
default=None,
|
109
|
+
title="The response payload.",
|
110
|
+
)
|
111
|
+
|
112
|
+
|
113
|
+
class ApiTransactionResponseMetadata(UserScopedResponseMetadata):
|
114
|
+
"""Response metadata for API transactions."""
|
115
|
+
|
116
|
+
|
117
|
+
class ApiTransactionResponseResources(UserScopedResponseResources):
|
118
|
+
"""Response resources for API transactions."""
|
119
|
+
|
120
|
+
|
121
|
+
class ApiTransactionResponse(
|
122
|
+
UserScopedResponse[
|
123
|
+
ApiTransactionResponseBody,
|
124
|
+
ApiTransactionResponseMetadata,
|
125
|
+
ApiTransactionResponseResources,
|
126
|
+
]
|
127
|
+
):
|
128
|
+
"""Response model for API transactions."""
|
129
|
+
|
130
|
+
def get_hydrated_version(self) -> "ApiTransactionResponse":
|
131
|
+
"""Get the hydrated version of this API transaction.
|
132
|
+
|
133
|
+
Returns:
|
134
|
+
an instance of the same entity with the metadata field attached.
|
135
|
+
"""
|
136
|
+
return self
|
137
|
+
|
138
|
+
# Body and metadata properties
|
139
|
+
|
140
|
+
@property
|
141
|
+
def method(self) -> str:
|
142
|
+
"""The `method` property.
|
143
|
+
|
144
|
+
Returns:
|
145
|
+
the value of the property.
|
146
|
+
"""
|
147
|
+
return self.get_body().method
|
148
|
+
|
149
|
+
@property
|
150
|
+
def url(self) -> str:
|
151
|
+
"""The `url` property.
|
152
|
+
|
153
|
+
Returns:
|
154
|
+
the value of the property.
|
155
|
+
"""
|
156
|
+
return self.get_body().url
|
157
|
+
|
158
|
+
@property
|
159
|
+
def completed(self) -> bool:
|
160
|
+
"""The `completed` property.
|
161
|
+
|
162
|
+
Returns:
|
163
|
+
the value of the property.
|
164
|
+
"""
|
165
|
+
return self.get_body().completed
|
166
|
+
|
167
|
+
@property
|
168
|
+
def result(self) -> Optional[PlainSerializedSecretStr]:
|
169
|
+
"""The `result` property.
|
170
|
+
|
171
|
+
Returns:
|
172
|
+
the value of the property.
|
173
|
+
"""
|
174
|
+
return self.get_body().result
|
175
|
+
|
176
|
+
def get_result(self) -> Optional[str]:
|
177
|
+
"""Get the result of the API transaction.
|
178
|
+
|
179
|
+
Returns:
|
180
|
+
the result of the API transaction.
|
181
|
+
"""
|
182
|
+
result = self.result
|
183
|
+
if result is None:
|
184
|
+
return None
|
185
|
+
return result.get_secret_value()
|
186
|
+
|
187
|
+
def set_result(self, result: str) -> None:
|
188
|
+
"""Set the result of the API transaction.
|
189
|
+
|
190
|
+
Args:
|
191
|
+
result: the result of the API transaction.
|
192
|
+
"""
|
193
|
+
self.get_body().result = SecretStr(result)
|
@@ -200,6 +200,10 @@ class PipelineBuildResponseBody(ProjectScopedResponseBody):
|
|
200
200
|
class PipelineBuildResponseMetadata(ProjectScopedResponseMetadata):
|
201
201
|
"""Response metadata for pipeline builds."""
|
202
202
|
|
203
|
+
__zenml_skip_dehydration__: ClassVar[List[str]] = [
|
204
|
+
"images",
|
205
|
+
]
|
206
|
+
|
203
207
|
pipeline: Optional["PipelineResponse"] = Field(
|
204
208
|
default=None, title="The pipeline that was used for this build."
|
205
209
|
)
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# permissions and limitations under the License.
|
14
14
|
"""Models representing pipeline deployments."""
|
15
15
|
|
16
|
-
from typing import Any, Dict, Optional, Union
|
16
|
+
from typing import Any, ClassVar, Dict, List, Optional, Union
|
17
17
|
from uuid import UUID
|
18
18
|
|
19
19
|
from pydantic import Field
|
@@ -130,6 +130,13 @@ class PipelineDeploymentResponseBody(ProjectScopedResponseBody):
|
|
130
130
|
class PipelineDeploymentResponseMetadata(ProjectScopedResponseMetadata):
|
131
131
|
"""Response metadata for pipeline deployments."""
|
132
132
|
|
133
|
+
__zenml_skip_dehydration__: ClassVar[List[str]] = [
|
134
|
+
"pipeline_configuration",
|
135
|
+
"step_configurations",
|
136
|
+
"client_environment",
|
137
|
+
"pipeline_spec",
|
138
|
+
]
|
139
|
+
|
133
140
|
run_name_template: str = Field(
|
134
141
|
title="The run name template for runs created using this deployment.",
|
135
142
|
)
|
@@ -190,6 +190,13 @@ class PipelineRunResponseBody(ProjectScopedResponseBody):
|
|
190
190
|
class PipelineRunResponseMetadata(ProjectScopedResponseMetadata):
|
191
191
|
"""Response metadata for pipeline runs."""
|
192
192
|
|
193
|
+
__zenml_skip_dehydration__: ClassVar[List[str]] = [
|
194
|
+
"run_metadata",
|
195
|
+
"config",
|
196
|
+
"client_environment",
|
197
|
+
"orchestrator_environment",
|
198
|
+
]
|
199
|
+
|
193
200
|
run_metadata: Dict[str, MetadataType] = Field(
|
194
201
|
default={},
|
195
202
|
title="Metadata associated with this pipeline run.",
|
zenml/models/v2/core/step_run.py
CHANGED
@@ -199,6 +199,12 @@ class StepRunResponseBody(ProjectScopedResponseBody):
|
|
199
199
|
class StepRunResponseMetadata(ProjectScopedResponseMetadata):
|
200
200
|
"""Response metadata for step runs."""
|
201
201
|
|
202
|
+
__zenml_skip_dehydration__: ClassVar[List[str]] = [
|
203
|
+
"config",
|
204
|
+
"spec",
|
205
|
+
"metadata",
|
206
|
+
]
|
207
|
+
|
202
208
|
# Configuration
|
203
209
|
config: "StepConfiguration" = Field(title="The configuration of the step.")
|
204
210
|
spec: "StepSpec" = Field(title="The spec of the step.")
|
@@ -65,17 +65,40 @@ def resolve_step_inputs(
|
|
65
65
|
steps_to_fetch.difference_update(step_runs.keys())
|
66
66
|
|
67
67
|
if steps_to_fetch:
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
68
|
+
# The list of steps might be too big to fit in the default max URL
|
69
|
+
# length of 8KB supported by most servers. So we need to split it into
|
70
|
+
# smaller chunks.
|
71
|
+
steps_list = list(steps_to_fetch)
|
72
|
+
chunks = []
|
73
|
+
current_chunk = []
|
74
|
+
current_length = 0
|
75
|
+
# stay under 6KB for good measure.
|
76
|
+
max_chunk_length = 6000
|
77
|
+
|
78
|
+
for step_name in steps_list:
|
79
|
+
current_chunk.append(step_name)
|
80
|
+
current_length += len(step_name) + 5 # 5 is for the JSON encoding
|
81
|
+
|
82
|
+
if current_length > max_chunk_length:
|
83
|
+
chunks.append(current_chunk)
|
84
|
+
current_chunk = []
|
85
|
+
current_length = 0
|
86
|
+
|
87
|
+
if current_chunk:
|
88
|
+
chunks.append(current_chunk)
|
89
|
+
|
90
|
+
for chunk in chunks:
|
91
|
+
step_runs.update(
|
92
|
+
{
|
93
|
+
run_step.name: run_step
|
94
|
+
for run_step in pagination_utils.depaginate(
|
95
|
+
Client().list_run_steps,
|
96
|
+
pipeline_run_id=pipeline_run.id,
|
97
|
+
project=pipeline_run.project_id,
|
98
|
+
name="oneof:" + json.dumps(chunk),
|
99
|
+
)
|
100
|
+
}
|
101
|
+
)
|
79
102
|
|
80
103
|
input_artifacts: Dict[str, StepRunInputResponse] = {}
|
81
104
|
for name, input_ in step.spec.inputs.items():
|
zenml/utils/json_utils.py
CHANGED
zenml/zen_server/auth.py
CHANGED
@@ -13,13 +13,16 @@
|
|
13
13
|
# permissions and limitations under the License.
|
14
14
|
"""Authentication module for ZenML server."""
|
15
15
|
|
16
|
-
|
16
|
+
import functools
|
17
17
|
from datetime import datetime, timedelta
|
18
|
-
from
|
18
|
+
from functools import wraps
|
19
|
+
from typing import Any, Awaitable, Callable, Optional, Tuple, Union
|
19
20
|
from urllib.parse import urlencode, urlparse
|
20
21
|
from uuid import UUID, uuid4
|
21
22
|
|
23
|
+
import anyio.to_thread
|
22
24
|
import requests
|
25
|
+
from anyio import CapacityLimiter
|
23
26
|
from fastapi import Depends, Response
|
24
27
|
from fastapi.security import (
|
25
28
|
HTTPBasic,
|
@@ -73,40 +76,13 @@ from zenml.zen_server.jwt import JWTToken
|
|
73
76
|
from zenml.zen_server.utils import (
|
74
77
|
get_zenml_headers,
|
75
78
|
is_same_or_subdomain,
|
79
|
+
request_manager,
|
76
80
|
server_config,
|
77
81
|
zen_store,
|
78
82
|
)
|
79
83
|
|
80
84
|
logger = get_logger(__name__)
|
81
85
|
|
82
|
-
# create a context variable to store the authentication context
|
83
|
-
_auth_context: ContextVar[Optional["AuthContext"]] = ContextVar(
|
84
|
-
"auth_context", default=None
|
85
|
-
)
|
86
|
-
|
87
|
-
|
88
|
-
def get_auth_context() -> Optional["AuthContext"]:
|
89
|
-
"""Returns the current authentication context.
|
90
|
-
|
91
|
-
Returns:
|
92
|
-
The authentication context.
|
93
|
-
"""
|
94
|
-
auth_context = _auth_context.get()
|
95
|
-
return auth_context
|
96
|
-
|
97
|
-
|
98
|
-
def set_auth_context(auth_context: "AuthContext") -> "AuthContext":
|
99
|
-
"""Sets the current authentication context.
|
100
|
-
|
101
|
-
Args:
|
102
|
-
auth_context: The authentication context.
|
103
|
-
|
104
|
-
Returns:
|
105
|
-
The authentication context.
|
106
|
-
"""
|
107
|
-
_auth_context.set(auth_context)
|
108
|
-
return auth_context
|
109
|
-
|
110
86
|
|
111
87
|
class AuthContext(BaseModel):
|
112
88
|
"""The authentication context."""
|
@@ -1128,4 +1104,50 @@ def authentication_provider() -> Callable[..., AuthContext]:
|
|
1128
1104
|
raise ValueError(f"Unknown authentication scheme: {auth_scheme}")
|
1129
1105
|
|
1130
1106
|
|
1131
|
-
|
1107
|
+
def get_authorization_provider() -> Callable[..., Awaitable[AuthContext]]:
|
1108
|
+
"""Fetches the async authorization provider.
|
1109
|
+
|
1110
|
+
Returns:
|
1111
|
+
The async authorization provider.
|
1112
|
+
"""
|
1113
|
+
provider = authentication_provider()
|
1114
|
+
# Create a custom thread pool limiter with a limit of 1 thread for all
|
1115
|
+
# auth calls
|
1116
|
+
thread_limiter = CapacityLimiter(server_config().auth_thread_pool_size)
|
1117
|
+
|
1118
|
+
@wraps(provider)
|
1119
|
+
async def async_authorize_fn(*args: Any, **kwargs: Any) -> AuthContext:
|
1120
|
+
from zenml.zen_server.utils import get_system_metrics_log_str
|
1121
|
+
|
1122
|
+
request_context = request_manager().current_request
|
1123
|
+
|
1124
|
+
@wraps(provider)
|
1125
|
+
def sync_authorize_fn(*args: Any, **kwargs: Any) -> AuthContext:
|
1126
|
+
assert request_context is not None
|
1127
|
+
|
1128
|
+
logger.debug(
|
1129
|
+
f"[{request_context.log_request_id}] API STATS - "
|
1130
|
+
f"{request_context.log_request} "
|
1131
|
+
f"AUTHORIZING "
|
1132
|
+
f"{get_system_metrics_log_str(request_context.request)}"
|
1133
|
+
)
|
1134
|
+
|
1135
|
+
try:
|
1136
|
+
auth_context = provider(*args, **kwargs)
|
1137
|
+
request_context.auth_context = auth_context
|
1138
|
+
return auth_context
|
1139
|
+
finally:
|
1140
|
+
logger.debug(
|
1141
|
+
f"[{request_context.log_request_id}] API STATS - "
|
1142
|
+
f"{request_context.log_request} "
|
1143
|
+
f"AUTHORIZED "
|
1144
|
+
f"{get_system_metrics_log_str(request_context.request)}"
|
1145
|
+
)
|
1146
|
+
|
1147
|
+
func = functools.partial(sync_authorize_fn, *args, **kwargs)
|
1148
|
+
return await anyio.to_thread.run_sync(func, limiter=thread_limiter)
|
1149
|
+
|
1150
|
+
return async_authorize_fn
|
1151
|
+
|
1152
|
+
|
1153
|
+
authorize = get_authorization_provider()
|
zenml/zen_server/cloud_utils.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1
1
|
"""Utils concerning anything concerning the cloud control plane backend."""
|
2
2
|
|
3
3
|
import logging
|
4
|
-
import threading
|
5
4
|
import time
|
6
5
|
from datetime import datetime, timedelta
|
7
6
|
from threading import RLock
|
@@ -17,7 +16,11 @@ from zenml.exceptions import (
|
|
17
16
|
)
|
18
17
|
from zenml.logger import get_logger
|
19
18
|
from zenml.utils.time_utils import utc_now
|
20
|
-
from zenml.zen_server.utils import
|
19
|
+
from zenml.zen_server.utils import (
|
20
|
+
get_system_metrics_log_str,
|
21
|
+
get_zenml_headers,
|
22
|
+
server_config,
|
23
|
+
)
|
21
24
|
|
22
25
|
logger = get_logger(__name__)
|
23
26
|
|
@@ -60,13 +63,21 @@ class ZenMLCloudConnection:
|
|
60
63
|
Returns:
|
61
64
|
The response.
|
62
65
|
"""
|
66
|
+
from zenml.zen_server.utils import get_current_request_context
|
67
|
+
|
63
68
|
url = self._config.api_url + endpoint
|
64
69
|
|
70
|
+
log_request_id = "N/A"
|
71
|
+
try:
|
72
|
+
request_context = get_current_request_context()
|
73
|
+
log_request_id = request_context.log_request_id
|
74
|
+
except RuntimeError:
|
75
|
+
pass
|
76
|
+
|
65
77
|
if logger.isEnabledFor(logging.DEBUG):
|
66
|
-
# Get the request ID from the current thread object
|
67
|
-
request_id = threading.current_thread().name
|
68
78
|
logger.debug(
|
69
|
-
f"[{
|
79
|
+
f"[{log_request_id}] RBAC STATS - {method} "
|
80
|
+
f"{endpoint} started {get_system_metrics_log_str()}"
|
70
81
|
)
|
71
82
|
start_time = time.time()
|
72
83
|
|
@@ -107,8 +118,9 @@ class ZenMLCloudConnection:
|
|
107
118
|
if logger.isEnabledFor(logging.DEBUG):
|
108
119
|
duration = (time.time() - start_time) * 1000
|
109
120
|
logger.debug(
|
110
|
-
f"[{
|
111
|
-
f"{endpoint} completed in
|
121
|
+
f"[{log_request_id}] RBAC STATS - "
|
122
|
+
f"{status_code} {method} {endpoint} completed in "
|
123
|
+
f"{duration:.2f}ms {get_system_metrics_log_str()}"
|
112
124
|
)
|
113
125
|
|
114
126
|
return response
|