snowflake-ml-python 1.8.5__py3-none-any.whl → 1.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/telemetry.py +6 -9
- snowflake/ml/_internal/utils/connection_params.py +196 -0
- snowflake/ml/_internal/utils/identifier.py +1 -1
- snowflake/ml/_internal/utils/mixins.py +61 -0
- snowflake/ml/jobs/__init__.py +2 -0
- snowflake/ml/jobs/_utils/constants.py +3 -2
- snowflake/ml/jobs/_utils/function_payload_utils.py +43 -0
- snowflake/ml/jobs/_utils/interop_utils.py +63 -4
- snowflake/ml/jobs/_utils/payload_utils.py +89 -40
- snowflake/ml/jobs/_utils/query_helper.py +9 -0
- snowflake/ml/jobs/_utils/scripts/constants.py +19 -3
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +8 -26
- snowflake/ml/jobs/_utils/spec_utils.py +29 -5
- snowflake/ml/jobs/_utils/stage_utils.py +119 -0
- snowflake/ml/jobs/_utils/types.py +5 -1
- snowflake/ml/jobs/decorators.py +20 -28
- snowflake/ml/jobs/job.py +197 -61
- snowflake/ml/jobs/manager.py +253 -121
- snowflake/ml/model/_client/model/model_impl.py +58 -0
- snowflake/ml/model/_client/model/model_version_impl.py +90 -0
- snowflake/ml/model/_client/ops/model_ops.py +18 -6
- snowflake/ml/model/_client/ops/service_ops.py +23 -6
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -0
- snowflake/ml/model/_client/sql/service.py +68 -20
- snowflake/ml/model/_client/sql/stage.py +5 -2
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +38 -10
- snowflake/ml/model/_packager/model_env/model_env.py +35 -27
- snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +103 -73
- snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
- snowflake/ml/model/_signatures/core.py +24 -0
- snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
- snowflake/ml/model/target_platform.py +11 -0
- snowflake/ml/model/task.py +9 -0
- snowflake/ml/model/type_hints.py +5 -13
- snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
- snowflake/ml/monitoring/explain_visualize.py +2 -2
- snowflake/ml/monitoring/model_monitor.py +0 -4
- snowflake/ml/registry/_manager/model_manager.py +30 -15
- snowflake/ml/registry/registry.py +144 -47
- snowflake/ml/utils/connection_params.py +1 -1
- snowflake/ml/utils/html_utils.py +263 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/METADATA +64 -19
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/RECORD +48 -41
- snowflake/ml/monitoring/model_monitor_version.py +0 -1
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.5.dist-info → snowflake_ml_python-1.9.0.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, Callable, Optional, cast, final
|
|
5
5
|
import cloudpickle
|
6
6
|
import numpy as np
|
7
7
|
import pandas as pd
|
8
|
+
import shap
|
8
9
|
from typing_extensions import TypeGuard, Unpack
|
9
10
|
|
10
11
|
from snowflake.ml._internal import type_utils
|
@@ -25,6 +26,19 @@ if TYPE_CHECKING:
|
|
25
26
|
from snowflake.ml.modeling.framework.base import BaseEstimator
|
26
27
|
|
27
28
|
|
29
|
+
def _apply_transforms_up_to_last_step(
|
30
|
+
model: "BaseEstimator",
|
31
|
+
data: model_types.SupportedDataType,
|
32
|
+
) -> pd.DataFrame:
|
33
|
+
"""Apply all transformations in the snowml pipeline model up to the last step."""
|
34
|
+
if type_utils.LazyType("snowflake.ml.modeling.pipeline.Pipeline").isinstance(model):
|
35
|
+
for step_name, step in model.steps[:-1]: # type: ignore[attr-defined]
|
36
|
+
if not hasattr(step, "transform"):
|
37
|
+
raise ValueError(f"Step '{step_name}' does not have a 'transform' method.")
|
38
|
+
data = pd.DataFrame(step.transform(data))
|
39
|
+
return data
|
40
|
+
|
41
|
+
|
28
42
|
@final
|
29
43
|
class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
30
44
|
"""Handler for SnowML based model.
|
@@ -39,7 +53,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
39
53
|
_HANDLER_MIGRATOR_PLANS: dict[str, type[base_migrator.BaseModelHandlerMigrator]] = {}
|
40
54
|
|
41
55
|
DEFAULT_TARGET_METHODS = ["predict", "transform", "predict_proba", "predict_log_proba", "decision_function"]
|
42
|
-
EXPLAIN_TARGET_METHODS = ["
|
56
|
+
EXPLAIN_TARGET_METHODS = ["predict_proba", "predict", "predict_log_proba"]
|
43
57
|
|
44
58
|
IS_AUTO_SIGNATURE = True
|
45
59
|
|
@@ -97,11 +111,6 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
97
111
|
return result
|
98
112
|
except exceptions.SnowflakeMLException:
|
99
113
|
pass # Do nothing and continue to the next method
|
100
|
-
|
101
|
-
if enable_explainability:
|
102
|
-
raise ValueError(
|
103
|
-
"Explain only supported for xgboost, lightgbm and sklearn (not pipeline) Snowpark ML models."
|
104
|
-
)
|
105
114
|
return None
|
106
115
|
|
107
116
|
@classmethod
|
@@ -189,23 +198,46 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
189
198
|
else:
|
190
199
|
enable_explainability = True
|
191
200
|
if enable_explainability:
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
target_method=explain_target_method,
|
200
|
-
output_return_type=model_task_and_output_type.output_type,
|
201
|
-
)
|
202
|
-
background_data = handlers_utils.get_explainability_supported_background(
|
203
|
-
sample_input_data, model_meta, explain_target_method
|
204
|
-
)
|
205
|
-
if background_data is not None:
|
206
|
-
handlers_utils.save_background_data(
|
207
|
-
model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, cls.BG_DATA_FILE_SUFFIX, name, background_data
|
201
|
+
try:
|
202
|
+
model_task_and_output_type = model_task_utils.resolve_model_task_and_output_type(
|
203
|
+
python_base_obj, model_meta.task
|
204
|
+
)
|
205
|
+
model_meta.task = model_task_and_output_type.task
|
206
|
+
background_data = handlers_utils.get_explainability_supported_background(
|
207
|
+
sample_input_data, model_meta, explain_target_method
|
208
208
|
)
|
209
|
+
if type_utils.LazyType("snowflake.ml.modeling.pipeline.Pipeline").isinstance(model):
|
210
|
+
transformed_df = _apply_transforms_up_to_last_step(model, sample_input_data)
|
211
|
+
explain_fn = cls._build_explain_fn(model, background_data)
|
212
|
+
model_meta = handlers_utils.add_inferred_explain_method_signature(
|
213
|
+
model_meta=model_meta,
|
214
|
+
explain_method="explain",
|
215
|
+
target_method=explain_target_method, # type: ignore[arg-type]
|
216
|
+
background_data=background_data,
|
217
|
+
explain_fn=explain_fn,
|
218
|
+
output_feature_names=transformed_df.columns,
|
219
|
+
)
|
220
|
+
else:
|
221
|
+
model_meta = handlers_utils.add_explain_method_signature(
|
222
|
+
model_meta=model_meta,
|
223
|
+
explain_method="explain",
|
224
|
+
target_method=explain_target_method,
|
225
|
+
output_return_type=model_task_and_output_type.output_type,
|
226
|
+
)
|
227
|
+
if background_data is not None:
|
228
|
+
handlers_utils.save_background_data(
|
229
|
+
model_blobs_dir_path,
|
230
|
+
cls.EXPLAIN_ARTIFACTS_DIR,
|
231
|
+
cls.BG_DATA_FILE_SUFFIX,
|
232
|
+
name,
|
233
|
+
background_data,
|
234
|
+
)
|
235
|
+
except Exception:
|
236
|
+
if kwargs.get("enable_explainability", None):
|
237
|
+
# user explicitly enabled explainability, so we should raise the error
|
238
|
+
raise ValueError(
|
239
|
+
"Explainability for this model is not supported. Please set `enable_explainability=False`"
|
240
|
+
)
|
209
241
|
|
210
242
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
211
243
|
os.makedirs(model_blob_path, exist_ok=True)
|
@@ -251,6 +283,53 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
251
283
|
assert isinstance(m, BaseEstimator)
|
252
284
|
return m
|
253
285
|
|
286
|
+
@classmethod
|
287
|
+
def _build_explain_fn(
|
288
|
+
cls, model: "BaseEstimator", background_data: model_types.SupportedDataType
|
289
|
+
) -> Callable[[model_types.SupportedDataType], pd.DataFrame]:
|
290
|
+
|
291
|
+
predictor = model
|
292
|
+
is_pipeline = type_utils.LazyType("snowflake.ml.modeling.pipeline.Pipeline").isinstance(model)
|
293
|
+
if is_pipeline:
|
294
|
+
background_data = _apply_transforms_up_to_last_step(model, background_data)
|
295
|
+
predictor = model.steps[-1][1] # type: ignore[attr-defined]
|
296
|
+
|
297
|
+
def explain_fn(data: model_types.SupportedDataType) -> pd.DataFrame:
|
298
|
+
data = _apply_transforms_up_to_last_step(model, data)
|
299
|
+
tree_methods = ["to_xgboost", "to_lightgbm"]
|
300
|
+
non_tree_methods = ["to_sklearn", None] # None just uses the predictor directly
|
301
|
+
for method_name in tree_methods:
|
302
|
+
try:
|
303
|
+
base_model = getattr(predictor, method_name)()
|
304
|
+
explainer = shap.TreeExplainer(base_model)
|
305
|
+
return handlers_utils.convert_explanations_to_2D_df(model, explainer.shap_values(data))
|
306
|
+
except exceptions.SnowflakeMLException:
|
307
|
+
pass # Do nothing and continue to the next method
|
308
|
+
for method_name in non_tree_methods: # type: ignore[assignment]
|
309
|
+
try:
|
310
|
+
base_model = getattr(predictor, method_name)() if method_name is not None else predictor
|
311
|
+
try:
|
312
|
+
explainer = shap.Explainer(base_model, masker=background_data)
|
313
|
+
return handlers_utils.convert_explanations_to_2D_df(base_model, explainer(data).values)
|
314
|
+
except TypeError:
|
315
|
+
for explain_target_method in cls.EXPLAIN_TARGET_METHODS:
|
316
|
+
if not hasattr(base_model, explain_target_method):
|
317
|
+
continue
|
318
|
+
explain_target_method_fn = getattr(base_model, explain_target_method)
|
319
|
+
if isinstance(data, np.ndarray):
|
320
|
+
explainer = shap.Explainer(
|
321
|
+
explain_target_method_fn,
|
322
|
+
background_data.values, # type: ignore[union-attr]
|
323
|
+
)
|
324
|
+
else:
|
325
|
+
explainer = shap.Explainer(explain_target_method_fn, background_data)
|
326
|
+
return handlers_utils.convert_explanations_to_2D_df(base_model, explainer(data).values)
|
327
|
+
except Exception:
|
328
|
+
pass # Do nothing and continue to the next method
|
329
|
+
raise ValueError("Explainability for this model is not supported.")
|
330
|
+
|
331
|
+
return explain_fn
|
332
|
+
|
254
333
|
@classmethod
|
255
334
|
def convert_as_custom_model(
|
256
335
|
cls,
|
@@ -286,57 +365,8 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
286
365
|
|
287
366
|
@custom_model.inference_api
|
288
367
|
def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
289
|
-
|
290
|
-
|
291
|
-
tree_methods = ["to_xgboost", "to_lightgbm"]
|
292
|
-
non_tree_methods = ["to_sklearn"]
|
293
|
-
for method_name in tree_methods:
|
294
|
-
try:
|
295
|
-
base_model = getattr(raw_model, method_name)()
|
296
|
-
explainer = shap.TreeExplainer(base_model)
|
297
|
-
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer.shap_values(X))
|
298
|
-
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
299
|
-
except exceptions.SnowflakeMLException:
|
300
|
-
pass # Do nothing and continue to the next method
|
301
|
-
for method_name in non_tree_methods:
|
302
|
-
try:
|
303
|
-
base_model = getattr(raw_model, method_name)()
|
304
|
-
try:
|
305
|
-
explainer = shap.Explainer(base_model, masker=background_data)
|
306
|
-
df = handlers_utils.convert_explanations_to_2D_df(base_model, explainer(X).values)
|
307
|
-
except TypeError:
|
308
|
-
try:
|
309
|
-
dtype_map = {
|
310
|
-
spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in signature.inputs
|
311
|
-
}
|
312
|
-
|
313
|
-
if isinstance(X, pd.DataFrame):
|
314
|
-
X = X.astype(dtype_map, copy=False)
|
315
|
-
if hasattr(base_model, "predict_proba"):
|
316
|
-
if isinstance(X, np.ndarray):
|
317
|
-
explainer = shap.Explainer(
|
318
|
-
base_model.predict_proba,
|
319
|
-
background_data.values, # type: ignore[union-attr]
|
320
|
-
)
|
321
|
-
else:
|
322
|
-
explainer = shap.Explainer(base_model.predict_proba, background_data)
|
323
|
-
elif hasattr(base_model, "predict"):
|
324
|
-
if isinstance(X, np.ndarray):
|
325
|
-
explainer = shap.Explainer(
|
326
|
-
base_model.predict, background_data.values # type: ignore[union-attr]
|
327
|
-
)
|
328
|
-
else:
|
329
|
-
explainer = shap.Explainer(base_model.predict, background_data)
|
330
|
-
else:
|
331
|
-
raise ValueError("Missing any supported target method to explain.")
|
332
|
-
df = handlers_utils.convert_explanations_to_2D_df(base_model, explainer(X).values)
|
333
|
-
except TypeError as e:
|
334
|
-
raise ValueError(f"Explanation for this model type not supported yet: {str(e)}")
|
335
|
-
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
336
|
-
|
337
|
-
except exceptions.SnowflakeMLException:
|
338
|
-
pass # Do nothing and continue to the next method
|
339
|
-
raise ValueError("The model must be an xgboost, lightgbm or sklearn (not pipeline) estimator.")
|
368
|
+
fn = cls._build_explain_fn(raw_model, background_data)
|
369
|
+
return model_signature_utils.rename_pandas_df(fn(X), signature.outputs)
|
340
370
|
|
341
371
|
if target_method == "explain":
|
342
372
|
return explain_fn
|
@@ -110,6 +110,7 @@ def create_model_metadata(
|
|
110
110
|
python_version=python_version,
|
111
111
|
embed_local_ml_library=embed_local_ml_library,
|
112
112
|
prefer_pip=prefer_pip,
|
113
|
+
target_platforms=target_platforms,
|
113
114
|
)
|
114
115
|
|
115
116
|
if embed_local_ml_library:
|
@@ -162,8 +163,9 @@ def _create_env_for_model_metadata(
|
|
162
163
|
python_version: Optional[str] = None,
|
163
164
|
embed_local_ml_library: bool = False,
|
164
165
|
prefer_pip: bool = False,
|
166
|
+
target_platforms: Optional[list[model_types.TargetPlatform]] = None,
|
165
167
|
) -> model_env.ModelEnv:
|
166
|
-
env = model_env.ModelEnv(prefer_pip=prefer_pip)
|
168
|
+
env = model_env.ModelEnv(prefer_pip=prefer_pip, target_platforms=target_platforms)
|
167
169
|
|
168
170
|
# Mypy doesn't like getter and setter have different types. See python/mypy #3004
|
169
171
|
env.conda_dependencies = conda_dependencies # type: ignore[assignment]
|
@@ -559,6 +559,30 @@ class ModelSignature:
|
|
559
559
|
)"""
|
560
560
|
)
|
561
561
|
|
562
|
+
def _repr_html_(self) -> str:
|
563
|
+
"""Generate an HTML representation of the model signature.
|
564
|
+
|
565
|
+
Returns:
|
566
|
+
str: HTML string containing formatted signature details.
|
567
|
+
"""
|
568
|
+
from snowflake.ml.utils import html_utils
|
569
|
+
|
570
|
+
# Create collapsible sections for inputs and outputs
|
571
|
+
inputs_content = html_utils.create_features_html(self.inputs, "Input")
|
572
|
+
outputs_content = html_utils.create_features_html(self.outputs, "Output")
|
573
|
+
|
574
|
+
inputs_section = html_utils.create_collapsible_section("Inputs", inputs_content, open_by_default=True)
|
575
|
+
outputs_section = html_utils.create_collapsible_section("Outputs", outputs_content, open_by_default=True)
|
576
|
+
|
577
|
+
content = f"""
|
578
|
+
<div style="margin-top: 10px;">
|
579
|
+
{inputs_section}
|
580
|
+
{outputs_section}
|
581
|
+
</div>
|
582
|
+
"""
|
583
|
+
|
584
|
+
return html_utils.create_base_container("Model Signature", content)
|
585
|
+
|
562
586
|
@classmethod
|
563
587
|
def from_mlflow_sig(cls, mlflow_sig: "mlflow.models.ModelSignature") -> "ModelSignature":
|
564
588
|
return ModelSignature(
|
@@ -60,12 +60,19 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
|
|
60
60
|
data: snowflake.snowpark.DataFrame,
|
61
61
|
ensure_serializable: bool = True,
|
62
62
|
features: Optional[Sequence[core.BaseFeatureSpec]] = None,
|
63
|
+
statement_params: Optional[dict[str, Any]] = None,
|
63
64
|
) -> pd.DataFrame:
|
64
65
|
# This method do things on top of to_pandas, to make sure the local dataframe got is in correct shape.
|
65
66
|
dtype_map = {}
|
67
|
+
|
66
68
|
if features:
|
69
|
+
quoted_identifiers_ignore_case = SnowparkDataFrameHandler._is_quoted_identifiers_ignore_case_enabled(
|
70
|
+
data.session, statement_params
|
71
|
+
)
|
67
72
|
for feature in features:
|
68
|
-
|
73
|
+
feature_name = feature.name.upper() if quoted_identifiers_ignore_case else feature.name
|
74
|
+
dtype_map[feature_name] = feature.as_dtype()
|
75
|
+
|
69
76
|
df_local = data.to_pandas()
|
70
77
|
|
71
78
|
# This is because Array will become string (Even though the correct schema is set)
|
@@ -93,6 +100,7 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
|
|
93
100
|
df: pd.DataFrame,
|
94
101
|
keep_order: bool = False,
|
95
102
|
features: Optional[Sequence[core.BaseFeatureSpec]] = None,
|
103
|
+
statement_params: Optional[dict[str, Any]] = None,
|
96
104
|
) -> snowflake.snowpark.DataFrame:
|
97
105
|
# This method is necessary to create the Snowpark Dataframe in correct schema.
|
98
106
|
# However, in this case, the order could not be preserved. Thus, a _ID column has to be added,
|
@@ -100,6 +108,12 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
|
|
100
108
|
# Although in this case, the column with array type can get correct ARRAY type, however, the element
|
101
109
|
# type is not preserved, and will become string type. This affect the implementation of convert_from_df.
|
102
110
|
df = pandas_handler.PandasDataFrameHandler.convert_to_df(df)
|
111
|
+
quoted_identifiers_ignore_case = SnowparkDataFrameHandler._is_quoted_identifiers_ignore_case_enabled(
|
112
|
+
session, statement_params
|
113
|
+
)
|
114
|
+
if quoted_identifiers_ignore_case:
|
115
|
+
df.columns = [str(col).upper() for col in df.columns]
|
116
|
+
|
103
117
|
df_cols = df.columns
|
104
118
|
if df_cols.dtype != np.object_:
|
105
119
|
raise snowml_exceptions.SnowflakeMLException(
|
@@ -116,9 +130,47 @@ class SnowparkDataFrameHandler(base_handler.BaseDataHandler[snowflake.snowpark.D
|
|
116
130
|
column_names = []
|
117
131
|
columns = []
|
118
132
|
for feature in features:
|
119
|
-
|
120
|
-
|
133
|
+
feature_name = identifier.get_inferred_name(feature.name)
|
134
|
+
if quoted_identifiers_ignore_case:
|
135
|
+
feature_name = feature_name.upper()
|
136
|
+
column_names.append(feature_name)
|
137
|
+
columns.append(F.col(feature_name).cast(feature.as_snowpark_type()))
|
121
138
|
|
122
139
|
sp_df = sp_df.with_columns(column_names, columns)
|
123
140
|
|
124
141
|
return sp_df
|
142
|
+
|
143
|
+
@staticmethod
|
144
|
+
def _is_quoted_identifiers_ignore_case_enabled(
|
145
|
+
session: snowflake.snowpark.Session, statement_params: Optional[dict[str, Any]] = None
|
146
|
+
) -> bool:
|
147
|
+
"""
|
148
|
+
Check if QUOTED_IDENTIFIERS_IGNORE_CASE parameter is enabled.
|
149
|
+
|
150
|
+
Args:
|
151
|
+
session: Snowpark session to check parameter for
|
152
|
+
statement_params: Optional statement parameters to check first
|
153
|
+
|
154
|
+
Returns:
|
155
|
+
bool: True if QUOTED_IDENTIFIERS_IGNORE_CASE is enabled, False otherwise
|
156
|
+
Returns False if the parameter cannot be retrieved (e.g., in stored procedures)
|
157
|
+
"""
|
158
|
+
if statement_params is not None:
|
159
|
+
for key, value in statement_params.items():
|
160
|
+
if key.upper() == "QUOTED_IDENTIFIERS_IGNORE_CASE":
|
161
|
+
parameter_value = str(value)
|
162
|
+
return parameter_value.lower() == "true"
|
163
|
+
|
164
|
+
try:
|
165
|
+
result = session.sql(
|
166
|
+
"SHOW PARAMETERS LIKE 'QUOTED_IDENTIFIERS_IGNORE_CASE' IN SESSION",
|
167
|
+
_emit_ast=False,
|
168
|
+
).collect(_emit_ast=False)
|
169
|
+
|
170
|
+
parameter_value = str(result[0].value)
|
171
|
+
return parameter_value.lower() == "true"
|
172
|
+
|
173
|
+
except Exception:
|
174
|
+
# Parameter query can fail in certain environments (e.g., in stored procedures)
|
175
|
+
# In that case, assume default behavior (case-sensitive)
|
176
|
+
return False
|
@@ -0,0 +1,11 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
|
3
|
+
|
4
|
+
class TargetPlatform(Enum):
|
5
|
+
WAREHOUSE = "WAREHOUSE"
|
6
|
+
SNOWPARK_CONTAINER_SERVICES = "SNOWPARK_CONTAINER_SERVICES"
|
7
|
+
|
8
|
+
|
9
|
+
WAREHOUSE_ONLY = [TargetPlatform.WAREHOUSE]
|
10
|
+
SNOWPARK_CONTAINER_SERVICES_ONLY = [TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
|
11
|
+
BOTH_WAREHOUSE_AND_SNOWPARK_CONTAINER_SERVICES = [TargetPlatform.WAREHOUSE, TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
|
@@ -0,0 +1,9 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
|
3
|
+
|
4
|
+
class Task(Enum):
|
5
|
+
UNKNOWN = "UNKNOWN"
|
6
|
+
TABULAR_BINARY_CLASSIFICATION = "TABULAR_BINARY_CLASSIFICATION"
|
7
|
+
TABULAR_MULTI_CLASSIFICATION = "TABULAR_MULTI_CLASSIFICATION"
|
8
|
+
TABULAR_REGRESSION = "TABULAR_REGRESSION"
|
9
|
+
TABULAR_RANKING = "TABULAR_RANKING"
|
snowflake/ml/model/type_hints.py
CHANGED
@@ -1,10 +1,12 @@
|
|
1
1
|
# mypy: disable-error-code="import"
|
2
|
-
from enum import Enum
|
3
2
|
from typing import TYPE_CHECKING, Literal, Sequence, TypedDict, TypeVar, Union
|
4
3
|
|
5
4
|
import numpy.typing as npt
|
6
5
|
from typing_extensions import NotRequired
|
7
6
|
|
7
|
+
from snowflake.ml.model.target_platform import TargetPlatform
|
8
|
+
from snowflake.ml.model.task import Task
|
9
|
+
|
8
10
|
if TYPE_CHECKING:
|
9
11
|
import catboost
|
10
12
|
import keras
|
@@ -321,17 +323,7 @@ ModelLoadOption = Union[
|
|
321
323
|
]
|
322
324
|
|
323
325
|
|
324
|
-
|
325
|
-
UNKNOWN = "UNKNOWN"
|
326
|
-
TABULAR_BINARY_CLASSIFICATION = "TABULAR_BINARY_CLASSIFICATION"
|
327
|
-
TABULAR_MULTI_CLASSIFICATION = "TABULAR_MULTI_CLASSIFICATION"
|
328
|
-
TABULAR_REGRESSION = "TABULAR_REGRESSION"
|
329
|
-
TABULAR_RANKING = "TABULAR_RANKING"
|
330
|
-
|
331
|
-
|
332
|
-
class TargetPlatform(Enum):
|
333
|
-
WAREHOUSE = "WAREHOUSE"
|
334
|
-
SNOWPARK_CONTAINER_SERVICES = "SNOWPARK_CONTAINER_SERVICES"
|
326
|
+
SupportedTargetPlatformType = Union[TargetPlatform, str]
|
335
327
|
|
336
328
|
|
337
|
-
|
329
|
+
__all__ = ["TargetPlatform", "Task"]
|
@@ -60,6 +60,7 @@ def register_accumulator_udtf(*, session: Session, statement_params: dict[str, A
|
|
60
60
|
),
|
61
61
|
input_types=[T.BinaryType()],
|
62
62
|
packages=[f"numpy=={np.__version__}", f"cloudpickle=={cloudpickle.__version__}"],
|
63
|
+
imports=[], # Prevents unnecessary import resolution.
|
63
64
|
name=accumulator,
|
64
65
|
is_permanent=False,
|
65
66
|
replace=True,
|
@@ -175,6 +176,7 @@ def register_sharded_dot_sum_computer(*, session: Session, statement_params: dic
|
|
175
176
|
),
|
176
177
|
input_types=[T.ArrayType(), T.IntegerType(), T.IntegerType()],
|
177
178
|
packages=[f"numpy=={np.__version__}", f"cloudpickle=={cloudpickle.__version__}"],
|
179
|
+
imports=[], # Prevents unnecessary import resolution.
|
178
180
|
name=sharded_dot_and_sum_computer,
|
179
181
|
is_permanent=False,
|
180
182
|
replace=True,
|
@@ -272,8 +272,8 @@ def plot_influence_sensitivity(
|
|
272
272
|
If Streamlit is not available and a DataFrame is passed in, an ImportError will be raised.
|
273
273
|
|
274
274
|
Args:
|
275
|
-
|
276
|
-
|
275
|
+
shap_values: pandas Series or 2D array containing the SHAP values for a specific feature
|
276
|
+
feature_values: pandas Series or 2D array containing the feature values for the same feature
|
277
277
|
figsize: tuple of (width, height) for the plot
|
278
278
|
|
279
279
|
Returns:
|
@@ -1,7 +1,5 @@
|
|
1
|
-
from snowflake import snowpark
|
2
1
|
from snowflake.ml._internal import telemetry
|
3
2
|
from snowflake.ml._internal.utils import sql_identifier
|
4
|
-
from snowflake.ml.monitoring import model_monitor_version
|
5
3
|
from snowflake.ml.monitoring._client import model_monitor_sql_client
|
6
4
|
|
7
5
|
|
@@ -29,7 +27,6 @@ class ModelMonitor:
|
|
29
27
|
project=telemetry.TelemetryProject.MLOPS.value,
|
30
28
|
subproject=telemetry.TelemetrySubProject.MONITORING.value,
|
31
29
|
)
|
32
|
-
@snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
|
33
30
|
def suspend(self) -> None:
|
34
31
|
"""Suspend the Model Monitor"""
|
35
32
|
statement_params = telemetry.get_statement_params(
|
@@ -42,7 +39,6 @@ class ModelMonitor:
|
|
42
39
|
project=telemetry.TelemetryProject.MLOPS.value,
|
43
40
|
subproject=telemetry.TelemetrySubProject.MONITORING.value,
|
44
41
|
)
|
45
|
-
@snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
|
46
42
|
def resume(self) -> None:
|
47
43
|
"""Resume the Model Monitor"""
|
48
44
|
statement_params = telemetry.get_statement_params(
|
@@ -1,5 +1,5 @@
|
|
1
1
|
from types import ModuleType
|
2
|
-
from typing import Any, Optional, Union
|
2
|
+
from typing import Any, Optional, Protocol, Union
|
3
3
|
|
4
4
|
import pandas as pd
|
5
5
|
from absl.logging import logging
|
@@ -8,7 +8,7 @@ from snowflake.ml._internal import env, platform_capabilities, telemetry
|
|
8
8
|
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
9
9
|
from snowflake.ml._internal.human_readable_id import hrid_generator
|
10
10
|
from snowflake.ml._internal.utils import sql_identifier
|
11
|
-
from snowflake.ml.model import model_signature,
|
11
|
+
from snowflake.ml.model import model_signature, target_platform, task, type_hints
|
12
12
|
from snowflake.ml.model._client.model import model_impl, model_version_impl
|
13
13
|
from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
|
14
14
|
from snowflake.ml.model._model_composer import model_composer
|
@@ -20,6 +20,14 @@ from snowflake.snowpark._internal import utils as snowpark_utils
|
|
20
20
|
logger = logging.getLogger(__name__)
|
21
21
|
|
22
22
|
|
23
|
+
class EventHandler(Protocol):
|
24
|
+
"""Protocol defining the interface for event handlers used during model operations."""
|
25
|
+
|
26
|
+
def update(self, message: str) -> None:
|
27
|
+
"""Update with a progress message."""
|
28
|
+
...
|
29
|
+
|
30
|
+
|
23
31
|
class ModelManager:
|
24
32
|
def __init__(
|
25
33
|
self,
|
@@ -41,7 +49,7 @@ class ModelManager:
|
|
41
49
|
def log_model(
|
42
50
|
self,
|
43
51
|
*,
|
44
|
-
model: Union[
|
52
|
+
model: Union[type_hints.SupportedModelType, model_version_impl.ModelVersion],
|
45
53
|
model_name: str,
|
46
54
|
version_name: Optional[str] = None,
|
47
55
|
comment: Optional[str] = None,
|
@@ -50,16 +58,17 @@ class ModelManager:
|
|
50
58
|
pip_requirements: Optional[list[str]] = None,
|
51
59
|
artifact_repository_map: Optional[dict[str, str]] = None,
|
52
60
|
resource_constraint: Optional[dict[str, str]] = None,
|
53
|
-
target_platforms: Optional[list[
|
61
|
+
target_platforms: Optional[list[type_hints.SupportedTargetPlatformType]] = None,
|
54
62
|
python_version: Optional[str] = None,
|
55
63
|
signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
|
56
|
-
sample_input_data: Optional[
|
64
|
+
sample_input_data: Optional[type_hints.SupportedDataType] = None,
|
57
65
|
user_files: Optional[dict[str, list[str]]] = None,
|
58
66
|
code_paths: Optional[list[str]] = None,
|
59
67
|
ext_modules: Optional[list[ModuleType]] = None,
|
60
|
-
task:
|
61
|
-
options: Optional[
|
68
|
+
task: type_hints.Task = task.Task.UNKNOWN,
|
69
|
+
options: Optional[type_hints.ModelSaveOption] = None,
|
62
70
|
statement_params: Optional[dict[str, Any]] = None,
|
71
|
+
event_handler: EventHandler,
|
63
72
|
) -> model_version_impl.ModelVersion:
|
64
73
|
|
65
74
|
database_name_id, schema_name_id, model_name_id = self._parse_fully_qualified_name(model_name)
|
@@ -143,11 +152,12 @@ class ModelManager:
|
|
143
152
|
task=task,
|
144
153
|
options=options,
|
145
154
|
statement_params=statement_params,
|
155
|
+
event_handler=event_handler,
|
146
156
|
)
|
147
157
|
|
148
158
|
def _log_model(
|
149
159
|
self,
|
150
|
-
model:
|
160
|
+
model: type_hints.SupportedModelType,
|
151
161
|
*,
|
152
162
|
model_name: str,
|
153
163
|
version_name: str,
|
@@ -157,16 +167,17 @@ class ModelManager:
|
|
157
167
|
pip_requirements: Optional[list[str]] = None,
|
158
168
|
artifact_repository_map: Optional[dict[str, str]] = None,
|
159
169
|
resource_constraint: Optional[dict[str, str]] = None,
|
160
|
-
target_platforms: Optional[list[
|
170
|
+
target_platforms: Optional[list[type_hints.SupportedTargetPlatformType]] = None,
|
161
171
|
python_version: Optional[str] = None,
|
162
172
|
signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
|
163
|
-
sample_input_data: Optional[
|
173
|
+
sample_input_data: Optional[type_hints.SupportedDataType] = None,
|
164
174
|
user_files: Optional[dict[str, list[str]]] = None,
|
165
175
|
code_paths: Optional[list[str]] = None,
|
166
176
|
ext_modules: Optional[list[ModuleType]] = None,
|
167
|
-
task:
|
168
|
-
options: Optional[
|
177
|
+
task: type_hints.Task = task.Task.UNKNOWN,
|
178
|
+
options: Optional[type_hints.ModelSaveOption] = None,
|
169
179
|
statement_params: Optional[dict[str, Any]] = None,
|
180
|
+
event_handler: EventHandler,
|
170
181
|
) -> model_version_impl.ModelVersion:
|
171
182
|
database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
|
172
183
|
version_name_id = sql_identifier.SqlIdentifier(version_name)
|
@@ -215,7 +226,7 @@ class ModelManager:
|
|
215
226
|
# User specified target platforms are defaulted to None and will not show up in the generated manifest.
|
216
227
|
if target_platforms:
|
217
228
|
# Convert any string target platforms to TargetPlatform objects
|
218
|
-
platforms = [
|
229
|
+
platforms = [type_hints.TargetPlatform(platform) for platform in target_platforms]
|
219
230
|
else:
|
220
231
|
# Default the target platform to warehouse if not specified and any table function exists
|
221
232
|
if options and (
|
@@ -231,7 +242,7 @@ class ModelManager:
|
|
231
242
|
"Logging a partitioned model with a table function without specifying `target_platforms`. "
|
232
243
|
'Default to `target_platforms=["WAREHOUSE"]`.'
|
233
244
|
)
|
234
|
-
platforms = [
|
245
|
+
platforms = [target_platform.TargetPlatform.WAREHOUSE]
|
235
246
|
|
236
247
|
# Default the target platform to SPCS if not specified when running in ML runtime
|
237
248
|
if not platforms and env.IN_ML_RUNTIME:
|
@@ -239,7 +250,7 @@ class ModelManager:
|
|
239
250
|
"Logging the model on Container Runtime for ML without specifying `target_platforms`. "
|
240
251
|
'Default to `target_platforms=["SNOWPARK_CONTAINER_SERVICES"]`.'
|
241
252
|
)
|
242
|
-
platforms = [
|
253
|
+
platforms = [target_platform.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
|
243
254
|
|
244
255
|
if artifact_repository_map:
|
245
256
|
for channel, artifact_repository_name in artifact_repository_map.items():
|
@@ -254,6 +265,7 @@ class ModelManager:
|
|
254
265
|
)
|
255
266
|
|
256
267
|
logger.info("Start packaging and uploading your model. It might take some time based on the size of the model.")
|
268
|
+
event_handler.update("📦 Packaging model...")
|
257
269
|
|
258
270
|
# Extract save_location from options if present
|
259
271
|
save_location = None
|
@@ -292,6 +304,7 @@ class ModelManager:
|
|
292
304
|
)
|
293
305
|
|
294
306
|
logger.info("Start creating MODEL object for you in the Snowflake.")
|
307
|
+
event_handler.update("🏗️ Creating model object in Snowflake...")
|
295
308
|
|
296
309
|
self._model_ops.create_from_stage(
|
297
310
|
composed_model=mc,
|
@@ -331,6 +344,8 @@ class ModelManager:
|
|
331
344
|
statement_params=statement_params,
|
332
345
|
)
|
333
346
|
|
347
|
+
event_handler.update("✅ Model logged successfully!")
|
348
|
+
|
334
349
|
return mv
|
335
350
|
|
336
351
|
def get_model(
|