zenml-nightly 0.71.0.dev20241212__py3-none-any.whl → 0.71.0.dev20241214__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. zenml/VERSION +1 -1
  2. zenml/artifacts/artifact_config.py +8 -5
  3. zenml/artifacts/utils.py +3 -1
  4. zenml/client.py +54 -2
  5. zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py +0 -1
  6. zenml/model/model.py +12 -16
  7. zenml/model/utils.py +3 -1
  8. zenml/models/v2/base/filter.py +26 -30
  9. zenml/models/v2/base/scoped.py +258 -5
  10. zenml/models/v2/core/artifact_version.py +15 -26
  11. zenml/models/v2/core/code_repository.py +1 -12
  12. zenml/models/v2/core/component.py +5 -46
  13. zenml/models/v2/core/flavor.py +1 -11
  14. zenml/models/v2/core/model.py +1 -57
  15. zenml/models/v2/core/model_version.py +5 -33
  16. zenml/models/v2/core/model_version_artifact.py +11 -3
  17. zenml/models/v2/core/model_version_pipeline_run.py +14 -3
  18. zenml/models/v2/core/pipeline.py +47 -55
  19. zenml/models/v2/core/pipeline_build.py +67 -12
  20. zenml/models/v2/core/pipeline_deployment.py +0 -10
  21. zenml/models/v2/core/pipeline_run.py +91 -29
  22. zenml/models/v2/core/run_template.py +21 -29
  23. zenml/models/v2/core/schedule.py +0 -10
  24. zenml/models/v2/core/secret.py +0 -14
  25. zenml/models/v2/core/service.py +9 -16
  26. zenml/models/v2/core/service_connector.py +0 -11
  27. zenml/models/v2/core/stack.py +21 -30
  28. zenml/models/v2/core/step_run.py +18 -14
  29. zenml/models/v2/core/trigger.py +19 -3
  30. zenml/orchestrators/step_launcher.py +9 -13
  31. zenml/orchestrators/step_run_utils.py +8 -204
  32. zenml/pipelines/build_utils.py +12 -0
  33. zenml/zen_server/rbac/rbac_sql_zen_store.py +173 -0
  34. zenml/zen_server/utils.py +4 -3
  35. zenml/zen_stores/base_zen_store.py +10 -2
  36. zenml/zen_stores/migrations/versions/26351d482b9e_add_step_run_unique_constraint.py +37 -0
  37. zenml/zen_stores/migrations/versions/a1237ba94fd8_add_model_version_producer_run_unique_.py +68 -0
  38. zenml/zen_stores/schemas/model_schemas.py +42 -6
  39. zenml/zen_stores/schemas/pipeline_deployment_schemas.py +7 -7
  40. zenml/zen_stores/schemas/pipeline_schemas.py +5 -0
  41. zenml/zen_stores/schemas/step_run_schemas.py +8 -1
  42. zenml/zen_stores/sql_zen_store.py +327 -99
  43. {zenml_nightly-0.71.0.dev20241212.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/METADATA +1 -1
  44. {zenml_nightly-0.71.0.dev20241212.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/RECORD +47 -44
  45. {zenml_nightly-0.71.0.dev20241212.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/LICENSE +0 -0
  46. {zenml_nightly-0.71.0.dev20241212.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/WHEEL +0 -0
  47. {zenml_nightly-0.71.0.dev20241212.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/entry_points.txt +0 -0
@@ -14,7 +14,7 @@
14
14
  """Utilities for creating step runs."""
15
15
 
16
16
  from datetime import datetime
17
- from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
17
+ from typing import Dict, List, Optional, Set, Tuple
18
18
 
19
19
  from zenml.client import Client
20
20
  from zenml.config.step_configurations import Step
@@ -24,21 +24,13 @@ from zenml.logger import get_logger
24
24
  from zenml.model.utils import link_artifact_version_to_model_version
25
25
  from zenml.models import (
26
26
  ArtifactVersionResponse,
27
- ModelVersionPipelineRunRequest,
28
27
  ModelVersionResponse,
29
28
  PipelineDeploymentResponse,
30
29
  PipelineRunResponse,
31
- PipelineRunUpdate,
32
30
  StepRunRequest,
33
- StepRunResponse,
34
- StepRunUpdate,
35
31
  )
36
32
  from zenml.orchestrators import cache_utils, input_utils, utils
37
33
  from zenml.stack import Stack
38
- from zenml.utils import pagination_utils, string_utils
39
-
40
- if TYPE_CHECKING:
41
- from zenml.model.model import Model
42
34
 
43
35
  logger = get_logger(__name__)
44
36
 
@@ -293,10 +285,6 @@ def create_cached_step_runs(
293
285
  deployment=deployment, pipeline_run=pipeline_run, stack=stack
294
286
  )
295
287
 
296
- pipeline_model_version, pipeline_run = prepare_pipeline_run_model_version(
297
- pipeline_run=pipeline_run
298
- )
299
-
300
288
  while (
301
289
  cache_candidates := find_cacheable_invocation_candidates(
302
290
  deployment=deployment,
@@ -311,7 +299,9 @@ def create_cached_step_runs(
311
299
 
312
300
  # Make sure the request factory has the most up to date pipeline
313
301
  # run to avoid hydration calls
314
- request_factory.pipeline_run = pipeline_run
302
+ request_factory.pipeline_run = Client().get_pipeline_run(
303
+ pipeline_run.id
304
+ )
315
305
  try:
316
306
  step_run_request = request_factory.create_request(
317
307
  invocation_id
@@ -336,15 +326,10 @@ def create_cached_step_runs(
336
326
 
337
327
  step_run = Client().zen_store.create_run_step(step_run_request)
338
328
 
339
- # Refresh the pipeline run here to make sure we have the latest
340
- # state
341
- pipeline_run = Client().get_pipeline_run(pipeline_run.id)
342
-
343
- step_model_version, step_run = prepare_step_run_model_version(
344
- step_run=step_run, pipeline_run=pipeline_run
345
- )
346
-
347
- if model_version := step_model_version or pipeline_model_version:
329
+ if (
330
+ model_version := step_run.model_version
331
+ or pipeline_run.model_version
332
+ ):
348
333
  link_output_artifacts_to_model_version(
349
334
  artifacts=step_run.outputs,
350
335
  model_version=model_version,
@@ -356,169 +341,6 @@ def create_cached_step_runs(
356
341
  return cached_invocations
357
342
 
358
343
 
359
- def get_or_create_model_version_for_pipeline_run(
360
- model: "Model",
361
- pipeline_run: PipelineRunResponse,
362
- substitutions: Dict[str, str],
363
- ) -> Tuple[ModelVersionResponse, bool]:
364
- """Get or create a model version as part of a pipeline run.
365
-
366
- Args:
367
- model: The model to get or create.
368
- pipeline_run: The pipeline run for which the model should be created.
369
- substitutions: Substitutions to apply to the model version name.
370
-
371
- Returns:
372
- The model version and a boolean indicating whether it was newly created
373
- or not.
374
- """
375
- # Copy the model before modifying it so we don't accidently modify
376
- # configurations in which the model object is potentially referenced
377
- model = model.model_copy()
378
-
379
- if model.model_version_id:
380
- return model._get_model_version(), False
381
- elif model.version:
382
- if isinstance(model.version, str):
383
- model.version = string_utils.format_name_template(
384
- model.version,
385
- substitutions=substitutions,
386
- )
387
- model.name = string_utils.format_name_template(
388
- model.name,
389
- substitutions=substitutions,
390
- )
391
-
392
- return (
393
- model._get_or_create_model_version(),
394
- model._created_model_version,
395
- )
396
-
397
- # The model version should be created as part of this run
398
- # -> We first check if it was already created as part of this run, and if
399
- # not we do create it. If this is running in two parallel steps, we might
400
- # run into issues that this will create two versions. Ideally, all model
401
- # versions required for a pipeline run and its steps could be created
402
- # server-side at run creation time before the first step starts.
403
- if model_version := get_model_version_created_by_pipeline_run(
404
- model_name=model.name, pipeline_run=pipeline_run
405
- ):
406
- return model_version, False
407
- else:
408
- return model._get_or_create_model_version(), True
409
-
410
-
411
- def get_model_version_created_by_pipeline_run(
412
- model_name: str, pipeline_run: PipelineRunResponse
413
- ) -> Optional[ModelVersionResponse]:
414
- """Get a model version that was created by a specific pipeline run.
415
-
416
- This function does not refresh the pipeline run, so it will only try to
417
- fetch the model version from existing steps if they're already part of the
418
- response.
419
-
420
- Args:
421
- model_name: The model name for which to get the version.
422
- pipeline_run: The pipeline run for which to get the version.
423
-
424
- Returns:
425
- A model version with the given name created by the run, or None if such
426
- a model version does not exist.
427
- """
428
- if pipeline_run.config.model and pipeline_run.model_version:
429
- if (
430
- pipeline_run.config.model.name == model_name
431
- and pipeline_run.config.model.version is None
432
- ):
433
- return pipeline_run.model_version
434
-
435
- # We fetch a list of hydrated step runs here in order to avoid hydration
436
- # calls for each step separately.
437
- candidate_step_runs = pagination_utils.depaginate(
438
- Client().list_run_steps,
439
- pipeline_run_id=pipeline_run.id,
440
- model=model_name,
441
- hydrate=True,
442
- )
443
- for step_run in candidate_step_runs:
444
- if step_run.config.model and step_run.model_version:
445
- if (
446
- step_run.config.model.name == model_name
447
- and step_run.config.model.version is None
448
- ):
449
- return step_run.model_version
450
-
451
- return None
452
-
453
-
454
- def prepare_pipeline_run_model_version(
455
- pipeline_run: PipelineRunResponse,
456
- ) -> Tuple[Optional[ModelVersionResponse], PipelineRunResponse]:
457
- """Prepare the model version for a pipeline run.
458
-
459
- Args:
460
- pipeline_run: The pipeline run for which to prepare the model version.
461
-
462
- Returns:
463
- The prepared model version and the updated pipeline run.
464
- """
465
- model_version = None
466
-
467
- if pipeline_run.model_version:
468
- model_version = pipeline_run.model_version
469
- elif config_model := pipeline_run.config.model:
470
- model_version, _ = get_or_create_model_version_for_pipeline_run(
471
- model=config_model,
472
- pipeline_run=pipeline_run,
473
- substitutions=pipeline_run.config.substitutions,
474
- )
475
- pipeline_run = Client().zen_store.update_run(
476
- run_id=pipeline_run.id,
477
- run_update=PipelineRunUpdate(model_version_id=model_version.id),
478
- )
479
- link_pipeline_run_to_model_version(
480
- pipeline_run=pipeline_run, model_version=model_version
481
- )
482
- log_model_version_dashboard_url(model_version)
483
-
484
- return model_version, pipeline_run
485
-
486
-
487
- def prepare_step_run_model_version(
488
- step_run: StepRunResponse, pipeline_run: PipelineRunResponse
489
- ) -> Tuple[Optional[ModelVersionResponse], StepRunResponse]:
490
- """Prepare the model version for a step run.
491
-
492
- Args:
493
- step_run: The step run for which to prepare the model version.
494
- pipeline_run: The pipeline run of the step.
495
-
496
- Returns:
497
- The prepared model version and the updated step run.
498
- """
499
- model_version = None
500
-
501
- if step_run.model_version:
502
- model_version = step_run.model_version
503
- elif config_model := step_run.config.model:
504
- model_version, created = get_or_create_model_version_for_pipeline_run(
505
- model=config_model,
506
- pipeline_run=pipeline_run,
507
- substitutions=step_run.config.substitutions,
508
- )
509
- step_run = Client().zen_store.update_run_step(
510
- step_run_id=step_run.id,
511
- step_run_update=StepRunUpdate(model_version_id=model_version.id),
512
- )
513
- link_pipeline_run_to_model_version(
514
- pipeline_run=pipeline_run, model_version=model_version
515
- )
516
- if created:
517
- log_model_version_dashboard_url(model_version)
518
-
519
- return model_version, step_run
520
-
521
-
522
344
  def log_model_version_dashboard_url(
523
345
  model_version: ModelVersionResponse,
524
346
  ) -> None:
@@ -546,24 +368,6 @@ def log_model_version_dashboard_url(
546
368
  )
547
369
 
548
370
 
549
- def link_pipeline_run_to_model_version(
550
- pipeline_run: PipelineRunResponse, model_version: ModelVersionResponse
551
- ) -> None:
552
- """Link a pipeline run to a model version.
553
-
554
- Args:
555
- pipeline_run: The pipeline run to link.
556
- model_version: The model version to link.
557
- """
558
- client = Client()
559
- client.zen_store.create_model_version_pipeline_run_link(
560
- ModelVersionPipelineRunRequest(
561
- pipeline_run=pipeline_run.id,
562
- model_version=model_version.id,
563
- )
564
- )
565
-
566
-
567
371
  def link_output_artifacts_to_model_version(
568
372
  artifacts: Dict[str, List[ArtifactVersionResponse]],
569
373
  model_version: ModelVersionResponse,
@@ -249,6 +249,11 @@ def find_existing_build(
249
249
  client = Client()
250
250
  stack = client.active_stack
251
251
 
252
+ if not stack.container_registry:
253
+ # There can be no non-local builds that we can reuse if there is no
254
+ # container registry in the stack.
255
+ return None
256
+
252
257
  python_version_prefix = ".".join(platform.python_version_tuple()[:2])
253
258
  required_builds = stack.get_docker_builds(deployment=deployment)
254
259
 
@@ -263,6 +268,13 @@ def find_existing_build(
263
268
  sort_by="desc:created",
264
269
  size=1,
265
270
  stack_id=stack.id,
271
+ # Until we implement stack versioning, users can still update their
272
+ # stack to update/remove the container registry. In that case, we might
273
+ # try to pull an image from a container registry that we don't have
274
+ # access to. This is why we add an additional check for the container
275
+ # registry ID here. (This is still not perfect as users can update the
276
+ # container registry URI or config, but the best we can do)
277
+ container_registry_id=stack.container_registry.id,
266
278
  # The build is local and it's not clear whether the images
267
279
  # exist on the current machine or if they've been overwritten.
268
280
  # TODO: Should we support this by storing the unique Docker ID for
@@ -0,0 +1,173 @@
1
+ # Copyright (c) ZenML GmbH 2024. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at:
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
12
+ # or implied. See the License for the specific language governing
13
+ # permissions and limitations under the License.
14
+ """RBAC SQL Zen Store implementation."""
15
+
16
+ from typing import (
17
+ Optional,
18
+ Tuple,
19
+ )
20
+ from uuid import UUID
21
+
22
+ from zenml.logger import get_logger
23
+ from zenml.models import (
24
+ ModelRequest,
25
+ ModelResponse,
26
+ ModelVersionRequest,
27
+ ModelVersionResponse,
28
+ )
29
+ from zenml.zen_server.feature_gate.endpoint_utils import (
30
+ check_entitlement,
31
+ report_usage,
32
+ )
33
+ from zenml.zen_server.rbac.models import Action, ResourceType
34
+ from zenml.zen_server.rbac.utils import (
35
+ verify_permission,
36
+ verify_permission_for_model,
37
+ )
38
+ from zenml.zen_stores.sql_zen_store import SqlZenStore
39
+
40
+ logger = get_logger(__name__)
41
+
42
+
43
+ class RBACSqlZenStore(SqlZenStore):
44
+ """Wrapper around the SQLZenStore that implements RBAC functionality."""
45
+
46
+ def _get_or_create_model(
47
+ self, model_request: ModelRequest
48
+ ) -> Tuple[bool, ModelResponse]:
49
+ """Get or create a model.
50
+
51
+ Args:
52
+ model_request: The model request.
53
+
54
+ # noqa: DAR401
55
+ Raises:
56
+ Exception: If the user is not allowed to create a model.
57
+
58
+ Returns:
59
+ A boolean whether the model was created or not, and the model.
60
+ """
61
+ allow_model_creation = True
62
+ error = None
63
+
64
+ try:
65
+ verify_permission(
66
+ resource_type=ResourceType.MODEL, action=Action.CREATE
67
+ )
68
+ check_entitlement(resource_type=ResourceType.MODEL)
69
+ except Exception as e:
70
+ allow_model_creation = False
71
+ error = e
72
+
73
+ if allow_model_creation:
74
+ created, model_response = super()._get_or_create_model(
75
+ model_request
76
+ )
77
+ else:
78
+ try:
79
+ model_response = self.get_model(model_request.name)
80
+ created = False
81
+ except KeyError:
82
+ # The model does not exist. We now raise the error that
83
+ # explains why the model could not be created, instead of just
84
+ # the KeyError that it doesn't exist
85
+ assert error
86
+ raise error from None
87
+
88
+ if created:
89
+ report_usage(
90
+ resource_type=ResourceType.MODEL, resource_id=model_response.id
91
+ )
92
+ else:
93
+ verify_permission_for_model(model_response, action=Action.READ)
94
+
95
+ return created, model_response
96
+
97
+ def _get_model_version(
98
+ self,
99
+ model_id: UUID,
100
+ version_name: Optional[str] = None,
101
+ producer_run_id: Optional[UUID] = None,
102
+ ) -> ModelVersionResponse:
103
+ """Get a model version.
104
+
105
+ Args:
106
+ model_id: The ID of the model.
107
+ version_name: The name of the model version.
108
+ producer_run_id: The ID of the producer pipeline run. If this is
109
+ set, only numeric versions created as part of the pipeline run
110
+ will be returned.
111
+
112
+ Returns:
113
+ The model version.
114
+ """
115
+ model_version = super()._get_model_version(
116
+ model_id=model_id,
117
+ version_name=version_name,
118
+ producer_run_id=producer_run_id,
119
+ )
120
+ verify_permission_for_model(model_version, action=Action.READ)
121
+ return model_version
122
+
123
+ def _get_or_create_model_version(
124
+ self,
125
+ model_version_request: ModelVersionRequest,
126
+ producer_run_id: Optional[UUID] = None,
127
+ ) -> Tuple[bool, ModelVersionResponse]:
128
+ """Get or create a model version.
129
+
130
+ Args:
131
+ model_version_request: The model version request.
132
+ producer_run_id: ID of the producer pipeline run.
133
+
134
+ # noqa: DAR401
135
+ Raises:
136
+ Exception: If the authenticated user is not allowed to
137
+ create a model version.
138
+
139
+ Returns:
140
+ A boolean whether the model version was created or not, and the
141
+ model version.
142
+ """
143
+ allow_creation = True
144
+ error = None
145
+
146
+ try:
147
+ verify_permission(
148
+ resource_type=ResourceType.MODEL_VERSION, action=Action.CREATE
149
+ )
150
+ except Exception as e:
151
+ allow_creation = False
152
+ error = e
153
+
154
+ if allow_creation:
155
+ created, model_version_response = (
156
+ super()._get_or_create_model_version(model_version_request, producer_run_id=producer_run_id)
157
+ )
158
+ else:
159
+ try:
160
+ model_version_response = self._get_model_version(
161
+ model_id=model_version_request.model,
162
+ version_name=model_version_request.name,
163
+ producer_run_id=producer_run_id,
164
+ )
165
+ created = False
166
+ except KeyError:
167
+ # The model version does not exist. We now raise the error that
168
+ # explains why the version could not be created, instead of just
169
+ # the KeyError that it doesn't exist
170
+ assert error
171
+ raise error from None
172
+
173
+ return created, model_version_response
zenml/zen_server/utils.py CHANGED
@@ -421,6 +421,8 @@ def make_dependable(cls: Type[BaseModel]) -> Callable[..., Any]:
421
421
  """
422
422
  from fastapi import Query
423
423
 
424
+ from zenml.zen_server.exceptions import error_detail
425
+
424
426
  def init_cls_and_handle_errors(*args: Any, **kwargs: Any) -> BaseModel:
425
427
  from fastapi import HTTPException
426
428
 
@@ -428,9 +430,8 @@ def make_dependable(cls: Type[BaseModel]) -> Callable[..., Any]:
428
430
  inspect.signature(init_cls_and_handle_errors).bind(*args, **kwargs)
429
431
  return cls(*args, **kwargs)
430
432
  except ValidationError as e:
431
- for error in e.errors():
432
- error["loc"] = tuple(["query"] + list(error["loc"]))
433
- raise HTTPException(422, detail=e.errors())
433
+ detail = error_detail(e, exception_type=ValueError)
434
+ raise HTTPException(422, detail=detail)
434
435
 
435
436
  params = {v.name: v for v in inspect.signature(cls).parameters.values()}
436
437
  query_params = getattr(cls, "API_MULTI_INPUT_PARAMS", [])
@@ -36,6 +36,7 @@ from zenml.constants import (
36
36
  DEFAULT_STACK_AND_COMPONENT_NAME,
37
37
  DEFAULT_WORKSPACE_NAME,
38
38
  ENV_ZENML_DEFAULT_WORKSPACE_NAME,
39
+ ENV_ZENML_SERVER,
39
40
  IS_DEBUG_ENV,
40
41
  )
41
42
  from zenml.enums import (
@@ -155,9 +156,16 @@ class BaseZenStore(
155
156
  TypeError: If the store type is unsupported.
156
157
  """
157
158
  if store_type == StoreType.SQL:
158
- from zenml.zen_stores.sql_zen_store import SqlZenStore
159
+ if os.environ.get(ENV_ZENML_SERVER):
160
+ from zenml.zen_server.rbac.rbac_sql_zen_store import (
161
+ RBACSqlZenStore,
162
+ )
163
+
164
+ return RBACSqlZenStore
165
+ else:
166
+ from zenml.zen_stores.sql_zen_store import SqlZenStore
159
167
 
160
- return SqlZenStore
168
+ return SqlZenStore
161
169
  elif store_type == StoreType.REST:
162
170
  from zenml.zen_stores.rest_zen_store import RestZenStore
163
171
 
@@ -0,0 +1,37 @@
1
+ """Add step run unique constraint [26351d482b9e].
2
+
3
+ Revision ID: 26351d482b9e
4
+ Revises: 0.71.0
5
+ Create Date: 2024-12-03 11:46:57.541578
6
+
7
+ """
8
+
9
+ from alembic import op
10
+
11
+ # revision identifiers, used by Alembic.
12
+ revision = "26351d482b9e"
13
+ down_revision = "0.71.0"
14
+ branch_labels = None
15
+ depends_on = None
16
+
17
+
18
+ def upgrade() -> None:
19
+ """Upgrade database schema and/or data, creating a new revision."""
20
+ # ### commands auto generated by Alembic - please adjust! ###
21
+ with op.batch_alter_table("step_run", schema=None) as batch_op:
22
+ batch_op.create_unique_constraint(
23
+ "unique_step_name_for_pipeline_run", ["name", "pipeline_run_id"]
24
+ )
25
+
26
+ # ### end Alembic commands ###
27
+
28
+
29
+ def downgrade() -> None:
30
+ """Downgrade database schema and/or data back to the previous revision."""
31
+ # ### commands auto generated by Alembic - please adjust! ###
32
+ with op.batch_alter_table("step_run", schema=None) as batch_op:
33
+ batch_op.drop_constraint(
34
+ "unique_step_name_for_pipeline_run", type_="unique"
35
+ )
36
+
37
+ # ### end Alembic commands ###
@@ -0,0 +1,68 @@
1
+ """Add model version producer run unique constraint [a1237ba94fd8].
2
+
3
+ Revision ID: a1237ba94fd8
4
+ Revises: 26351d482b9e
5
+ Create Date: 2024-12-13 10:28:55.432414
6
+
7
+ """
8
+
9
+ import sqlalchemy as sa
10
+ import sqlmodel
11
+ from alembic import op
12
+
13
+ # revision identifiers, used by Alembic.
14
+ revision = "a1237ba94fd8"
15
+ down_revision = "26351d482b9e"
16
+ branch_labels = None
17
+ depends_on = None
18
+
19
+
20
+ def upgrade() -> None:
21
+ """Upgrade database schema and/or data, creating a new revision."""
22
+ # ### commands auto generated by Alembic - please adjust! ###
23
+ with op.batch_alter_table("model_version", schema=None) as batch_op:
24
+ batch_op.add_column(
25
+ sa.Column(
26
+ "producer_run_id_if_numeric",
27
+ sqlmodel.sql.sqltypes.GUID(),
28
+ nullable=True,
29
+ )
30
+ )
31
+
32
+ # Set the producer_run_id_if_numeric column to the model version ID for
33
+ # existing rows
34
+ connection = op.get_bind()
35
+ metadata = sa.MetaData()
36
+ metadata.reflect(only=("model_version",), bind=connection)
37
+ model_version_table = sa.Table("model_version", metadata)
38
+
39
+ connection.execute(
40
+ model_version_table.update().values(
41
+ producer_run_id_if_numeric=model_version_table.c.id
42
+ )
43
+ )
44
+
45
+ with op.batch_alter_table("model_version", schema=None) as batch_op:
46
+ batch_op.alter_column(
47
+ "producer_run_id_if_numeric",
48
+ existing_type=sqlmodel.sql.sqltypes.GUID(),
49
+ nullable=False,
50
+ )
51
+ batch_op.create_unique_constraint(
52
+ "unique_numeric_version_for_pipeline_run",
53
+ ["model_id", "producer_run_id_if_numeric"],
54
+ )
55
+
56
+ # ### end Alembic commands ###
57
+
58
+
59
+ def downgrade() -> None:
60
+ """Downgrade database schema and/or data back to the previous revision."""
61
+ # ### commands auto generated by Alembic - please adjust! ###
62
+ with op.batch_alter_table("model_version", schema=None) as batch_op:
63
+ batch_op.drop_constraint(
64
+ "unique_numeric_version_for_pipeline_run", type_="unique"
65
+ )
66
+ batch_op.drop_column("producer_run_id_if_numeric")
67
+
68
+ # ### end Alembic commands ###