snowflake-ml-python 1.8.3__py3-none-any.whl → 1.8.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (196) hide show
  1. snowflake/cortex/__init__.py +7 -1
  2. snowflake/ml/_internal/platform_capabilities.py +13 -11
  3. snowflake/ml/_internal/telemetry.py +42 -13
  4. snowflake/ml/_internal/utils/identifier.py +2 -2
  5. snowflake/ml/data/data_connector.py +1 -1
  6. snowflake/ml/jobs/_utils/constants.py +10 -1
  7. snowflake/ml/jobs/_utils/interop_utils.py +1 -1
  8. snowflake/ml/jobs/_utils/payload_utils.py +51 -34
  9. snowflake/ml/jobs/_utils/scripts/constants.py +6 -0
  10. snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
  11. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +86 -3
  12. snowflake/ml/jobs/_utils/spec_utils.py +8 -6
  13. snowflake/ml/jobs/decorators.py +13 -3
  14. snowflake/ml/jobs/job.py +206 -26
  15. snowflake/ml/jobs/manager.py +78 -34
  16. snowflake/ml/model/_client/model/model_version_impl.py +1 -1
  17. snowflake/ml/model/_client/ops/service_ops.py +31 -17
  18. snowflake/ml/model/_client/service/model_deployment_spec.py +351 -170
  19. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +25 -0
  20. snowflake/ml/model/_client/sql/model_version.py +1 -1
  21. snowflake/ml/model/_client/sql/service.py +20 -32
  22. snowflake/ml/model/_model_composer/model_composer.py +44 -19
  23. snowflake/ml/model/_packager/model_handlers/_utils.py +32 -2
  24. snowflake/ml/model/_packager/model_handlers/custom.py +1 -1
  25. snowflake/ml/model/_packager/model_handlers/pytorch.py +1 -2
  26. snowflake/ml/model/_packager/model_handlers/sklearn.py +100 -41
  27. snowflake/ml/model/_packager/model_handlers/tensorflow.py +7 -4
  28. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  29. snowflake/ml/model/_packager/model_handlers/xgboost.py +16 -7
  30. snowflake/ml/model/_packager/model_meta/model_meta.py +2 -1
  31. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
  32. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +5 -4
  33. snowflake/ml/model/_signatures/dmatrix_handler.py +15 -2
  34. snowflake/ml/model/custom_model.py +17 -4
  35. snowflake/ml/model/model_signature.py +3 -3
  36. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
  37. snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
  38. snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
  39. snowflake/ml/modeling/cluster/birch.py +9 -1
  40. snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
  41. snowflake/ml/modeling/cluster/dbscan.py +9 -1
  42. snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
  43. snowflake/ml/modeling/cluster/k_means.py +9 -1
  44. snowflake/ml/modeling/cluster/mean_shift.py +9 -1
  45. snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
  46. snowflake/ml/modeling/cluster/optics.py +9 -1
  47. snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
  48. snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
  49. snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
  50. snowflake/ml/modeling/compose/column_transformer.py +9 -1
  51. snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
  52. snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
  53. snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
  54. snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
  55. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
  56. snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
  57. snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
  58. snowflake/ml/modeling/covariance/oas.py +9 -1
  59. snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
  60. snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
  61. snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
  62. snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
  63. snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
  64. snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
  65. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
  66. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
  67. snowflake/ml/modeling/decomposition/pca.py +9 -1
  68. snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
  69. snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
  70. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
  71. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
  72. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
  73. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
  74. snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
  75. snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
  76. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
  77. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
  78. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
  79. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
  80. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
  81. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
  82. snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
  83. snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
  84. snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
  85. snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
  86. snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
  87. snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
  88. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
  89. snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
  90. snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
  91. snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
  92. snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
  93. snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
  94. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
  95. snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
  96. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
  97. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
  98. snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
  99. snowflake/ml/modeling/impute/knn_imputer.py +9 -1
  100. snowflake/ml/modeling/impute/missing_indicator.py +9 -1
  101. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
  102. snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
  103. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
  104. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
  105. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
  106. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
  107. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
  108. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
  109. snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
  110. snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
  111. snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
  112. snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
  113. snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
  114. snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
  115. snowflake/ml/modeling/linear_model/lars.py +9 -1
  116. snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
  117. snowflake/ml/modeling/linear_model/lasso.py +9 -1
  118. snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
  119. snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
  120. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
  121. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
  122. snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
  123. snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
  124. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
  125. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
  126. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
  127. snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
  128. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
  129. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
  130. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
  131. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
  132. snowflake/ml/modeling/linear_model/perceptron.py +9 -1
  133. snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
  134. snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
  135. snowflake/ml/modeling/linear_model/ridge.py +9 -1
  136. snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
  137. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
  138. snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
  139. snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
  140. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
  141. snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
  142. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
  143. snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
  144. snowflake/ml/modeling/manifold/isomap.py +9 -1
  145. snowflake/ml/modeling/manifold/mds.py +9 -1
  146. snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
  147. snowflake/ml/modeling/manifold/tsne.py +9 -1
  148. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
  149. snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
  150. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
  151. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
  152. snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
  153. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
  154. snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
  155. snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
  156. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
  157. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
  158. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
  159. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
  160. snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
  161. snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
  162. snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
  163. snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
  164. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
  165. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
  166. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
  167. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
  168. snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
  169. snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
  170. snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
  171. snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
  172. snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
  173. snowflake/ml/modeling/svm/linear_svc.py +9 -1
  174. snowflake/ml/modeling/svm/linear_svr.py +9 -1
  175. snowflake/ml/modeling/svm/nu_svc.py +9 -1
  176. snowflake/ml/modeling/svm/nu_svr.py +9 -1
  177. snowflake/ml/modeling/svm/svc.py +9 -1
  178. snowflake/ml/modeling/svm/svr.py +9 -1
  179. snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
  180. snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
  181. snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
  182. snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
  183. snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
  184. snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
  185. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
  186. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
  187. snowflake/ml/monitoring/explain_visualize.py +424 -0
  188. snowflake/ml/registry/_manager/model_manager.py +23 -2
  189. snowflake/ml/registry/registry.py +10 -9
  190. snowflake/ml/utils/connection_params.py +8 -2
  191. snowflake/ml/version.py +1 -1
  192. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/METADATA +58 -8
  193. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/RECORD +196 -195
  194. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/WHEEL +1 -1
  195. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/licenses/LICENSE.txt +0 -0
  196. {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,424 @@
1
+ from typing import Any, Union, cast, overload
2
+
3
+ import altair as alt
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+ import snowflake.snowpark.dataframe as sp_df
8
+ from snowflake import snowpark
9
+ from snowflake.ml._internal.exceptions import error_codes, exceptions
10
+ from snowflake.ml.model import model_signature, type_hints
11
+ from snowflake.ml.model._signatures import snowpark_handler
12
+
13
+ DEFAULT_FIGSIZE = (1400, 500)
14
+ DEFAULT_VIOLIN_FIGSIZE = (1400, 100)
15
+ MAX_ANNOTATION_LENGTH = 20
16
+ MIN_DISTANCE = 10 # Increase minimum distance between labels for more spreading in plot_force
17
+
18
+
19
+ @overload
20
+ def plot_force(
21
+ shap_row: snowpark.Row,
22
+ features_row: snowpark.Row,
23
+ base_value: float = 0.0,
24
+ figsize: tuple[float, float] = DEFAULT_FIGSIZE,
25
+ contribution_threshold: float = 0.05,
26
+ ) -> alt.LayerChart:
27
+ ...
28
+
29
+
30
+ @overload
31
+ def plot_force(
32
+ shap_row: pd.Series,
33
+ features_row: pd.Series,
34
+ base_value: float = 0.0,
35
+ figsize: tuple[float, float] = DEFAULT_FIGSIZE,
36
+ contribution_threshold: float = 0.05,
37
+ ) -> alt.LayerChart:
38
+ ...
39
+
40
+
41
+ def plot_force(
42
+ shap_row: Union[pd.Series, snowpark.Row],
43
+ features_row: Union[pd.Series, snowpark.Row],
44
+ base_value: float = 0.0,
45
+ figsize: tuple[float, float] = DEFAULT_FIGSIZE,
46
+ contribution_threshold: float = 0.05,
47
+ ) -> alt.LayerChart:
48
+ """
49
+ Create a force plot for SHAP values with stacked bars based on influence direction.
50
+
51
+ Args:
52
+ shap_row: pandas Series or snowpark Row containing SHAP values for a specific instance
53
+ features_row: pandas Series or snowpark Row containing the feature values for the same instance
54
+ base_value: base value of the predictions. Defaults to 0, but is usually the model's average prediction
55
+ figsize: tuple of (width, height) for the plot
56
+ contribution_threshold:
57
+ Only features with magnitude greater than contribution_threshold as a percentage of the
58
+ total absolute SHAP values will be plotted. Defaults to 0.05 (5%)
59
+
60
+ Returns:
61
+ Altair chart object
62
+
63
+ Raises:
64
+ SnowflakeMLException: If the contribution threshold is not between 0 and 1,
65
+ or if no features with significant contributions are found.
66
+ """
67
+ if not (0 < contribution_threshold and contribution_threshold < 1):
68
+ raise exceptions.SnowflakeMLException(
69
+ error_code=error_codes.INVALID_ARGUMENT,
70
+ original_exception=ValueError("contribution_threshold must be between 0 and 1."),
71
+ )
72
+
73
+ if isinstance(shap_row, snowpark.Row):
74
+ shap_row = pd.Series(shap_row.as_dict())
75
+ if isinstance(features_row, snowpark.Row):
76
+ features_row = pd.Series(features_row.as_dict())
77
+
78
+ # Create a dataframe for plotting
79
+ positive_label = "Positive"
80
+ negative_label = "Negative"
81
+ plot_df = pd.DataFrame(
82
+ [
83
+ {
84
+ "feature": feature,
85
+ "feature_value": features_row.iloc[index],
86
+ "feature_annotated": f"{feature}: {features_row.iloc[index]}"[:MAX_ANNOTATION_LENGTH],
87
+ "influence_value": shap_row.iloc[index],
88
+ "bar_direction": positive_label if shap_row.iloc[index] >= 0 else negative_label,
89
+ }
90
+ for index, feature in enumerate(features_row.index)
91
+ ]
92
+ )
93
+
94
+ # Calculate cumulative positions for the stacked bars
95
+ shap_sum = np.sum(shap_row)
96
+ current_position_pos = shap_sum
97
+ current_position_neg = shap_sum
98
+ positions = []
99
+
100
+ total_abs_value_sum = np.sum(plot_df["influence_value"].abs())
101
+ max_abs_value = plot_df["influence_value"].abs().max()
102
+ spacing = max_abs_value * 0.07 # Use 2% of max value as spacing between bars
103
+
104
+ # Sort by absolute value to have largest impacts first
105
+ plot_df = plot_df.reindex(plot_df["influence_value"].abs().sort_values(ascending=False).index)
106
+ for _, row in plot_df.iterrows():
107
+ # Skip features with small contributions
108
+ row_influence_value = row["influence_value"]
109
+ if abs(row_influence_value) / total_abs_value_sum < contribution_threshold:
110
+ continue
111
+
112
+ if row_influence_value >= 0:
113
+ start = current_position_pos - spacing
114
+ end = current_position_pos - row_influence_value - spacing
115
+ current_position_pos = end
116
+ else:
117
+ start = current_position_neg + spacing
118
+ end = current_position_neg + abs(row_influence_value) + spacing
119
+ current_position_neg = end
120
+
121
+ positions.append(
122
+ {
123
+ "start": start,
124
+ "end": end,
125
+ "avg": (start + end) / 2,
126
+ "influence_value": row_influence_value,
127
+ "feature_value": row["feature_value"],
128
+ "feature_annotated": row["feature_annotated"],
129
+ "bar_direction": row["bar_direction"],
130
+ "bar_y": 0,
131
+ "feature": row["feature"],
132
+ }
133
+ )
134
+
135
+ if len(positions) == 0:
136
+ raise exceptions.SnowflakeMLException(
137
+ error_code=error_codes.INVALID_ARGUMENT,
138
+ original_exception=ValueError(
139
+ "No features with significant contributions found. Try lowering the contribution_threshold,"
140
+ "and verify the input is non-empty."
141
+ ),
142
+ )
143
+
144
+ position_df = pd.DataFrame(positions)
145
+
146
+ # Create force plot using Altair
147
+ blue_color = "#1f77b4"
148
+ red_color = "#d62728"
149
+ width, height = figsize
150
+ bars: alt.Chart = (
151
+ alt.Chart(position_df)
152
+ .mark_bar(size=10)
153
+ .encode(
154
+ x=alt.X("start:Q", title="Feature Impact"),
155
+ x2=alt.X2("end:Q"),
156
+ y=alt.Y("bar_y:Q", axis=None),
157
+ color=alt.Color(
158
+ "bar_direction:N",
159
+ scale=alt.Scale(domain=[positive_label, negative_label], range=[red_color, blue_color]),
160
+ legend=alt.Legend(title="Influence Direction"),
161
+ ),
162
+ tooltip=["feature", "influence_value", "feature_value"],
163
+ )
164
+ .properties(title="Feature Influence (SHAP values)", width=width, height=height)
165
+ ).interactive()
166
+
167
+ arrow: alt.Chart = (
168
+ alt.Chart(position_df)
169
+ .mark_point(shape="triangle", filled=True, fillOpacity=1)
170
+ .encode(
171
+ x=alt.X("start:Q"),
172
+ y=alt.Y("bar_y:Q", axis=None),
173
+ angle=alt.Angle("bar_direction:N", scale=alt.Scale(domain=["Positive", "Negative"], range=[90, -90])),
174
+ color=alt.Color(
175
+ "bar_direction:N", scale=alt.Scale(domain=["Positive", "Negative"], range=["#1f77b4", "#d62728"])
176
+ ),
177
+ size=alt.SizeValue(300),
178
+ tooltip=alt.value(None),
179
+ )
180
+ )
181
+
182
+ # Add a vertical line at the base value
183
+ zero_line: alt.Chart = alt.Chart(pd.DataFrame({"x": [base_value]})).mark_rule(strokeDash=[3, 3]).encode(x="x:Q")
184
+
185
+ # Calculate label positions to avoid overlap and ensure labels are spread apart horizontally
186
+
187
+ # Sort by bar center (avg) for label placement
188
+ sorted_positions = sorted(positions, key=lambda x: x["avg"])
189
+
190
+ # Improved label spreading algorithm:
191
+ # Calculate the minimum and maximum x positions (avg) for the bars
192
+ min_x = min(pos["avg"] for pos in sorted_positions)
193
+ max_x = max(pos["avg"] for pos in sorted_positions)
194
+ n_labels = len(sorted_positions)
195
+ # Calculate the minimum required distance between labels
196
+ spread_width = max_x - min_x
197
+ if n_labels > 1:
198
+ space_per_label = spread_width / (n_labels - 1)
199
+ # If space_per_label is less than min_distance, use min_distance instead
200
+ effective_distance = max(space_per_label, MIN_DISTANCE)
201
+ else:
202
+ effective_distance = 0
203
+
204
+ # Start from min_x - offset, and assign label_x for each label from left to right
205
+ offset = -effective_distance # Start a bit to the left
206
+ label_positions = []
207
+ label_lines = []
208
+ placed_label_xs: list[float] = []
209
+ for i, pos in enumerate(sorted_positions):
210
+ if i == 0:
211
+ label_x = min_x + offset
212
+ else:
213
+ label_x = placed_label_xs[-1] + effective_distance
214
+ placed_label_xs.append(label_x)
215
+ label_positions.append(
216
+ {
217
+ "label_x": label_x,
218
+ "label_y": 1, # Place labels below the bars
219
+ "feature_annotated": pos["feature_annotated"],
220
+ "feature_value": pos["feature_value"],
221
+ }
222
+ )
223
+ # Draw a diagonal line from the bar to the label
224
+ label_lines.append(
225
+ {
226
+ "x": pos["avg"],
227
+ "x2": label_x,
228
+ "y": 0,
229
+ "y2": 1,
230
+ }
231
+ )
232
+
233
+ label_positions_df = pd.DataFrame(label_positions)
234
+ label_lines_df = pd.DataFrame(label_lines)
235
+
236
+ # Draw diagonal lines from bar to label
237
+ label_connectors = (
238
+ alt.Chart(label_lines_df)
239
+ .mark_rule(strokeDash=[2, 2], color="grey")
240
+ .encode(
241
+ x="x:Q",
242
+ x2="x2:Q",
243
+ y=alt.Y("y:Q", axis=None),
244
+ y2="y2:Q",
245
+ )
246
+ )
247
+
248
+ # Place labels at adjusted positions
249
+ feature_labels = (
250
+ alt.Chart(label_positions_df)
251
+ .mark_text(align="center", baseline="line-bottom", dy=0, fontSize=11)
252
+ .encode(
253
+ x=alt.X("label_x:Q"),
254
+ y=alt.Y("label_y:Q", axis=None),
255
+ text=alt.Text("feature_annotated:N"),
256
+ color=alt.value("grey"),
257
+ tooltip=["feature_value"],
258
+ )
259
+ )
260
+
261
+ return cast(alt.LayerChart, bars + feature_labels + zero_line + arrow + label_connectors)
262
+
263
+
264
+ def plot_influence_sensitivity(
265
+ shap_values: type_hints.SupportedDataType,
266
+ feature_values: type_hints.SupportedDataType,
267
+ figsize: tuple[float, float] = DEFAULT_FIGSIZE,
268
+ ) -> Any:
269
+ """
270
+ Create a SHAP dependence scatter plot for a specific feature. If a DataFrame is provided, a select box
271
+ will be displayed to select the feature. This is only supported in Snowflake notebooks.
272
+ If Streamlit is not available and a DataFrame is passed in, an ImportError will be raised.
273
+
274
+ Args:
275
+ feature_values: pandas Series or 2D array containing the feature values for a specific feature
276
+ shap_values: pandas Series or 2D array containing the SHAP values for the same feature
277
+ figsize: tuple of (width, height) for the plot
278
+
279
+ Returns:
280
+ Altair chart object
281
+
282
+ Raises:
283
+ ValueError: If the types of feature_values and shap_values are not the same
284
+
285
+ """
286
+
287
+ use_streamlit = False
288
+ feature_values_df = _convert_to_pandas_df(feature_values)
289
+ shap_values_df = _convert_to_pandas_df(shap_values)
290
+
291
+ if len(shap_values_df.shape) > 1:
292
+ feature_values, shap_values, st = _prepare_feature_values_for_streamlit(feature_values_df, shap_values_df)
293
+ use_streamlit = True
294
+ elif feature_values_df.shape[0] != shap_values_df.shape[0]:
295
+ raise ValueError("Feature values and SHAP values must have the same number of rows.")
296
+
297
+ scatter = _create_scatter_plot(feature_values, shap_values, figsize)
298
+ return st.altair_chart(scatter) if use_streamlit else scatter
299
+
300
+
301
+ def _prepare_feature_values_for_streamlit(
302
+ feature_values_df: pd.DataFrame, shap_values: pd.DataFrame
303
+ ) -> tuple[pd.Series, pd.Series, Any]:
304
+ try:
305
+ from IPython import get_ipython
306
+ from snowbook.executor.python_transformer import IPythonProxy
307
+
308
+ assert isinstance(
309
+ get_ipython(), IPythonProxy
310
+ ), "Influence sensitivity plots for a DataFrame are not supported outside of Snowflake notebooks."
311
+ except ImportError:
312
+ raise RuntimeError(
313
+ "Influence sensitivity plots for a DataFrame are not supported outside of Snowflake notebooks."
314
+ )
315
+
316
+ import streamlit as st
317
+
318
+ feature_columns = feature_values_df.columns
319
+ chosen_ft: str = st.selectbox("Feature:", feature_columns)
320
+ feature_values = feature_values_df[chosen_ft]
321
+ shap_values = shap_values.iloc[:, feature_columns.get_loc(chosen_ft)]
322
+ return feature_values, shap_values, st
323
+
324
+
325
+ def _create_scatter_plot(feature_values: pd.Series, shap_values: pd.Series, figsize: tuple[float, float]) -> alt.Chart:
326
+ unique_vals = np.sort(np.unique(feature_values.values))
327
+ max_points_per_unique_value = float(np.max(np.bincount(np.searchsorted(unique_vals, feature_values.values))))
328
+ points_per_value = len(feature_values.values) / len(unique_vals)
329
+ is_categorical = float(max(max_points_per_unique_value, points_per_value)) > 10
330
+
331
+ kwargs = (
332
+ {
333
+ "x": alt.X("feature_value:N", title="Feature Value"),
334
+ "color": alt.Color("feature_value:N").legend(None),
335
+ "xOffset": "jitter:Q",
336
+ }
337
+ if is_categorical
338
+ else {"x": alt.X("feature_value:Q", title="Feature Value")}
339
+ )
340
+
341
+ # Create a dataframe for plotting
342
+ plot_df = pd.DataFrame({"feature_value": feature_values, "shap_value": shap_values})
343
+
344
+ width, height = figsize
345
+
346
+ # Create scatter plot
347
+ scatter = (
348
+ alt.Chart(plot_df)
349
+ .transform_calculate(jitter="random()")
350
+ .mark_circle(size=60, opacity=0.7)
351
+ .encode(
352
+ y=alt.Y("shap_value:Q", title="SHAP Value"),
353
+ tooltip=["feature_value", "shap_value"],
354
+ **kwargs,
355
+ )
356
+ .properties(title="SHAP Dependence Scatter Plot", width=width, height=height)
357
+ )
358
+
359
+ return cast(alt.Chart, scatter)
360
+
361
+
362
+ def plot_violin(
363
+ shap_df: type_hints.SupportedDataType,
364
+ feature_df: type_hints.SupportedDataType,
365
+ figsize: tuple[float, float] = DEFAULT_VIOLIN_FIGSIZE,
366
+ ) -> alt.Chart:
367
+ """
368
+ Create a violin plot per feature showing the distribution of SHAP values.
369
+
370
+ Args:
371
+ shap_df: 2D array containing SHAP values for multiple features
372
+ feature_df: 2D array containing the corresponding feature values
373
+ figsize: tuple of (width, height) for the plot
374
+
375
+ Returns:
376
+ Altair chart object
377
+ """
378
+
379
+ shap_df_pd = _convert_to_pandas_df(shap_df)
380
+ feature_df_pd = _convert_to_pandas_df(feature_df)
381
+
382
+ # Assert that the input dataframes are 2D
383
+ assert len(shap_df_pd.shape) == 2, f"shap_df must be 2D, but got shape {shap_df_pd.shape}"
384
+ assert len(feature_df_pd.shape) == 2, f"feature_df must be 2D, but got shape {feature_df_pd.shape}"
385
+
386
+ # Prepare data for plotting
387
+ plot_data = pd.DataFrame(
388
+ {
389
+ "feature_name": feature_df_pd.columns.repeat(shap_df_pd.shape[0]),
390
+ "shap_value": shap_df_pd.transpose().values.flatten(),
391
+ }
392
+ )
393
+
394
+ # Order the rows by the absolute sum of SHAP values per feature
395
+ feature_abs_sum = shap_df_pd.abs().sum(axis=0)
396
+ sorted_features = feature_abs_sum.sort_values(ascending=False).index
397
+ column_sort_order = [feature_df_pd.columns[shap_df_pd.columns.get_loc(col)] for col in sorted_features]
398
+
399
+ # Create the violin plot
400
+ width, height = figsize
401
+ violin = (
402
+ alt.Chart(plot_data)
403
+ .transform_density(density="shap_value", groupby=["feature_name"], as_=["shap_value", "density"])
404
+ .mark_area(orient="vertical")
405
+ .encode(
406
+ y=alt.Y("density:Q", title=None).stack("center").impute(None).axis(labels=False, grid=False, ticks=True),
407
+ x=alt.X("shap_value:Q", title="SHAP Value"),
408
+ row=alt.Row("feature_name:N", sort=column_sort_order).spacing(0),
409
+ color=alt.Color("feature_name:N", legend=None),
410
+ tooltip=["feature_name", "shap_value"],
411
+ )
412
+ .properties(width=width, height=height)
413
+ ).interactive()
414
+
415
+ return cast(alt.Chart, violin)
416
+
417
+
418
+ def _convert_to_pandas_df(
419
+ data: type_hints.SupportedDataType,
420
+ ) -> pd.DataFrame:
421
+ if isinstance(data, sp_df.DataFrame):
422
+ return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(data)
423
+
424
+ return model_signature._convert_local_data_to_df(data)
@@ -12,8 +12,10 @@ from snowflake.ml.model import model_signature, type_hints as model_types
12
12
  from snowflake.ml.model._client.model import model_impl, model_version_impl
13
13
  from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
14
14
  from snowflake.ml.model._model_composer import model_composer
15
+ from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
15
16
  from snowflake.ml.model._packager.model_meta import model_meta
16
17
  from snowflake.snowpark import exceptions as snowpark_exceptions, session
18
+ from snowflake.snowpark._internal import utils as snowpark_utils
17
19
 
18
20
  logger = logging.getLogger(__name__)
19
21
 
@@ -169,7 +171,10 @@ class ModelManager:
169
171
  database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
170
172
  version_name_id = sql_identifier.SqlIdentifier(version_name)
171
173
 
172
- use_live_commit = platform_capabilities.PlatformCapabilities.get_instance().is_live_commit_enabled()
174
+ # TODO(SNOW-2091317): Remove this when the snowpark enables file PUT operation for snowurls
175
+ use_live_commit = (
176
+ not snowpark_utils.is_in_stored_procedure() # type: ignore[no-untyped-call]
177
+ ) and platform_capabilities.PlatformCapabilities.get_instance().is_live_commit_enabled()
173
178
  if use_live_commit:
174
179
  logger.info("Using live commit model version")
175
180
  else:
@@ -212,8 +217,24 @@ class ModelManager:
212
217
  # Convert any string target platforms to TargetPlatform objects
213
218
  platforms = [model_types.TargetPlatform(platform) for platform in target_platforms]
214
219
  else:
220
+ # Default the target platform to warehouse if not specified and any table function exists
221
+ if options and (
222
+ options.get("function_type") == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
223
+ or (
224
+ any(
225
+ opt.get("function_type") == "TABLE_FUNCTION"
226
+ for opt in options.get("method_options", {}).values()
227
+ )
228
+ )
229
+ ):
230
+ logger.info(
231
+ "Logging a partitioned model with a table function without specifying `target_platforms`. "
232
+ 'Default to `target_platforms=["WAREHOUSE"]`.'
233
+ )
234
+ platforms = [model_types.TargetPlatform.WAREHOUSE]
235
+
215
236
  # Default the target platform to SPCS if not specified when running in ML runtime
216
- if env.IN_ML_RUNTIME:
237
+ if not platforms and env.IN_ML_RUNTIME:
217
238
  logger.info(
218
239
  "Logging the model on Container Runtime for ML without specifying `target_platforms`. "
219
240
  'Default to `target_platforms=["SNOWPARK_CONTAINER_SERVICES"]`.'
@@ -148,11 +148,11 @@ class Registry:
148
148
  dependencies must be retrieved from Snowflake Anaconda Channel.
149
149
  artifact_repository_map: Specifies a mapping of package channels or platforms to custom artifact
150
150
  repositories. Defaults to None. Currently, the mapping applies only to warehouse execution.
151
- Note : This feature is currently in Private Preview; please contact your Snowflake account team
152
- to enable it.
151
+ Note : This feature is currently in Public Preview.
153
152
  Format: {channel_name: artifact_repository_name}, where:
154
- - channel_name: The name of the Conda package channel (e.g., 'condaforge') or 'pip' for pip packages.
155
- - artifact_repository_name: The name or URL of the repository to fetch packages from.
153
+ - channel_name: Currently must be 'pip'.
154
+ - artifact_repository_name: The identifier of the artifact repository to fetch packages from, e.g.
155
+ `snowflake.snowpark.pypi_shared_repository`.
156
156
  resource_constraint: Mapping of resource constraint keys and values, e.g. {"architecture": "x86"}.
157
157
  target_platforms: List of target platforms to run the model. The only acceptable inputs are a combination of
158
158
  {"WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"}. Defaults to None.
@@ -288,14 +288,15 @@ class Registry:
288
288
  dependencies must be retrieved from Snowflake Anaconda Channel.
289
289
  artifact_repository_map: Specifies a mapping of package channels or platforms to custom artifact
290
290
  repositories. Defaults to None. Currently, the mapping applies only to warehouse execution.
291
- Note : This feature is currently in Private Preview; please contact your Snowflake account team to
292
- enable it.
291
+ Note : This feature is currently in Public Preview.
293
292
  Format: {channel_name: artifact_repository_name}, where:
294
- - channel_name: The name of the Conda package channel (e.g., 'condaforge') or 'pip' for pip packages.
295
- - artifact_repository_name: The name or URL of the repository to fetch packages from.
293
+ - channel_name: Currently must be 'pip'.
294
+ - artifact_repository_name: The identifier of the artifact repository to fetch packages from, e.g.
295
+ `snowflake.snowpark.pypi_shared_repository`.
296
296
  resource_constraint: Mapping of resource constraint keys and values, e.g. {"architecture": "x86"}.
297
297
  target_platforms: List of target platforms to run the model. The only acceptable inputs are a combination of
298
- {"WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"}. Defaults to None.
298
+ ["WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"]. Defaults to None. When None, the target platforms will be
299
+ both.
299
300
  python_version: Python version in which the model is run. Defaults to None.
300
301
  signatures: Model data signatures for inputs and outputs for various target methods. If it is None,
301
302
  sample_input_data would be used to infer the signatures for those models that cannot automatically
@@ -113,6 +113,10 @@ def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -
113
113
 
114
114
  config = configparser.ConfigParser(inline_comment_prefixes="#")
115
115
 
116
+ snowflake_connection_name = os.getenv("SNOWFLAKE_CONNECTION_NAME")
117
+ if snowflake_connection_name is not None:
118
+ connection_name = snowflake_connection_name
119
+
116
120
  if connection_name:
117
121
  if not connection_name.startswith("connections."):
118
122
  connection_name = "connections." + connection_name
@@ -153,9 +157,11 @@ def SnowflakeLoginOptions(connection_name: str = "", login_file: Optional[str] =
153
157
  Ideally one should have a snowsql config file. Read more here:
154
158
  https://docs.snowflake.com/en/user-guide/snowsql-start.html#configuring-default-connection-settings
155
159
 
160
+ If snowsql config file does not exist, it tries auth from env variables.
161
+
156
162
  Args:
157
- connection_name: Name of the connection to look for inside the config file. If `connection_name` is NOT given,
158
- it tries auth from env variables.
163
+ connection_name: Name of the connection to look for inside the config file. If environment variable
164
+ SNOWFLAKE_CONNECTION_NAME is provided, it will override the input connection_name.
159
165
  login_file: If provided, this is used as config file instead of default one (_DEFAULT_CONNECTION_FILE).
160
166
 
161
167
  Returns:
snowflake/ml/version.py CHANGED
@@ -1,2 +1,2 @@
1
1
  # This is parsed by regex in conda recipe meta file. Make sure not to break it.
2
- VERSION = "1.8.3"
2
+ VERSION = "1.8.5"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: snowflake-ml-python
3
- Version: 1.8.3
3
+ Version: 1.8.5
4
4
  Summary: The machine learning client library that is used for interacting with Snowflake to build machine learning solutions.
5
5
  Author-email: "Snowflake, Inc" <support@snowflake.com>
6
6
  License:
@@ -236,13 +236,13 @@ License-File: LICENSE.txt
236
236
  Requires-Dist: absl-py<2,>=0.15
237
237
  Requires-Dist: anyio<5,>=3.5.0
238
238
  Requires-Dist: cachetools<6,>=3.1.1
239
- Requires-Dist: cloudpickle<3,>=2.0.0
239
+ Requires-Dist: cloudpickle>=2.0.0
240
240
  Requires-Dist: cryptography
241
241
  Requires-Dist: fsspec[http]<2026,>=2024.6.1
242
242
  Requires-Dist: importlib_resources<7,>=6.1.1
243
243
  Requires-Dist: numpy<2,>=1.23
244
244
  Requires-Dist: packaging<25,>=20.9
245
- Requires-Dist: pandas<3,>=1.0.0
245
+ Requires-Dist: pandas<3,>=2.1.4
246
246
  Requires-Dist: pyarrow
247
247
  Requires-Dist: pydantic<3,>=2.8.2
248
248
  Requires-Dist: pyjwt<3,>=2.0.0
@@ -250,27 +250,31 @@ Requires-Dist: pytimeparse<2,>=1.1.8
250
250
  Requires-Dist: pyyaml<7,>=6.0
251
251
  Requires-Dist: retrying<2,>=1.3.3
252
252
  Requires-Dist: s3fs<2026,>=2024.6.1
253
- Requires-Dist: scikit-learn<1.6,>=1.4
253
+ Requires-Dist: scikit-learn<1.6
254
254
  Requires-Dist: scipy<2,>=1.9
255
- Requires-Dist: snowflake-connector-python[pandas]<4,>=3.12.0
255
+ Requires-Dist: shap<1,>=0.46.0
256
+ Requires-Dist: snowflake-connector-python[pandas]<4,>=3.15.0
256
257
  Requires-Dist: snowflake-snowpark-python!=1.26.0,<2,>=1.17.0
257
258
  Requires-Dist: snowflake.core<2,>=1.0.2
258
259
  Requires-Dist: sqlparse<1,>=0.4
259
260
  Requires-Dist: typing-extensions<5,>=4.1.0
260
261
  Requires-Dist: xgboost<3,>=1.7.3
261
262
  Provides-Extra: all
263
+ Requires-Dist: altair<6,>=5; extra == "all"
262
264
  Requires-Dist: catboost<2,>=1.2.0; extra == "all"
263
265
  Requires-Dist: keras<4,>=2.0.0; extra == "all"
264
266
  Requires-Dist: lightgbm<5,>=4.1.0; extra == "all"
265
267
  Requires-Dist: mlflow<3,>=2.16.0; extra == "all"
266
268
  Requires-Dist: sentence-transformers<4,>=2.7.0; extra == "all"
267
269
  Requires-Dist: sentencepiece<0.2.0,>=0.1.95; extra == "all"
268
- Requires-Dist: shap<1,>=0.46.0; extra == "all"
270
+ Requires-Dist: streamlit<2,>=1.30.0; extra == "all"
269
271
  Requires-Dist: tensorflow<3,>=2.17.0; extra == "all"
270
272
  Requires-Dist: tokenizers<1,>=0.15.1; extra == "all"
271
273
  Requires-Dist: torch<3,>=2.0.1; extra == "all"
272
274
  Requires-Dist: torchdata<1,>=0.4; extra == "all"
273
275
  Requires-Dist: transformers<5,>=4.39.3; extra == "all"
276
+ Provides-Extra: altair
277
+ Requires-Dist: altair<6,>=5; extra == "altair"
274
278
  Provides-Extra: catboost
275
279
  Requires-Dist: catboost<2,>=1.2.0; extra == "catboost"
276
280
  Provides-Extra: keras
@@ -281,8 +285,8 @@ Provides-Extra: lightgbm
281
285
  Requires-Dist: lightgbm<5,>=4.1.0; extra == "lightgbm"
282
286
  Provides-Extra: mlflow
283
287
  Requires-Dist: mlflow<3,>=2.16.0; extra == "mlflow"
284
- Provides-Extra: shap
285
- Requires-Dist: shap<1,>=0.46.0; extra == "shap"
288
+ Provides-Extra: streamlit
289
+ Requires-Dist: streamlit<2,>=1.30.0; extra == "streamlit"
286
290
  Provides-Extra: tensorflow
287
291
  Requires-Dist: tensorflow<3,>=2.17.0; extra == "tensorflow"
288
292
  Provides-Extra: torch
@@ -404,6 +408,51 @@ NOTE: Version 1.7.0 is used as example here. Please choose the the latest versio
404
408
 
405
409
  # Release History
406
410
 
411
+ ## 1.8.5
412
+
413
+ ### Bug Fixes
414
+
415
+ - Registry: Fixed a bug when listing and deleting container services.
416
+ - Registry: Fixed explainability issue with scikit-learn pipelines, skipping explain function creation.
417
+ - Explainability: bump minimum streamlit version down to 1.30
418
+ - Modeling: Make XGBoost a required dependency (xgboost is not a required dependency in snowflake-ml-python 1.8.4).
419
+
420
+ ### Breaking change
421
+
422
+ - ML Job: Rename argument `num_instances` to `target_instances` in job submission APIs and
423
+ change type from `Optional[int]` to `int`
424
+
425
+ ### New Features
426
+
427
+ - Registry: No longer checks if the snowflake-ml-python version is available in the Snowflake Conda channel when logging
428
+ an SPCS-only model.
429
+ - ML Job: Add `min_instances` argument to the job decorator to allow waiting for workers to be ready.
430
+
431
+ ## 1.8.4 (2025-05-12)
432
+
433
+ ### Bug Fixes
434
+
435
+ - Registry: Default `enable_explainability` to True when the model can be deployed to Warehouse.
436
+ - Registry: Add `custom_model.partitioned_api` decorator and deprecate `partitioned_inference_api`.
437
+ - Registry: Fixed a bug when logging pytroch and tensorflow models that caused
438
+ `UnboundLocalError: local variable 'multiple_inputs' referenced before assignment`.
439
+
440
+ ### Breaking change
441
+
442
+ - ML Job: Updated property `id` to be fully qualified name; Introduced new property `name` to represent the ML Job name
443
+ - ML Job: Modified `list_jobs()` to return ML Job `name` instead of `id`
444
+ - Registry: Error in `log_model` if `enable_explainability` is True and model is only deployed to
445
+ Snowpark Container Services, instead of just user warning.
446
+
447
+ ### New Features
448
+
449
+ - ML Job: Extend `@remote` function decorator, `submit_file()` and `submit_directory()` to accept `database` and
450
+ `schema` parameters
451
+ - ML Job: Support querying by fully qualified name in `get_job()`
452
+ - Explainability: Added visualization functions to `snowflake.ml.monitoring` to plot explanations in notebooks.
453
+ - Explainability: Support explain for categorical transforms for sklearn pipeline
454
+ - Support categorical type for `xgboost.DMatrix` inputs.
455
+
407
456
  ## 1.8.3
408
457
 
409
458
  ### Bug Fixes
@@ -417,6 +466,7 @@ NOTE: Version 1.7.0 is used as example here. Please choose the the latest versio
417
466
  as a list of strings
418
467
  - Registry: Support `ModelVersion.run_job` to run inference with a single-node Snowpark Container Services job.
419
468
  - DataConnector: Removed PrPr decorators
469
+ - Registry: Default the target platform to warehouse when logging a partitioned model.
420
470
 
421
471
  ## 1.8.2
422
472