snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +156 -20
  5. snowflake/ml/_internal/utils/identifier.py +48 -11
  6. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  7. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  8. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  9. snowflake/ml/_internal/utils/table_manager.py +19 -1
  10. snowflake/ml/_internal/utils/uri.py +2 -2
  11. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  12. snowflake/ml/data/data_connector.py +88 -9
  13. snowflake/ml/data/data_ingestor.py +18 -1
  14. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  15. snowflake/ml/data/torch_utils.py +68 -0
  16. snowflake/ml/dataset/dataset.py +1 -3
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +9 -3
  19. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  20. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  21. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  22. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  23. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  24. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  26. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  27. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  29. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  31. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  32. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  33. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  34. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  35. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  36. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  37. snowflake/ml/feature_store/feature_store.py +100 -41
  38. snowflake/ml/feature_store/feature_view.py +149 -5
  39. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  40. snowflake/ml/fileset/fileset.py +1 -1
  41. snowflake/ml/fileset/sfcfs.py +9 -3
  42. snowflake/ml/model/_client/model/model_impl.py +11 -2
  43. snowflake/ml/model/_client/model/model_version_impl.py +186 -20
  44. snowflake/ml/model/_client/ops/model_ops.py +144 -30
  45. snowflake/ml/model/_client/ops/service_ops.py +312 -0
  46. snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
  47. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
  48. snowflake/ml/model/_client/sql/model_version.py +13 -4
  49. snowflake/ml/model/_client/sql/service.py +196 -0
  50. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  51. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  52. snowflake/ml/model/_model_composer/model_composer.py +5 -0
  53. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
  54. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  55. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  57. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  58. snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
  59. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  60. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
  61. snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
  62. snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
  63. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  64. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  65. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  66. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  67. snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
  68. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
  69. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  70. snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
  71. snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
  72. snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
  73. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
  74. snowflake/ml/model/_packager/model_packager.py +4 -1
  75. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  76. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  77. snowflake/ml/model/_signatures/utils.py +9 -0
  78. snowflake/ml/model/models/llm.py +3 -1
  79. snowflake/ml/model/type_hints.py +10 -4
  80. snowflake/ml/modeling/_internal/constants.py +1 -0
  81. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  82. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  83. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  84. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  85. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  86. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  87. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  88. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  89. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  90. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  91. snowflake/ml/modeling/cluster/birch.py +60 -21
  92. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  93. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  94. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  95. snowflake/ml/modeling/cluster/k_means.py +60 -21
  96. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  97. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  98. snowflake/ml/modeling/cluster/optics.py +60 -21
  99. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  100. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  101. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  102. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  103. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  104. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  105. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  106. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  107. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  108. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  109. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  110. snowflake/ml/modeling/covariance/oas.py +60 -21
  111. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  112. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  113. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  114. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  115. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  116. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  117. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  118. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  119. snowflake/ml/modeling/decomposition/pca.py +60 -21
  120. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  121. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  122. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  123. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  124. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  125. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  126. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  127. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  128. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  129. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  130. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  131. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  132. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  133. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  134. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  135. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  136. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  137. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  138. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  139. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  140. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  141. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  142. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  143. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  144. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  145. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  146. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  147. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  148. snowflake/ml/modeling/framework/base.py +28 -19
  149. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  150. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  151. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  152. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  153. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  154. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  155. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  156. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  157. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  158. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  159. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  160. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  161. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  162. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  163. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  164. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  165. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  166. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  167. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  168. snowflake/ml/modeling/linear_model/lars.py +60 -21
  169. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  170. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  171. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  172. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  173. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  174. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  175. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  176. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  177. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  178. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  179. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  180. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  181. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  182. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  183. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  184. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  185. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  186. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  187. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  188. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  189. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  190. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  191. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  192. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  193. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  194. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  195. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  196. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  197. snowflake/ml/modeling/manifold/isomap.py +60 -21
  198. snowflake/ml/modeling/manifold/mds.py +60 -21
  199. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  200. snowflake/ml/modeling/manifold/tsne.py +60 -21
  201. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  202. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  203. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  204. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  205. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  206. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  207. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  208. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  209. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  210. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  211. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  212. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  213. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  214. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  215. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  216. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  217. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  218. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  219. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  220. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  221. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  222. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  223. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  224. snowflake/ml/modeling/pipeline/pipeline.py +4 -12
  225. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  226. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  227. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  228. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  229. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  230. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  231. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  232. snowflake/ml/modeling/svm/svc.py +60 -21
  233. snowflake/ml/modeling/svm/svr.py +60 -21
  234. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  235. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  236. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  237. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  238. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  239. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  240. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  241. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  242. snowflake/ml/registry/_manager/model_manager.py +20 -2
  243. snowflake/ml/registry/model_registry.py +1 -1
  244. snowflake/ml/registry/registry.py +1 -2
  245. snowflake/ml/utils/sql_client.py +22 -0
  246. snowflake/ml/version.py +1 -1
  247. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
  248. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
  249. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  250. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  251. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  252. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -1,20 +1,27 @@
1
1
  import os
2
2
  import warnings
3
- from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
3
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast, final
4
4
 
5
5
  import cloudpickle
6
6
  import numpy as np
7
7
  import pandas as pd
8
+ from packaging import version
8
9
  from typing_extensions import TypeGuard, Unpack
9
10
 
10
11
  from snowflake.ml._internal import type_utils
12
+ from snowflake.ml._internal.exceptions import exceptions
11
13
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
12
14
  from snowflake.ml.model._packager.model_env import model_env
13
- from snowflake.ml.model._packager.model_handlers import _base
15
+ from snowflake.ml.model._packager.model_handlers import (
16
+ _base,
17
+ _utils as handlers_utils,
18
+ model_objective_utils,
19
+ )
14
20
  from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
15
21
  from snowflake.ml.model._packager.model_meta import (
16
22
  model_blob_meta,
17
23
  model_meta as model_meta_api,
24
+ model_meta_schema,
18
25
  )
19
26
  from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
20
27
 
@@ -62,6 +69,53 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
62
69
 
63
70
  return cast("BaseEstimator", model)
64
71
 
72
+ @classmethod
73
+ def _get_local_version_package(cls, pkg_name: str) -> Optional[version.Version]:
74
+ import importlib_metadata
75
+ from packaging import version
76
+
77
+ local_version = None
78
+
79
+ try:
80
+ local_dist = importlib_metadata.distribution(pkg_name) # type: ignore[no-untyped-call]
81
+ local_version = version.parse(local_dist.version)
82
+ except importlib_metadata.PackageNotFoundError:
83
+ pass
84
+
85
+ return local_version
86
+
87
+ @classmethod
88
+ def _can_support_xgb(cls, enable_explainability: Optional[bool]) -> bool:
89
+
90
+ local_xgb_version = cls._get_local_version_package("xgboost")
91
+
92
+ if local_xgb_version and local_xgb_version >= version.parse("2.1.0"):
93
+ if enable_explainability:
94
+ warnings.warn(
95
+ f"This version of xgboost {local_xgb_version} does not work with shap 0.42.1."
96
+ + "If you want model explanations, lower the xgboost version to <2.1.0.",
97
+ category=UserWarning,
98
+ stacklevel=1,
99
+ )
100
+ return False
101
+ return True
102
+
103
+ @classmethod
104
+ def _get_supported_object_for_explainability(
105
+ cls, estimator: "BaseEstimator", enable_explainability: Optional[bool]
106
+ ) -> Any:
107
+ methods = ["to_xgboost", "to_lightgbm"]
108
+ for method_name in methods:
109
+ if hasattr(estimator, method_name):
110
+ try:
111
+ result = getattr(estimator, method_name)()
112
+ if method_name == "to_xgboost" and not cls._can_support_xgb(enable_explainability):
113
+ return None
114
+ return result
115
+ except exceptions.SnowflakeMLException:
116
+ pass # Do nothing and continue to the next method
117
+ return None
118
+
65
119
  @classmethod
66
120
  def save_model(
67
121
  cls,
@@ -73,6 +127,9 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
73
127
  is_sub_model: Optional[bool] = False,
74
128
  **kwargs: Unpack[model_types.SNOWModelSaveOptions],
75
129
  ) -> None:
130
+
131
+ enable_explainability = kwargs.get("enable_explainability", None)
132
+
76
133
  from snowflake.ml.modeling.framework.base import BaseEstimator
77
134
 
78
135
  assert isinstance(model, BaseEstimator)
@@ -101,15 +158,35 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
101
158
  raise ValueError(f"Target method {method_name} does not exist in the model.")
102
159
  model_meta.signatures = temp_model_signature_dict
103
160
 
161
+ if enable_explainability or enable_explainability is None:
162
+ python_base_obj = cls._get_supported_object_for_explainability(model, enable_explainability)
163
+ if python_base_obj is None:
164
+ if enable_explainability: # if user set enable_explainability to True, throw error else silently skip
165
+ raise ValueError("Explain only support for xgboost or lightgbm Snowpark ML models.")
166
+ # set None to False so we don't include shap in the environment
167
+ enable_explainability = False
168
+ else:
169
+ model_objective_and_output_type = model_objective_utils.get_model_objective_and_output_type(
170
+ python_base_obj
171
+ )
172
+ model_meta.model_objective = model_objective_and_output_type.objective
173
+ model_meta = handlers_utils.add_explain_method_signature(
174
+ model_meta=model_meta,
175
+ explain_method="explain",
176
+ target_method="predict",
177
+ output_return_type=model_objective_and_output_type.output_type,
178
+ )
179
+ enable_explainability = True
180
+
104
181
  model_blob_path = os.path.join(model_blobs_dir_path, name)
105
182
  os.makedirs(model_blob_path, exist_ok=True)
106
- with open(os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR), "wb") as f:
183
+ with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
107
184
  cloudpickle.dump(model, f)
108
185
  base_meta = model_blob_meta.ModelBlobMeta(
109
186
  name=name,
110
187
  model_type=cls.HANDLER_TYPE,
111
188
  handler_version=cls.HANDLER_VERSION,
112
- path=cls.MODELE_BLOB_FILE_OR_DIR,
189
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
113
190
  )
114
191
  model_meta.models[name] = base_meta
115
192
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
@@ -118,7 +195,29 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
118
195
  model_dependencies = model._get_dependencies()
119
196
  for dep in model_dependencies:
120
197
  pkg_name = dep.split("==")[0]
121
- _include_if_absent_pkgs.append(model_env.ModelDependency(requirement=pkg_name, pip_name=pkg_name))
198
+ if pkg_name != "xgboost":
199
+ _include_if_absent_pkgs.append(model_env.ModelDependency(requirement=pkg_name, pip_name=pkg_name))
200
+ continue
201
+
202
+ local_xgb_version = cls._get_local_version_package("xgboost")
203
+ if local_xgb_version and local_xgb_version >= version.parse("2.0.0") and enable_explainability:
204
+ model_meta.env.include_if_absent(
205
+ [
206
+ model_env.ModelDependency(requirement="xgboost==2.0.*", pip_name="xgboost"),
207
+ ],
208
+ check_local_version=False,
209
+ )
210
+ else:
211
+ model_meta.env.include_if_absent(
212
+ [
213
+ model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
214
+ ],
215
+ check_local_version=True,
216
+ )
217
+
218
+ if enable_explainability:
219
+ model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
220
+ model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
122
221
  model_meta.env.include_if_absent(_include_if_absent_pkgs, check_local_version=True)
123
222
 
124
223
  @classmethod
@@ -146,6 +245,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
146
245
  cls,
147
246
  raw_model: "BaseEstimator",
148
247
  model_meta: model_meta_api.ModelMetadata,
248
+ background_data: Optional[pd.DataFrame] = None,
149
249
  **kwargs: Unpack[model_types.SNOWModelLoadOptions],
150
250
  ) -> custom_model.CustomModel:
151
251
  from snowflake.ml.model import custom_model
@@ -172,6 +272,24 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
172
272
 
173
273
  return model_signature_utils.rename_pandas_df(df, signature.outputs)
174
274
 
275
+ @custom_model.inference_api
276
+ def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
277
+ import shap
278
+
279
+ methods = ["to_xgboost", "to_lightgbm"]
280
+ for method_name in methods:
281
+ try:
282
+ base_model = getattr(raw_model, method_name)()
283
+ explainer = shap.TreeExplainer(base_model)
284
+ df = pd.DataFrame(explainer(X).values)
285
+ return model_signature_utils.rename_pandas_df(df, signature.outputs)
286
+ except exceptions.SnowflakeMLException:
287
+ pass # Do nothing and continue to the next method
288
+ raise ValueError("The model must be an xgboost or lightgbm estimator.")
289
+
290
+ if target_method == "explain":
291
+ return explain_fn
292
+
175
293
  return fn
176
294
 
177
295
  type_method_dict = {}
@@ -36,7 +36,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
36
36
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
37
37
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
38
38
 
39
- MODELE_BLOB_FILE_OR_DIR = "model"
39
+ MODEL_BLOB_FILE_OR_DIR = "model"
40
40
  DEFAULT_TARGET_METHODS = ["__call__"]
41
41
 
42
42
  @classmethod
@@ -68,6 +68,10 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
68
68
  is_sub_model: Optional[bool] = False,
69
69
  **kwargs: Unpack[model_types.TensorflowSaveOptions],
70
70
  ) -> None:
71
+ enable_explainability = kwargs.get("enable_explainability", False)
72
+ if enable_explainability:
73
+ raise NotImplementedError("Explainability is not supported for Tensorflow model.")
74
+
71
75
  import tensorflow
72
76
 
73
77
  assert isinstance(model, tensorflow.Module)
@@ -114,15 +118,15 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
114
118
  model_blob_path = os.path.join(model_blobs_dir_path, name)
115
119
  os.makedirs(model_blob_path, exist_ok=True)
116
120
  if isinstance(model, tensorflow.keras.Model):
117
- tensorflow.keras.models.save_model(model, os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR))
121
+ tensorflow.keras.models.save_model(model, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
118
122
  else:
119
- tensorflow.saved_model.save(model, os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR))
123
+ tensorflow.saved_model.save(model, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
120
124
 
121
125
  base_meta = model_blob_meta.ModelBlobMeta(
122
126
  name=name,
123
127
  model_type=cls.HANDLER_TYPE,
124
128
  handler_version=cls.HANDLER_VERSION,
125
- path=cls.MODELE_BLOB_FILE_OR_DIR,
129
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
126
130
  )
127
131
  model_meta.models[name] = base_meta
128
132
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
@@ -156,6 +160,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
156
160
  cls,
157
161
  raw_model: "tensorflow.Module",
158
162
  model_meta: model_meta_api.ModelMetadata,
163
+ background_data: Optional[pd.DataFrame] = None,
159
164
  **kwargs: Unpack[model_types.TensorflowLoadOptions],
160
165
  ) -> custom_model.CustomModel:
161
166
  import tensorflow
@@ -34,7 +34,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
34
34
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
35
35
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
36
36
 
37
- MODELE_BLOB_FILE_OR_DIR = "model.pt"
37
+ MODEL_BLOB_FILE_OR_DIR = "model.pt"
38
38
  DEFAULT_TARGET_METHODS = ["forward"]
39
39
 
40
40
  @classmethod
@@ -66,6 +66,10 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
66
66
  is_sub_model: Optional[bool] = False,
67
67
  **kwargs: Unpack[model_types.TorchScriptSaveOptions],
68
68
  ) -> None:
69
+ enable_explainability = kwargs.get("enable_explainability", False)
70
+ if enable_explainability:
71
+ raise NotImplementedError("Explainability is not supported for Torch Script model.")
72
+
69
73
  import torch
70
74
 
71
75
  assert isinstance(model, torch.jit.ScriptModule) # type:ignore[attr-defined]
@@ -106,13 +110,13 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
106
110
 
107
111
  model_blob_path = os.path.join(model_blobs_dir_path, name)
108
112
  os.makedirs(model_blob_path, exist_ok=True)
109
- with open(os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR), "wb") as f:
110
- torch.jit.save(model, f) # type:ignore[attr-defined]
113
+ with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
114
+ torch.jit.save(model, f) # type:ignore[no-untyped-call, attr-defined]
111
115
  base_meta = model_blob_meta.ModelBlobMeta(
112
116
  name=name,
113
117
  model_type=cls.HANDLER_TYPE,
114
118
  handler_version=cls.HANDLER_VERSION,
115
- path=cls.MODELE_BLOB_FILE_OR_DIR,
119
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
116
120
  )
117
121
  model_meta.models[name] = base_meta
118
122
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
@@ -137,7 +141,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
137
141
  model_blob_metadata = model_blobs_metadata[name]
138
142
  model_blob_filename = model_blob_metadata.path
139
143
  with open(os.path.join(model_blob_path, model_blob_filename), "rb") as f:
140
- m = torch.jit.load( # type:ignore[attr-defined]
144
+ m = torch.jit.load( # type:ignore[no-untyped-call, attr-defined]
141
145
  f, map_location="cuda" if kwargs.get("use_gpu", False) else "cpu"
142
146
  )
143
147
  assert isinstance(m, torch.jit.ScriptModule) # type:ignore[attr-defined]
@@ -152,6 +156,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
152
156
  cls,
153
157
  raw_model: "torch.jit.ScriptModule", # type:ignore[name-defined]
154
158
  model_meta: model_meta_api.ModelMetadata,
159
+ background_data: Optional[pd.DataFrame] = None,
155
160
  **kwargs: Unpack[model_types.TorchScriptLoadOptions],
156
161
  ) -> custom_model.CustomModel:
157
162
  from snowflake.ml.model import custom_model
@@ -1,6 +1,6 @@
1
1
  # mypy: disable-error-code="import"
2
- import json
3
2
  import os
3
+ import warnings
4
4
  from typing import (
5
5
  TYPE_CHECKING,
6
6
  Any,
@@ -13,14 +13,20 @@ from typing import (
13
13
  final,
14
14
  )
15
15
 
16
+ import importlib_metadata
16
17
  import numpy as np
17
18
  import pandas as pd
19
+ from packaging import version
18
20
  from typing_extensions import TypeGuard, Unpack
19
21
 
20
22
  from snowflake.ml._internal import type_utils
21
23
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
22
24
  from snowflake.ml.model._packager.model_env import model_env
23
- from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
25
+ from snowflake.ml.model._packager.model_handlers import (
26
+ _base,
27
+ _utils as handlers_utils,
28
+ model_objective_utils,
29
+ )
24
30
  from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
25
31
  from snowflake.ml.model._packager.model_meta import (
26
32
  model_blob_meta,
@@ -45,41 +51,8 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
45
51
  _MIN_SNOWPARK_ML_VERSION = "1.0.12"
46
52
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
47
53
 
48
- MODELE_BLOB_FILE_OR_DIR = "model.ubj"
54
+ MODEL_BLOB_FILE_OR_DIR = "model.ubj"
49
55
  DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
50
- _BINARY_CLASSIFICATION_OBJECTIVE_PREFIX = ["binary:"]
51
- _MULTI_CLASSIFICATION_OBJECTIVE_PREFIX = ["multi:"]
52
- _RANKING_OBJECTIVE_PREFIX = ["rank:"]
53
- _REGRESSION_OBJECTIVE_PREFIX = ["reg:"]
54
-
55
- @classmethod
56
- def get_model_objective(cls, model: Union["xgboost.Booster", "xgboost.XGBModel"]) -> _base.ModelObjective:
57
- import xgboost
58
-
59
- if isinstance(model, xgboost.XGBClassifier) or isinstance(model, xgboost.XGBRFClassifier):
60
- num_classes = handlers_utils.get_num_classes_if_exists(model)
61
- if num_classes == 2:
62
- return _base.ModelObjective.BINARY_CLASSIFICATION
63
- return _base.ModelObjective.MULTI_CLASSIFICATION
64
- if isinstance(model, xgboost.XGBRegressor) or isinstance(model, xgboost.XGBRFRegressor):
65
- return _base.ModelObjective.REGRESSION
66
- if isinstance(model, xgboost.XGBRanker):
67
- return _base.ModelObjective.RANKING
68
- model_params = json.loads(model.save_config())
69
- model_objective = model_params["learner"]["objective"]
70
- for classification_objective in cls._BINARY_CLASSIFICATION_OBJECTIVE_PREFIX:
71
- if classification_objective in model_objective:
72
- return _base.ModelObjective.BINARY_CLASSIFICATION
73
- for classification_objective in cls._MULTI_CLASSIFICATION_OBJECTIVE_PREFIX:
74
- if classification_objective in model_objective:
75
- return _base.ModelObjective.MULTI_CLASSIFICATION
76
- for ranking_objective in cls._RANKING_OBJECTIVE_PREFIX:
77
- if ranking_objective in model_objective:
78
- return _base.ModelObjective.RANKING
79
- for regression_objective in cls._REGRESSION_OBJECTIVE_PREFIX:
80
- if regression_objective in model_objective:
81
- return _base.ModelObjective.REGRESSION
82
- return _base.ModelObjective.UNKNOWN
83
56
 
84
57
  @classmethod
85
58
  def can_handle(
@@ -114,10 +87,29 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
114
87
  is_sub_model: Optional[bool] = False,
115
88
  **kwargs: Unpack[model_types.XGBModelSaveOptions],
116
89
  ) -> None:
90
+ enable_explainability = kwargs.get("enable_explainability", True)
91
+
117
92
  import xgboost
118
93
 
119
94
  assert isinstance(model, xgboost.Booster) or isinstance(model, xgboost.XGBModel)
120
95
 
96
+ local_xgb_version = None
97
+
98
+ try:
99
+ local_dist = importlib_metadata.distribution("xgboost") # type: ignore[no-untyped-call]
100
+ local_xgb_version = version.parse(local_dist.version)
101
+ except importlib_metadata.PackageNotFoundError:
102
+ pass
103
+
104
+ if local_xgb_version and local_xgb_version >= version.parse("2.1.0") and enable_explainability:
105
+ warnings.warn(
106
+ f"This version of xgboost {local_xgb_version} does not work with shap 0.42.1."
107
+ + "If you want model explanations, lower the xgboost version to <2.1.0.",
108
+ category=UserWarning,
109
+ stacklevel=1,
110
+ )
111
+ enable_explainability = False
112
+
121
113
  if not is_sub_model:
122
114
  target_methods = handlers_utils.get_target_methods(
123
115
  model=model,
@@ -146,25 +138,29 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
146
138
  sample_input_data=sample_input_data,
147
139
  get_prediction_fn=get_prediction,
148
140
  )
149
- if kwargs.get("enable_explainability", False):
150
- output_type = model_signature.DataType.DOUBLE
151
- if cls.get_model_objective(model) == _base.ModelObjective.MULTI_CLASSIFICATION:
152
- output_type = model_signature.DataType.STRING
141
+ model_objective_and_output = model_objective_utils.get_model_objective_and_output_type(model)
142
+ model_meta.model_objective = handlers_utils.validate_model_objective(
143
+ model_meta.model_objective, model_objective_and_output.objective
144
+ )
145
+ if enable_explainability:
153
146
  model_meta = handlers_utils.add_explain_method_signature(
154
147
  model_meta=model_meta,
155
148
  explain_method="explain",
156
149
  target_method="predict",
157
- output_return_type=output_type,
150
+ output_return_type=model_objective_and_output.output_type,
158
151
  )
152
+ model_meta.function_properties = {
153
+ "explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False}
154
+ }
159
155
 
160
156
  model_blob_path = os.path.join(model_blobs_dir_path, name)
161
157
  os.makedirs(model_blob_path, exist_ok=True)
162
- model.save_model(os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR))
158
+ model.save_model(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
163
159
  base_meta = model_blob_meta.ModelBlobMeta(
164
160
  name=name,
165
161
  model_type=cls.HANDLER_TYPE,
166
162
  handler_version=cls.HANDLER_VERSION,
167
- path=cls.MODELE_BLOB_FILE_OR_DIR,
163
+ path=cls.MODEL_BLOB_FILE_OR_DIR,
168
164
  options=model_meta_schema.XgboostModelBlobOptions({"xgb_estimator_type": model.__class__.__name__}),
169
165
  )
170
166
  model_meta.models[name] = base_meta
@@ -173,15 +169,27 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
173
169
  model_meta.env.include_if_absent(
174
170
  [
175
171
  model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn"),
176
- model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
177
172
  ],
178
173
  check_local_version=True,
179
174
  )
180
- if kwargs.get("enable_explainability", False):
175
+ if local_xgb_version and local_xgb_version >= version.parse("2.0.0") and enable_explainability:
181
176
  model_meta.env.include_if_absent(
182
- [model_env.ModelDependency(requirement="shap", pip_name="shap")],
177
+ [
178
+ model_env.ModelDependency(requirement="xgboost==2.0.*", pip_name="xgboost"),
179
+ ],
180
+ check_local_version=False,
181
+ )
182
+ else:
183
+ model_meta.env.include_if_absent(
184
+ [
185
+ model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
186
+ ],
183
187
  check_local_version=True,
184
188
  )
189
+
190
+ if enable_explainability:
191
+ model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
192
+ model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
185
193
  model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
186
194
 
187
195
  @classmethod
@@ -224,6 +232,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
224
232
  cls,
225
233
  raw_model: Union["xgboost.Booster", "xgboost.XGBModel"],
226
234
  model_meta: model_meta_api.ModelMetadata,
235
+ background_data: Optional[pd.DataFrame] = None,
227
236
  **kwargs: Unpack[model_types.XGBModelLoadOptions],
228
237
  ) -> custom_model.CustomModel:
229
238
  import xgboost
@@ -260,7 +269,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
260
269
  import shap
261
270
 
262
271
  explainer = shap.TreeExplainer(raw_model)
263
- df = pd.DataFrame(explainer(X).values)
272
+ df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
264
273
  return model_signature_utils.rename_pandas_df(df, signature.outputs)
265
274
 
266
275
  if target_method == "explain":
@@ -55,6 +55,7 @@ def create_model_metadata(
55
55
  conda_dependencies: Optional[List[str]] = None,
56
56
  pip_requirements: Optional[List[str]] = None,
57
57
  python_version: Optional[str] = None,
58
+ model_objective: model_types.ModelObjective = model_types.ModelObjective.UNKNOWN,
58
59
  **kwargs: Any,
59
60
  ) -> Generator["ModelMetadata", None, None]:
60
61
  """Create a generator for model metadata object. Use generator to ensure correct register and unregister for
@@ -74,6 +75,9 @@ def create_model_metadata(
74
75
  pip_requirements: List of pip Python packages requirements for running the model. Defaults to None.
75
76
  python_version: A string of python version where model is run. Used for user override. If specified as None,
76
77
  current version would be captured. Defaults to None.
78
+ model_objective: The objective of the Model Version. It is an enum class ModelObjective with values REGRESSION,
79
+ BINARY_CLASSIFICATION, MULTI_CLASSIFICATION, RANKING, or UNKNOWN. By default it is set to
80
+ ModelObjective.UNKNOWN and may be overridden by inferring from the Model Object.
77
81
  **kwargs: Dict of attributes and values of the metadata. Used when loading from file.
78
82
 
79
83
  Raises:
@@ -131,6 +135,7 @@ def create_model_metadata(
131
135
  model_type=model_type,
132
136
  signatures=signatures,
133
137
  function_properties=function_properties,
138
+ model_objective=model_objective,
134
139
  )
135
140
 
136
141
  code_dir_path = os.path.join(model_dir_path, MODEL_CODE_DIR)
@@ -237,6 +242,7 @@ class ModelMetadata:
237
242
  function_properties: A dict mapping function names to dict mapping function property key to value.
238
243
  metadata: User provided key-value metadata of the model. Defaults to None.
239
244
  creation_timestamp: Unix timestamp when the model metadata is created.
245
+ model_objective: Model objective like regression, classification etc.
240
246
  """
241
247
 
242
248
  def telemetry_metadata(self) -> ModelMetadataTelemetryDict:
@@ -260,6 +266,8 @@ class ModelMetadata:
260
266
  min_snowpark_ml_version: Optional[str] = None,
261
267
  models: Optional[Dict[str, model_blob_meta.ModelBlobMeta]] = None,
262
268
  original_metadata_version: Optional[str] = model_meta_schema.MODEL_METADATA_VERSION,
269
+ model_objective: model_types.ModelObjective = model_types.ModelObjective.UNKNOWN,
270
+ explain_algorithm: Optional[model_meta_schema.ModelExplainAlgorithm] = None,
263
271
  ) -> None:
264
272
  self.name = name
265
273
  self.signatures: Dict[str, model_signature.ModelSignature] = dict()
@@ -284,6 +292,9 @@ class ModelMetadata:
284
292
 
285
293
  self.original_metadata_version = original_metadata_version
286
294
 
295
+ self.model_objective: model_types.ModelObjective = model_objective
296
+ self.explain_algorithm: Optional[model_meta_schema.ModelExplainAlgorithm] = explain_algorithm
297
+
287
298
  @property
288
299
  def min_snowpark_ml_version(self) -> str:
289
300
  return self._min_snowpark_ml_version.base_version
@@ -321,9 +332,11 @@ class ModelMetadata:
321
332
  model_dict = model_meta_schema.ModelMetadataDict(
322
333
  {
323
334
  "creation_timestamp": self.creation_timestamp,
324
- "env": self.env.save_as_dict(pathlib.Path(model_dir_path)),
335
+ "env": self.env.save_as_dict(
336
+ pathlib.Path(model_dir_path), default_channel_override=env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
337
+ ),
325
338
  "runtimes": {
326
- runtime_name: runtime.save(pathlib.Path(model_dir_path))
339
+ runtime_name: runtime.save(pathlib.Path(model_dir_path), default_channel_override="conda-forge")
327
340
  for runtime_name, runtime in self.runtimes.items()
328
341
  },
329
342
  "metadata": self.metadata,
@@ -333,6 +346,13 @@ class ModelMetadata:
333
346
  "signatures": {func_name: sig.to_dict() for func_name, sig in self.signatures.items()},
334
347
  "version": model_meta_schema.MODEL_METADATA_VERSION,
335
348
  "min_snowpark_ml_version": self.min_snowpark_ml_version,
349
+ "model_objective": self.model_objective.value,
350
+ "explainability": (
351
+ model_meta_schema.ExplainabilityMetadataDict(algorithm=self.explain_algorithm.value)
352
+ if self.explain_algorithm
353
+ else None
354
+ ),
355
+ "function_properties": self.function_properties,
336
356
  }
337
357
  )
338
358
 
@@ -370,6 +390,9 @@ class ModelMetadata:
370
390
  signatures=loaded_meta["signatures"],
371
391
  version=original_loaded_meta_version,
372
392
  min_snowpark_ml_version=loaded_meta_min_snowpark_ml_version,
393
+ model_objective=loaded_meta.get("model_objective", model_types.ModelObjective.UNKNOWN.value),
394
+ explainability=loaded_meta.get("explainability", None),
395
+ function_properties=loaded_meta.get("function_properties", {}),
373
396
  )
374
397
 
375
398
  @classmethod
@@ -406,6 +429,11 @@ class ModelMetadata:
406
429
  else:
407
430
  runtimes = None
408
431
 
432
+ explanation_algorithm_dict = model_dict.get("explainability", None)
433
+ explanation_algorithm = None
434
+ if explanation_algorithm_dict:
435
+ explanation_algorithm = model_meta_schema.ModelExplainAlgorithm(explanation_algorithm_dict["algorithm"])
436
+
409
437
  return cls(
410
438
  name=model_dict["name"],
411
439
  model_type=model_dict["model_type"],
@@ -417,4 +445,9 @@ class ModelMetadata:
417
445
  min_snowpark_ml_version=model_dict["min_snowpark_ml_version"],
418
446
  models=models,
419
447
  original_metadata_version=model_dict["version"],
448
+ model_objective=model_types.ModelObjective(
449
+ model_dict.get("model_objective", model_types.ModelObjective.UNKNOWN.value)
450
+ ),
451
+ explain_algorithm=explanation_algorithm,
452
+ function_properties=model_dict.get("function_properties", {}),
420
453
  )
@@ -71,6 +71,10 @@ ModelBlobOptions = Union[
71
71
  ]
72
72
 
73
73
 
74
+ class ExplainabilityMetadataDict(TypedDict):
75
+ algorithm: Required[str]
76
+
77
+
74
78
  class ModelBlobMetadataDict(TypedDict):
75
79
  name: Required[str]
76
80
  model_type: Required[type_hints.SupportedModelHandlerType]
@@ -92,3 +96,10 @@ class ModelMetadataDict(TypedDict):
92
96
  signatures: Required[Dict[str, Dict[str, Any]]]
93
97
  version: Required[str]
94
98
  min_snowpark_ml_version: Required[str]
99
+ model_objective: Required[str]
100
+ explainability: NotRequired[Optional[ExplainabilityMetadataDict]]
101
+ function_properties: NotRequired[Dict[str, Dict[str, Any]]]
102
+
103
+
104
+ class ModelExplainAlgorithm(Enum):
105
+ SHAP = "shap"
@@ -47,6 +47,7 @@ class ModelPackager:
47
47
  ext_modules: Optional[List[ModuleType]] = None,
48
48
  code_paths: Optional[List[str]] = None,
49
49
  options: Optional[model_types.ModelSaveOption] = None,
50
+ model_objective: model_types.ModelObjective = model_types.ModelObjective.UNKNOWN,
50
51
  ) -> model_meta.ModelMetadata:
51
52
  if (signatures is None) and (sample_input_data is None) and not model_handler.is_auto_signature_model(model):
52
53
  raise snowml_exceptions.SnowflakeMLException(
@@ -84,6 +85,7 @@ class ModelPackager:
84
85
  conda_dependencies=conda_dependencies,
85
86
  pip_requirements=pip_requirements,
86
87
  python_version=python_version,
88
+ model_objective=model_objective,
87
89
  **options,
88
90
  ) as meta:
89
91
  model_blobs_path = os.path.join(self.local_dir_path, ModelPackager.MODEL_BLOBS_DIR)
@@ -146,7 +148,8 @@ class ModelPackager:
146
148
  m = handler.load_model(self.meta.name, self.meta, model_blobs_path, **options)
147
149
 
148
150
  if as_custom_model:
149
- m = handler.convert_as_custom_model(m, self.meta, **options)
151
+ background_data = handler.load_background_data(self.meta.name, model_blobs_path)
152
+ m = handler.convert_as_custom_model(m, self.meta, background_data, **options)
150
153
  assert isinstance(m, custom_model.CustomModel)
151
154
 
152
155
  self.model = m
@@ -67,7 +67,9 @@ class ModelRuntime:
67
67
  def runtime_rel_path(self) -> pathlib.PurePosixPath:
68
68
  return pathlib.PurePosixPath(ModelRuntime.RUNTIME_DIR_REL_PATH) / self.name
69
69
 
70
- def save(self, packager_path: pathlib.Path) -> model_meta_schema.ModelRuntimeDict:
70
+ def save(
71
+ self, packager_path: pathlib.Path, default_channel_override: str = env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
72
+ ) -> model_meta_schema.ModelRuntimeDict:
71
73
  runtime_base_path = packager_path / self.runtime_rel_path
72
74
  runtime_base_path.mkdir(parents=True, exist_ok=True)
73
75
 
@@ -80,7 +82,7 @@ class ModelRuntime:
80
82
  self.runtime_env.conda_env_rel_path = self.runtime_rel_path / self.runtime_env.conda_env_rel_path
81
83
  self.runtime_env.pip_requirements_rel_path = self.runtime_rel_path / self.runtime_env.pip_requirements_rel_path
82
84
 
83
- env_dict = self.runtime_env.save_as_dict(packager_path)
85
+ env_dict = self.runtime_env.save_as_dict(packager_path, default_channel_override=default_channel_override)
84
86
 
85
87
  return model_meta_schema.ModelRuntimeDict(
86
88
  imports=list(map(str, self.imports)),
@@ -30,7 +30,7 @@ class SeqOfPyTorchTensorHandler(base_handler.BaseDataHandler[Sequence["torch.Ten
30
30
 
31
31
  @staticmethod
32
32
  def count(data: Sequence["torch.Tensor"]) -> int:
33
- return min(data_col.shape[0] for data_col in data)
33
+ return min(data_col.shape[0] for data_col in data) # type: ignore[no-any-return]
34
34
 
35
35
  @staticmethod
36
36
  def truncate(data: Sequence["torch.Tensor"]) -> Sequence["torch.Tensor"]: