snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
1
1
  import os
2
+ import warnings
2
3
  from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union, cast, final
3
4
 
4
5
  import cloudpickle
@@ -6,22 +7,21 @@ import numpy as np
6
7
  import pandas as pd
7
8
  from typing_extensions import TypeGuard, Unpack
8
9
 
9
- import snowflake.snowpark.dataframe as sp_df
10
10
  from snowflake.ml._internal import type_utils
11
11
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
12
12
  from snowflake.ml.model._packager.model_env import model_env
13
- from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
13
+ from snowflake.ml.model._packager.model_handlers import (
14
+ _base,
15
+ _utils as handlers_utils,
16
+ model_objective_utils,
17
+ )
14
18
  from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
15
19
  from snowflake.ml.model._packager.model_meta import (
16
20
  model_blob_meta,
17
21
  model_meta as model_meta_api,
18
22
  model_meta_schema,
19
23
  )
20
- from snowflake.ml.model._signatures import (
21
- numpy_handler,
22
- snowpark_handler,
23
- utils as model_signature_utils,
24
- )
24
+ from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
25
25
 
26
26
  if TYPE_CHECKING:
27
27
  import sklearn.base
@@ -40,28 +40,14 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
40
40
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
41
41
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
42
42
 
43
- DEFAULT_TARGET_METHODS = ["predict", "transform", "predict_proba", "predict_log_proba", "decision_function"]
44
-
45
- @classmethod
46
- def get_model_objective(
47
- cls, model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"]
48
- ) -> model_meta_schema.ModelObjective:
49
- import sklearn.pipeline
50
- from sklearn.base import is_classifier, is_regressor
51
-
52
- if isinstance(model, sklearn.pipeline.Pipeline):
53
- return model_meta_schema.ModelObjective.UNKNOWN
54
- if is_regressor(model):
55
- return model_meta_schema.ModelObjective.REGRESSION
56
- if is_classifier(model):
57
- classes_list = getattr(model, "classes_", [])
58
- num_classes = getattr(model, "n_classes_", None) or len(classes_list)
59
- if isinstance(num_classes, int):
60
- if num_classes > 2:
61
- return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
62
- return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
63
- return model_meta_schema.ModelObjective.UNKNOWN
64
- return model_meta_schema.ModelObjective.UNKNOWN
43
+ DEFAULT_TARGET_METHODS = [
44
+ "predict",
45
+ "transform",
46
+ "predict_proba",
47
+ "predict_log_proba",
48
+ "decision_function",
49
+ ]
50
+ EXPLAIN_TARGET_METHODS = ["predict", "predict_proba", "predict_log_proba"]
65
51
 
66
52
  @classmethod
67
53
  def can_handle(
@@ -106,32 +92,17 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
106
92
  is_sub_model: Optional[bool] = False,
107
93
  **kwargs: Unpack[model_types.SKLModelSaveOptions],
108
94
  ) -> None:
109
- enable_explainability = kwargs.get("enable_explainability", False)
95
+ # setting None by default to distinguish if users did not set it
96
+ enable_explainability = kwargs.get("enable_explainability", None)
110
97
 
111
98
  import sklearn.base
112
99
  import sklearn.pipeline
113
100
 
114
101
  assert isinstance(model, sklearn.base.BaseEstimator) or isinstance(model, sklearn.pipeline.Pipeline)
115
-
116
- enable_explainability = kwargs.get("enable_explainability", False)
117
102
  if enable_explainability:
118
- # TODO: Currently limited to pandas df, need to extend to other types.
119
- if sample_input_data is None or not (
120
- isinstance(sample_input_data, pd.DataFrame) or isinstance(sample_input_data, sp_df.DataFrame)
121
- ):
122
- raise ValueError(
123
- "Sample input data is required to enable explainability. Currently we only support this for "
124
- + "`pandas.DataFrame` and `snowflake.snowpark.dataframe.DataFrame`."
125
- )
126
- sample_input_data_pandas = (
127
- sample_input_data
128
- if isinstance(sample_input_data, pd.DataFrame)
129
- else snowpark_handler.SnowparkDataFrameHandler.convert_to_df(sample_input_data)
130
- )
131
- data_blob_path = os.path.join(model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR)
132
- os.makedirs(data_blob_path, exist_ok=True)
133
- with open(os.path.join(data_blob_path, name + cls.BG_DATA_FILE_SUFFIX), "wb") as f:
134
- sample_input_data_pandas.to_parquet(f)
103
+ # if users set it explicitly but no sample_input_data then error out
104
+ if sample_input_data is None:
105
+ raise ValueError("Sample input data is required to enable explainability.")
135
106
 
136
107
  if not is_sub_model:
137
108
  target_methods = handlers_utils.get_target_methods(
@@ -141,7 +112,8 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
141
112
  )
142
113
 
143
114
  def get_prediction(
144
- target_method_name: str, sample_input_data: model_types.SupportedLocalDataType
115
+ target_method_name: str,
116
+ sample_input_data: model_types.SupportedLocalDataType,
145
117
  ) -> model_types.SupportedLocalDataType:
146
118
  if not isinstance(sample_input_data, (pd.DataFrame, np.ndarray)):
147
119
  sample_input_data = model_signature._convert_local_data_to_df(sample_input_data)
@@ -159,15 +131,40 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
159
131
  get_prediction_fn=get_prediction,
160
132
  )
161
133
 
134
+ explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
135
+
136
+ background_data = handlers_utils.get_explainability_supported_background(
137
+ sample_input_data, model_meta, explain_target_method
138
+ )
139
+
140
+ model_task_and_output_type = model_objective_utils.get_model_task_and_output_type(model)
141
+ model_meta.task = model_task_and_output_type.task
142
+
143
+ # if users did not ask then we enable if we have background data
144
+ if enable_explainability is None:
145
+ if background_data is None:
146
+ warnings.warn(
147
+ "sample_input_data should be provided to enable explainability by default",
148
+ category=UserWarning,
149
+ stacklevel=1,
150
+ )
151
+ enable_explainability = False
152
+ else:
153
+ enable_explainability = True
162
154
  if enable_explainability:
163
- output_type = model_signature.DataType.DOUBLE
164
- if cls.get_model_objective(model) == model_meta_schema.ModelObjective.MULTI_CLASSIFICATION:
165
- output_type = model_signature.DataType.STRING
155
+ handlers_utils.save_background_data(
156
+ model_blobs_dir_path,
157
+ cls.EXPLAIN_ARTIFACTS_DIR,
158
+ cls.BG_DATA_FILE_SUFFIX,
159
+ name,
160
+ background_data,
161
+ )
162
+
166
163
  model_meta = handlers_utils.add_explain_method_signature(
167
164
  model_meta=model_meta,
168
165
  explain_method="explain",
169
- target_method="predict",
170
- output_return_type=output_type,
166
+ target_method=explain_target_method,
167
+ output_return_type=model_task_and_output_type.output_type,
171
168
  )
172
169
 
173
170
  model_blob_path = os.path.join(model_blobs_dir_path, name)
@@ -184,13 +181,12 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
184
181
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
185
182
 
186
183
  if enable_explainability:
187
- model_meta.env.include_if_absent(
188
- [model_env.ModelDependency(requirement="shap", pip_name="shap")],
189
- check_local_version=True,
190
- )
184
+ model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
185
+ model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
191
186
 
192
187
  model_meta.env.include_if_absent(
193
- [model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn")], check_local_version=True
188
+ [model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn")],
189
+ check_local_version=True,
194
190
  )
195
191
 
196
192
  @classmethod
@@ -1,20 +1,27 @@
1
1
  import os
2
2
  import warnings
3
- from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
3
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast, final
4
4
 
5
5
  import cloudpickle
6
6
  import numpy as np
7
7
  import pandas as pd
8
+ from packaging import version
8
9
  from typing_extensions import TypeGuard, Unpack
9
10
 
10
11
  from snowflake.ml._internal import type_utils
12
+ from snowflake.ml._internal.exceptions import exceptions
11
13
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
12
14
  from snowflake.ml.model._packager.model_env import model_env
13
- from snowflake.ml.model._packager.model_handlers import _base
15
+ from snowflake.ml.model._packager.model_handlers import (
16
+ _base,
17
+ _utils as handlers_utils,
18
+ model_objective_utils,
19
+ )
14
20
  from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
15
21
  from snowflake.ml.model._packager.model_meta import (
16
22
  model_blob_meta,
17
23
  model_meta as model_meta_api,
24
+ model_meta_schema,
18
25
  )
19
26
  from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
20
27
 
@@ -36,6 +43,8 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
36
43
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
37
44
 
38
45
  DEFAULT_TARGET_METHODS = ["predict", "transform", "predict_proba", "predict_log_proba", "decision_function"]
46
+ EXPLAIN_TARGET_METHODS = ["predict", "predict_proba", "predict_log_proba"]
47
+
39
48
  IS_AUTO_SIGNATURE = True
40
49
 
41
50
  @classmethod
@@ -62,6 +71,60 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
62
71
 
63
72
  return cast("BaseEstimator", model)
64
73
 
74
+ @classmethod
75
+ def _get_local_version_package(cls, pkg_name: str) -> Optional[version.Version]:
76
+ from importlib import metadata as importlib_metadata
77
+
78
+ from packaging import version
79
+
80
+ local_version = None
81
+
82
+ try:
83
+ local_dist = importlib_metadata.distribution(pkg_name)
84
+ local_version = version.parse(local_dist.version)
85
+ except importlib_metadata.PackageNotFoundError:
86
+ pass
87
+
88
+ return local_version
89
+
90
+ @classmethod
91
+ def _can_support_xgb(cls, enable_explainability: Optional[bool]) -> bool:
92
+
93
+ local_xgb_version = cls._get_local_version_package("xgboost")
94
+
95
+ if local_xgb_version and local_xgb_version >= version.parse("2.1.0"):
96
+ if enable_explainability:
97
+ warnings.warn(
98
+ f"This version of xgboost {local_xgb_version} does not work with shap 0.42.1."
99
+ + "If you want model explanations, lower the xgboost version to <2.1.0.",
100
+ category=UserWarning,
101
+ stacklevel=1,
102
+ )
103
+ return False
104
+ return True
105
+
106
+ @classmethod
107
+ def _get_supported_object_for_explainability(
108
+ cls, estimator: "BaseEstimator", enable_explainability: Optional[bool]
109
+ ) -> Any:
110
+ from snowflake.ml.modeling import pipeline as snowml_pipeline
111
+
112
+ # handle pipeline objects separately
113
+ if isinstance(estimator, snowml_pipeline.Pipeline): # type: ignore[attr-defined]
114
+ return None
115
+
116
+ methods = ["to_xgboost", "to_lightgbm", "to_sklearn"]
117
+ for method_name in methods:
118
+ if hasattr(estimator, method_name):
119
+ try:
120
+ result = getattr(estimator, method_name)()
121
+ if method_name == "to_xgboost" and not cls._can_support_xgb(enable_explainability):
122
+ return None
123
+ return result
124
+ except exceptions.SnowflakeMLException:
125
+ pass # Do nothing and continue to the next method
126
+ return None
127
+
65
128
  @classmethod
66
129
  def save_model(
67
130
  cls,
@@ -73,9 +136,8 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
73
136
  is_sub_model: Optional[bool] = False,
74
137
  **kwargs: Unpack[model_types.SNOWModelSaveOptions],
75
138
  ) -> None:
76
- enable_explainability = kwargs.get("enable_explainability", False)
77
- if enable_explainability:
78
- raise NotImplementedError("Explainability is not supported for Snowpark ML model.")
139
+
140
+ enable_explainability = kwargs.get("enable_explainability", None)
79
141
 
80
142
  from snowflake.ml.modeling.framework.base import BaseEstimator
81
143
 
@@ -83,9 +145,9 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
83
145
  # Pipeline is inherited from BaseEstimator, so no need to add one more check
84
146
 
85
147
  if not is_sub_model:
86
- if sample_input_data is not None or model_meta.signatures:
148
+ if model_meta.signatures:
87
149
  warnings.warn(
88
- "Inferring model signature from sample input or providing model signature for Snowpark ML "
150
+ "Providing model signature for Snowpark ML "
89
151
  + "Modeling model is not required. Model signature will automatically be inferred during fitting. ",
90
152
  UserWarning,
91
153
  stacklevel=2,
@@ -105,6 +167,35 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
105
167
  raise ValueError(f"Target method {method_name} does not exist in the model.")
106
168
  model_meta.signatures = temp_model_signature_dict
107
169
 
170
+ if enable_explainability or enable_explainability is None:
171
+ python_base_obj = cls._get_supported_object_for_explainability(model, enable_explainability)
172
+ if python_base_obj is None:
173
+ if enable_explainability: # if user set enable_explainability to True, throw error else silently skip
174
+ raise ValueError(
175
+ "Explain only supported for xgboost, lightgbm and sklearn (not pipeline) Snowpark ML models."
176
+ )
177
+ # set None to False so we don't include shap in the environment
178
+ enable_explainability = False
179
+ else:
180
+ model_task_and_output_type = model_objective_utils.get_model_task_and_output_type(python_base_obj)
181
+ model_meta.task = model_task_and_output_type.task
182
+ explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
183
+ model_meta = handlers_utils.add_explain_method_signature(
184
+ model_meta=model_meta,
185
+ explain_method="explain",
186
+ target_method=explain_target_method,
187
+ output_return_type=model_task_and_output_type.output_type,
188
+ )
189
+ enable_explainability = True
190
+
191
+ background_data = handlers_utils.get_explainability_supported_background(
192
+ sample_input_data, model_meta, explain_target_method
193
+ )
194
+ if background_data is not None:
195
+ handlers_utils.save_background_data(
196
+ model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, cls.BG_DATA_FILE_SUFFIX, name, background_data
197
+ )
198
+
108
199
  model_blob_path = os.path.join(model_blobs_dir_path, name)
109
200
  os.makedirs(model_blob_path, exist_ok=True)
110
201
  with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
@@ -122,7 +213,29 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
122
213
  model_dependencies = model._get_dependencies()
123
214
  for dep in model_dependencies:
124
215
  pkg_name = dep.split("==")[0]
125
- _include_if_absent_pkgs.append(model_env.ModelDependency(requirement=pkg_name, pip_name=pkg_name))
216
+ if pkg_name != "xgboost":
217
+ _include_if_absent_pkgs.append(model_env.ModelDependency(requirement=pkg_name, pip_name=pkg_name))
218
+ continue
219
+
220
+ local_xgb_version = cls._get_local_version_package("xgboost")
221
+ if local_xgb_version and local_xgb_version >= version.parse("2.0.0") and enable_explainability:
222
+ model_meta.env.include_if_absent(
223
+ [
224
+ model_env.ModelDependency(requirement="xgboost==2.0.*", pip_name="xgboost"),
225
+ ],
226
+ check_local_version=False,
227
+ )
228
+ else:
229
+ model_meta.env.include_if_absent(
230
+ [
231
+ model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
232
+ ],
233
+ check_local_version=True,
234
+ )
235
+
236
+ if enable_explainability:
237
+ model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
238
+ model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
126
239
  model_meta.env.include_if_absent(_include_if_absent_pkgs, check_local_version=True)
127
240
 
128
241
  @classmethod
@@ -163,6 +276,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
163
276
  raw_model: "BaseEstimator",
164
277
  signature: model_signature.ModelSignature,
165
278
  target_method: str,
279
+ background_data: Optional[pd.DataFrame] = None,
166
280
  ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
167
281
  @custom_model.inference_api
168
282
  def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
@@ -177,11 +291,29 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
177
291
 
178
292
  return model_signature_utils.rename_pandas_df(df, signature.outputs)
179
293
 
294
+ @custom_model.inference_api
295
+ def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
296
+ import shap
297
+
298
+ methods = ["to_xgboost", "to_lightgbm", "to_sklearn"]
299
+ for method_name in methods:
300
+ try:
301
+ base_model = getattr(raw_model, method_name)()
302
+ explainer = shap.Explainer(base_model, masker=background_data)
303
+ df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
304
+ return model_signature_utils.rename_pandas_df(df, signature.outputs)
305
+ except exceptions.SnowflakeMLException:
306
+ pass # Do nothing and continue to the next method
307
+ raise ValueError("The model must be an xgboost, lightgbm or sklearn (not pipeline) estimator.")
308
+
309
+ if target_method == "explain":
310
+ return explain_fn
311
+
180
312
  return fn
181
313
 
182
314
  type_method_dict = {}
183
315
  for target_method_name, sig in model_meta.signatures.items():
184
- type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
316
+ type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name, background_data)
185
317
 
186
318
  _SnowMLModel = type(
187
319
  "_SnowMLModel",
@@ -111,7 +111,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
111
111
  model_blob_path = os.path.join(model_blobs_dir_path, name)
112
112
  os.makedirs(model_blob_path, exist_ok=True)
113
113
  with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
114
- torch.jit.save(model, f) # type:ignore[attr-defined]
114
+ torch.jit.save(model, f) # type:ignore[no-untyped-call, attr-defined]
115
115
  base_meta = model_blob_meta.ModelBlobMeta(
116
116
  name=name,
117
117
  model_type=cls.HANDLER_TYPE,
@@ -141,7 +141,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
141
141
  model_blob_metadata = model_blobs_metadata[name]
142
142
  model_blob_filename = model_blob_metadata.path
143
143
  with open(os.path.join(model_blob_path, model_blob_filename), "rb") as f:
144
- m = torch.jit.load( # type:ignore[attr-defined]
144
+ m = torch.jit.load( # type:ignore[no-untyped-call, attr-defined]
145
145
  f, map_location="cuda" if kwargs.get("use_gpu", False) else "cpu"
146
146
  )
147
147
  assert isinstance(m, torch.jit.ScriptModule) # type:ignore[attr-defined]
@@ -1,6 +1,7 @@
1
1
  # mypy: disable-error-code="import"
2
- import json
3
2
  import os
3
+ import warnings
4
+ from importlib import metadata as importlib_metadata
4
5
  from typing import (
5
6
  TYPE_CHECKING,
6
7
  Any,
@@ -15,12 +16,17 @@ from typing import (
15
16
 
16
17
  import numpy as np
17
18
  import pandas as pd
19
+ from packaging import version
18
20
  from typing_extensions import TypeGuard, Unpack
19
21
 
20
22
  from snowflake.ml._internal import type_utils
21
23
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
22
24
  from snowflake.ml.model._packager.model_env import model_env
23
- from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
25
+ from snowflake.ml.model._packager.model_handlers import (
26
+ _base,
27
+ _utils as handlers_utils,
28
+ model_objective_utils,
29
+ )
24
30
  from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
25
31
  from snowflake.ml.model._packager.model_meta import (
26
32
  model_blob_meta,
@@ -47,41 +53,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
47
53
 
48
54
  MODEL_BLOB_FILE_OR_DIR = "model.ubj"
49
55
  DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
50
- _BINARY_CLASSIFICATION_OBJECTIVE_PREFIX = ["binary:"]
51
- _MULTI_CLASSIFICATION_OBJECTIVE_PREFIX = ["multi:"]
52
- _RANKING_OBJECTIVE_PREFIX = ["rank:"]
53
- _REGRESSION_OBJECTIVE_PREFIX = ["reg:"]
54
-
55
- @classmethod
56
- def get_model_objective(
57
- cls, model: Union["xgboost.Booster", "xgboost.XGBModel"]
58
- ) -> model_meta_schema.ModelObjective:
59
- import xgboost
60
-
61
- if isinstance(model, xgboost.XGBClassifier) or isinstance(model, xgboost.XGBRFClassifier):
62
- num_classes = handlers_utils.get_num_classes_if_exists(model)
63
- if num_classes == 2:
64
- return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
65
- return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
66
- if isinstance(model, xgboost.XGBRegressor) or isinstance(model, xgboost.XGBRFRegressor):
67
- return model_meta_schema.ModelObjective.REGRESSION
68
- if isinstance(model, xgboost.XGBRanker):
69
- return model_meta_schema.ModelObjective.RANKING
70
- model_params = json.loads(model.save_config())
71
- model_objective = model_params["learner"]["objective"]
72
- for classification_objective in cls._BINARY_CLASSIFICATION_OBJECTIVE_PREFIX:
73
- if classification_objective in model_objective:
74
- return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
75
- for classification_objective in cls._MULTI_CLASSIFICATION_OBJECTIVE_PREFIX:
76
- if classification_objective in model_objective:
77
- return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
78
- for ranking_objective in cls._RANKING_OBJECTIVE_PREFIX:
79
- if ranking_objective in model_objective:
80
- return model_meta_schema.ModelObjective.RANKING
81
- for regression_objective in cls._REGRESSION_OBJECTIVE_PREFIX:
82
- if regression_objective in model_objective:
83
- return model_meta_schema.ModelObjective.REGRESSION
84
- return model_meta_schema.ModelObjective.UNKNOWN
56
+ EXPLAIN_TARGET_METHODS = ["predict", "predict_proba"]
85
57
 
86
58
  @classmethod
87
59
  def can_handle(
@@ -116,10 +88,29 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
116
88
  is_sub_model: Optional[bool] = False,
117
89
  **kwargs: Unpack[model_types.XGBModelSaveOptions],
118
90
  ) -> None:
91
+ enable_explainability = kwargs.get("enable_explainability", True)
92
+
119
93
  import xgboost
120
94
 
121
95
  assert isinstance(model, xgboost.Booster) or isinstance(model, xgboost.XGBModel)
122
96
 
97
+ local_xgb_version = None
98
+
99
+ try:
100
+ local_dist = importlib_metadata.distribution("xgboost")
101
+ local_xgb_version = version.parse(local_dist.version)
102
+ except importlib_metadata.PackageNotFoundError:
103
+ pass
104
+
105
+ if local_xgb_version and local_xgb_version >= version.parse("2.1.0") and enable_explainability:
106
+ warnings.warn(
107
+ f"This version of xgboost {local_xgb_version} does not work with shap 0.42.1."
108
+ + "If you want model explanations, lower the xgboost version to <2.1.0.",
109
+ category=UserWarning,
110
+ stacklevel=1,
111
+ )
112
+ enable_explainability = False
113
+
123
114
  if not is_sub_model:
124
115
  target_methods = handlers_utils.get_target_methods(
125
116
  model=model,
@@ -148,22 +139,35 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
148
139
  sample_input_data=sample_input_data,
149
140
  get_prediction_fn=get_prediction,
150
141
  )
151
- model_objective = cls.get_model_objective(model)
152
- model_meta.model_objective = model_objective
153
- if kwargs.get("enable_explainability", True):
154
- output_type = model_signature.DataType.DOUBLE
155
- if model_objective == model_meta_schema.ModelObjective.MULTI_CLASSIFICATION:
156
- output_type = model_signature.DataType.STRING
142
+ model_task_and_output = model_objective_utils.get_model_task_and_output_type(model)
143
+ model_meta.task = handlers_utils.validate_model_task(model_meta.task, model_task_and_output.task)
144
+ if enable_explainability:
157
145
  model_meta = handlers_utils.add_explain_method_signature(
158
146
  model_meta=model_meta,
159
147
  explain_method="explain",
160
148
  target_method="predict",
161
- output_return_type=output_type,
149
+ output_return_type=model_task_and_output.output_type,
162
150
  )
163
151
  model_meta.function_properties = {
164
152
  "explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False}
165
153
  }
166
154
 
155
+ explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
156
+
157
+ background_data = handlers_utils.get_explainability_supported_background(
158
+ sample_input_data, model_meta, explain_target_method
159
+ )
160
+ if background_data is not None:
161
+ handlers_utils.save_background_data(
162
+ model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, cls.BG_DATA_FILE_SUFFIX, name, background_data
163
+ )
164
+ else:
165
+ warnings.warn(
166
+ "sample_input_data should be provided for better explainability results",
167
+ category=UserWarning,
168
+ stacklevel=1,
169
+ )
170
+
167
171
  model_blob_path = os.path.join(model_blobs_dir_path, name)
168
172
  os.makedirs(model_blob_path, exist_ok=True)
169
173
  model.save_model(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
@@ -180,15 +184,26 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
180
184
  model_meta.env.include_if_absent(
181
185
  [
182
186
  model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn"),
183
- model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
184
187
  ],
185
188
  check_local_version=True,
186
189
  )
187
- if kwargs.get("enable_explainability", True):
190
+ if local_xgb_version and local_xgb_version >= version.parse("2.0.0") and enable_explainability:
191
+ model_meta.env.include_if_absent(
192
+ [
193
+ model_env.ModelDependency(requirement="xgboost==2.0.*", pip_name="xgboost"),
194
+ ],
195
+ check_local_version=False,
196
+ )
197
+ else:
188
198
  model_meta.env.include_if_absent(
189
- [model_env.ModelDependency(requirement="shap", pip_name="shap")],
199
+ [
200
+ model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
201
+ ],
190
202
  check_local_version=True,
191
203
  )
204
+
205
+ if enable_explainability:
206
+ model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
192
207
  model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
193
208
  model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
194
209
 
@@ -269,7 +284,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
269
284
  import shap
270
285
 
271
286
  explainer = shap.TreeExplainer(raw_model)
272
- df = pd.DataFrame(explainer(X).values)
287
+ df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
273
288
  return model_signature_utils.rename_pandas_df(df, signature.outputs)
274
289
 
275
290
  if target_method == "explain":