zenml-nightly 0.83.0.dev20250618__py3-none-any.whl → 0.83.0.dev20250621__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. zenml/VERSION +1 -1
  2. zenml/__init__.py +12 -2
  3. zenml/analytics/context.py +4 -2
  4. zenml/config/server_config.py +6 -1
  5. zenml/constants.py +3 -0
  6. zenml/entrypoints/step_entrypoint_configuration.py +14 -0
  7. zenml/models/__init__.py +15 -0
  8. zenml/models/v2/core/api_transaction.py +193 -0
  9. zenml/models/v2/core/pipeline_build.py +4 -0
  10. zenml/models/v2/core/pipeline_deployment.py +8 -1
  11. zenml/models/v2/core/pipeline_run.py +7 -0
  12. zenml/models/v2/core/step_run.py +6 -0
  13. zenml/orchestrators/input_utils.py +34 -11
  14. zenml/utils/json_utils.py +1 -1
  15. zenml/zen_server/auth.py +53 -31
  16. zenml/zen_server/cloud_utils.py +19 -7
  17. zenml/zen_server/middleware.py +424 -0
  18. zenml/zen_server/rbac/endpoint_utils.py +5 -2
  19. zenml/zen_server/rbac/utils.py +12 -7
  20. zenml/zen_server/request_management.py +556 -0
  21. zenml/zen_server/routers/auth_endpoints.py +1 -0
  22. zenml/zen_server/routers/model_versions_endpoints.py +3 -3
  23. zenml/zen_server/routers/models_endpoints.py +3 -3
  24. zenml/zen_server/routers/pipeline_builds_endpoints.py +2 -2
  25. zenml/zen_server/routers/pipeline_deployments_endpoints.py +9 -4
  26. zenml/zen_server/routers/pipelines_endpoints.py +4 -4
  27. zenml/zen_server/routers/run_templates_endpoints.py +3 -3
  28. zenml/zen_server/routers/runs_endpoints.py +4 -4
  29. zenml/zen_server/routers/service_connectors_endpoints.py +6 -6
  30. zenml/zen_server/routers/steps_endpoints.py +3 -3
  31. zenml/zen_server/utils.py +230 -63
  32. zenml/zen_server/zen_server_api.py +34 -399
  33. zenml/zen_stores/migrations/versions/3d7e39f3ac92_split_up_step_configurations.py +138 -0
  34. zenml/zen_stores/migrations/versions/857843db1bcf_add_api_transaction_table.py +69 -0
  35. zenml/zen_stores/rest_zen_store.py +52 -42
  36. zenml/zen_stores/schemas/__init__.py +4 -0
  37. zenml/zen_stores/schemas/api_transaction_schemas.py +141 -0
  38. zenml/zen_stores/schemas/pipeline_deployment_schemas.py +88 -27
  39. zenml/zen_stores/schemas/pipeline_run_schemas.py +28 -11
  40. zenml/zen_stores/schemas/step_run_schemas.py +4 -4
  41. zenml/zen_stores/sql_zen_store.py +277 -42
  42. zenml/zen_stores/zen_store_interface.py +7 -1
  43. {zenml_nightly-0.83.0.dev20250618.dist-info → zenml_nightly-0.83.0.dev20250621.dist-info}/METADATA +1 -1
  44. {zenml_nightly-0.83.0.dev20250618.dist-info → zenml_nightly-0.83.0.dev20250621.dist-info}/RECORD +47 -41
  45. {zenml_nightly-0.83.0.dev20250618.dist-info → zenml_nightly-0.83.0.dev20250621.dist-info}/LICENSE +0 -0
  46. {zenml_nightly-0.83.0.dev20250618.dist-info → zenml_nightly-0.83.0.dev20250621.dist-info}/WHEEL +0 -0
  47. {zenml_nightly-0.83.0.dev20250618.dist-info → zenml_nightly-0.83.0.dev20250621.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,69 @@
1
+ """add API transaction table [857843db1bcf].
2
+
3
+ Revision ID: 857843db1bcf
4
+ Revises: 0.83.0
5
+ Create Date: 2025-06-13 19:47:14.165158
6
+
7
+ """
8
+
9
+ import sqlalchemy as sa
10
+ import sqlmodel
11
+ from alembic import op
12
+ from sqlalchemy.dialects.mysql import MEDIUMTEXT
13
+
14
+ # revision identifiers, used by Alembic.
15
+ revision = "857843db1bcf"
16
+ down_revision = "0.83.0"
17
+ branch_labels = None
18
+ depends_on = None
19
+
20
+
21
+ def upgrade() -> None:
22
+ """Upgrade database schema and/or data, creating a new revision."""
23
+ # ### commands auto generated by Alembic - please adjust! ###
24
+ op.create_table(
25
+ "api_transaction",
26
+ sa.Column("id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
27
+ sa.Column("created", sa.DateTime(), nullable=False),
28
+ sa.Column("updated", sa.DateTime(), nullable=False),
29
+ sa.Column(
30
+ "method", sqlmodel.sql.sqltypes.AutoString(), nullable=False
31
+ ),
32
+ sa.Column("url", sa.TEXT(), nullable=False),
33
+ sa.Column("completed", sa.Boolean(), nullable=False),
34
+ sa.Column(
35
+ "result",
36
+ MEDIUMTEXT
37
+ if op.get_bind().dialect.name == "mysql"
38
+ else sa.String(length=16777215),
39
+ nullable=True,
40
+ ),
41
+ sa.Column("user_id", sqlmodel.sql.sqltypes.GUID(), nullable=False),
42
+ sa.Column("expired", sa.DateTime(), nullable=True),
43
+ sa.ForeignKeyConstraint(
44
+ ["user_id"],
45
+ ["user.id"],
46
+ name="fk_api_transaction_user_id_user",
47
+ ondelete="CASCADE",
48
+ ),
49
+ sa.PrimaryKeyConstraint("id"),
50
+ )
51
+ op.create_index(
52
+ "ix_api_transaction_completed_expired",
53
+ "api_transaction",
54
+ ["completed", "expired"],
55
+ unique=False,
56
+ )
57
+
58
+ # ### end Alembic commands ###
59
+
60
+
61
+ def downgrade() -> None:
62
+ """Downgrade database schema and/or data back to the previous revision."""
63
+ # ### commands auto generated by Alembic - please adjust! ###
64
+ op.drop_index(
65
+ "ix_api_transaction_completed_expired",
66
+ table_name="api_transaction",
67
+ )
68
+ op.drop_table("api_transaction")
69
+ # ### end Alembic commands ###
@@ -1631,7 +1631,10 @@ class RestZenStore(BaseZenStore):
1631
1631
  )
1632
1632
 
1633
1633
  def get_deployment(
1634
- self, deployment_id: UUID, hydrate: bool = True
1634
+ self,
1635
+ deployment_id: UUID,
1636
+ hydrate: bool = True,
1637
+ step_configuration_filter: Optional[List[str]] = None,
1635
1638
  ) -> PipelineDeploymentResponse:
1636
1639
  """Get a deployment with a given ID.
1637
1640
 
@@ -1639,6 +1642,9 @@ class RestZenStore(BaseZenStore):
1639
1642
  deployment_id: ID of the deployment.
1640
1643
  hydrate: Flag deciding whether to hydrate the output model(s)
1641
1644
  by including metadata fields in the response.
1645
+ step_configuration_filter: List of step configurations to include in
1646
+ the response. If not given, all step configurations will be
1647
+ included.
1642
1648
 
1643
1649
  Returns:
1644
1650
  The deployment.
@@ -1647,7 +1653,10 @@ class RestZenStore(BaseZenStore):
1647
1653
  resource_id=deployment_id,
1648
1654
  route=PIPELINE_DEPLOYMENTS,
1649
1655
  response_model=PipelineDeploymentResponse,
1650
- params={"hydrate": hydrate},
1656
+ params={
1657
+ "hydrate": hydrate,
1658
+ "step_configuration_filter": step_configuration_filter,
1659
+ },
1651
1660
  )
1652
1661
 
1653
1662
  def list_deployments(
@@ -4194,20 +4203,6 @@ class RestZenStore(BaseZenStore):
4194
4203
  Returns:
4195
4204
  A requests session.
4196
4205
  """
4197
-
4198
- class AugmentedRetry(Retry):
4199
- """Augmented retry class that also retries on 429 status codes for POST requests."""
4200
-
4201
- def is_retry(
4202
- self,
4203
- method: str,
4204
- status_code: int,
4205
- has_retry_after: bool = False,
4206
- ) -> bool:
4207
- if status_code == 429:
4208
- return True
4209
- return super().is_retry(method, status_code, has_retry_after)
4210
-
4211
4206
  if self._session is None:
4212
4207
  # We only need to initialize the session once over the lifetime
4213
4208
  # of the client. We can swap the token out when it expires.
@@ -4217,12 +4212,14 @@ class RestZenStore(BaseZenStore):
4217
4212
  )
4218
4213
 
4219
4214
  self._session = requests.Session()
4220
- # Retries are triggered for idempotent HTTP methods (GET, HEAD, PUT,
4221
- # OPTIONS and DELETE) on specific HTTP status codes:
4215
+ # Retries are triggered for all HTTP methods (GET, HEAD, POST, PUT,
4216
+ # PATCH, OPTIONS and DELETE) on specific HTTP status codes:
4222
4217
  #
4218
+ # 408: Request Timeout.
4219
+ # 429: Too Many Requests.
4223
4220
  # 502: Bad Gateway.
4224
4221
  # 503: Service Unavailable.
4225
- # 504: Gateway Timeout.
4222
+ # 504: Gateway Timeout
4226
4223
  #
4227
4224
  # This also handles connection level errors, if a connection attempt
4228
4225
  # fails due to transient issues like:
@@ -4237,12 +4234,20 @@ class RestZenStore(BaseZenStore):
4237
4234
  # the timeout period.
4238
4235
  # Connection Refused: If the server refuses the connection.
4239
4236
  #
4240
- retries = AugmentedRetry(
4237
+ retries = Retry(
4241
4238
  connect=5,
4242
4239
  read=8,
4243
4240
  redirect=3,
4244
4241
  status=10,
4245
- allowed_methods=["HEAD", "GET", "PUT", "DELETE", "OPTIONS"],
4242
+ allowed_methods=[
4243
+ "HEAD",
4244
+ "GET",
4245
+ "POST",
4246
+ "PUT",
4247
+ "PATCH",
4248
+ "DELETE",
4249
+ "OPTIONS",
4250
+ ],
4246
4251
  status_forcelist=[
4247
4252
  408, # Request Timeout
4248
4253
  429, # Too Many Requests
@@ -4389,14 +4394,6 @@ class RestZenStore(BaseZenStore):
4389
4394
  self.session.headers.update(
4390
4395
  {source_context.name: source_context.get().value}
4391
4396
  )
4392
- # Add a request ID to the request headers
4393
- request_id = str(uuid4())[:8]
4394
- self.session.headers.update({"X-Request-ID": request_id})
4395
- path = url.removeprefix(self.url)
4396
- start_time = time.time()
4397
- logger.debug(
4398
- f"Sending {method} request to {path} with request ID {request_id}..."
4399
- )
4400
4397
 
4401
4398
  # If the server replies with a credentials validation (401 Unauthorized)
4402
4399
  # error, we (re-)authenticate and retry the request here in the
@@ -4413,18 +4410,31 @@ class RestZenStore(BaseZenStore):
4413
4410
  # two times: once after initial authentication and once after
4414
4411
  # re-authentication.
4415
4412
  re_authenticated = False
4413
+ path = url.removeprefix(self.url)
4416
4414
  while True:
4415
+ # Add a request ID to the request headers
4416
+ request_id = str(uuid4())[:8]
4417
+ self.session.headers.update({"X-Request-ID": request_id})
4418
+ # Add an idempotency key to the request headers to ensure that
4419
+ # requests are idempotent.
4420
+ self.session.headers.update({"Idempotency-Key": str(uuid4())})
4421
+
4422
+ start_time = time.time()
4423
+ logger.debug(f"[{request_id}] {method} {path} started...")
4424
+ status_code = "failed"
4425
+
4417
4426
  try:
4418
- return self._handle_response(
4419
- self.session.request(
4420
- method,
4421
- url,
4422
- params=params if params else {},
4423
- verify=self.config.verify_ssl,
4424
- timeout=timeout or self.config.http_timeout,
4425
- **kwargs,
4426
- )
4427
+ response = self.session.request(
4428
+ method,
4429
+ url,
4430
+ params=params if params else {},
4431
+ verify=self.config.verify_ssl,
4432
+ timeout=timeout or self.config.http_timeout,
4433
+ **kwargs,
4427
4434
  )
4435
+
4436
+ status_code = str(response.status_code)
4437
+ return self._handle_response(response)
4428
4438
  except CredentialsNotValid as e:
4429
4439
  # NOTE: CredentialsNotValid is raised only when the server
4430
4440
  # explicitly indicates that the credentials are not valid and
@@ -4438,7 +4448,7 @@ class RestZenStore(BaseZenStore):
4438
4448
  # request again, this time with a valid API token in the
4439
4449
  # header.
4440
4450
  logger.debug(
4441
- f"The last request with ID {request_id} was not "
4451
+ f"[{request_id}] The last request was not "
4442
4452
  f"authenticated: {e}\n"
4443
4453
  "Re-authenticating and retrying..."
4444
4454
  )
@@ -4466,7 +4476,7 @@ class RestZenStore(BaseZenStore):
4466
4476
  # that was rejected by the server. We attempt a
4467
4477
  # re-authentication here and then retry the request.
4468
4478
  logger.debug(
4469
- f"The last request with ID {request_id} was authenticated "
4479
+ f"[{request_id}] The last request was authenticated "
4470
4480
  "with an API token that was rejected by the server: "
4471
4481
  f"{e}\n"
4472
4482
  "Re-authenticating and retrying..."
@@ -4480,7 +4490,7 @@ class RestZenStore(BaseZenStore):
4480
4490
  # The last request was made after re-authenticating but
4481
4491
  # still failed. Bailing out.
4482
4492
  logger.debug(
4483
- f"The last request with ID {request_id} failed after "
4493
+ f"[{request_id}] The last request failed after "
4484
4494
  "re-authenticating: {e}\n"
4485
4495
  "Bailing out..."
4486
4496
  )
@@ -4492,7 +4502,7 @@ class RestZenStore(BaseZenStore):
4492
4502
  end_time = time.time()
4493
4503
  duration = (end_time - start_time) * 1000
4494
4504
  logger.debug(
4495
- f"Request to {path} with request ID {request_id} took "
4505
+ f"[{request_id}] {status_code} {method} {path} took "
4496
4506
  f"{duration:.2f}ms."
4497
4507
  )
4498
4508
 
@@ -35,6 +35,7 @@ from zenml.zen_stores.schemas.flavor_schemas import FlavorSchema
35
35
  from zenml.zen_stores.schemas.server_settings_schemas import ServerSettingsSchema
36
36
  from zenml.zen_stores.schemas.pipeline_deployment_schemas import (
37
37
  PipelineDeploymentSchema,
38
+ StepConfigurationSchema,
38
39
  )
39
40
  from zenml.zen_stores.schemas.pipeline_run_schemas import PipelineRunSchema
40
41
  from zenml.zen_stores.schemas.pipeline_schemas import PipelineSchema
@@ -74,6 +75,7 @@ from zenml.zen_stores.schemas.model_schemas import (
74
75
  )
75
76
  from zenml.zen_stores.schemas.run_template_schemas import RunTemplateSchema
76
77
  from zenml.zen_stores.schemas.server_settings_schemas import ServerSettingsSchema
78
+ from zenml.zen_stores.schemas.api_transaction_schemas import ApiTransactionSchema
77
79
 
78
80
  __all__ = [
79
81
  "ActionSchema",
@@ -91,6 +93,7 @@ __all__ = [
91
93
  "OAuthDeviceSchema",
92
94
  "PipelineBuildSchema",
93
95
  "PipelineDeploymentSchema",
96
+ "StepConfigurationSchema",
94
97
  "PipelineRunSchema",
95
98
  "PipelineSchema",
96
99
  "RunMetadataResourceSchema",
@@ -119,4 +122,5 @@ __all__ = [
119
122
  "ModelVersionArtifactSchema",
120
123
  "ModelVersionPipelineRunSchema",
121
124
  "ProjectSchema",
125
+ "ApiTransactionSchema",
122
126
  ]
@@ -0,0 +1,141 @@
1
+ # Copyright (c) ZenML GmbH 2025. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12
+ # or implied. See the License for the specific language governing
13
+ # permissions and limitations under the License.
14
+ """SQLModel implementation of idempotent API transaction tables."""
15
+
16
+ from datetime import datetime, timedelta
17
+ from typing import Any, Optional
18
+ from uuid import UUID
19
+
20
+ from sqlalchemy import TEXT, Column, String
21
+ from sqlalchemy.dialects.mysql import MEDIUMTEXT
22
+ from sqlmodel import Field
23
+
24
+ from zenml.constants import MEDIUMTEXT_MAX_LENGTH
25
+ from zenml.models import (
26
+ ApiTransactionRequest,
27
+ ApiTransactionResponse,
28
+ ApiTransactionResponseBody,
29
+ ApiTransactionUpdate,
30
+ )
31
+ from zenml.utils.time_utils import utc_now
32
+ from zenml.zen_stores.schemas.base_schemas import BaseSchema
33
+ from zenml.zen_stores.schemas.schema_utils import (
34
+ build_foreign_key_field,
35
+ build_index,
36
+ )
37
+ from zenml.zen_stores.schemas.user_schemas import UserSchema
38
+
39
+
40
+ class ApiTransactionSchema(BaseSchema, table=True):
41
+ """SQL Model for API transactions."""
42
+
43
+ __tablename__ = "api_transaction"
44
+ __table_args__ = (
45
+ build_index(
46
+ table_name=__tablename__,
47
+ column_names=[
48
+ "completed",
49
+ "expired",
50
+ ],
51
+ ),
52
+ )
53
+ method: str
54
+ url: str = Field(sa_column=Column(TEXT, nullable=False))
55
+ completed: bool = Field(default=False)
56
+ result: Optional[str] = Field(
57
+ default=None,
58
+ sa_column=Column(
59
+ String(length=MEDIUMTEXT_MAX_LENGTH).with_variant(
60
+ MEDIUMTEXT, "mysql"
61
+ ),
62
+ nullable=True,
63
+ ),
64
+ )
65
+ expired: Optional[datetime] = Field(default=None, nullable=True)
66
+
67
+ user_id: UUID = build_foreign_key_field(
68
+ source=__tablename__,
69
+ target=UserSchema.__tablename__,
70
+ source_column="user_id",
71
+ target_column="id",
72
+ ondelete="CASCADE",
73
+ nullable=False,
74
+ )
75
+
76
+ @classmethod
77
+ def from_request(
78
+ cls, request: ApiTransactionRequest
79
+ ) -> "ApiTransactionSchema":
80
+ """Create a new API transaction from a request.
81
+
82
+ Args:
83
+ request: The API transaction request.
84
+
85
+ Returns:
86
+ The API transaction schema.
87
+ """
88
+ assert request.user is not None, "User must be set."
89
+ return cls(
90
+ id=request.transaction_id,
91
+ user_id=request.user,
92
+ method=request.method,
93
+ url=request.url,
94
+ completed=False,
95
+ )
96
+
97
+ def to_model(
98
+ self,
99
+ include_metadata: bool = False,
100
+ include_resources: bool = False,
101
+ **kwargs: Any,
102
+ ) -> ApiTransactionResponse:
103
+ """Convert the SQL model to a ZenML model.
104
+
105
+ Args:
106
+ include_metadata: Whether to include metadata in the response.
107
+ include_resources: Whether to include resources in the response.
108
+ **kwargs: Additional keyword arguments.
109
+
110
+ Returns:
111
+ The API transaction response.
112
+ """
113
+ response = ApiTransactionResponse(
114
+ id=self.id,
115
+ body=ApiTransactionResponseBody(
116
+ method=self.method,
117
+ url=self.url,
118
+ created=self.created,
119
+ updated=self.updated,
120
+ user_id=self.user_id,
121
+ completed=self.completed,
122
+ ),
123
+ )
124
+ if self.result is not None:
125
+ response.set_result(self.result)
126
+ return response
127
+
128
+ def update(self, update: ApiTransactionUpdate) -> "ApiTransactionSchema":
129
+ """Update the API transaction.
130
+
131
+ Args:
132
+ update: The API transaction update.
133
+
134
+ Returns:
135
+ The API transaction schema.
136
+ """
137
+ if update.result is not None:
138
+ self.result = update.get_result()
139
+ self.updated = utc_now()
140
+ self.expired = self.updated + timedelta(seconds=update.cache_time)
141
+ return self
@@ -17,11 +17,11 @@ import json
17
17
  from typing import TYPE_CHECKING, Any, List, Optional, Sequence
18
18
  from uuid import UUID
19
19
 
20
- from sqlalchemy import TEXT, Column, String
20
+ from sqlalchemy import TEXT, Column, String, UniqueConstraint
21
21
  from sqlalchemy.dialects.mysql import MEDIUMTEXT
22
- from sqlalchemy.orm import joinedload
22
+ from sqlalchemy.orm import joinedload, object_session
23
23
  from sqlalchemy.sql.base import ExecutableOption
24
- from sqlmodel import Field, Relationship
24
+ from sqlmodel import Field, Relationship, asc, col, select
25
25
 
26
26
  from zenml.config.pipeline_configurations import PipelineConfiguration
27
27
  from zenml.config.pipeline_spec import PipelineSpec
@@ -69,14 +69,6 @@ class PipelineDeploymentSchema(BaseSchema, table=True):
69
69
  nullable=False,
70
70
  )
71
71
  )
72
- step_configurations: str = Field(
73
- sa_column=Column(
74
- String(length=MEDIUMTEXT_MAX_LENGTH).with_variant(
75
- MEDIUMTEXT, "mysql"
76
- ),
77
- nullable=False,
78
- )
79
- )
80
72
  client_environment: str = Field(sa_column=Column(TEXT, nullable=False))
81
73
  run_name_template: str = Field(nullable=False)
82
74
  client_version: str = Field(nullable=True)
@@ -174,6 +166,46 @@ class PipelineDeploymentSchema(BaseSchema, table=True):
174
166
  step_runs: List["StepRunSchema"] = Relationship(
175
167
  sa_relationship_kwargs={"cascade": "delete"}
176
168
  )
169
+ step_configurations: List["StepConfigurationSchema"] = Relationship(
170
+ sa_relationship_kwargs={
171
+ "cascade": "delete",
172
+ "order_by": "asc(StepConfigurationSchema.index)",
173
+ }
174
+ )
175
+ step_count: int
176
+
177
+ def get_step_configurations(
178
+ self, include: Optional[List[str]] = None
179
+ ) -> List["StepConfigurationSchema"]:
180
+ """Get step configurations for the deployment.
181
+
182
+ Args:
183
+ include: List of step names to include. If not given, all step
184
+ configurations will be included.
185
+
186
+ Raises:
187
+ RuntimeError: If no session for the schema exists.
188
+
189
+ Returns:
190
+ List of step configurations.
191
+ """
192
+ if session := object_session(self):
193
+ query = (
194
+ select(StepConfigurationSchema)
195
+ .where(StepConfigurationSchema.deployment_id == self.id)
196
+ .order_by(asc(StepConfigurationSchema.index))
197
+ )
198
+
199
+ if include:
200
+ query = query.where(
201
+ col(StepConfigurationSchema.name).in_(include)
202
+ )
203
+
204
+ return list(session.execute(query).scalars().all())
205
+ else:
206
+ raise RuntimeError(
207
+ "Missing DB session to fetch step configurations."
208
+ )
177
209
 
178
210
  @classmethod
179
211
  def get_query_options(
@@ -230,14 +262,6 @@ class PipelineDeploymentSchema(BaseSchema, table=True):
230
262
  Returns:
231
263
  The created `PipelineDeploymentSchema`.
232
264
  """
233
- # Don't include the merged config in the step configurations, we
234
- # reconstruct it in the `to_model` method using the pipeline
235
- # configuration.
236
- step_configurations = {
237
- invocation_id: step.model_dump(mode="json", exclude={"config"})
238
- for invocation_id, step in request.step_configurations.items()
239
- }
240
-
241
265
  client_env = json.dumps(request.client_environment)
242
266
  if len(client_env) > TEXT_FIELD_MAX_LENGTH:
243
267
  logger.warning(
@@ -257,10 +281,7 @@ class PipelineDeploymentSchema(BaseSchema, table=True):
257
281
  code_reference_id=code_reference_id,
258
282
  run_name_template=request.run_name_template,
259
283
  pipeline_configuration=request.pipeline_configuration.model_dump_json(),
260
- step_configurations=json.dumps(
261
- step_configurations,
262
- sort_keys=False,
263
- ),
284
+ step_count=len(request.step_configurations),
264
285
  client_environment=client_env,
265
286
  client_version=request.client_version,
266
287
  server_version=request.server_version,
@@ -278,6 +299,7 @@ class PipelineDeploymentSchema(BaseSchema, table=True):
278
299
  include_metadata: bool = False,
279
300
  include_resources: bool = False,
280
301
  include_python_packages: bool = False,
302
+ step_configuration_filter: Optional[List[str]] = None,
281
303
  **kwargs: Any,
282
304
  ) -> PipelineDeploymentResponse:
283
305
  """Convert a `PipelineDeploymentSchema` to a `PipelineDeploymentResponse`.
@@ -286,6 +308,9 @@ class PipelineDeploymentSchema(BaseSchema, table=True):
286
308
  include_metadata: Whether the metadata will be filled.
287
309
  include_resources: Whether the resources will be filled.
288
310
  include_python_packages: Whether the python packages will be filled.
311
+ step_configuration_filter: List of step configurations to include in
312
+ the response. If not given, all step configurations will be
313
+ included.
289
314
  **kwargs: Keyword arguments to allow schema specific logic
290
315
 
291
316
 
@@ -303,10 +328,13 @@ class PipelineDeploymentSchema(BaseSchema, table=True):
303
328
  pipeline_configuration = PipelineConfiguration.model_validate_json(
304
329
  self.pipeline_configuration
305
330
  )
306
- step_configurations = json.loads(self.step_configurations)
307
- for invocation_id, step in step_configurations.items():
308
- step_configurations[invocation_id] = Step.from_dict(
309
- step, pipeline_configuration
331
+ step_configurations = {}
332
+ for step_configuration in self.get_step_configurations(
333
+ include=step_configuration_filter
334
+ ):
335
+ step_configurations[step_configuration.name] = Step.from_dict(
336
+ json.loads(step_configuration.config),
337
+ pipeline_configuration,
310
338
  )
311
339
 
312
340
  client_environment = json.loads(self.client_environment)
@@ -348,3 +376,36 @@ class PipelineDeploymentSchema(BaseSchema, table=True):
348
376
  metadata=metadata,
349
377
  resources=resources,
350
378
  )
379
+
380
+
381
+ class StepConfigurationSchema(BaseSchema, table=True):
382
+ """SQL Model for step configurations."""
383
+
384
+ __tablename__ = "step_configuration"
385
+ __table_args__ = (
386
+ UniqueConstraint(
387
+ "deployment_id",
388
+ "name",
389
+ name="unique_step_name_for_deployment",
390
+ ),
391
+ )
392
+
393
+ index: int
394
+ name: str
395
+ config: str = Field(
396
+ sa_column=Column(
397
+ String(length=MEDIUMTEXT_MAX_LENGTH).with_variant(
398
+ MEDIUMTEXT, "mysql"
399
+ ),
400
+ nullable=False,
401
+ )
402
+ )
403
+
404
+ deployment_id: UUID = build_foreign_key_field(
405
+ source=__tablename__,
406
+ target=PipelineDeploymentSchema.__tablename__,
407
+ source_column="deployment_id",
408
+ target_column="id",
409
+ ondelete="CASCADE",
410
+ nullable=False,
411
+ )
@@ -389,19 +389,36 @@ class PipelineRunSchema(NamedSchema, RunMetadataInterface, table=True):
389
389
  RuntimeError: if the model creation fails.
390
390
  """
391
391
  if self.deployment is not None:
392
- deployment = self.deployment.to_model(
393
- include_metadata=True,
394
- include_python_packages=include_python_packages,
392
+ config = PipelineConfiguration.model_validate_json(
393
+ self.deployment.pipeline_configuration
395
394
  )
395
+ client_environment = json.loads(self.deployment.client_environment)
396
396
 
397
- config = deployment.pipeline_configuration
398
- client_environment = deployment.client_environment
399
-
400
- stack = deployment.stack
401
- pipeline = deployment.pipeline
402
- build = deployment.build
403
- schedule = deployment.schedule
404
- code_reference = deployment.code_reference
397
+ stack = (
398
+ self.deployment.stack.to_model()
399
+ if self.deployment.stack
400
+ else None
401
+ )
402
+ pipeline = (
403
+ self.deployment.pipeline.to_model()
404
+ if self.deployment.pipeline
405
+ else None
406
+ )
407
+ build = (
408
+ self.deployment.build.to_model()
409
+ if self.deployment.build
410
+ else None
411
+ )
412
+ schedule = (
413
+ self.deployment.schedule.to_model()
414
+ if self.deployment.schedule
415
+ else None
416
+ )
417
+ code_reference = (
418
+ self.deployment.code_reference.to_model()
419
+ if self.deployment.code_reference
420
+ else None
421
+ )
405
422
 
406
423
  elif self.pipeline_configuration is not None:
407
424
  config = PipelineConfiguration.model_validate_json(
@@ -290,10 +290,10 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True):
290
290
  """
291
291
  step = None
292
292
  if self.deployment is not None:
293
- step_configurations = json.loads(
294
- self.deployment.step_configurations
293
+ step_configurations = self.deployment.get_step_configurations(
294
+ include=[self.name]
295
295
  )
296
- if self.name in step_configurations:
296
+ if step_configurations:
297
297
  pipeline_configuration = (
298
298
  PipelineConfiguration.model_validate_json(
299
299
  self.deployment.pipeline_configuration
@@ -304,7 +304,7 @@ class StepRunSchema(NamedSchema, RunMetadataInterface, table=True):
304
304
  inplace=True,
305
305
  )
306
306
  step = Step.from_dict(
307
- step_configurations[self.name],
307
+ json.loads(step_configurations[0].config),
308
308
  pipeline_configuration=pipeline_configuration,
309
309
  )
310
310
  if not step and self.step_configuration: