snowflake-ml-python 1.8.3__py3-none-any.whl → 1.8.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +7 -1
- snowflake/ml/_internal/platform_capabilities.py +13 -11
- snowflake/ml/_internal/telemetry.py +42 -13
- snowflake/ml/_internal/utils/identifier.py +2 -2
- snowflake/ml/data/data_connector.py +1 -1
- snowflake/ml/jobs/_utils/constants.py +10 -1
- snowflake/ml/jobs/_utils/interop_utils.py +1 -1
- snowflake/ml/jobs/_utils/payload_utils.py +51 -34
- snowflake/ml/jobs/_utils/scripts/constants.py +6 -0
- snowflake/ml/jobs/_utils/scripts/get_instance_ip.py +4 -4
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +86 -3
- snowflake/ml/jobs/_utils/spec_utils.py +8 -6
- snowflake/ml/jobs/decorators.py +13 -3
- snowflake/ml/jobs/job.py +206 -26
- snowflake/ml/jobs/manager.py +78 -34
- snowflake/ml/model/_client/model/model_version_impl.py +1 -1
- snowflake/ml/model/_client/ops/service_ops.py +31 -17
- snowflake/ml/model/_client/service/model_deployment_spec.py +351 -170
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +25 -0
- snowflake/ml/model/_client/sql/model_version.py +1 -1
- snowflake/ml/model/_client/sql/service.py +20 -32
- snowflake/ml/model/_model_composer/model_composer.py +44 -19
- snowflake/ml/model/_packager/model_handlers/_utils.py +32 -2
- snowflake/ml/model/_packager/model_handlers/custom.py +1 -1
- snowflake/ml/model/_packager/model_handlers/pytorch.py +1 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +100 -41
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +7 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +16 -7
- snowflake/ml/model/_packager/model_meta/model_meta.py +2 -1
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -0
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +5 -4
- snowflake/ml/model/_signatures/dmatrix_handler.py +15 -2
- snowflake/ml/model/custom_model.py +17 -4
- snowflake/ml/model/model_signature.py +3 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +9 -1
- snowflake/ml/modeling/cluster/affinity_propagation.py +9 -1
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +9 -1
- snowflake/ml/modeling/cluster/birch.py +9 -1
- snowflake/ml/modeling/cluster/bisecting_k_means.py +9 -1
- snowflake/ml/modeling/cluster/dbscan.py +9 -1
- snowflake/ml/modeling/cluster/feature_agglomeration.py +9 -1
- snowflake/ml/modeling/cluster/k_means.py +9 -1
- snowflake/ml/modeling/cluster/mean_shift.py +9 -1
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +9 -1
- snowflake/ml/modeling/cluster/optics.py +9 -1
- snowflake/ml/modeling/cluster/spectral_biclustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_clustering.py +9 -1
- snowflake/ml/modeling/cluster/spectral_coclustering.py +9 -1
- snowflake/ml/modeling/compose/column_transformer.py +9 -1
- snowflake/ml/modeling/compose/transformed_target_regressor.py +9 -1
- snowflake/ml/modeling/covariance/elliptic_envelope.py +9 -1
- snowflake/ml/modeling/covariance/empirical_covariance.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso.py +9 -1
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +9 -1
- snowflake/ml/modeling/covariance/ledoit_wolf.py +9 -1
- snowflake/ml/modeling/covariance/min_cov_det.py +9 -1
- snowflake/ml/modeling/covariance/oas.py +9 -1
- snowflake/ml/modeling/covariance/shrunk_covariance.py +9 -1
- snowflake/ml/modeling/decomposition/dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/factor_analysis.py +9 -1
- snowflake/ml/modeling/decomposition/fast_ica.py +9 -1
- snowflake/ml/modeling/decomposition/incremental_pca.py +9 -1
- snowflake/ml/modeling/decomposition/kernel_pca.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +9 -1
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/pca.py +9 -1
- snowflake/ml/modeling/decomposition/sparse_pca.py +9 -1
- snowflake/ml/modeling/decomposition/truncated_svd.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/bagging_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/isolation_forest.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/stacking_regressor.py +9 -1
- snowflake/ml/modeling/ensemble/voting_classifier.py +9 -1
- snowflake/ml/modeling/ensemble/voting_regressor.py +9 -1
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fdr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fpr.py +9 -1
- snowflake/ml/modeling/feature_selection/select_fwe.py +9 -1
- snowflake/ml/modeling/feature_selection/select_k_best.py +9 -1
- snowflake/ml/modeling/feature_selection/select_percentile.py +9 -1
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +9 -1
- snowflake/ml/modeling/feature_selection/variance_threshold.py +9 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +9 -1
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +9 -1
- snowflake/ml/modeling/impute/iterative_imputer.py +9 -1
- snowflake/ml/modeling/impute/knn_imputer.py +9 -1
- snowflake/ml/modeling/impute/missing_indicator.py +9 -1
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/nystroem.py +9 -1
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +9 -1
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +9 -1
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +9 -1
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +9 -1
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ard_regression.py +9 -1
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/gamma_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/huber_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/lars.py +9 -1
- snowflake/ml/modeling/linear_model/lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +9 -1
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +9 -1
- snowflake/ml/modeling/linear_model/linear_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression.py +9 -1
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +9 -1
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +9 -1
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/perceptron.py +9 -1
- snowflake/ml/modeling/linear_model/poisson_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ransac_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/ridge.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +9 -1
- snowflake/ml/modeling/linear_model/ridge_cv.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_classifier.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +9 -1
- snowflake/ml/modeling/linear_model/sgd_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +9 -1
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +9 -1
- snowflake/ml/modeling/manifold/isomap.py +9 -1
- snowflake/ml/modeling/manifold/mds.py +9 -1
- snowflake/ml/modeling/manifold/spectral_embedding.py +9 -1
- snowflake/ml/modeling/manifold/tsne.py +9 -1
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +9 -1
- snowflake/ml/modeling/mixture/gaussian_mixture.py +9 -1
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +9 -1
- snowflake/ml/modeling/multiclass/output_code_classifier.py +9 -1
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/complement_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +9 -1
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neighbors/kernel_density.py +9 -1
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_centroid.py +9 -1
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +9 -1
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +9 -1
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +9 -1
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_classifier.py +9 -1
- snowflake/ml/modeling/neural_network/mlp_regressor.py +9 -1
- snowflake/ml/modeling/preprocessing/polynomial_features.py +9 -1
- snowflake/ml/modeling/semi_supervised/label_propagation.py +9 -1
- snowflake/ml/modeling/semi_supervised/label_spreading.py +9 -1
- snowflake/ml/modeling/svm/linear_svc.py +9 -1
- snowflake/ml/modeling/svm/linear_svr.py +9 -1
- snowflake/ml/modeling/svm/nu_svc.py +9 -1
- snowflake/ml/modeling/svm/nu_svr.py +9 -1
- snowflake/ml/modeling/svm/svc.py +9 -1
- snowflake/ml/modeling/svm/svr.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/decision_tree_regressor.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_classifier.py +9 -1
- snowflake/ml/modeling/tree/extra_tree_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgb_regressor.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +9 -1
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +9 -1
- snowflake/ml/monitoring/explain_visualize.py +424 -0
- snowflake/ml/registry/_manager/model_manager.py +23 -2
- snowflake/ml/registry/registry.py +10 -9
- snowflake/ml/utils/connection_params.py +8 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/METADATA +58 -8
- {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/RECORD +196 -195
- {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/WHEEL +1 -1
- {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.3.dist-info → snowflake_ml_python-1.8.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,424 @@
|
|
1
|
+
from typing import Any, Union, cast, overload
|
2
|
+
|
3
|
+
import altair as alt
|
4
|
+
import numpy as np
|
5
|
+
import pandas as pd
|
6
|
+
|
7
|
+
import snowflake.snowpark.dataframe as sp_df
|
8
|
+
from snowflake import snowpark
|
9
|
+
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
10
|
+
from snowflake.ml.model import model_signature, type_hints
|
11
|
+
from snowflake.ml.model._signatures import snowpark_handler
|
12
|
+
|
13
|
+
DEFAULT_FIGSIZE = (1400, 500)
|
14
|
+
DEFAULT_VIOLIN_FIGSIZE = (1400, 100)
|
15
|
+
MAX_ANNOTATION_LENGTH = 20
|
16
|
+
MIN_DISTANCE = 10 # Increase minimum distance between labels for more spreading in plot_force
|
17
|
+
|
18
|
+
|
19
|
+
@overload
|
20
|
+
def plot_force(
|
21
|
+
shap_row: snowpark.Row,
|
22
|
+
features_row: snowpark.Row,
|
23
|
+
base_value: float = 0.0,
|
24
|
+
figsize: tuple[float, float] = DEFAULT_FIGSIZE,
|
25
|
+
contribution_threshold: float = 0.05,
|
26
|
+
) -> alt.LayerChart:
|
27
|
+
...
|
28
|
+
|
29
|
+
|
30
|
+
@overload
|
31
|
+
def plot_force(
|
32
|
+
shap_row: pd.Series,
|
33
|
+
features_row: pd.Series,
|
34
|
+
base_value: float = 0.0,
|
35
|
+
figsize: tuple[float, float] = DEFAULT_FIGSIZE,
|
36
|
+
contribution_threshold: float = 0.05,
|
37
|
+
) -> alt.LayerChart:
|
38
|
+
...
|
39
|
+
|
40
|
+
|
41
|
+
def plot_force(
|
42
|
+
shap_row: Union[pd.Series, snowpark.Row],
|
43
|
+
features_row: Union[pd.Series, snowpark.Row],
|
44
|
+
base_value: float = 0.0,
|
45
|
+
figsize: tuple[float, float] = DEFAULT_FIGSIZE,
|
46
|
+
contribution_threshold: float = 0.05,
|
47
|
+
) -> alt.LayerChart:
|
48
|
+
"""
|
49
|
+
Create a force plot for SHAP values with stacked bars based on influence direction.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
shap_row: pandas Series or snowpark Row containing SHAP values for a specific instance
|
53
|
+
features_row: pandas Series or snowpark Row containing the feature values for the same instance
|
54
|
+
base_value: base value of the predictions. Defaults to 0, but is usually the model's average prediction
|
55
|
+
figsize: tuple of (width, height) for the plot
|
56
|
+
contribution_threshold:
|
57
|
+
Only features with magnitude greater than contribution_threshold as a percentage of the
|
58
|
+
total absolute SHAP values will be plotted. Defaults to 0.05 (5%)
|
59
|
+
|
60
|
+
Returns:
|
61
|
+
Altair chart object
|
62
|
+
|
63
|
+
Raises:
|
64
|
+
SnowflakeMLException: If the contribution threshold is not between 0 and 1,
|
65
|
+
or if no features with significant contributions are found.
|
66
|
+
"""
|
67
|
+
if not (0 < contribution_threshold and contribution_threshold < 1):
|
68
|
+
raise exceptions.SnowflakeMLException(
|
69
|
+
error_code=error_codes.INVALID_ARGUMENT,
|
70
|
+
original_exception=ValueError("contribution_threshold must be between 0 and 1."),
|
71
|
+
)
|
72
|
+
|
73
|
+
if isinstance(shap_row, snowpark.Row):
|
74
|
+
shap_row = pd.Series(shap_row.as_dict())
|
75
|
+
if isinstance(features_row, snowpark.Row):
|
76
|
+
features_row = pd.Series(features_row.as_dict())
|
77
|
+
|
78
|
+
# Create a dataframe for plotting
|
79
|
+
positive_label = "Positive"
|
80
|
+
negative_label = "Negative"
|
81
|
+
plot_df = pd.DataFrame(
|
82
|
+
[
|
83
|
+
{
|
84
|
+
"feature": feature,
|
85
|
+
"feature_value": features_row.iloc[index],
|
86
|
+
"feature_annotated": f"{feature}: {features_row.iloc[index]}"[:MAX_ANNOTATION_LENGTH],
|
87
|
+
"influence_value": shap_row.iloc[index],
|
88
|
+
"bar_direction": positive_label if shap_row.iloc[index] >= 0 else negative_label,
|
89
|
+
}
|
90
|
+
for index, feature in enumerate(features_row.index)
|
91
|
+
]
|
92
|
+
)
|
93
|
+
|
94
|
+
# Calculate cumulative positions for the stacked bars
|
95
|
+
shap_sum = np.sum(shap_row)
|
96
|
+
current_position_pos = shap_sum
|
97
|
+
current_position_neg = shap_sum
|
98
|
+
positions = []
|
99
|
+
|
100
|
+
total_abs_value_sum = np.sum(plot_df["influence_value"].abs())
|
101
|
+
max_abs_value = plot_df["influence_value"].abs().max()
|
102
|
+
spacing = max_abs_value * 0.07 # Use 2% of max value as spacing between bars
|
103
|
+
|
104
|
+
# Sort by absolute value to have largest impacts first
|
105
|
+
plot_df = plot_df.reindex(plot_df["influence_value"].abs().sort_values(ascending=False).index)
|
106
|
+
for _, row in plot_df.iterrows():
|
107
|
+
# Skip features with small contributions
|
108
|
+
row_influence_value = row["influence_value"]
|
109
|
+
if abs(row_influence_value) / total_abs_value_sum < contribution_threshold:
|
110
|
+
continue
|
111
|
+
|
112
|
+
if row_influence_value >= 0:
|
113
|
+
start = current_position_pos - spacing
|
114
|
+
end = current_position_pos - row_influence_value - spacing
|
115
|
+
current_position_pos = end
|
116
|
+
else:
|
117
|
+
start = current_position_neg + spacing
|
118
|
+
end = current_position_neg + abs(row_influence_value) + spacing
|
119
|
+
current_position_neg = end
|
120
|
+
|
121
|
+
positions.append(
|
122
|
+
{
|
123
|
+
"start": start,
|
124
|
+
"end": end,
|
125
|
+
"avg": (start + end) / 2,
|
126
|
+
"influence_value": row_influence_value,
|
127
|
+
"feature_value": row["feature_value"],
|
128
|
+
"feature_annotated": row["feature_annotated"],
|
129
|
+
"bar_direction": row["bar_direction"],
|
130
|
+
"bar_y": 0,
|
131
|
+
"feature": row["feature"],
|
132
|
+
}
|
133
|
+
)
|
134
|
+
|
135
|
+
if len(positions) == 0:
|
136
|
+
raise exceptions.SnowflakeMLException(
|
137
|
+
error_code=error_codes.INVALID_ARGUMENT,
|
138
|
+
original_exception=ValueError(
|
139
|
+
"No features with significant contributions found. Try lowering the contribution_threshold,"
|
140
|
+
"and verify the input is non-empty."
|
141
|
+
),
|
142
|
+
)
|
143
|
+
|
144
|
+
position_df = pd.DataFrame(positions)
|
145
|
+
|
146
|
+
# Create force plot using Altair
|
147
|
+
blue_color = "#1f77b4"
|
148
|
+
red_color = "#d62728"
|
149
|
+
width, height = figsize
|
150
|
+
bars: alt.Chart = (
|
151
|
+
alt.Chart(position_df)
|
152
|
+
.mark_bar(size=10)
|
153
|
+
.encode(
|
154
|
+
x=alt.X("start:Q", title="Feature Impact"),
|
155
|
+
x2=alt.X2("end:Q"),
|
156
|
+
y=alt.Y("bar_y:Q", axis=None),
|
157
|
+
color=alt.Color(
|
158
|
+
"bar_direction:N",
|
159
|
+
scale=alt.Scale(domain=[positive_label, negative_label], range=[red_color, blue_color]),
|
160
|
+
legend=alt.Legend(title="Influence Direction"),
|
161
|
+
),
|
162
|
+
tooltip=["feature", "influence_value", "feature_value"],
|
163
|
+
)
|
164
|
+
.properties(title="Feature Influence (SHAP values)", width=width, height=height)
|
165
|
+
).interactive()
|
166
|
+
|
167
|
+
arrow: alt.Chart = (
|
168
|
+
alt.Chart(position_df)
|
169
|
+
.mark_point(shape="triangle", filled=True, fillOpacity=1)
|
170
|
+
.encode(
|
171
|
+
x=alt.X("start:Q"),
|
172
|
+
y=alt.Y("bar_y:Q", axis=None),
|
173
|
+
angle=alt.Angle("bar_direction:N", scale=alt.Scale(domain=["Positive", "Negative"], range=[90, -90])),
|
174
|
+
color=alt.Color(
|
175
|
+
"bar_direction:N", scale=alt.Scale(domain=["Positive", "Negative"], range=["#1f77b4", "#d62728"])
|
176
|
+
),
|
177
|
+
size=alt.SizeValue(300),
|
178
|
+
tooltip=alt.value(None),
|
179
|
+
)
|
180
|
+
)
|
181
|
+
|
182
|
+
# Add a vertical line at the base value
|
183
|
+
zero_line: alt.Chart = alt.Chart(pd.DataFrame({"x": [base_value]})).mark_rule(strokeDash=[3, 3]).encode(x="x:Q")
|
184
|
+
|
185
|
+
# Calculate label positions to avoid overlap and ensure labels are spread apart horizontally
|
186
|
+
|
187
|
+
# Sort by bar center (avg) for label placement
|
188
|
+
sorted_positions = sorted(positions, key=lambda x: x["avg"])
|
189
|
+
|
190
|
+
# Improved label spreading algorithm:
|
191
|
+
# Calculate the minimum and maximum x positions (avg) for the bars
|
192
|
+
min_x = min(pos["avg"] for pos in sorted_positions)
|
193
|
+
max_x = max(pos["avg"] for pos in sorted_positions)
|
194
|
+
n_labels = len(sorted_positions)
|
195
|
+
# Calculate the minimum required distance between labels
|
196
|
+
spread_width = max_x - min_x
|
197
|
+
if n_labels > 1:
|
198
|
+
space_per_label = spread_width / (n_labels - 1)
|
199
|
+
# If space_per_label is less than min_distance, use min_distance instead
|
200
|
+
effective_distance = max(space_per_label, MIN_DISTANCE)
|
201
|
+
else:
|
202
|
+
effective_distance = 0
|
203
|
+
|
204
|
+
# Start from min_x - offset, and assign label_x for each label from left to right
|
205
|
+
offset = -effective_distance # Start a bit to the left
|
206
|
+
label_positions = []
|
207
|
+
label_lines = []
|
208
|
+
placed_label_xs: list[float] = []
|
209
|
+
for i, pos in enumerate(sorted_positions):
|
210
|
+
if i == 0:
|
211
|
+
label_x = min_x + offset
|
212
|
+
else:
|
213
|
+
label_x = placed_label_xs[-1] + effective_distance
|
214
|
+
placed_label_xs.append(label_x)
|
215
|
+
label_positions.append(
|
216
|
+
{
|
217
|
+
"label_x": label_x,
|
218
|
+
"label_y": 1, # Place labels below the bars
|
219
|
+
"feature_annotated": pos["feature_annotated"],
|
220
|
+
"feature_value": pos["feature_value"],
|
221
|
+
}
|
222
|
+
)
|
223
|
+
# Draw a diagonal line from the bar to the label
|
224
|
+
label_lines.append(
|
225
|
+
{
|
226
|
+
"x": pos["avg"],
|
227
|
+
"x2": label_x,
|
228
|
+
"y": 0,
|
229
|
+
"y2": 1,
|
230
|
+
}
|
231
|
+
)
|
232
|
+
|
233
|
+
label_positions_df = pd.DataFrame(label_positions)
|
234
|
+
label_lines_df = pd.DataFrame(label_lines)
|
235
|
+
|
236
|
+
# Draw diagonal lines from bar to label
|
237
|
+
label_connectors = (
|
238
|
+
alt.Chart(label_lines_df)
|
239
|
+
.mark_rule(strokeDash=[2, 2], color="grey")
|
240
|
+
.encode(
|
241
|
+
x="x:Q",
|
242
|
+
x2="x2:Q",
|
243
|
+
y=alt.Y("y:Q", axis=None),
|
244
|
+
y2="y2:Q",
|
245
|
+
)
|
246
|
+
)
|
247
|
+
|
248
|
+
# Place labels at adjusted positions
|
249
|
+
feature_labels = (
|
250
|
+
alt.Chart(label_positions_df)
|
251
|
+
.mark_text(align="center", baseline="line-bottom", dy=0, fontSize=11)
|
252
|
+
.encode(
|
253
|
+
x=alt.X("label_x:Q"),
|
254
|
+
y=alt.Y("label_y:Q", axis=None),
|
255
|
+
text=alt.Text("feature_annotated:N"),
|
256
|
+
color=alt.value("grey"),
|
257
|
+
tooltip=["feature_value"],
|
258
|
+
)
|
259
|
+
)
|
260
|
+
|
261
|
+
return cast(alt.LayerChart, bars + feature_labels + zero_line + arrow + label_connectors)
|
262
|
+
|
263
|
+
|
264
|
+
def plot_influence_sensitivity(
|
265
|
+
shap_values: type_hints.SupportedDataType,
|
266
|
+
feature_values: type_hints.SupportedDataType,
|
267
|
+
figsize: tuple[float, float] = DEFAULT_FIGSIZE,
|
268
|
+
) -> Any:
|
269
|
+
"""
|
270
|
+
Create a SHAP dependence scatter plot for a specific feature. If a DataFrame is provided, a select box
|
271
|
+
will be displayed to select the feature. This is only supported in Snowflake notebooks.
|
272
|
+
If Streamlit is not available and a DataFrame is passed in, an ImportError will be raised.
|
273
|
+
|
274
|
+
Args:
|
275
|
+
feature_values: pandas Series or 2D array containing the feature values for a specific feature
|
276
|
+
shap_values: pandas Series or 2D array containing the SHAP values for the same feature
|
277
|
+
figsize: tuple of (width, height) for the plot
|
278
|
+
|
279
|
+
Returns:
|
280
|
+
Altair chart object
|
281
|
+
|
282
|
+
Raises:
|
283
|
+
ValueError: If the types of feature_values and shap_values are not the same
|
284
|
+
|
285
|
+
"""
|
286
|
+
|
287
|
+
use_streamlit = False
|
288
|
+
feature_values_df = _convert_to_pandas_df(feature_values)
|
289
|
+
shap_values_df = _convert_to_pandas_df(shap_values)
|
290
|
+
|
291
|
+
if len(shap_values_df.shape) > 1:
|
292
|
+
feature_values, shap_values, st = _prepare_feature_values_for_streamlit(feature_values_df, shap_values_df)
|
293
|
+
use_streamlit = True
|
294
|
+
elif feature_values_df.shape[0] != shap_values_df.shape[0]:
|
295
|
+
raise ValueError("Feature values and SHAP values must have the same number of rows.")
|
296
|
+
|
297
|
+
scatter = _create_scatter_plot(feature_values, shap_values, figsize)
|
298
|
+
return st.altair_chart(scatter) if use_streamlit else scatter
|
299
|
+
|
300
|
+
|
301
|
+
def _prepare_feature_values_for_streamlit(
|
302
|
+
feature_values_df: pd.DataFrame, shap_values: pd.DataFrame
|
303
|
+
) -> tuple[pd.Series, pd.Series, Any]:
|
304
|
+
try:
|
305
|
+
from IPython import get_ipython
|
306
|
+
from snowbook.executor.python_transformer import IPythonProxy
|
307
|
+
|
308
|
+
assert isinstance(
|
309
|
+
get_ipython(), IPythonProxy
|
310
|
+
), "Influence sensitivity plots for a DataFrame are not supported outside of Snowflake notebooks."
|
311
|
+
except ImportError:
|
312
|
+
raise RuntimeError(
|
313
|
+
"Influence sensitivity plots for a DataFrame are not supported outside of Snowflake notebooks."
|
314
|
+
)
|
315
|
+
|
316
|
+
import streamlit as st
|
317
|
+
|
318
|
+
feature_columns = feature_values_df.columns
|
319
|
+
chosen_ft: str = st.selectbox("Feature:", feature_columns)
|
320
|
+
feature_values = feature_values_df[chosen_ft]
|
321
|
+
shap_values = shap_values.iloc[:, feature_columns.get_loc(chosen_ft)]
|
322
|
+
return feature_values, shap_values, st
|
323
|
+
|
324
|
+
|
325
|
+
def _create_scatter_plot(feature_values: pd.Series, shap_values: pd.Series, figsize: tuple[float, float]) -> alt.Chart:
|
326
|
+
unique_vals = np.sort(np.unique(feature_values.values))
|
327
|
+
max_points_per_unique_value = float(np.max(np.bincount(np.searchsorted(unique_vals, feature_values.values))))
|
328
|
+
points_per_value = len(feature_values.values) / len(unique_vals)
|
329
|
+
is_categorical = float(max(max_points_per_unique_value, points_per_value)) > 10
|
330
|
+
|
331
|
+
kwargs = (
|
332
|
+
{
|
333
|
+
"x": alt.X("feature_value:N", title="Feature Value"),
|
334
|
+
"color": alt.Color("feature_value:N").legend(None),
|
335
|
+
"xOffset": "jitter:Q",
|
336
|
+
}
|
337
|
+
if is_categorical
|
338
|
+
else {"x": alt.X("feature_value:Q", title="Feature Value")}
|
339
|
+
)
|
340
|
+
|
341
|
+
# Create a dataframe for plotting
|
342
|
+
plot_df = pd.DataFrame({"feature_value": feature_values, "shap_value": shap_values})
|
343
|
+
|
344
|
+
width, height = figsize
|
345
|
+
|
346
|
+
# Create scatter plot
|
347
|
+
scatter = (
|
348
|
+
alt.Chart(plot_df)
|
349
|
+
.transform_calculate(jitter="random()")
|
350
|
+
.mark_circle(size=60, opacity=0.7)
|
351
|
+
.encode(
|
352
|
+
y=alt.Y("shap_value:Q", title="SHAP Value"),
|
353
|
+
tooltip=["feature_value", "shap_value"],
|
354
|
+
**kwargs,
|
355
|
+
)
|
356
|
+
.properties(title="SHAP Dependence Scatter Plot", width=width, height=height)
|
357
|
+
)
|
358
|
+
|
359
|
+
return cast(alt.Chart, scatter)
|
360
|
+
|
361
|
+
|
362
|
+
def plot_violin(
|
363
|
+
shap_df: type_hints.SupportedDataType,
|
364
|
+
feature_df: type_hints.SupportedDataType,
|
365
|
+
figsize: tuple[float, float] = DEFAULT_VIOLIN_FIGSIZE,
|
366
|
+
) -> alt.Chart:
|
367
|
+
"""
|
368
|
+
Create a violin plot per feature showing the distribution of SHAP values.
|
369
|
+
|
370
|
+
Args:
|
371
|
+
shap_df: 2D array containing SHAP values for multiple features
|
372
|
+
feature_df: 2D array containing the corresponding feature values
|
373
|
+
figsize: tuple of (width, height) for the plot
|
374
|
+
|
375
|
+
Returns:
|
376
|
+
Altair chart object
|
377
|
+
"""
|
378
|
+
|
379
|
+
shap_df_pd = _convert_to_pandas_df(shap_df)
|
380
|
+
feature_df_pd = _convert_to_pandas_df(feature_df)
|
381
|
+
|
382
|
+
# Assert that the input dataframes are 2D
|
383
|
+
assert len(shap_df_pd.shape) == 2, f"shap_df must be 2D, but got shape {shap_df_pd.shape}"
|
384
|
+
assert len(feature_df_pd.shape) == 2, f"feature_df must be 2D, but got shape {feature_df_pd.shape}"
|
385
|
+
|
386
|
+
# Prepare data for plotting
|
387
|
+
plot_data = pd.DataFrame(
|
388
|
+
{
|
389
|
+
"feature_name": feature_df_pd.columns.repeat(shap_df_pd.shape[0]),
|
390
|
+
"shap_value": shap_df_pd.transpose().values.flatten(),
|
391
|
+
}
|
392
|
+
)
|
393
|
+
|
394
|
+
# Order the rows by the absolute sum of SHAP values per feature
|
395
|
+
feature_abs_sum = shap_df_pd.abs().sum(axis=0)
|
396
|
+
sorted_features = feature_abs_sum.sort_values(ascending=False).index
|
397
|
+
column_sort_order = [feature_df_pd.columns[shap_df_pd.columns.get_loc(col)] for col in sorted_features]
|
398
|
+
|
399
|
+
# Create the violin plot
|
400
|
+
width, height = figsize
|
401
|
+
violin = (
|
402
|
+
alt.Chart(plot_data)
|
403
|
+
.transform_density(density="shap_value", groupby=["feature_name"], as_=["shap_value", "density"])
|
404
|
+
.mark_area(orient="vertical")
|
405
|
+
.encode(
|
406
|
+
y=alt.Y("density:Q", title=None).stack("center").impute(None).axis(labels=False, grid=False, ticks=True),
|
407
|
+
x=alt.X("shap_value:Q", title="SHAP Value"),
|
408
|
+
row=alt.Row("feature_name:N", sort=column_sort_order).spacing(0),
|
409
|
+
color=alt.Color("feature_name:N", legend=None),
|
410
|
+
tooltip=["feature_name", "shap_value"],
|
411
|
+
)
|
412
|
+
.properties(width=width, height=height)
|
413
|
+
).interactive()
|
414
|
+
|
415
|
+
return cast(alt.Chart, violin)
|
416
|
+
|
417
|
+
|
418
|
+
def _convert_to_pandas_df(
|
419
|
+
data: type_hints.SupportedDataType,
|
420
|
+
) -> pd.DataFrame:
|
421
|
+
if isinstance(data, sp_df.DataFrame):
|
422
|
+
return snowpark_handler.SnowparkDataFrameHandler.convert_to_df(data)
|
423
|
+
|
424
|
+
return model_signature._convert_local_data_to_df(data)
|
@@ -12,8 +12,10 @@ from snowflake.ml.model import model_signature, type_hints as model_types
|
|
12
12
|
from snowflake.ml.model._client.model import model_impl, model_version_impl
|
13
13
|
from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
|
14
14
|
from snowflake.ml.model._model_composer import model_composer
|
15
|
+
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
15
16
|
from snowflake.ml.model._packager.model_meta import model_meta
|
16
17
|
from snowflake.snowpark import exceptions as snowpark_exceptions, session
|
18
|
+
from snowflake.snowpark._internal import utils as snowpark_utils
|
17
19
|
|
18
20
|
logger = logging.getLogger(__name__)
|
19
21
|
|
@@ -169,7 +171,10 @@ class ModelManager:
|
|
169
171
|
database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
|
170
172
|
version_name_id = sql_identifier.SqlIdentifier(version_name)
|
171
173
|
|
172
|
-
|
174
|
+
# TODO(SNOW-2091317): Remove this when the snowpark enables file PUT operation for snowurls
|
175
|
+
use_live_commit = (
|
176
|
+
not snowpark_utils.is_in_stored_procedure() # type: ignore[no-untyped-call]
|
177
|
+
) and platform_capabilities.PlatformCapabilities.get_instance().is_live_commit_enabled()
|
173
178
|
if use_live_commit:
|
174
179
|
logger.info("Using live commit model version")
|
175
180
|
else:
|
@@ -212,8 +217,24 @@ class ModelManager:
|
|
212
217
|
# Convert any string target platforms to TargetPlatform objects
|
213
218
|
platforms = [model_types.TargetPlatform(platform) for platform in target_platforms]
|
214
219
|
else:
|
220
|
+
# Default the target platform to warehouse if not specified and any table function exists
|
221
|
+
if options and (
|
222
|
+
options.get("function_type") == model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
|
223
|
+
or (
|
224
|
+
any(
|
225
|
+
opt.get("function_type") == "TABLE_FUNCTION"
|
226
|
+
for opt in options.get("method_options", {}).values()
|
227
|
+
)
|
228
|
+
)
|
229
|
+
):
|
230
|
+
logger.info(
|
231
|
+
"Logging a partitioned model with a table function without specifying `target_platforms`. "
|
232
|
+
'Default to `target_platforms=["WAREHOUSE"]`.'
|
233
|
+
)
|
234
|
+
platforms = [model_types.TargetPlatform.WAREHOUSE]
|
235
|
+
|
215
236
|
# Default the target platform to SPCS if not specified when running in ML runtime
|
216
|
-
if env.IN_ML_RUNTIME:
|
237
|
+
if not platforms and env.IN_ML_RUNTIME:
|
217
238
|
logger.info(
|
218
239
|
"Logging the model on Container Runtime for ML without specifying `target_platforms`. "
|
219
240
|
'Default to `target_platforms=["SNOWPARK_CONTAINER_SERVICES"]`.'
|
@@ -148,11 +148,11 @@ class Registry:
|
|
148
148
|
dependencies must be retrieved from Snowflake Anaconda Channel.
|
149
149
|
artifact_repository_map: Specifies a mapping of package channels or platforms to custom artifact
|
150
150
|
repositories. Defaults to None. Currently, the mapping applies only to warehouse execution.
|
151
|
-
Note : This feature is currently in
|
152
|
-
to enable it.
|
151
|
+
Note : This feature is currently in Public Preview.
|
153
152
|
Format: {channel_name: artifact_repository_name}, where:
|
154
|
-
- channel_name:
|
155
|
-
- artifact_repository_name: The
|
153
|
+
- channel_name: Currently must be 'pip'.
|
154
|
+
- artifact_repository_name: The identifier of the artifact repository to fetch packages from, e.g.
|
155
|
+
`snowflake.snowpark.pypi_shared_repository`.
|
156
156
|
resource_constraint: Mapping of resource constraint keys and values, e.g. {"architecture": "x86"}.
|
157
157
|
target_platforms: List of target platforms to run the model. The only acceptable inputs are a combination of
|
158
158
|
{"WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"}. Defaults to None.
|
@@ -288,14 +288,15 @@ class Registry:
|
|
288
288
|
dependencies must be retrieved from Snowflake Anaconda Channel.
|
289
289
|
artifact_repository_map: Specifies a mapping of package channels or platforms to custom artifact
|
290
290
|
repositories. Defaults to None. Currently, the mapping applies only to warehouse execution.
|
291
|
-
Note : This feature is currently in
|
292
|
-
enable it.
|
291
|
+
Note : This feature is currently in Public Preview.
|
293
292
|
Format: {channel_name: artifact_repository_name}, where:
|
294
|
-
- channel_name:
|
295
|
-
- artifact_repository_name: The
|
293
|
+
- channel_name: Currently must be 'pip'.
|
294
|
+
- artifact_repository_name: The identifier of the artifact repository to fetch packages from, e.g.
|
295
|
+
`snowflake.snowpark.pypi_shared_repository`.
|
296
296
|
resource_constraint: Mapping of resource constraint keys and values, e.g. {"architecture": "x86"}.
|
297
297
|
target_platforms: List of target platforms to run the model. The only acceptable inputs are a combination of
|
298
|
-
|
298
|
+
["WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"]. Defaults to None. When None, the target platforms will be
|
299
|
+
both.
|
299
300
|
python_version: Python version in which the model is run. Defaults to None.
|
300
301
|
signatures: Model data signatures for inputs and outputs for various target methods. If it is None,
|
301
302
|
sample_input_data would be used to infer the signatures for those models that cannot automatically
|
@@ -113,6 +113,10 @@ def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -
|
|
113
113
|
|
114
114
|
config = configparser.ConfigParser(inline_comment_prefixes="#")
|
115
115
|
|
116
|
+
snowflake_connection_name = os.getenv("SNOWFLAKE_CONNECTION_NAME")
|
117
|
+
if snowflake_connection_name is not None:
|
118
|
+
connection_name = snowflake_connection_name
|
119
|
+
|
116
120
|
if connection_name:
|
117
121
|
if not connection_name.startswith("connections."):
|
118
122
|
connection_name = "connections." + connection_name
|
@@ -153,9 +157,11 @@ def SnowflakeLoginOptions(connection_name: str = "", login_file: Optional[str] =
|
|
153
157
|
Ideally one should have a snowsql config file. Read more here:
|
154
158
|
https://docs.snowflake.com/en/user-guide/snowsql-start.html#configuring-default-connection-settings
|
155
159
|
|
160
|
+
If snowsql config file does not exist, it tries auth from env variables.
|
161
|
+
|
156
162
|
Args:
|
157
|
-
connection_name: Name of the connection to look for inside the config file. If
|
158
|
-
it
|
163
|
+
connection_name: Name of the connection to look for inside the config file. If environment variable
|
164
|
+
SNOWFLAKE_CONNECTION_NAME is provided, it will override the input connection_name.
|
159
165
|
login_file: If provided, this is used as config file instead of default one (_DEFAULT_CONNECTION_FILE).
|
160
166
|
|
161
167
|
Returns:
|
snowflake/ml/version.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1
1
|
# This is parsed by regex in conda recipe meta file. Make sure not to break it.
|
2
|
-
VERSION = "1.8.
|
2
|
+
VERSION = "1.8.5"
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: snowflake-ml-python
|
3
|
-
Version: 1.8.
|
3
|
+
Version: 1.8.5
|
4
4
|
Summary: The machine learning client library that is used for interacting with Snowflake to build machine learning solutions.
|
5
5
|
Author-email: "Snowflake, Inc" <support@snowflake.com>
|
6
6
|
License:
|
@@ -236,13 +236,13 @@ License-File: LICENSE.txt
|
|
236
236
|
Requires-Dist: absl-py<2,>=0.15
|
237
237
|
Requires-Dist: anyio<5,>=3.5.0
|
238
238
|
Requires-Dist: cachetools<6,>=3.1.1
|
239
|
-
Requires-Dist: cloudpickle
|
239
|
+
Requires-Dist: cloudpickle>=2.0.0
|
240
240
|
Requires-Dist: cryptography
|
241
241
|
Requires-Dist: fsspec[http]<2026,>=2024.6.1
|
242
242
|
Requires-Dist: importlib_resources<7,>=6.1.1
|
243
243
|
Requires-Dist: numpy<2,>=1.23
|
244
244
|
Requires-Dist: packaging<25,>=20.9
|
245
|
-
Requires-Dist: pandas<3,>=1.
|
245
|
+
Requires-Dist: pandas<3,>=2.1.4
|
246
246
|
Requires-Dist: pyarrow
|
247
247
|
Requires-Dist: pydantic<3,>=2.8.2
|
248
248
|
Requires-Dist: pyjwt<3,>=2.0.0
|
@@ -250,27 +250,31 @@ Requires-Dist: pytimeparse<2,>=1.1.8
|
|
250
250
|
Requires-Dist: pyyaml<7,>=6.0
|
251
251
|
Requires-Dist: retrying<2,>=1.3.3
|
252
252
|
Requires-Dist: s3fs<2026,>=2024.6.1
|
253
|
-
Requires-Dist: scikit-learn<1.6
|
253
|
+
Requires-Dist: scikit-learn<1.6
|
254
254
|
Requires-Dist: scipy<2,>=1.9
|
255
|
-
Requires-Dist:
|
255
|
+
Requires-Dist: shap<1,>=0.46.0
|
256
|
+
Requires-Dist: snowflake-connector-python[pandas]<4,>=3.15.0
|
256
257
|
Requires-Dist: snowflake-snowpark-python!=1.26.0,<2,>=1.17.0
|
257
258
|
Requires-Dist: snowflake.core<2,>=1.0.2
|
258
259
|
Requires-Dist: sqlparse<1,>=0.4
|
259
260
|
Requires-Dist: typing-extensions<5,>=4.1.0
|
260
261
|
Requires-Dist: xgboost<3,>=1.7.3
|
261
262
|
Provides-Extra: all
|
263
|
+
Requires-Dist: altair<6,>=5; extra == "all"
|
262
264
|
Requires-Dist: catboost<2,>=1.2.0; extra == "all"
|
263
265
|
Requires-Dist: keras<4,>=2.0.0; extra == "all"
|
264
266
|
Requires-Dist: lightgbm<5,>=4.1.0; extra == "all"
|
265
267
|
Requires-Dist: mlflow<3,>=2.16.0; extra == "all"
|
266
268
|
Requires-Dist: sentence-transformers<4,>=2.7.0; extra == "all"
|
267
269
|
Requires-Dist: sentencepiece<0.2.0,>=0.1.95; extra == "all"
|
268
|
-
Requires-Dist:
|
270
|
+
Requires-Dist: streamlit<2,>=1.30.0; extra == "all"
|
269
271
|
Requires-Dist: tensorflow<3,>=2.17.0; extra == "all"
|
270
272
|
Requires-Dist: tokenizers<1,>=0.15.1; extra == "all"
|
271
273
|
Requires-Dist: torch<3,>=2.0.1; extra == "all"
|
272
274
|
Requires-Dist: torchdata<1,>=0.4; extra == "all"
|
273
275
|
Requires-Dist: transformers<5,>=4.39.3; extra == "all"
|
276
|
+
Provides-Extra: altair
|
277
|
+
Requires-Dist: altair<6,>=5; extra == "altair"
|
274
278
|
Provides-Extra: catboost
|
275
279
|
Requires-Dist: catboost<2,>=1.2.0; extra == "catboost"
|
276
280
|
Provides-Extra: keras
|
@@ -281,8 +285,8 @@ Provides-Extra: lightgbm
|
|
281
285
|
Requires-Dist: lightgbm<5,>=4.1.0; extra == "lightgbm"
|
282
286
|
Provides-Extra: mlflow
|
283
287
|
Requires-Dist: mlflow<3,>=2.16.0; extra == "mlflow"
|
284
|
-
Provides-Extra:
|
285
|
-
Requires-Dist:
|
288
|
+
Provides-Extra: streamlit
|
289
|
+
Requires-Dist: streamlit<2,>=1.30.0; extra == "streamlit"
|
286
290
|
Provides-Extra: tensorflow
|
287
291
|
Requires-Dist: tensorflow<3,>=2.17.0; extra == "tensorflow"
|
288
292
|
Provides-Extra: torch
|
@@ -404,6 +408,51 @@ NOTE: Version 1.7.0 is used as example here. Please choose the the latest versio
|
|
404
408
|
|
405
409
|
# Release History
|
406
410
|
|
411
|
+
## 1.8.5
|
412
|
+
|
413
|
+
### Bug Fixes
|
414
|
+
|
415
|
+
- Registry: Fixed a bug when listing and deleting container services.
|
416
|
+
- Registry: Fixed explainability issue with scikit-learn pipelines, skipping explain function creation.
|
417
|
+
- Explainability: bump minimum streamlit version down to 1.30
|
418
|
+
- Modeling: Make XGBoost a required dependency (xgboost is not a required dependency in snowflake-ml-python 1.8.4).
|
419
|
+
|
420
|
+
### Breaking change
|
421
|
+
|
422
|
+
- ML Job: Rename argument `num_instances` to `target_instances` in job submission APIs and
|
423
|
+
change type from `Optional[int]` to `int`
|
424
|
+
|
425
|
+
### New Features
|
426
|
+
|
427
|
+
- Registry: No longer checks if the snowflake-ml-python version is available in the Snowflake Conda channel when logging
|
428
|
+
an SPCS-only model.
|
429
|
+
- ML Job: Add `min_instances` argument to the job decorator to allow waiting for workers to be ready.
|
430
|
+
|
431
|
+
## 1.8.4 (2025-05-12)
|
432
|
+
|
433
|
+
### Bug Fixes
|
434
|
+
|
435
|
+
- Registry: Default `enable_explainability` to True when the model can be deployed to Warehouse.
|
436
|
+
- Registry: Add `custom_model.partitioned_api` decorator and deprecate `partitioned_inference_api`.
|
437
|
+
- Registry: Fixed a bug when logging pytroch and tensorflow models that caused
|
438
|
+
`UnboundLocalError: local variable 'multiple_inputs' referenced before assignment`.
|
439
|
+
|
440
|
+
### Breaking change
|
441
|
+
|
442
|
+
- ML Job: Updated property `id` to be fully qualified name; Introduced new property `name` to represent the ML Job name
|
443
|
+
- ML Job: Modified `list_jobs()` to return ML Job `name` instead of `id`
|
444
|
+
- Registry: Error in `log_model` if `enable_explainability` is True and model is only deployed to
|
445
|
+
Snowpark Container Services, instead of just user warning.
|
446
|
+
|
447
|
+
### New Features
|
448
|
+
|
449
|
+
- ML Job: Extend `@remote` function decorator, `submit_file()` and `submit_directory()` to accept `database` and
|
450
|
+
`schema` parameters
|
451
|
+
- ML Job: Support querying by fully qualified name in `get_job()`
|
452
|
+
- Explainability: Added visualization functions to `snowflake.ml.monitoring` to plot explanations in notebooks.
|
453
|
+
- Explainability: Support explain for categorical transforms for sklearn pipeline
|
454
|
+
- Support categorical type for `xgboost.DMatrix` inputs.
|
455
|
+
|
407
456
|
## 1.8.3
|
408
457
|
|
409
458
|
### Bug Fixes
|
@@ -417,6 +466,7 @@ NOTE: Version 1.7.0 is used as example here. Please choose the the latest versio
|
|
417
466
|
as a list of strings
|
418
467
|
- Registry: Support `ModelVersion.run_job` to run inference with a single-node Snowpark Container Services job.
|
419
468
|
- DataConnector: Removed PrPr decorators
|
469
|
+
- Registry: Default the target platform to warehouse when logging a partitioned model.
|
420
470
|
|
421
471
|
## 1.8.2
|
422
472
|
|