snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. snowflake/ml/_internal/telemetry.py +142 -20
  2. snowflake/ml/_internal/utils/identifier.py +48 -11
  3. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  4. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  5. snowflake/ml/_internal/utils/table_manager.py +19 -1
  6. snowflake/ml/_internal/utils/uri.py +2 -2
  7. snowflake/ml/data/data_connector.py +33 -7
  8. snowflake/ml/data/torch_utils.py +68 -0
  9. snowflake/ml/dataset/dataset.py +1 -3
  10. snowflake/ml/feature_store/feature_store.py +41 -17
  11. snowflake/ml/feature_store/feature_view.py +2 -2
  12. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  13. snowflake/ml/fileset/fileset.py +1 -1
  14. snowflake/ml/fileset/sfcfs.py +9 -3
  15. snowflake/ml/model/_client/model/model_version_impl.py +22 -7
  16. snowflake/ml/model/_client/ops/model_ops.py +39 -3
  17. snowflake/ml/model/_client/ops/service_ops.py +198 -7
  18. snowflake/ml/model/_client/service/model_deployment_spec.py +4 -5
  19. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -2
  20. snowflake/ml/model/_client/sql/service.py +85 -18
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  23. snowflake/ml/model/_model_composer/model_composer.py +2 -0
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -8
  25. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  26. snowflake/ml/model/_packager/model_handlers/catboost.py +17 -15
  27. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +23 -15
  28. snowflake/ml/model/_packager/model_handlers/lightgbm.py +15 -57
  29. snowflake/ml/model/_packager/model_handlers/llm.py +4 -2
  30. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  31. snowflake/ml/model/_packager/model_handlers/sklearn.py +36 -24
  32. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +119 -6
  33. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  34. snowflake/ml/model/_packager/model_handlers/xgboost.py +48 -48
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +10 -7
  36. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +0 -8
  37. snowflake/ml/model/_packager/model_packager.py +2 -0
  38. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  39. snowflake/ml/model/_signatures/utils.py +9 -0
  40. snowflake/ml/model/models/llm.py +3 -1
  41. snowflake/ml/model/type_hints.py +9 -1
  42. snowflake/ml/modeling/_internal/constants.py +1 -0
  43. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  44. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  45. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  46. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  48. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  49. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  50. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  51. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  52. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  53. snowflake/ml/modeling/cluster/birch.py +60 -21
  54. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  55. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  56. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  57. snowflake/ml/modeling/cluster/k_means.py +60 -21
  58. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  59. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  60. snowflake/ml/modeling/cluster/optics.py +60 -21
  61. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  62. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  63. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  64. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  65. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  66. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  67. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  68. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  69. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  70. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  71. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  72. snowflake/ml/modeling/covariance/oas.py +60 -21
  73. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  74. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  75. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  76. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  77. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  78. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  79. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  80. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  81. snowflake/ml/modeling/decomposition/pca.py +60 -21
  82. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  83. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  84. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  85. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  86. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  87. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  88. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  89. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  90. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  91. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  92. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  93. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  94. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  95. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  96. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  97. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  98. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  99. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  100. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  101. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  102. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  103. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  104. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  105. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  106. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  107. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  108. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  109. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  111. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  112. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  113. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  114. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  123. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  125. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  129. snowflake/ml/modeling/linear_model/lars.py +60 -21
  130. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  131. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  136. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  146. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  149. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  158. snowflake/ml/modeling/manifold/isomap.py +60 -21
  159. snowflake/ml/modeling/manifold/mds.py +60 -21
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  161. snowflake/ml/modeling/manifold/tsne.py +60 -21
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  174. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  184. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  185. snowflake/ml/modeling/pipeline/pipeline.py +1 -12
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  189. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  190. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  191. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  192. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  193. snowflake/ml/modeling/svm/svc.py +60 -21
  194. snowflake/ml/modeling/svm/svr.py +60 -21
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  203. snowflake/ml/registry/_manager/model_manager.py +4 -0
  204. snowflake/ml/registry/model_registry.py +1 -1
  205. snowflake/ml/registry/registry.py +1 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +23 -4
  208. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +211 -209
  209. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  210. snowflake/ml/data/torch_dataset.py +0 -33
  211. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  212. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,116 @@
1
+ import json
2
+ from dataclasses import dataclass
3
+ from typing import TYPE_CHECKING, Any, Union
4
+
5
+ from snowflake.ml.model import model_signature, type_hints
6
+ from snowflake.ml.model._packager.model_handlers import _utils as handlers_utils
7
+
8
+ if TYPE_CHECKING:
9
+ import lightgbm
10
+ import xgboost
11
+
12
+
13
+ @dataclass
14
+ class ModelObjectiveAndOutputType:
15
+ objective: type_hints.ModelObjective
16
+ output_type: model_signature.DataType
17
+
18
+
19
+ def get_model_objective_lightgbm(model: Union["lightgbm.Booster", "lightgbm.LGBMModel"]) -> type_hints.ModelObjective:
20
+
21
+ import lightgbm
22
+
23
+ _BINARY_CLASSIFICATION_OBJECTIVES = ["binary"]
24
+ _MULTI_CLASSIFICATION_OBJECTIVES = ["multiclass", "multiclassova"]
25
+ _RANKING_OBJECTIVES = ["lambdarank", "rank_xendcg"]
26
+ _REGRESSION_OBJECTIVES = [
27
+ "regression",
28
+ "regression_l1",
29
+ "huber",
30
+ "fair",
31
+ "poisson",
32
+ "quantile",
33
+ "tweedie",
34
+ "mape",
35
+ "gamma",
36
+ ]
37
+
38
+ # does not account for cross-entropy and custom
39
+ if isinstance(model, lightgbm.LGBMClassifier):
40
+ num_classes = handlers_utils.get_num_classes_if_exists(model)
41
+ if num_classes == 2:
42
+ return type_hints.ModelObjective.BINARY_CLASSIFICATION
43
+ return type_hints.ModelObjective.MULTI_CLASSIFICATION
44
+ if isinstance(model, lightgbm.LGBMRanker):
45
+ return type_hints.ModelObjective.RANKING
46
+ if isinstance(model, lightgbm.LGBMRegressor):
47
+ return type_hints.ModelObjective.REGRESSION
48
+ model_objective = model.params["objective"]
49
+ if model_objective in _BINARY_CLASSIFICATION_OBJECTIVES:
50
+ return type_hints.ModelObjective.BINARY_CLASSIFICATION
51
+ if model_objective in _MULTI_CLASSIFICATION_OBJECTIVES:
52
+ return type_hints.ModelObjective.MULTI_CLASSIFICATION
53
+ if model_objective in _RANKING_OBJECTIVES:
54
+ return type_hints.ModelObjective.RANKING
55
+ if model_objective in _REGRESSION_OBJECTIVES:
56
+ return type_hints.ModelObjective.REGRESSION
57
+ return type_hints.ModelObjective.UNKNOWN
58
+
59
+
60
+ def get_model_objective_xgb(model: Union["xgboost.Booster", "xgboost.XGBModel"]) -> type_hints.ModelObjective:
61
+
62
+ import xgboost
63
+
64
+ _BINARY_CLASSIFICATION_OBJECTIVE_PREFIX = ["binary:"]
65
+ _MULTI_CLASSIFICATION_OBJECTIVE_PREFIX = ["multi:"]
66
+ _RANKING_OBJECTIVE_PREFIX = ["rank:"]
67
+ _REGRESSION_OBJECTIVE_PREFIX = ["reg:"]
68
+
69
+ model_objective = ""
70
+ if isinstance(model, xgboost.Booster):
71
+ model_params = json.loads(model.save_config())
72
+ model_objective = model_params.get("learner", {}).get("objective", "")
73
+ else:
74
+ if hasattr(model, "get_params"):
75
+ model_objective = model.get_params().get("objective", "")
76
+
77
+ if isinstance(model_objective, dict):
78
+ model_objective = model_objective.get("name", "")
79
+ for classification_objective in _BINARY_CLASSIFICATION_OBJECTIVE_PREFIX:
80
+ if classification_objective in model_objective:
81
+ return type_hints.ModelObjective.BINARY_CLASSIFICATION
82
+ for classification_objective in _MULTI_CLASSIFICATION_OBJECTIVE_PREFIX:
83
+ if classification_objective in model_objective:
84
+ return type_hints.ModelObjective.MULTI_CLASSIFICATION
85
+ for ranking_objective in _RANKING_OBJECTIVE_PREFIX:
86
+ if ranking_objective in model_objective:
87
+ return type_hints.ModelObjective.RANKING
88
+ for regression_objective in _REGRESSION_OBJECTIVE_PREFIX:
89
+ if regression_objective in model_objective:
90
+ return type_hints.ModelObjective.REGRESSION
91
+ return type_hints.ModelObjective.UNKNOWN
92
+
93
+
94
+ def get_model_objective_and_output_type(model: Any) -> ModelObjectiveAndOutputType:
95
+ import xgboost
96
+
97
+ if isinstance(model, xgboost.Booster) or isinstance(model, xgboost.XGBModel):
98
+ model_objective = get_model_objective_xgb(model)
99
+ output_type = model_signature.DataType.DOUBLE
100
+ if model_objective == type_hints.ModelObjective.MULTI_CLASSIFICATION:
101
+ output_type = model_signature.DataType.STRING
102
+ return ModelObjectiveAndOutputType(objective=model_objective, output_type=output_type)
103
+
104
+ import lightgbm
105
+
106
+ if isinstance(model, lightgbm.Booster) or isinstance(model, lightgbm.LGBMModel):
107
+ model_objective = get_model_objective_lightgbm(model)
108
+ output_type = model_signature.DataType.DOUBLE
109
+ if model_objective in [
110
+ type_hints.ModelObjective.BINARY_CLASSIFICATION,
111
+ type_hints.ModelObjective.MULTI_CLASSIFICATION,
112
+ ]:
113
+ output_type = model_signature.DataType.STRING
114
+ return ModelObjectiveAndOutputType(objective=model_objective, output_type=output_type)
115
+
116
+ raise ValueError(f"Model type {type(model)} is not supported")
@@ -45,23 +45,23 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
45
45
  @classmethod
46
46
  def get_model_objective(
47
47
  cls, model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"]
48
- ) -> model_meta_schema.ModelObjective:
48
+ ) -> model_types.ModelObjective:
49
49
  import sklearn.pipeline
50
50
  from sklearn.base import is_classifier, is_regressor
51
51
 
52
52
  if isinstance(model, sklearn.pipeline.Pipeline):
53
- return model_meta_schema.ModelObjective.UNKNOWN
53
+ return model_types.ModelObjective.UNKNOWN
54
54
  if is_regressor(model):
55
- return model_meta_schema.ModelObjective.REGRESSION
55
+ return model_types.ModelObjective.REGRESSION
56
56
  if is_classifier(model):
57
57
  classes_list = getattr(model, "classes_", [])
58
58
  num_classes = getattr(model, "n_classes_", None) or len(classes_list)
59
59
  if isinstance(num_classes, int):
60
60
  if num_classes > 2:
61
- return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
62
- return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
63
- return model_meta_schema.ModelObjective.UNKNOWN
64
- return model_meta_schema.ModelObjective.UNKNOWN
61
+ return model_types.ModelObjective.MULTI_CLASSIFICATION
62
+ return model_types.ModelObjective.BINARY_CLASSIFICATION
63
+ return model_types.ModelObjective.UNKNOWN
64
+ return model_types.ModelObjective.UNKNOWN
65
65
 
66
66
  @classmethod
67
67
  def can_handle(
@@ -95,6 +95,18 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
95
95
 
96
96
  return cast(Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"], model)
97
97
 
98
+ @staticmethod
99
+ def get_explainability_supported_background(
100
+ sample_input_data: Optional[model_types.SupportedDataType] = None,
101
+ ) -> Optional[pd.DataFrame]:
102
+ if isinstance(sample_input_data, pd.DataFrame) or isinstance(sample_input_data, sp_df.DataFrame):
103
+ return (
104
+ sample_input_data
105
+ if isinstance(sample_input_data, pd.DataFrame)
106
+ else snowpark_handler.SnowparkDataFrameHandler.convert_to_df(sample_input_data)
107
+ )
108
+ return None
109
+
98
110
  @classmethod
99
111
  def save_model(
100
112
  cls,
@@ -106,32 +118,30 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
106
118
  is_sub_model: Optional[bool] = False,
107
119
  **kwargs: Unpack[model_types.SKLModelSaveOptions],
108
120
  ) -> None:
109
- enable_explainability = kwargs.get("enable_explainability", False)
121
+ # setting None by default to distinguish if users did not set it
122
+ enable_explainability = kwargs.get("enable_explainability", None)
110
123
 
111
124
  import sklearn.base
112
125
  import sklearn.pipeline
113
126
 
114
127
  assert isinstance(model, sklearn.base.BaseEstimator) or isinstance(model, sklearn.pipeline.Pipeline)
115
128
 
116
- enable_explainability = kwargs.get("enable_explainability", False)
129
+ background_data = cls.get_explainability_supported_background(sample_input_data)
130
+
131
+ # if users did not ask then we enable if we have background data
132
+ if enable_explainability is None and background_data is not None:
133
+ enable_explainability = True
117
134
  if enable_explainability:
118
- # TODO: Currently limited to pandas df, need to extend to other types.
119
- if sample_input_data is None or not (
120
- isinstance(sample_input_data, pd.DataFrame) or isinstance(sample_input_data, sp_df.DataFrame)
121
- ):
135
+ # if users set it explicitly but no background data then error out
136
+ if background_data is None:
122
137
  raise ValueError(
123
138
  "Sample input data is required to enable explainability. Currently we only support this for "
124
139
  + "`pandas.DataFrame` and `snowflake.snowpark.dataframe.DataFrame`."
125
140
  )
126
- sample_input_data_pandas = (
127
- sample_input_data
128
- if isinstance(sample_input_data, pd.DataFrame)
129
- else snowpark_handler.SnowparkDataFrameHandler.convert_to_df(sample_input_data)
130
- )
131
141
  data_blob_path = os.path.join(model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR)
132
142
  os.makedirs(data_blob_path, exist_ok=True)
133
143
  with open(os.path.join(data_blob_path, name + cls.BG_DATA_FILE_SUFFIX), "wb") as f:
134
- sample_input_data_pandas.to_parquet(f)
144
+ background_data.to_parquet(f)
135
145
 
136
146
  if not is_sub_model:
137
147
  target_methods = handlers_utils.get_target_methods(
@@ -159,9 +169,13 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
159
169
  get_prediction_fn=get_prediction,
160
170
  )
161
171
 
172
+ model_objective = cls.get_model_objective(model)
173
+ model_meta.model_objective = model_objective
174
+
162
175
  if enable_explainability:
163
176
  output_type = model_signature.DataType.DOUBLE
164
- if cls.get_model_objective(model) == model_meta_schema.ModelObjective.MULTI_CLASSIFICATION:
177
+
178
+ if model_objective == model_types.ModelObjective.MULTI_CLASSIFICATION:
165
179
  output_type = model_signature.DataType.STRING
166
180
  model_meta = handlers_utils.add_explain_method_signature(
167
181
  model_meta=model_meta,
@@ -184,10 +198,8 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
184
198
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
185
199
 
186
200
  if enable_explainability:
187
- model_meta.env.include_if_absent(
188
- [model_env.ModelDependency(requirement="shap", pip_name="shap")],
189
- check_local_version=True,
190
- )
201
+ model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
202
+ model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
191
203
 
192
204
  model_meta.env.include_if_absent(
193
205
  [model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn")], check_local_version=True
@@ -1,20 +1,27 @@
1
1
  import os
2
2
  import warnings
3
- from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
3
+ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Type, cast, final
4
4
 
5
5
  import cloudpickle
6
6
  import numpy as np
7
7
  import pandas as pd
8
+ from packaging import version
8
9
  from typing_extensions import TypeGuard, Unpack
9
10
 
10
11
  from snowflake.ml._internal import type_utils
12
+ from snowflake.ml._internal.exceptions import exceptions
11
13
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
12
14
  from snowflake.ml.model._packager.model_env import model_env
13
- from snowflake.ml.model._packager.model_handlers import _base
15
+ from snowflake.ml.model._packager.model_handlers import (
16
+ _base,
17
+ _utils as handlers_utils,
18
+ model_objective_utils,
19
+ )
14
20
  from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
15
21
  from snowflake.ml.model._packager.model_meta import (
16
22
  model_blob_meta,
17
23
  model_meta as model_meta_api,
24
+ model_meta_schema,
18
25
  )
19
26
  from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
20
27
 
@@ -62,6 +69,53 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
62
69
 
63
70
  return cast("BaseEstimator", model)
64
71
 
72
+ @classmethod
73
+ def _get_local_version_package(cls, pkg_name: str) -> Optional[version.Version]:
74
+ import importlib_metadata
75
+ from packaging import version
76
+
77
+ local_version = None
78
+
79
+ try:
80
+ local_dist = importlib_metadata.distribution(pkg_name) # type: ignore[no-untyped-call]
81
+ local_version = version.parse(local_dist.version)
82
+ except importlib_metadata.PackageNotFoundError:
83
+ pass
84
+
85
+ return local_version
86
+
87
+ @classmethod
88
+ def _can_support_xgb(cls, enable_explainability: Optional[bool]) -> bool:
89
+
90
+ local_xgb_version = cls._get_local_version_package("xgboost")
91
+
92
+ if local_xgb_version and local_xgb_version >= version.parse("2.1.0"):
93
+ if enable_explainability:
94
+ warnings.warn(
95
+ f"This version of xgboost {local_xgb_version} does not work with shap 0.42.1."
96
+ + "If you want model explanations, lower the xgboost version to <2.1.0.",
97
+ category=UserWarning,
98
+ stacklevel=1,
99
+ )
100
+ return False
101
+ return True
102
+
103
+ @classmethod
104
+ def _get_supported_object_for_explainability(
105
+ cls, estimator: "BaseEstimator", enable_explainability: Optional[bool]
106
+ ) -> Any:
107
+ methods = ["to_xgboost", "to_lightgbm"]
108
+ for method_name in methods:
109
+ if hasattr(estimator, method_name):
110
+ try:
111
+ result = getattr(estimator, method_name)()
112
+ if method_name == "to_xgboost" and not cls._can_support_xgb(enable_explainability):
113
+ return None
114
+ return result
115
+ except exceptions.SnowflakeMLException:
116
+ pass # Do nothing and continue to the next method
117
+ return None
118
+
65
119
  @classmethod
66
120
  def save_model(
67
121
  cls,
@@ -73,9 +127,8 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
73
127
  is_sub_model: Optional[bool] = False,
74
128
  **kwargs: Unpack[model_types.SNOWModelSaveOptions],
75
129
  ) -> None:
76
- enable_explainability = kwargs.get("enable_explainability", False)
77
- if enable_explainability:
78
- raise NotImplementedError("Explainability is not supported for Snowpark ML model.")
130
+
131
+ enable_explainability = kwargs.get("enable_explainability", None)
79
132
 
80
133
  from snowflake.ml.modeling.framework.base import BaseEstimator
81
134
 
@@ -105,6 +158,26 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
105
158
  raise ValueError(f"Target method {method_name} does not exist in the model.")
106
159
  model_meta.signatures = temp_model_signature_dict
107
160
 
161
+ if enable_explainability or enable_explainability is None:
162
+ python_base_obj = cls._get_supported_object_for_explainability(model, enable_explainability)
163
+ if python_base_obj is None:
164
+ if enable_explainability: # if user set enable_explainability to True, throw error else silently skip
165
+ raise ValueError("Explain only support for xgboost or lightgbm Snowpark ML models.")
166
+ # set None to False so we don't include shap in the environment
167
+ enable_explainability = False
168
+ else:
169
+ model_objective_and_output_type = model_objective_utils.get_model_objective_and_output_type(
170
+ python_base_obj
171
+ )
172
+ model_meta.model_objective = model_objective_and_output_type.objective
173
+ model_meta = handlers_utils.add_explain_method_signature(
174
+ model_meta=model_meta,
175
+ explain_method="explain",
176
+ target_method="predict",
177
+ output_return_type=model_objective_and_output_type.output_type,
178
+ )
179
+ enable_explainability = True
180
+
108
181
  model_blob_path = os.path.join(model_blobs_dir_path, name)
109
182
  os.makedirs(model_blob_path, exist_ok=True)
110
183
  with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
@@ -122,7 +195,29 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
122
195
  model_dependencies = model._get_dependencies()
123
196
  for dep in model_dependencies:
124
197
  pkg_name = dep.split("==")[0]
125
- _include_if_absent_pkgs.append(model_env.ModelDependency(requirement=pkg_name, pip_name=pkg_name))
198
+ if pkg_name != "xgboost":
199
+ _include_if_absent_pkgs.append(model_env.ModelDependency(requirement=pkg_name, pip_name=pkg_name))
200
+ continue
201
+
202
+ local_xgb_version = cls._get_local_version_package("xgboost")
203
+ if local_xgb_version and local_xgb_version >= version.parse("2.0.0") and enable_explainability:
204
+ model_meta.env.include_if_absent(
205
+ [
206
+ model_env.ModelDependency(requirement="xgboost==2.0.*", pip_name="xgboost"),
207
+ ],
208
+ check_local_version=False,
209
+ )
210
+ else:
211
+ model_meta.env.include_if_absent(
212
+ [
213
+ model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
214
+ ],
215
+ check_local_version=True,
216
+ )
217
+
218
+ if enable_explainability:
219
+ model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
220
+ model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
126
221
  model_meta.env.include_if_absent(_include_if_absent_pkgs, check_local_version=True)
127
222
 
128
223
  @classmethod
@@ -177,6 +272,24 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
177
272
 
178
273
  return model_signature_utils.rename_pandas_df(df, signature.outputs)
179
274
 
275
+ @custom_model.inference_api
276
+ def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
277
+ import shap
278
+
279
+ methods = ["to_xgboost", "to_lightgbm"]
280
+ for method_name in methods:
281
+ try:
282
+ base_model = getattr(raw_model, method_name)()
283
+ explainer = shap.TreeExplainer(base_model)
284
+ df = pd.DataFrame(explainer(X).values)
285
+ return model_signature_utils.rename_pandas_df(df, signature.outputs)
286
+ except exceptions.SnowflakeMLException:
287
+ pass # Do nothing and continue to the next method
288
+ raise ValueError("The model must be an xgboost or lightgbm estimator.")
289
+
290
+ if target_method == "explain":
291
+ return explain_fn
292
+
180
293
  return fn
181
294
 
182
295
  type_method_dict = {}
@@ -111,7 +111,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
111
111
  model_blob_path = os.path.join(model_blobs_dir_path, name)
112
112
  os.makedirs(model_blob_path, exist_ok=True)
113
113
  with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
114
- torch.jit.save(model, f) # type:ignore[attr-defined]
114
+ torch.jit.save(model, f) # type:ignore[no-untyped-call, attr-defined]
115
115
  base_meta = model_blob_meta.ModelBlobMeta(
116
116
  name=name,
117
117
  model_type=cls.HANDLER_TYPE,
@@ -141,7 +141,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
141
141
  model_blob_metadata = model_blobs_metadata[name]
142
142
  model_blob_filename = model_blob_metadata.path
143
143
  with open(os.path.join(model_blob_path, model_blob_filename), "rb") as f:
144
- m = torch.jit.load( # type:ignore[attr-defined]
144
+ m = torch.jit.load( # type:ignore[no-untyped-call, attr-defined]
145
145
  f, map_location="cuda" if kwargs.get("use_gpu", False) else "cpu"
146
146
  )
147
147
  assert isinstance(m, torch.jit.ScriptModule) # type:ignore[attr-defined]
@@ -1,6 +1,6 @@
1
1
  # mypy: disable-error-code="import"
2
- import json
3
2
  import os
3
+ import warnings
4
4
  from typing import (
5
5
  TYPE_CHECKING,
6
6
  Any,
@@ -13,14 +13,20 @@ from typing import (
13
13
  final,
14
14
  )
15
15
 
16
+ import importlib_metadata
16
17
  import numpy as np
17
18
  import pandas as pd
19
+ from packaging import version
18
20
  from typing_extensions import TypeGuard, Unpack
19
21
 
20
22
  from snowflake.ml._internal import type_utils
21
23
  from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
22
24
  from snowflake.ml.model._packager.model_env import model_env
23
- from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
25
+ from snowflake.ml.model._packager.model_handlers import (
26
+ _base,
27
+ _utils as handlers_utils,
28
+ model_objective_utils,
29
+ )
24
30
  from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
25
31
  from snowflake.ml.model._packager.model_meta import (
26
32
  model_blob_meta,
@@ -47,41 +53,6 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
47
53
 
48
54
  MODEL_BLOB_FILE_OR_DIR = "model.ubj"
49
55
  DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
50
- _BINARY_CLASSIFICATION_OBJECTIVE_PREFIX = ["binary:"]
51
- _MULTI_CLASSIFICATION_OBJECTIVE_PREFIX = ["multi:"]
52
- _RANKING_OBJECTIVE_PREFIX = ["rank:"]
53
- _REGRESSION_OBJECTIVE_PREFIX = ["reg:"]
54
-
55
- @classmethod
56
- def get_model_objective(
57
- cls, model: Union["xgboost.Booster", "xgboost.XGBModel"]
58
- ) -> model_meta_schema.ModelObjective:
59
- import xgboost
60
-
61
- if isinstance(model, xgboost.XGBClassifier) or isinstance(model, xgboost.XGBRFClassifier):
62
- num_classes = handlers_utils.get_num_classes_if_exists(model)
63
- if num_classes == 2:
64
- return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
65
- return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
66
- if isinstance(model, xgboost.XGBRegressor) or isinstance(model, xgboost.XGBRFRegressor):
67
- return model_meta_schema.ModelObjective.REGRESSION
68
- if isinstance(model, xgboost.XGBRanker):
69
- return model_meta_schema.ModelObjective.RANKING
70
- model_params = json.loads(model.save_config())
71
- model_objective = model_params["learner"]["objective"]
72
- for classification_objective in cls._BINARY_CLASSIFICATION_OBJECTIVE_PREFIX:
73
- if classification_objective in model_objective:
74
- return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
75
- for classification_objective in cls._MULTI_CLASSIFICATION_OBJECTIVE_PREFIX:
76
- if classification_objective in model_objective:
77
- return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
78
- for ranking_objective in cls._RANKING_OBJECTIVE_PREFIX:
79
- if ranking_objective in model_objective:
80
- return model_meta_schema.ModelObjective.RANKING
81
- for regression_objective in cls._REGRESSION_OBJECTIVE_PREFIX:
82
- if regression_objective in model_objective:
83
- return model_meta_schema.ModelObjective.REGRESSION
84
- return model_meta_schema.ModelObjective.UNKNOWN
85
56
 
86
57
  @classmethod
87
58
  def can_handle(
@@ -116,10 +87,29 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
116
87
  is_sub_model: Optional[bool] = False,
117
88
  **kwargs: Unpack[model_types.XGBModelSaveOptions],
118
89
  ) -> None:
90
+ enable_explainability = kwargs.get("enable_explainability", True)
91
+
119
92
  import xgboost
120
93
 
121
94
  assert isinstance(model, xgboost.Booster) or isinstance(model, xgboost.XGBModel)
122
95
 
96
+ local_xgb_version = None
97
+
98
+ try:
99
+ local_dist = importlib_metadata.distribution("xgboost") # type: ignore[no-untyped-call]
100
+ local_xgb_version = version.parse(local_dist.version)
101
+ except importlib_metadata.PackageNotFoundError:
102
+ pass
103
+
104
+ if local_xgb_version and local_xgb_version >= version.parse("2.1.0") and enable_explainability:
105
+ warnings.warn(
106
+ f"This version of xgboost {local_xgb_version} does not work with shap 0.42.1."
107
+ + "If you want model explanations, lower the xgboost version to <2.1.0.",
108
+ category=UserWarning,
109
+ stacklevel=1,
110
+ )
111
+ enable_explainability = False
112
+
123
113
  if not is_sub_model:
124
114
  target_methods = handlers_utils.get_target_methods(
125
115
  model=model,
@@ -148,17 +138,16 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
148
138
  sample_input_data=sample_input_data,
149
139
  get_prediction_fn=get_prediction,
150
140
  )
151
- model_objective = cls.get_model_objective(model)
152
- model_meta.model_objective = model_objective
153
- if kwargs.get("enable_explainability", True):
154
- output_type = model_signature.DataType.DOUBLE
155
- if model_objective == model_meta_schema.ModelObjective.MULTI_CLASSIFICATION:
156
- output_type = model_signature.DataType.STRING
141
+ model_objective_and_output = model_objective_utils.get_model_objective_and_output_type(model)
142
+ model_meta.model_objective = handlers_utils.validate_model_objective(
143
+ model_meta.model_objective, model_objective_and_output.objective
144
+ )
145
+ if enable_explainability:
157
146
  model_meta = handlers_utils.add_explain_method_signature(
158
147
  model_meta=model_meta,
159
148
  explain_method="explain",
160
149
  target_method="predict",
161
- output_return_type=output_type,
150
+ output_return_type=model_objective_and_output.output_type,
162
151
  )
163
152
  model_meta.function_properties = {
164
153
  "explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False}
@@ -180,15 +169,26 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
180
169
  model_meta.env.include_if_absent(
181
170
  [
182
171
  model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn"),
183
- model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
184
172
  ],
185
173
  check_local_version=True,
186
174
  )
187
- if kwargs.get("enable_explainability", True):
175
+ if local_xgb_version and local_xgb_version >= version.parse("2.0.0") and enable_explainability:
188
176
  model_meta.env.include_if_absent(
189
- [model_env.ModelDependency(requirement="shap", pip_name="shap")],
177
+ [
178
+ model_env.ModelDependency(requirement="xgboost==2.0.*", pip_name="xgboost"),
179
+ ],
180
+ check_local_version=False,
181
+ )
182
+ else:
183
+ model_meta.env.include_if_absent(
184
+ [
185
+ model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
186
+ ],
190
187
  check_local_version=True,
191
188
  )
189
+
190
+ if enable_explainability:
191
+ model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
192
192
  model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
193
193
  model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
194
194
 
@@ -269,7 +269,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
269
269
  import shap
270
270
 
271
271
  explainer = shap.TreeExplainer(raw_model)
272
- df = pd.DataFrame(explainer(X).values)
272
+ df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
273
273
  return model_signature_utils.rename_pandas_df(df, signature.outputs)
274
274
 
275
275
  if target_method == "explain":
@@ -55,6 +55,7 @@ def create_model_metadata(
55
55
  conda_dependencies: Optional[List[str]] = None,
56
56
  pip_requirements: Optional[List[str]] = None,
57
57
  python_version: Optional[str] = None,
58
+ model_objective: model_types.ModelObjective = model_types.ModelObjective.UNKNOWN,
58
59
  **kwargs: Any,
59
60
  ) -> Generator["ModelMetadata", None, None]:
60
61
  """Create a generator for model metadata object. Use generator to ensure correct register and unregister for
@@ -74,6 +75,9 @@ def create_model_metadata(
74
75
  pip_requirements: List of pip Python packages requirements for running the model. Defaults to None.
75
76
  python_version: A string of python version where model is run. Used for user override. If specified as None,
76
77
  current version would be captured. Defaults to None.
78
+ model_objective: The objective of the Model Version. It is an enum class ModelObjective with values REGRESSION,
79
+ BINARY_CLASSIFICATION, MULTI_CLASSIFICATION, RANKING, or UNKNOWN. By default it is set to
80
+ ModelObjective.UNKNOWN and may be overridden by inferring from the Model Object.
77
81
  **kwargs: Dict of attributes and values of the metadata. Used when loading from file.
78
82
 
79
83
  Raises:
@@ -131,6 +135,7 @@ def create_model_metadata(
131
135
  model_type=model_type,
132
136
  signatures=signatures,
133
137
  function_properties=function_properties,
138
+ model_objective=model_objective,
134
139
  )
135
140
 
136
141
  code_dir_path = os.path.join(model_dir_path, MODEL_CODE_DIR)
@@ -261,7 +266,7 @@ class ModelMetadata:
261
266
  min_snowpark_ml_version: Optional[str] = None,
262
267
  models: Optional[Dict[str, model_blob_meta.ModelBlobMeta]] = None,
263
268
  original_metadata_version: Optional[str] = model_meta_schema.MODEL_METADATA_VERSION,
264
- model_objective: Optional[model_meta_schema.ModelObjective] = model_meta_schema.ModelObjective.UNKNOWN,
269
+ model_objective: model_types.ModelObjective = model_types.ModelObjective.UNKNOWN,
265
270
  explain_algorithm: Optional[model_meta_schema.ModelExplainAlgorithm] = None,
266
271
  ) -> None:
267
272
  self.name = name
@@ -287,9 +292,7 @@ class ModelMetadata:
287
292
 
288
293
  self.original_metadata_version = original_metadata_version
289
294
 
290
- self.model_objective: model_meta_schema.ModelObjective = (
291
- model_objective or model_meta_schema.ModelObjective.UNKNOWN
292
- )
295
+ self.model_objective: model_types.ModelObjective = model_objective
293
296
  self.explain_algorithm: Optional[model_meta_schema.ModelExplainAlgorithm] = explain_algorithm
294
297
 
295
298
  @property
@@ -387,7 +390,7 @@ class ModelMetadata:
387
390
  signatures=loaded_meta["signatures"],
388
391
  version=original_loaded_meta_version,
389
392
  min_snowpark_ml_version=loaded_meta_min_snowpark_ml_version,
390
- model_objective=loaded_meta.get("model_objective", model_meta_schema.ModelObjective.UNKNOWN.value),
393
+ model_objective=loaded_meta.get("model_objective", model_types.ModelObjective.UNKNOWN.value),
391
394
  explainability=loaded_meta.get("explainability", None),
392
395
  function_properties=loaded_meta.get("function_properties", {}),
393
396
  )
@@ -442,8 +445,8 @@ class ModelMetadata:
442
445
  min_snowpark_ml_version=model_dict["min_snowpark_ml_version"],
443
446
  models=models,
444
447
  original_metadata_version=model_dict["version"],
445
- model_objective=model_meta_schema.ModelObjective(
446
- model_dict.get("model_objective", model_meta_schema.ModelObjective.UNKNOWN.value)
448
+ model_objective=model_types.ModelObjective(
449
+ model_dict.get("model_objective", model_types.ModelObjective.UNKNOWN.value)
447
450
  ),
448
451
  explain_algorithm=explanation_algorithm,
449
452
  function_properties=model_dict.get("function_properties", {}),