snowflake-ml-python 1.5.4__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. snowflake/cortex/__init__.py +2 -0
  2. snowflake/cortex/_classify_text.py +36 -0
  3. snowflake/cortex/_complete.py +66 -35
  4. snowflake/cortex/_util.py +4 -4
  5. snowflake/ml/_internal/env_utils.py +11 -5
  6. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  7. snowflake/ml/_internal/lineage/lineage_utils.py +4 -4
  8. snowflake/ml/_internal/telemetry.py +26 -2
  9. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  10. snowflake/ml/data/_internal/arrow_ingestor.py +284 -0
  11. snowflake/ml/data/data_connector.py +186 -0
  12. snowflake/ml/data/data_ingestor.py +45 -0
  13. snowflake/ml/data/data_source.py +23 -0
  14. snowflake/ml/data/ingestor_utils.py +62 -0
  15. snowflake/ml/data/torch_dataset.py +33 -0
  16. snowflake/ml/dataset/dataset.py +1 -13
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +23 -117
  19. snowflake/ml/feature_store/access_manager.py +7 -1
  20. snowflake/ml/feature_store/entity.py +19 -2
  21. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  22. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  23. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  24. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +20 -0
  26. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +37 -0
  27. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +30 -0
  28. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +7 -0
  29. snowflake/ml/feature_store/examples/example_helper.py +278 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +12 -0
  31. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +44 -0
  32. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  33. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +9 -0
  34. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  35. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +36 -0
  36. snowflake/ml/feature_store/examples/source_data/fraud_transactions.yaml +29 -0
  37. snowflake/ml/feature_store/examples/source_data/nyc_yellow_trips.yaml +4 -0
  38. snowflake/ml/feature_store/examples/source_data/winequality_red.yaml +32 -0
  39. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +14 -0
  40. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +36 -0
  41. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +24 -0
  42. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +8 -0
  43. snowflake/ml/feature_store/feature_store.py +637 -76
  44. snowflake/ml/feature_store/feature_view.py +316 -9
  45. snowflake/ml/fileset/stage_fs.py +18 -10
  46. snowflake/ml/lineage/lineage_node.py +1 -1
  47. snowflake/ml/model/_client/model/model_impl.py +11 -2
  48. snowflake/ml/model/_client/model/model_version_impl.py +171 -20
  49. snowflake/ml/model/_client/ops/model_ops.py +105 -27
  50. snowflake/ml/model/_client/ops/service_ops.py +121 -0
  51. snowflake/ml/model/_client/service/model_deployment_spec.py +95 -0
  52. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +31 -0
  53. snowflake/ml/model/_client/sql/model_version.py +13 -4
  54. snowflake/ml/model/_client/sql/service.py +129 -0
  55. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
  56. snowflake/ml/model/_model_composer/model_composer.py +14 -14
  57. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +33 -17
  58. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -1
  59. snowflake/ml/model/_model_composer/model_method/function_generator.py +3 -3
  60. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +3 -32
  61. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +3 -27
  62. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +3 -32
  63. snowflake/ml/model/_model_composer/model_method/model_method.py +5 -2
  64. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  65. snowflake/ml/model/_packager/model_handlers/_base.py +30 -3
  66. snowflake/ml/model/_packager/model_handlers/_utils.py +58 -1
  67. snowflake/ml/model/_packager/model_handlers/catboost.py +52 -3
  68. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  69. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +9 -5
  70. snowflake/ml/model/_packager/model_handlers/lightgbm.py +80 -3
  71. snowflake/ml/model/_packager/model_handlers/llm.py +7 -3
  72. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  73. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  74. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  75. snowflake/ml/model/_packager/model_handlers/sklearn.py +87 -4
  76. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +7 -2
  77. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  78. snowflake/ml/model/_packager/model_handlers/torchscript.py +8 -3
  79. snowflake/ml/model/_packager/model_handlers/xgboost.py +71 -3
  80. snowflake/ml/model/_packager/model_meta/model_meta.py +32 -2
  81. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +19 -0
  82. snowflake/ml/model/_packager/model_packager.py +2 -1
  83. snowflake/ml/model/_packager/model_runtime/model_runtime.py +7 -7
  84. snowflake/ml/model/model_signature.py +4 -4
  85. snowflake/ml/model/type_hints.py +2 -0
  86. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +1 -1
  87. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +13 -1
  88. snowflake/ml/modeling/framework/base.py +28 -19
  89. snowflake/ml/modeling/impute/simple_imputer.py +26 -0
  90. snowflake/ml/modeling/pipeline/pipeline.py +7 -4
  91. snowflake/ml/registry/_manager/model_manager.py +16 -2
  92. snowflake/ml/registry/registry.py +100 -13
  93. snowflake/ml/utils/sql_client.py +22 -0
  94. snowflake/ml/version.py +1 -1
  95. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/METADATA +81 -2
  96. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/RECORD +99 -66
  97. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/WHEEL +1 -1
  98. snowflake/ml/_internal/lineage/data_source.py +0 -10
  99. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/LICENSE.txt +0 -0
  100. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,7 @@ import numpy as np
6
6
  import pandas as pd
7
7
  from typing_extensions import TypeGuard, Unpack
8
8
 
9
+ import snowflake.snowpark.dataframe as sp_df
9
10
  from snowflake.ml._internal import type_utils
10
11
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
11
12
  from snowflake.ml.model._packager.model_env import model_env
@@ -14,8 +15,13 @@ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
14
15
  from snowflake.ml.model._packager.model_meta import (
15
16
  model_blob_meta,
16
17
  model_meta as model_meta_api,
18
+ model_meta_schema,
19
+ )
20
+ from snowflake.ml.model._signatures import (
21
+ numpy_handler,
22
+ snowpark_handler,
23
+ utils as model_signature_utils,
17
24
  )
18
- from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
19
25
 
20
26
  if TYPE_CHECKING:
21
27
  import sklearn.base
@@ -36,6 +42,27 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
36
42
 
37
43
  DEFAULT_TARGET_METHODS = ["predict", "transform", "predict_proba", "predict_log_proba", "decision_function"]
38
44
 
45
+ @classmethod
46
+ def get_model_objective(
47
+ cls, model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"]
48
+ ) -> model_meta_schema.ModelObjective:
49
+ import sklearn.pipeline
50
+ from sklearn.base import is_classifier, is_regressor
51
+
52
+ if isinstance(model, sklearn.pipeline.Pipeline):
53
+ return model_meta_schema.ModelObjective.UNKNOWN
54
+ if is_regressor(model):
55
+ return model_meta_schema.ModelObjective.REGRESSION
56
+ if is_classifier(model):
57
+ classes_list = getattr(model, "classes_", [])
58
+ num_classes = getattr(model, "n_classes_", None) or len(classes_list)
59
+ if isinstance(num_classes, int):
60
+ if num_classes > 2:
61
+ return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
62
+ return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
63
+ return model_meta_schema.ModelObjective.UNKNOWN
64
+ return model_meta_schema.ModelObjective.UNKNOWN
65
+
39
66
  @classmethod
40
67
  def can_handle(
41
68
  cls,
@@ -79,11 +106,33 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
79
106
  is_sub_model: Optional[bool] = False,
80
107
  **kwargs: Unpack[model_types.SKLModelSaveOptions],
81
108
  ) -> None:
109
+ enable_explainability = kwargs.get("enable_explainability", False)
110
+
82
111
  import sklearn.base
83
112
  import sklearn.pipeline
84
113
 
85
114
  assert isinstance(model, sklearn.base.BaseEstimator) or isinstance(model, sklearn.pipeline.Pipeline)
86
115
 
116
+ enable_explainability = kwargs.get("enable_explainability", False)
117
+ if enable_explainability:
118
+ # TODO: Currently limited to pandas df, need to extend to other types.
119
+ if sample_input_data is None or not (
120
+ isinstance(sample_input_data, pd.DataFrame) or isinstance(sample_input_data, sp_df.DataFrame)
121
+ ):
122
+ raise ValueError(
123
+ "Sample input data is required to enable explainability. Currently we only support this for "
124
+ + "`pandas.DataFrame` and `snowflake.snowpark.dataframe.DataFrame`."
125
+ )
126
+ sample_input_data_pandas = (
127
+ sample_input_data
128
+ if isinstance(sample_input_data, pd.DataFrame)
129
+ else snowpark_handler.SnowparkDataFrameHandler.convert_to_df(sample_input_data)
130
+ )
131
+ data_blob_path = os.path.join(model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR)
132
+ os.makedirs(data_blob_path, exist_ok=True)
133
+ with open(os.path.join(data_blob_path, name + cls.BG_DATA_FILE_SUFFIX), "wb") as f:
134
+ sample_input_data_pandas.to_parquet(f)
135
+
87
136
  if not is_sub_model:
88
137
  target_methods = handlers_utils.get_target_methods(
89
138
  model=model,
@@ -110,19 +159,36 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
110
159
  get_prediction_fn=get_prediction,
111
160
  )
112
161
 
162
+ if enable_explainability:
163
+ output_type = model_signature.DataType.DOUBLE
164
+ if cls.get_model_objective(model) == model_meta_schema.ModelObjective.MULTI_CLASSIFICATION:
165
+ output_type = model_signature.DataType.STRING
166
+ model_meta = handlers_utils.add_explain_method_signature(
167
+ model_meta=model_meta,
168
+ explain_method="explain",
169
+ target_method="predict",
170
+ output_return_type=output_type,
171
+ )
172
+
113
173
  model_blob_path = os.path.join(model_blobs_dir_path, name)
114
174
  os.makedirs(model_blob_path, exist_ok=True)
115
- with open(os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR), "wb") as f:
175
+ with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
116
176
  cloudpickle.dump(model, f)
117
177
  base_meta = model_blob_meta.ModelBlobMeta(
118
178
  name=name,
119
179
  model_type=cls.HANDLER_TYPE,
120
180
  handler_version=cls.HANDLER_VERSION,
121
- path=cls.MODELE_BLOB_FILE_OR_DIR,
181
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
122
182
  )
123
183
  model_meta.models[name] = base_meta
124
184
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
125
185
 
186
+ if enable_explainability:
187
+ model_meta.env.include_if_absent(
188
+ [model_env.ModelDependency(requirement="shap", pip_name="shap")],
189
+ check_local_version=True,
190
+ )
191
+
126
192
  model_meta.env.include_if_absent(
127
193
  [model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn")], check_local_version=True
128
194
  )
@@ -153,6 +219,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
153
219
  cls,
154
220
  raw_model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
155
221
  model_meta: model_meta_api.ModelMetadata,
222
+ background_data: Optional[pd.DataFrame] = None,
156
223
  **kwargs: Unpack[model_types.SKLModelLoadOptions],
157
224
  ) -> custom_model.CustomModel:
158
225
  from snowflake.ml.model import custom_model
@@ -165,6 +232,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
165
232
  raw_model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
166
233
  signature: model_signature.ModelSignature,
167
234
  target_method: str,
235
+ background_data: Optional[pd.DataFrame],
168
236
  ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
169
237
  @custom_model.inference_api
170
238
  def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
@@ -179,11 +247,26 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
179
247
 
180
248
  return model_signature_utils.rename_pandas_df(df, signature.outputs)
181
249
 
250
+ @custom_model.inference_api
251
+ def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
252
+ import shap
253
+
254
+ # TODO: if not resolved by explainer, we need to pass the callable function
255
+ try:
256
+ explainer = shap.Explainer(raw_model, background_data)
257
+ df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
258
+ except TypeError as e:
259
+ raise ValueError(f"Explanation for this model type not supported yet: {str(e)}")
260
+ return model_signature_utils.rename_pandas_df(df, signature.outputs)
261
+
262
+ if target_method == "explain":
263
+ return explain_fn
264
+
182
265
  return fn
183
266
 
184
267
  type_method_dict = {}
185
268
  for target_method_name, sig in model_meta.signatures.items():
186
- type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
269
+ type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name, background_data)
187
270
 
188
271
  _SKLModel = type(
189
272
  "_SKLModel",
@@ -73,6 +73,10 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
73
73
  is_sub_model: Optional[bool] = False,
74
74
  **kwargs: Unpack[model_types.SNOWModelSaveOptions],
75
75
  ) -> None:
76
+ enable_explainability = kwargs.get("enable_explainability", False)
77
+ if enable_explainability:
78
+ raise NotImplementedError("Explainability is not supported for Snowpark ML model.")
79
+
76
80
  from snowflake.ml.modeling.framework.base import BaseEstimator
77
81
 
78
82
  assert isinstance(model, BaseEstimator)
@@ -103,13 +107,13 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
103
107
 
104
108
  model_blob_path = os.path.join(model_blobs_dir_path, name)
105
109
  os.makedirs(model_blob_path, exist_ok=True)
106
- with open(os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR), "wb") as f:
110
+ with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
107
111
  cloudpickle.dump(model, f)
108
112
  base_meta = model_blob_meta.ModelBlobMeta(
109
113
  name=name,
110
114
  model_type=cls.HANDLER_TYPE,
111
115
  handler_version=cls.HANDLER_VERSION,
112
- path=cls.MODELE_BLOB_FILE_OR_DIR,
116
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
113
117
  )
114
118
  model_meta.models[name] = base_meta
115
119
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
@@ -146,6 +150,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
146
150
  cls,
147
151
  raw_model: "BaseEstimator",
148
152
  model_meta: model_meta_api.ModelMetadata,
153
+ background_data: Optional[pd.DataFrame] = None,
149
154
  **kwargs: Unpack[model_types.SNOWModelLoadOptions],
150
155
  ) -> custom_model.CustomModel:
151
156
  from snowflake.ml.model import custom_model
@@ -36,7 +36,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
36
36
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
37
37
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
38
38
 
39
- MODELE_BLOB_FILE_OR_DIR = "model"
39
+ MODEL_BLOB_FILE_OR_DIR = "model"
40
40
  DEFAULT_TARGET_METHODS = ["__call__"]
41
41
 
42
42
  @classmethod
@@ -68,6 +68,10 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
68
68
  is_sub_model: Optional[bool] = False,
69
69
  **kwargs: Unpack[model_types.TensorflowSaveOptions],
70
70
  ) -> None:
71
+ enable_explainability = kwargs.get("enable_explainability", False)
72
+ if enable_explainability:
73
+ raise NotImplementedError("Explainability is not supported for Tensorflow model.")
74
+
71
75
  import tensorflow
72
76
 
73
77
  assert isinstance(model, tensorflow.Module)
@@ -114,15 +118,15 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
114
118
  model_blob_path = os.path.join(model_blobs_dir_path, name)
115
119
  os.makedirs(model_blob_path, exist_ok=True)
116
120
  if isinstance(model, tensorflow.keras.Model):
117
- tensorflow.keras.models.save_model(model, os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR))
121
+ tensorflow.keras.models.save_model(model, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
118
122
  else:
119
- tensorflow.saved_model.save(model, os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR))
123
+ tensorflow.saved_model.save(model, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
120
124
 
121
125
  base_meta = model_blob_meta.ModelBlobMeta(
122
126
  name=name,
123
127
  model_type=cls.HANDLER_TYPE,
124
128
  handler_version=cls.HANDLER_VERSION,
125
- path=cls.MODELE_BLOB_FILE_OR_DIR,
129
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
126
130
  )
127
131
  model_meta.models[name] = base_meta
128
132
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
@@ -156,6 +160,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
156
160
  cls,
157
161
  raw_model: "tensorflow.Module",
158
162
  model_meta: model_meta_api.ModelMetadata,
163
+ background_data: Optional[pd.DataFrame] = None,
159
164
  **kwargs: Unpack[model_types.TensorflowLoadOptions],
160
165
  ) -> custom_model.CustomModel:
161
166
  import tensorflow
@@ -34,7 +34,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
34
34
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
35
35
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
36
36
 
37
- MODELE_BLOB_FILE_OR_DIR = "model.pt"
37
+ MODEL_BLOB_FILE_OR_DIR = "model.pt"
38
38
  DEFAULT_TARGET_METHODS = ["forward"]
39
39
 
40
40
  @classmethod
@@ -66,6 +66,10 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
66
66
  is_sub_model: Optional[bool] = False,
67
67
  **kwargs: Unpack[model_types.TorchScriptSaveOptions],
68
68
  ) -> None:
69
+ enable_explainability = kwargs.get("enable_explainability", False)
70
+ if enable_explainability:
71
+ raise NotImplementedError("Explainability is not supported for Torch Script model.")
72
+
69
73
  import torch
70
74
 
71
75
  assert isinstance(model, torch.jit.ScriptModule) # type:ignore[attr-defined]
@@ -106,13 +110,13 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
106
110
 
107
111
  model_blob_path = os.path.join(model_blobs_dir_path, name)
108
112
  os.makedirs(model_blob_path, exist_ok=True)
109
- with open(os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR), "wb") as f:
113
+ with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
110
114
  torch.jit.save(model, f) # type:ignore[attr-defined]
111
115
  base_meta = model_blob_meta.ModelBlobMeta(
112
116
  name=name,
113
117
  model_type=cls.HANDLER_TYPE,
114
118
  handler_version=cls.HANDLER_VERSION,
115
- path=cls.MODELE_BLOB_FILE_OR_DIR,
119
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
116
120
  )
117
121
  model_meta.models[name] = base_meta
118
122
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
@@ -152,6 +156,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
152
156
  cls,
153
157
  raw_model: "torch.jit.ScriptModule", # type:ignore[name-defined]
154
158
  model_meta: model_meta_api.ModelMetadata,
159
+ background_data: Optional[pd.DataFrame] = None,
155
160
  **kwargs: Unpack[model_types.TorchScriptLoadOptions],
156
161
  ) -> custom_model.CustomModel:
157
162
  from snowflake.ml.model import custom_model
@@ -1,4 +1,5 @@
1
1
  # mypy: disable-error-code="import"
2
+ import json
2
3
  import os
3
4
  from typing import (
4
5
  TYPE_CHECKING,
@@ -44,8 +45,43 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
44
45
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
45
46
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
46
47
 
47
- MODELE_BLOB_FILE_OR_DIR = "model.ubj"
48
+ MODEL_BLOB_FILE_OR_DIR = "model.ubj"
48
49
  DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
50
+ _BINARY_CLASSIFICATION_OBJECTIVE_PREFIX = ["binary:"]
51
+ _MULTI_CLASSIFICATION_OBJECTIVE_PREFIX = ["multi:"]
52
+ _RANKING_OBJECTIVE_PREFIX = ["rank:"]
53
+ _REGRESSION_OBJECTIVE_PREFIX = ["reg:"]
54
+
55
+ @classmethod
56
+ def get_model_objective(
57
+ cls, model: Union["xgboost.Booster", "xgboost.XGBModel"]
58
+ ) -> model_meta_schema.ModelObjective:
59
+ import xgboost
60
+
61
+ if isinstance(model, xgboost.XGBClassifier) or isinstance(model, xgboost.XGBRFClassifier):
62
+ num_classes = handlers_utils.get_num_classes_if_exists(model)
63
+ if num_classes == 2:
64
+ return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
65
+ return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
66
+ if isinstance(model, xgboost.XGBRegressor) or isinstance(model, xgboost.XGBRFRegressor):
67
+ return model_meta_schema.ModelObjective.REGRESSION
68
+ if isinstance(model, xgboost.XGBRanker):
69
+ return model_meta_schema.ModelObjective.RANKING
70
+ model_params = json.loads(model.save_config())
71
+ model_objective = model_params["learner"]["objective"]
72
+ for classification_objective in cls._BINARY_CLASSIFICATION_OBJECTIVE_PREFIX:
73
+ if classification_objective in model_objective:
74
+ return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
75
+ for classification_objective in cls._MULTI_CLASSIFICATION_OBJECTIVE_PREFIX:
76
+ if classification_objective in model_objective:
77
+ return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
78
+ for ranking_objective in cls._RANKING_OBJECTIVE_PREFIX:
79
+ if ranking_objective in model_objective:
80
+ return model_meta_schema.ModelObjective.RANKING
81
+ for regression_objective in cls._REGRESSION_OBJECTIVE_PREFIX:
82
+ if regression_objective in model_objective:
83
+ return model_meta_schema.ModelObjective.REGRESSION
84
+ return model_meta_schema.ModelObjective.UNKNOWN
49
85
 
50
86
  @classmethod
51
87
  def can_handle(
@@ -112,15 +148,30 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
112
148
  sample_input_data=sample_input_data,
113
149
  get_prediction_fn=get_prediction,
114
150
  )
151
+ model_objective = cls.get_model_objective(model)
152
+ model_meta.model_objective = model_objective
153
+ if kwargs.get("enable_explainability", True):
154
+ output_type = model_signature.DataType.DOUBLE
155
+ if model_objective == model_meta_schema.ModelObjective.MULTI_CLASSIFICATION:
156
+ output_type = model_signature.DataType.STRING
157
+ model_meta = handlers_utils.add_explain_method_signature(
158
+ model_meta=model_meta,
159
+ explain_method="explain",
160
+ target_method="predict",
161
+ output_return_type=output_type,
162
+ )
163
+ model_meta.function_properties = {
164
+ "explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False}
165
+ }
115
166
 
116
167
  model_blob_path = os.path.join(model_blobs_dir_path, name)
117
168
  os.makedirs(model_blob_path, exist_ok=True)
118
- model.save_model(os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR))
169
+ model.save_model(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
119
170
  base_meta = model_blob_meta.ModelBlobMeta(
120
171
  name=name,
121
172
  model_type=cls.HANDLER_TYPE,
122
173
  handler_version=cls.HANDLER_VERSION,
123
- path=cls.MODELE_BLOB_FILE_OR_DIR,
174
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
124
175
  options=model_meta_schema.XgboostModelBlobOptions({"xgb_estimator_type": model.__class__.__name__}),
125
176
  )
126
177
  model_meta.models[name] = base_meta
@@ -133,6 +184,12 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
133
184
  ],
134
185
  check_local_version=True,
135
186
  )
187
+ if kwargs.get("enable_explainability", True):
188
+ model_meta.env.include_if_absent(
189
+ [model_env.ModelDependency(requirement="shap", pip_name="shap")],
190
+ check_local_version=True,
191
+ )
192
+ model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
136
193
  model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
137
194
 
138
195
  @classmethod
@@ -175,6 +232,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
175
232
  cls,
176
233
  raw_model: Union["xgboost.Booster", "xgboost.XGBModel"],
177
234
  model_meta: model_meta_api.ModelMetadata,
235
+ background_data: Optional[pd.DataFrame] = None,
178
236
  **kwargs: Unpack[model_types.XGBModelLoadOptions],
179
237
  ) -> custom_model.CustomModel:
180
238
  import xgboost
@@ -206,6 +264,16 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
206
264
 
207
265
  return model_signature_utils.rename_pandas_df(df, signature.outputs)
208
266
 
267
+ @custom_model.inference_api
268
+ def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
269
+ import shap
270
+
271
+ explainer = shap.TreeExplainer(raw_model)
272
+ df = pd.DataFrame(explainer(X).values)
273
+ return model_signature_utils.rename_pandas_df(df, signature.outputs)
274
+
275
+ if target_method == "explain":
276
+ return explain_fn
209
277
  return fn
210
278
 
211
279
  type_method_dict: Dict[str, Any] = {"_raw_model": raw_model}
@@ -237,6 +237,7 @@ class ModelMetadata:
237
237
  function_properties: A dict mapping function names to dict mapping function property key to value.
238
238
  metadata: User provided key-value metadata of the model. Defaults to None.
239
239
  creation_timestamp: Unix timestamp when the model metadata is created.
240
+ model_objective: Model objective like regression, classification etc.
240
241
  """
241
242
 
242
243
  def telemetry_metadata(self) -> ModelMetadataTelemetryDict:
@@ -260,6 +261,8 @@ class ModelMetadata:
260
261
  min_snowpark_ml_version: Optional[str] = None,
261
262
  models: Optional[Dict[str, model_blob_meta.ModelBlobMeta]] = None,
262
263
  original_metadata_version: Optional[str] = model_meta_schema.MODEL_METADATA_VERSION,
264
+ model_objective: Optional[model_meta_schema.ModelObjective] = model_meta_schema.ModelObjective.UNKNOWN,
265
+ explain_algorithm: Optional[model_meta_schema.ModelExplainAlgorithm] = None,
263
266
  ) -> None:
264
267
  self.name = name
265
268
  self.signatures: Dict[str, model_signature.ModelSignature] = dict()
@@ -284,6 +287,11 @@ class ModelMetadata:
284
287
 
285
288
  self.original_metadata_version = original_metadata_version
286
289
 
290
+ self.model_objective: model_meta_schema.ModelObjective = (
291
+ model_objective or model_meta_schema.ModelObjective.UNKNOWN
292
+ )
293
+ self.explain_algorithm: Optional[model_meta_schema.ModelExplainAlgorithm] = explain_algorithm
294
+
287
295
  @property
288
296
  def min_snowpark_ml_version(self) -> str:
289
297
  return self._min_snowpark_ml_version.base_version
@@ -321,9 +329,11 @@ class ModelMetadata:
321
329
  model_dict = model_meta_schema.ModelMetadataDict(
322
330
  {
323
331
  "creation_timestamp": self.creation_timestamp,
324
- "env": self.env.save_as_dict(pathlib.Path(model_dir_path)),
332
+ "env": self.env.save_as_dict(
333
+ pathlib.Path(model_dir_path), default_channel_override=env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
334
+ ),
325
335
  "runtimes": {
326
- runtime_name: runtime.save(pathlib.Path(model_dir_path))
336
+ runtime_name: runtime.save(pathlib.Path(model_dir_path), default_channel_override="conda-forge")
327
337
  for runtime_name, runtime in self.runtimes.items()
328
338
  },
329
339
  "metadata": self.metadata,
@@ -333,6 +343,13 @@ class ModelMetadata:
333
343
  "signatures": {func_name: sig.to_dict() for func_name, sig in self.signatures.items()},
334
344
  "version": model_meta_schema.MODEL_METADATA_VERSION,
335
345
  "min_snowpark_ml_version": self.min_snowpark_ml_version,
346
+ "model_objective": self.model_objective.value,
347
+ "explainability": (
348
+ model_meta_schema.ExplainabilityMetadataDict(algorithm=self.explain_algorithm.value)
349
+ if self.explain_algorithm
350
+ else None
351
+ ),
352
+ "function_properties": self.function_properties,
336
353
  }
337
354
  )
338
355
 
@@ -370,6 +387,9 @@ class ModelMetadata:
370
387
  signatures=loaded_meta["signatures"],
371
388
  version=original_loaded_meta_version,
372
389
  min_snowpark_ml_version=loaded_meta_min_snowpark_ml_version,
390
+ model_objective=loaded_meta.get("model_objective", model_meta_schema.ModelObjective.UNKNOWN.value),
391
+ explainability=loaded_meta.get("explainability", None),
392
+ function_properties=loaded_meta.get("function_properties", {}),
373
393
  )
374
394
 
375
395
  @classmethod
@@ -406,6 +426,11 @@ class ModelMetadata:
406
426
  else:
407
427
  runtimes = None
408
428
 
429
+ explanation_algorithm_dict = model_dict.get("explainability", None)
430
+ explanation_algorithm = None
431
+ if explanation_algorithm_dict:
432
+ explanation_algorithm = model_meta_schema.ModelExplainAlgorithm(explanation_algorithm_dict["algorithm"])
433
+
409
434
  return cls(
410
435
  name=model_dict["name"],
411
436
  model_type=model_dict["model_type"],
@@ -417,4 +442,9 @@ class ModelMetadata:
417
442
  min_snowpark_ml_version=model_dict["min_snowpark_ml_version"],
418
443
  models=models,
419
444
  original_metadata_version=model_dict["version"],
445
+ model_objective=model_meta_schema.ModelObjective(
446
+ model_dict.get("model_objective", model_meta_schema.ModelObjective.UNKNOWN.value)
447
+ ),
448
+ explain_algorithm=explanation_algorithm,
449
+ function_properties=model_dict.get("function_properties", {}),
420
450
  )
@@ -71,6 +71,10 @@ ModelBlobOptions = Union[
71
71
  ]
72
72
 
73
73
 
74
+ class ExplainabilityMetadataDict(TypedDict):
75
+ algorithm: Required[str]
76
+
77
+
74
78
  class ModelBlobMetadataDict(TypedDict):
75
79
  name: Required[str]
76
80
  model_type: Required[type_hints.SupportedModelHandlerType]
@@ -92,3 +96,18 @@ class ModelMetadataDict(TypedDict):
92
96
  signatures: Required[Dict[str, Dict[str, Any]]]
93
97
  version: Required[str]
94
98
  min_snowpark_ml_version: Required[str]
99
+ model_objective: Required[str]
100
+ explainability: NotRequired[Optional[ExplainabilityMetadataDict]]
101
+ function_properties: NotRequired[Dict[str, Dict[str, Any]]]
102
+
103
+
104
+ class ModelObjective(Enum):
105
+ UNKNOWN = "unknown"
106
+ BINARY_CLASSIFICATION = "binary_classification"
107
+ MULTI_CLASSIFICATION = "multi_classification"
108
+ REGRESSION = "regression"
109
+ RANKING = "ranking"
110
+
111
+
112
+ class ModelExplainAlgorithm(Enum):
113
+ SHAP = "shap"
@@ -146,7 +146,8 @@ class ModelPackager:
146
146
  m = handler.load_model(self.meta.name, self.meta, model_blobs_path, **options)
147
147
 
148
148
  if as_custom_model:
149
- m = handler.convert_as_custom_model(m, self.meta, **options)
149
+ background_data = handler.load_background_data(self.meta.name, model_blobs_path)
150
+ m = handler.convert_as_custom_model(m, self.meta, background_data, **options)
150
151
  assert isinstance(m, custom_model.CustomModel)
151
152
 
152
153
  self.model = m
@@ -35,7 +35,7 @@ class ModelRuntime:
35
35
  self,
36
36
  name: str,
37
37
  env: model_env.ModelEnv,
38
- imports: Optional[List[pathlib.PurePosixPath]] = None,
38
+ imports: Optional[List[str]] = None,
39
39
  is_gpu: bool = False,
40
40
  loading_from_file: bool = False,
41
41
  ) -> None:
@@ -67,7 +67,9 @@ class ModelRuntime:
67
67
  def runtime_rel_path(self) -> pathlib.PurePosixPath:
68
68
  return pathlib.PurePosixPath(ModelRuntime.RUNTIME_DIR_REL_PATH) / self.name
69
69
 
70
- def save(self, packager_path: pathlib.Path) -> model_meta_schema.ModelRuntimeDict:
70
+ def save(
71
+ self, packager_path: pathlib.Path, default_channel_override: str = env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
72
+ ) -> model_meta_schema.ModelRuntimeDict:
71
73
  runtime_base_path = packager_path / self.runtime_rel_path
72
74
  runtime_base_path.mkdir(parents=True, exist_ok=True)
73
75
 
@@ -75,12 +77,12 @@ class ModelRuntime:
75
77
  snowpark_ml_lib_path = runtime_base_path / "snowflake-ml-python.zip"
76
78
  file_utils.zip_python_package(str(snowpark_ml_lib_path), "snowflake.ml")
77
79
  snowpark_ml_lib_rel_path = pathlib.PurePosixPath(snowpark_ml_lib_path.relative_to(packager_path).as_posix())
78
- self.imports.append(snowpark_ml_lib_rel_path)
80
+ self.imports.append(str(snowpark_ml_lib_rel_path))
79
81
 
80
82
  self.runtime_env.conda_env_rel_path = self.runtime_rel_path / self.runtime_env.conda_env_rel_path
81
83
  self.runtime_env.pip_requirements_rel_path = self.runtime_rel_path / self.runtime_env.pip_requirements_rel_path
82
84
 
83
- env_dict = self.runtime_env.save_as_dict(packager_path)
85
+ env_dict = self.runtime_env.save_as_dict(packager_path, default_channel_override=default_channel_override)
84
86
 
85
87
  return model_meta_schema.ModelRuntimeDict(
86
88
  imports=list(map(str, self.imports)),
@@ -108,6 +110,4 @@ class ModelRuntime:
108
110
  warnings.simplefilter("ignore")
109
111
  env.load_from_conda_file(packager_path / conda_env_rel_path)
110
112
  env.load_from_pip_file(packager_path / pip_requirements_rel_path)
111
- return ModelRuntime(
112
- name=name, env=env, imports=list(map(pathlib.PurePosixPath, loaded_dict["imports"])), loading_from_file=True
113
- )
113
+ return ModelRuntime(name=name, env=env, imports=loaded_dict["imports"], loading_from_file=True)
@@ -232,7 +232,7 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
232
232
  ),
233
233
  )
234
234
  else:
235
- if isinstance(data_col[0], list):
235
+ if isinstance(data_col.iloc[0], list):
236
236
  if not ft_shape:
237
237
  raise snowml_exceptions.SnowflakeMLException(
238
238
  error_code=error_codes.INVALID_DATA,
@@ -266,7 +266,7 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
266
266
  ),
267
267
  )
268
268
 
269
- elif isinstance(data_col[0], np.ndarray):
269
+ elif isinstance(data_col.iloc[0], np.ndarray):
270
270
  if not ft_shape:
271
271
  raise snowml_exceptions.SnowflakeMLException(
272
272
  error_code=error_codes.INVALID_DATA,
@@ -297,7 +297,7 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
297
297
  ),
298
298
  )
299
299
 
300
- elif isinstance(data_col[0], str):
300
+ elif isinstance(data_col.iloc[0], str):
301
301
  if ft_shape is not None:
302
302
  raise snowml_exceptions.SnowflakeMLException(
303
303
  error_code=error_codes.INVALID_DATA,
@@ -316,7 +316,7 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
316
316
  ),
317
317
  )
318
318
 
319
- elif isinstance(data_col[0], bytes):
319
+ elif isinstance(data_col.iloc[0], bytes):
320
320
  if ft_shape is not None:
321
321
  raise snowml_exceptions.SnowflakeMLException(
322
322
  error_code=error_codes.INVALID_DATA,
@@ -232,6 +232,8 @@ class BaseModelSaveOption(TypedDict):
232
232
  _legacy_save: NotRequired[bool]
233
233
  function_type: NotRequired[Literal["FUNCTION", "TABLE_FUNCTION"]]
234
234
  method_options: NotRequired[Dict[str, ModelMethodSaveOptions]]
235
+ include_pip_dependencies: NotRequired[bool]
236
+ enable_explainability: NotRequired[bool]
235
237
 
236
238
 
237
239
  class CatBoostModelSaveOptions(BaseModelSaveOption):
@@ -41,7 +41,7 @@ cp.register_pickle_by_value(inspect.getmodule(snowpark_dataframe_utils.cast_snow
41
41
 
42
42
  _PROJECT = "ModelDevelopment"
43
43
  DEFAULT_UDTF_NJOBS = 3
44
- ENABLE_EFFICIENT_MEMORY_USAGE = False
44
+ ENABLE_EFFICIENT_MEMORY_USAGE = True
45
45
  _UDTF_STAGE_NAME = f"MEMORY_EFFICIENT_UDTF_{str(uuid.uuid4()).replace('-', '_')}"
46
46
 
47
47
 
@@ -83,7 +83,19 @@ def _load_data_into_udf() -> Tuple[
83
83
  with open(local_fit_and_score_kwargs_file_path, mode="rb") as local_fit_and_score_kwargs_file_obj:
84
84
  fit_and_score_kwargs = cp.load(local_fit_and_score_kwargs_file_obj)
85
85
 
86
- # convert dataframe to numpy would save memory consumption
86
+ # Convert dataframe to numpy would save memory consumption
87
+ # Except for Pipeline, we need to keep the dataframe for the column names
88
+ from sklearn.pipeline import Pipeline
89
+ if isinstance(base_estimator, Pipeline):
90
+ return (
91
+ df[CONSTANTS['input_cols']],
92
+ df[CONSTANTS['label_cols']].squeeze(),
93
+ indices,
94
+ params_to_evaluate,
95
+ base_estimator,
96
+ fit_and_score_kwargs,
97
+ CONSTANTS
98
+ )
87
99
  return (
88
100
  df[CONSTANTS['input_cols']].to_numpy(),
89
101
  df[CONSTANTS['label_cols']].squeeze().to_numpy(),