snowflake-ml-python 1.5.4__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. snowflake/cortex/__init__.py +2 -0
  2. snowflake/cortex/_classify_text.py +36 -0
  3. snowflake/cortex/_complete.py +66 -35
  4. snowflake/cortex/_util.py +4 -4
  5. snowflake/ml/_internal/env_utils.py +11 -5
  6. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  7. snowflake/ml/_internal/lineage/lineage_utils.py +4 -4
  8. snowflake/ml/_internal/telemetry.py +26 -2
  9. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  10. snowflake/ml/data/_internal/arrow_ingestor.py +284 -0
  11. snowflake/ml/data/data_connector.py +186 -0
  12. snowflake/ml/data/data_ingestor.py +45 -0
  13. snowflake/ml/data/data_source.py +23 -0
  14. snowflake/ml/data/ingestor_utils.py +62 -0
  15. snowflake/ml/data/torch_dataset.py +33 -0
  16. snowflake/ml/dataset/dataset.py +1 -13
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +23 -117
  19. snowflake/ml/feature_store/access_manager.py +7 -1
  20. snowflake/ml/feature_store/entity.py +19 -2
  21. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  22. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  23. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  24. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +20 -0
  26. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +37 -0
  27. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +30 -0
  28. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +7 -0
  29. snowflake/ml/feature_store/examples/example_helper.py +278 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +12 -0
  31. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +44 -0
  32. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  33. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +9 -0
  34. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  35. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +36 -0
  36. snowflake/ml/feature_store/examples/source_data/fraud_transactions.yaml +29 -0
  37. snowflake/ml/feature_store/examples/source_data/nyc_yellow_trips.yaml +4 -0
  38. snowflake/ml/feature_store/examples/source_data/winequality_red.yaml +32 -0
  39. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +14 -0
  40. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +36 -0
  41. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +24 -0
  42. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +8 -0
  43. snowflake/ml/feature_store/feature_store.py +637 -76
  44. snowflake/ml/feature_store/feature_view.py +316 -9
  45. snowflake/ml/fileset/stage_fs.py +18 -10
  46. snowflake/ml/lineage/lineage_node.py +1 -1
  47. snowflake/ml/model/_client/model/model_impl.py +11 -2
  48. snowflake/ml/model/_client/model/model_version_impl.py +171 -20
  49. snowflake/ml/model/_client/ops/model_ops.py +105 -27
  50. snowflake/ml/model/_client/ops/service_ops.py +121 -0
  51. snowflake/ml/model/_client/service/model_deployment_spec.py +95 -0
  52. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +31 -0
  53. snowflake/ml/model/_client/sql/model_version.py +13 -4
  54. snowflake/ml/model/_client/sql/service.py +129 -0
  55. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
  56. snowflake/ml/model/_model_composer/model_composer.py +14 -14
  57. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +33 -17
  58. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -1
  59. snowflake/ml/model/_model_composer/model_method/function_generator.py +3 -3
  60. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +3 -32
  61. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +3 -27
  62. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +3 -32
  63. snowflake/ml/model/_model_composer/model_method/model_method.py +5 -2
  64. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  65. snowflake/ml/model/_packager/model_handlers/_base.py +30 -3
  66. snowflake/ml/model/_packager/model_handlers/_utils.py +58 -1
  67. snowflake/ml/model/_packager/model_handlers/catboost.py +52 -3
  68. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  69. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +9 -5
  70. snowflake/ml/model/_packager/model_handlers/lightgbm.py +80 -3
  71. snowflake/ml/model/_packager/model_handlers/llm.py +7 -3
  72. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  73. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  74. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  75. snowflake/ml/model/_packager/model_handlers/sklearn.py +87 -4
  76. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +7 -2
  77. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  78. snowflake/ml/model/_packager/model_handlers/torchscript.py +8 -3
  79. snowflake/ml/model/_packager/model_handlers/xgboost.py +71 -3
  80. snowflake/ml/model/_packager/model_meta/model_meta.py +32 -2
  81. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +19 -0
  82. snowflake/ml/model/_packager/model_packager.py +2 -1
  83. snowflake/ml/model/_packager/model_runtime/model_runtime.py +7 -7
  84. snowflake/ml/model/model_signature.py +4 -4
  85. snowflake/ml/model/type_hints.py +2 -0
  86. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +1 -1
  87. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +13 -1
  88. snowflake/ml/modeling/framework/base.py +28 -19
  89. snowflake/ml/modeling/impute/simple_imputer.py +26 -0
  90. snowflake/ml/modeling/pipeline/pipeline.py +7 -4
  91. snowflake/ml/registry/_manager/model_manager.py +16 -2
  92. snowflake/ml/registry/registry.py +100 -13
  93. snowflake/ml/utils/sql_client.py +22 -0
  94. snowflake/ml/version.py +1 -1
  95. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/METADATA +81 -2
  96. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/RECORD +99 -66
  97. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/WHEEL +1 -1
  98. snowflake/ml/_internal/lineage/data_source.py +0 -10
  99. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/LICENSE.txt +0 -0
  100. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,8 @@
1
+ import os
1
2
  from abc import abstractmethod
2
3
  from typing import Dict, Generic, Optional, Protocol, Type, final
3
4
 
5
+ import pandas as pd
4
6
  from typing_extensions import TypeGuard, Unpack
5
7
 
6
8
  from snowflake.ml.model import custom_model, type_hints as model_types
@@ -16,7 +18,7 @@ class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
16
18
 
17
19
  @classmethod
18
20
  @abstractmethod
19
- def can_handle(cls, model: model_types.SupportedDataType) -> TypeGuard[model_types._ModelType]:
21
+ def can_handle(cls, model: model_types.SupportedModelType) -> TypeGuard[model_types._ModelType]:
20
22
  """Whether this handler could support the type of the `model`.
21
23
 
22
24
  Args:
@@ -96,6 +98,7 @@ class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
96
98
  cls,
97
99
  raw_model: model_types._ModelType,
98
100
  model_meta: model_meta.ModelMetadata,
101
+ background_data: Optional[pd.DataFrame] = None,
99
102
  **kwargs: Unpack[model_types.BaseModelLoadOption],
100
103
  ) -> custom_model.CustomModel:
101
104
  """Create a custom model class wrap for unified interface when being deployed. The predict method will be
@@ -104,6 +107,7 @@ class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
104
107
  Args:
105
108
  raw_model: original model object,
106
109
  model_meta: The model metadata.
110
+ background_data: The background data used for the model explanations.
107
111
  kwargs: Options when converting the model.
108
112
 
109
113
  Raises:
@@ -121,7 +125,8 @@ class BaseModelHandler(Generic[model_types._ModelType], _BaseModelHandlerProtoco
121
125
  _MIN_SNOWPARK_ML_VERSION: The minimal version of Snowpark ML library to use the current handler.
122
126
  _HANDLER_MIGRATOR_PLANS: Dict holding handler migrator plans.
123
127
 
124
- MODELE_BLOB_FILE_OR_DIR: Relative path of the model blob file in the model subdir. Default to "model.pkl".
128
+ MODEL_BLOB_FILE_OR_DIR: Relative path of the model blob file in the model subdir. Default to "model.pkl".
129
+ BG_DATA_FILE_SUFFIX: Suffix of the background data file. Default to "_background_data.pqt".
125
130
  MODEL_ARTIFACTS_DIR: Relative path of the model artifacts dir in the model subdir. Default to "artifacts"
126
131
  DEFAULT_TARGET_METHODS: Default target methods to be logged if not specified in this kind of model. Default to
127
132
  ["predict"]
@@ -129,8 +134,10 @@ class BaseModelHandler(Generic[model_types._ModelType], _BaseModelHandlerProtoco
129
134
  inputting sample data or model signature. Default to False.
130
135
  """
131
136
 
132
- MODELE_BLOB_FILE_OR_DIR = "model.pkl"
137
+ MODEL_BLOB_FILE_OR_DIR = "model.pkl"
138
+ BG_DATA_FILE_SUFFIX = "_background_data.pqt"
133
139
  MODEL_ARTIFACTS_DIR = "artifacts"
140
+ EXPLAIN_ARTIFACTS_DIR = "explain_artifacts"
134
141
  DEFAULT_TARGET_METHODS = ["predict"]
135
142
  IS_AUTO_SIGNATURE = False
136
143
 
@@ -159,3 +166,23 @@ class BaseModelHandler(Generic[model_types._ModelType], _BaseModelHandlerProtoco
159
166
  model_meta=model_meta,
160
167
  model_blobs_dir_path=model_blobs_dir_path,
161
168
  )
169
+
170
+ @classmethod
171
+ @final
172
+ def load_background_data(cls, name: str, model_blobs_dir_path: str) -> Optional[pd.DataFrame]:
173
+ """Load the model into memory.
174
+
175
+ Args:
176
+ name: Name of the model.
177
+ model_blobs_dir_path: Directory path to the whole model.
178
+
179
+ Returns:
180
+ Optional[pd.DataFrame], background data as pandas DataFrame, if exists.
181
+ """
182
+ data_blob_path = os.path.join(model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, name + cls.BG_DATA_FILE_SUFFIX)
183
+ if not os.path.exists(model_blobs_dir_path) or not os.path.isfile(data_blob_path):
184
+ return None
185
+ with open(data_blob_path, "rb") as f:
186
+ background_data = pd.read_parquet(f)
187
+
188
+ return background_data
@@ -1,4 +1,9 @@
1
- from typing import Callable, Iterable, Optional, Sequence, cast
1
+ import json
2
+ from typing import Any, Callable, Iterable, Optional, Sequence, cast
3
+
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+ import pandas as pd
2
7
 
3
8
  from snowflake.ml.model import model_signature, type_hints as model_types
4
9
  from snowflake.ml.model._packager.model_meta import model_meta
@@ -40,6 +45,24 @@ def validate_signature(
40
45
  return model_meta
41
46
 
42
47
 
48
+ def add_explain_method_signature(
49
+ model_meta: model_meta.ModelMetadata,
50
+ explain_method: str,
51
+ target_method: str,
52
+ output_return_type: model_signature.DataType = model_signature.DataType.DOUBLE,
53
+ ) -> model_meta.ModelMetadata:
54
+ if target_method not in model_meta.signatures:
55
+ raise ValueError(f"Signature for target method {target_method} is missing")
56
+ inputs = model_meta.signatures[target_method].inputs
57
+ model_meta.signatures[explain_method] = model_signature.ModelSignature(
58
+ inputs=inputs,
59
+ outputs=[
60
+ model_signature.FeatureSpec(dtype=output_return_type, name=f"{spec.name}_explanation") for spec in inputs
61
+ ],
62
+ )
63
+ return model_meta
64
+
65
+
43
66
  def get_target_methods(
44
67
  model: model_types.SupportedModelType,
45
68
  target_methods: Optional[Sequence[str]],
@@ -56,3 +79,37 @@ def validate_target_methods(model: model_types.SupportedModelType, target_method
56
79
  for method_name in target_methods:
57
80
  if not _is_callable(model, method_name):
58
81
  raise ValueError(f"Target method {method_name} is not callable or does not exist in the model.")
82
+
83
+
84
+ def get_num_classes_if_exists(model: model_types.SupportedModelType) -> int:
85
+ num_classes = getattr(model, "classes_", [])
86
+ return len(num_classes)
87
+
88
+
89
+ def convert_explanations_to_2D_df(
90
+ model: model_types.SupportedModelType, explanations: npt.NDArray[Any]
91
+ ) -> pd.DataFrame:
92
+ if explanations.ndim != 3:
93
+ return pd.DataFrame(explanations)
94
+
95
+ if hasattr(model, "classes_"):
96
+ classes_list = [cl for cl in model.classes_] # type:ignore[union-attr]
97
+ len_classes = len(classes_list)
98
+ if explanations.shape[2] != len_classes:
99
+ raise ValueError(f"Model has {len_classes} classes but explanations have {explanations.shape[2]}")
100
+ else:
101
+ classes_list = [i for i in range(explanations.shape[2])]
102
+ exp_2d = []
103
+ # TODO (SNOW-1549044): Optimize this
104
+ for row in explanations:
105
+ col_list = []
106
+ for column in row:
107
+ class_explanations = {}
108
+ for cl, cl_exp in zip(classes_list, column):
109
+ if isinstance(cl, (int, np.integer)):
110
+ cl = int(cl)
111
+ class_explanations[cl] = cl_exp
112
+ col_list.append(json.dumps(class_explanations))
113
+ exp_2d.append(col_list)
114
+
115
+ return pd.DataFrame(exp_2d)
@@ -30,9 +30,25 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
30
30
  _MIN_SNOWPARK_ML_VERSION = "1.3.1"
31
31
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
32
32
 
33
- MODELE_BLOB_FILE_OR_DIR = "model.bin"
33
+ MODEL_BLOB_FILE_OR_DIR = "model.bin"
34
34
  DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
35
35
 
36
+ @classmethod
37
+ def get_model_objective(cls, model: "catboost.CatBoost") -> model_meta_schema.ModelObjective:
38
+ import catboost
39
+
40
+ if isinstance(model, catboost.CatBoostClassifier):
41
+ num_classes = handlers_utils.get_num_classes_if_exists(model)
42
+ if num_classes == 2:
43
+ return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
44
+ return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
45
+ if isinstance(model, catboost.CatBoostRanker):
46
+ return model_meta_schema.ModelObjective.RANKING
47
+ if isinstance(model, catboost.CatBoostRegressor):
48
+ return model_meta_schema.ModelObjective.REGRESSION
49
+ # TODO: Find out model type from the generic Catboost Model
50
+ return model_meta_schema.ModelObjective.UNKNOWN
51
+
36
52
  @classmethod
37
53
  def can_handle(cls, model: model_types.SupportedModelType) -> TypeGuard["catboost.CatBoost"]:
38
54
  return (type_utils.LazyType("catboost.CatBoost").isinstance(model)) and any(
@@ -89,10 +105,25 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
89
105
  sample_input_data=sample_input_data,
90
106
  get_prediction_fn=get_prediction,
91
107
  )
108
+ model_objective = cls.get_model_objective(model)
109
+ model_meta.model_objective = model_objective
110
+ if kwargs.get("enable_explainability", True):
111
+ output_type = model_signature.DataType.DOUBLE
112
+ if model_objective == model_meta_schema.ModelObjective.MULTI_CLASSIFICATION:
113
+ output_type = model_signature.DataType.STRING
114
+ model_meta = handlers_utils.add_explain_method_signature(
115
+ model_meta=model_meta,
116
+ explain_method="explain",
117
+ target_method="predict",
118
+ output_return_type=output_type,
119
+ )
120
+ model_meta.function_properties = {
121
+ "explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False}
122
+ }
92
123
 
93
124
  model_blob_path = os.path.join(model_blobs_dir_path, name)
94
125
  os.makedirs(model_blob_path, exist_ok=True)
95
- model_save_path = os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR)
126
+ model_save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
96
127
 
97
128
  model.save_model(model_save_path)
98
129
 
@@ -100,7 +131,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
100
131
  name=name,
101
132
  model_type=cls.HANDLER_TYPE,
102
133
  handler_version=cls.HANDLER_VERSION,
103
- path=cls.MODELE_BLOB_FILE_OR_DIR,
134
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
104
135
  options=model_meta_schema.CatBoostModelBlobOptions({"catboost_estimator_type": model.__class__.__name__}),
105
136
  )
106
137
  model_meta.models[name] = base_meta
@@ -112,6 +143,12 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
112
143
  ],
113
144
  check_local_version=True,
114
145
  )
146
+ if kwargs.get("enable_explainability", True):
147
+ model_meta.env.include_if_absent(
148
+ [model_env.ModelDependency(requirement="shap", pip_name="shap")],
149
+ check_local_version=True,
150
+ )
151
+ model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
115
152
  model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
116
153
 
117
154
  return None
@@ -157,6 +194,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
157
194
  cls,
158
195
  raw_model: "catboost.CatBoost",
159
196
  model_meta: model_meta_api.ModelMetadata,
197
+ background_data: Optional[pd.DataFrame] = None,
160
198
  **kwargs: Unpack[model_types.CatBoostModelLoadOptions],
161
199
  ) -> custom_model.CustomModel:
162
200
  import catboost
@@ -186,6 +224,17 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
186
224
 
187
225
  return model_signature_utils.rename_pandas_df(df, signature.outputs)
188
226
 
227
+ @custom_model.inference_api
228
+ def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
229
+ import shap
230
+
231
+ explainer = shap.TreeExplainer(raw_model)
232
+ df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
233
+ return model_signature_utils.rename_pandas_df(df, signature.outputs)
234
+
235
+ if target_method == "explain":
236
+ return explain_fn
237
+
189
238
  return fn
190
239
 
191
240
  type_method_dict: Dict[str, Any] = {"_raw_model": raw_model}
@@ -51,6 +51,9 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
51
51
  **kwargs: Unpack[model_types.CustomModelSaveOption],
52
52
  ) -> None:
53
53
  assert isinstance(model, custom_model.CustomModel)
54
+ enable_explainability = kwargs.get("enable_explainability", False)
55
+ if enable_explainability:
56
+ raise NotImplementedError("Explainability is not supported for custom model.")
54
57
 
55
58
  def get_prediction(
56
59
  target_method_name: str, sample_input_data: model_types.SupportedLocalDataType
@@ -108,13 +111,13 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
108
111
  # Make sure that the module where the model is defined get pickled by value as well.
109
112
  cloudpickle.register_pickle_by_value(sys.modules[model.__module__])
110
113
  pickled_obj = (model.__class__, model.context)
111
- with open(os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR), "wb") as f:
114
+ with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
112
115
  cloudpickle.dump(pickled_obj, f)
113
116
  # model meta will be saved by the context manager
114
117
  model_meta.models[name] = model_blob_meta.ModelBlobMeta(
115
118
  name=name,
116
119
  model_type=cls.HANDLER_TYPE,
117
- path=cls.MODELE_BLOB_FILE_OR_DIR,
120
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
118
121
  handler_version=cls.HANDLER_VERSION,
119
122
  function_properties=model_meta.function_properties,
120
123
  artifacts={
@@ -183,6 +186,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
183
186
  cls,
184
187
  raw_model: custom_model.CustomModel,
185
188
  model_meta: model_meta_api.ModelMetadata,
189
+ background_data: Optional[pd.DataFrame] = None,
186
190
  **kwargs: Unpack[model_types.CustomModelLoadOption],
187
191
  ) -> custom_model.CustomModel:
188
192
  return raw_model
@@ -89,7 +89,7 @@ class HuggingFacePipelineHandler(
89
89
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
90
90
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
91
91
 
92
- MODELE_BLOB_FILE_OR_DIR = "model"
92
+ MODEL_BLOB_FILE_OR_DIR = "model"
93
93
  ADDITIONAL_CONFIG_FILE = "pipeline_config.pt"
94
94
  DEFAULT_TARGET_METHODS = ["__call__"]
95
95
  IS_AUTO_SIGNATURE = True
@@ -133,6 +133,9 @@ class HuggingFacePipelineHandler(
133
133
  is_sub_model: Optional[bool] = False,
134
134
  **kwargs: Unpack[model_types.HuggingFaceSaveOptions],
135
135
  ) -> None:
136
+ enable_explainability = kwargs.get("enable_explainability", False)
137
+ if enable_explainability:
138
+ raise NotImplementedError("Explainability is not supported for huggingface model.")
136
139
  if type_utils.LazyType("transformers.Pipeline").isinstance(model):
137
140
  task = model.task # type:ignore[attr-defined]
138
141
  framework = model.framework # type:ignore[attr-defined]
@@ -193,7 +196,7 @@ class HuggingFacePipelineHandler(
193
196
 
194
197
  if type_utils.LazyType("transformers.Pipeline").isinstance(model):
195
198
  model.save_pretrained( # type:ignore[attr-defined]
196
- os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR)
199
+ os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
197
200
  )
198
201
  pipeline_params = {
199
202
  "_batch_size": model._batch_size, # type:ignore[attr-defined]
@@ -205,7 +208,7 @@ class HuggingFacePipelineHandler(
205
208
  with open(
206
209
  os.path.join(
207
210
  model_blob_path,
208
- cls.MODELE_BLOB_FILE_OR_DIR,
211
+ cls.MODEL_BLOB_FILE_OR_DIR,
209
212
  cls.ADDITIONAL_CONFIG_FILE,
210
213
  ),
211
214
  "wb",
@@ -213,7 +216,7 @@ class HuggingFacePipelineHandler(
213
216
  cloudpickle.dump(pipeline_params, f)
214
217
  else:
215
218
  with open(
216
- os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR),
219
+ os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR),
217
220
  "wb",
218
221
  ) as f:
219
222
  cloudpickle.dump(model, f)
@@ -222,7 +225,7 @@ class HuggingFacePipelineHandler(
222
225
  name=name,
223
226
  model_type=cls.HANDLER_TYPE,
224
227
  handler_version=cls.HANDLER_VERSION,
225
- path=cls.MODELE_BLOB_FILE_OR_DIR,
228
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
226
229
  options=model_meta_schema.HuggingFacePipelineModelBlobOptions(
227
230
  {
228
231
  "task": task,
@@ -329,6 +332,7 @@ class HuggingFacePipelineHandler(
329
332
  cls,
330
333
  raw_model: Union[huggingface_pipeline.HuggingFacePipelineModel, "transformers.Pipeline"],
331
334
  model_meta: model_meta_api.ModelMetadata,
335
+ background_data: Optional[pd.DataFrame] = None,
332
336
  **kwargs: Unpack[model_types.HuggingFaceLoadOptions],
333
337
  ) -> custom_model.CustomModel:
334
338
  import transformers
@@ -41,8 +41,49 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
41
41
  _MIN_SNOWPARK_ML_VERSION = "1.3.1"
42
42
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
43
43
 
44
- MODELE_BLOB_FILE_OR_DIR = "model.pkl"
44
+ MODEL_BLOB_FILE_OR_DIR = "model.pkl"
45
45
  DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
46
+ _BINARY_CLASSIFICATION_OBJECTIVES = ["binary"]
47
+ _MULTI_CLASSIFICATION_OBJECTIVES = ["multiclass", "multiclassova"]
48
+ _RANKING_OBJECTIVES = ["lambdarank", "rank_xendcg"]
49
+ _REGRESSION_OBJECTIVES = [
50
+ "regression",
51
+ "regression_l1",
52
+ "huber",
53
+ "fair",
54
+ "poisson",
55
+ "quantile",
56
+ "tweedie",
57
+ "mape",
58
+ "gamma",
59
+ ]
60
+
61
+ @classmethod
62
+ def get_model_objective(
63
+ cls, model: Union["lightgbm.Booster", "lightgbm.LGBMModel"]
64
+ ) -> model_meta_schema.ModelObjective:
65
+ import lightgbm
66
+
67
+ # does not account for cross-entropy and custom
68
+ if isinstance(model, lightgbm.LGBMClassifier):
69
+ num_classes = handlers_utils.get_num_classes_if_exists(model)
70
+ if num_classes == 2:
71
+ return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
72
+ return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
73
+ if isinstance(model, lightgbm.LGBMRanker):
74
+ return model_meta_schema.ModelObjective.RANKING
75
+ if isinstance(model, lightgbm.LGBMRegressor):
76
+ return model_meta_schema.ModelObjective.REGRESSION
77
+ model_objective = model.params["objective"]
78
+ if model_objective in cls._BINARY_CLASSIFICATION_OBJECTIVES:
79
+ return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
80
+ if model_objective in cls._MULTI_CLASSIFICATION_OBJECTIVES:
81
+ return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
82
+ if model_objective in cls._RANKING_OBJECTIVES:
83
+ return model_meta_schema.ModelObjective.RANKING
84
+ if model_objective in cls._REGRESSION_OBJECTIVES:
85
+ return model_meta_schema.ModelObjective.REGRESSION
86
+ return model_meta_schema.ModelObjective.UNKNOWN
46
87
 
47
88
  @classmethod
48
89
  def can_handle(
@@ -105,11 +146,29 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
105
146
  sample_input_data=sample_input_data,
106
147
  get_prediction_fn=get_prediction,
107
148
  )
149
+ model_objective = cls.get_model_objective(model)
150
+ model_meta.model_objective = model_objective
151
+ if kwargs.get("enable_explainability", True):
152
+ output_type = model_signature.DataType.DOUBLE
153
+ if model_objective in [
154
+ model_meta_schema.ModelObjective.BINARY_CLASSIFICATION,
155
+ model_meta_schema.ModelObjective.MULTI_CLASSIFICATION,
156
+ ]:
157
+ output_type = model_signature.DataType.STRING
158
+ model_meta = handlers_utils.add_explain_method_signature(
159
+ model_meta=model_meta,
160
+ explain_method="explain",
161
+ target_method="predict",
162
+ output_return_type=output_type,
163
+ )
164
+ model_meta.function_properties = {
165
+ "explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False}
166
+ }
108
167
 
109
168
  model_blob_path = os.path.join(model_blobs_dir_path, name)
110
169
  os.makedirs(model_blob_path, exist_ok=True)
111
170
 
112
- model_save_path = os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR)
171
+ model_save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
113
172
  with open(model_save_path, "wb") as f:
114
173
  cloudpickle.dump(model, f)
115
174
 
@@ -117,7 +176,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
117
176
  name=name,
118
177
  model_type=cls.HANDLER_TYPE,
119
178
  handler_version=cls.HANDLER_VERSION,
120
- path=cls.MODELE_BLOB_FILE_OR_DIR,
179
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
121
180
  options=model_meta_schema.LightGBMModelBlobOptions({"lightgbm_estimator_type": model.__class__.__name__}),
122
181
  )
123
182
  model_meta.models[name] = base_meta
@@ -130,6 +189,12 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
130
189
  ],
131
190
  check_local_version=True,
132
191
  )
192
+ if kwargs.get("enable_explainability", True):
193
+ model_meta.env.include_if_absent(
194
+ [model_env.ModelDependency(requirement="shap", pip_name="shap")],
195
+ check_local_version=True,
196
+ )
197
+ model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
133
198
 
134
199
  return None
135
200
 
@@ -169,6 +234,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
169
234
  cls,
170
235
  raw_model: Union["lightgbm.Booster", "lightgbm.XGBModel"],
171
236
  model_meta: model_meta_api.ModelMetadata,
237
+ background_data: Optional[pd.DataFrame] = None,
172
238
  **kwargs: Unpack[model_types.LGBMModelLoadOptions],
173
239
  ) -> custom_model.CustomModel:
174
240
  import lightgbm
@@ -198,6 +264,17 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
198
264
 
199
265
  return model_signature_utils.rename_pandas_df(df, signature.outputs)
200
266
 
267
+ @custom_model.inference_api
268
+ def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
269
+ import shap
270
+
271
+ explainer = shap.TreeExplainer(raw_model)
272
+ df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
273
+ return model_signature_utils.rename_pandas_df(df, signature.outputs)
274
+
275
+ if target_method == "explain":
276
+ return explain_fn
277
+
201
278
  return fn
202
279
 
203
280
  type_method_dict: Dict[str, Any] = {"_raw_model": raw_model}
@@ -28,7 +28,7 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
28
28
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
29
29
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
30
30
 
31
- MODELE_BLOB_FILE_OR_DIR = "model"
31
+ MODEL_BLOB_FILE_OR_DIR = "model"
32
32
  LLM_META = "llm_meta"
33
33
  IS_AUTO_SIGNATURE = True
34
34
 
@@ -59,9 +59,12 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
59
59
  **kwargs: Unpack[model_types.LLMSaveOptions],
60
60
  ) -> None:
61
61
  assert not is_sub_model, "LLM can not be sub-model."
62
+ enable_explainability = kwargs.get("enable_explainability", False)
63
+ if enable_explainability:
64
+ raise NotImplementedError("Explainability is not supported for llm model.")
62
65
  model_blob_path = os.path.join(model_blobs_dir_path, name)
63
66
  os.makedirs(model_blob_path, exist_ok=True)
64
- model_blob_dir_path = os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR)
67
+ model_blob_dir_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
65
68
 
66
69
  sig = model_signature.ModelSignature(
67
70
  inputs=[
@@ -86,7 +89,7 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
86
89
  name=name,
87
90
  model_type=cls.HANDLER_TYPE,
88
91
  handler_version=cls.HANDLER_VERSION,
89
- path=cls.MODELE_BLOB_FILE_OR_DIR,
92
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
90
93
  options=model_meta_schema.LLMModelBlobOptions(
91
94
  {
92
95
  "batch_size": model.max_batch_size,
@@ -143,6 +146,7 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
143
146
  cls,
144
147
  raw_model: llm.LLM,
145
148
  model_meta: model_meta_api.ModelMetadata,
149
+ background_data: Optional[pd.DataFrame] = None,
146
150
  **kwargs: Unpack[model_types.LLMLoadOptions],
147
151
  ) -> custom_model.CustomModel:
148
152
  import gc
@@ -63,7 +63,7 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
63
63
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
64
64
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
65
65
 
66
- MODELE_BLOB_FILE_OR_DIR = "model"
66
+ MODEL_BLOB_FILE_OR_DIR = "model"
67
67
  _DEFAULT_TARGET_METHOD = "predict"
68
68
  DEFAULT_TARGET_METHODS = [_DEFAULT_TARGET_METHOD]
69
69
  IS_AUTO_SIGNATURE = True
@@ -97,6 +97,10 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
97
97
  is_sub_model: Optional[bool] = False,
98
98
  **kwargs: Unpack[model_types.MLFlowSaveOptions],
99
99
  ) -> None:
100
+ enable_explainability = kwargs.get("enable_explainability", False)
101
+ if enable_explainability:
102
+ raise NotImplementedError("Explainability is not supported for MLFlow model.")
103
+
100
104
  import mlflow
101
105
 
102
106
  assert isinstance(model, mlflow.pyfunc.PyFuncModel)
@@ -142,13 +146,13 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
142
146
  except (mlflow.MlflowException, OSError):
143
147
  raise ValueError("Cannot load MLFlow model artifacts.")
144
148
 
145
- file_utils.copy_file_or_tree(local_path, os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR))
149
+ file_utils.copy_file_or_tree(local_path, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
146
150
 
147
151
  base_meta = model_blob_meta.ModelBlobMeta(
148
152
  name=name,
149
153
  model_type=cls.HANDLER_TYPE,
150
154
  handler_version=cls.HANDLER_VERSION,
151
- path=cls.MODELE_BLOB_FILE_OR_DIR,
155
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
152
156
  options=model_meta_schema.MLFlowModelBlobOptions({"artifact_path": model_info.artifact_path}),
153
157
  )
154
158
  model_meta.models[name] = base_meta
@@ -194,6 +198,7 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
194
198
  cls,
195
199
  raw_model: "mlflow.pyfunc.PyFuncModel",
196
200
  model_meta: model_meta_api.ModelMetadata,
201
+ background_data: Optional[pd.DataFrame] = None,
197
202
  **kwargs: Unpack[model_types.MLFlowLoadOptions],
198
203
  ) -> custom_model.CustomModel:
199
204
  from snowflake.ml.model import custom_model
@@ -37,7 +37,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
37
37
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
38
38
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
39
39
 
40
- MODELE_BLOB_FILE_OR_DIR = "model.pt"
40
+ MODEL_BLOB_FILE_OR_DIR = "model.pt"
41
41
  DEFAULT_TARGET_METHODS = ["forward"]
42
42
 
43
43
  @classmethod
@@ -73,6 +73,10 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
73
73
  is_sub_model: Optional[bool] = False,
74
74
  **kwargs: Unpack[model_types.PyTorchSaveOptions],
75
75
  ) -> None:
76
+ enable_explainability = kwargs.get("enable_explainability", False)
77
+ if enable_explainability:
78
+ raise NotImplementedError("Explainability is not supported for PyTorch model.")
79
+
76
80
  import torch
77
81
 
78
82
  assert isinstance(model, torch.nn.Module)
@@ -115,13 +119,13 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
115
119
  cloudpickle.register_pickle_by_value(sys.modules[model.__module__])
116
120
  model_blob_path = os.path.join(model_blobs_dir_path, name)
117
121
  os.makedirs(model_blob_path, exist_ok=True)
118
- with open(os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR), "wb") as f:
122
+ with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
119
123
  torch.save(model, f, pickle_module=cloudpickle)
120
124
  base_meta = model_blob_meta.ModelBlobMeta(
121
125
  name=name,
122
126
  model_type=cls.HANDLER_TYPE,
123
127
  handler_version=cls.HANDLER_VERSION,
124
- path=cls.MODELE_BLOB_FILE_OR_DIR,
128
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
125
129
  )
126
130
  model_meta.models[name] = base_meta
127
131
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
@@ -156,6 +160,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
156
160
  cls,
157
161
  raw_model: "torch.nn.Module",
158
162
  model_meta: model_meta_api.ModelMetadata,
163
+ background_data: Optional[pd.DataFrame] = None,
159
164
  **kwargs: Unpack[model_types.PyTorchLoadOptions],
160
165
  ) -> custom_model.CustomModel:
161
166
  import torch
@@ -31,7 +31,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
31
31
  _MIN_SNOWPARK_ML_VERSION = "1.3.1"
32
32
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
33
33
 
34
- MODELE_BLOB_FILE_OR_DIR = "model"
34
+ MODEL_BLOB_FILE_OR_DIR = "model"
35
35
  DEFAULT_TARGET_METHODS = ["encode"]
36
36
 
37
37
  @classmethod
@@ -64,6 +64,10 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
64
64
  is_sub_model: Optional[bool] = False,
65
65
  **kwargs: Unpack[model_types.SentenceTransformersSaveOptions], # registry.log_model(options={...})
66
66
  ) -> None:
67
+ enable_explainability = kwargs.get("enable_explainability", False)
68
+ if enable_explainability:
69
+ raise NotImplementedError("Explainability is not supported for Sentence Transformer model.")
70
+
67
71
  # Validate target methods and signature (if possible)
68
72
  if not is_sub_model:
69
73
  target_methods = handlers_utils.get_target_methods(
@@ -101,14 +105,14 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
101
105
  # save model
102
106
  model_blob_path = os.path.join(model_blobs_dir_path, name)
103
107
  os.makedirs(model_blob_path, exist_ok=True)
104
- model.save(os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR))
108
+ model.save(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
105
109
 
106
110
  # save model metadata
107
111
  base_meta = model_blob_meta.ModelBlobMeta(
108
112
  name=name,
109
113
  model_type=cls.HANDLER_TYPE,
110
114
  handler_version=cls.HANDLER_VERSION,
111
- path=cls.MODELE_BLOB_FILE_OR_DIR,
115
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
112
116
  )
113
117
  model_meta.models[name] = base_meta
114
118
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
@@ -154,6 +158,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
154
158
  cls,
155
159
  raw_model: "sentence_transformers.SentenceTransformer",
156
160
  model_meta: model_meta_api.ModelMetadata,
161
+ background_data: Optional[pd.DataFrame] = None,
157
162
  **kwargs: Unpack[model_types.SentenceTransformersLoadOptions],
158
163
  ) -> custom_model.CustomModel:
159
164
  import sentence_transformers