snowflake-ml-python 1.10.0__py3-none-any.whl → 1.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. snowflake/ml/experiment/callback/keras.py +63 -0
  2. snowflake/ml/experiment/callback/lightgbm.py +5 -1
  3. snowflake/ml/experiment/callback/xgboost.py +5 -1
  4. snowflake/ml/jobs/_utils/__init__.py +0 -0
  5. snowflake/ml/jobs/_utils/constants.py +4 -1
  6. snowflake/ml/jobs/_utils/payload_utils.py +42 -14
  7. snowflake/ml/jobs/_utils/query_helper.py +5 -1
  8. snowflake/ml/jobs/_utils/runtime_env_utils.py +63 -0
  9. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +2 -2
  10. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +3 -3
  11. snowflake/ml/jobs/_utils/spec_utils.py +41 -8
  12. snowflake/ml/jobs/_utils/stage_utils.py +22 -9
  13. snowflake/ml/jobs/_utils/types.py +5 -7
  14. snowflake/ml/jobs/job.py +1 -1
  15. snowflake/ml/jobs/manager.py +1 -13
  16. snowflake/ml/model/_client/model/model_version_impl.py +166 -10
  17. snowflake/ml/model/_client/ops/service_ops.py +63 -28
  18. snowflake/ml/model/_client/service/model_deployment_spec.py +103 -27
  19. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +11 -5
  20. snowflake/ml/model/_model_composer/model_composer.py +1 -70
  21. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +2 -43
  22. snowflake/ml/model/inference_engine.py +5 -0
  23. snowflake/ml/model/models/huggingface_pipeline.py +4 -3
  24. snowflake/ml/registry/_manager/model_manager.py +7 -35
  25. snowflake/ml/registry/_manager/model_parameter_reconciler.py +194 -5
  26. snowflake/ml/version.py +1 -1
  27. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.11.0.dist-info}/METADATA +23 -4
  28. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.11.0.dist-info}/RECORD +31 -27
  29. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.11.0.dist-info}/WHEEL +0 -0
  30. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.11.0.dist-info}/licenses/LICENSE.txt +0 -0
  31. {snowflake_ml_python-1.10.0.dist-info → snowflake_ml_python-1.11.0.dist-info}/top_level.txt +0 -0
@@ -10,10 +10,15 @@ class Model(BaseModel):
10
10
  version: str
11
11
 
12
12
 
13
+ class InferenceEngineSpec(BaseModel):
14
+ inference_engine_name: str
15
+ inference_engine_args: Optional[list[str]] = None
16
+
17
+
13
18
  class ImageBuild(BaseModel):
14
- compute_pool: str
15
- image_repo: str
16
- force_rebuild: bool
19
+ compute_pool: Optional[str] = None
20
+ image_repo: Optional[str] = None
21
+ force_rebuild: Optional[bool] = None
17
22
  external_access_integrations: Optional[list[str]] = None
18
23
 
19
24
 
@@ -27,6 +32,7 @@ class Service(BaseModel):
27
32
  gpu: Optional[str] = None
28
33
  num_workers: Optional[int] = None
29
34
  max_batch_rows: Optional[int] = None
35
+ inference_engine_spec: Optional[InferenceEngineSpec] = None
30
36
 
31
37
 
32
38
  class Job(BaseModel):
@@ -68,13 +74,13 @@ class ModelLogging(BaseModel):
68
74
 
69
75
  class ModelServiceDeploymentSpec(BaseModel):
70
76
  models: list[Model]
71
- image_build: ImageBuild
77
+ image_build: Optional[ImageBuild] = None
72
78
  service: Service
73
79
  model_loggings: Optional[list[ModelLogging]] = None
74
80
 
75
81
 
76
82
  class ModelJobDeploymentSpec(BaseModel):
77
83
  models: list[Model]
78
- image_build: ImageBuild
84
+ image_build: Optional[ImageBuild] = None
79
85
  job: Job
80
86
  model_loggings: Optional[list[ModelLogging]] = None
@@ -1,17 +1,12 @@
1
1
  import pathlib
2
2
  import tempfile
3
3
  import uuid
4
- import warnings
5
4
  from types import ModuleType
6
5
  from typing import TYPE_CHECKING, Any, Optional, Union
7
6
  from urllib import parse
8
7
 
9
- from absl import logging
10
- from packaging import requirements
11
-
12
8
  from snowflake import snowpark
13
- from snowflake.ml import version as snowml_version
14
- from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
9
+ from snowflake.ml._internal import file_utils
15
10
  from snowflake.ml._internal.lineage import lineage_utils
16
11
  from snowflake.ml.data import data_source
17
12
  from snowflake.ml.model import model_signature, type_hints as model_types
@@ -19,7 +14,6 @@ from snowflake.ml.model._model_composer.model_manifest import model_manifest
19
14
  from snowflake.ml.model._packager import model_packager
20
15
  from snowflake.ml.model._packager.model_meta import model_meta
21
16
  from snowflake.snowpark import Session
22
- from snowflake.snowpark._internal import utils as snowpark_utils
23
17
 
24
18
  if TYPE_CHECKING:
25
19
  from snowflake.ml.experiment._experiment_info import ExperimentInfo
@@ -142,73 +136,10 @@ class ModelComposer:
142
136
  experiment_info: Optional["ExperimentInfo"] = None,
143
137
  options: Optional[model_types.ModelSaveOption] = None,
144
138
  ) -> model_meta.ModelMetadata:
145
- # set enable_explainability=False if the model is not runnable in WH or the target platforms include SPCS
146
- conda_dep_dict = env_utils.validate_conda_dependency_string_list(
147
- conda_dependencies if conda_dependencies else []
148
- )
149
-
150
- enable_explainability = None
151
-
152
- if options:
153
- enable_explainability = options.get("enable_explainability", None)
154
-
155
- # skip everything if user said False explicitly
156
- if enable_explainability is None or enable_explainability is True:
157
- is_warehouse_runnable = (
158
- not conda_dep_dict
159
- or all(
160
- chan == env_utils.DEFAULT_CHANNEL_NAME or chan == env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
161
- for chan in conda_dep_dict
162
- )
163
- ) and (not pip_requirements)
164
-
165
- only_spcs = (
166
- target_platforms
167
- and len(target_platforms) == 1
168
- and model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES in target_platforms
169
- )
170
- if only_spcs or (not is_warehouse_runnable):
171
- # if only SPCS and user asked for explainability we fail
172
- if enable_explainability is True:
173
- raise ValueError(
174
- "`enable_explainability` cannot be set to True when the model is not runnable in WH "
175
- "or the target platforms include SPCS."
176
- )
177
- elif not options: # explicitly set flag to false in these cases if not specified
178
- options = model_types.BaseModelSaveOption()
179
- options["enable_explainability"] = False
180
- elif (
181
- target_platforms
182
- and len(target_platforms) > 1
183
- and model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES in target_platforms
184
- ): # if both then only available for WH
185
- if enable_explainability is True:
186
- warnings.warn(
187
- ("Explain function will only be available for model deployed to warehouse."),
188
- category=UserWarning,
189
- stacklevel=2,
190
- )
191
139
 
192
140
  if not options:
193
141
  options = model_types.BaseModelSaveOption()
194
142
 
195
- if not snowpark_utils.is_in_stored_procedure() and target_platforms != [ # type: ignore[no-untyped-call]
196
- model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES # no information schema check for SPCS-only models
197
- ]:
198
- snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema(
199
- self.session,
200
- reqs=[requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}=={snowml_version.VERSION}")],
201
- python_version=python_version or snowml_env.PYTHON_VERSION,
202
- statement_params=self._statement_params,
203
- ).get(env_utils.SNOWPARK_ML_PKG_NAME, [])
204
-
205
- if len(snowml_matched_versions) < 1 and options.get("embed_local_ml_library", False) is False:
206
- logging.info(
207
- f"Local snowflake-ml-python library has version {snowml_version.VERSION},"
208
- " which is not available in the Snowflake server, embedding local ML library automatically."
209
- )
210
- options["embed_local_ml_library"] = True
211
-
212
143
  model_metadata: model_meta.ModelMetadata = self.packager.save(
213
144
  name=name,
214
145
  model=model,
@@ -1,13 +1,11 @@
1
1
  import collections
2
2
  import logging
3
3
  import pathlib
4
- import warnings
5
4
  from typing import TYPE_CHECKING, Optional, cast
6
5
 
7
6
  import yaml
8
7
 
9
8
  from snowflake.ml._internal import env_utils
10
- from snowflake.ml._internal.exceptions import error_codes, exceptions
11
9
  from snowflake.ml.data import data_source
12
10
  from snowflake.ml.model import type_hints
13
11
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
@@ -55,47 +53,8 @@ class ModelManifest:
55
53
  experiment_info: Optional["ExperimentInfo"] = None,
56
54
  target_platforms: Optional[list[type_hints.TargetPlatform]] = None,
57
55
  ) -> None:
58
- if options is None:
59
- options = {}
60
-
61
- has_pip_requirements = len(model_meta.env.pip_requirements) > 0
62
- only_spcs = (
63
- target_platforms
64
- and len(target_platforms) == 1
65
- and target_platforms[0] == type_hints.TargetPlatform.SNOWPARK_CONTAINER_SERVICES
66
- )
67
-
68
- if "relax_version" not in options:
69
- if has_pip_requirements or only_spcs:
70
- logger.info(
71
- "Setting `relax_version=False` as this model will run in Snowpark Container Services "
72
- "or in Warehouse with a specified artifact_repository_map where exact version "
73
- " specifications will be honored."
74
- )
75
- relax_version = False
76
- else:
77
- warnings.warn(
78
- (
79
- "`relax_version` is not set and therefore defaulted to True. Dependency version constraints"
80
- " relaxed from ==x.y.z to >=x.y, <(x+1). To use specific dependency versions for compatibility,"
81
- " reproducibility, etc., set `options={'relax_version': False}` when logging the model."
82
- ),
83
- category=UserWarning,
84
- stacklevel=2,
85
- )
86
- relax_version = True
87
- options["relax_version"] = relax_version
88
- else:
89
- relax_version = options.get("relax_version", True)
90
- if relax_version and (has_pip_requirements or only_spcs):
91
- raise exceptions.SnowflakeMLException(
92
- error_code=error_codes.INVALID_ARGUMENT,
93
- original_exception=ValueError(
94
- "Setting `relax_version=True` is only allowed for models to be run in Warehouse with "
95
- "Snowflake Conda Channel dependencies. It cannot be used with pip requirements or when "
96
- "targeting only Snowpark Container Services."
97
- ),
98
- )
56
+ assert options is not None, "ModelParameterReconciler should have set options with relax_version"
57
+ relax_version = options["relax_version"]
99
58
 
100
59
  runtime_to_use = model_runtime.ModelRuntime(
101
60
  name=self._DEFAULT_RUNTIME_NAME,
@@ -0,0 +1,5 @@
1
+ import enum
2
+
3
+
4
+ class InferenceEngine(enum.Enum):
5
+ VLLM = "vllm"
@@ -258,7 +258,7 @@ class HuggingFacePipelineModel:
258
258
  # model_version_impl.create_service parameters
259
259
  service_name: str,
260
260
  service_compute_pool: str,
261
- image_repo: str,
261
+ image_repo: Optional[str] = None,
262
262
  image_build_compute_pool: Optional[str] = None,
263
263
  ingress_enabled: bool = False,
264
264
  max_instances: int = 1,
@@ -282,7 +282,8 @@ class HuggingFacePipelineModel:
282
282
  comment: Comment for the model. Defaults to None.
283
283
  service_name: The name of the service to create.
284
284
  service_compute_pool: The compute pool for the service.
285
- image_repo: The name of the image repository.
285
+ image_repo: The name of the image repository. This can be None, in that case a default hidden image
286
+ repository will be used.
286
287
  image_build_compute_pool: The name of the compute pool used to build the model inference image. It uses
287
288
  the service compute pool if None.
288
289
  ingress_enabled: Whether ingress is enabled. Defaults to False.
@@ -356,7 +357,7 @@ class HuggingFacePipelineModel:
356
357
  else sql_identifier.SqlIdentifier(service_compute_pool)
357
358
  ),
358
359
  service_compute_pool_name=sql_identifier.SqlIdentifier(service_compute_pool),
359
- image_repo=image_repo,
360
+ image_repo_name=image_repo,
360
361
  ingress_enabled=ingress_enabled,
361
362
  max_instances=max_instances,
362
363
  cpu_requests=cpu_requests,
@@ -4,15 +4,14 @@ from typing import TYPE_CHECKING, Any, Optional, Union
4
4
  import pandas as pd
5
5
  from absl.logging import logging
6
6
 
7
- from snowflake.ml._internal import env, platform_capabilities, telemetry
7
+ from snowflake.ml._internal import platform_capabilities, telemetry
8
8
  from snowflake.ml._internal.exceptions import error_codes, exceptions
9
9
  from snowflake.ml._internal.human_readable_id import hrid_generator
10
10
  from snowflake.ml._internal.utils import sql_identifier
11
- from snowflake.ml.model import model_signature, target_platform, task, type_hints
11
+ from snowflake.ml.model import model_signature, task, type_hints
12
12
  from snowflake.ml.model._client.model import model_impl, model_version_impl
13
13
  from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
14
14
  from snowflake.ml.model._model_composer import model_composer
15
- from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
16
15
  from snowflake.ml.model._packager.model_meta import model_meta
17
16
  from snowflake.ml.registry._manager import model_parameter_reconciler
18
17
  from snowflake.snowpark import exceptions as snowpark_exceptions, session
@@ -221,37 +220,8 @@ class ModelManager:
221
220
  statement_params=statement_params,
222
221
  )
223
222
 
224
- platforms = None
225
- # User specified target platforms are defaulted to None and will not show up in the generated manifest.
226
- if target_platforms:
227
- # Convert any string target platforms to TargetPlatform objects
228
- platforms = [type_hints.TargetPlatform(platform) for platform in target_platforms]
229
- else:
230
- # Default the target platform to warehouse if not specified and any table function exists
231
- if options and (
232
- options.get("function_type") == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
233
- or (
234
- any(
235
- opt.get("function_type") == "TABLE_FUNCTION"
236
- for opt in options.get("method_options", {}).values()
237
- )
238
- )
239
- ):
240
- logger.info(
241
- "Logging a partitioned model with a table function without specifying `target_platforms`. "
242
- 'Default to `target_platforms=["WAREHOUSE"]`.'
243
- )
244
- platforms = [target_platform.TargetPlatform.WAREHOUSE]
245
-
246
- # Default the target platform to SPCS if not specified when running in ML runtime
247
- if not platforms and env.IN_ML_RUNTIME:
248
- logger.info(
249
- "Logging the model on Container Runtime for ML without specifying `target_platforms`. "
250
- 'Default to `target_platforms=["SNOWPARK_CONTAINER_SERVICES"]`.'
251
- )
252
- platforms = [target_platform.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
253
-
254
223
  reconciler = model_parameter_reconciler.ModelParameterReconciler(
224
+ session=self._model_ops._session,
255
225
  database_name=self._database_name,
256
226
  schema_name=self._schema_name,
257
227
  conda_dependencies=conda_dependencies,
@@ -259,6 +229,8 @@ class ModelManager:
259
229
  target_platforms=target_platforms,
260
230
  artifact_repository_map=artifact_repository_map,
261
231
  options=options,
232
+ python_version=python_version,
233
+ statement_params=statement_params,
262
234
  )
263
235
 
264
236
  model_params = reconciler.reconcile()
@@ -293,12 +265,12 @@ class ModelManager:
293
265
  pip_requirements=pip_requirements,
294
266
  artifact_repository_map=artifact_repository_map,
295
267
  resource_constraint=resource_constraint,
296
- target_platforms=platforms,
268
+ target_platforms=model_params.target_platforms,
297
269
  python_version=python_version,
298
270
  user_files=user_files,
299
271
  code_paths=code_paths,
300
272
  ext_modules=ext_modules,
301
- options=options,
273
+ options=model_params.options,
302
274
  task=task,
303
275
  experiment_info=experiment_info,
304
276
  )
@@ -1,9 +1,20 @@
1
1
  import warnings
2
2
  from dataclasses import dataclass
3
- from typing import Optional
3
+ from typing import Any, Optional
4
4
 
5
+ from absl.logging import logging
6
+ from packaging import requirements
7
+
8
+ from snowflake.ml import version as snowml_version
9
+ from snowflake.ml._internal import env, env as snowml_env, env_utils
10
+ from snowflake.ml._internal.exceptions import error_codes, exceptions
5
11
  from snowflake.ml._internal.utils import sql_identifier
6
- from snowflake.ml.model import type_hints as model_types
12
+ from snowflake.ml.model import target_platform, type_hints as model_types
13
+ from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
14
+ from snowflake.snowpark import Session
15
+ from snowflake.snowpark._internal import utils as snowpark_utils
16
+
17
+ logger = logging.getLogger(__name__)
7
18
 
8
19
 
9
20
  @dataclass
@@ -12,7 +23,7 @@ class ReconciledParameters:
12
23
 
13
24
  conda_dependencies: Optional[list[str]] = None
14
25
  pip_requirements: Optional[list[str]] = None
15
- target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None
26
+ target_platforms: Optional[list[model_types.TargetPlatform]] = None
16
27
  artifact_repository_map: Optional[dict[str, str]] = None
17
28
  options: Optional[model_types.ModelSaveOption] = None
18
29
  save_location: Optional[str] = None
@@ -23,6 +34,7 @@ class ModelParameterReconciler:
23
34
 
24
35
  def __init__(
25
36
  self,
37
+ session: Session,
26
38
  database_name: sql_identifier.SqlIdentifier,
27
39
  schema_name: sql_identifier.SqlIdentifier,
28
40
  conda_dependencies: Optional[list[str]] = None,
@@ -30,7 +42,10 @@ class ModelParameterReconciler:
30
42
  target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None,
31
43
  artifact_repository_map: Optional[dict[str, str]] = None,
32
44
  options: Optional[model_types.ModelSaveOption] = None,
45
+ python_version: Optional[str] = None,
46
+ statement_params: Optional[dict[str, str]] = None,
33
47
  ) -> None:
48
+ self._session = session
34
49
  self._database_name = database_name
35
50
  self._schema_name = schema_name
36
51
  self._conda_dependencies = conda_dependencies
@@ -38,20 +53,27 @@ class ModelParameterReconciler:
38
53
  self._target_platforms = target_platforms
39
54
  self._artifact_repository_map = artifact_repository_map
40
55
  self._options = options
56
+ self._python_version = python_version
57
+ self._statement_params = statement_params
41
58
 
42
59
  def reconcile(self) -> ReconciledParameters:
43
60
  """Perform all parameter reconciliation and return clean parameters."""
61
+
44
62
  reconciled_artifact_repository_map = self._reconcile_artifact_repository_map()
45
63
  reconciled_save_location = self._extract_save_location()
46
64
 
47
65
  self._validate_pip_requirements_warehouse_compatibility(reconciled_artifact_repository_map)
48
66
 
67
+ reconciled_target_platforms = self._reconcile_target_platforms()
68
+ reconciled_options = self._reconcile_explainability_options(reconciled_target_platforms)
69
+ reconciled_options = self._reconcile_relax_version(reconciled_options, reconciled_target_platforms)
70
+
49
71
  return ReconciledParameters(
50
72
  conda_dependencies=self._conda_dependencies,
51
73
  pip_requirements=self._pip_requirements,
52
- target_platforms=self._target_platforms,
74
+ target_platforms=reconciled_target_platforms,
53
75
  artifact_repository_map=reconciled_artifact_repository_map,
54
- options=self._options,
76
+ options=reconciled_options,
55
77
  save_location=reconciled_save_location,
56
78
  )
57
79
 
@@ -82,6 +104,45 @@ class ModelParameterReconciler:
82
104
 
83
105
  return None
84
106
 
107
+ def _reconcile_target_platforms(self) -> Optional[list[model_types.TargetPlatform]]:
108
+ """Reconcile target platforms with proper defaulting logic."""
109
+ # User specified target platforms are defaulted to None and will not show up in the generated manifest.
110
+ if self._target_platforms:
111
+ # Convert any string target platforms to TargetPlatform objects
112
+ return [model_types.TargetPlatform(platform) for platform in self._target_platforms]
113
+
114
+ # Default the target platform to warehouse if not specified and any table function exists
115
+ if self._has_table_function():
116
+ logger.info(
117
+ "Logging a partitioned model with a table function without specifying `target_platforms`. "
118
+ 'Default to `target_platforms=["WAREHOUSE"]`.'
119
+ )
120
+ return [target_platform.TargetPlatform.WAREHOUSE]
121
+
122
+ # Default the target platform to SPCS if not specified when running in ML runtime
123
+ if env.IN_ML_RUNTIME:
124
+ logger.info(
125
+ "Logging the model on Container Runtime for ML without specifying `target_platforms`. "
126
+ 'Default to `target_platforms=["SNOWPARK_CONTAINER_SERVICES"]`.'
127
+ )
128
+ return [target_platform.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
129
+
130
+ return None
131
+
132
+ def _has_table_function(self) -> bool:
133
+ """Check if any table function exists in options."""
134
+ if self._options is None:
135
+ return False
136
+
137
+ if self._options.get("function_type") == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
138
+ return True
139
+
140
+ for opt in self._options.get("method_options", {}).values():
141
+ if opt.get("function_type") == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value:
142
+ return True
143
+
144
+ return False
145
+
85
146
  def _validate_pip_requirements_warehouse_compatibility(
86
147
  self, artifact_repository_map: Optional[dict[str, str]]
87
148
  ) -> None:
@@ -103,3 +164,131 @@ class ModelParameterReconciler:
103
164
  or model_types.TargetPlatform.WAREHOUSE in target_platforms
104
165
  or "WAREHOUSE" in target_platforms
105
166
  )
167
+
168
+ def _reconcile_explainability_options(
169
+ self, target_platforms: Optional[list[model_types.TargetPlatform]]
170
+ ) -> model_types.ModelSaveOption:
171
+ """Reconcile explainability settings and embed_local_ml_library based on warehouse runnability."""
172
+ options = self._options.copy() if self._options else model_types.BaseModelSaveOption()
173
+
174
+ conda_dep_dict = env_utils.validate_conda_dependency_string_list(self._conda_dependencies or [])
175
+
176
+ enable_explainability = options.get("enable_explainability", None)
177
+
178
+ # Handle case where user explicitly disabled explainability
179
+ if enable_explainability is False:
180
+ return self._handle_embed_local_ml_library(options, target_platforms)
181
+
182
+ target_platform_set = set(target_platforms) if target_platforms else set()
183
+
184
+ is_warehouse_runnable = self._is_warehouse_runnable(conda_dep_dict)
185
+ only_spcs = target_platform_set == set(target_platform.SNOWPARK_CONTAINER_SERVICES_ONLY)
186
+ has_both_platforms = target_platform_set == set(target_platform.BOTH_WAREHOUSE_AND_SNOWPARK_CONTAINER_SERVICES)
187
+
188
+ # Handle case where user explicitly requested explainability
189
+ if enable_explainability:
190
+ if only_spcs or not is_warehouse_runnable:
191
+ raise ValueError(
192
+ "`enable_explainability` cannot be set to True when the model is not runnable in WH "
193
+ "or the target platforms include SPCS."
194
+ )
195
+ elif has_both_platforms:
196
+ warnings.warn(
197
+ ("Explain function will only be available for model deployed to warehouse."),
198
+ category=UserWarning,
199
+ stacklevel=2,
200
+ )
201
+
202
+ # Handle case where explainability is not specified (None) - set default behavior
203
+ if enable_explainability is None:
204
+ if only_spcs or not is_warehouse_runnable:
205
+ options["enable_explainability"] = False
206
+
207
+ return self._handle_embed_local_ml_library(options, target_platforms)
208
+
209
+ def _handle_embed_local_ml_library(
210
+ self, options: model_types.ModelSaveOption, target_platforms: Optional[list[model_types.TargetPlatform]]
211
+ ) -> model_types.ModelSaveOption:
212
+ """Handle embed_local_ml_library logic."""
213
+ if not snowpark_utils.is_in_stored_procedure() and target_platforms != [ # type: ignore[no-untyped-call]
214
+ model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES # no information schema check for SPCS-only models
215
+ ]:
216
+ snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema(
217
+ self._session,
218
+ reqs=[requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}=={snowml_version.VERSION}")],
219
+ python_version=self._python_version or snowml_env.PYTHON_VERSION,
220
+ statement_params=self._statement_params,
221
+ ).get(env_utils.SNOWPARK_ML_PKG_NAME, [])
222
+
223
+ if len(snowml_matched_versions) < 1 and not options.get("embed_local_ml_library", False):
224
+ logging.info(
225
+ f"Local snowflake-ml-python library has version {snowml_version.VERSION},"
226
+ " which is not available in the Snowflake server, embedding local ML library automatically."
227
+ )
228
+ options["embed_local_ml_library"] = True
229
+
230
+ return options
231
+
232
+ def _is_warehouse_runnable(self, conda_dep_dict: dict[str, list[Any]]) -> bool:
233
+ """Check if model can run in warehouse based on conda channels and pip requirements."""
234
+ # If pip requirements are present but no artifact repository map, model cannot run in warehouse
235
+ if self._pip_requirements and not self._artifact_repository_map:
236
+ return False
237
+
238
+ # If no conda dependencies, model can run in warehouse
239
+ if not conda_dep_dict:
240
+ return True
241
+
242
+ # Check if all conda channels are warehouse-compatible
243
+ warehouse_compatible_channels = {env_utils.DEFAULT_CHANNEL_NAME, env_utils.SNOWFLAKE_CONDA_CHANNEL_URL}
244
+ for channel in conda_dep_dict:
245
+ if channel not in warehouse_compatible_channels:
246
+ return False
247
+
248
+ return True
249
+
250
+ def _reconcile_relax_version(
251
+ self,
252
+ options: model_types.ModelSaveOption,
253
+ target_platforms: Optional[list[model_types.TargetPlatform]],
254
+ ) -> model_types.ModelSaveOption:
255
+ """Reconcile relax_version setting based on pip requirements and target platforms."""
256
+ target_platform_set = set(target_platforms) if target_platforms else set()
257
+ has_pip_requirements = bool(self._pip_requirements)
258
+ only_spcs = target_platform_set == set(target_platform.SNOWPARK_CONTAINER_SERVICES_ONLY)
259
+
260
+ if "relax_version" not in options:
261
+ if has_pip_requirements or only_spcs:
262
+ logger.info(
263
+ "Setting `relax_version=False` as this model will run in Snowpark Container Services "
264
+ "or in Warehouse with a specified artifact_repository_map where exact version "
265
+ " specifications will be honored."
266
+ )
267
+ relax_version = False
268
+ else:
269
+ warnings.warn(
270
+ (
271
+ "`relax_version` is not set and therefore defaulted to True. Dependency version constraints"
272
+ " relaxed from ==x.y.z to >=x.y, <(x+1). To use specific dependency versions for compatibility,"
273
+ " reproducibility, etc., set `options={'relax_version': False}` when logging the model."
274
+ ),
275
+ category=UserWarning,
276
+ stacklevel=2,
277
+ )
278
+ relax_version = True
279
+ options["relax_version"] = relax_version
280
+ return options
281
+
282
+ # Handle case where relax_version is already set
283
+ relax_version = options["relax_version"]
284
+ if relax_version and (has_pip_requirements or only_spcs):
285
+ raise exceptions.SnowflakeMLException(
286
+ error_code=error_codes.INVALID_ARGUMENT,
287
+ original_exception=ValueError(
288
+ "Setting `relax_version=True` is only allowed for models to be run in Warehouse with "
289
+ "Snowflake Conda Channel dependencies. It cannot be used with pip requirements or when "
290
+ "targeting only Snowpark Container Services."
291
+ ),
292
+ )
293
+
294
+ return options
snowflake/ml/version.py CHANGED
@@ -1,2 +1,2 @@
1
1
  # This is parsed by regex in conda recipe meta file. Make sure not to break it.
2
- VERSION = "1.10.0"
2
+ VERSION = "1.11.0"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: snowflake-ml-python
3
- Version: 1.10.0
3
+ Version: 1.11.0
4
4
  Summary: The machine learning client library that is used for interacting with Snowflake to build machine learning solutions.
5
5
  Author-email: "Snowflake, Inc" <support@snowflake.com>
6
6
  License:
@@ -410,12 +410,32 @@ NOTE: Version 1.7.0 is used as example here. Please choose the the latest versio
410
410
 
411
411
  # Release History
412
412
 
413
- ## 1.10.0
413
+ ## 1.11.0
414
414
 
415
415
  ### Bug Fixes
416
416
 
417
+ * ML Job: Fix `Error: Unable to retrieve head IP address` if not all instances start within the timeout.
418
+ * ML Job: Fix `TypeError: SnowflakeCursor.execute() got an unexpected keyword argument '_force_qmark_paramstyle'`
419
+ when running inside Stored Procedures.
420
+
421
+ ### Behavior Changes
422
+
423
+ ### New Features
424
+
425
+ * `ModelVersion.create_service()`: Made `image_repo` argument optional. By
426
+ default it will use a default image repo, which is
427
+ being rolled out in server version 9.22+.
428
+ * Experiment Tracking (PrPr): Automatically log the model, metrics, and parameters while training Keras models with
429
+ `snowflake.ml.experiment.callback.keras.SnowflakeKerasCallback`.
430
+
431
+ ## 1.10.0
432
+
417
433
  ### Behavior Changes
418
434
 
435
+ * Experiment Tracking (PrPr): The import paths for the auto-logging callbacks have changed to
436
+ `snowflake.ml.experiment.callback.xgboost.SnowflakeXgboostCallback` and
437
+ `snowflake.ml.experiment.callback.lightgbm.SnowflakeLightgbmCallback`.
438
+
419
439
  ### New Features
420
440
 
421
441
  * Registry: add progress bars for `ModelVersion.create_service` and `ModelVersion.log_model`.
@@ -436,13 +456,13 @@ NOTE: Version 1.7.0 is used as example here. Please choose the the latest versio
436
456
 
437
457
  ```python
438
458
  from snowflake.ml.experiment import ExperimentTracking
459
+ from snowflake.ml.experiment.callback import SnowflakeXgboostCallback, SnowflakeLightgbmCallback
439
460
 
440
461
  exp = ExperimentTracking(session=sp_session, database_name="ML", schema_name="PUBLIC")
441
462
 
442
463
  exp.set_experiment("MY_EXPERIMENT")
443
464
 
444
465
  # XGBoost
445
- from snowflake.ml.experiment.callback.xgboost import SnowflakeXgboostCallback
446
466
  callback = SnowflakeXgboostCallback(
447
467
  exp, log_model=True, log_metrics=True, log_params=True, model_name="model_name", model_signature=sig
448
468
  )
@@ -451,7 +471,6 @@ with exp.start_run():
451
471
  model.fit(X, y, eval_set=[(X_test, y_test)])
452
472
 
453
473
  # LightGBM
454
- from snowflake.ml.experiment.callback.lightgbm import SnowflakeLightgbmCallback
455
474
  callback = SnowflakeLightgbmCallback(
456
475
  exp, log_model=True, log_metrics=True, log_params=True, model_name="model_name", model_signature=sig
457
476
  )