snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +156 -20
  5. snowflake/ml/_internal/utils/identifier.py +48 -11
  6. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  7. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  8. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  9. snowflake/ml/_internal/utils/table_manager.py +19 -1
  10. snowflake/ml/_internal/utils/uri.py +2 -2
  11. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  12. snowflake/ml/data/data_connector.py +88 -9
  13. snowflake/ml/data/data_ingestor.py +18 -1
  14. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  15. snowflake/ml/data/torch_utils.py +68 -0
  16. snowflake/ml/dataset/dataset.py +1 -3
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +9 -3
  19. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  20. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  21. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  22. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  23. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  24. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  26. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  27. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  29. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  31. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  32. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  33. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  34. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  35. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  36. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  37. snowflake/ml/feature_store/feature_store.py +100 -41
  38. snowflake/ml/feature_store/feature_view.py +149 -5
  39. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  40. snowflake/ml/fileset/fileset.py +1 -1
  41. snowflake/ml/fileset/sfcfs.py +9 -3
  42. snowflake/ml/model/_client/model/model_impl.py +11 -2
  43. snowflake/ml/model/_client/model/model_version_impl.py +186 -20
  44. snowflake/ml/model/_client/ops/model_ops.py +144 -30
  45. snowflake/ml/model/_client/ops/service_ops.py +312 -0
  46. snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
  47. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
  48. snowflake/ml/model/_client/sql/model_version.py +13 -4
  49. snowflake/ml/model/_client/sql/service.py +196 -0
  50. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  51. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  52. snowflake/ml/model/_model_composer/model_composer.py +5 -0
  53. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
  54. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  55. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  57. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  58. snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
  59. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  60. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
  61. snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
  62. snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
  63. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  64. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  65. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  66. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  67. snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
  68. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
  69. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  70. snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
  71. snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
  72. snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
  73. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
  74. snowflake/ml/model/_packager/model_packager.py +4 -1
  75. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  76. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  77. snowflake/ml/model/_signatures/utils.py +9 -0
  78. snowflake/ml/model/models/llm.py +3 -1
  79. snowflake/ml/model/type_hints.py +10 -4
  80. snowflake/ml/modeling/_internal/constants.py +1 -0
  81. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  82. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  83. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  84. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  85. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  86. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  87. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  88. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  89. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  90. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  91. snowflake/ml/modeling/cluster/birch.py +60 -21
  92. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  93. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  94. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  95. snowflake/ml/modeling/cluster/k_means.py +60 -21
  96. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  97. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  98. snowflake/ml/modeling/cluster/optics.py +60 -21
  99. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  100. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  101. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  102. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  103. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  104. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  105. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  106. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  107. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  108. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  109. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  110. snowflake/ml/modeling/covariance/oas.py +60 -21
  111. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  112. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  113. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  114. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  115. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  116. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  117. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  118. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  119. snowflake/ml/modeling/decomposition/pca.py +60 -21
  120. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  121. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  122. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  123. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  124. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  125. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  126. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  127. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  128. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  129. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  130. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  131. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  132. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  133. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  134. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  135. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  136. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  137. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  138. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  139. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  140. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  141. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  142. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  143. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  144. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  145. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  146. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  147. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  148. snowflake/ml/modeling/framework/base.py +28 -19
  149. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  150. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  151. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  152. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  153. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  154. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  155. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  156. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  157. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  158. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  159. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  160. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  161. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  162. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  163. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  164. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  165. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  166. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  167. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  168. snowflake/ml/modeling/linear_model/lars.py +60 -21
  169. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  170. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  171. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  172. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  173. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  174. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  175. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  176. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  177. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  178. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  179. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  180. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  181. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  182. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  183. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  184. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  185. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  186. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  187. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  188. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  189. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  190. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  191. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  192. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  193. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  194. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  195. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  196. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  197. snowflake/ml/modeling/manifold/isomap.py +60 -21
  198. snowflake/ml/modeling/manifold/mds.py +60 -21
  199. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  200. snowflake/ml/modeling/manifold/tsne.py +60 -21
  201. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  202. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  203. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  204. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  205. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  206. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  207. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  208. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  209. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  210. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  211. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  212. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  213. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  214. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  215. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  216. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  217. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  218. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  219. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  220. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  221. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  222. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  223. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  224. snowflake/ml/modeling/pipeline/pipeline.py +4 -12
  225. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  226. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  227. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  228. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  229. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  230. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  231. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  232. snowflake/ml/modeling/svm/svc.py +60 -21
  233. snowflake/ml/modeling/svm/svr.py +60 -21
  234. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  235. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  236. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  237. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  238. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  239. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  240. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  241. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  242. snowflake/ml/registry/_manager/model_manager.py +20 -2
  243. snowflake/ml/registry/model_registry.py +1 -1
  244. snowflake/ml/registry/registry.py +1 -2
  245. snowflake/ml/utils/sql_client.py +22 -0
  246. snowflake/ml/version.py +1 -1
  247. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
  248. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
  249. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  250. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  251. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  252. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -19,7 +19,11 @@ from typing_extensions import TypeGuard, Unpack
19
19
  from snowflake.ml._internal import type_utils
20
20
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
21
21
  from snowflake.ml.model._packager.model_env import model_env
22
- from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
22
+ from snowflake.ml.model._packager.model_handlers import (
23
+ _base,
24
+ _utils as handlers_utils,
25
+ model_objective_utils,
26
+ )
23
27
  from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
24
28
  from snowflake.ml.model._packager.model_meta import (
25
29
  model_blob_meta,
@@ -41,47 +45,8 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
41
45
  _MIN_SNOWPARK_ML_VERSION = "1.3.1"
42
46
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
43
47
 
44
- MODELE_BLOB_FILE_OR_DIR = "model.pkl"
48
+ MODEL_BLOB_FILE_OR_DIR = "model.pkl"
45
49
  DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
46
- _BINARY_CLASSIFICATION_OBJECTIVES = ["binary"]
47
- _MULTI_CLASSIFICATION_OBJECTIVES = ["multiclass", "multiclassova"]
48
- _RANKING_OBJECTIVES = ["lambdarank", "rank_xendcg"]
49
- _REGRESSION_OBJECTIVES = [
50
- "regression",
51
- "regression_l1",
52
- "huber",
53
- "fair",
54
- "poisson",
55
- "quantile",
56
- "tweedie",
57
- "mape",
58
- "gamma",
59
- ]
60
-
61
- @classmethod
62
- def get_model_objective(cls, model: Union["lightgbm.Booster", "lightgbm.LGBMModel"]) -> _base.ModelObjective:
63
- import lightgbm
64
-
65
- # does not account for cross-entropy and custom
66
- if isinstance(model, lightgbm.LGBMClassifier):
67
- num_classes = handlers_utils.get_num_classes_if_exists(model)
68
- if num_classes == 2:
69
- return _base.ModelObjective.BINARY_CLASSIFICATION
70
- return _base.ModelObjective.MULTI_CLASSIFICATION
71
- if isinstance(model, lightgbm.LGBMRanker):
72
- return _base.ModelObjective.RANKING
73
- if isinstance(model, lightgbm.LGBMRegressor):
74
- return _base.ModelObjective.REGRESSION
75
- model_objective = model.params["objective"]
76
- if model_objective in cls._BINARY_CLASSIFICATION_OBJECTIVES:
77
- return _base.ModelObjective.BINARY_CLASSIFICATION
78
- if model_objective in cls._MULTI_CLASSIFICATION_OBJECTIVES:
79
- return _base.ModelObjective.MULTI_CLASSIFICATION
80
- if model_objective in cls._RANKING_OBJECTIVES:
81
- return _base.ModelObjective.RANKING
82
- if model_objective in cls._REGRESSION_OBJECTIVES:
83
- return _base.ModelObjective.REGRESSION
84
- return _base.ModelObjective.UNKNOWN
85
50
 
86
51
  @classmethod
87
52
  def can_handle(
@@ -116,6 +81,8 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
116
81
  is_sub_model: Optional[bool] = False,
117
82
  **kwargs: Unpack[model_types.LGBMModelSaveOptions],
118
83
  ) -> None:
84
+ enable_explainability = kwargs.get("enable_explainability", True)
85
+
119
86
  import lightgbm
120
87
 
121
88
  assert isinstance(model, lightgbm.Booster) or isinstance(model, lightgbm.LGBMModel)
@@ -144,24 +111,25 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
144
111
  sample_input_data=sample_input_data,
145
112
  get_prediction_fn=get_prediction,
146
113
  )
147
- if kwargs.get("enable_explainability", False):
148
- output_type = model_signature.DataType.DOUBLE
149
- if cls.get_model_objective(model) in [
150
- _base.ModelObjective.BINARY_CLASSIFICATION,
151
- _base.ModelObjective.MULTI_CLASSIFICATION,
152
- ]:
153
- output_type = model_signature.DataType.STRING
114
+ model_objective_and_output = model_objective_utils.get_model_objective_and_output_type(model)
115
+ model_meta.model_objective = handlers_utils.validate_model_objective(
116
+ model_meta.model_objective, model_objective_and_output.objective
117
+ )
118
+ if enable_explainability:
154
119
  model_meta = handlers_utils.add_explain_method_signature(
155
120
  model_meta=model_meta,
156
121
  explain_method="explain",
157
122
  target_method="predict",
158
- output_return_type=output_type,
123
+ output_return_type=model_objective_and_output.output_type,
159
124
  )
125
+ model_meta.function_properties = {
126
+ "explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False}
127
+ }
160
128
 
161
129
  model_blob_path = os.path.join(model_blobs_dir_path, name)
162
130
  os.makedirs(model_blob_path, exist_ok=True)
163
131
 
164
- model_save_path = os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR)
132
+ model_save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
165
133
  with open(model_save_path, "wb") as f:
166
134
  cloudpickle.dump(model, f)
167
135
 
@@ -169,7 +137,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
169
137
  name=name,
170
138
  model_type=cls.HANDLER_TYPE,
171
139
  handler_version=cls.HANDLER_VERSION,
172
- path=cls.MODELE_BLOB_FILE_OR_DIR,
140
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
173
141
  options=model_meta_schema.LightGBMModelBlobOptions({"lightgbm_estimator_type": model.__class__.__name__}),
174
142
  )
175
143
  model_meta.models[name] = base_meta
@@ -182,11 +150,9 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
182
150
  ],
183
151
  check_local_version=True,
184
152
  )
185
- if kwargs.get("enable_explainability", False):
186
- model_meta.env.include_if_absent(
187
- [model_env.ModelDependency(requirement="shap", pip_name="shap")],
188
- check_local_version=True,
189
- )
153
+ if enable_explainability:
154
+ model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
155
+ model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
190
156
 
191
157
  return None
192
158
 
@@ -226,6 +192,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
226
192
  cls,
227
193
  raw_model: Union["lightgbm.Booster", "lightgbm.XGBModel"],
228
194
  model_meta: model_meta_api.ModelMetadata,
195
+ background_data: Optional[pd.DataFrame] = None,
229
196
  **kwargs: Unpack[model_types.LGBMModelLoadOptions],
230
197
  ) -> custom_model.CustomModel:
231
198
  import lightgbm
@@ -28,7 +28,7 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
28
28
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
29
29
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
30
30
 
31
- MODELE_BLOB_FILE_OR_DIR = "model"
31
+ MODEL_BLOB_FILE_OR_DIR = "model"
32
32
  LLM_META = "llm_meta"
33
33
  IS_AUTO_SIGNATURE = True
34
34
 
@@ -59,9 +59,12 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
59
59
  **kwargs: Unpack[model_types.LLMSaveOptions],
60
60
  ) -> None:
61
61
  assert not is_sub_model, "LLM can not be sub-model."
62
+ enable_explainability = kwargs.get("enable_explainability", False)
63
+ if enable_explainability:
64
+ raise NotImplementedError("Explainability is not supported for llm model.")
62
65
  model_blob_path = os.path.join(model_blobs_dir_path, name)
63
66
  os.makedirs(model_blob_path, exist_ok=True)
64
- model_blob_dir_path = os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR)
67
+ model_blob_dir_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
65
68
 
66
69
  sig = model_signature.ModelSignature(
67
70
  inputs=[
@@ -86,7 +89,7 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
86
89
  name=name,
87
90
  model_type=cls.HANDLER_TYPE,
88
91
  handler_version=cls.HANDLER_VERSION,
89
- path=cls.MODELE_BLOB_FILE_OR_DIR,
92
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
90
93
  options=model_meta_schema.LLMModelBlobOptions(
91
94
  {
92
95
  "batch_size": model.max_batch_size,
@@ -143,6 +146,7 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
143
146
  cls,
144
147
  raw_model: llm.LLM,
145
148
  model_meta: model_meta_api.ModelMetadata,
149
+ background_data: Optional[pd.DataFrame] = None,
146
150
  **kwargs: Unpack[model_types.LLMLoadOptions],
147
151
  ) -> custom_model.CustomModel:
148
152
  import gc
@@ -201,7 +205,9 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
201
205
  "token": raw_model.token,
202
206
  }
203
207
  model_dir_path = raw_model.model_id_or_path
204
- peft_config = peft.PeftConfig.from_pretrained(model_dir_path) # type: ignore[attr-defined]
208
+ peft_config = peft.PeftConfig.from_pretrained( # type: ignore[no-untyped-call, attr-defined]
209
+ model_dir_path
210
+ )
205
211
  base_model_path = peft_config.base_model_name_or_path
206
212
  tokenizer = transformers.AutoTokenizer.from_pretrained(
207
213
  base_model_path,
@@ -217,7 +223,7 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
217
223
  model_dir_path,
218
224
  device_map="auto",
219
225
  torch_dtype="auto",
220
- **hub_kwargs,
226
+ **hub_kwargs, # type: ignore[arg-type]
221
227
  )
222
228
  hf_model.eval()
223
229
  hf_model = hf_model.merge_and_unload()
@@ -63,7 +63,7 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
63
63
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
64
64
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
65
65
 
66
- MODELE_BLOB_FILE_OR_DIR = "model"
66
+ MODEL_BLOB_FILE_OR_DIR = "model"
67
67
  _DEFAULT_TARGET_METHOD = "predict"
68
68
  DEFAULT_TARGET_METHODS = [_DEFAULT_TARGET_METHOD]
69
69
  IS_AUTO_SIGNATURE = True
@@ -97,6 +97,10 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
97
97
  is_sub_model: Optional[bool] = False,
98
98
  **kwargs: Unpack[model_types.MLFlowSaveOptions],
99
99
  ) -> None:
100
+ enable_explainability = kwargs.get("enable_explainability", False)
101
+ if enable_explainability:
102
+ raise NotImplementedError("Explainability is not supported for MLFlow model.")
103
+
100
104
  import mlflow
101
105
 
102
106
  assert isinstance(model, mlflow.pyfunc.PyFuncModel)
@@ -142,13 +146,13 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
142
146
  except (mlflow.MlflowException, OSError):
143
147
  raise ValueError("Cannot load MLFlow model artifacts.")
144
148
 
145
- file_utils.copy_file_or_tree(local_path, os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR))
149
+ file_utils.copy_file_or_tree(local_path, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
146
150
 
147
151
  base_meta = model_blob_meta.ModelBlobMeta(
148
152
  name=name,
149
153
  model_type=cls.HANDLER_TYPE,
150
154
  handler_version=cls.HANDLER_VERSION,
151
- path=cls.MODELE_BLOB_FILE_OR_DIR,
155
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
152
156
  options=model_meta_schema.MLFlowModelBlobOptions({"artifact_path": model_info.artifact_path}),
153
157
  )
154
158
  model_meta.models[name] = base_meta
@@ -194,6 +198,7 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
194
198
  cls,
195
199
  raw_model: "mlflow.pyfunc.PyFuncModel",
196
200
  model_meta: model_meta_api.ModelMetadata,
201
+ background_data: Optional[pd.DataFrame] = None,
197
202
  **kwargs: Unpack[model_types.MLFlowLoadOptions],
198
203
  ) -> custom_model.CustomModel:
199
204
  from snowflake.ml.model import custom_model
@@ -0,0 +1,116 @@
1
+ import json
2
+ from dataclasses import dataclass
3
+ from typing import TYPE_CHECKING, Any, Union
4
+
5
+ from snowflake.ml.model import model_signature, type_hints
6
+ from snowflake.ml.model._packager.model_handlers import _utils as handlers_utils
7
+
8
+ if TYPE_CHECKING:
9
+ import lightgbm
10
+ import xgboost
11
+
12
+
13
+ @dataclass
14
+ class ModelObjectiveAndOutputType:
15
+ objective: type_hints.ModelObjective
16
+ output_type: model_signature.DataType
17
+
18
+
19
+ def get_model_objective_lightgbm(model: Union["lightgbm.Booster", "lightgbm.LGBMModel"]) -> type_hints.ModelObjective:
20
+
21
+ import lightgbm
22
+
23
+ _BINARY_CLASSIFICATION_OBJECTIVES = ["binary"]
24
+ _MULTI_CLASSIFICATION_OBJECTIVES = ["multiclass", "multiclassova"]
25
+ _RANKING_OBJECTIVES = ["lambdarank", "rank_xendcg"]
26
+ _REGRESSION_OBJECTIVES = [
27
+ "regression",
28
+ "regression_l1",
29
+ "huber",
30
+ "fair",
31
+ "poisson",
32
+ "quantile",
33
+ "tweedie",
34
+ "mape",
35
+ "gamma",
36
+ ]
37
+
38
+ # does not account for cross-entropy and custom
39
+ if isinstance(model, lightgbm.LGBMClassifier):
40
+ num_classes = handlers_utils.get_num_classes_if_exists(model)
41
+ if num_classes == 2:
42
+ return type_hints.ModelObjective.BINARY_CLASSIFICATION
43
+ return type_hints.ModelObjective.MULTI_CLASSIFICATION
44
+ if isinstance(model, lightgbm.LGBMRanker):
45
+ return type_hints.ModelObjective.RANKING
46
+ if isinstance(model, lightgbm.LGBMRegressor):
47
+ return type_hints.ModelObjective.REGRESSION
48
+ model_objective = model.params["objective"]
49
+ if model_objective in _BINARY_CLASSIFICATION_OBJECTIVES:
50
+ return type_hints.ModelObjective.BINARY_CLASSIFICATION
51
+ if model_objective in _MULTI_CLASSIFICATION_OBJECTIVES:
52
+ return type_hints.ModelObjective.MULTI_CLASSIFICATION
53
+ if model_objective in _RANKING_OBJECTIVES:
54
+ return type_hints.ModelObjective.RANKING
55
+ if model_objective in _REGRESSION_OBJECTIVES:
56
+ return type_hints.ModelObjective.REGRESSION
57
+ return type_hints.ModelObjective.UNKNOWN
58
+
59
+
60
+ def get_model_objective_xgb(model: Union["xgboost.Booster", "xgboost.XGBModel"]) -> type_hints.ModelObjective:
61
+
62
+ import xgboost
63
+
64
+ _BINARY_CLASSIFICATION_OBJECTIVE_PREFIX = ["binary:"]
65
+ _MULTI_CLASSIFICATION_OBJECTIVE_PREFIX = ["multi:"]
66
+ _RANKING_OBJECTIVE_PREFIX = ["rank:"]
67
+ _REGRESSION_OBJECTIVE_PREFIX = ["reg:"]
68
+
69
+ model_objective = ""
70
+ if isinstance(model, xgboost.Booster):
71
+ model_params = json.loads(model.save_config())
72
+ model_objective = model_params.get("learner", {}).get("objective", "")
73
+ else:
74
+ if hasattr(model, "get_params"):
75
+ model_objective = model.get_params().get("objective", "")
76
+
77
+ if isinstance(model_objective, dict):
78
+ model_objective = model_objective.get("name", "")
79
+ for classification_objective in _BINARY_CLASSIFICATION_OBJECTIVE_PREFIX:
80
+ if classification_objective in model_objective:
81
+ return type_hints.ModelObjective.BINARY_CLASSIFICATION
82
+ for classification_objective in _MULTI_CLASSIFICATION_OBJECTIVE_PREFIX:
83
+ if classification_objective in model_objective:
84
+ return type_hints.ModelObjective.MULTI_CLASSIFICATION
85
+ for ranking_objective in _RANKING_OBJECTIVE_PREFIX:
86
+ if ranking_objective in model_objective:
87
+ return type_hints.ModelObjective.RANKING
88
+ for regression_objective in _REGRESSION_OBJECTIVE_PREFIX:
89
+ if regression_objective in model_objective:
90
+ return type_hints.ModelObjective.REGRESSION
91
+ return type_hints.ModelObjective.UNKNOWN
92
+
93
+
94
+ def get_model_objective_and_output_type(model: Any) -> ModelObjectiveAndOutputType:
95
+ import xgboost
96
+
97
+ if isinstance(model, xgboost.Booster) or isinstance(model, xgboost.XGBModel):
98
+ model_objective = get_model_objective_xgb(model)
99
+ output_type = model_signature.DataType.DOUBLE
100
+ if model_objective == type_hints.ModelObjective.MULTI_CLASSIFICATION:
101
+ output_type = model_signature.DataType.STRING
102
+ return ModelObjectiveAndOutputType(objective=model_objective, output_type=output_type)
103
+
104
+ import lightgbm
105
+
106
+ if isinstance(model, lightgbm.Booster) or isinstance(model, lightgbm.LGBMModel):
107
+ model_objective = get_model_objective_lightgbm(model)
108
+ output_type = model_signature.DataType.DOUBLE
109
+ if model_objective in [
110
+ type_hints.ModelObjective.BINARY_CLASSIFICATION,
111
+ type_hints.ModelObjective.MULTI_CLASSIFICATION,
112
+ ]:
113
+ output_type = model_signature.DataType.STRING
114
+ return ModelObjectiveAndOutputType(objective=model_objective, output_type=output_type)
115
+
116
+ raise ValueError(f"Model type {type(model)} is not supported")
@@ -37,7 +37,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
37
37
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
38
38
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
39
39
 
40
- MODELE_BLOB_FILE_OR_DIR = "model.pt"
40
+ MODEL_BLOB_FILE_OR_DIR = "model.pt"
41
41
  DEFAULT_TARGET_METHODS = ["forward"]
42
42
 
43
43
  @classmethod
@@ -73,6 +73,10 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
73
73
  is_sub_model: Optional[bool] = False,
74
74
  **kwargs: Unpack[model_types.PyTorchSaveOptions],
75
75
  ) -> None:
76
+ enable_explainability = kwargs.get("enable_explainability", False)
77
+ if enable_explainability:
78
+ raise NotImplementedError("Explainability is not supported for PyTorch model.")
79
+
76
80
  import torch
77
81
 
78
82
  assert isinstance(model, torch.nn.Module)
@@ -115,13 +119,13 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
115
119
  cloudpickle.register_pickle_by_value(sys.modules[model.__module__])
116
120
  model_blob_path = os.path.join(model_blobs_dir_path, name)
117
121
  os.makedirs(model_blob_path, exist_ok=True)
118
- with open(os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR), "wb") as f:
122
+ with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
119
123
  torch.save(model, f, pickle_module=cloudpickle)
120
124
  base_meta = model_blob_meta.ModelBlobMeta(
121
125
  name=name,
122
126
  model_type=cls.HANDLER_TYPE,
123
127
  handler_version=cls.HANDLER_VERSION,
124
- path=cls.MODELE_BLOB_FILE_OR_DIR,
128
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
125
129
  )
126
130
  model_meta.models[name] = base_meta
127
131
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
@@ -156,6 +160,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
156
160
  cls,
157
161
  raw_model: "torch.nn.Module",
158
162
  model_meta: model_meta_api.ModelMetadata,
163
+ background_data: Optional[pd.DataFrame] = None,
159
164
  **kwargs: Unpack[model_types.PyTorchLoadOptions],
160
165
  ) -> custom_model.CustomModel:
161
166
  import torch
@@ -31,7 +31,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
31
31
  _MIN_SNOWPARK_ML_VERSION = "1.3.1"
32
32
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
33
33
 
34
- MODELE_BLOB_FILE_OR_DIR = "model"
34
+ MODEL_BLOB_FILE_OR_DIR = "model"
35
35
  DEFAULT_TARGET_METHODS = ["encode"]
36
36
 
37
37
  @classmethod
@@ -64,6 +64,10 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
64
64
  is_sub_model: Optional[bool] = False,
65
65
  **kwargs: Unpack[model_types.SentenceTransformersSaveOptions], # registry.log_model(options={...})
66
66
  ) -> None:
67
+ enable_explainability = kwargs.get("enable_explainability", False)
68
+ if enable_explainability:
69
+ raise NotImplementedError("Explainability is not supported for Sentence Transformer model.")
70
+
67
71
  # Validate target methods and signature (if possible)
68
72
  if not is_sub_model:
69
73
  target_methods = handlers_utils.get_target_methods(
@@ -101,14 +105,14 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
101
105
  # save model
102
106
  model_blob_path = os.path.join(model_blobs_dir_path, name)
103
107
  os.makedirs(model_blob_path, exist_ok=True)
104
- model.save(os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR))
108
+ model.save(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
105
109
 
106
110
  # save model metadata
107
111
  base_meta = model_blob_meta.ModelBlobMeta(
108
112
  name=name,
109
113
  model_type=cls.HANDLER_TYPE,
110
114
  handler_version=cls.HANDLER_VERSION,
111
- path=cls.MODELE_BLOB_FILE_OR_DIR,
115
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
112
116
  )
113
117
  model_meta.models[name] = base_meta
114
118
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
@@ -154,6 +158,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
154
158
  cls,
155
159
  raw_model: "sentence_transformers.SentenceTransformer",
156
160
  model_meta: model_meta_api.ModelMetadata,
161
+ background_data: Optional[pd.DataFrame] = None,
157
162
  **kwargs: Unpack[model_types.SentenceTransformersLoadOptions],
158
163
  ) -> custom_model.CustomModel:
159
164
  import sentence_transformers
@@ -6,6 +6,7 @@ import numpy as np
6
6
  import pandas as pd
7
7
  from typing_extensions import TypeGuard, Unpack
8
8
 
9
+ import snowflake.snowpark.dataframe as sp_df
9
10
  from snowflake.ml._internal import type_utils
10
11
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
11
12
  from snowflake.ml.model._packager.model_env import model_env
@@ -14,8 +15,13 @@ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
14
15
  from snowflake.ml.model._packager.model_meta import (
15
16
  model_blob_meta,
16
17
  model_meta as model_meta_api,
18
+ model_meta_schema,
19
+ )
20
+ from snowflake.ml.model._signatures import (
21
+ numpy_handler,
22
+ snowpark_handler,
23
+ utils as model_signature_utils,
17
24
  )
18
- from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
19
25
 
20
26
  if TYPE_CHECKING:
21
27
  import sklearn.base
@@ -36,6 +42,27 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
36
42
 
37
43
  DEFAULT_TARGET_METHODS = ["predict", "transform", "predict_proba", "predict_log_proba", "decision_function"]
38
44
 
45
+ @classmethod
46
+ def get_model_objective(
47
+ cls, model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"]
48
+ ) -> model_types.ModelObjective:
49
+ import sklearn.pipeline
50
+ from sklearn.base import is_classifier, is_regressor
51
+
52
+ if isinstance(model, sklearn.pipeline.Pipeline):
53
+ return model_types.ModelObjective.UNKNOWN
54
+ if is_regressor(model):
55
+ return model_types.ModelObjective.REGRESSION
56
+ if is_classifier(model):
57
+ classes_list = getattr(model, "classes_", [])
58
+ num_classes = getattr(model, "n_classes_", None) or len(classes_list)
59
+ if isinstance(num_classes, int):
60
+ if num_classes > 2:
61
+ return model_types.ModelObjective.MULTI_CLASSIFICATION
62
+ return model_types.ModelObjective.BINARY_CLASSIFICATION
63
+ return model_types.ModelObjective.UNKNOWN
64
+ return model_types.ModelObjective.UNKNOWN
65
+
39
66
  @classmethod
40
67
  def can_handle(
41
68
  cls,
@@ -68,6 +95,18 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
68
95
 
69
96
  return cast(Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"], model)
70
97
 
98
+ @staticmethod
99
+ def get_explainability_supported_background(
100
+ sample_input_data: Optional[model_types.SupportedDataType] = None,
101
+ ) -> Optional[pd.DataFrame]:
102
+ if isinstance(sample_input_data, pd.DataFrame) or isinstance(sample_input_data, sp_df.DataFrame):
103
+ return (
104
+ sample_input_data
105
+ if isinstance(sample_input_data, pd.DataFrame)
106
+ else snowpark_handler.SnowparkDataFrameHandler.convert_to_df(sample_input_data)
107
+ )
108
+ return None
109
+
71
110
  @classmethod
72
111
  def save_model(
73
112
  cls,
@@ -79,11 +118,31 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
79
118
  is_sub_model: Optional[bool] = False,
80
119
  **kwargs: Unpack[model_types.SKLModelSaveOptions],
81
120
  ) -> None:
121
+ # setting None by default to distinguish if users did not set it
122
+ enable_explainability = kwargs.get("enable_explainability", None)
123
+
82
124
  import sklearn.base
83
125
  import sklearn.pipeline
84
126
 
85
127
  assert isinstance(model, sklearn.base.BaseEstimator) or isinstance(model, sklearn.pipeline.Pipeline)
86
128
 
129
+ background_data = cls.get_explainability_supported_background(sample_input_data)
130
+
131
+ # if users did not ask then we enable if we have background data
132
+ if enable_explainability is None and background_data is not None:
133
+ enable_explainability = True
134
+ if enable_explainability:
135
+ # if users set it explicitly but no background data then error out
136
+ if background_data is None:
137
+ raise ValueError(
138
+ "Sample input data is required to enable explainability. Currently we only support this for "
139
+ + "`pandas.DataFrame` and `snowflake.snowpark.dataframe.DataFrame`."
140
+ )
141
+ data_blob_path = os.path.join(model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR)
142
+ os.makedirs(data_blob_path, exist_ok=True)
143
+ with open(os.path.join(data_blob_path, name + cls.BG_DATA_FILE_SUFFIX), "wb") as f:
144
+ background_data.to_parquet(f)
145
+
87
146
  if not is_sub_model:
88
147
  target_methods = handlers_utils.get_target_methods(
89
148
  model=model,
@@ -110,19 +169,38 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
110
169
  get_prediction_fn=get_prediction,
111
170
  )
112
171
 
172
+ model_objective = cls.get_model_objective(model)
173
+ model_meta.model_objective = model_objective
174
+
175
+ if enable_explainability:
176
+ output_type = model_signature.DataType.DOUBLE
177
+
178
+ if model_objective == model_types.ModelObjective.MULTI_CLASSIFICATION:
179
+ output_type = model_signature.DataType.STRING
180
+ model_meta = handlers_utils.add_explain_method_signature(
181
+ model_meta=model_meta,
182
+ explain_method="explain",
183
+ target_method="predict",
184
+ output_return_type=output_type,
185
+ )
186
+
113
187
  model_blob_path = os.path.join(model_blobs_dir_path, name)
114
188
  os.makedirs(model_blob_path, exist_ok=True)
115
- with open(os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR), "wb") as f:
189
+ with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
116
190
  cloudpickle.dump(model, f)
117
191
  base_meta = model_blob_meta.ModelBlobMeta(
118
192
  name=name,
119
193
  model_type=cls.HANDLER_TYPE,
120
194
  handler_version=cls.HANDLER_VERSION,
121
- path=cls.MODELE_BLOB_FILE_OR_DIR,
195
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
122
196
  )
123
197
  model_meta.models[name] = base_meta
124
198
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
125
199
 
200
+ if enable_explainability:
201
+ model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
202
+ model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
203
+
126
204
  model_meta.env.include_if_absent(
127
205
  [model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn")], check_local_version=True
128
206
  )
@@ -153,6 +231,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
153
231
  cls,
154
232
  raw_model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
155
233
  model_meta: model_meta_api.ModelMetadata,
234
+ background_data: Optional[pd.DataFrame] = None,
156
235
  **kwargs: Unpack[model_types.SKLModelLoadOptions],
157
236
  ) -> custom_model.CustomModel:
158
237
  from snowflake.ml.model import custom_model
@@ -165,6 +244,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
165
244
  raw_model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
166
245
  signature: model_signature.ModelSignature,
167
246
  target_method: str,
247
+ background_data: Optional[pd.DataFrame],
168
248
  ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
169
249
  @custom_model.inference_api
170
250
  def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
@@ -179,11 +259,26 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
179
259
 
180
260
  return model_signature_utils.rename_pandas_df(df, signature.outputs)
181
261
 
262
+ @custom_model.inference_api
263
+ def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
264
+ import shap
265
+
266
+ # TODO: if not resolved by explainer, we need to pass the callable function
267
+ try:
268
+ explainer = shap.Explainer(raw_model, background_data)
269
+ df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
270
+ except TypeError as e:
271
+ raise ValueError(f"Explanation for this model type not supported yet: {str(e)}")
272
+ return model_signature_utils.rename_pandas_df(df, signature.outputs)
273
+
274
+ if target_method == "explain":
275
+ return explain_fn
276
+
182
277
  return fn
183
278
 
184
279
  type_method_dict = {}
185
280
  for target_method_name, sig in model_meta.signatures.items():
186
- type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
281
+ type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name, background_data)
187
282
 
188
283
  _SKLModel = type(
189
284
  "_SKLModel",