snowflake-ml-python 1.8.6__py3-none-any.whl → 1.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. snowflake/ml/_internal/utils/identifier.py +1 -1
  2. snowflake/ml/_internal/utils/mixins.py +61 -0
  3. snowflake/ml/jobs/_utils/constants.py +1 -1
  4. snowflake/ml/jobs/_utils/interop_utils.py +63 -4
  5. snowflake/ml/jobs/_utils/payload_utils.py +6 -5
  6. snowflake/ml/jobs/_utils/query_helper.py +9 -0
  7. snowflake/ml/jobs/_utils/spec_utils.py +6 -4
  8. snowflake/ml/jobs/decorators.py +18 -25
  9. snowflake/ml/jobs/job.py +179 -58
  10. snowflake/ml/jobs/manager.py +194 -145
  11. snowflake/ml/model/_client/ops/model_ops.py +12 -3
  12. snowflake/ml/model/_client/ops/service_ops.py +4 -2
  13. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -0
  14. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -10
  15. snowflake/ml/model/_packager/model_env/model_env.py +35 -27
  16. snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
  17. snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
  18. snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
  19. snowflake/ml/model/target_platform.py +11 -0
  20. snowflake/ml/model/task.py +9 -0
  21. snowflake/ml/model/type_hints.py +5 -13
  22. snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
  23. snowflake/ml/registry/_manager/model_manager.py +30 -15
  24. snowflake/ml/registry/registry.py +119 -42
  25. snowflake/ml/version.py +1 -1
  26. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/METADATA +52 -16
  27. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/RECORD +30 -26
  28. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/WHEEL +0 -0
  29. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/licenses/LICENSE.txt +0 -0
  30. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.0.dist-info}/top_level.txt +0 -0
@@ -7,6 +7,7 @@ from typing import Optional, cast
7
7
  import yaml
8
8
 
9
9
  from snowflake.ml._internal import env_utils
10
+ from snowflake.ml._internal.exceptions import error_codes, exceptions
10
11
  from snowflake.ml.data import data_source
11
12
  from snowflake.ml.model import type_hints
12
13
  from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
@@ -53,17 +54,44 @@ class ModelManifest:
53
54
  if options is None:
54
55
  options = {}
55
56
 
57
+ has_pip_requirements = len(model_meta.env.pip_requirements) > 0
58
+ only_spcs = (
59
+ target_platforms
60
+ and len(target_platforms) == 1
61
+ and target_platforms[0] == type_hints.TargetPlatform.SNOWPARK_CONTAINER_SERVICES
62
+ )
63
+
56
64
  if "relax_version" not in options:
57
- warnings.warn(
58
- (
59
- "`relax_version` is not set and therefore defaulted to True. Dependency version constraints relaxed"
60
- " from ==x.y.z to >=x.y, <(x+1). To use specific dependency versions for compatibility, "
61
- "reproducibility, etc., set `options={'relax_version': False}` when logging the model."
62
- ),
63
- category=UserWarning,
64
- stacklevel=2,
65
- )
66
- relax_version = options.get("relax_version", True)
65
+ if has_pip_requirements or only_spcs:
66
+ logger.info(
67
+ "Setting `relax_version=False` as this model will run in Snowpark Container Services "
68
+ "or in Warehouse with a specified artifact_repository_map where exact version "
69
+ " specifications will be honored."
70
+ )
71
+ relax_version = False
72
+ else:
73
+ warnings.warn(
74
+ (
75
+ "`relax_version` is not set and therefore defaulted to True. Dependency version constraints"
76
+ " relaxed from ==x.y.z to >=x.y, <(x+1). To use specific dependency versions for compatibility,"
77
+ " reproducibility, etc., set `options={'relax_version': False}` when logging the model."
78
+ ),
79
+ category=UserWarning,
80
+ stacklevel=2,
81
+ )
82
+ relax_version = True
83
+ options["relax_version"] = relax_version
84
+ else:
85
+ relax_version = options.get("relax_version", True)
86
+ if relax_version and (has_pip_requirements or only_spcs):
87
+ raise exceptions.SnowflakeMLException(
88
+ error_code=error_codes.INVALID_ARGUMENT,
89
+ original_exception=ValueError(
90
+ "Setting `relax_version=True` is only allowed for models to be run in Warehouse with "
91
+ "Snowflake Conda Channel dependencies. It cannot be used with pip requirements or when "
92
+ "targeting only Snowpark Container Services."
93
+ ),
94
+ )
67
95
 
68
96
  runtime_to_use = model_runtime.ModelRuntime(
69
97
  name=self._DEFAULT_RUNTIME_NAME,
@@ -9,6 +9,7 @@ from packaging import requirements, version
9
9
 
10
10
  from snowflake.ml import version as snowml_version
11
11
  from snowflake.ml._internal import env as snowml_env, env_utils
12
+ from snowflake.ml.model import type_hints as model_types
12
13
  from snowflake.ml.model._packager.model_meta import model_meta_schema
13
14
 
14
15
  # requirement: Full version requirement where name is conda package name.
@@ -30,6 +31,7 @@ class ModelEnv:
30
31
  conda_env_rel_path: Optional[str] = None,
31
32
  pip_requirements_rel_path: Optional[str] = None,
32
33
  prefer_pip: bool = False,
34
+ target_platforms: Optional[list[model_types.TargetPlatform]] = None,
33
35
  ) -> None:
34
36
  if conda_env_rel_path is None:
35
37
  conda_env_rel_path = os.path.join(_DEFAULT_ENV_DIR, _DEFAULT_CONDA_ENV_FILENAME)
@@ -45,6 +47,8 @@ class ModelEnv:
45
47
  self._python_version: version.Version = version.parse(snowml_env.PYTHON_VERSION)
46
48
  self._cuda_version: Optional[version.Version] = None
47
49
  self._snowpark_ml_version: version.Version = version.parse(snowml_version.VERSION)
50
+ self._target_platforms = target_platforms
51
+ self._warnings_shown: set[str] = set()
48
52
 
49
53
  @property
50
54
  def conda_dependencies(self) -> list[str]:
@@ -116,6 +120,17 @@ class ModelEnv:
116
120
  if snowpark_ml_version:
117
121
  self._snowpark_ml_version = version.parse(snowpark_ml_version)
118
122
 
123
+ @property
124
+ def targets_warehouse(self) -> bool:
125
+ """Returns True if warehouse is a target platform."""
126
+ return self._target_platforms is None or model_types.TargetPlatform.WAREHOUSE in self._target_platforms
127
+
128
+ def _warn_once(self, message: str, stacklevel: int = 2) -> None:
129
+ """Show warning only once per ModelEnv instance."""
130
+ if message not in self._warnings_shown:
131
+ warnings.warn(message, category=UserWarning, stacklevel=stacklevel)
132
+ self._warnings_shown.add(message)
133
+
119
134
  def include_if_absent(
120
135
  self,
121
136
  pkgs: list[ModelDependency],
@@ -130,14 +145,14 @@ class ModelEnv:
130
145
  """
131
146
  if (self.pip_requirements or self.prefer_pip) and not self.conda_dependencies and pkgs:
132
147
  pip_pkg_reqs: list[str] = []
133
- warnings.warn(
134
- (
135
- "Dependencies specified from pip requirements."
136
- " This may prevent model deploying to Snowflake Warehouse."
137
- ),
138
- category=UserWarning,
139
- stacklevel=2,
140
- )
148
+ if self.targets_warehouse:
149
+ self._warn_once(
150
+ (
151
+ "Dependencies specified from pip requirements."
152
+ " This may prevent model deploying to Snowflake Warehouse."
153
+ ),
154
+ stacklevel=2,
155
+ )
141
156
  for conda_req_str, pip_name in pkgs:
142
157
  _, conda_req = env_utils._validate_conda_dependency_string(conda_req_str)
143
158
  pip_req = requirements.Requirement(f"{pip_name}{conda_req.specifier}")
@@ -162,16 +177,15 @@ class ModelEnv:
162
177
  req_to_add.name = conda_req.name
163
178
  else:
164
179
  req_to_add = conda_req
165
- show_warning_message = conda_req_channel == env_utils.DEFAULT_CHANNEL_NAME
180
+ show_warning_message = conda_req_channel == env_utils.DEFAULT_CHANNEL_NAME and self.targets_warehouse
166
181
 
167
182
  if any(added_pip_req.name == pip_name for added_pip_req in self._pip_requirements):
168
183
  if show_warning_message:
169
- warnings.warn(
184
+ self._warn_once(
170
185
  (
171
186
  f"Basic dependency {req_to_add.name} specified from pip requirements."
172
187
  " This may prevent model deploying to Snowflake Warehouse."
173
188
  ),
174
- category=UserWarning,
175
189
  stacklevel=2,
176
190
  )
177
191
  continue
@@ -182,12 +196,11 @@ class ModelEnv:
182
196
  pass
183
197
  except env_utils.DuplicateDependencyInMultipleChannelsError:
184
198
  if show_warning_message:
185
- warnings.warn(
199
+ self._warn_once(
186
200
  (
187
201
  f"Basic dependency {req_to_add.name} specified from non-Snowflake channel."
188
202
  + " This may prevent model deploying to Snowflake Warehouse."
189
203
  ),
190
- category=UserWarning,
191
204
  stacklevel=2,
192
205
  )
193
206
 
@@ -272,22 +285,20 @@ class ModelEnv:
272
285
  )
273
286
 
274
287
  for channel, channel_dependencies in conda_dependencies_dict.items():
275
- if channel != env_utils.DEFAULT_CHANNEL_NAME:
276
- warnings.warn(
288
+ if channel != env_utils.DEFAULT_CHANNEL_NAME and self.targets_warehouse:
289
+ self._warn_once(
277
290
  (
278
291
  "Found dependencies specified in the conda file from non-Snowflake channel."
279
292
  " This may prevent model deploying to Snowflake Warehouse."
280
293
  ),
281
- category=UserWarning,
282
294
  stacklevel=2,
283
295
  )
284
- if len(channel_dependencies) == 0 and channel not in self._conda_dependencies:
285
- warnings.warn(
296
+ if len(channel_dependencies) == 0 and channel not in self._conda_dependencies and self.targets_warehouse:
297
+ self._warn_once(
286
298
  (
287
299
  f"Found additional conda channel {channel} specified in the conda file."
288
300
  " This may prevent model deploying to Snowflake Warehouse."
289
301
  ),
290
- category=UserWarning,
291
302
  stacklevel=2,
292
303
  )
293
304
  self._conda_dependencies[channel] = []
@@ -298,22 +309,20 @@ class ModelEnv:
298
309
  except env_utils.DuplicateDependencyError:
299
310
  pass
300
311
  except env_utils.DuplicateDependencyInMultipleChannelsError:
301
- warnings.warn(
312
+ self._warn_once(
302
313
  (
303
314
  f"Dependency {channel_dependency.name} appeared in multiple channels as conda dependency."
304
315
  " This may be unintentional."
305
316
  ),
306
- category=UserWarning,
307
317
  stacklevel=2,
308
318
  )
309
319
 
310
- if pip_requirements_list:
311
- warnings.warn(
320
+ if pip_requirements_list and self.targets_warehouse:
321
+ self._warn_once(
312
322
  (
313
323
  "Found dependencies specified as pip requirements."
314
324
  " This may prevent model deploying to Snowflake Warehouse."
315
325
  ),
316
- category=UserWarning,
317
326
  stacklevel=2,
318
327
  )
319
328
  for pip_dependency in pip_requirements_list:
@@ -333,13 +342,12 @@ class ModelEnv:
333
342
  def load_from_pip_file(self, pip_requirements_path: pathlib.Path) -> None:
334
343
  pip_requirements_list = env_utils.load_requirements_file(pip_requirements_path)
335
344
 
336
- if pip_requirements_list:
337
- warnings.warn(
345
+ if pip_requirements_list and self.targets_warehouse:
346
+ self._warn_once(
338
347
  (
339
348
  "Found dependencies specified as pip requirements."
340
349
  " This may prevent model deploying to Snowflake Warehouse."
341
350
  ),
342
- category=UserWarning,
343
351
  stacklevel=2,
344
352
  )
345
353
  for pip_dependency in pip_requirements_list:
@@ -167,7 +167,11 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
167
167
  model_blob_metadata = model_blobs_metadata[name]
168
168
  model_blob_filename = model_blob_metadata.path
169
169
  with open(os.path.join(model_blob_path, model_blob_filename), "rb") as f:
170
- m = torch.load(f, map_location="cuda" if kwargs.get("use_gpu", False) else "cpu")
170
+ m = torch.load(
171
+ f,
172
+ map_location="cuda" if kwargs.get("use_gpu", False) else "cpu",
173
+ weights_only=False,
174
+ )
171
175
  assert isinstance(m, torch.nn.Module)
172
176
 
173
177
  return m
@@ -110,6 +110,7 @@ def create_model_metadata(
110
110
  python_version=python_version,
111
111
  embed_local_ml_library=embed_local_ml_library,
112
112
  prefer_pip=prefer_pip,
113
+ target_platforms=target_platforms,
113
114
  )
114
115
 
115
116
  if embed_local_ml_library:
@@ -162,8 +163,9 @@ def _create_env_for_model_metadata(
162
163
  python_version: Optional[str] = None,
163
164
  embed_local_ml_library: bool = False,
164
165
  prefer_pip: bool = False,
166
+ target_platforms: Optional[list[model_types.TargetPlatform]] = None,
165
167
  ) -> model_env.ModelEnv:
166
- env = model_env.ModelEnv(prefer_pip=prefer_pip)
168
+ env = model_env.ModelEnv(prefer_pip=prefer_pip, target_platforms=target_platforms)
167
169
 
168
170
  # Mypy doesn't like getter and setter have different types. See python/mypy #3004
169
171
  env.conda_dependencies = conda_dependencies # type: ignore[assignment]
@@ -60,12 +60,19 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
60
60
  data: snowflake.snowpark.DataFrame,
61
61
  ensure_serializable: bool = True,
62
62
  features: Optional[Sequence[core.BaseFeatureSpec]] = None,
63
+ statement_params: Optional[dict[str, Any]] = None,
63
64
  ) -> pd.DataFrame:
64
65
  # This method do things on top of to_pandas, to make sure the local dataframe got is in correct shape.
65
66
  dtype_map = {}
67
+
66
68
  if features:
69
+ quoted_identifiers_ignore_case = SnowparkDataFrameHandler._is_quoted_identifiers_ignore_case_enabled(
70
+ data.session, statement_params
71
+ )
67
72
  for feature in features:
68
- dtype_map[feature.name] = feature.as_dtype()
73
+ feature_name = feature.name.upper() if quoted_identifiers_ignore_case else feature.name
74
+ dtype_map[feature_name] = feature.as_dtype()
75
+
69
76
  df_local = data.to_pandas()
70
77
 
71
78
  # This is because Array will become string (Even though the correct schema is set)
@@ -93,6 +100,7 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
93
100
  df: pd.DataFrame,
94
101
  keep_order: bool = False,
95
102
  features: Optional[Sequence[core.BaseFeatureSpec]] = None,
103
+ statement_params: Optional[dict[str, Any]] = None,
96
104
  ) -> snowflake.snowpark.DataFrame:
97
105
  # This method is necessary to create the Snowpark Dataframe in correct schema.
98
106
  # However, in this case, the order could not be preserved. Thus, a _ID column has to be added,
@@ -100,6 +108,12 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
100
108
  # Although in this case, the column with array type can get correct ARRAY type, however, the element
101
109
  # type is not preserved, and will become string type. This affect the implementation of convert_from_df.
102
110
  df = pandas_handler.PandasDataFrameHandler.convert_to_df(df)
111
+ quoted_identifiers_ignore_case = SnowparkDataFrameHandler._is_quoted_identifiers_ignore_case_enabled(
112
+ session, statement_params
113
+ )
114
+ if quoted_identifiers_ignore_case:
115
+ df.columns = [str(col).upper() for col in df.columns]
116
+
103
117
  df_cols = df.columns
104
118
  if df_cols.dtype != np.object_:
105
119
  raise snowml_exceptions.SnowflakeMLException(
@@ -116,9 +130,47 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
116
130
  column_names = []
117
131
  columns = []
118
132
  for feature in features:
119
- column_names.append(identifier.get_inferred_name(feature.name))
120
- columns.append(F.col(identifier.get_inferred_name(feature.name)).cast(feature.as_snowpark_type()))
133
+ feature_name = identifier.get_inferred_name(feature.name)
134
+ if quoted_identifiers_ignore_case:
135
+ feature_name = feature_name.upper()
136
+ column_names.append(feature_name)
137
+ columns.append(F.col(feature_name).cast(feature.as_snowpark_type()))
121
138
 
122
139
  sp_df = sp_df.with_columns(column_names, columns)
123
140
 
124
141
  return sp_df
142
+
143
+ @staticmethod
144
+ def _is_quoted_identifiers_ignore_case_enabled(
145
+ session: snowflake.snowpark.Session, statement_params: Optional[dict[str, Any]] = None
146
+ ) -> bool:
147
+ """
148
+ Check if QUOTED_IDENTIFIERS_IGNORE_CASE parameter is enabled.
149
+
150
+ Args:
151
+ session: Snowpark session to check parameter for
152
+ statement_params: Optional statement parameters to check first
153
+
154
+ Returns:
155
+ bool: True if QUOTED_IDENTIFIERS_IGNORE_CASE is enabled, False otherwise
156
+ Returns False if the parameter cannot be retrieved (e.g., in stored procedures)
157
+ """
158
+ if statement_params is not None:
159
+ for key, value in statement_params.items():
160
+ if key.upper() == "QUOTED_IDENTIFIERS_IGNORE_CASE":
161
+ parameter_value = str(value)
162
+ return parameter_value.lower() == "true"
163
+
164
+ try:
165
+ result = session.sql(
166
+ "SHOW PARAMETERS LIKE 'QUOTED_IDENTIFIERS_IGNORE_CASE' IN SESSION",
167
+ _emit_ast=False,
168
+ ).collect(_emit_ast=False)
169
+
170
+ parameter_value = str(result[0].value)
171
+ return parameter_value.lower() == "true"
172
+
173
+ except Exception:
174
+ # Parameter query can fail in certain environments (e.g., in stored procedures)
175
+ # In that case, assume default behavior (case-sensitive)
176
+ return False
@@ -0,0 +1,11 @@
1
+ from enum import Enum
2
+
3
+
4
+ class TargetPlatform(Enum):
5
+ WAREHOUSE = "WAREHOUSE"
6
+ SNOWPARK_CONTAINER_SERVICES = "SNOWPARK_CONTAINER_SERVICES"
7
+
8
+
9
+ WAREHOUSE_ONLY = [TargetPlatform.WAREHOUSE]
10
+ SNOWPARK_CONTAINER_SERVICES_ONLY = [TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
11
+ BOTH_WAREHOUSE_AND_SNOWPARK_CONTAINER_SERVICES = [TargetPlatform.WAREHOUSE, TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
@@ -0,0 +1,9 @@
1
+ from enum import Enum
2
+
3
+
4
+ class Task(Enum):
5
+ UNKNOWN = "UNKNOWN"
6
+ TABULAR_BINARY_CLASSIFICATION = "TABULAR_BINARY_CLASSIFICATION"
7
+ TABULAR_MULTI_CLASSIFICATION = "TABULAR_MULTI_CLASSIFICATION"
8
+ TABULAR_REGRESSION = "TABULAR_REGRESSION"
9
+ TABULAR_RANKING = "TABULAR_RANKING"
@@ -1,10 +1,12 @@
1
1
  # mypy: disable-error-code="import"
2
- from enum import Enum
3
2
  from typing import TYPE_CHECKING, Literal, Sequence, TypedDict, TypeVar, Union
4
3
 
5
4
  import numpy.typing as npt
6
5
  from typing_extensions import NotRequired
7
6
 
7
+ from snowflake.ml.model.target_platform import TargetPlatform
8
+ from snowflake.ml.model.task import Task
9
+
8
10
  if TYPE_CHECKING:
9
11
  import catboost
10
12
  import keras
@@ -321,17 +323,7 @@ ModelLoadOption = Union[
321
323
  ]
322
324
 
323
325
 
324
- class Task(Enum):
325
- UNKNOWN = "UNKNOWN"
326
- TABULAR_BINARY_CLASSIFICATION = "TABULAR_BINARY_CLASSIFICATION"
327
- TABULAR_MULTI_CLASSIFICATION = "TABULAR_MULTI_CLASSIFICATION"
328
- TABULAR_REGRESSION = "TABULAR_REGRESSION"
329
- TABULAR_RANKING = "TABULAR_RANKING"
330
-
331
-
332
- class TargetPlatform(Enum):
333
- WAREHOUSE = "WAREHOUSE"
334
- SNOWPARK_CONTAINER_SERVICES = "SNOWPARK_CONTAINER_SERVICES"
326
+ SupportedTargetPlatformType = Union[TargetPlatform, str]
335
327
 
336
328
 
337
- SupportedTargetPlatformType = Union[TargetPlatform, str]
329
+ __all__ = ["TargetPlatform", "Task"]
@@ -60,6 +60,7 @@ def register_accumulator_udtf(*, session: Session, statement_params: dict[str, A
60
60
  ),
61
61
  input_types=[T.BinaryType()],
62
62
  packages=[f"numpy=={np.__version__}", f"cloudpickle=={cloudpickle.__version__}"],
63
+ imports=[], # Prevents unnecessary import resolution.
63
64
  name=accumulator,
64
65
  is_permanent=False,
65
66
  replace=True,
@@ -175,6 +176,7 @@ def register_sharded_dot_sum_computer(*, session: Session, statement_params: dic
175
176
  ),
176
177
  input_types=[T.ArrayType(), T.IntegerType(), T.IntegerType()],
177
178
  packages=[f"numpy=={np.__version__}", f"cloudpickle=={cloudpickle.__version__}"],
179
+ imports=[], # Prevents unnecessary import resolution.
178
180
  name=sharded_dot_and_sum_computer,
179
181
  is_permanent=False,
180
182
  replace=True,
@@ -1,5 +1,5 @@
1
1
  from types import ModuleType
2
- from typing import Any, Optional, Union
2
+ from typing import Any, Optional, Protocol, Union
3
3
 
4
4
  import pandas as pd
5
5
  from absl.logging import logging
@@ -8,7 +8,7 @@ from snowflake.ml._internal import env, platform_capabilities, telemetry
8
8
  from snowflake.ml._internal.exceptions import error_codes, exceptions
9
9
  from snowflake.ml._internal.human_readable_id import hrid_generator
10
10
  from snowflake.ml._internal.utils import sql_identifier
11
- from snowflake.ml.model import model_signature, type_hints as model_types
11
+ from snowflake.ml.model import model_signature, target_platform, task, type_hints
12
12
  from snowflake.ml.model._client.model import model_impl, model_version_impl
13
13
  from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
14
14
  from snowflake.ml.model._model_composer import model_composer
@@ -20,6 +20,14 @@ from snowflake.snowpark._internal import utils as snowpark_utils
20
20
  logger = logging.getLogger(__name__)
21
21
 
22
22
 
23
+ class EventHandler(Protocol):
24
+ """Protocol defining the interface for event handlers used during model operations."""
25
+
26
+ def update(self, message: str) -> None:
27
+ """Update with a progress message."""
28
+ ...
29
+
30
+
23
31
  class ModelManager:
24
32
  def __init__(
25
33
  self,
@@ -41,7 +49,7 @@ class ModelManager:
41
49
  def log_model(
42
50
  self,
43
51
  *,
44
- model: Union[model_types.SupportedModelType, model_version_impl.ModelVersion],
52
+ model: Union[type_hints.SupportedModelType, model_version_impl.ModelVersion],
45
53
  model_name: str,
46
54
  version_name: Optional[str] = None,
47
55
  comment: Optional[str] = None,
@@ -50,16 +58,17 @@ class ModelManager:
50
58
  pip_requirements: Optional[list[str]] = None,
51
59
  artifact_repository_map: Optional[dict[str, str]] = None,
52
60
  resource_constraint: Optional[dict[str, str]] = None,
53
- target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None,
61
+ target_platforms: Optional[list[type_hints.SupportedTargetPlatformType]] = None,
54
62
  python_version: Optional[str] = None,
55
63
  signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
56
- sample_input_data: Optional[model_types.SupportedDataType] = None,
64
+ sample_input_data: Optional[type_hints.SupportedDataType] = None,
57
65
  user_files: Optional[dict[str, list[str]]] = None,
58
66
  code_paths: Optional[list[str]] = None,
59
67
  ext_modules: Optional[list[ModuleType]] = None,
60
- task: model_types.Task = model_types.Task.UNKNOWN,
61
- options: Optional[model_types.ModelSaveOption] = None,
68
+ task: type_hints.Task = task.Task.UNKNOWN,
69
+ options: Optional[type_hints.ModelSaveOption] = None,
62
70
  statement_params: Optional[dict[str, Any]] = None,
71
+ event_handler: EventHandler,
63
72
  ) -> model_version_impl.ModelVersion:
64
73
 
65
74
  database_name_id, schema_name_id, model_name_id = self._parse_fully_qualified_name(model_name)
@@ -143,11 +152,12 @@ class ModelManager:
143
152
  task=task,
144
153
  options=options,
145
154
  statement_params=statement_params,
155
+ event_handler=event_handler,
146
156
  )
147
157
 
148
158
  def _log_model(
149
159
  self,
150
- model: model_types.SupportedModelType,
160
+ model: type_hints.SupportedModelType,
151
161
  *,
152
162
  model_name: str,
153
163
  version_name: str,
@@ -157,16 +167,17 @@ class ModelManager:
157
167
  pip_requirements: Optional[list[str]] = None,
158
168
  artifact_repository_map: Optional[dict[str, str]] = None,
159
169
  resource_constraint: Optional[dict[str, str]] = None,
160
- target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None,
170
+ target_platforms: Optional[list[type_hints.SupportedTargetPlatformType]] = None,
161
171
  python_version: Optional[str] = None,
162
172
  signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
163
- sample_input_data: Optional[model_types.SupportedDataType] = None,
173
+ sample_input_data: Optional[type_hints.SupportedDataType] = None,
164
174
  user_files: Optional[dict[str, list[str]]] = None,
165
175
  code_paths: Optional[list[str]] = None,
166
176
  ext_modules: Optional[list[ModuleType]] = None,
167
- task: model_types.Task = model_types.Task.UNKNOWN,
168
- options: Optional[model_types.ModelSaveOption] = None,
177
+ task: type_hints.Task = task.Task.UNKNOWN,
178
+ options: Optional[type_hints.ModelSaveOption] = None,
169
179
  statement_params: Optional[dict[str, Any]] = None,
180
+ event_handler: EventHandler,
170
181
  ) -> model_version_impl.ModelVersion:
171
182
  database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
172
183
  version_name_id = sql_identifier.SqlIdentifier(version_name)
@@ -215,7 +226,7 @@ class ModelManager:
215
226
  # User specified target platforms are defaulted to None and will not show up in the generated manifest.
216
227
  if target_platforms:
217
228
  # Convert any string target platforms to TargetPlatform objects
218
- platforms = [model_types.TargetPlatform(platform) for platform in target_platforms]
229
+ platforms = [type_hints.TargetPlatform(platform) for platform in target_platforms]
219
230
  else:
220
231
  # Default the target platform to warehouse if not specified and any table function exists
221
232
  if options and (
@@ -231,7 +242,7 @@ class ModelManager:
231
242
  "Logging a partitioned model with a table function without specifying `target_platforms`. "
232
243
  'Default to `target_platforms=["WAREHOUSE"]`.'
233
244
  )
234
- platforms = [model_types.TargetPlatform.WAREHOUSE]
245
+ platforms = [target_platform.TargetPlatform.WAREHOUSE]
235
246
 
236
247
  # Default the target platform to SPCS if not specified when running in ML runtime
237
248
  if not platforms and env.IN_ML_RUNTIME:
@@ -239,7 +250,7 @@ class ModelManager:
239
250
  "Logging the model on Container Runtime for ML without specifying `target_platforms`. "
240
251
  'Default to `target_platforms=["SNOWPARK_CONTAINER_SERVICES"]`.'
241
252
  )
242
- platforms = [model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
253
+ platforms = [target_platform.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
243
254
 
244
255
  if artifact_repository_map:
245
256
  for channel, artifact_repository_name in artifact_repository_map.items():
@@ -254,6 +265,7 @@ class ModelManager:
254
265
  )
255
266
 
256
267
  logger.info("Start packaging and uploading your model. It might take some time based on the size of the model.")
268
+ event_handler.update("📦 Packaging model...")
257
269
 
258
270
  # Extract save_location from options if present
259
271
  save_location = None
@@ -292,6 +304,7 @@ class ModelManager:
292
304
  )
293
305
 
294
306
  logger.info("Start creating MODEL object for you in the Snowflake.")
307
+ event_handler.update("🏗️ Creating model object in Snowflake...")
295
308
 
296
309
  self._model_ops.create_from_stage(
297
310
  composed_model=mc,
@@ -331,6 +344,8 @@ class ModelManager:
331
344
  statement_params=statement_params,
332
345
  )
333
346
 
347
+ event_handler.update("✅ Model logged successfully!")
348
+
334
349
  return mv
335
350
 
336
351
  def get_model(