snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +14 -0
  5. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  6. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  7. snowflake/ml/data/data_connector.py +59 -6
  8. snowflake/ml/data/data_ingestor.py +18 -1
  9. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  10. snowflake/ml/data/torch_dataset.py +33 -0
  11. snowflake/ml/dataset/dataset_metadata.py +3 -1
  12. snowflake/ml/dataset/dataset_reader.py +9 -3
  13. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  14. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  15. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  16. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  17. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  18. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  19. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  20. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  21. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  22. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  23. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  24. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  25. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  26. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  27. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  29. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  30. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  31. snowflake/ml/feature_store/feature_store.py +59 -24
  32. snowflake/ml/feature_store/feature_view.py +148 -4
  33. snowflake/ml/model/_client/model/model_impl.py +11 -2
  34. snowflake/ml/model/_client/model/model_version_impl.py +171 -20
  35. snowflake/ml/model/_client/ops/model_ops.py +105 -27
  36. snowflake/ml/model/_client/ops/service_ops.py +121 -0
  37. snowflake/ml/model/_client/service/model_deployment_spec.py +95 -0
  38. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +31 -0
  39. snowflake/ml/model/_client/sql/model_version.py +13 -4
  40. snowflake/ml/model/_client/sql/service.py +129 -0
  41. snowflake/ml/model/_model_composer/model_composer.py +3 -0
  42. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +10 -2
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  44. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  45. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +19 -12
  47. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  48. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +9 -5
  49. snowflake/ml/model/_packager/model_handlers/lightgbm.py +27 -18
  50. snowflake/ml/model/_packager/model_handlers/llm.py +7 -3
  51. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  52. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  53. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  54. snowflake/ml/model/_packager/model_handlers/sklearn.py +87 -4
  55. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  57. snowflake/ml/model/_packager/model_handlers/torchscript.py +8 -3
  58. snowflake/ml/model/_packager/model_handlers/xgboost.py +25 -16
  59. snowflake/ml/model/_packager/model_meta/model_meta.py +32 -2
  60. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +19 -0
  61. snowflake/ml/model/_packager/model_packager.py +2 -1
  62. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  63. snowflake/ml/model/type_hints.py +1 -3
  64. snowflake/ml/modeling/framework/base.py +28 -19
  65. snowflake/ml/modeling/pipeline/pipeline.py +3 -0
  66. snowflake/ml/registry/_manager/model_manager.py +16 -2
  67. snowflake/ml/utils/sql_client.py +22 -0
  68. snowflake/ml/version.py +1 -1
  69. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/METADATA +35 -2
  70. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/RECORD +73 -62
  71. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  72. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/LICENSE.txt +0 -0
  73. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/WHEEL +0 -0
  74. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/top_level.txt +0 -0
@@ -41,7 +41,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
41
41
  _MIN_SNOWPARK_ML_VERSION = "1.3.1"
42
42
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
43
43
 
44
- MODELE_BLOB_FILE_OR_DIR = "model.pkl"
44
+ MODEL_BLOB_FILE_OR_DIR = "model.pkl"
45
45
  DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
46
46
  _BINARY_CLASSIFICATION_OBJECTIVES = ["binary"]
47
47
  _MULTI_CLASSIFICATION_OBJECTIVES = ["multiclass", "multiclassova"]
@@ -59,29 +59,31 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
59
59
  ]
60
60
 
61
61
  @classmethod
62
- def get_model_objective(cls, model: Union["lightgbm.Booster", "lightgbm.LGBMModel"]) -> _base.ModelObjective:
62
+ def get_model_objective(
63
+ cls, model: Union["lightgbm.Booster", "lightgbm.LGBMModel"]
64
+ ) -> model_meta_schema.ModelObjective:
63
65
  import lightgbm
64
66
 
65
67
  # does not account for cross-entropy and custom
66
68
  if isinstance(model, lightgbm.LGBMClassifier):
67
69
  num_classes = handlers_utils.get_num_classes_if_exists(model)
68
70
  if num_classes == 2:
69
- return _base.ModelObjective.BINARY_CLASSIFICATION
70
- return _base.ModelObjective.MULTI_CLASSIFICATION
71
+ return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
72
+ return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
71
73
  if isinstance(model, lightgbm.LGBMRanker):
72
- return _base.ModelObjective.RANKING
74
+ return model_meta_schema.ModelObjective.RANKING
73
75
  if isinstance(model, lightgbm.LGBMRegressor):
74
- return _base.ModelObjective.REGRESSION
76
+ return model_meta_schema.ModelObjective.REGRESSION
75
77
  model_objective = model.params["objective"]
76
78
  if model_objective in cls._BINARY_CLASSIFICATION_OBJECTIVES:
77
- return _base.ModelObjective.BINARY_CLASSIFICATION
79
+ return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
78
80
  if model_objective in cls._MULTI_CLASSIFICATION_OBJECTIVES:
79
- return _base.ModelObjective.MULTI_CLASSIFICATION
81
+ return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
80
82
  if model_objective in cls._RANKING_OBJECTIVES:
81
- return _base.ModelObjective.RANKING
83
+ return model_meta_schema.ModelObjective.RANKING
82
84
  if model_objective in cls._REGRESSION_OBJECTIVES:
83
- return _base.ModelObjective.REGRESSION
84
- return _base.ModelObjective.UNKNOWN
85
+ return model_meta_schema.ModelObjective.REGRESSION
86
+ return model_meta_schema.ModelObjective.UNKNOWN
85
87
 
86
88
  @classmethod
87
89
  def can_handle(
@@ -144,11 +146,13 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
144
146
  sample_input_data=sample_input_data,
145
147
  get_prediction_fn=get_prediction,
146
148
  )
147
- if kwargs.get("enable_explainability", False):
149
+ model_objective = cls.get_model_objective(model)
150
+ model_meta.model_objective = model_objective
151
+ if kwargs.get("enable_explainability", True):
148
152
  output_type = model_signature.DataType.DOUBLE
149
- if cls.get_model_objective(model) in [
150
- _base.ModelObjective.BINARY_CLASSIFICATION,
151
- _base.ModelObjective.MULTI_CLASSIFICATION,
153
+ if model_objective in [
154
+ model_meta_schema.ModelObjective.BINARY_CLASSIFICATION,
155
+ model_meta_schema.ModelObjective.MULTI_CLASSIFICATION,
152
156
  ]:
153
157
  output_type = model_signature.DataType.STRING
154
158
  model_meta = handlers_utils.add_explain_method_signature(
@@ -157,11 +161,14 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
157
161
  target_method="predict",
158
162
  output_return_type=output_type,
159
163
  )
164
+ model_meta.function_properties = {
165
+ "explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False}
166
+ }
160
167
 
161
168
  model_blob_path = os.path.join(model_blobs_dir_path, name)
162
169
  os.makedirs(model_blob_path, exist_ok=True)
163
170
 
164
- model_save_path = os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR)
171
+ model_save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
165
172
  with open(model_save_path, "wb") as f:
166
173
  cloudpickle.dump(model, f)
167
174
 
@@ -169,7 +176,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
169
176
  name=name,
170
177
  model_type=cls.HANDLER_TYPE,
171
178
  handler_version=cls.HANDLER_VERSION,
172
- path=cls.MODELE_BLOB_FILE_OR_DIR,
179
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
173
180
  options=model_meta_schema.LightGBMModelBlobOptions({"lightgbm_estimator_type": model.__class__.__name__}),
174
181
  )
175
182
  model_meta.models[name] = base_meta
@@ -182,11 +189,12 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
182
189
  ],
183
190
  check_local_version=True,
184
191
  )
185
- if kwargs.get("enable_explainability", False):
192
+ if kwargs.get("enable_explainability", True):
186
193
  model_meta.env.include_if_absent(
187
194
  [model_env.ModelDependency(requirement="shap", pip_name="shap")],
188
195
  check_local_version=True,
189
196
  )
197
+ model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
190
198
 
191
199
  return None
192
200
 
@@ -226,6 +234,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
226
234
  cls,
227
235
  raw_model: Union["lightgbm.Booster", "lightgbm.XGBModel"],
228
236
  model_meta: model_meta_api.ModelMetadata,
237
+ background_data: Optional[pd.DataFrame] = None,
229
238
  **kwargs: Unpack[model_types.LGBMModelLoadOptions],
230
239
  ) -> custom_model.CustomModel:
231
240
  import lightgbm
@@ -28,7 +28,7 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
28
28
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
29
29
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
30
30
 
31
- MODELE_BLOB_FILE_OR_DIR = "model"
31
+ MODEL_BLOB_FILE_OR_DIR = "model"
32
32
  LLM_META = "llm_meta"
33
33
  IS_AUTO_SIGNATURE = True
34
34
 
@@ -59,9 +59,12 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
59
59
  **kwargs: Unpack[model_types.LLMSaveOptions],
60
60
  ) -> None:
61
61
  assert not is_sub_model, "LLM can not be sub-model."
62
+ enable_explainability = kwargs.get("enable_explainability", False)
63
+ if enable_explainability:
64
+ raise NotImplementedError("Explainability is not supported for llm model.")
62
65
  model_blob_path = os.path.join(model_blobs_dir_path, name)
63
66
  os.makedirs(model_blob_path, exist_ok=True)
64
- model_blob_dir_path = os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR)
67
+ model_blob_dir_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
65
68
 
66
69
  sig = model_signature.ModelSignature(
67
70
  inputs=[
@@ -86,7 +89,7 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
86
89
  name=name,
87
90
  model_type=cls.HANDLER_TYPE,
88
91
  handler_version=cls.HANDLER_VERSION,
89
- path=cls.MODELE_BLOB_FILE_OR_DIR,
92
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
90
93
  options=model_meta_schema.LLMModelBlobOptions(
91
94
  {
92
95
  "batch_size": model.max_batch_size,
@@ -143,6 +146,7 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
143
146
  cls,
144
147
  raw_model: llm.LLM,
145
148
  model_meta: model_meta_api.ModelMetadata,
149
+ background_data: Optional[pd.DataFrame] = None,
146
150
  **kwargs: Unpack[model_types.LLMLoadOptions],
147
151
  ) -> custom_model.CustomModel:
148
152
  import gc
@@ -63,7 +63,7 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
63
63
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
64
64
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
65
65
 
66
- MODELE_BLOB_FILE_OR_DIR = "model"
66
+ MODEL_BLOB_FILE_OR_DIR = "model"
67
67
  _DEFAULT_TARGET_METHOD = "predict"
68
68
  DEFAULT_TARGET_METHODS = [_DEFAULT_TARGET_METHOD]
69
69
  IS_AUTO_SIGNATURE = True
@@ -97,6 +97,10 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
97
97
  is_sub_model: Optional[bool] = False,
98
98
  **kwargs: Unpack[model_types.MLFlowSaveOptions],
99
99
  ) -> None:
100
+ enable_explainability = kwargs.get("enable_explainability", False)
101
+ if enable_explainability:
102
+ raise NotImplementedError("Explainability is not supported for MLFlow model.")
103
+
100
104
  import mlflow
101
105
 
102
106
  assert isinstance(model, mlflow.pyfunc.PyFuncModel)
@@ -142,13 +146,13 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
142
146
  except (mlflow.MlflowException, OSError):
143
147
  raise ValueError("Cannot load MLFlow model artifacts.")
144
148
 
145
- file_utils.copy_file_or_tree(local_path, os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR))
149
+ file_utils.copy_file_or_tree(local_path, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
146
150
 
147
151
  base_meta = model_blob_meta.ModelBlobMeta(
148
152
  name=name,
149
153
  model_type=cls.HANDLER_TYPE,
150
154
  handler_version=cls.HANDLER_VERSION,
151
- path=cls.MODELE_BLOB_FILE_OR_DIR,
155
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
152
156
  options=model_meta_schema.MLFlowModelBlobOptions({"artifact_path": model_info.artifact_path}),
153
157
  )
154
158
  model_meta.models[name] = base_meta
@@ -194,6 +198,7 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
194
198
  cls,
195
199
  raw_model: "mlflow.pyfunc.PyFuncModel",
196
200
  model_meta: model_meta_api.ModelMetadata,
201
+ background_data: Optional[pd.DataFrame] = None,
197
202
  **kwargs: Unpack[model_types.MLFlowLoadOptions],
198
203
  ) -> custom_model.CustomModel:
199
204
  from snowflake.ml.model import custom_model
@@ -37,7 +37,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
37
37
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
38
38
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
39
39
 
40
- MODELE_BLOB_FILE_OR_DIR = "model.pt"
40
+ MODEL_BLOB_FILE_OR_DIR = "model.pt"
41
41
  DEFAULT_TARGET_METHODS = ["forward"]
42
42
 
43
43
  @classmethod
@@ -73,6 +73,10 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
73
73
  is_sub_model: Optional[bool] = False,
74
74
  **kwargs: Unpack[model_types.PyTorchSaveOptions],
75
75
  ) -> None:
76
+ enable_explainability = kwargs.get("enable_explainability", False)
77
+ if enable_explainability:
78
+ raise NotImplementedError("Explainability is not supported for PyTorch model.")
79
+
76
80
  import torch
77
81
 
78
82
  assert isinstance(model, torch.nn.Module)
@@ -115,13 +119,13 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
115
119
  cloudpickle.register_pickle_by_value(sys.modules[model.__module__])
116
120
  model_blob_path = os.path.join(model_blobs_dir_path, name)
117
121
  os.makedirs(model_blob_path, exist_ok=True)
118
- with open(os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR), "wb") as f:
122
+ with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
119
123
  torch.save(model, f, pickle_module=cloudpickle)
120
124
  base_meta = model_blob_meta.ModelBlobMeta(
121
125
  name=name,
122
126
  model_type=cls.HANDLER_TYPE,
123
127
  handler_version=cls.HANDLER_VERSION,
124
- path=cls.MODELE_BLOB_FILE_OR_DIR,
128
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
125
129
  )
126
130
  model_meta.models[name] = base_meta
127
131
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
@@ -156,6 +160,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
156
160
  cls,
157
161
  raw_model: "torch.nn.Module",
158
162
  model_meta: model_meta_api.ModelMetadata,
163
+ background_data: Optional[pd.DataFrame] = None,
159
164
  **kwargs: Unpack[model_types.PyTorchLoadOptions],
160
165
  ) -> custom_model.CustomModel:
161
166
  import torch
@@ -31,7 +31,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
31
31
  _MIN_SNOWPARK_ML_VERSION = "1.3.1"
32
32
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
33
33
 
34
- MODELE_BLOB_FILE_OR_DIR = "model"
34
+ MODEL_BLOB_FILE_OR_DIR = "model"
35
35
  DEFAULT_TARGET_METHODS = ["encode"]
36
36
 
37
37
  @classmethod
@@ -64,6 +64,10 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
64
64
  is_sub_model: Optional[bool] = False,
65
65
  **kwargs: Unpack[model_types.SentenceTransformersSaveOptions], # registry.log_model(options={...})
66
66
  ) -> None:
67
+ enable_explainability = kwargs.get("enable_explainability", False)
68
+ if enable_explainability:
69
+ raise NotImplementedError("Explainability is not supported for Sentence Transformer model.")
70
+
67
71
  # Validate target methods and signature (if possible)
68
72
  if not is_sub_model:
69
73
  target_methods = handlers_utils.get_target_methods(
@@ -101,14 +105,14 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
101
105
  # save model
102
106
  model_blob_path = os.path.join(model_blobs_dir_path, name)
103
107
  os.makedirs(model_blob_path, exist_ok=True)
104
- model.save(os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR))
108
+ model.save(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
105
109
 
106
110
  # save model metadata
107
111
  base_meta = model_blob_meta.ModelBlobMeta(
108
112
  name=name,
109
113
  model_type=cls.HANDLER_TYPE,
110
114
  handler_version=cls.HANDLER_VERSION,
111
- path=cls.MODELE_BLOB_FILE_OR_DIR,
115
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
112
116
  )
113
117
  model_meta.models[name] = base_meta
114
118
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
@@ -154,6 +158,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
154
158
  cls,
155
159
  raw_model: "sentence_transformers.SentenceTransformer",
156
160
  model_meta: model_meta_api.ModelMetadata,
161
+ background_data: Optional[pd.DataFrame] = None,
157
162
  **kwargs: Unpack[model_types.SentenceTransformersLoadOptions],
158
163
  ) -> custom_model.CustomModel:
159
164
  import sentence_transformers
@@ -6,6 +6,7 @@ import numpy as np
6
6
  import pandas as pd
7
7
  from typing_extensions import TypeGuard, Unpack
8
8
 
9
+ import snowflake.snowpark.dataframe as sp_df
9
10
  from snowflake.ml._internal import type_utils
10
11
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
11
12
  from snowflake.ml.model._packager.model_env import model_env
@@ -14,8 +15,13 @@ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
14
15
  from snowflake.ml.model._packager.model_meta import (
15
16
  model_blob_meta,
16
17
  model_meta as model_meta_api,
18
+ model_meta_schema,
19
+ )
20
+ from snowflake.ml.model._signatures import (
21
+ numpy_handler,
22
+ snowpark_handler,
23
+ utils as model_signature_utils,
17
24
  )
18
- from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
19
25
 
20
26
  if TYPE_CHECKING:
21
27
  import sklearn.base
@@ -36,6 +42,27 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
36
42
 
37
43
  DEFAULT_TARGET_METHODS = ["predict", "transform", "predict_proba", "predict_log_proba", "decision_function"]
38
44
 
45
+ @classmethod
46
+ def get_model_objective(
47
+ cls, model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"]
48
+ ) -> model_meta_schema.ModelObjective:
49
+ import sklearn.pipeline
50
+ from sklearn.base import is_classifier, is_regressor
51
+
52
+ if isinstance(model, sklearn.pipeline.Pipeline):
53
+ return model_meta_schema.ModelObjective.UNKNOWN
54
+ if is_regressor(model):
55
+ return model_meta_schema.ModelObjective.REGRESSION
56
+ if is_classifier(model):
57
+ classes_list = getattr(model, "classes_", [])
58
+ num_classes = getattr(model, "n_classes_", None) or len(classes_list)
59
+ if isinstance(num_classes, int):
60
+ if num_classes > 2:
61
+ return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
62
+ return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
63
+ return model_meta_schema.ModelObjective.UNKNOWN
64
+ return model_meta_schema.ModelObjective.UNKNOWN
65
+
39
66
  @classmethod
40
67
  def can_handle(
41
68
  cls,
@@ -79,11 +106,33 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
79
106
  is_sub_model: Optional[bool] = False,
80
107
  **kwargs: Unpack[model_types.SKLModelSaveOptions],
81
108
  ) -> None:
109
+ enable_explainability = kwargs.get("enable_explainability", False)
110
+
82
111
  import sklearn.base
83
112
  import sklearn.pipeline
84
113
 
85
114
  assert isinstance(model, sklearn.base.BaseEstimator) or isinstance(model, sklearn.pipeline.Pipeline)
86
115
 
116
+ enable_explainability = kwargs.get("enable_explainability", False)
117
+ if enable_explainability:
118
+ # TODO: Currently limited to pandas df, need to extend to other types.
119
+ if sample_input_data is None or not (
120
+ isinstance(sample_input_data, pd.DataFrame) or isinstance(sample_input_data, sp_df.DataFrame)
121
+ ):
122
+ raise ValueError(
123
+ "Sample input data is required to enable explainability. Currently we only support this for "
124
+ + "`pandas.DataFrame` and `snowflake.snowpark.dataframe.DataFrame`."
125
+ )
126
+ sample_input_data_pandas = (
127
+ sample_input_data
128
+ if isinstance(sample_input_data, pd.DataFrame)
129
+ else snowpark_handler.SnowparkDataFrameHandler.convert_to_df(sample_input_data)
130
+ )
131
+ data_blob_path = os.path.join(model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR)
132
+ os.makedirs(data_blob_path, exist_ok=True)
133
+ with open(os.path.join(data_blob_path, name + cls.BG_DATA_FILE_SUFFIX), "wb") as f:
134
+ sample_input_data_pandas.to_parquet(f)
135
+
87
136
  if not is_sub_model:
88
137
  target_methods = handlers_utils.get_target_methods(
89
138
  model=model,
@@ -110,19 +159,36 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
110
159
  get_prediction_fn=get_prediction,
111
160
  )
112
161
 
162
+ if enable_explainability:
163
+ output_type = model_signature.DataType.DOUBLE
164
+ if cls.get_model_objective(model) == model_meta_schema.ModelObjective.MULTI_CLASSIFICATION:
165
+ output_type = model_signature.DataType.STRING
166
+ model_meta = handlers_utils.add_explain_method_signature(
167
+ model_meta=model_meta,
168
+ explain_method="explain",
169
+ target_method="predict",
170
+ output_return_type=output_type,
171
+ )
172
+
113
173
  model_blob_path = os.path.join(model_blobs_dir_path, name)
114
174
  os.makedirs(model_blob_path, exist_ok=True)
115
- with open(os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR), "wb") as f:
175
+ with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
116
176
  cloudpickle.dump(model, f)
117
177
  base_meta = model_blob_meta.ModelBlobMeta(
118
178
  name=name,
119
179
  model_type=cls.HANDLER_TYPE,
120
180
  handler_version=cls.HANDLER_VERSION,
121
- path=cls.MODELE_BLOB_FILE_OR_DIR,
181
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
122
182
  )
123
183
  model_meta.models[name] = base_meta
124
184
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
125
185
 
186
+ if enable_explainability:
187
+ model_meta.env.include_if_absent(
188
+ [model_env.ModelDependency(requirement="shap", pip_name="shap")],
189
+ check_local_version=True,
190
+ )
191
+
126
192
  model_meta.env.include_if_absent(
127
193
  [model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn")], check_local_version=True
128
194
  )
@@ -153,6 +219,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
153
219
  cls,
154
220
  raw_model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
155
221
  model_meta: model_meta_api.ModelMetadata,
222
+ background_data: Optional[pd.DataFrame] = None,
156
223
  **kwargs: Unpack[model_types.SKLModelLoadOptions],
157
224
  ) -> custom_model.CustomModel:
158
225
  from snowflake.ml.model import custom_model
@@ -165,6 +232,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
165
232
  raw_model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
166
233
  signature: model_signature.ModelSignature,
167
234
  target_method: str,
235
+ background_data: Optional[pd.DataFrame],
168
236
  ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
169
237
  @custom_model.inference_api
170
238
  def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
@@ -179,11 +247,26 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
179
247
 
180
248
  return model_signature_utils.rename_pandas_df(df, signature.outputs)
181
249
 
250
+ @custom_model.inference_api
251
+ def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
252
+ import shap
253
+
254
+ # TODO: if not resolved by explainer, we need to pass the callable function
255
+ try:
256
+ explainer = shap.Explainer(raw_model, background_data)
257
+ df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
258
+ except TypeError as e:
259
+ raise ValueError(f"Explanation for this model type not supported yet: {str(e)}")
260
+ return model_signature_utils.rename_pandas_df(df, signature.outputs)
261
+
262
+ if target_method == "explain":
263
+ return explain_fn
264
+
182
265
  return fn
183
266
 
184
267
  type_method_dict = {}
185
268
  for target_method_name, sig in model_meta.signatures.items():
186
- type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
269
+ type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name, background_data)
187
270
 
188
271
  _SKLModel = type(
189
272
  "_SKLModel",
@@ -73,6 +73,10 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
73
73
  is_sub_model: Optional[bool] = False,
74
74
  **kwargs: Unpack[model_types.SNOWModelSaveOptions],
75
75
  ) -> None:
76
+ enable_explainability = kwargs.get("enable_explainability", False)
77
+ if enable_explainability:
78
+ raise NotImplementedError("Explainability is not supported for Snowpark ML model.")
79
+
76
80
  from snowflake.ml.modeling.framework.base import BaseEstimator
77
81
 
78
82
  assert isinstance(model, BaseEstimator)
@@ -103,13 +107,13 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
103
107
 
104
108
  model_blob_path = os.path.join(model_blobs_dir_path, name)
105
109
  os.makedirs(model_blob_path, exist_ok=True)
106
- with open(os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR), "wb") as f:
110
+ with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
107
111
  cloudpickle.dump(model, f)
108
112
  base_meta = model_blob_meta.ModelBlobMeta(
109
113
  name=name,
110
114
  model_type=cls.HANDLER_TYPE,
111
115
  handler_version=cls.HANDLER_VERSION,
112
- path=cls.MODELE_BLOB_FILE_OR_DIR,
116
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
113
117
  )
114
118
  model_meta.models[name] = base_meta
115
119
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
@@ -146,6 +150,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
146
150
  cls,
147
151
  raw_model: "BaseEstimator",
148
152
  model_meta: model_meta_api.ModelMetadata,
153
+ background_data: Optional[pd.DataFrame] = None,
149
154
  **kwargs: Unpack[model_types.SNOWModelLoadOptions],
150
155
  ) -> custom_model.CustomModel:
151
156
  from snowflake.ml.model import custom_model
@@ -36,7 +36,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
36
36
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
37
37
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
38
38
 
39
- MODELE_BLOB_FILE_OR_DIR = "model"
39
+ MODEL_BLOB_FILE_OR_DIR = "model"
40
40
  DEFAULT_TARGET_METHODS = ["__call__"]
41
41
 
42
42
  @classmethod
@@ -68,6 +68,10 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
68
68
  is_sub_model: Optional[bool] = False,
69
69
  **kwargs: Unpack[model_types.TensorflowSaveOptions],
70
70
  ) -> None:
71
+ enable_explainability = kwargs.get("enable_explainability", False)
72
+ if enable_explainability:
73
+ raise NotImplementedError("Explainability is not supported for Tensorflow model.")
74
+
71
75
  import tensorflow
72
76
 
73
77
  assert isinstance(model, tensorflow.Module)
@@ -114,15 +118,15 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
114
118
  model_blob_path = os.path.join(model_blobs_dir_path, name)
115
119
  os.makedirs(model_blob_path, exist_ok=True)
116
120
  if isinstance(model, tensorflow.keras.Model):
117
- tensorflow.keras.models.save_model(model, os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR))
121
+ tensorflow.keras.models.save_model(model, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
118
122
  else:
119
- tensorflow.saved_model.save(model, os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR))
123
+ tensorflow.saved_model.save(model, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
120
124
 
121
125
  base_meta = model_blob_meta.ModelBlobMeta(
122
126
  name=name,
123
127
  model_type=cls.HANDLER_TYPE,
124
128
  handler_version=cls.HANDLER_VERSION,
125
- path=cls.MODELE_BLOB_FILE_OR_DIR,
129
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
126
130
  )
127
131
  model_meta.models[name] = base_meta
128
132
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
@@ -156,6 +160,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
156
160
  cls,
157
161
  raw_model: "tensorflow.Module",
158
162
  model_meta: model_meta_api.ModelMetadata,
163
+ background_data: Optional[pd.DataFrame] = None,
159
164
  **kwargs: Unpack[model_types.TensorflowLoadOptions],
160
165
  ) -> custom_model.CustomModel:
161
166
  import tensorflow
@@ -34,7 +34,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
34
34
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
35
35
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
36
36
 
37
- MODELE_BLOB_FILE_OR_DIR = "model.pt"
37
+ MODEL_BLOB_FILE_OR_DIR = "model.pt"
38
38
  DEFAULT_TARGET_METHODS = ["forward"]
39
39
 
40
40
  @classmethod
@@ -66,6 +66,10 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
66
66
  is_sub_model: Optional[bool] = False,
67
67
  **kwargs: Unpack[model_types.TorchScriptSaveOptions],
68
68
  ) -> None:
69
+ enable_explainability = kwargs.get("enable_explainability", False)
70
+ if enable_explainability:
71
+ raise NotImplementedError("Explainability is not supported for Torch Script model.")
72
+
69
73
  import torch
70
74
 
71
75
  assert isinstance(model, torch.jit.ScriptModule) # type:ignore[attr-defined]
@@ -106,13 +110,13 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
106
110
 
107
111
  model_blob_path = os.path.join(model_blobs_dir_path, name)
108
112
  os.makedirs(model_blob_path, exist_ok=True)
109
- with open(os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR), "wb") as f:
113
+ with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
110
114
  torch.jit.save(model, f) # type:ignore[attr-defined]
111
115
  base_meta = model_blob_meta.ModelBlobMeta(
112
116
  name=name,
113
117
  model_type=cls.HANDLER_TYPE,
114
118
  handler_version=cls.HANDLER_VERSION,
115
- path=cls.MODELE_BLOB_FILE_OR_DIR,
119
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
116
120
  )
117
121
  model_meta.models[name] = base_meta
118
122
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
@@ -152,6 +156,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
152
156
  cls,
153
157
  raw_model: "torch.jit.ScriptModule", # type:ignore[name-defined]
154
158
  model_meta: model_meta_api.ModelMetadata,
159
+ background_data: Optional[pd.DataFrame] = None,
155
160
  **kwargs: Unpack[model_types.TorchScriptLoadOptions],
156
161
  ) -> custom_model.CustomModel:
157
162
  from snowflake.ml.model import custom_model