zenml-nightly 0.68.1.dev20241111__py3-none-any.whl → 0.68.1.dev20241112__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,192 @@
1
+ """Add pipeline, model and run unique constraints [904464ea4041].
2
+
3
+ Revision ID: 904464ea4041
4
+ Revises: b557b2871693
5
+ Create Date: 2024-11-04 10:27:05.450092
6
+
7
+ """
8
+
9
+ from collections import defaultdict
10
+ from typing import Any, Dict, Set
11
+
12
+ import sqlalchemy as sa
13
+ from alembic import op
14
+
15
+ from zenml.logger import get_logger
16
+
17
+ logger = get_logger(__name__)
18
+
19
+ # revision identifiers, used by Alembic.
20
+ revision = "904464ea4041"
21
+ down_revision = "b557b2871693"
22
+ branch_labels = None
23
+ depends_on = None
24
+
25
+
26
+ def resolve_duplicate_entities() -> None:
27
+ """Resolve duplicate entities."""
28
+ connection = op.get_bind()
29
+ meta = sa.MetaData()
30
+ meta.reflect(
31
+ bind=connection,
32
+ only=("pipeline_run", "pipeline", "model", "model_version"),
33
+ )
34
+
35
+ # Remove duplicate names for runs, pipelines and models
36
+ for table_name in ["pipeline_run", "pipeline", "model"]:
37
+ table = sa.Table(table_name, meta)
38
+ result = connection.execute(
39
+ sa.select(table.c.id, table.c.name, table.c.workspace_id)
40
+ ).all()
41
+ existing: Dict[str, Set[str]] = defaultdict(set)
42
+
43
+ for id_, name, workspace_id in result:
44
+ names_in_workspace = existing[workspace_id]
45
+
46
+ if name in names_in_workspace:
47
+ new_name = f"{name}_{id_[:6]}"
48
+ logger.warning(
49
+ "Migrating %s name from %s to %s to resolve duplicate name.",
50
+ table_name,
51
+ name,
52
+ new_name,
53
+ )
54
+ connection.execute(
55
+ sa.update(table)
56
+ .where(table.c.id == id_)
57
+ .values(name=new_name)
58
+ )
59
+ names_in_workspace.add(new_name)
60
+ else:
61
+ names_in_workspace.add(name)
62
+
63
+ # Remove duplicate names and version numbers for model versions
64
+ model_version_table = sa.Table("model_version", meta)
65
+ result = connection.execute(
66
+ sa.select(
67
+ model_version_table.c.id,
68
+ model_version_table.c.name,
69
+ model_version_table.c.number,
70
+ model_version_table.c.model_id,
71
+ )
72
+ ).all()
73
+
74
+ existing_names: Dict[str, Set[str]] = defaultdict(set)
75
+ existing_numbers: Dict[str, Set[int]] = defaultdict(set)
76
+
77
+ needs_update = []
78
+
79
+ for id_, name, number, model_id in result:
80
+ names_for_model = existing_names[model_id]
81
+ numbers_for_model = existing_numbers[model_id]
82
+
83
+ needs_new_name = name in names_for_model
84
+ needs_new_number = number in numbers_for_model
85
+
86
+ if needs_new_name or needs_new_number:
87
+ needs_update.append(
88
+ (id_, name, number, model_id, needs_new_name, needs_new_number)
89
+ )
90
+
91
+ names_for_model.add(name)
92
+ numbers_for_model.add(number)
93
+
94
+ for (
95
+ id_,
96
+ name,
97
+ number,
98
+ model_id,
99
+ needs_new_name,
100
+ needs_new_number,
101
+ ) in needs_update:
102
+ values: Dict[str, Any] = {}
103
+
104
+ is_numeric_version = str(number) == name
105
+ next_numeric_version = max(existing_numbers[model_id]) + 1
106
+
107
+ if is_numeric_version:
108
+ # No matter if the name or number clashes, we need to update both
109
+ values["number"] = next_numeric_version
110
+ values["name"] = str(next_numeric_version)
111
+ existing_numbers[model_id].add(next_numeric_version)
112
+ logger.warning(
113
+ "Migrating model version %s to %s to resolve duplicate name.",
114
+ name,
115
+ values["name"],
116
+ )
117
+ else:
118
+ if needs_new_name:
119
+ values["name"] = f"{name}_{id_[:6]}"
120
+ logger.warning(
121
+ "Migrating model version %s to %s to resolve duplicate name.",
122
+ name,
123
+ values["name"],
124
+ )
125
+
126
+ if needs_new_number:
127
+ values["number"] = next_numeric_version
128
+ existing_numbers[model_id].add(next_numeric_version)
129
+
130
+ connection.execute(
131
+ sa.update(model_version_table)
132
+ .where(model_version_table.c.id == id_)
133
+ .values(**values)
134
+ )
135
+
136
+
137
+ def upgrade() -> None:
138
+ """Upgrade database schema and/or data, creating a new revision."""
139
+ # ### commands auto generated by Alembic - please adjust! ###
140
+
141
+ resolve_duplicate_entities()
142
+
143
+ with op.batch_alter_table("pipeline", schema=None) as batch_op:
144
+ batch_op.create_unique_constraint(
145
+ "unique_pipeline_name_in_workspace", ["name", "workspace_id"]
146
+ )
147
+
148
+ with op.batch_alter_table("pipeline_run", schema=None) as batch_op:
149
+ batch_op.create_unique_constraint(
150
+ "unique_run_name_in_workspace", ["name", "workspace_id"]
151
+ )
152
+
153
+ with op.batch_alter_table("model", schema=None) as batch_op:
154
+ batch_op.create_unique_constraint(
155
+ "unique_model_name_in_workspace", ["name", "workspace_id"]
156
+ )
157
+
158
+ with op.batch_alter_table("model_version", schema=None) as batch_op:
159
+ batch_op.create_unique_constraint(
160
+ "unique_version_for_model_id", ["name", "model_id"]
161
+ )
162
+ batch_op.create_unique_constraint(
163
+ "unique_version_number_for_model_id", ["number", "model_id"]
164
+ )
165
+ # ### end Alembic commands ###
166
+
167
+
168
+ def downgrade() -> None:
169
+ """Downgrade database schema and/or data back to the previous revision."""
170
+ # ### commands auto generated by Alembic - please adjust! ###
171
+ with op.batch_alter_table("model_version", schema=None) as batch_op:
172
+ batch_op.drop_constraint(
173
+ "unique_version_number_for_model_id", type_="unique"
174
+ )
175
+ batch_op.drop_constraint("unique_version_for_model_id", type_="unique")
176
+
177
+ with op.batch_alter_table("model", schema=None) as batch_op:
178
+ batch_op.drop_constraint(
179
+ "unique_model_name_in_workspace", type_="unique"
180
+ )
181
+
182
+ with op.batch_alter_table("pipeline_run", schema=None) as batch_op:
183
+ batch_op.drop_constraint(
184
+ "unique_run_name_in_workspace", type_="unique"
185
+ )
186
+
187
+ with op.batch_alter_table("pipeline", schema=None) as batch_op:
188
+ batch_op.drop_constraint(
189
+ "unique_pipeline_name_in_workspace", type_="unique"
190
+ )
191
+
192
+ # ### end Alembic commands ###
@@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, cast
19
19
  from uuid import UUID
20
20
 
21
21
  from pydantic import ConfigDict
22
- from sqlalchemy import BOOLEAN, INTEGER, TEXT, Column
22
+ from sqlalchemy import BOOLEAN, INTEGER, TEXT, Column, UniqueConstraint
23
23
  from sqlmodel import Field, Relationship
24
24
 
25
25
  from zenml.enums import MetadataResourceTypes, TaggableResourceTypes
@@ -62,6 +62,13 @@ class ModelSchema(NamedSchema, table=True):
62
62
  """SQL Model for model."""
63
63
 
64
64
  __tablename__ = "model"
65
+ __table_args__ = (
66
+ UniqueConstraint(
67
+ "name",
68
+ "workspace_id",
69
+ name="unique_model_name_in_workspace",
70
+ ),
71
+ )
65
72
 
66
73
  workspace_id: UUID = build_foreign_key_field(
67
74
  source=__tablename__,
@@ -220,6 +227,23 @@ class ModelVersionSchema(NamedSchema, table=True):
220
227
  """SQL Model for model version."""
221
228
 
222
229
  __tablename__ = MODEL_VERSION_TABLENAME
230
+ __table_args__ = (
231
+ # We need two unique constraints here:
232
+ # - The first to ensure that each model version for a
233
+ # model has a unique version number
234
+ # - The second one to ensure that explicit names given by
235
+ # users are unique
236
+ UniqueConstraint(
237
+ "number",
238
+ "model_id",
239
+ name="unique_version_number_for_model_id",
240
+ ),
241
+ UniqueConstraint(
242
+ "name",
243
+ "model_id",
244
+ name="unique_version_for_model_id",
245
+ ),
246
+ )
223
247
 
224
248
  workspace_id: UUID = build_foreign_key_field(
225
249
  source=__tablename__,
@@ -72,6 +72,11 @@ class PipelineRunSchema(NamedSchema, table=True):
72
72
  "orchestrator_run_id",
73
73
  name="unique_orchestrator_run_id_for_deployment_id",
74
74
  ),
75
+ UniqueConstraint(
76
+ "name",
77
+ "workspace_id",
78
+ name="unique_run_name_in_workspace",
79
+ ),
75
80
  )
76
81
 
77
82
  # Fields
@@ -17,7 +17,7 @@ from datetime import datetime
17
17
  from typing import TYPE_CHECKING, Any, List, Optional
18
18
  from uuid import UUID
19
19
 
20
- from sqlalchemy import TEXT, Column
20
+ from sqlalchemy import TEXT, Column, UniqueConstraint
21
21
  from sqlmodel import Field, Relationship
22
22
 
23
23
  from zenml.enums import TaggableResourceTypes
@@ -50,7 +50,13 @@ class PipelineSchema(NamedSchema, table=True):
50
50
  """SQL Model for pipelines."""
51
51
 
52
52
  __tablename__ = "pipeline"
53
-
53
+ __table_args__ = (
54
+ UniqueConstraint(
55
+ "name",
56
+ "workspace_id",
57
+ name="unique_pipeline_name_in_workspace",
58
+ ),
59
+ )
54
60
  # Fields
55
61
  description: Optional[str] = Field(sa_column=Column(TEXT, nullable=True))
56
62