snowflake-ml-python 1.8.5__py3-none-any.whl → 1.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. snowflake/ml/_internal/telemetry.py +6 -9
  2. snowflake/ml/_internal/utils/connection_params.py +196 -0
  3. snowflake/ml/_internal/utils/identifier.py +1 -1
  4. snowflake/ml/_internal/utils/mixins.py +61 -0
  5. snowflake/ml/jobs/__init__.py +2 -0
  6. snowflake/ml/jobs/_utils/constants.py +3 -2
  7. snowflake/ml/jobs/_utils/function_payload_utils.py +43 -0
  8. snowflake/ml/jobs/_utils/interop_utils.py +63 -4
  9. snowflake/ml/jobs/_utils/payload_utils.py +89 -40
  10. snowflake/ml/jobs/_utils/query_helper.py +9 -0
  11. snowflake/ml/jobs/_utils/scripts/constants.py +19 -3
  12. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -26
  13. snowflake/ml/jobs/_utils/spec_utils.py +29 -5
  14. snowflake/ml/jobs/_utils/stage_utils.py +119 -0
  15. snowflake/ml/jobs/_utils/types.py +5 -1
  16. snowflake/ml/jobs/decorators.py +20 -28
  17. snowflake/ml/jobs/job.py +197 -61
  18. snowflake/ml/jobs/manager.py +253 -121
  19. snowflake/ml/model/_client/model/model_impl.py +58 -0
  20. snowflake/ml/model/_client/model/model_version_impl.py +90 -0
  21. snowflake/ml/model/_client/ops/model_ops.py +18 -6
  22. snowflake/ml/model/_client/ops/service_ops.py +23 -6
  23. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -0
  24. snowflake/ml/model/_client/sql/service.py +68 -20
  25. snowflake/ml/model/_client/sql/stage.py +5 -2
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -10
  27. snowflake/ml/model/_packager/model_env/model_env.py +35 -27
  28. snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
  29. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +103 -73
  30. snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
  31. snowflake/ml/model/_signatures/core.py +24 -0
  32. snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
  33. snowflake/ml/model/target_platform.py +11 -0
  34. snowflake/ml/model/task.py +9 -0
  35. snowflake/ml/model/type_hints.py +5 -13
  36. snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
  37. snowflake/ml/monitoring/explain_visualize.py +2 -2
  38. snowflake/ml/monitoring/model_monitor.py +0 -4
  39. snowflake/ml/registry/_manager/model_manager.py +30 -15
  40. snowflake/ml/registry/registry.py +144 -47
  41. snowflake/ml/utils/connection_params.py +1 -1
  42. snowflake/ml/utils/html_utils.py +263 -0
  43. snowflake/ml/version.py +1 -1
  44. {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/METADATA +64 -19
  45. {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/RECORD +48 -41
  46. snowflake/ml/monitoring/model_monitor_version.py +0 -1
  47. {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/WHEEL +0 -0
  48. {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/licenses/LICENSE.txt +0 -0
  49. {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Callable, Optional, cast, final
5
5
  import cloudpickle
6
6
  import numpy as np
7
7
  import pandas as pd
8
+ import shap
8
9
  from typing_extensions import TypeGuard, Unpack
9
10
 
10
11
  from snowflake.ml._internal import type_utils
@@ -25,6 +26,19 @@ if TYPE_CHECKING:
25
26
  from snowflake.ml.modeling.framework.base import BaseEstimator
26
27
 
27
28
 
29
+ def _apply_transforms_up_to_last_step(
30
+ model: "BaseEstimator",
31
+ data: model_types.SupportedDataType,
32
+ ) -> pd.DataFrame:
33
+ """Apply all transformations in the snowml pipeline model up to the last step."""
34
+ if type_utils.LazyType("snowflake.ml.modeling.pipeline.Pipeline").isinstance(model):
35
+ for step_name, step in model.steps[:-1]: # type: ignore[attr-defined]
36
+ if not hasattr(step, "transform"):
37
+ raise ValueError(f"Step '{step_name}' does not have a 'transform' method.")
38
+ data = pd.DataFrame(step.transform(data))
39
+ return data
40
+
41
+
28
42
  @final
29
43
  class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
30
44
  """Handler for SnowML based model.
@@ -39,7 +53,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
39
53
  _HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
40
54
 
41
55
  DEFAULT_TARGET_METHODS = ["predict", "transform", "predict_proba", "predict_log_proba", "decision_function"]
42
- EXPLAIN_TARGET_METHODS = ["predict", "predict_proba", "predict_log_proba"]
56
+ EXPLAIN_TARGET_METHODS = ["predict_proba", "predict", "predict_log_proba"]
43
57
 
44
58
  IS_AUTO_SIGNATURE = True
45
59
 
@@ -97,11 +111,6 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
97
111
  return result
98
112
  except exceptions.SnowflakeMLException:
99
113
  pass # Do nothing and continue to the next method
100
-
101
- if enable_explainability:
102
- raise ValueError(
103
- "Explain only supported for xgboost, lightgbm and sklearn (not pipeline) Snowpark ML models."
104
- )
105
114
  return None
106
115
 
107
116
  @classmethod
@@ -189,23 +198,46 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
189
198
  else:
190
199
  enable_explainability = True
191
200
  if enable_explainability:
192
- model_task_and_output_type = model_task_utils.resolve_model_task_and_output_type(
193
- python_base_obj, model_meta.task
194
- )
195
- model_meta.task = model_task_and_output_type.task
196
- model_meta = handlers_utils.add_explain_method_signature(
197
- model_meta=model_meta,
198
- explain_method="explain",
199
- target_method=explain_target_method,
200
- output_return_type=model_task_and_output_type.output_type,
201
- )
202
- background_data = handlers_utils.get_explainability_supported_background(
203
- sample_input_data, model_meta, explain_target_method
204
- )
205
- if background_data is not None:
206
- handlers_utils.save_background_data(
207
- model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, cls.BG_DATA_FILE_SUFFIX, name, background_data
201
+ try:
202
+ model_task_and_output_type = model_task_utils.resolve_model_task_and_output_type(
203
+ python_base_obj, model_meta.task
204
+ )
205
+ model_meta.task = model_task_and_output_type.task
206
+ background_data = handlers_utils.get_explainability_supported_background(
207
+ sample_input_data, model_meta, explain_target_method
208
208
  )
209
+ if type_utils.LazyType("snowflake.ml.modeling.pipeline.Pipeline").isinstance(model):
210
+ transformed_df = _apply_transforms_up_to_last_step(model, sample_input_data)
211
+ explain_fn = cls._build_explain_fn(model, background_data)
212
+ model_meta = handlers_utils.add_inferred_explain_method_signature(
213
+ model_meta=model_meta,
214
+ explain_method="explain",
215
+ target_method=explain_target_method, # type: ignore[arg-type]
216
+ background_data=background_data,
217
+ explain_fn=explain_fn,
218
+ output_feature_names=transformed_df.columns,
219
+ )
220
+ else:
221
+ model_meta = handlers_utils.add_explain_method_signature(
222
+ model_meta=model_meta,
223
+ explain_method="explain",
224
+ target_method=explain_target_method,
225
+ output_return_type=model_task_and_output_type.output_type,
226
+ )
227
+ if background_data is not None:
228
+ handlers_utils.save_background_data(
229
+ model_blobs_dir_path,
230
+ cls.EXPLAIN_ARTIFACTS_DIR,
231
+ cls.BG_DATA_FILE_SUFFIX,
232
+ name,
233
+ background_data,
234
+ )
235
+ except Exception:
236
+ if kwargs.get("enable_explainability", None):
237
+ # user explicitly enabled explainability, so we should raise the error
238
+ raise ValueError(
239
+ "Explainability for this model is not supported. Please set `enable_explainability=False`"
240
+ )
209
241
 
210
242
  model_blob_path = os.path.join(model_blobs_dir_path, name)
211
243
  os.makedirs(model_blob_path, exist_ok=True)
@@ -251,6 +283,53 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
251
283
  assert isinstance(m, BaseEstimator)
252
284
  return m
253
285
 
286
+ @classmethod
287
+ def _build_explain_fn(
288
+ cls, model: "BaseEstimator", background_data: model_types.SupportedDataType
289
+ ) -> Callable[[model_types.SupportedDataType], pd.DataFrame]:
290
+
291
+ predictor = model
292
+ is_pipeline = type_utils.LazyType("snowflake.ml.modeling.pipeline.Pipeline").isinstance(model)
293
+ if is_pipeline:
294
+ background_data = _apply_transforms_up_to_last_step(model, background_data)
295
+ predictor = model.steps[-1][1] # type: ignore[attr-defined]
296
+
297
+ def explain_fn(data: model_types.SupportedDataType) -> pd.DataFrame:
298
+ data = _apply_transforms_up_to_last_step(model, data)
299
+ tree_methods = ["to_xgboost", "to_lightgbm"]
300
+ non_tree_methods = ["to_sklearn", None] # None just uses the predictor directly
301
+ for method_name in tree_methods:
302
+ try:
303
+ base_model = getattr(predictor, method_name)()
304
+ explainer = shap.TreeExplainer(base_model)
305
+ return handlers_utils.convert_explanations_to_2D_df(model, explainer.shap_values(data))
306
+ except exceptions.SnowflakeMLException:
307
+ pass # Do nothing and continue to the next method
308
+ for method_name in non_tree_methods: # type: ignore[assignment]
309
+ try:
310
+ base_model = getattr(predictor, method_name)() if method_name is not None else predictor
311
+ try:
312
+ explainer = shap.Explainer(base_model, masker=background_data)
313
+ return handlers_utils.convert_explanations_to_2D_df(base_model, explainer(data).values)
314
+ except TypeError:
315
+ for explain_target_method in cls.EXPLAIN_TARGET_METHODS:
316
+ if not hasattr(base_model, explain_target_method):
317
+ continue
318
+ explain_target_method_fn = getattr(base_model, explain_target_method)
319
+ if isinstance(data, np.ndarray):
320
+ explainer = shap.Explainer(
321
+ explain_target_method_fn,
322
+ background_data.values, # type: ignore[union-attr]
323
+ )
324
+ else:
325
+ explainer = shap.Explainer(explain_target_method_fn, background_data)
326
+ return handlers_utils.convert_explanations_to_2D_df(base_model, explainer(data).values)
327
+ except Exception:
328
+ pass # Do nothing and continue to the next method
329
+ raise ValueError("Explainability for this model is not supported.")
330
+
331
+ return explain_fn
332
+
254
333
  @classmethod
255
334
  def convert_as_custom_model(
256
335
  cls,
@@ -286,57 +365,8 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
286
365
 
287
366
  @custom_model.inference_api
288
367
  def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
289
- import shap
290
-
291
- tree_methods = ["to_xgboost", "to_lightgbm"]
292
- non_tree_methods = ["to_sklearn"]
293
- for method_name in tree_methods:
294
- try:
295
- base_model = getattr(raw_model, method_name)()
296
- explainer = shap.TreeExplainer(base_model)
297
- df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer.shap_values(X))
298
- return model_signature_utils.rename_pandas_df(df, signature.outputs)
299
- except exceptions.SnowflakeMLException:
300
- pass # Do nothing and continue to the next method
301
- for method_name in non_tree_methods:
302
- try:
303
- base_model = getattr(raw_model, method_name)()
304
- try:
305
- explainer = shap.Explainer(base_model, masker=background_data)
306
- df = handlers_utils.convert_explanations_to_2D_df(base_model, explainer(X).values)
307
- except TypeError:
308
- try:
309
- dtype_map = {
310
- spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in signature.inputs
311
- }
312
-
313
- if isinstance(X, pd.DataFrame):
314
- X = X.astype(dtype_map, copy=False)
315
- if hasattr(base_model, "predict_proba"):
316
- if isinstance(X, np.ndarray):
317
- explainer = shap.Explainer(
318
- base_model.predict_proba,
319
- background_data.values, # type: ignore[union-attr]
320
- )
321
- else:
322
- explainer = shap.Explainer(base_model.predict_proba, background_data)
323
- elif hasattr(base_model, "predict"):
324
- if isinstance(X, np.ndarray):
325
- explainer = shap.Explainer(
326
- base_model.predict, background_data.values # type: ignore[union-attr]
327
- )
328
- else:
329
- explainer = shap.Explainer(base_model.predict, background_data)
330
- else:
331
- raise ValueError("Missing any supported target method to explain.")
332
- df = handlers_utils.convert_explanations_to_2D_df(base_model, explainer(X).values)
333
- except TypeError as e:
334
- raise ValueError(f"Explanation for this model type not supported yet: {str(e)}")
335
- return model_signature_utils.rename_pandas_df(df, signature.outputs)
336
-
337
- except exceptions.SnowflakeMLException:
338
- pass # Do nothing and continue to the next method
339
- raise ValueError("The model must be an xgboost, lightgbm or sklearn (not pipeline) estimator.")
368
+ fn = cls._build_explain_fn(raw_model, background_data)
369
+ return model_signature_utils.rename_pandas_df(fn(X), signature.outputs)
340
370
 
341
371
  if target_method == "explain":
342
372
  return explain_fn
@@ -110,6 +110,7 @@ def create_model_metadata(
110
110
  python_version=python_version,
111
111
  embed_local_ml_library=embed_local_ml_library,
112
112
  prefer_pip=prefer_pip,
113
+ target_platforms=target_platforms,
113
114
  )
114
115
 
115
116
  if embed_local_ml_library:
@@ -162,8 +163,9 @@ def _create_env_for_model_metadata(
162
163
  python_version: Optional[str] = None,
163
164
  embed_local_ml_library: bool = False,
164
165
  prefer_pip: bool = False,
166
+ target_platforms: Optional[list[model_types.TargetPlatform]] = None,
165
167
  ) -> model_env.ModelEnv:
166
- env = model_env.ModelEnv(prefer_pip=prefer_pip)
168
+ env = model_env.ModelEnv(prefer_pip=prefer_pip, target_platforms=target_platforms)
167
169
 
168
170
  # Mypy doesn't like getter and setter have different types. See python/mypy #3004
169
171
  env.conda_dependencies = conda_dependencies # type: ignore[assignment]
@@ -559,6 +559,30 @@ class ModelSignature:
559
559
  )"""
560
560
  )
561
561
 
562
+ def _repr_html_(self) -> str:
563
+ """Generate an HTML representation of the model signature.
564
+
565
+ Returns:
566
+ str: HTML string containing formatted signature details.
567
+ """
568
+ from snowflake.ml.utils import html_utils
569
+
570
+ # Create collapsible sections for inputs and outputs
571
+ inputs_content = html_utils.create_features_html(self.inputs, "Input")
572
+ outputs_content = html_utils.create_features_html(self.outputs, "Output")
573
+
574
+ inputs_section = html_utils.create_collapsible_section("Inputs", inputs_content, open_by_default=True)
575
+ outputs_section = html_utils.create_collapsible_section("Outputs", outputs_content, open_by_default=True)
576
+
577
+ content = f"""
578
+ <div style="margin-top: 10px;">
579
+ {inputs_section}
580
+ {outputs_section}
581
+ </div>
582
+ """
583
+
584
+ return html_utils.create_base_container("Model Signature", content)
585
+
562
586
  @classmethod
563
587
  def from_mlflow_sig(cls, mlflow_sig: "mlflow.models.ModelSignature") -> "ModelSignature":
564
588
  return ModelSignature(
@@ -60,12 +60,19 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
60
60
  data: snowflake.snowpark.DataFrame,
61
61
  ensure_serializable: bool = True,
62
62
  features: Optional[Sequence[core.BaseFeatureSpec]] = None,
63
+ statement_params: Optional[dict[str, Any]] = None,
63
64
  ) -> pd.DataFrame:
64
65
  # This method do things on top of to_pandas, to make sure the local dataframe got is in correct shape.
65
66
  dtype_map = {}
67
+
66
68
  if features:
69
+ quoted_identifiers_ignore_case = SnowparkDataFrameHandler._is_quoted_identifiers_ignore_case_enabled(
70
+ data.session, statement_params
71
+ )
67
72
  for feature in features:
68
- dtype_map[feature.name] = feature.as_dtype()
73
+ feature_name = feature.name.upper() if quoted_identifiers_ignore_case else feature.name
74
+ dtype_map[feature_name] = feature.as_dtype()
75
+
69
76
  df_local = data.to_pandas()
70
77
 
71
78
  # This is because Array will become string (Even though the correct schema is set)
@@ -93,6 +100,7 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
93
100
  df: pd.DataFrame,
94
101
  keep_order: bool = False,
95
102
  features: Optional[Sequence[core.BaseFeatureSpec]] = None,
103
+ statement_params: Optional[dict[str, Any]] = None,
96
104
  ) -> snowflake.snowpark.DataFrame:
97
105
  # This method is necessary to create the Snowpark Dataframe in correct schema.
98
106
  # However, in this case, the order could not be preserved. Thus, a _ID column has to be added,
@@ -100,6 +108,12 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
100
108
  # Although in this case, the column with array type can get correct ARRAY type, however, the element
101
109
  # type is not preserved, and will become string type. This affect the implementation of convert_from_df.
102
110
  df = pandas_handler.PandasDataFrameHandler.convert_to_df(df)
111
+ quoted_identifiers_ignore_case = SnowparkDataFrameHandler._is_quoted_identifiers_ignore_case_enabled(
112
+ session, statement_params
113
+ )
114
+ if quoted_identifiers_ignore_case:
115
+ df.columns = [str(col).upper() for col in df.columns]
116
+
103
117
  df_cols = df.columns
104
118
  if df_cols.dtype != np.object_:
105
119
  raise snowml_exceptions.SnowflakeMLException(
@@ -116,9 +130,47 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
116
130
  column_names = []
117
131
  columns = []
118
132
  for feature in features:
119
- column_names.append(identifier.get_inferred_name(feature.name))
120
- columns.append(F.col(identifier.get_inferred_name(feature.name)).cast(feature.as_snowpark_type()))
133
+ feature_name = identifier.get_inferred_name(feature.name)
134
+ if quoted_identifiers_ignore_case:
135
+ feature_name = feature_name.upper()
136
+ column_names.append(feature_name)
137
+ columns.append(F.col(feature_name).cast(feature.as_snowpark_type()))
121
138
 
122
139
  sp_df = sp_df.with_columns(column_names, columns)
123
140
 
124
141
  return sp_df
142
+
143
+ @staticmethod
144
+ def _is_quoted_identifiers_ignore_case_enabled(
145
+ session: snowflake.snowpark.Session, statement_params: Optional[dict[str, Any]] = None
146
+ ) -> bool:
147
+ """
148
+ Check if QUOTED_IDENTIFIERS_IGNORE_CASE parameter is enabled.
149
+
150
+ Args:
151
+ session: Snowpark session to check parameter for
152
+ statement_params: Optional statement parameters to check first
153
+
154
+ Returns:
155
+ bool: True if QUOTED_IDENTIFIERS_IGNORE_CASE is enabled, False otherwise
156
+ Returns False if the parameter cannot be retrieved (e.g., in stored procedures)
157
+ """
158
+ if statement_params is not None:
159
+ for key, value in statement_params.items():
160
+ if key.upper() == "QUOTED_IDENTIFIERS_IGNORE_CASE":
161
+ parameter_value = str(value)
162
+ return parameter_value.lower() == "true"
163
+
164
+ try:
165
+ result = session.sql(
166
+ "SHOW PARAMETERS LIKE 'QUOTED_IDENTIFIERS_IGNORE_CASE' IN SESSION",
167
+ _emit_ast=False,
168
+ ).collect(_emit_ast=False)
169
+
170
+ parameter_value = str(result[0].value)
171
+ return parameter_value.lower() == "true"
172
+
173
+ except Exception:
174
+ # Parameter query can fail in certain environments (e.g., in stored procedures)
175
+ # In that case, assume default behavior (case-sensitive)
176
+ return False
@@ -0,0 +1,11 @@
1
+ from enum import Enum
2
+
3
+
4
+ class TargetPlatform(Enum):
5
+ WAREHOUSE = "WAREHOUSE"
6
+ SNOWPARK_CONTAINER_SERVICES = "SNOWPARK_CONTAINER_SERVICES"
7
+
8
+
9
+ WAREHOUSE_ONLY = [TargetPlatform.WAREHOUSE]
10
+ SNOWPARK_CONTAINER_SERVICES_ONLY = [TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
11
+ BOTH_WAREHOUSE_AND_SNOWPARK_CONTAINER_SERVICES = [TargetPlatform.WAREHOUSE, TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
@@ -0,0 +1,9 @@
1
+ from enum import Enum
2
+
3
+
4
+ class Task(Enum):
5
+ UNKNOWN = "UNKNOWN"
6
+ TABULAR_BINARY_CLASSIFICATION = "TABULAR_BINARY_CLASSIFICATION"
7
+ TABULAR_MULTI_CLASSIFICATION = "TABULAR_MULTI_CLASSIFICATION"
8
+ TABULAR_REGRESSION = "TABULAR_REGRESSION"
9
+ TABULAR_RANKING = "TABULAR_RANKING"
@@ -1,10 +1,12 @@
1
1
  # mypy: disable-error-code="import"
2
- from enum import Enum
3
2
  from typing import TYPE_CHECKING, Literal, Sequence, TypedDict, TypeVar, Union
4
3
 
5
4
  import numpy.typing as npt
6
5
  from typing_extensions import NotRequired
7
6
 
7
+ from snowflake.ml.model.target_platform import TargetPlatform
8
+ from snowflake.ml.model.task import Task
9
+
8
10
  if TYPE_CHECKING:
9
11
  import catboost
10
12
  import keras
@@ -321,17 +323,7 @@ ModelLoadOption = Union[
321
323
  ]
322
324
 
323
325
 
324
- class Task(Enum):
325
- UNKNOWN = "UNKNOWN"
326
- TABULAR_BINARY_CLASSIFICATION = "TABULAR_BINARY_CLASSIFICATION"
327
- TABULAR_MULTI_CLASSIFICATION = "TABULAR_MULTI_CLASSIFICATION"
328
- TABULAR_REGRESSION = "TABULAR_REGRESSION"
329
- TABULAR_RANKING = "TABULAR_RANKING"
330
-
331
-
332
- class TargetPlatform(Enum):
333
- WAREHOUSE = "WAREHOUSE"
334
- SNOWPARK_CONTAINER_SERVICES = "SNOWPARK_CONTAINER_SERVICES"
326
+ SupportedTargetPlatformType = Union[TargetPlatform, str]
335
327
 
336
328
 
337
- SupportedTargetPlatformType = Union[TargetPlatform, str]
329
+ __all__ = ["TargetPlatform", "Task"]
@@ -60,6 +60,7 @@ def register_accumulator_udtf(*, session: Session, statement_params: dict[str, A
60
60
  ),
61
61
  input_types=[T.BinaryType()],
62
62
  packages=[f"numpy=={np.__version__}", f"cloudpickle=={cloudpickle.__version__}"],
63
+ imports=[], # Prevents unnecessary import resolution.
63
64
  name=accumulator,
64
65
  is_permanent=False,
65
66
  replace=True,
@@ -175,6 +176,7 @@ def register_sharded_dot_sum_computer(*, session: Session, statement_params: dic
175
176
  ),
176
177
  input_types=[T.ArrayType(), T.IntegerType(), T.IntegerType()],
177
178
  packages=[f"numpy=={np.__version__}", f"cloudpickle=={cloudpickle.__version__}"],
179
+ imports=[], # Prevents unnecessary import resolution.
178
180
  name=sharded_dot_and_sum_computer,
179
181
  is_permanent=False,
180
182
  replace=True,
@@ -272,8 +272,8 @@ def plot_influence_sensitivity(
272
272
  If Streamlit is not available and a DataFrame is passed in, an ImportError will be raised.
273
273
 
274
274
  Args:
275
- feature_values: pandas Series or 2D array containing the feature values for a specific feature
276
- shap_values: pandas Series or 2D array containing the SHAP values for the same feature
275
+ shap_values: pandas Series or 2D array containing the SHAP values for a specific feature
276
+ feature_values: pandas Series or 2D array containing the feature values for the same feature
277
277
  figsize: tuple of (width, height) for the plot
278
278
 
279
279
  Returns:
@@ -1,7 +1,5 @@
1
- from snowflake import snowpark
2
1
  from snowflake.ml._internal import telemetry
3
2
  from snowflake.ml._internal.utils import sql_identifier
4
- from snowflake.ml.monitoring import model_monitor_version
5
3
  from snowflake.ml.monitoring._client import model_monitor_sql_client
6
4
 
7
5
 
@@ -29,7 +27,6 @@ class ModelMonitor:
29
27
  project=telemetry.TelemetryProject.MLOPS.value,
30
28
  subproject=telemetry.TelemetrySubProject.MONITORING.value,
31
29
  )
32
- @snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
33
30
  def suspend(self) -> None:
34
31
  """Suspend the Model Monitor"""
35
32
  statement_params = telemetry.get_statement_params(
@@ -42,7 +39,6 @@ class ModelMonitor:
42
39
  project=telemetry.TelemetryProject.MLOPS.value,
43
40
  subproject=telemetry.TelemetrySubProject.MONITORING.value,
44
41
  )
45
- @snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
46
42
  def resume(self) -> None:
47
43
  """Resume the Model Monitor"""
48
44
  statement_params = telemetry.get_statement_params(
@@ -1,5 +1,5 @@
1
1
  from types import ModuleType
2
- from typing import Any, Optional, Union
2
+ from typing import Any, Optional, Protocol, Union
3
3
 
4
4
  import pandas as pd
5
5
  from absl.logging import logging
@@ -8,7 +8,7 @@ from snowflake.ml._internal import env, platform_capabilities, telemetry
8
8
  from snowflake.ml._internal.exceptions import error_codes, exceptions
9
9
  from snowflake.ml._internal.human_readable_id import hrid_generator
10
10
  from snowflake.ml._internal.utils import sql_identifier
11
- from snowflake.ml.model import model_signature, type_hints as model_types
11
+ from snowflake.ml.model import model_signature, target_platform, task, type_hints
12
12
  from snowflake.ml.model._client.model import model_impl, model_version_impl
13
13
  from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
14
14
  from snowflake.ml.model._model_composer import model_composer
@@ -20,6 +20,14 @@ from snowflake.snowpark._internal import utils as snowpark_utils
20
20
  logger = logging.getLogger(__name__)
21
21
 
22
22
 
23
+ class EventHandler(Protocol):
24
+ """Protocol defining the interface for event handlers used during model operations."""
25
+
26
+ def update(self, message: str) -> None:
27
+ """Update with a progress message."""
28
+ ...
29
+
30
+
23
31
  class ModelManager:
24
32
  def __init__(
25
33
  self,
@@ -41,7 +49,7 @@ class ModelManager:
41
49
  def log_model(
42
50
  self,
43
51
  *,
44
- model: Union[model_types.SupportedModelType, model_version_impl.ModelVersion],
52
+ model: Union[type_hints.SupportedModelType, model_version_impl.ModelVersion],
45
53
  model_name: str,
46
54
  version_name: Optional[str] = None,
47
55
  comment: Optional[str] = None,
@@ -50,16 +58,17 @@ class ModelManager:
50
58
  pip_requirements: Optional[list[str]] = None,
51
59
  artifact_repository_map: Optional[dict[str, str]] = None,
52
60
  resource_constraint: Optional[dict[str, str]] = None,
53
- target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None,
61
+ target_platforms: Optional[list[type_hints.SupportedTargetPlatformType]] = None,
54
62
  python_version: Optional[str] = None,
55
63
  signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
56
- sample_input_data: Optional[model_types.SupportedDataType] = None,
64
+ sample_input_data: Optional[type_hints.SupportedDataType] = None,
57
65
  user_files: Optional[dict[str, list[str]]] = None,
58
66
  code_paths: Optional[list[str]] = None,
59
67
  ext_modules: Optional[list[ModuleType]] = None,
60
- task: model_types.Task = model_types.Task.UNKNOWN,
61
- options: Optional[model_types.ModelSaveOption] = None,
68
+ task: type_hints.Task = task.Task.UNKNOWN,
69
+ options: Optional[type_hints.ModelSaveOption] = None,
62
70
  statement_params: Optional[dict[str, Any]] = None,
71
+ event_handler: EventHandler,
63
72
  ) -> model_version_impl.ModelVersion:
64
73
 
65
74
  database_name_id, schema_name_id, model_name_id = self._parse_fully_qualified_name(model_name)
@@ -143,11 +152,12 @@ class ModelManager:
143
152
  task=task,
144
153
  options=options,
145
154
  statement_params=statement_params,
155
+ event_handler=event_handler,
146
156
  )
147
157
 
148
158
  def _log_model(
149
159
  self,
150
- model: model_types.SupportedModelType,
160
+ model: type_hints.SupportedModelType,
151
161
  *,
152
162
  model_name: str,
153
163
  version_name: str,
@@ -157,16 +167,17 @@ class ModelManager:
157
167
  pip_requirements: Optional[list[str]] = None,
158
168
  artifact_repository_map: Optional[dict[str, str]] = None,
159
169
  resource_constraint: Optional[dict[str, str]] = None,
160
- target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None,
170
+ target_platforms: Optional[list[type_hints.SupportedTargetPlatformType]] = None,
161
171
  python_version: Optional[str] = None,
162
172
  signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
163
- sample_input_data: Optional[model_types.SupportedDataType] = None,
173
+ sample_input_data: Optional[type_hints.SupportedDataType] = None,
164
174
  user_files: Optional[dict[str, list[str]]] = None,
165
175
  code_paths: Optional[list[str]] = None,
166
176
  ext_modules: Optional[list[ModuleType]] = None,
167
- task: model_types.Task = model_types.Task.UNKNOWN,
168
- options: Optional[model_types.ModelSaveOption] = None,
177
+ task: type_hints.Task = task.Task.UNKNOWN,
178
+ options: Optional[type_hints.ModelSaveOption] = None,
169
179
  statement_params: Optional[dict[str, Any]] = None,
180
+ event_handler: EventHandler,
170
181
  ) -> model_version_impl.ModelVersion:
171
182
  database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
172
183
  version_name_id = sql_identifier.SqlIdentifier(version_name)
@@ -215,7 +226,7 @@ class ModelManager:
215
226
  # User specified target platforms are defaulted to None and will not show up in the generated manifest.
216
227
  if target_platforms:
217
228
  # Convert any string target platforms to TargetPlatform objects
218
- platforms = [model_types.TargetPlatform(platform) for platform in target_platforms]
229
+ platforms = [type_hints.TargetPlatform(platform) for platform in target_platforms]
219
230
  else:
220
231
  # Default the target platform to warehouse if not specified and any table function exists
221
232
  if options and (
@@ -231,7 +242,7 @@ class ModelManager:
231
242
  "Logging a partitioned model with a table function without specifying `target_platforms`. "
232
243
  'Default to `target_platforms=["WAREHOUSE"]`.'
233
244
  )
234
- platforms = [model_types.TargetPlatform.WAREHOUSE]
245
+ platforms = [target_platform.TargetPlatform.WAREHOUSE]
235
246
 
236
247
  # Default the target platform to SPCS if not specified when running in ML runtime
237
248
  if not platforms and env.IN_ML_RUNTIME:
@@ -239,7 +250,7 @@ class ModelManager:
239
250
  "Logging the model on Container Runtime for ML without specifying `target_platforms`. "
240
251
  'Default to `target_platforms=["SNOWPARK_CONTAINER_SERVICES"]`.'
241
252
  )
242
- platforms = [model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
253
+ platforms = [target_platform.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
243
254
 
244
255
  if artifact_repository_map:
245
256
  for channel, artifact_repository_name in artifact_repository_map.items():
@@ -254,6 +265,7 @@ class ModelManager:
254
265
  )
255
266
 
256
267
  logger.info("Start packaging and uploading your model. It might take some time based on the size of the model.")
268
+ event_handler.update("📦 Packaging model...")
257
269
 
258
270
  # Extract save_location from options if present
259
271
  save_location = None
@@ -292,6 +304,7 @@ class ModelManager:
292
304
  )
293
305
 
294
306
  logger.info("Start creating MODEL object for you in the Snowflake.")
307
+ event_handler.update("🏗️ Creating model object in Snowflake...")
295
308
 
296
309
  self._model_ops.create_from_stage(
297
310
  composed_model=mc,
@@ -331,6 +344,8 @@ class ModelManager:
331
344
  statement_params=statement_params,
332
345
  )
333
346
 
347
+ event_handler.update("✅ Model logged successfully!")
348
+
334
349
  return mv
335
350
 
336
351
  def get_model(