zenml-nightly 0.83.0.dev20250618__py3-none-any.whl → 0.83.0.dev20250621__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/__init__.py +12 -2
- zenml/analytics/context.py +4 -2
- zenml/config/server_config.py +6 -1
- zenml/constants.py +3 -0
- zenml/entrypoints/step_entrypoint_configuration.py +14 -0
- zenml/models/__init__.py +15 -0
- zenml/models/v2/core/api_transaction.py +193 -0
- zenml/models/v2/core/pipeline_build.py +4 -0
- zenml/models/v2/core/pipeline_deployment.py +8 -1
- zenml/models/v2/core/pipeline_run.py +7 -0
- zenml/models/v2/core/step_run.py +6 -0
- zenml/orchestrators/input_utils.py +34 -11
- zenml/utils/json_utils.py +1 -1
- zenml/zen_server/auth.py +53 -31
- zenml/zen_server/cloud_utils.py +19 -7
- zenml/zen_server/middleware.py +424 -0
- zenml/zen_server/rbac/endpoint_utils.py +5 -2
- zenml/zen_server/rbac/utils.py +12 -7
- zenml/zen_server/request_management.py +556 -0
- zenml/zen_server/routers/auth_endpoints.py +1 -0
- zenml/zen_server/routers/model_versions_endpoints.py +3 -3
- zenml/zen_server/routers/models_endpoints.py +3 -3
- zenml/zen_server/routers/pipeline_builds_endpoints.py +2 -2
- zenml/zen_server/routers/pipeline_deployments_endpoints.py +9 -4
- zenml/zen_server/routers/pipelines_endpoints.py +4 -4
- zenml/zen_server/routers/run_templates_endpoints.py +3 -3
- zenml/zen_server/routers/runs_endpoints.py +4 -4
- zenml/zen_server/routers/service_connectors_endpoints.py +6 -6
- zenml/zen_server/routers/steps_endpoints.py +3 -3
- zenml/zen_server/utils.py +230 -63
- zenml/zen_server/zen_server_api.py +34 -399
- zenml/zen_stores/migrations/versions/3d7e39f3ac92_split_up_step_configurations.py +138 -0
- zenml/zen_stores/migrations/versions/857843db1bcf_add_api_transaction_table.py +69 -0
- zenml/zen_stores/rest_zen_store.py +52 -42
- zenml/zen_stores/schemas/__init__.py +4 -0
- zenml/zen_stores/schemas/api_transaction_schemas.py +141 -0
- zenml/zen_stores/schemas/pipeline_deployment_schemas.py +88 -27
- zenml/zen_stores/schemas/pipeline_run_schemas.py +28 -11
- zenml/zen_stores/schemas/step_run_schemas.py +4 -4
- zenml/zen_stores/sql_zen_store.py +277 -42
- zenml/zen_stores/zen_store_interface.py +7 -1
- {zenml_nightly-0.83.0.dev20250618.dist-info → zenml_nightly-0.83.0.dev20250621.dist-info}/METADATA +1 -1
- {zenml_nightly-0.83.0.dev20250618.dist-info → zenml_nightly-0.83.0.dev20250621.dist-info}/RECORD +47 -41
- {zenml_nightly-0.83.0.dev20250618.dist-info → zenml_nightly-0.83.0.dev20250621.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.83.0.dev20250618.dist-info → zenml_nightly-0.83.0.dev20250621.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.83.0.dev20250618.dist-info → zenml_nightly-0.83.0.dev20250621.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,69 @@
|
|
1
|
+
"""add API transaction table [857843db1bcf].
|
2
|
+
|
3
|
+
Revision ID: 857843db1bcf
|
4
|
+
Revises: 0.83.0
|
5
|
+
Create Date: 2025-06-13 19:47:14.165158
|
6
|
+
|
7
|
+
"""
|
8
|
+
|
9
|
+
import sqlalchemy as sa
|
10
|
+
import sqlmodel
|
11
|
+
from alembic import op
|
12
|
+
from sqlalchemy.dialects.mysql import MEDIUMTEXT
|
13
|
+
|
14
|
+
# revision identifiers, used by Alembic.
|
15
|
+
revision = "857843db1bcf"
|
16
|
+
down_revision = "0.83.0"
|
17
|
+
branch_labels = None
|
18
|
+
depends_on = None
|
19
|
+
|
20
|
+
|
21
|
+
def upgrade() -> None:
|
22
|
+
"""Upgrade database schema and/or data, creating a new revision."""
|
23
|
+
# ### commands auto generated by Alembic - please adjust! ###
|
24
|
+
op.create_table(
|
25
|
+
"api_transaction",
|
26
|
+
sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
27
|
+
sa.Column("created", sa.DateTime(), nullable=False),
|
28
|
+
sa.Column("updated", sa.DateTime(), nullable=False),
|
29
|
+
sa.Column(
|
30
|
+
"method", sqlmodel.sql.sqltypes.AutoString(), nullable=False
|
31
|
+
),
|
32
|
+
sa.Column("url", sa.TEXT(), nullable=False),
|
33
|
+
sa.Column("completed", sa.Boolean(), nullable=False),
|
34
|
+
sa.Column(
|
35
|
+
"result",
|
36
|
+
MEDIUMTEXT
|
37
|
+
if op.get_bind().dialect.name == "mysql"
|
38
|
+
else sa.String(length=16777215),
|
39
|
+
nullable=True,
|
40
|
+
),
|
41
|
+
sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
|
42
|
+
sa.Column("expired", sa.DateTime(), nullable=True),
|
43
|
+
sa.ForeignKeyConstraint(
|
44
|
+
["user_id"],
|
45
|
+
["user.id"],
|
46
|
+
name="fk_api_transaction_user_id_user",
|
47
|
+
ondelete="CASCADE",
|
48
|
+
),
|
49
|
+
sa.PrimaryKeyConstraint("id"),
|
50
|
+
)
|
51
|
+
op.create_index(
|
52
|
+
"ix_api_transaction_completed_expired",
|
53
|
+
"api_transaction",
|
54
|
+
["completed", "expired"],
|
55
|
+
unique=False,
|
56
|
+
)
|
57
|
+
|
58
|
+
# ### end Alembic commands ###
|
59
|
+
|
60
|
+
|
61
|
+
def downgrade() -> None:
|
62
|
+
"""Downgrade database schema and/or data back to the previous revision."""
|
63
|
+
# ### commands auto generated by Alembic - please adjust! ###
|
64
|
+
op.drop_index(
|
65
|
+
"ix_api_transaction_completed_expired",
|
66
|
+
table_name="api_transaction",
|
67
|
+
)
|
68
|
+
op.drop_table("api_transaction")
|
69
|
+
# ### end Alembic commands ###
|
@@ -1631,7 +1631,10 @@ class RestZenStore(BaseZenStore):
|
|
1631
1631
|
)
|
1632
1632
|
|
1633
1633
|
def get_deployment(
|
1634
|
-
self,
|
1634
|
+
self,
|
1635
|
+
deployment_id: UUID,
|
1636
|
+
hydrate: bool = True,
|
1637
|
+
step_configuration_filter: Optional[List[str]] = None,
|
1635
1638
|
) -> PipelineDeploymentResponse:
|
1636
1639
|
"""Get a deployment with a given ID.
|
1637
1640
|
|
@@ -1639,6 +1642,9 @@ class RestZenStore(BaseZenStore):
|
|
1639
1642
|
deployment_id: ID of the deployment.
|
1640
1643
|
hydrate: Flag deciding whether to hydrate the output model(s)
|
1641
1644
|
by including metadata fields in the response.
|
1645
|
+
step_configuration_filter: List of step configurations to include in
|
1646
|
+
the response. If not given, all step configurations will be
|
1647
|
+
included.
|
1642
1648
|
|
1643
1649
|
Returns:
|
1644
1650
|
The deployment.
|
@@ -1647,7 +1653,10 @@ class RestZenStore(BaseZenStore):
|
|
1647
1653
|
resource_id=deployment_id,
|
1648
1654
|
route=PIPELINE_DEPLOYMENTS,
|
1649
1655
|
response_model=PipelineDeploymentResponse,
|
1650
|
-
params={
|
1656
|
+
params={
|
1657
|
+
"hydrate": hydrate,
|
1658
|
+
"step_configuration_filter": step_configuration_filter,
|
1659
|
+
},
|
1651
1660
|
)
|
1652
1661
|
|
1653
1662
|
def list_deployments(
|
@@ -4194,20 +4203,6 @@ class RestZenStore(BaseZenStore):
|
|
4194
4203
|
Returns:
|
4195
4204
|
A requests session.
|
4196
4205
|
"""
|
4197
|
-
|
4198
|
-
class AugmentedRetry(Retry):
|
4199
|
-
"""Augmented retry class that also retries on 429 status codes for POST requests."""
|
4200
|
-
|
4201
|
-
def is_retry(
|
4202
|
-
self,
|
4203
|
-
method: str,
|
4204
|
-
status_code: int,
|
4205
|
-
has_retry_after: bool = False,
|
4206
|
-
) -> bool:
|
4207
|
-
if status_code == 429:
|
4208
|
-
return True
|
4209
|
-
return super().is_retry(method, status_code, has_retry_after)
|
4210
|
-
|
4211
4206
|
if self._session is None:
|
4212
4207
|
# We only need to initialize the session once over the lifetime
|
4213
4208
|
# of the client. We can swap the token out when it expires.
|
@@ -4217,12 +4212,14 @@ class RestZenStore(BaseZenStore):
|
|
4217
4212
|
)
|
4218
4213
|
|
4219
4214
|
self._session = requests.Session()
|
4220
|
-
# Retries are triggered for
|
4221
|
-
# OPTIONS and DELETE) on specific HTTP status codes:
|
4215
|
+
# Retries are triggered for all HTTP methods (GET, HEAD, POST, PUT,
|
4216
|
+
# PATCH, OPTIONS and DELETE) on specific HTTP status codes:
|
4222
4217
|
#
|
4218
|
+
# 408: Request Timeout.
|
4219
|
+
# 429: Too Many Requests.
|
4223
4220
|
# 502: Bad Gateway.
|
4224
4221
|
# 503: Service Unavailable.
|
4225
|
-
# 504: Gateway Timeout
|
4222
|
+
# 504: Gateway Timeout
|
4226
4223
|
#
|
4227
4224
|
# This also handles connection level errors, if a connection attempt
|
4228
4225
|
# fails due to transient issues like:
|
@@ -4237,12 +4234,20 @@ class RestZenStore(BaseZenStore):
|
|
4237
4234
|
# the timeout period.
|
4238
4235
|
# Connection Refused: If the server refuses the connection.
|
4239
4236
|
#
|
4240
|
-
retries =
|
4237
|
+
retries = Retry(
|
4241
4238
|
connect=5,
|
4242
4239
|
read=8,
|
4243
4240
|
redirect=3,
|
4244
4241
|
status=10,
|
4245
|
-
allowed_methods=[
|
4242
|
+
allowed_methods=[
|
4243
|
+
"HEAD",
|
4244
|
+
"GET",
|
4245
|
+
"POST",
|
4246
|
+
"PUT",
|
4247
|
+
"PATCH",
|
4248
|
+
"DELETE",
|
4249
|
+
"OPTIONS",
|
4250
|
+
],
|
4246
4251
|
status_forcelist=[
|
4247
4252
|
408, # Request Timeout
|
4248
4253
|
429, # Too Many Requests
|
@@ -4389,14 +4394,6 @@ class RestZenStore(BaseZenStore):
|
|
4389
4394
|
self.session.headers.update(
|
4390
4395
|
{source_context.name: source_context.get().value}
|
4391
4396
|
)
|
4392
|
-
# Add a request ID to the request headers
|
4393
|
-
request_id = str(uuid4())[:8]
|
4394
|
-
self.session.headers.update({"X-Request-ID": request_id})
|
4395
|
-
path = url.removeprefix(self.url)
|
4396
|
-
start_time = time.time()
|
4397
|
-
logger.debug(
|
4398
|
-
f"Sending {method} request to {path} with request ID {request_id}..."
|
4399
|
-
)
|
4400
4397
|
|
4401
4398
|
# If the server replies with a credentials validation (401 Unauthorized)
|
4402
4399
|
# error, we (re-)authenticate and retry the request here in the
|
@@ -4413,18 +4410,31 @@ class RestZenStore(BaseZenStore):
|
|
4413
4410
|
# two times: once after initial authentication and once after
|
4414
4411
|
# re-authentication.
|
4415
4412
|
re_authenticated = False
|
4413
|
+
path = url.removeprefix(self.url)
|
4416
4414
|
while True:
|
4415
|
+
# Add a request ID to the request headers
|
4416
|
+
request_id = str(uuid4())[:8]
|
4417
|
+
self.session.headers.update({"X-Request-ID": request_id})
|
4418
|
+
# Add an idempotency key to the request headers to ensure that
|
4419
|
+
# requests are idempotent.
|
4420
|
+
self.session.headers.update({"Idempotency-Key": str(uuid4())})
|
4421
|
+
|
4422
|
+
start_time = time.time()
|
4423
|
+
logger.debug(f"[{request_id}] {method} {path} started...")
|
4424
|
+
status_code = "failed"
|
4425
|
+
|
4417
4426
|
try:
|
4418
|
-
|
4419
|
-
|
4420
|
-
|
4421
|
-
|
4422
|
-
|
4423
|
-
|
4424
|
-
|
4425
|
-
**kwargs,
|
4426
|
-
)
|
4427
|
+
response = self.session.request(
|
4428
|
+
method,
|
4429
|
+
url,
|
4430
|
+
params=params if params else {},
|
4431
|
+
verify=self.config.verify_ssl,
|
4432
|
+
timeout=timeout or self.config.http_timeout,
|
4433
|
+
**kwargs,
|
4427
4434
|
)
|
4435
|
+
|
4436
|
+
status_code = str(response.status_code)
|
4437
|
+
return self._handle_response(response)
|
4428
4438
|
except CredentialsNotValid as e:
|
4429
4439
|
# NOTE: CredentialsNotValid is raised only when the server
|
4430
4440
|
# explicitly indicates that the credentials are not valid and
|
@@ -4438,7 +4448,7 @@ class RestZenStore(BaseZenStore):
|
|
4438
4448
|
# request again, this time with a valid API token in the
|
4439
4449
|
# header.
|
4440
4450
|
logger.debug(
|
4441
|
-
f"The last request
|
4451
|
+
f"[{request_id}] The last request was not "
|
4442
4452
|
f"authenticated: {e}\n"
|
4443
4453
|
"Re-authenticating and retrying..."
|
4444
4454
|
)
|
@@ -4466,7 +4476,7 @@ class RestZenStore(BaseZenStore):
|
|
4466
4476
|
# that was rejected by the server. We attempt a
|
4467
4477
|
# re-authentication here and then retry the request.
|
4468
4478
|
logger.debug(
|
4469
|
-
f"The last request
|
4479
|
+
f"[{request_id}] The last request was authenticated "
|
4470
4480
|
"with an API token that was rejected by the server: "
|
4471
4481
|
f"{e}\n"
|
4472
4482
|
"Re-authenticating and retrying..."
|
@@ -4480,7 +4490,7 @@ class RestZenStore(BaseZenStore):
|
|
4480
4490
|
# The last request was made after re-authenticating but
|
4481
4491
|
# still failed. Bailing out.
|
4482
4492
|
logger.debug(
|
4483
|
-
f"The last request
|
4493
|
+
f"[{request_id}] The last request failed after "
|
4484
4494
|
"re-authenticating: {e}\n"
|
4485
4495
|
"Bailing out..."
|
4486
4496
|
)
|
@@ -4492,7 +4502,7 @@ class RestZenStore(BaseZenStore):
|
|
4492
4502
|
end_time = time.time()
|
4493
4503
|
duration = (end_time - start_time) * 1000
|
4494
4504
|
logger.debug(
|
4495
|
-
f"
|
4505
|
+
f"[{request_id}] {status_code} {method} {path} took "
|
4496
4506
|
f"{duration:.2f}ms."
|
4497
4507
|
)
|
4498
4508
|
|
@@ -35,6 +35,7 @@ from zenml.zen_stores.schemas.flavor_schemas import FlavorSchema
|
|
35
35
|
from zenml.zen_stores.schemas.server_settings_schemas import ServerSettingsSchema
|
36
36
|
from zenml.zen_stores.schemas.pipeline_deployment_schemas import (
|
37
37
|
PipelineDeploymentSchema,
|
38
|
+
StepConfigurationSchema,
|
38
39
|
)
|
39
40
|
from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema
|
40
41
|
from zenml.zen_stores.schemas.pipeline_schemas import PipelineSchema
|
@@ -74,6 +75,7 @@ from zenml.zen_stores.schemas.model_schemas import (
|
|
74
75
|
)
|
75
76
|
from zenml.zen_stores.schemas.run_template_schemas import RunTemplateSchema
|
76
77
|
from zenml.zen_stores.schemas.server_settings_schemas import ServerSettingsSchema
|
78
|
+
from zenml.zen_stores.schemas.api_transaction_schemas import ApiTransactionSchema
|
77
79
|
|
78
80
|
__all__ = [
|
79
81
|
"ActionSchema",
|
@@ -91,6 +93,7 @@ __all__ = [
|
|
91
93
|
"OAuthDeviceSchema",
|
92
94
|
"PipelineBuildSchema",
|
93
95
|
"PipelineDeploymentSchema",
|
96
|
+
"StepConfigurationSchema",
|
94
97
|
"PipelineRunSchema",
|
95
98
|
"PipelineSchema",
|
96
99
|
"RunMetadataResourceSchema",
|
@@ -119,4 +122,5 @@ __all__ = [
|
|
119
122
|
"ModelVersionArtifactSchema",
|
120
123
|
"ModelVersionPipelineRunSchema",
|
121
124
|
"ProjectSchema",
|
125
|
+
"ApiTransactionSchema",
|
122
126
|
]
|
@@ -0,0 +1,141 @@
|
|
1
|
+
# Copyright (c) ZenML GmbH 2025. All Rights Reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at:
|
6
|
+
#
|
7
|
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
|
12
|
+
# or implied. See the License for the specific language governing
|
13
|
+
# permissions and limitations under the License.
|
14
|
+
"""SQLModel implementation of idempotent API transaction tables."""
|
15
|
+
|
16
|
+
from datetime import datetime, timedelta
|
17
|
+
from typing import Any, Optional
|
18
|
+
from uuid import UUID
|
19
|
+
|
20
|
+
from sqlalchemy import TEXT, Column, String
|
21
|
+
from sqlalchemy.dialects.mysql import MEDIUMTEXT
|
22
|
+
from sqlmodel import Field
|
23
|
+
|
24
|
+
from zenml.constants import MEDIUMTEXT_MAX_LENGTH
|
25
|
+
from zenml.models import (
|
26
|
+
ApiTransactionRequest,
|
27
|
+
ApiTransactionResponse,
|
28
|
+
ApiTransactionResponseBody,
|
29
|
+
ApiTransactionUpdate,
|
30
|
+
)
|
31
|
+
from zenml.utils.time_utils import utc_now
|
32
|
+
from zenml.zen_stores.schemas.base_schemas import BaseSchema
|
33
|
+
from zenml.zen_stores.schemas.schema_utils import (
|
34
|
+
build_foreign_key_field,
|
35
|
+
build_index,
|
36
|
+
)
|
37
|
+
from zenml.zen_stores.schemas.user_schemas import UserSchema
|
38
|
+
|
39
|
+
|
40
|
+
class ApiTransactionSchema(BaseSchema, table=True):
|
41
|
+
"""SQL Model for API transactions."""
|
42
|
+
|
43
|
+
__tablename__ = "api_transaction"
|
44
|
+
__table_args__ = (
|
45
|
+
build_index(
|
46
|
+
table_name=__tablename__,
|
47
|
+
column_names=[
|
48
|
+
"completed",
|
49
|
+
"expired",
|
50
|
+
],
|
51
|
+
),
|
52
|
+
)
|
53
|
+
method: str
|
54
|
+
url: str = Field(sa_column=Column(TEXT, nullable=False))
|
55
|
+
completed: bool = Field(default=False)
|
56
|
+
result: Optional[str] = Field(
|
57
|
+
default=None,
|
58
|
+
sa_column=Column(
|
59
|
+
String(length=MEDIUMTEXT_MAX_LENGTH).with_variant(
|
60
|
+
MEDIUMTEXT, "mysql"
|
61
|
+
),
|
62
|
+
nullable=True,
|
63
|
+
),
|
64
|
+
)
|
65
|
+
expired: Optional[datetime] = Field(default=None, nullable=True)
|
66
|
+
|
67
|
+
user_id: UUID = build_foreign_key_field(
|
68
|
+
source=__tablename__,
|
69
|
+
target=UserSchema.__tablename__,
|
70
|
+
source_column="user_id",
|
71
|
+
target_column="id",
|
72
|
+
ondelete="CASCADE",
|
73
|
+
nullable=False,
|
74
|
+
)
|
75
|
+
|
76
|
+
@classmethod
|
77
|
+
def from_request(
|
78
|
+
cls, request: ApiTransactionRequest
|
79
|
+
) -> "ApiTransactionSchema":
|
80
|
+
"""Create a new API transaction from a request.
|
81
|
+
|
82
|
+
Args:
|
83
|
+
request: The API transaction request.
|
84
|
+
|
85
|
+
Returns:
|
86
|
+
The API transaction schema.
|
87
|
+
"""
|
88
|
+
assert request.user is not None, "User must be set."
|
89
|
+
return cls(
|
90
|
+
id=request.transaction_id,
|
91
|
+
user_id=request.user,
|
92
|
+
method=request.method,
|
93
|
+
url=request.url,
|
94
|
+
completed=False,
|
95
|
+
)
|
96
|
+
|
97
|
+
def to_model(
|
98
|
+
self,
|
99
|
+
include_metadata: bool = False,
|
100
|
+
include_resources: bool = False,
|
101
|
+
**kwargs: Any,
|
102
|
+
) -> ApiTransactionResponse:
|
103
|
+
"""Convert the SQL model to a ZenML model.
|
104
|
+
|
105
|
+
Args:
|
106
|
+
include_metadata: Whether to include metadata in the response.
|
107
|
+
include_resources: Whether to include resources in the response.
|
108
|
+
**kwargs: Additional keyword arguments.
|
109
|
+
|
110
|
+
Returns:
|
111
|
+
The API transaction response.
|
112
|
+
"""
|
113
|
+
response = ApiTransactionResponse(
|
114
|
+
id=self.id,
|
115
|
+
body=ApiTransactionResponseBody(
|
116
|
+
method=self.method,
|
117
|
+
url=self.url,
|
118
|
+
created=self.created,
|
119
|
+
updated=self.updated,
|
120
|
+
user_id=self.user_id,
|
121
|
+
completed=self.completed,
|
122
|
+
),
|
123
|
+
)
|
124
|
+
if self.result is not None:
|
125
|
+
response.set_result(self.result)
|
126
|
+
return response
|
127
|
+
|
128
|
+
def update(self, update: ApiTransactionUpdate) -> "ApiTransactionSchema":
|
129
|
+
"""Update the API transaction.
|
130
|
+
|
131
|
+
Args:
|
132
|
+
update: The API transaction update.
|
133
|
+
|
134
|
+
Returns:
|
135
|
+
The API transaction schema.
|
136
|
+
"""
|
137
|
+
if update.result is not None:
|
138
|
+
self.result = update.get_result()
|
139
|
+
self.updated = utc_now()
|
140
|
+
self.expired = self.updated + timedelta(seconds=update.cache_time)
|
141
|
+
return self
|
@@ -17,11 +17,11 @@ import json
|
|
17
17
|
from typing import TYPE_CHECKING, Any, List, Optional, Sequence
|
18
18
|
from uuid import UUID
|
19
19
|
|
20
|
-
from sqlalchemy import TEXT, Column, String
|
20
|
+
from sqlalchemy import TEXT, Column, String, UniqueConstraint
|
21
21
|
from sqlalchemy.dialects.mysql import MEDIUMTEXT
|
22
|
-
from sqlalchemy.orm import joinedload
|
22
|
+
from sqlalchemy.orm import joinedload, object_session
|
23
23
|
from sqlalchemy.sql.base import ExecutableOption
|
24
|
-
from sqlmodel import Field, Relationship
|
24
|
+
from sqlmodel import Field, Relationship, asc, col, select
|
25
25
|
|
26
26
|
from zenml.config.pipeline_configurations import PipelineConfiguration
|
27
27
|
from zenml.config.pipeline_spec import PipelineSpec
|
@@ -69,14 +69,6 @@ class PipelineDeploymentSchema(BaseSchema, table=True):
|
|
69
69
|
nullable=False,
|
70
70
|
)
|
71
71
|
)
|
72
|
-
step_configurations: str = Field(
|
73
|
-
sa_column=Column(
|
74
|
-
String(length=MEDIUMTEXT_MAX_LENGTH).with_variant(
|
75
|
-
MEDIUMTEXT, "mysql"
|
76
|
-
),
|
77
|
-
nullable=False,
|
78
|
-
)
|
79
|
-
)
|
80
72
|
client_environment: str = Field(sa_column=Column(TEXT, nullable=False))
|
81
73
|
run_name_template: str = Field(nullable=False)
|
82
74
|
client_version: str = Field(nullable=True)
|
@@ -174,6 +166,46 @@ class PipelineDeploymentSchema(BaseSchema, table=True):
|
|
174
166
|
step_runs: List["StepRunSchema"] = Relationship(
|
175
167
|
sa_relationship_kwargs={"cascade": "delete"}
|
176
168
|
)
|
169
|
+
step_configurations: List["StepConfigurationSchema"] = Relationship(
|
170
|
+
sa_relationship_kwargs={
|
171
|
+
"cascade": "delete",
|
172
|
+
"order_by": "asc(StepConfigurationSchema.index)",
|
173
|
+
}
|
174
|
+
)
|
175
|
+
step_count: int
|
176
|
+
|
177
|
+
def get_step_configurations(
|
178
|
+
self, include: Optional[List[str]] = None
|
179
|
+
) -> List["StepConfigurationSchema"]:
|
180
|
+
"""Get step configurations for the deployment.
|
181
|
+
|
182
|
+
Args:
|
183
|
+
include: List of step names to include. If not given, all step
|
184
|
+
configurations will be included.
|
185
|
+
|
186
|
+
Raises:
|
187
|
+
RuntimeError: If no session for the schema exists.
|
188
|
+
|
189
|
+
Returns:
|
190
|
+
List of step configurations.
|
191
|
+
"""
|
192
|
+
if session := object_session(self):
|
193
|
+
query = (
|
194
|
+
select(StepConfigurationSchema)
|
195
|
+
.where(StepConfigurationSchema.deployment_id == self.id)
|
196
|
+
.order_by(asc(StepConfigurationSchema.index))
|
197
|
+
)
|
198
|
+
|
199
|
+
if include:
|
200
|
+
query = query.where(
|
201
|
+
col(StepConfigurationSchema.name).in_(include)
|
202
|
+
)
|
203
|
+
|
204
|
+
return list(session.execute(query).scalars().all())
|
205
|
+
else:
|
206
|
+
raise RuntimeError(
|
207
|
+
"Missing DB session to fetch step configurations."
|
208
|
+
)
|
177
209
|
|
178
210
|
@classmethod
|
179
211
|
def get_query_options(
|
@@ -230,14 +262,6 @@ class PipelineDeploymentSchema(BaseSchema, table=True):
|
|
230
262
|
Returns:
|
231
263
|
The created `PipelineDeploymentSchema`.
|
232
264
|
"""
|
233
|
-
# Don't include the merged config in the step configurations, we
|
234
|
-
# reconstruct it in the `to_model` method using the pipeline
|
235
|
-
# configuration.
|
236
|
-
step_configurations = {
|
237
|
-
invocation_id: step.model_dump(mode="json", exclude={"config"})
|
238
|
-
for invocation_id, step in request.step_configurations.items()
|
239
|
-
}
|
240
|
-
|
241
265
|
client_env = json.dumps(request.client_environment)
|
242
266
|
if len(client_env) > TEXT_FIELD_MAX_LENGTH:
|
243
267
|
logger.warning(
|
@@ -257,10 +281,7 @@ class PipelineDeploymentSchema(BaseSchema, table=True):
|
|
257
281
|
code_reference_id=code_reference_id,
|
258
282
|
run_name_template=request.run_name_template,
|
259
283
|
pipeline_configuration=request.pipeline_configuration.model_dump_json(),
|
260
|
-
|
261
|
-
step_configurations,
|
262
|
-
sort_keys=False,
|
263
|
-
),
|
284
|
+
step_count=len(request.step_configurations),
|
264
285
|
client_environment=client_env,
|
265
286
|
client_version=request.client_version,
|
266
287
|
server_version=request.server_version,
|
@@ -278,6 +299,7 @@ class PipelineDeploymentSchema(BaseSchema, table=True):
|
|
278
299
|
include_metadata: bool = False,
|
279
300
|
include_resources: bool = False,
|
280
301
|
include_python_packages: bool = False,
|
302
|
+
step_configuration_filter: Optional[List[str]] = None,
|
281
303
|
**kwargs: Any,
|
282
304
|
) -> PipelineDeploymentResponse:
|
283
305
|
"""Convert a `PipelineDeploymentSchema` to a `PipelineDeploymentResponse`.
|
@@ -286,6 +308,9 @@ class PipelineDeploymentSchema(BaseSchema, table=True):
|
|
286
308
|
include_metadata: Whether the metadata will be filled.
|
287
309
|
include_resources: Whether the resources will be filled.
|
288
310
|
include_python_packages: Whether the python packages will be filled.
|
311
|
+
step_configuration_filter: List of step configurations to include in
|
312
|
+
the response. If not given, all step configurations will be
|
313
|
+
included.
|
289
314
|
**kwargs: Keyword arguments to allow schema specific logic
|
290
315
|
|
291
316
|
|
@@ -303,10 +328,13 @@ class PipelineDeploymentSchema(BaseSchema, table=True):
|
|
303
328
|
pipeline_configuration = PipelineConfiguration.model_validate_json(
|
304
329
|
self.pipeline_configuration
|
305
330
|
)
|
306
|
-
step_configurations =
|
307
|
-
for
|
308
|
-
|
309
|
-
|
331
|
+
step_configurations = {}
|
332
|
+
for step_configuration in self.get_step_configurations(
|
333
|
+
include=step_configuration_filter
|
334
|
+
):
|
335
|
+
step_configurations[step_configuration.name] = Step.from_dict(
|
336
|
+
json.loads(step_configuration.config),
|
337
|
+
pipeline_configuration,
|
310
338
|
)
|
311
339
|
|
312
340
|
client_environment = json.loads(self.client_environment)
|
@@ -348,3 +376,36 @@ class PipelineDeploymentSchema(BaseSchema, table=True):
|
|
348
376
|
metadata=metadata,
|
349
377
|
resources=resources,
|
350
378
|
)
|
379
|
+
|
380
|
+
|
381
|
+
class StepConfigurationSchema(BaseSchema, table=True):
|
382
|
+
"""SQL Model for step configurations."""
|
383
|
+
|
384
|
+
__tablename__ = "step_configuration"
|
385
|
+
__table_args__ = (
|
386
|
+
UniqueConstraint(
|
387
|
+
"deployment_id",
|
388
|
+
"name",
|
389
|
+
name="unique_step_name_for_deployment",
|
390
|
+
),
|
391
|
+
)
|
392
|
+
|
393
|
+
index: int
|
394
|
+
name: str
|
395
|
+
config: str = Field(
|
396
|
+
sa_column=Column(
|
397
|
+
String(length=MEDIUMTEXT_MAX_LENGTH).with_variant(
|
398
|
+
MEDIUMTEXT, "mysql"
|
399
|
+
),
|
400
|
+
nullable=False,
|
401
|
+
)
|
402
|
+
)
|
403
|
+
|
404
|
+
deployment_id: UUID = build_foreign_key_field(
|
405
|
+
source=__tablename__,
|
406
|
+
target=PipelineDeploymentSchema.__tablename__,
|
407
|
+
source_column="deployment_id",
|
408
|
+
target_column="id",
|
409
|
+
ondelete="CASCADE",
|
410
|
+
nullable=False,
|
411
|
+
)
|
@@ -389,19 +389,36 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True):
|
|
389
389
|
RuntimeError: if the model creation fails.
|
390
390
|
"""
|
391
391
|
if self.deployment is not None:
|
392
|
-
|
393
|
-
|
394
|
-
include_python_packages=include_python_packages,
|
392
|
+
config = PipelineConfiguration.model_validate_json(
|
393
|
+
self.deployment.pipeline_configuration
|
395
394
|
)
|
395
|
+
client_environment = json.loads(self.deployment.client_environment)
|
396
396
|
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
397
|
+
stack = (
|
398
|
+
self.deployment.stack.to_model()
|
399
|
+
if self.deployment.stack
|
400
|
+
else None
|
401
|
+
)
|
402
|
+
pipeline = (
|
403
|
+
self.deployment.pipeline.to_model()
|
404
|
+
if self.deployment.pipeline
|
405
|
+
else None
|
406
|
+
)
|
407
|
+
build = (
|
408
|
+
self.deployment.build.to_model()
|
409
|
+
if self.deployment.build
|
410
|
+
else None
|
411
|
+
)
|
412
|
+
schedule = (
|
413
|
+
self.deployment.schedule.to_model()
|
414
|
+
if self.deployment.schedule
|
415
|
+
else None
|
416
|
+
)
|
417
|
+
code_reference = (
|
418
|
+
self.deployment.code_reference.to_model()
|
419
|
+
if self.deployment.code_reference
|
420
|
+
else None
|
421
|
+
)
|
405
422
|
|
406
423
|
elif self.pipeline_configuration is not None:
|
407
424
|
config = PipelineConfiguration.model_validate_json(
|
@@ -290,10 +290,10 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True):
|
|
290
290
|
"""
|
291
291
|
step = None
|
292
292
|
if self.deployment is not None:
|
293
|
-
step_configurations =
|
294
|
-
self.
|
293
|
+
step_configurations = self.deployment.get_step_configurations(
|
294
|
+
include=[self.name]
|
295
295
|
)
|
296
|
-
if
|
296
|
+
if step_configurations:
|
297
297
|
pipeline_configuration = (
|
298
298
|
PipelineConfiguration.model_validate_json(
|
299
299
|
self.deployment.pipeline_configuration
|
@@ -304,7 +304,7 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True):
|
|
304
304
|
inplace=True,
|
305
305
|
)
|
306
306
|
step = Step.from_dict(
|
307
|
-
step_configurations[
|
307
|
+
json.loads(step_configurations[0].config),
|
308
308
|
pipeline_configuration=pipeline_configuration,
|
309
309
|
)
|
310
310
|
if not step and self.step_configuration:
|