snowflake-ml-python 1.6.2__py3-none-any.whl → 1.6.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (262) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/utils/db_utils.py +50 -0
  13. snowflake/ml/_internal/utils/service_logger.py +63 -0
  14. snowflake/ml/_internal/utils/sql_identifier.py +25 -1
  15. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  16. snowflake/ml/data/ingestor_utils.py +20 -10
  17. snowflake/ml/feature_store/access_manager.py +3 -3
  18. snowflake/ml/feature_store/feature_store.py +19 -2
  19. snowflake/ml/feature_store/feature_view.py +82 -28
  20. snowflake/ml/fileset/stage_fs.py +2 -1
  21. snowflake/ml/lineage/lineage_node.py +7 -2
  22. snowflake/ml/model/__init__.py +1 -2
  23. snowflake/ml/model/_client/model/model_version_impl.py +78 -9
  24. snowflake/ml/model/_client/ops/model_ops.py +89 -7
  25. snowflake/ml/model/_client/ops/service_ops.py +200 -91
  26. snowflake/ml/model/_client/service/model_deployment_spec.py +4 -0
  27. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -0
  28. snowflake/ml/model/_client/sql/_base.py +5 -0
  29. snowflake/ml/model/_client/sql/model.py +1 -0
  30. snowflake/ml/model/_client/sql/model_version.py +9 -5
  31. snowflake/ml/model/_client/sql/service.py +35 -13
  32. snowflake/ml/model/_model_composer/model_composer.py +11 -41
  33. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +29 -4
  34. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  35. snowflake/ml/model/_packager/model_handlers/_utils.py +106 -32
  36. snowflake/ml/model/_packager/model_handlers/catboost.py +26 -27
  37. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +3 -3
  38. snowflake/ml/model/_packager/model_handlers/lightgbm.py +21 -6
  39. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  40. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +111 -58
  41. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  42. snowflake/ml/model/_packager/model_handlers/sklearn.py +50 -66
  43. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +36 -17
  44. snowflake/ml/model/_packager/model_handlers/xgboost.py +22 -7
  45. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -45
  46. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -6
  47. snowflake/ml/model/_packager/model_packager.py +14 -10
  48. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  49. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  50. snowflake/ml/model/type_hints.py +11 -152
  51. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +0 -2
  52. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +17 -6
  53. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  54. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +1 -0
  55. snowflake/ml/modeling/cluster/affinity_propagation.py +1 -0
  56. snowflake/ml/modeling/cluster/agglomerative_clustering.py +1 -0
  57. snowflake/ml/modeling/cluster/birch.py +1 -0
  58. snowflake/ml/modeling/cluster/bisecting_k_means.py +1 -0
  59. snowflake/ml/modeling/cluster/dbscan.py +1 -0
  60. snowflake/ml/modeling/cluster/feature_agglomeration.py +1 -0
  61. snowflake/ml/modeling/cluster/k_means.py +1 -0
  62. snowflake/ml/modeling/cluster/mean_shift.py +1 -0
  63. snowflake/ml/modeling/cluster/mini_batch_k_means.py +1 -0
  64. snowflake/ml/modeling/cluster/optics.py +1 -0
  65. snowflake/ml/modeling/cluster/spectral_biclustering.py +1 -0
  66. snowflake/ml/modeling/cluster/spectral_clustering.py +1 -0
  67. snowflake/ml/modeling/cluster/spectral_coclustering.py +1 -0
  68. snowflake/ml/modeling/compose/column_transformer.py +1 -0
  69. snowflake/ml/modeling/compose/transformed_target_regressor.py +1 -0
  70. snowflake/ml/modeling/covariance/elliptic_envelope.py +1 -0
  71. snowflake/ml/modeling/covariance/empirical_covariance.py +1 -0
  72. snowflake/ml/modeling/covariance/graphical_lasso.py +1 -0
  73. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +1 -0
  74. snowflake/ml/modeling/covariance/ledoit_wolf.py +1 -0
  75. snowflake/ml/modeling/covariance/min_cov_det.py +1 -0
  76. snowflake/ml/modeling/covariance/oas.py +1 -0
  77. snowflake/ml/modeling/covariance/shrunk_covariance.py +1 -0
  78. snowflake/ml/modeling/decomposition/dictionary_learning.py +1 -0
  79. snowflake/ml/modeling/decomposition/factor_analysis.py +1 -0
  80. snowflake/ml/modeling/decomposition/fast_ica.py +1 -0
  81. snowflake/ml/modeling/decomposition/incremental_pca.py +1 -0
  82. snowflake/ml/modeling/decomposition/kernel_pca.py +1 -0
  83. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +1 -0
  84. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +1 -0
  85. snowflake/ml/modeling/decomposition/pca.py +1 -0
  86. snowflake/ml/modeling/decomposition/sparse_pca.py +1 -0
  87. snowflake/ml/modeling/decomposition/truncated_svd.py +1 -0
  88. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +1 -0
  89. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +1 -0
  90. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +1 -0
  91. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +1 -0
  92. snowflake/ml/modeling/ensemble/bagging_classifier.py +1 -0
  93. snowflake/ml/modeling/ensemble/bagging_regressor.py +1 -0
  94. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +1 -0
  95. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +1 -0
  96. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +1 -0
  97. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +1 -0
  98. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +1 -0
  99. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +1 -0
  100. snowflake/ml/modeling/ensemble/isolation_forest.py +1 -0
  101. snowflake/ml/modeling/ensemble/random_forest_classifier.py +1 -0
  102. snowflake/ml/modeling/ensemble/random_forest_regressor.py +1 -0
  103. snowflake/ml/modeling/ensemble/stacking_regressor.py +1 -0
  104. snowflake/ml/modeling/ensemble/voting_classifier.py +1 -0
  105. snowflake/ml/modeling/ensemble/voting_regressor.py +1 -0
  106. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +1 -0
  107. snowflake/ml/modeling/feature_selection/select_fdr.py +1 -0
  108. snowflake/ml/modeling/feature_selection/select_fpr.py +1 -0
  109. snowflake/ml/modeling/feature_selection/select_fwe.py +1 -0
  110. snowflake/ml/modeling/feature_selection/select_k_best.py +1 -0
  111. snowflake/ml/modeling/feature_selection/select_percentile.py +1 -0
  112. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +1 -0
  113. snowflake/ml/modeling/feature_selection/variance_threshold.py +1 -0
  114. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +1 -0
  115. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +1 -0
  116. snowflake/ml/modeling/impute/iterative_imputer.py +1 -0
  117. snowflake/ml/modeling/impute/knn_imputer.py +1 -0
  118. snowflake/ml/modeling/impute/missing_indicator.py +1 -0
  119. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +1 -0
  120. snowflake/ml/modeling/kernel_approximation/nystroem.py +1 -0
  121. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +1 -0
  122. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +1 -0
  123. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +1 -0
  124. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +1 -0
  125. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +1 -0
  126. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +1 -0
  127. snowflake/ml/modeling/linear_model/ard_regression.py +1 -0
  128. snowflake/ml/modeling/linear_model/bayesian_ridge.py +1 -0
  129. snowflake/ml/modeling/linear_model/elastic_net.py +1 -0
  130. snowflake/ml/modeling/linear_model/elastic_net_cv.py +1 -0
  131. snowflake/ml/modeling/linear_model/gamma_regressor.py +1 -0
  132. snowflake/ml/modeling/linear_model/huber_regressor.py +1 -0
  133. snowflake/ml/modeling/linear_model/lars.py +1 -0
  134. snowflake/ml/modeling/linear_model/lars_cv.py +1 -0
  135. snowflake/ml/modeling/linear_model/lasso.py +1 -0
  136. snowflake/ml/modeling/linear_model/lasso_cv.py +1 -0
  137. snowflake/ml/modeling/linear_model/lasso_lars.py +1 -0
  138. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +1 -0
  139. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +1 -0
  140. snowflake/ml/modeling/linear_model/linear_regression.py +1 -0
  141. snowflake/ml/modeling/linear_model/logistic_regression.py +1 -0
  142. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +1 -0
  143. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +1 -0
  144. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +1 -0
  145. snowflake/ml/modeling/linear_model/multi_task_lasso.py +1 -0
  146. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +1 -0
  147. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +1 -0
  148. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +1 -0
  149. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +1 -0
  150. snowflake/ml/modeling/linear_model/perceptron.py +1 -0
  151. snowflake/ml/modeling/linear_model/poisson_regressor.py +1 -0
  152. snowflake/ml/modeling/linear_model/ransac_regressor.py +1 -0
  153. snowflake/ml/modeling/linear_model/ridge.py +1 -0
  154. snowflake/ml/modeling/linear_model/ridge_classifier.py +1 -0
  155. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +1 -0
  156. snowflake/ml/modeling/linear_model/ridge_cv.py +1 -0
  157. snowflake/ml/modeling/linear_model/sgd_classifier.py +1 -0
  158. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +1 -0
  159. snowflake/ml/modeling/linear_model/sgd_regressor.py +1 -0
  160. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +1 -0
  161. snowflake/ml/modeling/linear_model/tweedie_regressor.py +1 -0
  162. snowflake/ml/modeling/manifold/isomap.py +1 -0
  163. snowflake/ml/modeling/manifold/mds.py +1 -0
  164. snowflake/ml/modeling/manifold/spectral_embedding.py +1 -0
  165. snowflake/ml/modeling/manifold/tsne.py +1 -0
  166. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  167. snowflake/ml/modeling/metrics/ranking.py +0 -3
  168. snowflake/ml/modeling/metrics/regression.py +0 -3
  169. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +1 -0
  170. snowflake/ml/modeling/mixture/gaussian_mixture.py +1 -0
  171. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +1 -0
  172. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +1 -0
  173. snowflake/ml/modeling/multiclass/output_code_classifier.py +1 -0
  174. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +1 -0
  175. snowflake/ml/modeling/naive_bayes/categorical_nb.py +1 -0
  176. snowflake/ml/modeling/naive_bayes/complement_nb.py +1 -0
  177. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +1 -0
  178. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +1 -0
  179. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +1 -0
  180. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +1 -0
  181. snowflake/ml/modeling/neighbors/kernel_density.py +1 -0
  182. snowflake/ml/modeling/neighbors/local_outlier_factor.py +1 -0
  183. snowflake/ml/modeling/neighbors/nearest_centroid.py +1 -0
  184. snowflake/ml/modeling/neighbors/nearest_neighbors.py +1 -0
  185. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +1 -0
  186. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +1 -0
  187. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +1 -0
  188. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +1 -0
  189. snowflake/ml/modeling/neural_network/mlp_classifier.py +1 -0
  190. snowflake/ml/modeling/neural_network/mlp_regressor.py +1 -0
  191. snowflake/ml/modeling/pipeline/pipeline.py +0 -1
  192. snowflake/ml/modeling/preprocessing/polynomial_features.py +1 -0
  193. snowflake/ml/modeling/semi_supervised/label_propagation.py +1 -0
  194. snowflake/ml/modeling/semi_supervised/label_spreading.py +1 -0
  195. snowflake/ml/modeling/svm/linear_svc.py +1 -0
  196. snowflake/ml/modeling/svm/linear_svr.py +1 -0
  197. snowflake/ml/modeling/svm/nu_svc.py +1 -0
  198. snowflake/ml/modeling/svm/nu_svr.py +1 -0
  199. snowflake/ml/modeling/svm/svc.py +1 -0
  200. snowflake/ml/modeling/svm/svr.py +1 -0
  201. snowflake/ml/modeling/tree/decision_tree_classifier.py +1 -0
  202. snowflake/ml/modeling/tree/decision_tree_regressor.py +1 -0
  203. snowflake/ml/modeling/tree/extra_tree_classifier.py +1 -0
  204. snowflake/ml/modeling/tree/extra_tree_regressor.py +1 -0
  205. snowflake/ml/modeling/xgboost/xgb_classifier.py +1 -0
  206. snowflake/ml/modeling/xgboost/xgb_regressor.py +1 -0
  207. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +1 -0
  208. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +1 -0
  209. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  210. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  211. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  212. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  213. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  214. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  215. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  216. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  217. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  218. snowflake/ml/registry/_manager/model_manager.py +4 -4
  219. snowflake/ml/registry/registry.py +165 -6
  220. snowflake/ml/version.py +1 -1
  221. {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.4.dist-info}/METADATA +30 -9
  222. {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.4.dist-info}/RECORD +225 -249
  223. {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.4.dist-info}/WHEEL +1 -1
  224. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  225. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  226. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  227. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  228. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  229. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  230. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  231. snowflake/ml/_internal/utils/uri.py +0 -77
  232. snowflake/ml/model/_api.py +0 -568
  233. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  234. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  235. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  236. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  237. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  238. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  239. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  240. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  241. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  242. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  243. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  244. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  245. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  246. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  247. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  248. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  249. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  250. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  251. snowflake/ml/model/_packager/model_handlers/llm.py +0 -269
  252. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  253. snowflake/ml/model/deploy_platforms.py +0 -6
  254. snowflake/ml/model/models/llm.py +0 -106
  255. snowflake/ml/monitoring/monitor.py +0 -203
  256. snowflake/ml/registry/_initial_schema.py +0 -142
  257. snowflake/ml/registry/_schema.py +0 -82
  258. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  259. snowflake/ml/registry/_schema_version_manager.py +0 -163
  260. snowflake/ml/registry/model_registry.py +0 -2048
  261. {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.4.dist-info}/LICENSE.txt +0 -0
  262. {snowflake_ml_python-1.6.2.dist-info → snowflake_ml_python-1.6.4.dist-info}/top_level.txt +0 -0
@@ -2,23 +2,67 @@ import json
2
2
  from dataclasses import dataclass
3
3
  from typing import TYPE_CHECKING, Any, Union
4
4
 
5
+ from snowflake.ml._internal import type_utils
5
6
  from snowflake.ml.model import model_signature, type_hints
6
7
  from snowflake.ml.model._packager.model_handlers import _utils as handlers_utils
7
8
 
8
9
  if TYPE_CHECKING:
10
+ import catboost
9
11
  import lightgbm
12
+ import sklearn
13
+ import sklearn.pipeline
10
14
  import xgboost
11
15
 
12
16
 
13
17
  @dataclass
14
- class ModelObjectiveAndOutputType:
15
- objective: type_hints.ModelObjective
18
+ class ModelTaskAndOutputType:
19
+ task: type_hints.Task
16
20
  output_type: model_signature.DataType
17
21
 
18
22
 
19
- def get_model_objective_lightgbm(model: Union["lightgbm.Booster", "lightgbm.LGBMModel"]) -> type_hints.ModelObjective:
23
+ def get_task_skl(model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"]) -> type_hints.Task:
24
+ from sklearn.base import is_classifier, is_regressor
25
+
26
+ if type_utils.LazyType("sklearn.pipeline.Pipeline").isinstance(model):
27
+ return type_hints.Task.UNKNOWN
28
+ if is_regressor(model):
29
+ return type_hints.Task.TABULAR_REGRESSION
30
+ if is_classifier(model):
31
+ classes_list = getattr(model, "classes_", [])
32
+ num_classes = getattr(model, "n_classes_", None) or len(classes_list)
33
+ if isinstance(num_classes, int):
34
+ if num_classes > 2:
35
+ return type_hints.Task.TABULAR_MULTI_CLASSIFICATION
36
+ return type_hints.Task.TABULAR_BINARY_CLASSIFICATION
37
+ return type_hints.Task.UNKNOWN
38
+ return type_hints.Task.UNKNOWN
39
+
40
+
41
+ def get_model_task_catboost(model: "catboost.CatBoost") -> type_hints.Task:
42
+ loss_function = None
43
+ if type_utils.LazyType("catboost.CatBoost").isinstance(model):
44
+ loss_function = model.get_all_params()["loss_function"] # type: ignore[attr-defined]
45
+
46
+ if (type_utils.LazyType("catboost.CatBoostClassifier").isinstance(model)) or model._is_classification_objective(
47
+ loss_function
48
+ ):
49
+ num_classes = handlers_utils.get_num_classes_if_exists(model)
50
+ if num_classes == 0:
51
+ return type_hints.Task.UNKNOWN
52
+ if num_classes <= 2:
53
+ return type_hints.Task.TABULAR_BINARY_CLASSIFICATION
54
+ return type_hints.Task.TABULAR_MULTI_CLASSIFICATION
55
+ if (type_utils.LazyType("catboost.CatBoostRanker").isinstance(model)) or model._is_ranking_objective(loss_function):
56
+ return type_hints.Task.TABULAR_RANKING
57
+ if (type_utils.LazyType("catboost.CatBoostRegressor").isinstance(model)) or model._is_regression_objective(
58
+ loss_function
59
+ ):
60
+ return type_hints.Task.TABULAR_REGRESSION
20
61
 
21
- import lightgbm
62
+ return type_hints.Task.UNKNOWN
63
+
64
+
65
+ def get_model_task_lightgbm(model: Union["lightgbm.Booster", "lightgbm.LGBMModel"]) -> type_hints.Task:
22
66
 
23
67
  _BINARY_CLASSIFICATION_OBJECTIVES = ["binary"]
24
68
  _MULTI_CLASSIFICATION_OBJECTIVES = ["multiclass", "multiclassova"]
@@ -36,81 +80,90 @@ def get_model_objective_lightgbm(model: Union["lightgbm.Booster", "lightgbm.LGBM
36
80
  ]
37
81
 
38
82
  # does not account for cross-entropy and custom
39
- if isinstance(model, lightgbm.LGBMClassifier):
40
- num_classes = handlers_utils.get_num_classes_if_exists(model)
41
- if num_classes == 2:
42
- return type_hints.ModelObjective.BINARY_CLASSIFICATION
43
- return type_hints.ModelObjective.MULTI_CLASSIFICATION
44
- if isinstance(model, lightgbm.LGBMRanker):
45
- return type_hints.ModelObjective.RANKING
46
- if isinstance(model, lightgbm.LGBMRegressor):
47
- return type_hints.ModelObjective.REGRESSION
48
- model_objective = model.params["objective"]
49
- if model_objective in _BINARY_CLASSIFICATION_OBJECTIVES:
50
- return type_hints.ModelObjective.BINARY_CLASSIFICATION
51
- if model_objective in _MULTI_CLASSIFICATION_OBJECTIVES:
52
- return type_hints.ModelObjective.MULTI_CLASSIFICATION
53
- if model_objective in _RANKING_OBJECTIVES:
54
- return type_hints.ModelObjective.RANKING
55
- if model_objective in _REGRESSION_OBJECTIVES:
56
- return type_hints.ModelObjective.REGRESSION
57
- return type_hints.ModelObjective.UNKNOWN
58
-
59
-
60
- def get_model_objective_xgb(model: Union["xgboost.Booster", "xgboost.XGBModel"]) -> type_hints.ModelObjective:
61
-
62
- import xgboost
83
+ model_task = ""
84
+ if type_utils.LazyType("lightgbm.Booster").isinstance(model):
85
+ model_task = model.params["objective"] # type: ignore[attr-defined]
86
+ elif hasattr(model, "objective_"):
87
+ model_task = model.objective_
88
+ if model_task in _BINARY_CLASSIFICATION_OBJECTIVES:
89
+ return type_hints.Task.TABULAR_BINARY_CLASSIFICATION
90
+ if model_task in _MULTI_CLASSIFICATION_OBJECTIVES:
91
+ return type_hints.Task.TABULAR_MULTI_CLASSIFICATION
92
+ if model_task in _RANKING_OBJECTIVES:
93
+ return type_hints.Task.TABULAR_RANKING
94
+ if model_task in _REGRESSION_OBJECTIVES:
95
+ return type_hints.Task.TABULAR_REGRESSION
96
+ return type_hints.Task.UNKNOWN
97
+
98
+
99
+ def get_model_task_xgb(model: Union["xgboost.Booster", "xgboost.XGBModel"]) -> type_hints.Task:
63
100
 
64
101
  _BINARY_CLASSIFICATION_OBJECTIVE_PREFIX = ["binary:"]
65
102
  _MULTI_CLASSIFICATION_OBJECTIVE_PREFIX = ["multi:"]
66
103
  _RANKING_OBJECTIVE_PREFIX = ["rank:"]
67
104
  _REGRESSION_OBJECTIVE_PREFIX = ["reg:"]
68
105
 
69
- model_objective = ""
70
- if isinstance(model, xgboost.Booster):
71
- model_params = json.loads(model.save_config())
72
- model_objective = model_params.get("learner", {}).get("objective", "")
106
+ model_task = ""
107
+ if type_utils.LazyType("xgboost.Booster").isinstance(model):
108
+ model_params = json.loads(model.save_config()) # type: ignore[attr-defined]
109
+ model_task = model_params.get("learner", {}).get("objective", "")
73
110
  else:
74
111
  if hasattr(model, "get_params"):
75
- model_objective = model.get_params().get("objective", "")
112
+ model_task = model.get_params().get("objective", "")
76
113
 
77
- if isinstance(model_objective, dict):
78
- model_objective = model_objective.get("name", "")
114
+ if isinstance(model_task, dict):
115
+ model_task = model_task.get("name", "")
79
116
  for classification_objective in _BINARY_CLASSIFICATION_OBJECTIVE_PREFIX:
80
- if classification_objective in model_objective:
81
- return type_hints.ModelObjective.BINARY_CLASSIFICATION
117
+ if classification_objective in model_task:
118
+ return type_hints.Task.TABULAR_BINARY_CLASSIFICATION
82
119
  for classification_objective in _MULTI_CLASSIFICATION_OBJECTIVE_PREFIX:
83
- if classification_objective in model_objective:
84
- return type_hints.ModelObjective.MULTI_CLASSIFICATION
120
+ if classification_objective in model_task:
121
+ return type_hints.Task.TABULAR_MULTI_CLASSIFICATION
85
122
  for ranking_objective in _RANKING_OBJECTIVE_PREFIX:
86
- if ranking_objective in model_objective:
87
- return type_hints.ModelObjective.RANKING
123
+ if ranking_objective in model_task:
124
+ return type_hints.Task.TABULAR_RANKING
88
125
  for regression_objective in _REGRESSION_OBJECTIVE_PREFIX:
89
- if regression_objective in model_objective:
90
- return type_hints.ModelObjective.REGRESSION
91
- return type_hints.ModelObjective.UNKNOWN
126
+ if regression_objective in model_task:
127
+ return type_hints.Task.TABULAR_REGRESSION
128
+ return type_hints.Task.UNKNOWN
92
129
 
93
130
 
94
- def get_model_objective_and_output_type(model: Any) -> ModelObjectiveAndOutputType:
95
- import xgboost
131
+ def get_model_task_and_output_type(model: Any) -> ModelTaskAndOutputType:
132
+ if type_utils.LazyType("xgboost.Booster").isinstance(model) or type_utils.LazyType("xgboost.XGBModel").isinstance(
133
+ model
134
+ ):
135
+ task = get_model_task_xgb(model)
136
+ output_type = model_signature.DataType.DOUBLE
137
+ if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION:
138
+ output_type = model_signature.DataType.STRING
139
+ return ModelTaskAndOutputType(task=task, output_type=output_type)
96
140
 
97
- if isinstance(model, xgboost.Booster) or isinstance(model, xgboost.XGBModel):
98
- model_objective = get_model_objective_xgb(model)
141
+ if type_utils.LazyType("lightgbm.Booster").isinstance(model) or type_utils.LazyType(
142
+ "lightgbm.LGBMModel"
143
+ ).isinstance(model):
144
+ task = get_model_task_lightgbm(model)
99
145
  output_type = model_signature.DataType.DOUBLE
100
- if model_objective == type_hints.ModelObjective.MULTI_CLASSIFICATION:
146
+ if task in [
147
+ type_hints.Task.TABULAR_BINARY_CLASSIFICATION,
148
+ type_hints.Task.TABULAR_MULTI_CLASSIFICATION,
149
+ ]:
101
150
  output_type = model_signature.DataType.STRING
102
- return ModelObjectiveAndOutputType(objective=model_objective, output_type=output_type)
151
+ return ModelTaskAndOutputType(task=task, output_type=output_type)
103
152
 
104
- import lightgbm
153
+ if type_utils.LazyType("catboost.CatBoost").isinstance(model):
154
+ task = get_model_task_catboost(model)
155
+ output_type = model_signature.DataType.DOUBLE
156
+ if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION:
157
+ output_type = model_signature.DataType.STRING
158
+ return ModelTaskAndOutputType(task=task, output_type=output_type)
105
159
 
106
- if isinstance(model, lightgbm.Booster) or isinstance(model, lightgbm.LGBMModel):
107
- model_objective = get_model_objective_lightgbm(model)
160
+ if type_utils.LazyType("sklearn.base.BaseEstimator").isinstance(model) or type_utils.LazyType(
161
+ "sklearn.pipeline.Pipeline"
162
+ ).isinstance(model):
163
+ task = get_task_skl(model)
108
164
  output_type = model_signature.DataType.DOUBLE
109
- if model_objective in [
110
- type_hints.ModelObjective.BINARY_CLASSIFICATION,
111
- type_hints.ModelObjective.MULTI_CLASSIFICATION,
112
- ]:
165
+ if task == type_hints.Task.TABULAR_MULTI_CLASSIFICATION:
113
166
  output_type = model_signature.DataType.STRING
114
- return ModelObjectiveAndOutputType(objective=model_objective, output_type=output_type)
167
+ return ModelTaskAndOutputType(task=task, output_type=output_type)
115
168
 
116
169
  raise ValueError(f"Model type {type(model)} is not supported")
@@ -2,7 +2,6 @@ import logging
2
2
  import os
3
3
  from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
4
4
 
5
- import cloudpickle
6
5
  import pandas as pd
7
6
  from typing_extensions import TypeGuard, Unpack
8
7
 
@@ -120,9 +119,21 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
120
119
  model_meta.env.include_if_absent(
121
120
  [
122
121
  model_env.ModelDependency(requirement="sentence-transformers", pip_name="sentence-transformers"),
122
+ model_env.ModelDependency(requirement="transformers", pip_name="transformers"),
123
+ model_env.ModelDependency(requirement="pytorch", pip_name="torch"),
123
124
  ],
124
125
  check_local_version=True,
125
126
  )
127
+ model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
128
+
129
+ @staticmethod
130
+ def _get_device_config(**kwargs: Unpack[model_types.SentenceTransformersLoadOptions]) -> Optional[str]:
131
+ if kwargs.get("device", None) is not None:
132
+ return kwargs["device"]
133
+ elif kwargs.get("use_gpu", False):
134
+ return "cuda"
135
+
136
+ return None
126
137
 
127
138
  @classmethod
128
139
  def load_model(
@@ -144,13 +155,9 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
144
155
  model_blob_filename = model_blob_metadata.path
145
156
  model_blob_file_or_dir_path = os.path.join(model_blob_path, model_blob_filename)
146
157
 
147
- if os.path.isdir(model_blob_file_or_dir_path): # if the saved model is a directory
148
- model = sentence_transformers.SentenceTransformer(model_blob_file_or_dir_path)
149
- else:
150
- assert os.path.isfile(model_blob_file_or_dir_path) # if the saved model is a file
151
- with open(model_blob_file_or_dir_path, "rb") as f:
152
- model = cloudpickle.load(f)
153
- assert isinstance(model, sentence_transformers.SentenceTransformer)
158
+ model = sentence_transformers.SentenceTransformer(
159
+ model_blob_file_or_dir_path, device=cls._get_device_config(**kwargs)
160
+ )
154
161
  return model
155
162
 
156
163
  @classmethod
@@ -1,4 +1,5 @@
1
1
  import os
2
+ import warnings
2
3
  from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union, cast, final
3
4
 
4
5
  import cloudpickle
@@ -6,22 +7,21 @@ import numpy as np
6
7
  import pandas as pd
7
8
  from typing_extensions import TypeGuard, Unpack
8
9
 
9
- import snowflake.snowpark.dataframe as sp_df
10
10
  from snowflake.ml._internal import type_utils
11
11
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
12
12
  from snowflake.ml.model._packager.model_env import model_env
13
- from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
13
+ from snowflake.ml.model._packager.model_handlers import (
14
+ _base,
15
+ _utils as handlers_utils,
16
+ model_objective_utils,
17
+ )
14
18
  from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
15
19
  from snowflake.ml.model._packager.model_meta import (
16
20
  model_blob_meta,
17
21
  model_meta as model_meta_api,
18
22
  model_meta_schema,
19
23
  )
20
- from snowflake.ml.model._signatures import (
21
- numpy_handler,
22
- snowpark_handler,
23
- utils as model_signature_utils,
24
- )
24
+ from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
25
25
 
26
26
  if TYPE_CHECKING:
27
27
  import sklearn.base
@@ -40,28 +40,14 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
40
40
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
41
41
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
42
42
 
43
- DEFAULT_TARGET_METHODS = ["predict", "transform", "predict_proba", "predict_log_proba", "decision_function"]
44
-
45
- @classmethod
46
- def get_model_objective(
47
- cls, model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"]
48
- ) -> model_types.ModelObjective:
49
- import sklearn.pipeline
50
- from sklearn.base import is_classifier, is_regressor
51
-
52
- if isinstance(model, sklearn.pipeline.Pipeline):
53
- return model_types.ModelObjective.UNKNOWN
54
- if is_regressor(model):
55
- return model_types.ModelObjective.REGRESSION
56
- if is_classifier(model):
57
- classes_list = getattr(model, "classes_", [])
58
- num_classes = getattr(model, "n_classes_", None) or len(classes_list)
59
- if isinstance(num_classes, int):
60
- if num_classes > 2:
61
- return model_types.ModelObjective.MULTI_CLASSIFICATION
62
- return model_types.ModelObjective.BINARY_CLASSIFICATION
63
- return model_types.ModelObjective.UNKNOWN
64
- return model_types.ModelObjective.UNKNOWN
43
+ DEFAULT_TARGET_METHODS = [
44
+ "predict",
45
+ "transform",
46
+ "predict_proba",
47
+ "predict_log_proba",
48
+ "decision_function",
49
+ ]
50
+ EXPLAIN_TARGET_METHODS = ["predict", "predict_proba", "predict_log_proba"]
65
51
 
66
52
  @classmethod
67
53
  def can_handle(
@@ -95,18 +81,6 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
95
81
 
96
82
  return cast(Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"], model)
97
83
 
98
- @staticmethod
99
- def get_explainability_supported_background(
100
- sample_input_data: Optional[model_types.SupportedDataType] = None,
101
- ) -> Optional[pd.DataFrame]:
102
- if isinstance(sample_input_data, pd.DataFrame) or isinstance(sample_input_data, sp_df.DataFrame):
103
- return (
104
- sample_input_data
105
- if isinstance(sample_input_data, pd.DataFrame)
106
- else snowpark_handler.SnowparkDataFrameHandler.convert_to_df(sample_input_data)
107
- )
108
- return None
109
-
110
84
  @classmethod
111
85
  def save_model(
112
86
  cls,
@@ -125,23 +99,10 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
125
99
  import sklearn.pipeline
126
100
 
127
101
  assert isinstance(model, sklearn.base.BaseEstimator) or isinstance(model, sklearn.pipeline.Pipeline)
128
-
129
- background_data = cls.get_explainability_supported_background(sample_input_data)
130
-
131
- # if users did not ask then we enable if we have background data
132
- if enable_explainability is None and background_data is not None:
133
- enable_explainability = True
134
102
  if enable_explainability:
135
- # if users set it explicitly but no background data then error out
136
- if background_data is None:
137
- raise ValueError(
138
- "Sample input data is required to enable explainability. Currently we only support this for "
139
- + "`pandas.DataFrame` and `snowflake.snowpark.dataframe.DataFrame`."
140
- )
141
- data_blob_path = os.path.join(model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR)
142
- os.makedirs(data_blob_path, exist_ok=True)
143
- with open(os.path.join(data_blob_path, name + cls.BG_DATA_FILE_SUFFIX), "wb") as f:
144
- background_data.to_parquet(f)
103
+ # if users set it explicitly but no sample_input_data then error out
104
+ if sample_input_data is None:
105
+ raise ValueError("Sample input data is required to enable explainability.")
145
106
 
146
107
  if not is_sub_model:
147
108
  target_methods = handlers_utils.get_target_methods(
@@ -151,7 +112,8 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
151
112
  )
152
113
 
153
114
  def get_prediction(
154
- target_method_name: str, sample_input_data: model_types.SupportedLocalDataType
115
+ target_method_name: str,
116
+ sample_input_data: model_types.SupportedLocalDataType,
155
117
  ) -> model_types.SupportedLocalDataType:
156
118
  if not isinstance(sample_input_data, (pd.DataFrame, np.ndarray)):
157
119
  sample_input_data = model_signature._convert_local_data_to_df(sample_input_data)
@@ -169,19 +131,40 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
169
131
  get_prediction_fn=get_prediction,
170
132
  )
171
133
 
172
- model_objective = cls.get_model_objective(model)
173
- model_meta.model_objective = model_objective
134
+ explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
174
135
 
136
+ background_data = handlers_utils.get_explainability_supported_background(
137
+ sample_input_data, model_meta, explain_target_method
138
+ )
139
+
140
+ model_task_and_output_type = model_objective_utils.get_model_task_and_output_type(model)
141
+ model_meta.task = model_task_and_output_type.task
142
+
143
+ # if users did not ask then we enable if we have background data
144
+ if enable_explainability is None:
145
+ if background_data is None:
146
+ warnings.warn(
147
+ "sample_input_data should be provided to enable explainability by default",
148
+ category=UserWarning,
149
+ stacklevel=1,
150
+ )
151
+ enable_explainability = False
152
+ else:
153
+ enable_explainability = True
175
154
  if enable_explainability:
176
- output_type = model_signature.DataType.DOUBLE
155
+ handlers_utils.save_background_data(
156
+ model_blobs_dir_path,
157
+ cls.EXPLAIN_ARTIFACTS_DIR,
158
+ cls.BG_DATA_FILE_SUFFIX,
159
+ name,
160
+ background_data,
161
+ )
177
162
 
178
- if model_objective == model_types.ModelObjective.MULTI_CLASSIFICATION:
179
- output_type = model_signature.DataType.STRING
180
163
  model_meta = handlers_utils.add_explain_method_signature(
181
164
  model_meta=model_meta,
182
165
  explain_method="explain",
183
- target_method="predict",
184
- output_return_type=output_type,
166
+ target_method=explain_target_method,
167
+ output_return_type=model_task_and_output_type.output_type,
185
168
  )
186
169
 
187
170
  model_blob_path = os.path.join(model_blobs_dir_path, name)
@@ -202,7 +185,8 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
202
185
  model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
203
186
 
204
187
  model_meta.env.include_if_absent(
205
- [model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn")], check_local_version=True
188
+ [model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn")],
189
+ check_local_version=True,
206
190
  )
207
191
 
208
192
  @classmethod
@@ -43,6 +43,8 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
43
43
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
44
44
 
45
45
  DEFAULT_TARGET_METHODS = ["predict", "transform", "predict_proba", "predict_log_proba", "decision_function"]
46
+ EXPLAIN_TARGET_METHODS = ["predict", "predict_proba", "predict_log_proba"]
47
+
46
48
  IS_AUTO_SIGNATURE = True
47
49
 
48
50
  @classmethod
@@ -71,13 +73,14 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
71
73
 
72
74
  @classmethod
73
75
  def _get_local_version_package(cls, pkg_name: str) -> Optional[version.Version]:
74
- import importlib_metadata
76
+ from importlib import metadata as importlib_metadata
77
+
75
78
  from packaging import version
76
79
 
77
80
  local_version = None
78
81
 
79
82
  try:
80
- local_dist = importlib_metadata.distribution(pkg_name) # type: ignore[no-untyped-call]
83
+ local_dist = importlib_metadata.distribution(pkg_name)
81
84
  local_version = version.parse(local_dist.version)
82
85
  except importlib_metadata.PackageNotFoundError:
83
86
  pass
@@ -104,7 +107,13 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
104
107
  def _get_supported_object_for_explainability(
105
108
  cls, estimator: "BaseEstimator", enable_explainability: Optional[bool]
106
109
  ) -> Any:
107
- methods = ["to_xgboost", "to_lightgbm"]
110
+ from snowflake.ml.modeling import pipeline as snowml_pipeline
111
+
112
+ # handle pipeline objects separately
113
+ if isinstance(estimator, snowml_pipeline.Pipeline): # type: ignore[attr-defined]
114
+ return None
115
+
116
+ methods = ["to_xgboost", "to_lightgbm", "to_sklearn"]
108
117
  for method_name in methods:
109
118
  if hasattr(estimator, method_name):
110
119
  try:
@@ -136,9 +145,9 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
136
145
  # Pipeline is inherited from BaseEstimator, so no need to add one more check
137
146
 
138
147
  if not is_sub_model:
139
- if sample_input_data is not None or model_meta.signatures:
148
+ if model_meta.signatures:
140
149
  warnings.warn(
141
- "Inferring model signature from sample input or providing model signature for Snowpark ML "
150
+ "Providing model signature for Snowpark ML "
142
151
  + "Modeling model is not required. Model signature will automatically be inferred during fitting. ",
143
152
  UserWarning,
144
153
  stacklevel=2,
@@ -162,22 +171,31 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
162
171
  python_base_obj = cls._get_supported_object_for_explainability(model, enable_explainability)
163
172
  if python_base_obj is None:
164
173
  if enable_explainability: # if user set enable_explainability to True, throw error else silently skip
165
- raise ValueError("Explain only support for xgboost or lightgbm Snowpark ML models.")
174
+ raise ValueError(
175
+ "Explain only supported for xgboost, lightgbm and sklearn (not pipeline) Snowpark ML models."
176
+ )
166
177
  # set None to False so we don't include shap in the environment
167
178
  enable_explainability = False
168
179
  else:
169
- model_objective_and_output_type = model_objective_utils.get_model_objective_and_output_type(
170
- python_base_obj
171
- )
172
- model_meta.model_objective = model_objective_and_output_type.objective
180
+ model_task_and_output_type = model_objective_utils.get_model_task_and_output_type(python_base_obj)
181
+ model_meta.task = model_task_and_output_type.task
182
+ explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
173
183
  model_meta = handlers_utils.add_explain_method_signature(
174
184
  model_meta=model_meta,
175
185
  explain_method="explain",
176
- target_method="predict",
177
- output_return_type=model_objective_and_output_type.output_type,
186
+ target_method=explain_target_method,
187
+ output_return_type=model_task_and_output_type.output_type,
178
188
  )
179
189
  enable_explainability = True
180
190
 
191
+ background_data = handlers_utils.get_explainability_supported_background(
192
+ sample_input_data, model_meta, explain_target_method
193
+ )
194
+ if background_data is not None:
195
+ handlers_utils.save_background_data(
196
+ model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, cls.BG_DATA_FILE_SUFFIX, name, background_data
197
+ )
198
+
181
199
  model_blob_path = os.path.join(model_blobs_dir_path, name)
182
200
  os.makedirs(model_blob_path, exist_ok=True)
183
201
  with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
@@ -258,6 +276,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
258
276
  raw_model: "BaseEstimator",
259
277
  signature: model_signature.ModelSignature,
260
278
  target_method: str,
279
+ background_data: Optional[pd.DataFrame] = None,
261
280
  ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
262
281
  @custom_model.inference_api
263
282
  def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
@@ -276,16 +295,16 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
276
295
  def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
277
296
  import shap
278
297
 
279
- methods = ["to_xgboost", "to_lightgbm"]
298
+ methods = ["to_xgboost", "to_lightgbm", "to_sklearn"]
280
299
  for method_name in methods:
281
300
  try:
282
301
  base_model = getattr(raw_model, method_name)()
283
- explainer = shap.TreeExplainer(base_model)
284
- df = pd.DataFrame(explainer(X).values)
302
+ explainer = shap.Explainer(base_model, masker=background_data)
303
+ df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
285
304
  return model_signature_utils.rename_pandas_df(df, signature.outputs)
286
305
  except exceptions.SnowflakeMLException:
287
306
  pass # Do nothing and continue to the next method
288
- raise ValueError("The model must be an xgboost or lightgbm estimator.")
307
+ raise ValueError("The model must be an xgboost, lightgbm or sklearn (not pipeline) estimator.")
289
308
 
290
309
  if target_method == "explain":
291
310
  return explain_fn
@@ -294,7 +313,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
294
313
 
295
314
  type_method_dict = {}
296
315
  for target_method_name, sig in model_meta.signatures.items():
297
- type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
316
+ type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name, background_data)
298
317
 
299
318
  _SnowMLModel = type(
300
319
  "_SnowMLModel",
@@ -1,6 +1,7 @@
1
1
  # mypy: disable-error-code="import"
2
2
  import os
3
3
  import warnings
4
+ from importlib import metadata as importlib_metadata
4
5
  from typing import (
5
6
  TYPE_CHECKING,
6
7
  Any,
@@ -13,7 +14,6 @@ from typing import (
13
14
  final,
14
15
  )
15
16
 
16
- import importlib_metadata
17
17
  import numpy as np
18
18
  import pandas as pd
19
19
  from packaging import version
@@ -53,6 +53,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
53
53
 
54
54
  MODEL_BLOB_FILE_OR_DIR = "model.ubj"
55
55
  DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
56
+ EXPLAIN_TARGET_METHODS = ["predict", "predict_proba"]
56
57
 
57
58
  @classmethod
58
59
  def can_handle(
@@ -96,7 +97,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
96
97
  local_xgb_version = None
97
98
 
98
99
  try:
99
- local_dist = importlib_metadata.distribution("xgboost") # type: ignore[no-untyped-call]
100
+ local_dist = importlib_metadata.distribution("xgboost")
100
101
  local_xgb_version = version.parse(local_dist.version)
101
102
  except importlib_metadata.PackageNotFoundError:
102
103
  pass
@@ -138,21 +139,35 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
138
139
  sample_input_data=sample_input_data,
139
140
  get_prediction_fn=get_prediction,
140
141
  )
141
- model_objective_and_output = model_objective_utils.get_model_objective_and_output_type(model)
142
- model_meta.model_objective = handlers_utils.validate_model_objective(
143
- model_meta.model_objective, model_objective_and_output.objective
144
- )
142
+ model_task_and_output = model_objective_utils.get_model_task_and_output_type(model)
143
+ model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output.task)
145
144
  if enable_explainability:
146
145
  model_meta = handlers_utils.add_explain_method_signature(
147
146
  model_meta=model_meta,
148
147
  explain_method="explain",
149
148
  target_method="predict",
150
- output_return_type=model_objective_and_output.output_type,
149
+ output_return_type=model_task_and_output.output_type,
151
150
  )
152
151
  model_meta.function_properties = {
153
152
  "explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False}
154
153
  }
155
154
 
155
+ explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
156
+
157
+ background_data = handlers_utils.get_explainability_supported_background(
158
+ sample_input_data, model_meta, explain_target_method
159
+ )
160
+ if background_data is not None:
161
+ handlers_utils.save_background_data(
162
+ model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, cls.BG_DATA_FILE_SUFFIX, name, background_data
163
+ )
164
+ else:
165
+ warnings.warn(
166
+ "sample_input_data should be provided for better explainability results",
167
+ category=UserWarning,
168
+ stacklevel=1,
169
+ )
170
+
156
171
  model_blob_path = os.path.join(model_blobs_dir_path, name)
157
172
  os.makedirs(model_blob_path, exist_ok=True)
158
173
  model.save_model(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))