snowflake-ml-python 1.8.3__py3-none-any.whl → 1.8.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. snowflake/cortex/__init__.py +7 -1
  2. snowflake/ml/_internal/platform_capabilities.py +13 -11
  3. snowflake/ml/_internal/telemetry.py +42 -13
  4. snowflake/ml/_internal/utils/identifier.py +2 -2
  5. snowflake/ml/data/data_connector.py +1 -1
  6. snowflake/ml/jobs/_utils/constants.py +10 -1
  7. snowflake/ml/jobs/_utils/interop_utils.py +1 -1
  8. snowflake/ml/jobs/_utils/payload_utils.py +51 -34
  9. snowflake/ml/jobs/_utils/scripts/constants.py +6 -0
  10. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
  11. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +86 -3
  12. snowflake/ml/jobs/_utils/spec_utils.py +8 -6
  13. snowflake/ml/jobs/decorators.py +13 -3
  14. snowflake/ml/jobs/job.py +206 -26
  15. snowflake/ml/jobs/manager.py +78 -34
  16. snowflake/ml/model/_client/model/model_version_impl.py +1 -1
  17. snowflake/ml/model/_client/ops/service_ops.py +31 -17
  18. snowflake/ml/model/_client/service/model_deployment_spec.py +351 -170
  19. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +25 -0
  20. snowflake/ml/model/_client/sql/model_version.py +1 -1
  21. snowflake/ml/model/_client/sql/service.py +20 -32
  22. snowflake/ml/model/_model_composer/model_composer.py +44 -19
  23. snowflake/ml/model/_packager/model_handlers/_utils.py +32 -2
  24. snowflake/ml/model/_packager/model_handlers/custom.py +1 -1
  25. snowflake/ml/model/_packager/model_handlers/pytorch.py +1 -2
  26. snowflake/ml/model/_packager/model_handlers/sklearn.py +100 -41
  27. snowflake/ml/model/_packager/model_handlers/tensorflow.py +7 -4
  28. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  29. snowflake/ml/model/_packager/model_handlers/xgboost.py +16 -7
  30. snowflake/ml/model/_packager/model_meta/model_meta.py +2 -1
  31. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
  32. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +5 -4
  33. snowflake/ml/model/_signatures/dmatrix_handler.py +15 -2
  34. snowflake/ml/model/custom_model.py +17 -4
  35. snowflake/ml/model/model_signature.py +3 -3
  36. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
  37. snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
  38. snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
  39. snowflake/ml/modeling/cluster/birch.py +9 -1
  40. snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
  41. snowflake/ml/modeling/cluster/dbscan.py +9 -1
  42. snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
  43. snowflake/ml/modeling/cluster/k_means.py +9 -1
  44. snowflake/ml/modeling/cluster/mean_shift.py +9 -1
  45. snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
  46. snowflake/ml/modeling/cluster/optics.py +9 -1
  47. snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
  48. snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
  49. snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
  50. snowflake/ml/modeling/compose/column_transformer.py +9 -1
  51. snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
  52. snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
  53. snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
  54. snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
  55. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
  56. snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
  57. snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
  58. snowflake/ml/modeling/covariance/oas.py +9 -1
  59. snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
  60. snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
  61. snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
  62. snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
  63. snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
  64. snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
  65. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
  66. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
  67. snowflake/ml/modeling/decomposition/pca.py +9 -1
  68. snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
  69. snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
  70. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
  71. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
  72. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
  73. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
  74. snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
  75. snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
  76. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
  77. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
  78. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
  79. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
  80. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
  81. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
  82. snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
  83. snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
  84. snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
  85. snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
  86. snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
  87. snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
  88. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
  89. snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
  90. snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
  91. snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
  92. snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
  93. snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
  94. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
  95. snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
  96. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
  97. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
  98. snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
  99. snowflake/ml/modeling/impute/knn_imputer.py +9 -1
  100. snowflake/ml/modeling/impute/missing_indicator.py +9 -1
  101. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
  102. snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
  103. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
  104. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
  105. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
  106. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
  107. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
  108. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
  109. snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
  110. snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
  111. snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
  112. snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
  113. snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
  114. snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
  115. snowflake/ml/modeling/linear_model/lars.py +9 -1
  116. snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
  117. snowflake/ml/modeling/linear_model/lasso.py +9 -1
  118. snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
  119. snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
  120. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
  121. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
  122. snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
  123. snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
  124. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
  125. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
  126. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
  127. snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
  128. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
  129. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
  130. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
  131. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
  132. snowflake/ml/modeling/linear_model/perceptron.py +9 -1
  133. snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
  134. snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
  135. snowflake/ml/modeling/linear_model/ridge.py +9 -1
  136. snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
  137. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
  138. snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
  139. snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
  140. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
  141. snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
  142. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
  143. snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
  144. snowflake/ml/modeling/manifold/isomap.py +9 -1
  145. snowflake/ml/modeling/manifold/mds.py +9 -1
  146. snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
  147. snowflake/ml/modeling/manifold/tsne.py +9 -1
  148. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
  149. snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
  150. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
  151. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
  152. snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
  153. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
  154. snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
  155. snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
  156. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
  157. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
  158. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
  159. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
  160. snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
  161. snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
  162. snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
  163. snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
  164. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
  165. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
  166. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
  167. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
  168. snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
  169. snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
  170. snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
  171. snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
  172. snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
  173. snowflake/ml/modeling/svm/linear_svc.py +9 -1
  174. snowflake/ml/modeling/svm/linear_svr.py +9 -1
  175. snowflake/ml/modeling/svm/nu_svc.py +9 -1
  176. snowflake/ml/modeling/svm/nu_svr.py +9 -1
  177. snowflake/ml/modeling/svm/svc.py +9 -1
  178. snowflake/ml/modeling/svm/svr.py +9 -1
  179. snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
  180. snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
  181. snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
  182. snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
  183. snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
  184. snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
  185. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
  186. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
  187. snowflake/ml/monitoring/explain_visualize.py +424 -0
  188. snowflake/ml/registry/_manager/model_manager.py +23 -2
  189. snowflake/ml/registry/registry.py +10 -9
  190. snowflake/ml/utils/connection_params.py +8 -2
  191. snowflake/ml/version.py +1 -1
  192. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/METADATA +58 -8
  193. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/RECORD +196 -195
  194. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/WHEEL +1 -1
  195. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/licenses/LICENSE.txt +0 -0
  196. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/top_level.txt +0 -0
@@ -1,10 +1,8 @@
1
1
  import enum
2
- import json
3
2
  import textwrap
4
3
  from typing import Any, Optional, Union
5
4
 
6
5
  from snowflake import snowpark
7
- from snowflake.ml._internal import platform_capabilities
8
6
  from snowflake.ml._internal.utils import (
9
7
  identifier,
10
8
  query_result_checker,
@@ -16,22 +14,25 @@ from snowflake.snowpark import dataframe, functions as F, row, types as spt
16
14
  from snowflake.snowpark._internal import utils as snowpark_utils
17
15
 
18
16
 
17
+ # The enum comes from https://docs.snowflake.com/en/sql-reference/sql/show-service-containers-in-service#output
18
+ # except UNKNOWN
19
19
  class ServiceStatus(enum.Enum):
20
20
  UNKNOWN = "UNKNOWN" # status is unknown because we have not received enough data from K8s yet.
21
21
  PENDING = "PENDING" # resource set is being created, can't be used yet
22
- READY = "READY" # resource set has been deployed.
23
22
  SUSPENDING = "SUSPENDING" # the service is set to suspended but the resource set is still in deleting state
24
23
  SUSPENDED = "SUSPENDED" # the service is suspended and the resource set is deleted
25
24
  DELETING = "DELETING" # resource set is being deleted
26
25
  FAILED = "FAILED" # resource set has failed and cannot be used anymore
27
26
  DONE = "DONE" # resource set has finished running
28
- NOT_FOUND = "NOT_FOUND" # not found or deleted
29
27
  INTERNAL_ERROR = "INTERNAL_ERROR" # there was an internal service error.
28
+ RUNNING = "RUNNING"
29
+ DELETED = "DELETED"
30
30
 
31
31
 
32
32
  class ServiceSQLClient(_base._BaseSQLClient):
33
33
  MODEL_INFERENCE_SERVICE_ENDPOINT_NAME_COL_NAME = "name"
34
34
  MODEL_INFERENCE_SERVICE_ENDPOINT_INGRESS_URL_COL_NAME = "ingress_url"
35
+ SERVICE_STATUS = "service_status"
35
36
 
36
37
  def build_model_container(
37
38
  self,
@@ -133,18 +134,10 @@ class ServiceSQLClient(_base._BaseSQLClient):
133
134
  input_args_sql = ", ".join(f"'{arg}', {arg.identifier()}" for arg in input_args)
134
135
  args_sql = f"object_construct_keep_null({input_args_sql})"
135
136
 
136
- if platform_capabilities.PlatformCapabilities.get_instance().is_nested_function_enabled():
137
- fully_qualified_service_name = self.fully_qualified_object_name(
138
- actual_database_name, actual_schema_name, service_name
139
- )
140
- fully_qualified_function_name = f"{fully_qualified_service_name}!{method_name.identifier()}"
141
- else:
142
- function_name = identifier.concat_names([service_name.identifier(), "_", method_name.identifier()])
143
- fully_qualified_function_name = identifier.get_schema_level_object_identifier(
144
- actual_database_name.identifier(),
145
- actual_schema_name.identifier(),
146
- function_name,
147
- )
137
+ fully_qualified_service_name = self.fully_qualified_object_name(
138
+ actual_database_name, actual_schema_name, service_name
139
+ )
140
+ fully_qualified_function_name = f"{fully_qualified_service_name}!{method_name.identifier()}"
148
141
 
149
142
  sql = textwrap.dedent(
150
143
  f"""{with_sql}
@@ -208,22 +201,17 @@ class ServiceSQLClient(_base._BaseSQLClient):
208
201
  include_message: bool = False,
209
202
  statement_params: Optional[dict[str, Any]] = None,
210
203
  ) -> tuple[ServiceStatus, Optional[str]]:
211
- system_func = "SYSTEM$GET_SERVICE_STATUS"
212
- rows = (
213
- query_result_checker.SqlResultValidator(
214
- self._session,
215
- f"CALL {system_func}('{self.fully_qualified_object_name(database_name, schema_name, service_name)}')",
216
- statement_params=statement_params,
217
- )
218
- .has_dimensions(expected_rows=1, expected_cols=1)
219
- .validate()
220
- )
221
- metadata = json.loads(rows[0][system_func])[0]
222
- if metadata and metadata["status"]:
223
- service_status = ServiceStatus(metadata["status"])
224
- message = metadata["message"] if include_message else None
225
- return service_status, message
226
- return ServiceStatus.UNKNOWN, None
204
+ fully_qualified_object_name = self.fully_qualified_object_name(database_name, schema_name, service_name)
205
+ query = f"SHOW SERVICE CONTAINERS IN SERVICE {fully_qualified_object_name}"
206
+ rows = self._session.sql(query).collect(statement_params=statement_params)
207
+ if len(rows) == 0:
208
+ return ServiceStatus.UNKNOWN, None
209
+ row = rows[0]
210
+ service_status = row[ServiceSQLClient.SERVICE_STATUS]
211
+ message = row["message"] if include_message else None
212
+ if not isinstance(service_status, ServiceStatus):
213
+ return ServiceStatus.UNKNOWN, message
214
+ return ServiceStatus(service_status), message
227
215
 
228
216
  def drop_service(
229
217
  self,
@@ -142,30 +142,55 @@ class ModelComposer:
142
142
  conda_dep_dict = env_utils.validate_conda_dependency_string_list(
143
143
  conda_dependencies if conda_dependencies else []
144
144
  )
145
- is_warehouse_runnable = (
146
- not conda_dep_dict
147
- or all(
148
- chan == env_utils.DEFAULT_CHANNEL_NAME or chan == env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
149
- for chan in conda_dep_dict
150
- )
151
- ) and (not pip_requirements)
152
- disable_explainability = (
153
- target_platforms and model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES in target_platforms
154
- ) or (not is_warehouse_runnable)
155
-
156
- if disable_explainability and options and options.get("enable_explainability", False):
157
- warnings.warn(
158
- ("The model can be deployed to Snowpark Container Services only if `enable_explainability=False`."),
159
- category=UserWarning,
160
- stacklevel=2,
145
+
146
+ enable_explainability = None
147
+
148
+ if options:
149
+ enable_explainability = options.get("enable_explainability", None)
150
+
151
+ # skip everything if user said False explicitly
152
+ if enable_explainability is None or enable_explainability is True:
153
+ is_warehouse_runnable = (
154
+ not conda_dep_dict
155
+ or all(
156
+ chan == env_utils.DEFAULT_CHANNEL_NAME or chan == env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
157
+ for chan in conda_dep_dict
158
+ )
159
+ ) and (not pip_requirements)
160
+
161
+ only_spcs = (
162
+ target_platforms
163
+ and len(target_platforms) == 1
164
+ and model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES in target_platforms
161
165
  )
166
+ if only_spcs or (not is_warehouse_runnable):
167
+ # if only SPCS and user asked for explainability we fail
168
+ if enable_explainability is True:
169
+ raise ValueError(
170
+ "`enable_explainability` cannot be set to True when the model is not runnable in WH "
171
+ "or the target platforms include SPCS."
172
+ )
173
+ elif not options: # explicitly set flag to false in these cases if not specified
174
+ options = model_types.BaseModelSaveOption()
175
+ options["enable_explainability"] = False
176
+ elif (
177
+ target_platforms
178
+ and len(target_platforms) > 1
179
+ and model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES in target_platforms
180
+ ): # if both then only available for WH
181
+ if enable_explainability is True:
182
+ warnings.warn(
183
+ ("Explain function will only be available for model deployed to warehouse."),
184
+ category=UserWarning,
185
+ stacklevel=2,
186
+ )
162
187
 
163
188
  if not options:
164
189
  options = model_types.BaseModelSaveOption()
165
- if disable_explainability:
166
- options["enable_explainability"] = False
167
190
 
168
- if not snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
191
+ if not snowpark_utils.is_in_stored_procedure() and target_platforms != [ # type: ignore[no-untyped-call]
192
+ model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES # no information schema check for SPCS-only models
193
+ ]:
169
194
  snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema(
170
195
  self.session,
171
196
  reqs=[requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}=={snowml_version.VERSION}")],
@@ -109,6 +109,35 @@ def get_input_signature(
109
109
  return input_sig
110
110
 
111
111
 
112
+ def add_inferred_explain_method_signature(
113
+ model_meta: model_meta.ModelMetadata,
114
+ explain_method: str,
115
+ target_method: str,
116
+ background_data: model_types.SupportedDataType,
117
+ explain_fn: Callable[[model_types.SupportedLocalDataType], model_types.SupportedLocalDataType],
118
+ output_feature_names: Optional[Sequence[str]] = None,
119
+ ) -> model_meta.ModelMetadata:
120
+ inputs = get_input_signature(model_meta, target_method)
121
+ if output_feature_names is None: # If not provided, assume output feature names are the same as input feature names
122
+ output_feature_names = [spec.name for spec in inputs]
123
+
124
+ if model_meta.model_type == "snowml":
125
+ suffixed_output_names = [identifier.concat_names([name, "_explanation"]) for name in output_feature_names]
126
+ else:
127
+ suffixed_output_names = [f"{name}_explanation" for name in output_feature_names]
128
+
129
+ truncated_background_data = get_truncated_sample_data(background_data, 5)
130
+ sig = model_signature.infer_signature(
131
+ input_data=truncated_background_data,
132
+ output_data=explain_fn(truncated_background_data),
133
+ input_feature_names=[spec.name for spec in inputs],
134
+ output_feature_names=suffixed_output_names,
135
+ )
136
+
137
+ model_meta.signatures[explain_method] = sig
138
+ return model_meta
139
+
140
+
112
141
  def add_explain_method_signature(
113
142
  model_meta: model_meta.ModelMetadata,
114
143
  explain_method: str,
@@ -236,8 +265,9 @@ def validate_model_task(passed_model_task: model_types.Task, inferred_model_task
236
265
  def get_explain_target_method(
237
266
  model_metadata: model_meta.ModelMetadata, target_methods_list: list[str]
238
267
  ) -> Optional[str]:
239
- for method in model_metadata.signatures.keys():
240
- if method in target_methods_list:
268
+ """Returns the first target method that is found in the model metadata signatures."""
269
+ for method in target_methods_list:
270
+ if method in model_metadata.signatures.keys():
241
271
  return method
242
272
  return None
243
273
 
@@ -72,7 +72,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
72
72
  predictions_df = target_method(model, sample_input_data)
73
73
  return predictions_df
74
74
 
75
- for func_name in model._get_partitioned_infer_methods():
75
+ for func_name in model._get_partitioned_methods():
76
76
  function_properties = model_meta.function_properties.get(func_name, {})
77
77
  function_properties[model_meta_schema.FunctionProperties.PARTITIONED.value] = True
78
78
  model_meta.function_properties[func_name] = function_properties
@@ -82,6 +82,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
82
82
  enable_explainability = kwargs.get("enable_explainability", False)
83
83
  if enable_explainability:
84
84
  raise NotImplementedError("Explainability is not supported for PyTorch model.")
85
+ multiple_inputs = kwargs.get("multiple_inputs", False)
85
86
 
86
87
  import torch
87
88
 
@@ -94,8 +95,6 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
94
95
  default_target_methods=cls.DEFAULT_TARGET_METHODS,
95
96
  )
96
97
 
97
- multiple_inputs = kwargs.get("multiple_inputs", False)
98
-
99
98
  def get_prediction(
100
99
  target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
101
100
  ) -> model_types.SupportedLocalDataType:
@@ -1,6 +1,6 @@
1
1
  import os
2
2
  import warnings
3
- from typing import TYPE_CHECKING, Callable, Optional, Union, cast, final
3
+ from typing import TYPE_CHECKING, Callable, Optional, Sequence, Union, cast, final
4
4
 
5
5
  import cloudpickle
6
6
  import numpy as np
@@ -38,6 +38,35 @@ def _unpack_container_runtime_pipeline(model: "sklearn.pipeline.Pipeline") -> "s
38
38
  return model
39
39
 
40
40
 
41
+ def _apply_transforms_up_to_last_step(
42
+ model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
43
+ data: model_types.SupportedDataType,
44
+ input_feature_names: Optional[list[str]] = None,
45
+ ) -> pd.DataFrame:
46
+ """Apply all transformations in the sklearn pipeline model up to the last step."""
47
+ transformed_data = data
48
+ output_features_names = input_feature_names
49
+
50
+ if type_utils.LazyType("sklearn.pipeline.Pipeline").isinstance(model):
51
+ for step_name, step in model.steps[:-1]: # type: ignore[attr-defined]
52
+ if not hasattr(step, "transform"):
53
+ raise ValueError(f"Step '{step_name}' does not have a 'transform' method.")
54
+ transformed_data = step.transform(transformed_data)
55
+ if output_features_names is None:
56
+ continue
57
+ elif hasattr(step, "get_feature_names_out"):
58
+ output_features_names = step.get_feature_names_out(output_features_names)
59
+ else:
60
+ raise ValueError(
61
+ f"Step '{step_name}' in the pipeline does not have a 'get_feature_names_out' method. "
62
+ "Feature names cannot be propagated."
63
+ )
64
+ if type_utils.LazyType("scipy.sparse.csr_matrix").isinstance(transformed_data):
65
+ # Convert to dense array if it's a sparse matrix
66
+ transformed_data = transformed_data.toarray() # type: ignore[attr-defined]
67
+ return pd.DataFrame(transformed_data, columns=output_features_names)
68
+
69
+
41
70
  @final
42
71
  class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"]]):
43
72
  """Handler for scikit-learn based model.
@@ -58,7 +87,9 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
58
87
  "decision_function",
59
88
  "score_samples",
60
89
  ]
61
- EXPLAIN_TARGET_METHODS = ["predict", "predict_proba", "predict_log_proba"]
90
+
91
+ # Prioritize predict_proba as it gives multi-class probabilities
92
+ EXPLAIN_TARGET_METHODS = ["predict_proba", "predict", "predict_log_proba"]
62
93
 
63
94
  @classmethod
64
95
  def can_handle(
@@ -160,17 +191,38 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
160
191
  stacklevel=1,
161
192
  )
162
193
  enable_explainability = False
163
- elif model_meta.task == model_types.Task.UNKNOWN or explain_target_method is None:
194
+ elif model_meta.task == model_types.Task.UNKNOWN:
195
+ enable_explainability = False
196
+ elif explain_target_method is None:
164
197
  enable_explainability = False
165
198
  else:
166
199
  enable_explainability = True
167
200
  if enable_explainability:
168
- model_meta = handlers_utils.add_explain_method_signature(
169
- model_meta=model_meta,
170
- explain_method="explain",
171
- target_method=explain_target_method,
172
- output_return_type=model_task_and_output_type.output_type,
201
+ explain_target_method = str(explain_target_method) # mypy complains if we don't cast to str here
202
+
203
+ input_signature = handlers_utils.get_input_signature(model_meta, explain_target_method)
204
+ transformed_background_data = _apply_transforms_up_to_last_step(
205
+ model=model,
206
+ data=background_data,
207
+ input_feature_names=[spec.name for spec in input_signature],
173
208
  )
209
+
210
+ try:
211
+ model_meta = handlers_utils.add_inferred_explain_method_signature(
212
+ model_meta=model_meta,
213
+ explain_method="explain",
214
+ target_method=explain_target_method,
215
+ background_data=background_data,
216
+ explain_fn=cls._build_explain_fn(model, background_data, input_signature),
217
+ output_feature_names=transformed_background_data.columns,
218
+ )
219
+ except Exception:
220
+ if kwargs.get("enable_explainability", None):
221
+ # user explicitly enabled explainability, so we should raise the error
222
+ raise ValueError(
223
+ "Explainability for this model is not supported. Please set `enable_explainability=False`"
224
+ )
225
+
174
226
  handlers_utils.save_background_data(
175
227
  model_blobs_dir_path,
176
228
  cls.EXPLAIN_ARTIFACTS_DIR,
@@ -222,11 +274,13 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
222
274
  )
223
275
 
224
276
  if enable_explainability:
225
- model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap", pip_name="shap")])
277
+ model_meta.env.include_if_absent([model_env.ModelDependency(requirement="shap>=0.46.0", pip_name="shap")])
226
278
  model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
227
279
 
228
280
  model_meta.env.include_if_absent(
229
- [model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn")],
281
+ [
282
+ model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn"),
283
+ ],
230
284
  check_local_version=True,
231
285
  )
232
286
 
@@ -286,37 +340,8 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
286
340
 
287
341
  @custom_model.inference_api
288
342
  def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
289
- import shap
290
-
291
- try:
292
- explainer = shap.Explainer(raw_model, background_data)
293
- df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
294
- except TypeError:
295
- try:
296
- dtype_map = {spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in signature.inputs}
297
-
298
- if isinstance(X, pd.DataFrame):
299
- X = X.astype(dtype_map, copy=False)
300
- if hasattr(raw_model, "predict_proba"):
301
- if isinstance(X, np.ndarray):
302
- explanations = shap.Explainer(
303
- raw_model.predict_proba, background_data.values # type: ignore[union-attr]
304
- )(X).values
305
- else:
306
- explanations = shap.Explainer(raw_model.predict_proba, background_data)(X).values
307
- elif hasattr(raw_model, "predict"):
308
- if isinstance(X, np.ndarray):
309
- explanations = shap.Explainer(
310
- raw_model.predict, background_data.values # type: ignore[union-attr]
311
- )(X).values
312
- else:
313
- explanations = shap.Explainer(raw_model.predict, background_data)(X).values
314
- else:
315
- raise ValueError("Missing any supported target method to explain.")
316
- df = handlers_utils.convert_explanations_to_2D_df(raw_model, explanations)
317
- except TypeError as e:
318
- raise ValueError(f"Explanation for this model type not supported yet: {str(e)}")
319
- return model_signature_utils.rename_pandas_df(df, signature.outputs)
343
+ fn = cls._build_explain_fn(raw_model, background_data, signature.inputs)
344
+ return model_signature_utils.rename_pandas_df(fn(X), signature.outputs)
320
345
 
321
346
  if target_method == "explain":
322
347
  return explain_fn
@@ -339,3 +364,37 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
339
364
  skl_model = _SKLModel(custom_model.ModelContext())
340
365
 
341
366
  return skl_model
367
+
368
+ @classmethod
369
+ def _build_explain_fn(
370
+ cls,
371
+ model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
372
+ background_data: model_types.SupportedDataType,
373
+ input_specs: Sequence[model_signature.BaseFeatureSpec],
374
+ ) -> Callable[[model_types.SupportedDataType], pd.DataFrame]:
375
+ import shap
376
+ import sklearn.pipeline
377
+
378
+ transformed_bg_data = _apply_transforms_up_to_last_step(model, background_data)
379
+
380
+ def explain_fn(data: model_types.SupportedDataType) -> pd.DataFrame:
381
+ transformed_data = _apply_transforms_up_to_last_step(model, data)
382
+ predictor = model[-1] if isinstance(model, sklearn.pipeline.Pipeline) else model
383
+ try:
384
+ explainer = shap.Explainer(predictor, transformed_bg_data)
385
+ return handlers_utils.convert_explanations_to_2D_df(model, explainer(transformed_data).values)
386
+ except TypeError:
387
+ if isinstance(data, pd.DataFrame):
388
+ dtype_map = {spec.name: spec.as_dtype(force_numpy_dtype=True) for spec in input_specs}
389
+ transformed_data = _apply_transforms_up_to_last_step(model, data.astype(dtype_map))
390
+ for explain_target_method in cls.EXPLAIN_TARGET_METHODS:
391
+ if not hasattr(predictor, explain_target_method):
392
+ continue
393
+ explain_target_method_fn = getattr(predictor, explain_target_method)
394
+ explanations = shap.Explainer(explain_target_method_fn, transformed_bg_data.values)(
395
+ transformed_data.to_numpy()
396
+ ).values
397
+ return handlers_utils.convert_explanations_to_2D_df(model, explanations)
398
+ raise ValueError("Missing any supported target method to explain.")
399
+
400
+ return explain_fn
@@ -88,6 +88,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
88
88
  import tensorflow
89
89
 
90
90
  assert isinstance(model, tensorflow.Module)
91
+ multiple_inputs = kwargs.get("multiple_inputs", False)
91
92
 
92
93
  is_keras_model = type_utils.LazyType("keras.Model").isinstance(model)
93
94
  is_tf_keras_model = type_utils.LazyType("tf_keras.Model").isinstance(model)
@@ -112,8 +113,6 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
112
113
  default_target_methods=default_target_methods,
113
114
  )
114
115
 
115
- multiple_inputs = kwargs.get("multiple_inputs", False)
116
-
117
116
  if is_keras_model and len(target_methods) > 1:
118
117
  raise ValueError("Keras model can only have one target method.")
119
118
 
@@ -198,7 +197,6 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
198
197
  model_blobs_dir_path: str,
199
198
  **kwargs: Unpack[model_types.TensorflowLoadOptions],
200
199
  ) -> "tensorflow.Module":
201
- os.environ["TF_USE_LEGACY_KERAS"] = "1"
202
200
  import tensorflow
203
201
 
204
202
  model_blob_path = os.path.join(model_blobs_dir_path, name)
@@ -209,7 +207,12 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
209
207
  load_path = os.path.join(model_blob_path, model_blob_filename)
210
208
  save_format = model_blob_options.get("save_format", "keras_tf")
211
209
  if save_format == "keras_tf":
212
- m = tensorflow.keras.models.load_model(load_path)
210
+ if version.parse(tensorflow.keras.__version__) >= version.parse("3.0.0"):
211
+ import tf_keras
212
+
213
+ m = tf_keras.models.load_model(load_path)
214
+ else:
215
+ m = tensorflow.keras.models.load_model(load_path)
213
216
  else:
214
217
  m = tensorflow.saved_model.load(load_path)
215
218
 
@@ -76,6 +76,8 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
76
76
  if enable_explainability:
77
77
  raise NotImplementedError("Explainability is not supported for Torch Script model.")
78
78
 
79
+ multiple_inputs = kwargs.get("multiple_inputs", False)
80
+
79
81
  import torch
80
82
 
81
83
  assert isinstance(model, torch.jit.ScriptModule)
@@ -87,8 +89,6 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]):
87
89
  default_target_methods=cls.DEFAULT_TARGET_METHODS,
88
90
  )
89
91
 
90
- multiple_inputs = kwargs.get("multiple_inputs", False)
91
-
92
92
  def get_prediction(
93
93
  target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
94
94
  ) -> model_types.SupportedLocalDataType:
@@ -144,7 +144,12 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
144
144
  model_type=cls.HANDLER_TYPE,
145
145
  handler_version=cls.HANDLER_VERSION,
146
146
  path=cls.MODEL_BLOB_FILE_OR_DIR,
147
- options=model_meta_schema.XgboostModelBlobOptions({"xgb_estimator_type": model.__class__.__name__}),
147
+ options=model_meta_schema.XgboostModelBlobOptions(
148
+ {
149
+ "xgb_estimator_type": model.__class__.__name__,
150
+ "enable_categorical": getattr(model, "enable_categorical", False),
151
+ }
152
+ ),
148
153
  )
149
154
  model_meta.models[name] = base_meta
150
155
  model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
@@ -152,11 +157,6 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
152
157
  model_meta.env.include_if_absent(
153
158
  [
154
159
  model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn"),
155
- ],
156
- check_local_version=True,
157
- )
158
- model_meta.env.include_if_absent(
159
- [
160
160
  model_env.ModelDependency(requirement="xgboost", pip_name="xgboost"),
161
161
  ],
162
162
  check_local_version=True,
@@ -190,6 +190,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
190
190
  raise ValueError("Type of XGB estimator is illegal.")
191
191
  m = getattr(xgboost, xgb_estimator_type)()
192
192
  m.load_model(os.path.join(model_blob_path, model_blob_filename))
193
+ m.enable_categorical = model_blob_options.get("enable_categorical", False)
193
194
 
194
195
  if kwargs.get("use_gpu", False):
195
196
  assert type(kwargs.get("use_gpu", False)) == bool
@@ -225,8 +226,16 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
225
226
  ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
226
227
  @custom_model.inference_api
227
228
  def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
229
+ enable_categorical = False
230
+ for col, d_type in X.dtypes.items():
231
+ if pd.api.extensions.ExtensionDtype.is_dtype(d_type):
232
+ continue
233
+ if not np.issubdtype(d_type, np.number):
234
+ # categorical columns are converted to numpy's str dtype
235
+ X[col] = X[col].astype("category")
236
+ enable_categorical = True
228
237
  if isinstance(raw_model, xgboost.Booster):
229
- X = xgboost.DMatrix(X)
238
+ X = xgboost.DMatrix(X, enable_categorical=enable_categorical)
230
239
 
231
240
  res = getattr(raw_model, target_method)(X)
232
241
 
@@ -65,7 +65,8 @@ def create_model_metadata(
65
65
  ext_modules: List of names of modules that need to be pickled with the model. Defaults to None.
66
66
  conda_dependencies: List of conda requirements for running the model. Defaults to None.
67
67
  pip_requirements: List of pip Python packages requirements for running the model. Defaults to None.
68
- artifact_repository_map: A dict mapping from package channel to artifact repository name.
68
+ artifact_repository_map: A dict mapping from package channel to artifact repository name (e.g.
69
+ {'pip': 'snowflake.snowpark.pypi_shared_repository'}).
69
70
  resource_constraint: Mapping of resource constraint keys and values, e.g. {"architecture": "x86"}.
70
71
  target_platforms: List of target platforms to run the model.
71
72
  python_version: A string of python version where model is run. Used for user override. If specified as None,
@@ -63,6 +63,7 @@ class MLFlowModelBlobOptions(BaseModelBlobOptions):
63
63
 
64
64
  class XgboostModelBlobOptions(BaseModelBlobOptions):
65
65
  xgb_estimator_type: Required[str]
66
+ enable_categorical: NotRequired[bool]
66
67
 
67
68
 
68
69
  class PyTorchModelBlobOptions(BaseModelBlobOptions):
@@ -6,13 +6,13 @@ REQUIREMENTS = [
6
6
  "aiohttp!=4.0.0a0, !=4.0.0a1",
7
7
  "anyio>=3.5.0,<5",
8
8
  "cachetools>=3.1.1,<6",
9
- "cloudpickle>=2.0.0,<3",
9
+ "cloudpickle>=2.0.0",
10
10
  "cryptography",
11
11
  "fsspec>=2024.6.1,<2026",
12
12
  "importlib_resources>=6.1.1, <7",
13
13
  "numpy>=1.23,<2",
14
14
  "packaging>=20.9,<25",
15
- "pandas>=1.0.0,<3",
15
+ "pandas>=2.1.4,<3",
16
16
  "pyarrow",
17
17
  "pydantic>=2.8.2, <3",
18
18
  "pyjwt>=2.0.0, <3",
@@ -21,9 +21,10 @@ REQUIREMENTS = [
21
21
  "requests",
22
22
  "retrying>=1.3.3,<2",
23
23
  "s3fs>=2024.6.1,<2026",
24
- "scikit-learn>=1.4,<1.6",
24
+ "scikit-learn<1.6",
25
25
  "scipy>=1.9,<2",
26
- "snowflake-connector-python>=3.12.0,<4",
26
+ "shap>=0.46.0,<1",
27
+ "snowflake-connector-python>=3.15.0,<4",
27
28
  "snowflake-snowpark-python>=1.17.0,<2,!=1.26.0",
28
29
  "snowflake.core>=1.0.2,<2",
29
30
  "sqlparse>=0.4,<1",
@@ -81,8 +81,16 @@ class XGBoostDMatrixHandler(base_handler.BaseDataHandler["xgboost.DMatrix"]):
81
81
  ) -> "xgboost.DMatrix":
82
82
  import xgboost as xgb
83
83
 
84
+ enable_categorical = False
85
+ for col, d_type in df.dtypes.items():
86
+ if pd.api.extensions.ExtensionDtype.is_dtype(d_type):
87
+ continue
88
+ if not np.issubdtype(d_type, np.number):
89
+ df[col] = df[col].astype("category")
90
+ enable_categorical = True
91
+
84
92
  if not features:
85
- return xgb.DMatrix(df)
93
+ return xgb.DMatrix(df, enable_categorical=enable_categorical)
86
94
  else:
87
95
  feature_names = []
88
96
  feature_types = []
@@ -95,4 +103,9 @@ class XGBoostDMatrixHandler(base_handler.BaseDataHandler["xgboost.DMatrix"]):
95
103
  assert isinstance(feature, core.FeatureSpec), "Invalid feature kind."
96
104
  feature_names.append(feature.name)
97
105
  feature_types.append(feature._dtype._numpy_type)
98
- return xgb.DMatrix(df, feature_names=feature_names, feature_types=feature_types)
106
+ return xgb.DMatrix(
107
+ df,
108
+ feature_names=feature_names,
109
+ feature_types=feature_types,
110
+ enable_categorical=enable_categorical,
111
+ )