zenml-nightly 0.71.0.dev20241213__py3-none-any.whl → 0.71.0.dev20241214__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- zenml/VERSION +1 -1
- zenml/client.py +44 -2
- zenml/integrations/kubernetes/step_operators/kubernetes_step_operator.py +0 -1
- zenml/model/model.py +12 -16
- zenml/models/v2/base/filter.py +26 -30
- zenml/models/v2/base/scoped.py +258 -5
- zenml/models/v2/core/artifact_version.py +15 -26
- zenml/models/v2/core/code_repository.py +1 -12
- zenml/models/v2/core/component.py +5 -46
- zenml/models/v2/core/flavor.py +1 -11
- zenml/models/v2/core/model.py +1 -57
- zenml/models/v2/core/model_version.py +5 -33
- zenml/models/v2/core/model_version_artifact.py +11 -3
- zenml/models/v2/core/model_version_pipeline_run.py +14 -3
- zenml/models/v2/core/pipeline.py +47 -55
- zenml/models/v2/core/pipeline_build.py +19 -12
- zenml/models/v2/core/pipeline_deployment.py +0 -10
- zenml/models/v2/core/pipeline_run.py +91 -29
- zenml/models/v2/core/run_template.py +21 -29
- zenml/models/v2/core/schedule.py +0 -10
- zenml/models/v2/core/secret.py +0 -14
- zenml/models/v2/core/service.py +9 -16
- zenml/models/v2/core/service_connector.py +0 -11
- zenml/models/v2/core/stack.py +21 -30
- zenml/models/v2/core/step_run.py +18 -14
- zenml/models/v2/core/trigger.py +19 -3
- zenml/orchestrators/step_launcher.py +9 -13
- zenml/orchestrators/step_run_utils.py +8 -204
- zenml/zen_server/rbac/rbac_sql_zen_store.py +173 -0
- zenml/zen_server/utils.py +4 -3
- zenml/zen_stores/base_zen_store.py +10 -2
- zenml/zen_stores/migrations/versions/a1237ba94fd8_add_model_version_producer_run_unique_.py +68 -0
- zenml/zen_stores/schemas/model_schemas.py +42 -6
- zenml/zen_stores/schemas/pipeline_deployment_schemas.py +7 -7
- zenml/zen_stores/schemas/pipeline_schemas.py +5 -0
- zenml/zen_stores/sql_zen_store.py +322 -86
- {zenml_nightly-0.71.0.dev20241213.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/METADATA +1 -1
- {zenml_nightly-0.71.0.dev20241213.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/RECORD +41 -39
- {zenml_nightly-0.71.0.dev20241213.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/LICENSE +0 -0
- {zenml_nightly-0.71.0.dev20241213.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/WHEEL +0 -0
- {zenml_nightly-0.71.0.dev20241213.dist-info → zenml_nightly-0.71.0.dev20241214.dist-info}/entry_points.txt +0 -0
@@ -55,7 +55,7 @@ from pydantic import (
|
|
55
55
|
field_validator,
|
56
56
|
model_validator,
|
57
57
|
)
|
58
|
-
from sqlalchemy import
|
58
|
+
from sqlalchemy import func
|
59
59
|
from sqlalchemy.engine import URL, Engine, make_url
|
60
60
|
from sqlalchemy.exc import (
|
61
61
|
ArgumentError,
|
@@ -71,6 +71,7 @@ from sqlmodel import (
|
|
71
71
|
col,
|
72
72
|
create_engine,
|
73
73
|
delete,
|
74
|
+
desc,
|
74
75
|
or_,
|
75
76
|
select,
|
76
77
|
)
|
@@ -100,7 +101,6 @@ from zenml.constants import (
|
|
100
101
|
ENV_ZENML_SERVER,
|
101
102
|
FINISHED_ONBOARDING_SURVEY_KEY,
|
102
103
|
MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION,
|
103
|
-
SORT_PIPELINES_BY_LATEST_RUN_KEY,
|
104
104
|
SQL_STORE_BACKUP_DIRECTORY_NAME,
|
105
105
|
TEXT_FIELD_MAX_LENGTH,
|
106
106
|
handle_bool_env_var,
|
@@ -117,7 +117,6 @@ from zenml.enums import (
|
|
117
117
|
OnboardingStep,
|
118
118
|
SecretScope,
|
119
119
|
SecretsStoreType,
|
120
|
-
SorterOps,
|
121
120
|
StackComponentType,
|
122
121
|
StackDeploymentProvider,
|
123
122
|
StepRunInputArtifactType,
|
@@ -298,7 +297,11 @@ from zenml.utils.networking_utils import (
|
|
298
297
|
replace_localhost_with_internal_hostname,
|
299
298
|
)
|
300
299
|
from zenml.utils.pydantic_utils import before_validator_handler
|
301
|
-
from zenml.utils.string_utils import
|
300
|
+
from zenml.utils.string_utils import (
|
301
|
+
format_name_template,
|
302
|
+
random_str,
|
303
|
+
validate_name,
|
304
|
+
)
|
302
305
|
from zenml.zen_stores import template_utils
|
303
306
|
from zenml.zen_stores.base_zen_store import (
|
304
307
|
BaseZenStore,
|
@@ -4358,69 +4361,14 @@ class SqlZenStore(BaseZenStore):
|
|
4358
4361
|
Returns:
|
4359
4362
|
A list of all pipelines matching the filter criteria.
|
4360
4363
|
"""
|
4361
|
-
query: Union[Select[Any], SelectOfScalar[Any]] = select(PipelineSchema)
|
4362
|
-
_custom_conversion: Optional[Callable[[Any], PipelineResponse]] = None
|
4363
|
-
|
4364
|
-
column, operand = pipeline_filter_model.sorting_params
|
4365
|
-
if column == SORT_PIPELINES_BY_LATEST_RUN_KEY:
|
4366
|
-
with Session(self.engine) as session:
|
4367
|
-
max_date_subquery = (
|
4368
|
-
# If no run exists for the pipeline yet, we use the pipeline
|
4369
|
-
# creation date as a fallback, otherwise newly created
|
4370
|
-
# pipeline would always be at the top/bottom
|
4371
|
-
select(
|
4372
|
-
PipelineSchema.id,
|
4373
|
-
case(
|
4374
|
-
(
|
4375
|
-
func.max(PipelineRunSchema.created).is_(None),
|
4376
|
-
PipelineSchema.created,
|
4377
|
-
),
|
4378
|
-
else_=func.max(PipelineRunSchema.created),
|
4379
|
-
).label("run_or_created"),
|
4380
|
-
)
|
4381
|
-
.outerjoin(
|
4382
|
-
PipelineRunSchema,
|
4383
|
-
PipelineSchema.id == PipelineRunSchema.pipeline_id, # type: ignore[arg-type]
|
4384
|
-
)
|
4385
|
-
.group_by(col(PipelineSchema.id))
|
4386
|
-
.subquery()
|
4387
|
-
)
|
4388
|
-
|
4389
|
-
if operand == SorterOps.DESCENDING:
|
4390
|
-
sort_clause = desc
|
4391
|
-
else:
|
4392
|
-
sort_clause = asc
|
4393
|
-
|
4394
|
-
query = (
|
4395
|
-
# We need to include the subquery in the select here to
|
4396
|
-
# make this query work with the distinct statement. This
|
4397
|
-
# result will be removed in the custom conversion function
|
4398
|
-
# applied later
|
4399
|
-
select(PipelineSchema, max_date_subquery.c.run_or_created)
|
4400
|
-
.where(PipelineSchema.id == max_date_subquery.c.id)
|
4401
|
-
.order_by(sort_clause(max_date_subquery.c.run_or_created))
|
4402
|
-
# We always add the `id` column as a tiebreaker to ensure a
|
4403
|
-
# stable, repeatable order of items, otherwise subsequent
|
4404
|
-
# pages might contain the same items.
|
4405
|
-
.order_by(col(PipelineSchema.id))
|
4406
|
-
)
|
4407
|
-
|
4408
|
-
def _custom_conversion(row: Any) -> PipelineResponse:
|
4409
|
-
return cast(
|
4410
|
-
PipelineResponse,
|
4411
|
-
row[0].to_model(
|
4412
|
-
include_metadata=hydrate, include_resources=True
|
4413
|
-
),
|
4414
|
-
)
|
4415
|
-
|
4416
4364
|
with Session(self.engine) as session:
|
4365
|
+
query = select(PipelineSchema)
|
4417
4366
|
return self.filter_and_paginate(
|
4418
4367
|
session=session,
|
4419
4368
|
query=query,
|
4420
4369
|
table=PipelineSchema,
|
4421
4370
|
filter_model=pipeline_filter_model,
|
4422
4371
|
hydrate=hydrate,
|
4423
|
-
custom_schema_to_model_conversion=_custom_conversion,
|
4424
4372
|
)
|
4425
4373
|
|
4426
4374
|
def count_pipelines(self, filter_model: Optional[PipelineFilter]) -> int:
|
@@ -5211,6 +5159,20 @@ class SqlZenStore(BaseZenStore):
|
|
5211
5159
|
"already exists."
|
5212
5160
|
)
|
5213
5161
|
|
5162
|
+
if model_version_id := self._get_or_create_model_version_for_run(
|
5163
|
+
new_run
|
5164
|
+
):
|
5165
|
+
new_run.model_version_id = model_version_id
|
5166
|
+
session.add(new_run)
|
5167
|
+
session.commit()
|
5168
|
+
|
5169
|
+
self.create_model_version_pipeline_run_link(
|
5170
|
+
ModelVersionPipelineRunRequest(
|
5171
|
+
model_version=model_version_id, pipeline_run=new_run.id
|
5172
|
+
)
|
5173
|
+
)
|
5174
|
+
session.refresh(new_run)
|
5175
|
+
|
5214
5176
|
return new_run.to_model(
|
5215
5177
|
include_metadata=True, include_resources=True
|
5216
5178
|
)
|
@@ -8273,6 +8235,21 @@ class SqlZenStore(BaseZenStore):
|
|
8273
8235
|
session.commit()
|
8274
8236
|
session.refresh(step_schema)
|
8275
8237
|
|
8238
|
+
if model_version_id := self._get_or_create_model_version_for_run(
|
8239
|
+
step_schema
|
8240
|
+
):
|
8241
|
+
step_schema.model_version_id = model_version_id
|
8242
|
+
session.add(step_schema)
|
8243
|
+
session.commit()
|
8244
|
+
|
8245
|
+
self.create_model_version_pipeline_run_link(
|
8246
|
+
ModelVersionPipelineRunRequest(
|
8247
|
+
model_version=model_version_id,
|
8248
|
+
pipeline_run=step_schema.pipeline_run_id,
|
8249
|
+
)
|
8250
|
+
)
|
8251
|
+
session.refresh(step_schema)
|
8252
|
+
|
8276
8253
|
return step_schema.to_model(
|
8277
8254
|
include_metadata=True, include_resources=True
|
8278
8255
|
)
|
@@ -10275,6 +10252,22 @@ class SqlZenStore(BaseZenStore):
|
|
10275
10252
|
|
10276
10253
|
# ----------------------------- Model Versions -----------------------------
|
10277
10254
|
|
10255
|
+
def _get_or_create_model(
|
10256
|
+
self, model_request: ModelRequest
|
10257
|
+
) -> Tuple[bool, ModelResponse]:
|
10258
|
+
"""Get or create a model.
|
10259
|
+
|
10260
|
+
Args:
|
10261
|
+
model_request: The model request.
|
10262
|
+
|
10263
|
+
Returns:
|
10264
|
+
A boolean whether the model was created or not, and the model.
|
10265
|
+
"""
|
10266
|
+
try:
|
10267
|
+
return True, self.create_model(model_request)
|
10268
|
+
except EntityExistsError:
|
10269
|
+
return False, self.get_model(model_request.name)
|
10270
|
+
|
10278
10271
|
def _get_next_numeric_version_for_model(
|
10279
10272
|
self, session: Session, model_id: UUID
|
10280
10273
|
) -> int:
|
@@ -10299,55 +10292,276 @@ class SqlZenStore(BaseZenStore):
|
|
10299
10292
|
else:
|
10300
10293
|
return int(current_max_version) + 1
|
10301
10294
|
|
10302
|
-
def _model_version_exists(
|
10295
|
+
def _model_version_exists(
|
10296
|
+
self,
|
10297
|
+
model_id: UUID,
|
10298
|
+
version: Optional[str] = None,
|
10299
|
+
producer_run_id: Optional[UUID] = None,
|
10300
|
+
) -> bool:
|
10303
10301
|
"""Check if a model version with a certain version exists.
|
10304
10302
|
|
10305
10303
|
Args:
|
10306
10304
|
model_id: The model ID of the version.
|
10307
10305
|
version: The version name.
|
10306
|
+
producer_run_id: The producer run ID. If given, checks if a numeric
|
10307
|
+
version for the producer run exists.
|
10308
10308
|
|
10309
10309
|
Returns:
|
10310
|
-
If a model version
|
10310
|
+
If a model version for the given arguments exists.
|
10311
10311
|
"""
|
10312
|
+
query = select(ModelVersionSchema.id).where(
|
10313
|
+
ModelVersionSchema.model_id == model_id
|
10314
|
+
)
|
10315
|
+
|
10316
|
+
if version:
|
10317
|
+
query = query.where(ModelVersionSchema.name == version)
|
10318
|
+
|
10319
|
+
if producer_run_id:
|
10320
|
+
query = query.where(
|
10321
|
+
ModelVersionSchema.producer_run_id_if_numeric
|
10322
|
+
== producer_run_id,
|
10323
|
+
)
|
10324
|
+
|
10312
10325
|
with Session(self.engine) as session:
|
10313
|
-
return (
|
10314
|
-
|
10315
|
-
|
10316
|
-
|
10317
|
-
|
10318
|
-
|
10319
|
-
|
10326
|
+
return session.exec(query).first() is not None
|
10327
|
+
|
10328
|
+
def _get_model_version(
|
10329
|
+
self,
|
10330
|
+
model_id: UUID,
|
10331
|
+
version_name: Optional[str] = None,
|
10332
|
+
producer_run_id: Optional[UUID] = None,
|
10333
|
+
) -> ModelVersionResponse:
|
10334
|
+
"""Get a model version.
|
10335
|
+
|
10336
|
+
Args:
|
10337
|
+
model_id: The ID of the model.
|
10338
|
+
version_name: The name of the model version.
|
10339
|
+
producer_run_id: The ID of the producer pipeline run. If this is
|
10340
|
+
set, only numeric versions created as part of the pipeline run
|
10341
|
+
will be returned.
|
10342
|
+
|
10343
|
+
Raises:
|
10344
|
+
ValueError: If no version name or producer run ID was provided.
|
10345
|
+
KeyError: If no model version was found.
|
10346
|
+
|
10347
|
+
Returns:
|
10348
|
+
The model version.
|
10349
|
+
"""
|
10350
|
+
query = select(ModelVersionSchema).where(
|
10351
|
+
ModelVersionSchema.model_id == model_id
|
10352
|
+
)
|
10353
|
+
|
10354
|
+
if version_name:
|
10355
|
+
if version_name.isnumeric():
|
10356
|
+
query = query.where(
|
10357
|
+
ModelVersionSchema.number == int(version_name)
|
10358
|
+
)
|
10359
|
+
error_text = (
|
10360
|
+
f"No version with number {version_name} found "
|
10361
|
+
f"for model {model_id}."
|
10362
|
+
)
|
10363
|
+
elif version_name in ModelStages.values():
|
10364
|
+
if version_name == ModelStages.LATEST:
|
10365
|
+
query = query.order_by(
|
10366
|
+
desc(col(ModelVersionSchema.number))
|
10367
|
+
).limit(1)
|
10368
|
+
else:
|
10369
|
+
query = query.where(
|
10370
|
+
ModelVersionSchema.stage == version_name
|
10371
|
+
)
|
10372
|
+
error_text = (
|
10373
|
+
f"No {version_name} stage version found for "
|
10374
|
+
f"model {model_id}."
|
10375
|
+
)
|
10376
|
+
else:
|
10377
|
+
query = query.where(ModelVersionSchema.name == version_name)
|
10378
|
+
error_text = (
|
10379
|
+
f"No {version_name} version found for model {model_id}."
|
10380
|
+
)
|
10381
|
+
|
10382
|
+
elif producer_run_id:
|
10383
|
+
query = query.where(
|
10384
|
+
ModelVersionSchema.producer_run_id_if_numeric
|
10385
|
+
== producer_run_id,
|
10386
|
+
)
|
10387
|
+
error_text = (
|
10388
|
+
f"No numeric model version found for model {model_id} "
|
10389
|
+
f"and producer run {producer_run_id}."
|
10390
|
+
)
|
10391
|
+
else:
|
10392
|
+
raise ValueError(
|
10393
|
+
"Version name or producer run id need to be specified."
|
10320
10394
|
)
|
10321
10395
|
|
10322
|
-
|
10323
|
-
|
10324
|
-
|
10396
|
+
with Session(self.engine) as session:
|
10397
|
+
schema = session.exec(query).one_or_none()
|
10398
|
+
|
10399
|
+
if not schema:
|
10400
|
+
raise KeyError(error_text)
|
10401
|
+
|
10402
|
+
return schema.to_model(
|
10403
|
+
include_metadata=True, include_resources=True
|
10404
|
+
)
|
10405
|
+
|
10406
|
+
def _get_or_create_model_version(
|
10407
|
+
self,
|
10408
|
+
model_version_request: ModelVersionRequest,
|
10409
|
+
producer_run_id: Optional[UUID] = None,
|
10410
|
+
) -> Tuple[bool, ModelVersionResponse]:
|
10411
|
+
"""Get or create a model version.
|
10412
|
+
|
10413
|
+
Args:
|
10414
|
+
model_version_request: The model version request.
|
10415
|
+
producer_run_id: ID of the producer pipeline run.
|
10416
|
+
|
10417
|
+
Raises:
|
10418
|
+
EntityCreationError: If the model version creation failed.
|
10419
|
+
|
10420
|
+
Returns:
|
10421
|
+
A boolean whether the model version was created or not, and the
|
10422
|
+
model version.
|
10423
|
+
"""
|
10424
|
+
try:
|
10425
|
+
model_version = self._create_model_version(
|
10426
|
+
model_version=model_version_request,
|
10427
|
+
producer_run_id=producer_run_id,
|
10428
|
+
)
|
10429
|
+
track(event=AnalyticsEvent.CREATED_MODEL_VERSION)
|
10430
|
+
return True, model_version
|
10431
|
+
except EntityCreationError:
|
10432
|
+
# Need to explicitly re-raise this here as otherwise the catching
|
10433
|
+
# of the RuntimeError would include this
|
10434
|
+
raise
|
10435
|
+
except RuntimeError:
|
10436
|
+
return False, self._get_model_version(
|
10437
|
+
model_id=model_version_request.model,
|
10438
|
+
producer_run_id=producer_run_id,
|
10439
|
+
)
|
10440
|
+
except EntityExistsError:
|
10441
|
+
return False, self._get_model_version(
|
10442
|
+
model_id=model_version_request.model,
|
10443
|
+
version_name=model_version_request.name,
|
10444
|
+
)
|
10445
|
+
|
10446
|
+
def _get_or_create_model_version_for_run(
|
10447
|
+
self, pipeline_or_step_run: Union[PipelineRunSchema, StepRunSchema]
|
10448
|
+
) -> Optional[UUID]:
|
10449
|
+
"""Get or create a model version for a pipeline or step run.
|
10450
|
+
|
10451
|
+
Args:
|
10452
|
+
pipeline_or_step_run: The pipeline or step run for which to create
|
10453
|
+
the model version.
|
10454
|
+
|
10455
|
+
Returns:
|
10456
|
+
The model version.
|
10457
|
+
"""
|
10458
|
+
if isinstance(pipeline_or_step_run, PipelineRunSchema):
|
10459
|
+
producer_run_id = pipeline_or_step_run.id
|
10460
|
+
pipeline_run = pipeline_or_step_run.to_model(include_metadata=True)
|
10461
|
+
configured_model = pipeline_run.config.model
|
10462
|
+
substitutions = pipeline_run.config.substitutions
|
10463
|
+
else:
|
10464
|
+
producer_run_id = pipeline_or_step_run.pipeline_run_id
|
10465
|
+
step_run = pipeline_or_step_run.to_model(include_metadata=True)
|
10466
|
+
configured_model = step_run.config.model
|
10467
|
+
substitutions = step_run.config.substitutions
|
10468
|
+
|
10469
|
+
if not configured_model:
|
10470
|
+
return None
|
10471
|
+
|
10472
|
+
model_request = ModelRequest(
|
10473
|
+
name=format_name_template(
|
10474
|
+
configured_model.name, substitutions=substitutions
|
10475
|
+
),
|
10476
|
+
license=configured_model.license,
|
10477
|
+
description=configured_model.description,
|
10478
|
+
audience=configured_model.audience,
|
10479
|
+
use_cases=configured_model.use_cases,
|
10480
|
+
limitations=configured_model.limitations,
|
10481
|
+
trade_offs=configured_model.trade_offs,
|
10482
|
+
ethics=configured_model.ethics,
|
10483
|
+
save_models_to_registry=configured_model.save_models_to_registry,
|
10484
|
+
user=pipeline_or_step_run.user_id,
|
10485
|
+
workspace=pipeline_or_step_run.workspace_id,
|
10486
|
+
)
|
10487
|
+
|
10488
|
+
_, model_response = self._get_or_create_model(
|
10489
|
+
model_request=model_request
|
10490
|
+
)
|
10491
|
+
|
10492
|
+
version_name = None
|
10493
|
+
if configured_model.version is not None:
|
10494
|
+
version_name = format_name_template(
|
10495
|
+
str(configured_model.version), substitutions=substitutions
|
10496
|
+
)
|
10497
|
+
|
10498
|
+
# If the model version was specified to be a numeric version or
|
10499
|
+
# stage we don't try to create it (which will fail because it is not
|
10500
|
+
# allowed) but try to fetch it immediately
|
10501
|
+
if (
|
10502
|
+
version_name.isnumeric()
|
10503
|
+
or version_name in ModelStages.values()
|
10504
|
+
):
|
10505
|
+
return self._get_model_version(
|
10506
|
+
model_id=model_response.id, version_name=version_name
|
10507
|
+
).id
|
10508
|
+
|
10509
|
+
model_version_request = ModelVersionRequest(
|
10510
|
+
model=model_response.id,
|
10511
|
+
name=version_name,
|
10512
|
+
description=configured_model.description,
|
10513
|
+
tags=configured_model.tags,
|
10514
|
+
user=pipeline_or_step_run.user_id,
|
10515
|
+
workspace=pipeline_or_step_run.workspace_id,
|
10516
|
+
)
|
10517
|
+
|
10518
|
+
_, model_version_response = self._get_or_create_model_version(
|
10519
|
+
model_version_request=model_version_request,
|
10520
|
+
producer_run_id=producer_run_id,
|
10521
|
+
)
|
10522
|
+
return model_version_response.id
|
10523
|
+
|
10524
|
+
def _create_model_version(
|
10525
|
+
self,
|
10526
|
+
model_version: ModelVersionRequest,
|
10527
|
+
producer_run_id: Optional[UUID] = None,
|
10325
10528
|
) -> ModelVersionResponse:
|
10326
10529
|
"""Creates a new model version.
|
10327
10530
|
|
10328
10531
|
Args:
|
10329
10532
|
model_version: the Model Version to be created.
|
10533
|
+
producer_run_id: ID of the pipeline run that produced this model
|
10534
|
+
version.
|
10330
10535
|
|
10331
10536
|
Returns:
|
10332
10537
|
The newly created model version.
|
10333
10538
|
|
10334
10539
|
Raises:
|
10335
|
-
ValueError: If
|
10540
|
+
ValueError: If the requested version name is invalid.
|
10336
10541
|
EntityExistsError: If a model version with the given name already
|
10337
10542
|
exists.
|
10338
10543
|
EntityCreationError: If the model version creation failed.
|
10544
|
+
RuntimeError: If an auto-incremented model version already exists
|
10545
|
+
for the producer run.
|
10339
10546
|
"""
|
10340
|
-
|
10341
|
-
|
10342
|
-
|
10343
|
-
)
|
10547
|
+
has_custom_name = False
|
10548
|
+
if model_version.name:
|
10549
|
+
has_custom_name = True
|
10550
|
+
validate_name(model_version)
|
10344
10551
|
|
10345
|
-
|
10552
|
+
if model_version.name.isnumeric():
|
10553
|
+
raise ValueError(
|
10554
|
+
"Can't create model version with custom numeric model "
|
10555
|
+
"version name."
|
10556
|
+
)
|
10346
10557
|
|
10347
|
-
|
10348
|
-
|
10349
|
-
|
10558
|
+
if str(model_version.name).lower() in ModelStages.values():
|
10559
|
+
raise ValueError(
|
10560
|
+
"Can't create model version with a name that is used as a "
|
10561
|
+
f"model version stage ({ModelStages.values()})."
|
10562
|
+
)
|
10350
10563
|
|
10564
|
+
model = self.get_model(model_version.model)
|
10351
10565
|
model_version_id = None
|
10352
10566
|
|
10353
10567
|
remaining_tries = MAX_RETRIES_FOR_VERSIONED_ENTITY_CREATION
|
@@ -10355,17 +10569,19 @@ class SqlZenStore(BaseZenStore):
|
|
10355
10569
|
remaining_tries -= 1
|
10356
10570
|
try:
|
10357
10571
|
with Session(self.engine) as session:
|
10358
|
-
|
10572
|
+
model_version_number = (
|
10359
10573
|
self._get_next_numeric_version_for_model(
|
10360
10574
|
session=session,
|
10361
10575
|
model_id=model.id,
|
10362
10576
|
)
|
10363
10577
|
)
|
10364
10578
|
if not has_custom_name:
|
10365
|
-
model_version.name = str(
|
10579
|
+
model_version.name = str(model_version_number)
|
10366
10580
|
|
10367
10581
|
model_version_schema = ModelVersionSchema.from_request(
|
10368
|
-
model_version
|
10582
|
+
model_version,
|
10583
|
+
model_version_number=model_version_number,
|
10584
|
+
producer_run_id=producer_run_id,
|
10369
10585
|
)
|
10370
10586
|
session.add(model_version_schema)
|
10371
10587
|
session.commit()
|
@@ -10386,6 +10602,13 @@ class SqlZenStore(BaseZenStore):
|
|
10386
10602
|
f"{model_version.name}): A model with the "
|
10387
10603
|
"same name and version already exists."
|
10388
10604
|
)
|
10605
|
+
elif producer_run_id and self._model_version_exists(
|
10606
|
+
model_id=model.id, producer_run_id=producer_run_id
|
10607
|
+
):
|
10608
|
+
raise RuntimeError(
|
10609
|
+
"Auto-incremented model version already exists for "
|
10610
|
+
f"producer run {producer_run_id}."
|
10611
|
+
)
|
10389
10612
|
elif remaining_tries == 0:
|
10390
10613
|
raise EntityCreationError(
|
10391
10614
|
f"Failed to create version for model "
|
@@ -10404,10 +10627,9 @@ class SqlZenStore(BaseZenStore):
|
|
10404
10627
|
)
|
10405
10628
|
logger.debug(
|
10406
10629
|
"Failed to create model version %s "
|
10407
|
-
"
|
10630
|
+
"due to an integrity error. "
|
10408
10631
|
"Retrying in %f seconds.",
|
10409
10632
|
model.name,
|
10410
|
-
model_version.number,
|
10411
10633
|
sleep_duration,
|
10412
10634
|
)
|
10413
10635
|
time.sleep(sleep_duration)
|
@@ -10422,6 +10644,20 @@ class SqlZenStore(BaseZenStore):
|
|
10422
10644
|
|
10423
10645
|
return self.get_model_version(model_version_id)
|
10424
10646
|
|
10647
|
+
@track_decorator(AnalyticsEvent.CREATED_MODEL_VERSION)
|
10648
|
+
def create_model_version(
|
10649
|
+
self, model_version: ModelVersionRequest
|
10650
|
+
) -> ModelVersionResponse:
|
10651
|
+
"""Creates a new model version.
|
10652
|
+
|
10653
|
+
Args:
|
10654
|
+
model_version: the Model Version to be created.
|
10655
|
+
|
10656
|
+
Returns:
|
10657
|
+
The newly created model version.
|
10658
|
+
"""
|
10659
|
+
return self._create_model_version(model_version=model_version)
|
10660
|
+
|
10425
10661
|
def get_model_version(
|
10426
10662
|
self, model_version_id: UUID, hydrate: bool = True
|
10427
10663
|
) -> ModelVersionResponse:
|