snowflake-ml-python 1.9.1__py3-none-any.whl → 1.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/utils/mixins.py +6 -4
- snowflake/ml/_internal/utils/service_logger.py +118 -4
- snowflake/ml/data/_internal/arrow_ingestor.py +4 -1
- snowflake/ml/data/data_connector.py +4 -34
- snowflake/ml/dataset/dataset.py +1 -1
- snowflake/ml/dataset/dataset_reader.py +2 -8
- snowflake/ml/experiment/__init__.py +3 -0
- snowflake/ml/experiment/callback/lightgbm.py +55 -0
- snowflake/ml/experiment/callback/xgboost.py +63 -0
- snowflake/ml/experiment/utils.py +14 -0
- snowflake/ml/jobs/_utils/constants.py +15 -4
- snowflake/ml/jobs/_utils/payload_utils.py +159 -52
- snowflake/ml/jobs/_utils/scripts/constants.py +0 -22
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +126 -23
- snowflake/ml/jobs/_utils/spec_utils.py +1 -1
- snowflake/ml/jobs/_utils/stage_utils.py +30 -14
- snowflake/ml/jobs/_utils/types.py +64 -4
- snowflake/ml/jobs/job.py +22 -6
- snowflake/ml/jobs/manager.py +5 -3
- snowflake/ml/model/_client/model/model_version_impl.py +56 -48
- snowflake/ml/model/_client/ops/service_ops.py +194 -14
- snowflake/ml/model/_client/sql/service.py +1 -38
- snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -5
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -0
- snowflake/ml/model/_signatures/pandas_handler.py +3 -0
- snowflake/ml/model/_signatures/utils.py +4 -0
- snowflake/ml/model/event_handler.py +87 -18
- snowflake/ml/model/model_signature.py +2 -0
- snowflake/ml/model/models/huggingface_pipeline.py +71 -49
- snowflake/ml/model/type_hints.py +26 -1
- snowflake/ml/registry/_manager/model_manager.py +30 -35
- snowflake/ml/registry/_manager/model_parameter_reconciler.py +105 -0
- snowflake/ml/registry/registry.py +0 -19
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/METADATA +542 -491
- {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/RECORD +39 -34
- {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.9.1.dist-info → snowflake_ml_python-1.10.0.dist-info}/top_level.txt +0 -0
|
@@ -299,6 +299,7 @@ class HuggingFacePipelineModel:
|
|
|
299
299
|
Raises:
|
|
300
300
|
ValueError: if database and schema name is not provided and session doesn't have a
|
|
301
301
|
database and schema name.
|
|
302
|
+
exceptions.SnowparkSQLException: if service already exists.
|
|
302
303
|
|
|
303
304
|
Returns:
|
|
304
305
|
The service ID or an async job object.
|
|
@@ -327,7 +328,6 @@ class HuggingFacePipelineModel:
|
|
|
327
328
|
version_name = name_generator.generate()[1]
|
|
328
329
|
|
|
329
330
|
service_db_id, service_schema_id, service_id = sql_identifier.parse_fully_qualified_name(service_name)
|
|
330
|
-
image_repo_db_id, image_repo_schema_id, image_repo_id = sql_identifier.parse_fully_qualified_name(image_repo)
|
|
331
331
|
|
|
332
332
|
service_operator = service_ops.ServiceOperator(
|
|
333
333
|
session=session,
|
|
@@ -336,51 +336,73 @@ class HuggingFacePipelineModel:
|
|
|
336
336
|
)
|
|
337
337
|
logger.info(f"A service job is going to register the hf model as: {model_name}.{version_name}")
|
|
338
338
|
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
339
|
+
from snowflake.ml.model import event_handler
|
|
340
|
+
from snowflake.snowpark import exceptions
|
|
341
|
+
|
|
342
|
+
hf_event_handler = event_handler.ModelEventHandler()
|
|
343
|
+
with hf_event_handler.status("Creating HuggingFace model service", total=6, block=block) as status:
|
|
344
|
+
try:
|
|
345
|
+
result = service_operator.create_service(
|
|
346
|
+
database_name=database_name_id,
|
|
347
|
+
schema_name=schema_name_id,
|
|
348
|
+
model_name=model_name_id,
|
|
349
|
+
version_name=sql_identifier.SqlIdentifier(version_name),
|
|
350
|
+
service_database_name=service_db_id,
|
|
351
|
+
service_schema_name=service_schema_id,
|
|
352
|
+
service_name=service_id,
|
|
353
|
+
image_build_compute_pool_name=(
|
|
354
|
+
sql_identifier.SqlIdentifier(image_build_compute_pool)
|
|
355
|
+
if image_build_compute_pool
|
|
356
|
+
else sql_identifier.SqlIdentifier(service_compute_pool)
|
|
357
|
+
),
|
|
358
|
+
service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
|
|
359
|
+
image_repo=image_repo,
|
|
360
|
+
ingress_enabled=ingress_enabled,
|
|
361
|
+
max_instances=max_instances,
|
|
362
|
+
cpu_requests=cpu_requests,
|
|
363
|
+
memory_requests=memory_requests,
|
|
364
|
+
gpu_requests=gpu_requests,
|
|
365
|
+
num_workers=num_workers,
|
|
366
|
+
max_batch_rows=max_batch_rows,
|
|
367
|
+
force_rebuild=force_rebuild,
|
|
368
|
+
build_external_access_integrations=(
|
|
369
|
+
None
|
|
370
|
+
if build_external_access_integrations is None
|
|
371
|
+
else [sql_identifier.SqlIdentifier(eai) for eai in build_external_access_integrations]
|
|
372
|
+
),
|
|
373
|
+
block=block,
|
|
374
|
+
progress_status=status,
|
|
375
|
+
statement_params=statement_params,
|
|
376
|
+
# hf model
|
|
377
|
+
hf_model_args=service_ops.HFModelArgs(
|
|
378
|
+
hf_model_name=self.model,
|
|
379
|
+
hf_task=self.task,
|
|
380
|
+
hf_tokenizer=self.tokenizer,
|
|
381
|
+
hf_revision=self.revision,
|
|
382
|
+
hf_token=self.token,
|
|
383
|
+
hf_trust_remote_code=bool(self.trust_remote_code),
|
|
384
|
+
hf_model_kwargs=self.model_kwargs,
|
|
385
|
+
pip_requirements=pip_requirements,
|
|
386
|
+
conda_dependencies=conda_dependencies,
|
|
387
|
+
comment=comment,
|
|
388
|
+
# TODO: remove warehouse in the next release
|
|
389
|
+
warehouse=session.get_current_warehouse(),
|
|
390
|
+
),
|
|
391
|
+
)
|
|
392
|
+
status.update(label="HuggingFace model service created successfully", state="complete", expanded=False)
|
|
393
|
+
return result
|
|
394
|
+
except exceptions.SnowparkSQLException as e:
|
|
395
|
+
# Check if the error is because the service already exists
|
|
396
|
+
if "already exists" in str(e).lower() or "100132" in str(
|
|
397
|
+
e
|
|
398
|
+
): # 100132 is Snowflake error code for object already exists
|
|
399
|
+
# Update progress to show service already exists (preserve exception behavior)
|
|
400
|
+
status.update("service already exists")
|
|
401
|
+
status.complete() # Complete progress to full state
|
|
402
|
+
status.update(label="Service already exists", state="error", expanded=False)
|
|
403
|
+
# Re-raise the exception to preserve existing API behavior
|
|
404
|
+
raise
|
|
405
|
+
else:
|
|
406
|
+
# Re-raise other SQL exceptions
|
|
407
|
+
status.update(label="Service creation failed", state="error", expanded=False)
|
|
408
|
+
raise
|
snowflake/ml/model/type_hints.py
CHANGED
|
@@ -1,5 +1,14 @@
|
|
|
1
1
|
# mypy: disable-error-code="import"
|
|
2
|
-
from typing import
|
|
2
|
+
from typing import (
|
|
3
|
+
TYPE_CHECKING,
|
|
4
|
+
Any,
|
|
5
|
+
Literal,
|
|
6
|
+
Protocol,
|
|
7
|
+
Sequence,
|
|
8
|
+
TypedDict,
|
|
9
|
+
TypeVar,
|
|
10
|
+
Union,
|
|
11
|
+
)
|
|
3
12
|
|
|
4
13
|
import numpy.typing as npt
|
|
5
14
|
from typing_extensions import NotRequired
|
|
@@ -326,4 +335,20 @@ ModelLoadOption = Union[
|
|
|
326
335
|
SupportedTargetPlatformType = Union[TargetPlatform, str]
|
|
327
336
|
|
|
328
337
|
|
|
338
|
+
class ProgressStatus(Protocol):
|
|
339
|
+
"""Protocol for tracking progress during long-running operations."""
|
|
340
|
+
|
|
341
|
+
def update(self, message: str, *, state: str = "running", expanded: bool = True, **kwargs: Any) -> None:
|
|
342
|
+
"""Update the progress status with a new message."""
|
|
343
|
+
...
|
|
344
|
+
|
|
345
|
+
def increment(self) -> None:
|
|
346
|
+
"""Increment the progress by one step."""
|
|
347
|
+
...
|
|
348
|
+
|
|
349
|
+
def complete(self) -> None:
|
|
350
|
+
"""Complete the progress bar to full state."""
|
|
351
|
+
...
|
|
352
|
+
|
|
353
|
+
|
|
329
354
|
__all__ = ["TargetPlatform", "Task"]
|
|
@@ -14,6 +14,7 @@ from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
|
|
|
14
14
|
from snowflake.ml.model._model_composer import model_composer
|
|
15
15
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
|
16
16
|
from snowflake.ml.model._packager.model_meta import model_meta
|
|
17
|
+
from snowflake.ml.registry._manager import model_parameter_reconciler
|
|
17
18
|
from snowflake.snowpark import exceptions as snowpark_exceptions, session
|
|
18
19
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
|
19
20
|
|
|
@@ -46,6 +47,7 @@ class ModelManager:
|
|
|
46
47
|
*,
|
|
47
48
|
model: Union[type_hints.SupportedModelType, model_version_impl.ModelVersion],
|
|
48
49
|
model_name: str,
|
|
50
|
+
progress_status: type_hints.ProgressStatus,
|
|
49
51
|
version_name: Optional[str] = None,
|
|
50
52
|
comment: Optional[str] = None,
|
|
51
53
|
metrics: Optional[dict[str, Any]] = None,
|
|
@@ -64,7 +66,6 @@ class ModelManager:
|
|
|
64
66
|
experiment_info: Optional["ExperimentInfo"] = None,
|
|
65
67
|
options: Optional[type_hints.ModelSaveOption] = None,
|
|
66
68
|
statement_params: Optional[dict[str, Any]] = None,
|
|
67
|
-
progress_status: Optional[Any] = None,
|
|
68
69
|
) -> model_version_impl.ModelVersion:
|
|
69
70
|
|
|
70
71
|
database_name_id, schema_name_id, model_name_id = self._parse_fully_qualified_name(model_name)
|
|
@@ -158,6 +159,7 @@ class ModelManager:
|
|
|
158
159
|
*,
|
|
159
160
|
model_name: str,
|
|
160
161
|
version_name: str,
|
|
162
|
+
progress_status: type_hints.ProgressStatus,
|
|
161
163
|
comment: Optional[str] = None,
|
|
162
164
|
metrics: Optional[dict[str, Any]] = None,
|
|
163
165
|
conda_dependencies: Optional[list[str]] = None,
|
|
@@ -175,7 +177,6 @@ class ModelManager:
|
|
|
175
177
|
experiment_info: Optional["ExperimentInfo"] = None,
|
|
176
178
|
options: Optional[type_hints.ModelSaveOption] = None,
|
|
177
179
|
statement_params: Optional[dict[str, Any]] = None,
|
|
178
|
-
progress_status: Optional[Any] = None,
|
|
179
180
|
) -> model_version_impl.ModelVersion:
|
|
180
181
|
database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
|
|
181
182
|
version_name_id = sql_identifier.SqlIdentifier(version_name)
|
|
@@ -250,27 +251,27 @@ class ModelManager:
|
|
|
250
251
|
)
|
|
251
252
|
platforms = [target_platform.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
|
|
252
253
|
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
254
|
+
reconciler = model_parameter_reconciler.ModelParameterReconciler(
|
|
255
|
+
database_name=self._database_name,
|
|
256
|
+
schema_name=self._schema_name,
|
|
257
|
+
conda_dependencies=conda_dependencies,
|
|
258
|
+
pip_requirements=pip_requirements,
|
|
259
|
+
target_platforms=target_platforms,
|
|
260
|
+
artifact_repository_map=artifact_repository_map,
|
|
261
|
+
options=options,
|
|
262
|
+
)
|
|
256
263
|
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
self._schema_name,
|
|
263
|
-
)
|
|
264
|
+
model_params = reconciler.reconcile()
|
|
265
|
+
|
|
266
|
+
# Use reconciled parameters
|
|
267
|
+
artifact_repository_map = model_params.artifact_repository_map
|
|
268
|
+
save_location = model_params.save_location
|
|
264
269
|
|
|
265
270
|
logger.info("Start packaging and uploading your model. It might take some time based on the size of the model.")
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
# Extract save_location from options if present
|
|
271
|
-
save_location = None
|
|
272
|
-
if options and "save_location" in options:
|
|
273
|
-
save_location = options.get("save_location")
|
|
271
|
+
progress_status.update("packaging model...")
|
|
272
|
+
progress_status.increment()
|
|
273
|
+
|
|
274
|
+
if save_location:
|
|
274
275
|
logger.info(f"Model will be saved to local directory: {save_location}")
|
|
275
276
|
|
|
276
277
|
mc = model_composer.ModelComposer(
|
|
@@ -280,9 +281,8 @@ class ModelManager:
|
|
|
280
281
|
save_location=save_location,
|
|
281
282
|
)
|
|
282
283
|
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
progress_status.increment()
|
|
284
|
+
progress_status.update("creating model manifest...")
|
|
285
|
+
progress_status.increment()
|
|
286
286
|
|
|
287
287
|
model_metadata: model_meta.ModelMetadata = mc.save(
|
|
288
288
|
name=model_name_id.resolved(),
|
|
@@ -303,9 +303,8 @@ class ModelManager:
|
|
|
303
303
|
experiment_info=experiment_info,
|
|
304
304
|
)
|
|
305
305
|
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
progress_status.increment()
|
|
306
|
+
progress_status.update("uploading model files...")
|
|
307
|
+
progress_status.increment()
|
|
309
308
|
statement_params = telemetry.add_statement_params_custom_tags(
|
|
310
309
|
statement_params, model_metadata.telemetry_metadata()
|
|
311
310
|
)
|
|
@@ -313,10 +312,8 @@ class ModelManager:
|
|
|
313
312
|
statement_params, {"model_version_name": version_name_id}
|
|
314
313
|
)
|
|
315
314
|
|
|
316
|
-
|
|
317
|
-
|
|
318
|
-
progress_status.update("creating model object in Snowflake...")
|
|
319
|
-
progress_status.increment()
|
|
315
|
+
progress_status.update("creating model object in Snowflake...")
|
|
316
|
+
progress_status.increment()
|
|
320
317
|
|
|
321
318
|
self._model_ops.create_from_stage(
|
|
322
319
|
composed_model=mc,
|
|
@@ -343,9 +340,8 @@ class ModelManager:
|
|
|
343
340
|
version_name=version_name_id,
|
|
344
341
|
)
|
|
345
342
|
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
progress_status.increment()
|
|
343
|
+
progress_status.update("setting model metadata...")
|
|
344
|
+
progress_status.increment()
|
|
349
345
|
|
|
350
346
|
if comment:
|
|
351
347
|
mv.comment = comment
|
|
@@ -360,8 +356,7 @@ class ModelManager:
|
|
|
360
356
|
statement_params=statement_params,
|
|
361
357
|
)
|
|
362
358
|
|
|
363
|
-
|
|
364
|
-
progress_status.update("model logged successfully!")
|
|
359
|
+
progress_status.update("model logged successfully!")
|
|
365
360
|
|
|
366
361
|
return mv
|
|
367
362
|
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
import warnings
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import Optional
|
|
4
|
+
|
|
5
|
+
from snowflake.ml._internal.utils import sql_identifier
|
|
6
|
+
from snowflake.ml.model import type_hints as model_types
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class ReconciledParameters:
|
|
11
|
+
"""Holds the reconciled and validated parameters after processing."""
|
|
12
|
+
|
|
13
|
+
conda_dependencies: Optional[list[str]] = None
|
|
14
|
+
pip_requirements: Optional[list[str]] = None
|
|
15
|
+
target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None
|
|
16
|
+
artifact_repository_map: Optional[dict[str, str]] = None
|
|
17
|
+
options: Optional[model_types.ModelSaveOption] = None
|
|
18
|
+
save_location: Optional[str] = None
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class ModelParameterReconciler:
|
|
22
|
+
"""Centralizes all complex log_model parameter validation, transformation, and reconciliation logic."""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
database_name: sql_identifier.SqlIdentifier,
|
|
27
|
+
schema_name: sql_identifier.SqlIdentifier,
|
|
28
|
+
conda_dependencies: Optional[list[str]] = None,
|
|
29
|
+
pip_requirements: Optional[list[str]] = None,
|
|
30
|
+
target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None,
|
|
31
|
+
artifact_repository_map: Optional[dict[str, str]] = None,
|
|
32
|
+
options: Optional[model_types.ModelSaveOption] = None,
|
|
33
|
+
) -> None:
|
|
34
|
+
self._database_name = database_name
|
|
35
|
+
self._schema_name = schema_name
|
|
36
|
+
self._conda_dependencies = conda_dependencies
|
|
37
|
+
self._pip_requirements = pip_requirements
|
|
38
|
+
self._target_platforms = target_platforms
|
|
39
|
+
self._artifact_repository_map = artifact_repository_map
|
|
40
|
+
self._options = options
|
|
41
|
+
|
|
42
|
+
def reconcile(self) -> ReconciledParameters:
|
|
43
|
+
"""Perform all parameter reconciliation and return clean parameters."""
|
|
44
|
+
reconciled_artifact_repository_map = self._reconcile_artifact_repository_map()
|
|
45
|
+
reconciled_save_location = self._extract_save_location()
|
|
46
|
+
|
|
47
|
+
self._validate_pip_requirements_warehouse_compatibility(reconciled_artifact_repository_map)
|
|
48
|
+
|
|
49
|
+
return ReconciledParameters(
|
|
50
|
+
conda_dependencies=self._conda_dependencies,
|
|
51
|
+
pip_requirements=self._pip_requirements,
|
|
52
|
+
target_platforms=self._target_platforms,
|
|
53
|
+
artifact_repository_map=reconciled_artifact_repository_map,
|
|
54
|
+
options=self._options,
|
|
55
|
+
save_location=reconciled_save_location,
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
def _reconcile_artifact_repository_map(self) -> Optional[dict[str, str]]:
|
|
59
|
+
"""Transform artifact_repository_map to use fully qualified names."""
|
|
60
|
+
if not self._artifact_repository_map:
|
|
61
|
+
return None
|
|
62
|
+
|
|
63
|
+
transformed_map = {}
|
|
64
|
+
|
|
65
|
+
for channel, artifact_repository_name in self._artifact_repository_map.items():
|
|
66
|
+
db_id, schema_id, repo_id = sql_identifier.parse_fully_qualified_name(artifact_repository_name)
|
|
67
|
+
|
|
68
|
+
transformed_map[channel] = sql_identifier.get_fully_qualified_name(
|
|
69
|
+
db_id,
|
|
70
|
+
schema_id,
|
|
71
|
+
repo_id,
|
|
72
|
+
self._database_name,
|
|
73
|
+
self._schema_name,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
return transformed_map
|
|
77
|
+
|
|
78
|
+
def _extract_save_location(self) -> Optional[str]:
|
|
79
|
+
"""Extract save_location from options."""
|
|
80
|
+
if self._options and "save_location" in self._options:
|
|
81
|
+
return self._options.get("save_location")
|
|
82
|
+
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
def _validate_pip_requirements_warehouse_compatibility(
|
|
86
|
+
self, artifact_repository_map: Optional[dict[str, str]]
|
|
87
|
+
) -> None:
|
|
88
|
+
"""Validate pip_requirements compatibility with warehouse deployment."""
|
|
89
|
+
if self._pip_requirements and not artifact_repository_map and self._targets_warehouse(self._target_platforms):
|
|
90
|
+
warnings.warn(
|
|
91
|
+
"Models logged specifying `pip_requirements` cannot be executed in a Snowflake Warehouse "
|
|
92
|
+
"without specifying `artifact_repository_map`. This model can be run in Snowpark Container "
|
|
93
|
+
"Services. See https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/container.",
|
|
94
|
+
category=UserWarning,
|
|
95
|
+
stacklevel=1,
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
@staticmethod
|
|
99
|
+
def _targets_warehouse(target_platforms: Optional[list[model_types.SupportedTargetPlatformType]]) -> bool:
|
|
100
|
+
"""Returns True if warehouse is a target platform (None defaults to True)."""
|
|
101
|
+
return (
|
|
102
|
+
target_platforms is None
|
|
103
|
+
or model_types.TargetPlatform.WAREHOUSE in target_platforms
|
|
104
|
+
or "WAREHOUSE" in target_platforms
|
|
105
|
+
)
|
|
@@ -1,4 +1,3 @@
|
|
|
1
|
-
import warnings
|
|
2
1
|
from types import ModuleType
|
|
3
2
|
from typing import Any, Optional, Union, overload
|
|
4
3
|
|
|
@@ -442,15 +441,6 @@ class Registry:
|
|
|
442
441
|
if task is not type_hints.Task.UNKNOWN:
|
|
443
442
|
raise ValueError("`task` cannot be specified when calling log_model with a ModelVersion.")
|
|
444
443
|
|
|
445
|
-
if pip_requirements and not artifact_repository_map and self._targets_warehouse(target_platforms):
|
|
446
|
-
warnings.warn(
|
|
447
|
-
"Models logged specifying `pip_requirements` cannot be executed in a Snowflake Warehouse "
|
|
448
|
-
"without specifying `artifact_repository_map`. This model can be run in Snowpark Container "
|
|
449
|
-
"Services. See https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/container.",
|
|
450
|
-
category=UserWarning,
|
|
451
|
-
stacklevel=1,
|
|
452
|
-
)
|
|
453
|
-
|
|
454
444
|
registry_event_handler = event_handler.ModelEventHandler()
|
|
455
445
|
with registry_event_handler.status("Logging model", total=6) as status:
|
|
456
446
|
# Step 1: Validation and setup
|
|
@@ -662,12 +652,3 @@ class Registry:
|
|
|
662
652
|
if not self.enable_monitoring:
|
|
663
653
|
raise ValueError(_MODEL_MONITORING_DISABLED_ERROR)
|
|
664
654
|
self._model_monitor_manager.delete_monitor(name)
|
|
665
|
-
|
|
666
|
-
@staticmethod
|
|
667
|
-
def _targets_warehouse(target_platforms: Optional[list[type_hints.SupportedTargetPlatformType]]) -> bool:
|
|
668
|
-
"""Returns True if warehouse is a target platform (None defaults to True)."""
|
|
669
|
-
return (
|
|
670
|
-
target_platforms is None
|
|
671
|
-
or type_hints.TargetPlatform.WAREHOUSE in target_platforms
|
|
672
|
-
or "WAREHOUSE" in target_platforms
|
|
673
|
-
)
|
snowflake/ml/version.py
CHANGED
|
@@ -1,2 +1,2 @@
|
|
|
1
1
|
# This is parsed by regex in conda recipe meta file. Make sure not to break it.
|
|
2
|
-
VERSION = "1.
|
|
2
|
+
VERSION = "1.10.0"
|