snowflake-ml-python 1.8.3__py3-none-any.whl → 1.8.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (190) hide show
  1. snowflake/cortex/__init__.py +7 -1
  2. snowflake/ml/_internal/platform_capabilities.py +13 -11
  3. snowflake/ml/_internal/utils/identifier.py +2 -2
  4. snowflake/ml/jobs/_utils/constants.py +1 -1
  5. snowflake/ml/jobs/_utils/payload_utils.py +39 -30
  6. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
  7. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +1 -1
  8. snowflake/ml/jobs/_utils/spec_utils.py +1 -1
  9. snowflake/ml/jobs/decorators.py +6 -0
  10. snowflake/ml/jobs/job.py +63 -16
  11. snowflake/ml/jobs/manager.py +50 -16
  12. snowflake/ml/model/_client/model/model_version_impl.py +1 -1
  13. snowflake/ml/model/_client/ops/service_ops.py +26 -14
  14. snowflake/ml/model/_client/service/model_deployment_spec.py +340 -170
  15. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +25 -0
  16. snowflake/ml/model/_client/sql/service.py +4 -13
  17. snowflake/ml/model/_model_composer/model_composer.py +41 -18
  18. snowflake/ml/model/_packager/model_handlers/_utils.py +32 -2
  19. snowflake/ml/model/_packager/model_handlers/custom.py +1 -1
  20. snowflake/ml/model/_packager/model_handlers/pytorch.py +1 -2
  21. snowflake/ml/model/_packager/model_handlers/sklearn.py +100 -41
  22. snowflake/ml/model/_packager/model_handlers/tensorflow.py +7 -4
  23. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  24. snowflake/ml/model/_packager/model_handlers/xgboost.py +16 -7
  25. snowflake/ml/model/_packager/model_meta/model_meta.py +2 -1
  26. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
  27. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +4 -4
  28. snowflake/ml/model/_signatures/dmatrix_handler.py +15 -2
  29. snowflake/ml/model/custom_model.py +17 -4
  30. snowflake/ml/model/model_signature.py +3 -3
  31. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
  32. snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
  33. snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
  34. snowflake/ml/modeling/cluster/birch.py +9 -1
  35. snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
  36. snowflake/ml/modeling/cluster/dbscan.py +9 -1
  37. snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
  38. snowflake/ml/modeling/cluster/k_means.py +9 -1
  39. snowflake/ml/modeling/cluster/mean_shift.py +9 -1
  40. snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
  41. snowflake/ml/modeling/cluster/optics.py +9 -1
  42. snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
  43. snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
  44. snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
  45. snowflake/ml/modeling/compose/column_transformer.py +9 -1
  46. snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
  47. snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
  48. snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
  49. snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
  50. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
  51. snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
  52. snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
  53. snowflake/ml/modeling/covariance/oas.py +9 -1
  54. snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
  55. snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
  56. snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
  57. snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
  58. snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
  59. snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
  60. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
  61. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
  62. snowflake/ml/modeling/decomposition/pca.py +9 -1
  63. snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
  64. snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
  65. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
  66. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
  67. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
  68. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
  69. snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
  70. snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
  71. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
  72. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
  73. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
  74. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
  75. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
  76. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
  77. snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
  78. snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
  79. snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
  80. snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
  81. snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
  82. snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
  83. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
  84. snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
  85. snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
  86. snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
  87. snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
  88. snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
  89. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
  90. snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
  91. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
  92. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
  93. snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
  94. snowflake/ml/modeling/impute/knn_imputer.py +9 -1
  95. snowflake/ml/modeling/impute/missing_indicator.py +9 -1
  96. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
  97. snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
  98. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
  99. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
  100. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
  101. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
  102. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
  103. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
  104. snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
  105. snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
  106. snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
  107. snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
  108. snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
  109. snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
  110. snowflake/ml/modeling/linear_model/lars.py +9 -1
  111. snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
  112. snowflake/ml/modeling/linear_model/lasso.py +9 -1
  113. snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
  114. snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
  115. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
  116. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
  117. snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
  118. snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
  119. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
  120. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
  121. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
  122. snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
  123. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
  124. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
  125. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
  126. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
  127. snowflake/ml/modeling/linear_model/perceptron.py +9 -1
  128. snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
  129. snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
  130. snowflake/ml/modeling/linear_model/ridge.py +9 -1
  131. snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
  132. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
  133. snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
  134. snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
  135. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
  136. snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
  137. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
  138. snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
  139. snowflake/ml/modeling/manifold/isomap.py +9 -1
  140. snowflake/ml/modeling/manifold/mds.py +9 -1
  141. snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
  142. snowflake/ml/modeling/manifold/tsne.py +9 -1
  143. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
  144. snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
  145. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
  146. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
  147. snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
  148. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
  149. snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
  150. snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
  151. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
  152. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
  153. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
  154. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
  155. snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
  156. snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
  157. snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
  158. snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
  159. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
  160. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
  161. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
  162. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
  163. snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
  164. snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
  165. snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
  166. snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
  167. snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
  168. snowflake/ml/modeling/svm/linear_svc.py +9 -1
  169. snowflake/ml/modeling/svm/linear_svr.py +9 -1
  170. snowflake/ml/modeling/svm/nu_svc.py +9 -1
  171. snowflake/ml/modeling/svm/nu_svr.py +9 -1
  172. snowflake/ml/modeling/svm/svc.py +9 -1
  173. snowflake/ml/modeling/svm/svr.py +9 -1
  174. snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
  175. snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
  176. snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
  177. snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
  178. snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
  179. snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
  180. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
  181. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
  182. snowflake/ml/monitoring/explain_visualize.py +286 -0
  183. snowflake/ml/registry/_manager/model_manager.py +23 -2
  184. snowflake/ml/registry/registry.py +10 -9
  185. snowflake/ml/version.py +1 -1
  186. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.4.dist-info}/METADATA +40 -8
  187. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.4.dist-info}/RECORD +190 -189
  188. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.4.dist-info}/WHEEL +1 -1
  189. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.4.dist-info}/licenses/LICENSE.txt +0 -0
  190. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.4.dist-info}/top_level.txt +0 -0
@@ -142,28 +142,51 @@ class ModelComposer:
142
142
  conda_dep_dict = env_utils.validate_conda_dependency_string_list(
143
143
  conda_dependencies if conda_dependencies else []
144
144
  )
145
- is_warehouse_runnable = (
146
- not conda_dep_dict
147
- or all(
148
- chan == env_utils.DEFAULT_CHANNEL_NAME or chan == env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
149
- for chan in conda_dep_dict
150
- )
151
- ) and (not pip_requirements)
152
- disable_explainability = (
153
- target_platforms and model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES in target_platforms
154
- ) or (not is_warehouse_runnable)
155
-
156
- if disable_explainability and options and options.get("enable_explainability", False):
157
- warnings.warn(
158
- ("The model can be deployed to Snowpark Container Services only if `enable_explainability=False`."),
159
- category=UserWarning,
160
- stacklevel=2,
145
+
146
+ enable_explainability = None
147
+
148
+ if options:
149
+ enable_explainability = options.get("enable_explainability", None)
150
+
151
+ # skip everything if user said False explicitly
152
+ if enable_explainability is None or enable_explainability is True:
153
+ is_warehouse_runnable = (
154
+ not conda_dep_dict
155
+ or all(
156
+ chan == env_utils.DEFAULT_CHANNEL_NAME or chan == env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
157
+ for chan in conda_dep_dict
158
+ )
159
+ ) and (not pip_requirements)
160
+
161
+ only_spcs = (
162
+ target_platforms
163
+ and len(target_platforms) == 1
164
+ and model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES in target_platforms
161
165
  )
166
+ if only_spcs or (not is_warehouse_runnable):
167
+ # if only SPCS and user asked for explainability we fail
168
+ if enable_explainability is True:
169
+ raise ValueError(
170
+ "`enable_explainability` cannot be set to True when the model is not runnable in WH "
171
+ "or the target platforms include SPCS."
172
+ )
173
+ elif not options: # explicitly set flag to false in these cases if not specified
174
+ options = model_types.BaseModelSaveOption()
175
+ options["enable_explainability"] = False
176
+ elif (
177
+ target_platforms
178
+ and len(target_platforms) > 1
179
+ and model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES in target_platforms
180
+ ): # if both then only available for WH
181
+ if enable_explainability is True:
182
+ warnings.warn(
183
+ ("Explain function will only be available for model deployed to warehouse."),
184
+ category=UserWarning,
185
+ stacklevel=2,
186
+ )
162
187
 
163
188
  if not options:
164
189
  options = model_types.BaseModelSaveOption()
165
- if disable_explainability:
166
- options["enable_explainability"] = False
167
190
 
168
191
  if not snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
169
192
  snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema(
@@ -109,6 +109,35 @@ def get_input_signature(
109
109
  return input_sig
110
110
 
111
111
 
112
+ def add_inferred_explain_method_signature(
113
+ model_meta: model_meta.ModelMetadata,
114
+ explain_method: str,
115
+ target_method: str,
116
+ background_data: model_types.SupportedDataType,
117
+ explain_fn: Callable[[model_types.SupportedLocalDataType], model_types.SupportedLocalDataType],
118
+ output_feature_names: Optional[Sequence[str]] = None,
119
+ ) -> model_meta.ModelMetadata:
120
+ inputs = get_input_signature(model_meta, target_method)
121
+ if output_feature_names is None: # If not provided, assume output feature names are the same as input feature names
122
+ output_feature_names = [spec.name for spec in inputs]
123
+
124
+ if model_meta.model_type == "snowml":
125
+ suffixed_output_names = [identifier.concat_names([name, "_explanation"]) for name in output_feature_names]
126
+ else:
127
+ suffixed_output_names = [f"{name}_explanation" for name in output_feature_names]
128
+
129
+ truncated_background_data = get_truncated_sample_data(background_data, 5)
130
+ sig = model_signature.infer_signature(
131
+ input_data=truncated_background_data,
132
+ output_data=explain_fn(truncated_background_data),
133
+ input_feature_names=[spec.name for spec in inputs],
134
+ output_feature_names=suffixed_output_names,
135
+ )
136
+
137
+ model_meta.signatures[explain_method] = sig
138
+ return model_meta
139
+
140
+
112
141
  def add_explain_method_signature(
113
142
  model_meta: model_meta.ModelMetadata,
114
143
  explain_method: str,
@@ -236,8 +265,9 @@ def validate_model_task(passed_model_task: model_types.Task, inferred_model_task
236
265
  def get_explain_target_method(
237
266
  model_metadata: model_meta.ModelMetadata, target_methods_list: list[str]
238
267
  ) -> Optional[str]:
239
- for method in model_metadata.signatures.keys():
240
- if method in target_methods_list:
268
+ """Returns the first target method that is found in the model metadata signatures."""
269
+ for method in target_methods_list:
270
+ if method in model_metadata.signatures.keys():
241
271
  return method
242
272
  return None
243
273
 
@@ -72,7 +72,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
72
72
  predictions_df = target_method(model, sample_input_data)
73
73
  return predictions_df
74
74
 
75
- for func_name in model._get_partitioned_infer_methods():
75
+ for func_name in model._get_partitioned_methods():
76
76
  function_properties = model_meta.function_properties.get(func_name, {})
77
77
  function_properties[model_meta_schema.FunctionProperties.PARTITIONED.value] = True
78
78
  model_meta.function_properties[func_name] = function_properties
@@ -82,6 +82,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
82
82
  enable_explainability = kwargs.get("enable_explainability", False)
83
83
  if enable_explainability:
84
84
  raise NotImplementedError("Explainability is not supported for PyTorch model.")
85
+ multiple_inputs = kwargs.get("multiple_inputs", False)
85
86
 
86
87
  import torch
87
88
 
@@ -94,8 +95,6 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
94
95
  default_target_methods=cls.DEFAULT_TARGET_METHODS,
95
96
  )
96
97
 
97
- multiple_inputs = kwargs.get("multiple_inputs", False)
98
-
99
98
  def get_prediction(
100
99
  target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
101
100
  ) -> model_types.SupportedLocalDataType:
@@ -1,6 +1,6 @@
1
1
  import os
2
2
  import warnings
3
- from typing import TYPE_CHECKING, Callable, Optional, Union, cast, final
3
+ from typing import TYPE_CHECKING, Callable, Optional, Sequence, Union, cast, final
4
4
 
5
5
  import cloudpickle
6
6
  import numpy as np
@@ -38,6 +38,35 @@ def _unpack_container_runtime_pipeline(model: "sklearn.pipeline.Pipeline") -> "s
38
38
  return model
39
39
 
40
40
 
41
+ def _apply_transforms_up_to_last_step(
42
+ model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
43
+ data: model_types.SupportedDataType,
44
+ input_feature_names: Optional[list[str]] = None,
45
+ ) -> pd.DataFrame:
46
+ """Apply all transformations in the sklearn pipeline model up to the last step."""
47
+ transformed_data = data
48
+ output_features_names = input_feature_names
49
+
50
+ if type_utils.LazyType("sklearn.pipeline.Pipeline").isinstance(model):
51
+ for step_name, step in model.steps[:-1]: # type: ignore[attr-defined]
52
+ if not hasattr(step, "transform"):
53
+ raise ValueError(f"Step '{step_name}' does not have a 'transform' method.")
54
+ transformed_data = step.transform(transformed_data)
55
+ if output_features_names is None:
56
+ continue
57
+ elif hasattr(step, "get_feature_names_out"):
58
+ output_features_names = step.get_feature_names_out(output_features_names)
59
+ else:
60
+ raise ValueError(
61
+ f"Step '{step_name}' in the pipeline does not have a 'get_feature_names_out' method. "
62
+ "Feature names cannot be propagated."
63
+ )
64
+ if type_utils.LazyType("scipy.sparse.csr_matrix").isinstance(transformed_data):
65
+ # Convert to dense array if it's a sparse matrix
66
+ transformed_data = transformed_data.toarray() # type: ignore[attr-defined]
67
+ return pd.DataFrame(transformed_data, columns=output_features_names)
68
+
69
+
41
70
  @final
42
71
  class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"]]):
43
72
  """Handler for scikit-learn based model.
@@ -58,7 +87,9 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
58
87
  "decision_function",
59
88
  "score_samples",
60
89
  ]
61
- EXPLAIN_TARGET_METHODS = ["predict", "predict_proba", "predict_log_proba"]
90
+
91
+ # Prioritize predict_proba as it gives multi-class probabilities
92
+ EXPLAIN_TARGET_METHODS = ["predict_proba", "predict", "predict_log_proba"]
62
93
 
63
94
  @classmethod
64
95
  def can_handle(
@@ -160,17 +191,38 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
160
191
  stacklevel=1,
161
192
  )
162
193
  enable_explainability = False
163
- elif model_meta.task == model_types.Task.UNKNOWN or explain_target_method is None:
194
+ elif model_meta.task == model_types.Task.UNKNOWN:
195
+ enable_explainability = False
196
+ elif explain_target_method is None:
164
197
  enable_explainability = False
165
198
  else:
166
199
  enable_explainability = True
167
200
  if enable_explainability:
168
- model_meta = handlers_utils.add_explain_method_signature(
169
- model_meta=model_meta,
170
- explain_method="explain",
171
- target_method=explain_target_method,
172
- output_return_type=model_task_and_output_type.output_type,
201
+ explain_target_method = str(explain_target_method) # mypy complains if we don't cast to str here
202
+
203
+ input_signature = handlers_utils.get_input_signature(model_meta, explain_target_method)
204
+ transformed_background_data = _apply_transforms_up_to_last_step(
205
+ model=model,
206
+ data=background_data,
207
+ input_feature_names=[spec.name for spec in input_signature],
173
208
  )
209
+
210
+ try:
211
+ model_meta = handlers_utils.add_inferred_explain_method_signature(
212
+ model_meta=model_meta,
213
+ explain_method="explain",
214
+ target_method=explain_target_method,
215
+ background_data=background_data,
216
+ explain_fn=cls._build_explain_fn(model, background_data, input_signature),
217
+ output_feature_names=transformed_background_data.columns,
218
+ )
219
+ except ValueError:
220
+ if kwargs.get("enable_explainability", None):
221
+ # user explicitly enabled explainability, so we should raise the error
222
+ raise ValueError(
223
+ "Explainability for this model is not supported. Please set `enable_explainability=False`"
224
+ )
225
+
174
226
  handlers_utils.save_background_data(
175
227
  model_blobs_dir_path,
176
228
  cls.EXPLAIN_ARTIFACTS_DIR,
@@ -222,11 +274,13 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
222
274
  )
223
275
 
224
276
  if enable_explainability:
225
- model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
277
+ model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap>=0.46.0", pip_name="shap")])
226
278
  model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
227
279
 
228
280
  model_meta.env.include_if_absent(
229
- [model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn")],
281
+ [
282
+ model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn"),
283
+ ],
230
284
  check_local_version=True,
231
285
  )
232
286
 
@@ -286,37 +340,8 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
286
340
 
287
341
  @custom_model.inference_api
288
342
  def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
289
- import shap
290
-
291
- try:
292
- explainer = shap.Explainer(raw_model, background_data)
293
- df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
294
- except TypeError:
295
- try:
296
- dtype_map = {spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in signature.inputs}
297
-
298
- if isinstance(X, pd.DataFrame):
299
- X = X.astype(dtype_map, copy=False)
300
- if hasattr(raw_model, "predict_proba"):
301
- if isinstance(X, np.ndarray):
302
- explanations = shap.Explainer(
303
- raw_model.predict_proba, background_data.values # type: ignore[union-attr]
304
- )(X).values
305
- else:
306
- explanations = shap.Explainer(raw_model.predict_proba, background_data)(X).values
307
- elif hasattr(raw_model, "predict"):
308
- if isinstance(X, np.ndarray):
309
- explanations = shap.Explainer(
310
- raw_model.predict, background_data.values # type: ignore[union-attr]
311
- )(X).values
312
- else:
313
- explanations = shap.Explainer(raw_model.predict, background_data)(X).values
314
- else:
315
- raise ValueError("Missing any supported target method to explain.")
316
- df = handlers_utils.convert_explanations_to_2D_df(raw_model, explanations)
317
- except TypeError as e:
318
- raise ValueError(f"Explanation for this model type not supported yet: {str(e)}")
319
- return model_signature_utils.rename_pandas_df(df, signature.outputs)
343
+ fn = cls._build_explain_fn(raw_model, background_data, signature.inputs)
344
+ return model_signature_utils.rename_pandas_df(fn(X), signature.outputs)
320
345
 
321
346
  if target_method == "explain":
322
347
  return explain_fn
@@ -339,3 +364,37 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
339
364
  skl_model = _SKLModel(custom_model.ModelContext())
340
365
 
341
366
  return skl_model
367
+
368
+ @classmethod
369
+ def _build_explain_fn(
370
+ cls,
371
+ model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
372
+ background_data: model_types.SupportedDataType,
373
+ input_specs: Sequence[model_signature.BaseFeatureSpec],
374
+ ) -> Callable[[model_types.SupportedDataType], pd.DataFrame]:
375
+ import shap
376
+ import sklearn.pipeline
377
+
378
+ transformed_bg_data = _apply_transforms_up_to_last_step(model, background_data)
379
+
380
+ def explain_fn(data: model_types.SupportedDataType) -> pd.DataFrame:
381
+ transformed_data = _apply_transforms_up_to_last_step(model, data)
382
+ predictor = model[-1] if isinstance(model, sklearn.pipeline.Pipeline) else model
383
+ try:
384
+ explainer = shap.Explainer(predictor, transformed_bg_data)
385
+ return handlers_utils.convert_explanations_to_2D_df(model, explainer(transformed_data).values)
386
+ except TypeError:
387
+ if isinstance(data, pd.DataFrame):
388
+ dtype_map = {spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in input_specs}
389
+ transformed_data = _apply_transforms_up_to_last_step(model, data.astype(dtype_map))
390
+ for explain_target_method in cls.EXPLAIN_TARGET_METHODS:
391
+ if not hasattr(predictor, explain_target_method):
392
+ continue
393
+ explain_target_method_fn = getattr(predictor, explain_target_method)
394
+ explanations = shap.Explainer(explain_target_method_fn, transformed_bg_data.values)(
395
+ transformed_data.to_numpy()
396
+ ).values
397
+ return handlers_utils.convert_explanations_to_2D_df(model, explanations)
398
+ raise ValueError("Missing any supported target method to explain.")
399
+
400
+ return explain_fn
@@ -88,6 +88,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
88
88
  import tensorflow
89
89
 
90
90
  assert isinstance(model, tensorflow.Module)
91
+ multiple_inputs = kwargs.get("multiple_inputs", False)
91
92
 
92
93
  is_keras_model = type_utils.LazyType("keras.Model").isinstance(model)
93
94
  is_tf_keras_model = type_utils.LazyType("tf_keras.Model").isinstance(model)
@@ -112,8 +113,6 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
112
113
  default_target_methods=default_target_methods,
113
114
  )
114
115
 
115
- multiple_inputs = kwargs.get("multiple_inputs", False)
116
-
117
116
  if is_keras_model and len(target_methods) > 1:
118
117
  raise ValueError("Keras model can only have one target method.")
119
118
 
@@ -198,7 +197,6 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
198
197
  model_blobs_dir_path: str,
199
198
  **kwargs: Unpack[model_types.TensorflowLoadOptions],
200
199
  ) -> "tensorflow.Module":
201
- os.environ["TF_USE_LEGACY_KERAS"] = "1"
202
200
  import tensorflow
203
201
 
204
202
  model_blob_path = os.path.join(model_blobs_dir_path, name)
@@ -209,7 +207,12 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
209
207
  load_path = os.path.join(model_blob_path, model_blob_filename)
210
208
  save_format = model_blob_options.get("save_format", "keras_tf")
211
209
  if save_format == "keras_tf":
212
- m = tensorflow.keras.models.load_model(load_path)
210
+ if version.parse(tensorflow.keras.__version__) >= version.parse("3.0.0"):
211
+ import tf_keras
212
+
213
+ m = tf_keras.models.load_model(load_path)
214
+ else:
215
+ m = tensorflow.keras.models.load_model(load_path)
213
216
  else:
214
217
  m = tensorflow.saved_model.load(load_path)
215
218
 
@@ -76,6 +76,8 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
76
76
  if enable_explainability:
77
77
  raise NotImplementedError("Explainability is not supported for Torch Script model.")
78
78
 
79
+ multiple_inputs = kwargs.get("multiple_inputs", False)
80
+
79
81
  import torch
80
82
 
81
83
  assert isinstance(model, torch.jit.ScriptModule)
@@ -87,8 +89,6 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
87
89
  default_target_methods=cls.DEFAULT_TARGET_METHODS,
88
90
  )
89
91
 
90
- multiple_inputs = kwargs.get("multiple_inputs", False)
91
-
92
92
  def get_prediction(
93
93
  target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
94
94
  ) -> model_types.SupportedLocalDataType:
@@ -144,7 +144,12 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
144
144
  model_type=cls.HANDLER_TYPE,
145
145
  handler_version=cls.HANDLER_VERSION,
146
146
  path=cls.MODEL_BLOB_FILE_OR_DIR,
147
- options=model_meta_schema.XgboostModelBlobOptions({"xgb_estimator_type": model.__class__.__name__}),
147
+ options=model_meta_schema.XgboostModelBlobOptions(
148
+ {
149
+ "xgb_estimator_type": model.__class__.__name__,
150
+ "enable_categorical": getattr(model, "enable_categorical", False),
151
+ }
152
+ ),
148
153
  )
149
154
  model_meta.models[name] = base_meta
150
155
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
@@ -152,11 +157,6 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
152
157
  model_meta.env.include_if_absent(
153
158
  [
154
159
  model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn"),
155
- ],
156
- check_local_version=True,
157
- )
158
- model_meta.env.include_if_absent(
159
- [
160
160
  model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
161
161
  ],
162
162
  check_local_version=True,
@@ -190,6 +190,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
190
190
  raise ValueError("Type of XGB estimator is illegal.")
191
191
  m = getattr(xgboost, xgb_estimator_type)()
192
192
  m.load_model(os.path.join(model_blob_path, model_blob_filename))
193
+ m.enable_categorical = model_blob_options.get("enable_categorical", False)
193
194
 
194
195
  if kwargs.get("use_gpu", False):
195
196
  assert type(kwargs.get("use_gpu", False)) == bool
@@ -225,8 +226,16 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
225
226
  ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
226
227
  @custom_model.inference_api
227
228
  def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
229
+ enable_categorical = False
230
+ for col, d_type in X.dtypes.items():
231
+ if pd.api.extensions.ExtensionDtype.is_dtype(d_type):
232
+ continue
233
+ if not np.issubdtype(d_type, np.number):
234
+ # categorical columns are converted to numpy's str dtype
235
+ X[col] = X[col].astype("category")
236
+ enable_categorical = True
228
237
  if isinstance(raw_model, xgboost.Booster):
229
- X = xgboost.DMatrix(X)
238
+ X = xgboost.DMatrix(X, enable_categorical=enable_categorical)
230
239
 
231
240
  res = getattr(raw_model, target_method)(X)
232
241
 
@@ -65,7 +65,8 @@ def create_model_metadata(
65
65
  ext_modules: List of names of modules that need to be pickled with the model. Defaults to None.
66
66
  conda_dependencies: List of conda requirements for running the model. Defaults to None.
67
67
  pip_requirements: List of pip Python packages requirements for running the model. Defaults to None.
68
- artifact_repository_map: A dict mapping from package channel to artifact repository name.
68
+ artifact_repository_map: A dict mapping from package channel to artifact repository name (e.g.
69
+ {'pip': 'snowflake.snowpark.pypi_shared_repository'}).
69
70
  resource_constraint: Mapping of resource constraint keys and values, e.g. {"architecture": "x86"}.
70
71
  target_platforms: List of target platforms to run the model.
71
72
  python_version: A string of python version where model is run. Used for user override. If specified as None,
@@ -63,6 +63,7 @@ class MLFlowModelBlobOptions(BaseModelBlobOptions):
63
63
 
64
64
  class XgboostModelBlobOptions(BaseModelBlobOptions):
65
65
  xgb_estimator_type: Required[str]
66
+ enable_categorical: NotRequired[bool]
66
67
 
67
68
 
68
69
  class PyTorchModelBlobOptions(BaseModelBlobOptions):
@@ -6,7 +6,7 @@ REQUIREMENTS = [
6
6
  "aiohttp!=4.0.0a0, !=4.0.0a1",
7
7
  "anyio>=3.5.0,<5",
8
8
  "cachetools>=3.1.1,<6",
9
- "cloudpickle>=2.0.0,<3",
9
+ "cloudpickle>=2.0.0",
10
10
  "cryptography",
11
11
  "fsspec>=2024.6.1,<2026",
12
12
  "importlib_resources>=6.1.1, <7",
@@ -21,12 +21,12 @@ REQUIREMENTS = [
21
21
  "requests",
22
22
  "retrying>=1.3.3,<2",
23
23
  "s3fs>=2024.6.1,<2026",
24
- "scikit-learn>=1.4,<1.6",
24
+ "scikit-learn<1.6",
25
25
  "scipy>=1.9,<2",
26
- "snowflake-connector-python>=3.12.0,<4",
26
+ "shap>=0.46.0,<1",
27
+ "snowflake-connector-python>=3.14.0,<4",
27
28
  "snowflake-snowpark-python>=1.17.0,<2,!=1.26.0",
28
29
  "snowflake.core>=1.0.2,<2",
29
30
  "sqlparse>=0.4,<1",
30
31
  "typing-extensions>=4.1.0,<5",
31
- "xgboost>=1.7.3,<3",
32
32
  ]
@@ -81,8 +81,16 @@ class XGBoostDMatrixHandler(base_handler.BaseDataHandler["xgboost.DMatrix"]):
81
81
  ) -> "xgboost.DMatrix":
82
82
  import xgboost as xgb
83
83
 
84
+ enable_categorical = False
85
+ for col, d_type in df.dtypes.items():
86
+ if pd.api.extensions.ExtensionDtype.is_dtype(d_type):
87
+ continue
88
+ if not np.issubdtype(d_type, np.number):
89
+ df[col] = df[col].astype("category")
90
+ enable_categorical = True
91
+
84
92
  if not features:
85
- return xgb.DMatrix(df)
93
+ return xgb.DMatrix(df, enable_categorical=enable_categorical)
86
94
  else:
87
95
  feature_names = []
88
96
  feature_types = []
@@ -95,4 +103,9 @@ class XGBoostDMatrixHandler(base_handler.BaseDataHandler["xgboost.DMatrix"]):
95
103
  assert isinstance(feature, core.FeatureSpec), "Invalid feature kind."
96
104
  feature_names.append(feature.name)
97
105
  feature_types.append(feature._dtype._numpy_type)
98
- return xgb.DMatrix(df, feature_names=feature_names, feature_types=feature_types)
106
+ return xgb.DMatrix(
107
+ df,
108
+ feature_names=feature_names,
109
+ feature_types=feature_types,
110
+ enable_categorical=enable_categorical,
111
+ )
@@ -4,6 +4,7 @@ from typing import Any, Callable, Coroutine, Generator, Optional, Union
4
4
 
5
5
  import anyio
6
6
  import pandas as pd
7
+ from typing_extensions import deprecated
7
8
 
8
9
  from snowflake.ml.model import type_hints as model_types
9
10
 
@@ -226,12 +227,12 @@ class CustomModel:
226
227
  else:
227
228
  raise TypeError("A non-method inference API function is not supported.")
228
229
 
229
- def _get_partitioned_infer_methods(self) -> list[str]:
230
- """Returns all methods in CLS with `partitioned_inference_api` as the outermost decorator."""
230
+ def _get_partitioned_methods(self) -> list[str]:
231
+ """Returns all methods in CLS with `partitioned_api` as the outermost decorator."""
231
232
  rv = []
232
233
  for cls_method_str in dir(self):
233
234
  cls_method = getattr(self, cls_method_str)
234
- if getattr(cls_method, "_is_partitioned_inference_api", False):
235
+ if getattr(cls_method, "_is_partitioned_api", False):
235
236
  if inspect.ismethod(cls_method):
236
237
  rv.append(cls_method_str)
237
238
  else:
@@ -282,9 +283,21 @@ def inference_api(
282
283
  return func
283
284
 
284
285
 
286
+ def partitioned_api(
287
+ func: Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame],
288
+ ) -> Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]:
289
+ func.__dict__["_is_inference_api"] = True
290
+ func.__dict__["_is_partitioned_api"] = True
291
+ return func
292
+
293
+
294
+ @deprecated(
295
+ "snowflake.ml.custom_model.partitioned_inference_api is deprecated and will be removed in a future release."
296
+ " Use snowflake.ml.custom_model.partitioned_api instead."
297
+ )
285
298
  def partitioned_inference_api(
286
299
  func: Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame],
287
300
  ) -> Callable[[model_types.CustomModelType, pd.DataFrame], pd.DataFrame]:
288
301
  func.__dict__["_is_inference_api"] = True
289
- func.__dict__["_is_partitioned_inference_api"] = True
302
+ func.__dict__["_is_partitioned_api"] = True
290
303
  return func
@@ -71,9 +71,9 @@ def _truncate_data(
71
71
  warnings.warn(
72
72
  formatting.unwrap(
73
73
  f"""
74
- The sample input has {row_count} rows, thus a truncation happened before inferring signature.
75
- This might cause inaccurate signature inference.
76
- If that happens, consider specifying signature manually.
74
+ The sample input has {row_count} rows. Using the first 100 rows to define the inputs and outputs
75
+ of the model and the data types of each. Use `signatures` parameter to specify model inputs and
76
+ outputs manually if the automatic inference is not correct.
77
77
  """
78
78
  ),
79
79
  category=UserWarning,
@@ -11,7 +11,7 @@ import cloudpickle as cp
11
11
  import numpy as np
12
12
  import pandas as pd
13
13
  from numpy import typing as npt
14
-
14
+ from packaging import version
15
15
 
16
16
  import numpy
17
17
  import sklearn
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
64
+ # Modeling library estimators require a smaller sklearn version range.
65
+ if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
+ raise Exception(
67
+ f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
68
+ )
69
+
70
+
63
71
  class CalibratedClassifierCV(BaseTransformer):
64
72
  r"""Probability calibration with isotonic regression or logistic regression
65
73
  For more details on this class, see [sklearn.calibration.CalibratedClassifierCV]
@@ -11,7 +11,7 @@ import cloudpickle as cp
11
11
  import numpy as np
12
12
  import pandas as pd
13
13
  from numpy import typing as npt
14
-
14
+ from packaging import version
15
15
 
16
16
  import numpy
17
17
  import sklearn
@@ -60,6 +60,14 @@ DATAFRAME_TYPE = Union[DataFrame, pd.DataFrame]
60
60
 
61
61
  INFER_SIGNATURE_MAX_ROWS = 100
62
62
 
63
+ SKLEARN_LOWER, SKLEARN_UPPER = ('1.4', '1.6')
64
+ # Modeling library estimators require a smaller sklearn version range.
65
+ if not version.Version(SKLEARN_LOWER) <= version.Version(sklearn.__version__) < version.Version(SKLEARN_UPPER):
66
+ raise Exception(
67
+ f"To use the modeling library, install scikit-learn version >= {SKLEARN_LOWER} and < {SKLEARN_UPPER}"
68
+ )
69
+
70
+
63
71
  class AffinityPropagation(BaseTransformer):
64
72
  r"""Perform Affinity Propagation Clustering of data
65
73
  For more details on this class, see [sklearn.cluster.AffinityPropagation]