zenml-nightly 0.71.0.dev20241212__py3-none-any.whl → 0.71.0.dev20241214__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/artifacts/artifact_config.py +8 -5
- zenml/artifacts/utils.py +3 -1
- zenml/client.py +54 -2
- zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py +0 -1
- zenml/model/model.py +12 -16
- zenml/model/utils.py +3 -1
- zenml/models/v2/base/filter.py +26 -30
- zenml/models/v2/base/scoped.py +258 -5
- zenml/models/v2/core/artifact_version.py +15 -26
- zenml/models/v2/core/code_repository.py +1 -12
- zenml/models/v2/core/component.py +5 -46
- zenml/models/v2/core/flavor.py +1 -11
- zenml/models/v2/core/model.py +1 -57
- zenml/models/v2/core/model_version.py +5 -33
- zenml/models/v2/core/model_version_artifact.py +11 -3
- zenml/models/v2/core/model_version_pipeline_run.py +14 -3
- zenml/models/v2/core/pipeline.py +47 -55
- zenml/models/v2/core/pipeline_build.py +67 -12
- zenml/models/v2/core/pipeline_deployment.py +0 -10
- zenml/models/v2/core/pipeline_run.py +91 -29
- zenml/models/v2/core/run_template.py +21 -29
- zenml/models/v2/core/schedule.py +0 -10
- zenml/models/v2/core/secret.py +0 -14
- zenml/models/v2/core/service.py +9 -16
- zenml/models/v2/core/service_connector.py +0 -11
- zenml/models/v2/core/stack.py +21 -30
- zenml/models/v2/core/step_run.py +18 -14
- zenml/models/v2/core/trigger.py +19 -3
- zenml/orchestrators/step_launcher.py +9 -13
- zenml/orchestrators/step_run_utils.py +8 -204
- zenml/pipelines/build_utils.py +12 -0
- zenml/zen_server/rbac/rbac_sql_zen_store.py +173 -0
- zenml/zen_server/utils.py +4 -3
- zenml/zen_stores/base_zen_store.py +10 -2
- zenml/zen_stores/migrations/versions/26351d482b9e_add_step_run_unique_constraint.py +37 -0
- zenml/zen_stores/migrations/versions/a1237ba94fd8_add_model_version_producer_run_unique_.py +68 -0
- zenml/zen_stores/schemas/model_schemas.py +42 -6
- zenml/zen_stores/schemas/pipeline_deployment_schemas.py +7 -7
- zenml/zen_stores/schemas/pipeline_schemas.py +5 -0
- zenml/zen_stores/schemas/step_run_schemas.py +8 -1
- zenml/zen_stores/sql_zen_store.py +327 -99
- {zenml_nightly-0.71.0.dev20241212.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/METADATA +1 -1
- {zenml_nightly-0.71.0.dev20241212.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/RECORD +47 -44
- {zenml_nightly-0.71.0.dev20241212.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.71.0.dev20241212.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.71.0.dev20241212.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/entry_points.txt +0 -0
@@ -55,7 +55,7 @@ from pydantic import (
|
|
55
55
|
field_validator,
|
56
56
|
model_validator,
|
57
57
|
)
|
58
|
-
from sqlalchemy import
|
58
|
+
from sqlalchemy import func
|
59
59
|
from sqlalchemy.engine import URL, Engine, make_url
|
60
60
|
from sqlalchemy.exc import (
|
61
61
|
ArgumentError,
|
@@ -71,6 +71,7 @@ from sqlmodel import (
|
|
71
71
|
col,
|
72
72
|
create_engine,
|
73
73
|
delete,
|
74
|
+
desc,
|
74
75
|
or_,
|
75
76
|
select,
|
76
77
|
)
|
@@ -100,7 +101,6 @@ from zenml.constants import (
|
|
100
101
|
ENV_ZENML_SERVER,
|
101
102
|
FINISHED_ONBOARDING_SURVEY_KEY,
|
102
103
|
MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION,
|
103
|
-
SORT_PIPELINES_BY_LATEST_RUN_KEY,
|
104
104
|
SQL_STORE_BACKUP_DIRECTORY_NAME,
|
105
105
|
TEXT_FIELD_MAX_LENGTH,
|
106
106
|
handle_bool_env_var,
|
@@ -117,7 +117,6 @@ from zenml.enums import (
|
|
117
117
|
OnboardingStep,
|
118
118
|
SecretScope,
|
119
119
|
SecretsStoreType,
|
120
|
-
SorterOps,
|
121
120
|
StackComponentType,
|
122
121
|
StackDeploymentProvider,
|
123
122
|
StepRunInputArtifactType,
|
@@ -298,7 +297,11 @@ from zenml.utils.networking_utils import (
|
|
298
297
|
replace_localhost_with_internal_hostname,
|
299
298
|
)
|
300
299
|
from zenml.utils.pydantic_utils import before_validator_handler
|
301
|
-
from zenml.utils.string_utils import
|
300
|
+
from zenml.utils.string_utils import (
|
301
|
+
format_name_template,
|
302
|
+
random_str,
|
303
|
+
validate_name,
|
304
|
+
)
|
302
305
|
from zenml.zen_stores import template_utils
|
303
306
|
from zenml.zen_stores.base_zen_store import (
|
304
307
|
BaseZenStore,
|
@@ -4358,69 +4361,14 @@ class SqlZenStore(BaseZenStore):
|
|
4358
4361
|
Returns:
|
4359
4362
|
A list of all pipelines matching the filter criteria.
|
4360
4363
|
"""
|
4361
|
-
query: Union[Select[Any], SelectOfScalar[Any]] = select(PipelineSchema)
|
4362
|
-
_custom_conversion: Optional[Callable[[Any], PipelineResponse]] = None
|
4363
|
-
|
4364
|
-
column, operand = pipeline_filter_model.sorting_params
|
4365
|
-
if column == SORT_PIPELINES_BY_LATEST_RUN_KEY:
|
4366
|
-
with Session(self.engine) as session:
|
4367
|
-
max_date_subquery = (
|
4368
|
-
# If no run exists for the pipeline yet, we use the pipeline
|
4369
|
-
# creation date as a fallback, otherwise newly created
|
4370
|
-
# pipeline would always be at the top/bottom
|
4371
|
-
select(
|
4372
|
-
PipelineSchema.id,
|
4373
|
-
case(
|
4374
|
-
(
|
4375
|
-
func.max(PipelineRunSchema.created).is_(None),
|
4376
|
-
PipelineSchema.created,
|
4377
|
-
),
|
4378
|
-
else_=func.max(PipelineRunSchema.created),
|
4379
|
-
).label("run_or_created"),
|
4380
|
-
)
|
4381
|
-
.outerjoin(
|
4382
|
-
PipelineRunSchema,
|
4383
|
-
PipelineSchema.id == PipelineRunSchema.pipeline_id, # type: ignore[arg-type]
|
4384
|
-
)
|
4385
|
-
.group_by(col(PipelineSchema.id))
|
4386
|
-
.subquery()
|
4387
|
-
)
|
4388
|
-
|
4389
|
-
if operand == SorterOps.DESCENDING:
|
4390
|
-
sort_clause = desc
|
4391
|
-
else:
|
4392
|
-
sort_clause = asc
|
4393
|
-
|
4394
|
-
query = (
|
4395
|
-
# We need to include the subquery in the select here to
|
4396
|
-
# make this query work with the distinct statement. This
|
4397
|
-
# result will be removed in the custom conversion function
|
4398
|
-
# applied later
|
4399
|
-
select(PipelineSchema, max_date_subquery.c.run_or_created)
|
4400
|
-
.where(PipelineSchema.id == max_date_subquery.c.id)
|
4401
|
-
.order_by(sort_clause(max_date_subquery.c.run_or_created))
|
4402
|
-
# We always add the `id` column as a tiebreaker to ensure a
|
4403
|
-
# stable, repeatable order of items, otherwise subsequent
|
4404
|
-
# pages might contain the same items.
|
4405
|
-
.order_by(col(PipelineSchema.id))
|
4406
|
-
)
|
4407
|
-
|
4408
|
-
def _custom_conversion(row: Any) -> PipelineResponse:
|
4409
|
-
return cast(
|
4410
|
-
PipelineResponse,
|
4411
|
-
row[0].to_model(
|
4412
|
-
include_metadata=hydrate, include_resources=True
|
4413
|
-
),
|
4414
|
-
)
|
4415
|
-
|
4416
4364
|
with Session(self.engine) as session:
|
4365
|
+
query = select(PipelineSchema)
|
4417
4366
|
return self.filter_and_paginate(
|
4418
4367
|
session=session,
|
4419
4368
|
query=query,
|
4420
4369
|
table=PipelineSchema,
|
4421
4370
|
filter_model=pipeline_filter_model,
|
4422
4371
|
hydrate=hydrate,
|
4423
|
-
custom_schema_to_model_conversion=_custom_conversion,
|
4424
4372
|
)
|
4425
4373
|
|
4426
4374
|
def count_pipelines(self, filter_model: Optional[PipelineFilter]) -> int:
|
@@ -5211,6 +5159,20 @@ class SqlZenStore(BaseZenStore):
|
|
5211
5159
|
"already exists."
|
5212
5160
|
)
|
5213
5161
|
|
5162
|
+
if model_version_id := self._get_or_create_model_version_for_run(
|
5163
|
+
new_run
|
5164
|
+
):
|
5165
|
+
new_run.model_version_id = model_version_id
|
5166
|
+
session.add(new_run)
|
5167
|
+
session.commit()
|
5168
|
+
|
5169
|
+
self.create_model_version_pipeline_run_link(
|
5170
|
+
ModelVersionPipelineRunRequest(
|
5171
|
+
model_version=model_version_id, pipeline_run=new_run.id
|
5172
|
+
)
|
5173
|
+
)
|
5174
|
+
session.refresh(new_run)
|
5175
|
+
|
5214
5176
|
return new_run.to_model(
|
5215
5177
|
include_metadata=True, include_resources=True
|
5216
5178
|
)
|
@@ -8167,25 +8129,17 @@ class SqlZenStore(BaseZenStore):
|
|
8167
8129
|
f"with ID '{step_run.pipeline_run_id}' found."
|
8168
8130
|
)
|
8169
8131
|
|
8170
|
-
|
8171
|
-
|
8172
|
-
|
8173
|
-
.
|
8174
|
-
|
8175
|
-
StepRunSchema.pipeline_run_id == step_run.pipeline_run_id
|
8176
|
-
)
|
8177
|
-
).first()
|
8178
|
-
if existing_step_run is not None:
|
8132
|
+
step_schema = StepRunSchema.from_request(step_run)
|
8133
|
+
session.add(step_schema)
|
8134
|
+
try:
|
8135
|
+
session.commit()
|
8136
|
+
except IntegrityError:
|
8179
8137
|
raise EntityExistsError(
|
8180
8138
|
f"Unable to create step `{step_run.name}`: A step with "
|
8181
8139
|
f"this name already exists in the pipeline run with ID "
|
8182
8140
|
f"'{step_run.pipeline_run_id}'."
|
8183
8141
|
)
|
8184
8142
|
|
8185
|
-
# Create the step
|
8186
|
-
step_schema = StepRunSchema.from_request(step_run)
|
8187
|
-
session.add(step_schema)
|
8188
|
-
|
8189
8143
|
# Add logs entry for the step if exists
|
8190
8144
|
if step_run.logs is not None:
|
8191
8145
|
log_entry = LogsSchema(
|
@@ -8281,6 +8235,21 @@ class SqlZenStore(BaseZenStore):
|
|
8281
8235
|
session.commit()
|
8282
8236
|
session.refresh(step_schema)
|
8283
8237
|
|
8238
|
+
if model_version_id := self._get_or_create_model_version_for_run(
|
8239
|
+
step_schema
|
8240
|
+
):
|
8241
|
+
step_schema.model_version_id = model_version_id
|
8242
|
+
session.add(step_schema)
|
8243
|
+
session.commit()
|
8244
|
+
|
8245
|
+
self.create_model_version_pipeline_run_link(
|
8246
|
+
ModelVersionPipelineRunRequest(
|
8247
|
+
model_version=model_version_id,
|
8248
|
+
pipeline_run=step_schema.pipeline_run_id,
|
8249
|
+
)
|
8250
|
+
)
|
8251
|
+
session.refresh(step_schema)
|
8252
|
+
|
8284
8253
|
return step_schema.to_model(
|
8285
8254
|
include_metadata=True, include_resources=True
|
8286
8255
|
)
|
@@ -10283,6 +10252,22 @@ class SqlZenStore(BaseZenStore):
|
|
10283
10252
|
|
10284
10253
|
# ----------------------------- Model Versions -----------------------------
|
10285
10254
|
|
10255
|
+
def _get_or_create_model(
|
10256
|
+
self, model_request: ModelRequest
|
10257
|
+
) -> Tuple[bool, ModelResponse]:
|
10258
|
+
"""Get or create a model.
|
10259
|
+
|
10260
|
+
Args:
|
10261
|
+
model_request: The model request.
|
10262
|
+
|
10263
|
+
Returns:
|
10264
|
+
A boolean whether the model was created or not, and the model.
|
10265
|
+
"""
|
10266
|
+
try:
|
10267
|
+
return True, self.create_model(model_request)
|
10268
|
+
except EntityExistsError:
|
10269
|
+
return False, self.get_model(model_request.name)
|
10270
|
+
|
10286
10271
|
def _get_next_numeric_version_for_model(
|
10287
10272
|
self, session: Session, model_id: UUID
|
10288
10273
|
) -> int:
|
@@ -10307,55 +10292,276 @@ class SqlZenStore(BaseZenStore):
|
|
10307
10292
|
else:
|
10308
10293
|
return int(current_max_version) + 1
|
10309
10294
|
|
10310
|
-
def _model_version_exists(
|
10295
|
+
def _model_version_exists(
|
10296
|
+
self,
|
10297
|
+
model_id: UUID,
|
10298
|
+
version: Optional[str] = None,
|
10299
|
+
producer_run_id: Optional[UUID] = None,
|
10300
|
+
) -> bool:
|
10311
10301
|
"""Check if a model version with a certain version exists.
|
10312
10302
|
|
10313
10303
|
Args:
|
10314
10304
|
model_id: The model ID of the version.
|
10315
10305
|
version: The version name.
|
10306
|
+
producer_run_id: The producer run ID. If given, checks if a numeric
|
10307
|
+
version for the producer run exists.
|
10316
10308
|
|
10317
10309
|
Returns:
|
10318
|
-
If a model version
|
10310
|
+
If a model version for the given arguments exists.
|
10319
10311
|
"""
|
10312
|
+
query = select(ModelVersionSchema.id).where(
|
10313
|
+
ModelVersionSchema.model_id == model_id
|
10314
|
+
)
|
10315
|
+
|
10316
|
+
if version:
|
10317
|
+
query = query.where(ModelVersionSchema.name == version)
|
10318
|
+
|
10319
|
+
if producer_run_id:
|
10320
|
+
query = query.where(
|
10321
|
+
ModelVersionSchema.producer_run_id_if_numeric
|
10322
|
+
== producer_run_id,
|
10323
|
+
)
|
10324
|
+
|
10320
10325
|
with Session(self.engine) as session:
|
10321
|
-
return (
|
10322
|
-
|
10323
|
-
|
10324
|
-
|
10325
|
-
|
10326
|
-
|
10327
|
-
|
10326
|
+
return session.exec(query).first() is not None
|
10327
|
+
|
10328
|
+
def _get_model_version(
|
10329
|
+
self,
|
10330
|
+
model_id: UUID,
|
10331
|
+
version_name: Optional[str] = None,
|
10332
|
+
producer_run_id: Optional[UUID] = None,
|
10333
|
+
) -> ModelVersionResponse:
|
10334
|
+
"""Get a model version.
|
10335
|
+
|
10336
|
+
Args:
|
10337
|
+
model_id: The ID of the model.
|
10338
|
+
version_name: The name of the model version.
|
10339
|
+
producer_run_id: The ID of the producer pipeline run. If this is
|
10340
|
+
set, only numeric versions created as part of the pipeline run
|
10341
|
+
will be returned.
|
10342
|
+
|
10343
|
+
Raises:
|
10344
|
+
ValueError: If no version name or producer run ID was provided.
|
10345
|
+
KeyError: If no model version was found.
|
10346
|
+
|
10347
|
+
Returns:
|
10348
|
+
The model version.
|
10349
|
+
"""
|
10350
|
+
query = select(ModelVersionSchema).where(
|
10351
|
+
ModelVersionSchema.model_id == model_id
|
10352
|
+
)
|
10353
|
+
|
10354
|
+
if version_name:
|
10355
|
+
if version_name.isnumeric():
|
10356
|
+
query = query.where(
|
10357
|
+
ModelVersionSchema.number == int(version_name)
|
10358
|
+
)
|
10359
|
+
error_text = (
|
10360
|
+
f"No version with number {version_name} found "
|
10361
|
+
f"for model {model_id}."
|
10362
|
+
)
|
10363
|
+
elif version_name in ModelStages.values():
|
10364
|
+
if version_name == ModelStages.LATEST:
|
10365
|
+
query = query.order_by(
|
10366
|
+
desc(col(ModelVersionSchema.number))
|
10367
|
+
).limit(1)
|
10368
|
+
else:
|
10369
|
+
query = query.where(
|
10370
|
+
ModelVersionSchema.stage == version_name
|
10371
|
+
)
|
10372
|
+
error_text = (
|
10373
|
+
f"No {version_name} stage version found for "
|
10374
|
+
f"model {model_id}."
|
10375
|
+
)
|
10376
|
+
else:
|
10377
|
+
query = query.where(ModelVersionSchema.name == version_name)
|
10378
|
+
error_text = (
|
10379
|
+
f"No {version_name} version found for model {model_id}."
|
10380
|
+
)
|
10381
|
+
|
10382
|
+
elif producer_run_id:
|
10383
|
+
query = query.where(
|
10384
|
+
ModelVersionSchema.producer_run_id_if_numeric
|
10385
|
+
== producer_run_id,
|
10386
|
+
)
|
10387
|
+
error_text = (
|
10388
|
+
f"No numeric model version found for model {model_id} "
|
10389
|
+
f"and producer run {producer_run_id}."
|
10390
|
+
)
|
10391
|
+
else:
|
10392
|
+
raise ValueError(
|
10393
|
+
"Version name or producer run id need to be specified."
|
10328
10394
|
)
|
10329
10395
|
|
10330
|
-
|
10331
|
-
|
10332
|
-
|
10396
|
+
with Session(self.engine) as session:
|
10397
|
+
schema = session.exec(query).one_or_none()
|
10398
|
+
|
10399
|
+
if not schema:
|
10400
|
+
raise KeyError(error_text)
|
10401
|
+
|
10402
|
+
return schema.to_model(
|
10403
|
+
include_metadata=True, include_resources=True
|
10404
|
+
)
|
10405
|
+
|
10406
|
+
def _get_or_create_model_version(
|
10407
|
+
self,
|
10408
|
+
model_version_request: ModelVersionRequest,
|
10409
|
+
producer_run_id: Optional[UUID] = None,
|
10410
|
+
) -> Tuple[bool, ModelVersionResponse]:
|
10411
|
+
"""Get or create a model version.
|
10412
|
+
|
10413
|
+
Args:
|
10414
|
+
model_version_request: The model version request.
|
10415
|
+
producer_run_id: ID of the producer pipeline run.
|
10416
|
+
|
10417
|
+
Raises:
|
10418
|
+
EntityCreationError: If the model version creation failed.
|
10419
|
+
|
10420
|
+
Returns:
|
10421
|
+
A boolean whether the model version was created or not, and the
|
10422
|
+
model version.
|
10423
|
+
"""
|
10424
|
+
try:
|
10425
|
+
model_version = self._create_model_version(
|
10426
|
+
model_version=model_version_request,
|
10427
|
+
producer_run_id=producer_run_id,
|
10428
|
+
)
|
10429
|
+
track(event=AnalyticsEvent.CREATED_MODEL_VERSION)
|
10430
|
+
return True, model_version
|
10431
|
+
except EntityCreationError:
|
10432
|
+
# Need to explicitly re-raise this here as otherwise the catching
|
10433
|
+
# of the RuntimeError would include this
|
10434
|
+
raise
|
10435
|
+
except RuntimeError:
|
10436
|
+
return False, self._get_model_version(
|
10437
|
+
model_id=model_version_request.model,
|
10438
|
+
producer_run_id=producer_run_id,
|
10439
|
+
)
|
10440
|
+
except EntityExistsError:
|
10441
|
+
return False, self._get_model_version(
|
10442
|
+
model_id=model_version_request.model,
|
10443
|
+
version_name=model_version_request.name,
|
10444
|
+
)
|
10445
|
+
|
10446
|
+
def _get_or_create_model_version_for_run(
|
10447
|
+
self, pipeline_or_step_run: Union[PipelineRunSchema, StepRunSchema]
|
10448
|
+
) -> Optional[UUID]:
|
10449
|
+
"""Get or create a model version for a pipeline or step run.
|
10450
|
+
|
10451
|
+
Args:
|
10452
|
+
pipeline_or_step_run: The pipeline or step run for which to create
|
10453
|
+
the model version.
|
10454
|
+
|
10455
|
+
Returns:
|
10456
|
+
The model version.
|
10457
|
+
"""
|
10458
|
+
if isinstance(pipeline_or_step_run, PipelineRunSchema):
|
10459
|
+
producer_run_id = pipeline_or_step_run.id
|
10460
|
+
pipeline_run = pipeline_or_step_run.to_model(include_metadata=True)
|
10461
|
+
configured_model = pipeline_run.config.model
|
10462
|
+
substitutions = pipeline_run.config.substitutions
|
10463
|
+
else:
|
10464
|
+
producer_run_id = pipeline_or_step_run.pipeline_run_id
|
10465
|
+
step_run = pipeline_or_step_run.to_model(include_metadata=True)
|
10466
|
+
configured_model = step_run.config.model
|
10467
|
+
substitutions = step_run.config.substitutions
|
10468
|
+
|
10469
|
+
if not configured_model:
|
10470
|
+
return None
|
10471
|
+
|
10472
|
+
model_request = ModelRequest(
|
10473
|
+
name=format_name_template(
|
10474
|
+
configured_model.name, substitutions=substitutions
|
10475
|
+
),
|
10476
|
+
license=configured_model.license,
|
10477
|
+
description=configured_model.description,
|
10478
|
+
audience=configured_model.audience,
|
10479
|
+
use_cases=configured_model.use_cases,
|
10480
|
+
limitations=configured_model.limitations,
|
10481
|
+
trade_offs=configured_model.trade_offs,
|
10482
|
+
ethics=configured_model.ethics,
|
10483
|
+
save_models_to_registry=configured_model.save_models_to_registry,
|
10484
|
+
user=pipeline_or_step_run.user_id,
|
10485
|
+
workspace=pipeline_or_step_run.workspace_id,
|
10486
|
+
)
|
10487
|
+
|
10488
|
+
_, model_response = self._get_or_create_model(
|
10489
|
+
model_request=model_request
|
10490
|
+
)
|
10491
|
+
|
10492
|
+
version_name = None
|
10493
|
+
if configured_model.version is not None:
|
10494
|
+
version_name = format_name_template(
|
10495
|
+
str(configured_model.version), substitutions=substitutions
|
10496
|
+
)
|
10497
|
+
|
10498
|
+
# If the model version was specified to be a numeric version or
|
10499
|
+
# stage we don't try to create it (which will fail because it is not
|
10500
|
+
# allowed) but try to fetch it immediately
|
10501
|
+
if (
|
10502
|
+
version_name.isnumeric()
|
10503
|
+
or version_name in ModelStages.values()
|
10504
|
+
):
|
10505
|
+
return self._get_model_version(
|
10506
|
+
model_id=model_response.id, version_name=version_name
|
10507
|
+
).id
|
10508
|
+
|
10509
|
+
model_version_request = ModelVersionRequest(
|
10510
|
+
model=model_response.id,
|
10511
|
+
name=version_name,
|
10512
|
+
description=configured_model.description,
|
10513
|
+
tags=configured_model.tags,
|
10514
|
+
user=pipeline_or_step_run.user_id,
|
10515
|
+
workspace=pipeline_or_step_run.workspace_id,
|
10516
|
+
)
|
10517
|
+
|
10518
|
+
_, model_version_response = self._get_or_create_model_version(
|
10519
|
+
model_version_request=model_version_request,
|
10520
|
+
producer_run_id=producer_run_id,
|
10521
|
+
)
|
10522
|
+
return model_version_response.id
|
10523
|
+
|
10524
|
+
def _create_model_version(
|
10525
|
+
self,
|
10526
|
+
model_version: ModelVersionRequest,
|
10527
|
+
producer_run_id: Optional[UUID] = None,
|
10333
10528
|
) -> ModelVersionResponse:
|
10334
10529
|
"""Creates a new model version.
|
10335
10530
|
|
10336
10531
|
Args:
|
10337
10532
|
model_version: the Model Version to be created.
|
10533
|
+
producer_run_id: ID of the pipeline run that produced this model
|
10534
|
+
version.
|
10338
10535
|
|
10339
10536
|
Returns:
|
10340
10537
|
The newly created model version.
|
10341
10538
|
|
10342
10539
|
Raises:
|
10343
|
-
ValueError: If
|
10540
|
+
ValueError: If the requested version name is invalid.
|
10344
10541
|
EntityExistsError: If a model version with the given name already
|
10345
10542
|
exists.
|
10346
10543
|
EntityCreationError: If the model version creation failed.
|
10544
|
+
RuntimeError: If an auto-incremented model version already exists
|
10545
|
+
for the producer run.
|
10347
10546
|
"""
|
10348
|
-
|
10349
|
-
|
10350
|
-
|
10351
|
-
)
|
10547
|
+
has_custom_name = False
|
10548
|
+
if model_version.name:
|
10549
|
+
has_custom_name = True
|
10550
|
+
validate_name(model_version)
|
10352
10551
|
|
10353
|
-
|
10552
|
+
if model_version.name.isnumeric():
|
10553
|
+
raise ValueError(
|
10554
|
+
"Can't create model version with custom numeric model "
|
10555
|
+
"version name."
|
10556
|
+
)
|
10354
10557
|
|
10355
|
-
|
10356
|
-
|
10357
|
-
|
10558
|
+
if str(model_version.name).lower() in ModelStages.values():
|
10559
|
+
raise ValueError(
|
10560
|
+
"Can't create model version with a name that is used as a "
|
10561
|
+
f"model version stage ({ModelStages.values()})."
|
10562
|
+
)
|
10358
10563
|
|
10564
|
+
model = self.get_model(model_version.model)
|
10359
10565
|
model_version_id = None
|
10360
10566
|
|
10361
10567
|
remaining_tries = MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION
|
@@ -10363,17 +10569,19 @@ class SqlZenStore(BaseZenStore):
|
|
10363
10569
|
remaining_tries -= 1
|
10364
10570
|
try:
|
10365
10571
|
with Session(self.engine) as session:
|
10366
|
-
|
10572
|
+
model_version_number = (
|
10367
10573
|
self._get_next_numeric_version_for_model(
|
10368
10574
|
session=session,
|
10369
10575
|
model_id=model.id,
|
10370
10576
|
)
|
10371
10577
|
)
|
10372
10578
|
if not has_custom_name:
|
10373
|
-
model_version.name = str(
|
10579
|
+
model_version.name = str(model_version_number)
|
10374
10580
|
|
10375
10581
|
model_version_schema = ModelVersionSchema.from_request(
|
10376
|
-
model_version
|
10582
|
+
model_version,
|
10583
|
+
model_version_number=model_version_number,
|
10584
|
+
producer_run_id=producer_run_id,
|
10377
10585
|
)
|
10378
10586
|
session.add(model_version_schema)
|
10379
10587
|
session.commit()
|
@@ -10394,6 +10602,13 @@ class SqlZenStore(BaseZenStore):
|
|
10394
10602
|
f"{model_version.name}): A model with the "
|
10395
10603
|
"same name and version already exists."
|
10396
10604
|
)
|
10605
|
+
elif producer_run_id and self._model_version_exists(
|
10606
|
+
model_id=model.id, producer_run_id=producer_run_id
|
10607
|
+
):
|
10608
|
+
raise RuntimeError(
|
10609
|
+
"Auto-incremented model version already exists for "
|
10610
|
+
f"producer run {producer_run_id}."
|
10611
|
+
)
|
10397
10612
|
elif remaining_tries == 0:
|
10398
10613
|
raise EntityCreationError(
|
10399
10614
|
f"Failed to create version for model "
|
@@ -10412,10 +10627,9 @@ class SqlZenStore(BaseZenStore):
|
|
10412
10627
|
)
|
10413
10628
|
logger.debug(
|
10414
10629
|
"Failed to create model version %s "
|
10415
|
-
"
|
10630
|
+
"due to an integrity error. "
|
10416
10631
|
"Retrying in %f seconds.",
|
10417
10632
|
model.name,
|
10418
|
-
model_version.number,
|
10419
10633
|
sleep_duration,
|
10420
10634
|
)
|
10421
10635
|
time.sleep(sleep_duration)
|
@@ -10430,6 +10644,20 @@ class SqlZenStore(BaseZenStore):
|
|
10430
10644
|
|
10431
10645
|
return self.get_model_version(model_version_id)
|
10432
10646
|
|
10647
|
+
@track_decorator(AnalyticsEvent.CREATED_MODEL_VERSION)
|
10648
|
+
def create_model_version(
|
10649
|
+
self, model_version: ModelVersionRequest
|
10650
|
+
) -> ModelVersionResponse:
|
10651
|
+
"""Creates a new model version.
|
10652
|
+
|
10653
|
+
Args:
|
10654
|
+
model_version: the Model Version to be created.
|
10655
|
+
|
10656
|
+
Returns:
|
10657
|
+
The newly created model version.
|
10658
|
+
"""
|
10659
|
+
return self._create_model_version(model_version=model_version)
|
10660
|
+
|
10433
10661
|
def get_model_version(
|
10434
10662
|
self, model_version_id: UUID, hydrate: bool = True
|
10435
10663
|
) -> ModelVersionResponse:
|