snowflake-ml-python 1.8.4__py3-none-any.whl → 1.8.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/telemetry.py +42 -16
- snowflake/ml/_internal/utils/connection_params.py +196 -0
- snowflake/ml/data/data_connector.py +1 -1
- snowflake/ml/jobs/__init__.py +2 -0
- snowflake/ml/jobs/_utils/constants.py +12 -2
- snowflake/ml/jobs/_utils/function_payload_utils.py +43 -0
- snowflake/ml/jobs/_utils/interop_utils.py +1 -1
- snowflake/ml/jobs/_utils/payload_utils.py +95 -39
- snowflake/ml/jobs/_utils/scripts/constants.py +22 -0
- snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +67 -2
- snowflake/ml/jobs/_utils/spec_utils.py +30 -6
- snowflake/ml/jobs/_utils/stage_utils.py +119 -0
- snowflake/ml/jobs/_utils/types.py +5 -1
- snowflake/ml/jobs/decorators.py +10 -7
- snowflake/ml/jobs/job.py +176 -28
- snowflake/ml/jobs/manager.py +119 -26
- snowflake/ml/model/_client/model/model_impl.py +58 -0
- snowflake/ml/model/_client/model/model_version_impl.py +90 -0
- snowflake/ml/model/_client/ops/model_ops.py +6 -3
- snowflake/ml/model/_client/ops/service_ops.py +24 -7
- snowflake/ml/model/_client/service/model_deployment_spec.py +11 -0
- snowflake/ml/model/_client/sql/model_version.py +1 -1
- snowflake/ml/model/_client/sql/service.py +73 -28
- snowflake/ml/model/_client/sql/stage.py +5 -2
- snowflake/ml/model/_model_composer/model_composer.py +3 -1
- snowflake/ml/model/_packager/model_handlers/sklearn.py +1 -1
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +103 -73
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -2
- snowflake/ml/model/_signatures/core.py +24 -0
- snowflake/ml/monitoring/explain_visualize.py +160 -22
- snowflake/ml/monitoring/model_monitor.py +0 -4
- snowflake/ml/registry/registry.py +34 -14
- snowflake/ml/utils/connection_params.py +9 -3
- snowflake/ml/utils/html_utils.py +263 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/METADATA +40 -13
- {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/RECORD +40 -37
- {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/WHEEL +1 -1
- snowflake/ml/monitoring/model_monitor_version.py +0 -1
- {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/licenses/LICENSE.txt +0 -0
- {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/top_level.txt +0 -0
@@ -12,7 +12,7 @@ REQUIREMENTS = [
|
|
12
12
|
"importlib_resources>=6.1.1, <7",
|
13
13
|
"numpy>=1.23,<2",
|
14
14
|
"packaging>=20.9,<25",
|
15
|
-
"pandas>=1.
|
15
|
+
"pandas>=2.1.4,<3",
|
16
16
|
"pyarrow",
|
17
17
|
"pydantic>=2.8.2, <3",
|
18
18
|
"pyjwt>=2.0.0, <3",
|
@@ -24,9 +24,10 @@ REQUIREMENTS = [
|
|
24
24
|
"scikit-learn<1.6",
|
25
25
|
"scipy>=1.9,<2",
|
26
26
|
"shap>=0.46.0,<1",
|
27
|
-
"snowflake-connector-python>=3.
|
27
|
+
"snowflake-connector-python>=3.15.0,<4",
|
28
28
|
"snowflake-snowpark-python>=1.17.0,<2,!=1.26.0",
|
29
29
|
"snowflake.core>=1.0.2,<2",
|
30
30
|
"sqlparse>=0.4,<1",
|
31
31
|
"typing-extensions>=4.1.0,<5",
|
32
|
+
"xgboost>=1.7.3,<3",
|
32
33
|
]
|
@@ -559,6 +559,30 @@ class ModelSignature:
|
|
559
559
|
)"""
|
560
560
|
)
|
561
561
|
|
562
|
+
def _repr_html_(self) -> str:
|
563
|
+
"""Generate an HTML representation of the model signature.
|
564
|
+
|
565
|
+
Returns:
|
566
|
+
str: HTML string containing formatted signature details.
|
567
|
+
"""
|
568
|
+
from snowflake.ml.utils import html_utils
|
569
|
+
|
570
|
+
# Create collapsible sections for inputs and outputs
|
571
|
+
inputs_content = html_utils.create_features_html(self.inputs, "Input")
|
572
|
+
outputs_content = html_utils.create_features_html(self.outputs, "Output")
|
573
|
+
|
574
|
+
inputs_section = html_utils.create_collapsible_section("Inputs", inputs_content, open_by_default=True)
|
575
|
+
outputs_section = html_utils.create_collapsible_section("Outputs", outputs_content, open_by_default=True)
|
576
|
+
|
577
|
+
content = f"""
|
578
|
+
<div style="margin-top: 10px;">
|
579
|
+
{inputs_section}
|
580
|
+
{outputs_section}
|
581
|
+
</div>
|
582
|
+
"""
|
583
|
+
|
584
|
+
return html_utils.create_base_container("Model Signature", content)
|
585
|
+
|
562
586
|
@classmethod
|
563
587
|
def from_mlflow_sig(cls, mlflow_sig: "mlflow.models.ModelSignature") -> "ModelSignature":
|
564
588
|
return ModelSignature(
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Union, cast, overload
|
1
|
+
from typing import Any, Union, cast, overload
|
2
2
|
|
3
3
|
import altair as alt
|
4
4
|
import numpy as np
|
@@ -6,16 +6,22 @@ import pandas as pd
|
|
6
6
|
|
7
7
|
import snowflake.snowpark.dataframe as sp_df
|
8
8
|
from snowflake import snowpark
|
9
|
+
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
9
10
|
from snowflake.ml.model import model_signature, type_hints
|
10
11
|
from snowflake.ml.model._signatures import snowpark_handler
|
11
12
|
|
13
|
+
DEFAULT_FIGSIZE = (1400, 500)
|
14
|
+
DEFAULT_VIOLIN_FIGSIZE = (1400, 100)
|
15
|
+
MAX_ANNOTATION_LENGTH = 20
|
16
|
+
MIN_DISTANCE = 10 # Increase minimum distance between labels for more spreading in plot_force
|
17
|
+
|
12
18
|
|
13
19
|
@overload
|
14
20
|
def plot_force(
|
15
21
|
shap_row: snowpark.Row,
|
16
22
|
features_row: snowpark.Row,
|
17
23
|
base_value: float = 0.0,
|
18
|
-
figsize: tuple[float, float] =
|
24
|
+
figsize: tuple[float, float] = DEFAULT_FIGSIZE,
|
19
25
|
contribution_threshold: float = 0.05,
|
20
26
|
) -> alt.LayerChart:
|
21
27
|
...
|
@@ -26,7 +32,7 @@ def plot_force(
|
|
26
32
|
shap_row: pd.Series,
|
27
33
|
features_row: pd.Series,
|
28
34
|
base_value: float = 0.0,
|
29
|
-
figsize: tuple[float, float] =
|
35
|
+
figsize: tuple[float, float] = DEFAULT_FIGSIZE,
|
30
36
|
contribution_threshold: float = 0.05,
|
31
37
|
) -> alt.LayerChart:
|
32
38
|
...
|
@@ -36,7 +42,7 @@ def plot_force(
|
|
36
42
|
shap_row: Union[pd.Series, snowpark.Row],
|
37
43
|
features_row: Union[pd.Series, snowpark.Row],
|
38
44
|
base_value: float = 0.0,
|
39
|
-
figsize: tuple[float, float] =
|
45
|
+
figsize: tuple[float, float] = DEFAULT_FIGSIZE,
|
40
46
|
contribution_threshold: float = 0.05,
|
41
47
|
) -> alt.LayerChart:
|
42
48
|
"""
|
@@ -53,7 +59,17 @@ def plot_force(
|
|
53
59
|
|
54
60
|
Returns:
|
55
61
|
Altair chart object
|
62
|
+
|
63
|
+
Raises:
|
64
|
+
SnowflakeMLException: If the contribution threshold is not between 0 and 1,
|
65
|
+
or if no features with significant contributions are found.
|
56
66
|
"""
|
67
|
+
if not (0 < contribution_threshold and contribution_threshold < 1):
|
68
|
+
raise exceptions.SnowflakeMLException(
|
69
|
+
error_code=error_codes.INVALID_ARGUMENT,
|
70
|
+
original_exception=ValueError("contribution_threshold must be between 0 and 1."),
|
71
|
+
)
|
72
|
+
|
57
73
|
if isinstance(shap_row, snowpark.Row):
|
58
74
|
shap_row = pd.Series(shap_row.as_dict())
|
59
75
|
if isinstance(features_row, snowpark.Row):
|
@@ -67,7 +83,7 @@ def plot_force(
|
|
67
83
|
{
|
68
84
|
"feature": feature,
|
69
85
|
"feature_value": features_row.iloc[index],
|
70
|
-
"feature_annotated": f"{feature}: {features_row.iloc[index]}",
|
86
|
+
"feature_annotated": f"{feature}: {features_row.iloc[index]}"[:MAX_ANNOTATION_LENGTH],
|
71
87
|
"influence_value": shap_row.iloc[index],
|
72
88
|
"bar_direction": positive_label if shap_row.iloc[index] >= 0 else negative_label,
|
73
89
|
}
|
@@ -95,11 +111,11 @@ def plot_force(
|
|
95
111
|
|
96
112
|
if row_influence_value >= 0:
|
97
113
|
start = current_position_pos - spacing
|
98
|
-
end = current_position_pos - row_influence_value
|
114
|
+
end = current_position_pos - row_influence_value - spacing
|
99
115
|
current_position_pos = end
|
100
116
|
else:
|
101
117
|
start = current_position_neg + spacing
|
102
|
-
end = current_position_neg + abs(row_influence_value)
|
118
|
+
end = current_position_neg + abs(row_influence_value) + spacing
|
103
119
|
current_position_neg = end
|
104
120
|
|
105
121
|
positions.append(
|
@@ -108,13 +124,23 @@ def plot_force(
|
|
108
124
|
"end": end,
|
109
125
|
"avg": (start + end) / 2,
|
110
126
|
"influence_value": row_influence_value,
|
111
|
-
"influence_annotated": f"Influence: {row_influence_value}",
|
112
127
|
"feature_value": row["feature_value"],
|
113
128
|
"feature_annotated": row["feature_annotated"],
|
114
129
|
"bar_direction": row["bar_direction"],
|
130
|
+
"bar_y": 0,
|
131
|
+
"feature": row["feature"],
|
115
132
|
}
|
116
133
|
)
|
117
134
|
|
135
|
+
if len(positions) == 0:
|
136
|
+
raise exceptions.SnowflakeMLException(
|
137
|
+
error_code=error_codes.INVALID_ARGUMENT,
|
138
|
+
original_exception=ValueError(
|
139
|
+
"No features with significant contributions found. Try lowering the contribution_threshold,"
|
140
|
+
"and verify the input is non-empty."
|
141
|
+
),
|
142
|
+
)
|
143
|
+
|
118
144
|
position_df = pd.DataFrame(positions)
|
119
145
|
|
120
146
|
# Create force plot using Altair
|
@@ -127,12 +153,13 @@ def plot_force(
|
|
127
153
|
.encode(
|
128
154
|
x=alt.X("start:Q", title="Feature Impact"),
|
129
155
|
x2=alt.X2("end:Q"),
|
156
|
+
y=alt.Y("bar_y:Q", axis=None),
|
130
157
|
color=alt.Color(
|
131
158
|
"bar_direction:N",
|
132
159
|
scale=alt.Scale(domain=[positive_label, negative_label], range=[red_color, blue_color]),
|
133
160
|
legend=alt.Legend(title="Influence Direction"),
|
134
161
|
),
|
135
|
-
tooltip=["influence_value", "feature_value"],
|
162
|
+
tooltip=["feature", "influence_value", "feature_value"],
|
136
163
|
)
|
137
164
|
.properties(title="Feature Influence (SHAP values)", width=width, height=height)
|
138
165
|
).interactive()
|
@@ -142,6 +169,7 @@ def plot_force(
|
|
142
169
|
.mark_point(shape="triangle", filled=True, fillOpacity=1)
|
143
170
|
.encode(
|
144
171
|
x=alt.X("start:Q"),
|
172
|
+
y=alt.Y("bar_y:Q", axis=None),
|
145
173
|
angle=alt.Angle("bar_direction:N", scale=alt.Scale(domain=["Positive", "Negative"], range=[90, -90])),
|
146
174
|
color=alt.Color(
|
147
175
|
"bar_direction:N", scale=alt.Scale(domain=["Positive", "Negative"], range=["#1f77b4", "#d62728"])
|
@@ -154,37 +182,147 @@ def plot_force(
|
|
154
182
|
# Add a vertical line at the base value
|
155
183
|
zero_line: alt.Chart = alt.Chart(pd.DataFrame({"x": [base_value]})).mark_rule(strokeDash=[3, 3]).encode(x="x:Q")
|
156
184
|
|
157
|
-
#
|
185
|
+
# Calculate label positions to avoid overlap and ensure labels are spread apart horizontally
|
186
|
+
|
187
|
+
# Sort by bar center (avg) for label placement
|
188
|
+
sorted_positions = sorted(positions, key=lambda x: x["avg"])
|
189
|
+
|
190
|
+
# Improved label spreading algorithm:
|
191
|
+
# Calculate the minimum and maximum x positions (avg) for the bars
|
192
|
+
min_x = min(pos["avg"] for pos in sorted_positions)
|
193
|
+
max_x = max(pos["avg"] for pos in sorted_positions)
|
194
|
+
n_labels = len(sorted_positions)
|
195
|
+
# Calculate the minimum required distance between labels
|
196
|
+
spread_width = max_x - min_x
|
197
|
+
if n_labels > 1:
|
198
|
+
space_per_label = spread_width / (n_labels - 1)
|
199
|
+
# If space_per_label is less than min_distance, use min_distance instead
|
200
|
+
effective_distance = max(space_per_label, MIN_DISTANCE)
|
201
|
+
else:
|
202
|
+
effective_distance = 0
|
203
|
+
|
204
|
+
# Start from min_x - offset, and assign label_x for each label from left to right
|
205
|
+
offset = -effective_distance # Start a bit to the left
|
206
|
+
label_positions = []
|
207
|
+
label_lines = []
|
208
|
+
placed_label_xs: list[float] = []
|
209
|
+
for i, pos in enumerate(sorted_positions):
|
210
|
+
if i == 0:
|
211
|
+
label_x = min_x + offset
|
212
|
+
else:
|
213
|
+
label_x = placed_label_xs[-1] + effective_distance
|
214
|
+
placed_label_xs.append(label_x)
|
215
|
+
label_positions.append(
|
216
|
+
{
|
217
|
+
"label_x": label_x,
|
218
|
+
"label_y": 1, # Place labels below the bars
|
219
|
+
"feature_annotated": pos["feature_annotated"],
|
220
|
+
"feature_value": pos["feature_value"],
|
221
|
+
}
|
222
|
+
)
|
223
|
+
# Draw a diagonal line from the bar to the label
|
224
|
+
label_lines.append(
|
225
|
+
{
|
226
|
+
"x": pos["avg"],
|
227
|
+
"x2": label_x,
|
228
|
+
"y": 0,
|
229
|
+
"y2": 1,
|
230
|
+
}
|
231
|
+
)
|
232
|
+
|
233
|
+
label_positions_df = pd.DataFrame(label_positions)
|
234
|
+
label_lines_df = pd.DataFrame(label_lines)
|
235
|
+
|
236
|
+
# Draw diagonal lines from bar to label
|
237
|
+
label_connectors = (
|
238
|
+
alt.Chart(label_lines_df)
|
239
|
+
.mark_rule(strokeDash=[2, 2], color="grey")
|
240
|
+
.encode(
|
241
|
+
x="x:Q",
|
242
|
+
x2="x2:Q",
|
243
|
+
y=alt.Y("y:Q", axis=None),
|
244
|
+
y2="y2:Q",
|
245
|
+
)
|
246
|
+
)
|
247
|
+
|
248
|
+
# Place labels at adjusted positions
|
158
249
|
feature_labels = (
|
159
|
-
alt.Chart(
|
160
|
-
.mark_text(align="center", baseline="line-bottom", dy=
|
250
|
+
alt.Chart(label_positions_df)
|
251
|
+
.mark_text(align="center", baseline="line-bottom", dy=0, fontSize=11)
|
161
252
|
.encode(
|
162
|
-
x=alt.X("
|
163
|
-
|
164
|
-
|
253
|
+
x=alt.X("label_x:Q"),
|
254
|
+
y=alt.Y("label_y:Q", axis=None),
|
255
|
+
text=alt.Text("feature_annotated:N"),
|
256
|
+
color=alt.value("grey"),
|
165
257
|
tooltip=["feature_value"],
|
166
258
|
)
|
167
259
|
)
|
168
260
|
|
169
|
-
return cast(alt.LayerChart, bars + feature_labels + zero_line + arrow)
|
261
|
+
return cast(alt.LayerChart, bars + feature_labels + zero_line + arrow + label_connectors)
|
170
262
|
|
171
263
|
|
172
264
|
def plot_influence_sensitivity(
|
173
|
-
|
174
|
-
|
265
|
+
shap_values: type_hints.SupportedDataType,
|
266
|
+
feature_values: type_hints.SupportedDataType,
|
267
|
+
figsize: tuple[float, float] = DEFAULT_FIGSIZE,
|
268
|
+
) -> Any:
|
175
269
|
"""
|
176
|
-
Create a SHAP dependence scatter plot for a specific feature.
|
270
|
+
Create a SHAP dependence scatter plot for a specific feature. If a DataFrame is provided, a select box
|
271
|
+
will be displayed to select the feature. This is only supported in Snowflake notebooks.
|
272
|
+
If Streamlit is not available and a DataFrame is passed in, an ImportError will be raised.
|
177
273
|
|
178
274
|
Args:
|
179
|
-
|
180
|
-
|
275
|
+
shap_values: pandas Series or 2D array containing the SHAP values for a specific feature
|
276
|
+
feature_values: pandas Series or 2D array containing the feature values for the same feature
|
181
277
|
figsize: tuple of (width, height) for the plot
|
182
278
|
|
183
279
|
Returns:
|
184
280
|
Altair chart object
|
185
281
|
|
282
|
+
Raises:
|
283
|
+
ValueError: If the types of feature_values and shap_values are not the same
|
284
|
+
|
186
285
|
"""
|
187
286
|
|
287
|
+
use_streamlit = False
|
288
|
+
feature_values_df = _convert_to_pandas_df(feature_values)
|
289
|
+
shap_values_df = _convert_to_pandas_df(shap_values)
|
290
|
+
|
291
|
+
if len(shap_values_df.shape) > 1:
|
292
|
+
feature_values, shap_values, st = _prepare_feature_values_for_streamlit(feature_values_df, shap_values_df)
|
293
|
+
use_streamlit = True
|
294
|
+
elif feature_values_df.shape[0] != shap_values_df.shape[0]:
|
295
|
+
raise ValueError("Feature values and SHAP values must have the same number of rows.")
|
296
|
+
|
297
|
+
scatter = _create_scatter_plot(feature_values, shap_values, figsize)
|
298
|
+
return st.altair_chart(scatter) if use_streamlit else scatter
|
299
|
+
|
300
|
+
|
301
|
+
def _prepare_feature_values_for_streamlit(
|
302
|
+
feature_values_df: pd.DataFrame, shap_values: pd.DataFrame
|
303
|
+
) -> tuple[pd.Series, pd.Series, Any]:
|
304
|
+
try:
|
305
|
+
from IPython import get_ipython
|
306
|
+
from snowbook.executor.python_transformer import IPythonProxy
|
307
|
+
|
308
|
+
assert isinstance(
|
309
|
+
get_ipython(), IPythonProxy
|
310
|
+
), "Influence sensitivity plots for a DataFrame are not supported outside of Snowflake notebooks."
|
311
|
+
except ImportError:
|
312
|
+
raise RuntimeError(
|
313
|
+
"Influence sensitivity plots for a DataFrame are not supported outside of Snowflake notebooks."
|
314
|
+
)
|
315
|
+
|
316
|
+
import streamlit as st
|
317
|
+
|
318
|
+
feature_columns = feature_values_df.columns
|
319
|
+
chosen_ft: str = st.selectbox("Feature:", feature_columns)
|
320
|
+
feature_values = feature_values_df[chosen_ft]
|
321
|
+
shap_values = shap_values.iloc[:, feature_columns.get_loc(chosen_ft)]
|
322
|
+
return feature_values, shap_values, st
|
323
|
+
|
324
|
+
|
325
|
+
def _create_scatter_plot(feature_values: pd.Series, shap_values: pd.Series, figsize: tuple[float, float]) -> alt.Chart:
|
188
326
|
unique_vals = np.sort(np.unique(feature_values.values))
|
189
327
|
max_points_per_unique_value = float(np.max(np.bincount(np.searchsorted(unique_vals, feature_values.values))))
|
190
328
|
points_per_value = len(feature_values.values) / len(unique_vals)
|
@@ -224,7 +362,7 @@ def plot_influence_sensitivity(
|
|
224
362
|
def plot_violin(
|
225
363
|
shap_df: type_hints.SupportedDataType,
|
226
364
|
feature_df: type_hints.SupportedDataType,
|
227
|
-
figsize: tuple[float, float] =
|
365
|
+
figsize: tuple[float, float] = DEFAULT_VIOLIN_FIGSIZE,
|
228
366
|
) -> alt.Chart:
|
229
367
|
"""
|
230
368
|
Create a violin plot per feature showing the distribution of SHAP values.
|
@@ -1,7 +1,5 @@
|
|
1
|
-
from snowflake import snowpark
|
2
1
|
from snowflake.ml._internal import telemetry
|
3
2
|
from snowflake.ml._internal.utils import sql_identifier
|
4
|
-
from snowflake.ml.monitoring import model_monitor_version
|
5
3
|
from snowflake.ml.monitoring._client import model_monitor_sql_client
|
6
4
|
|
7
5
|
|
@@ -29,7 +27,6 @@ class ModelMonitor:
|
|
29
27
|
project=telemetry.TelemetryProject.MLOPS.value,
|
30
28
|
subproject=telemetry.TelemetrySubProject.MONITORING.value,
|
31
29
|
)
|
32
|
-
@snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
|
33
30
|
def suspend(self) -> None:
|
34
31
|
"""Suspend the Model Monitor"""
|
35
32
|
statement_params = telemetry.get_statement_params(
|
@@ -42,7 +39,6 @@ class ModelMonitor:
|
|
42
39
|
project=telemetry.TelemetryProject.MLOPS.value,
|
43
40
|
subproject=telemetry.TelemetrySubProject.MONITORING.value,
|
44
41
|
)
|
45
|
-
@snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
|
46
42
|
def resume(self) -> None:
|
47
43
|
"""Resume the Model Monitor"""
|
48
44
|
statement_params = telemetry.get_statement_params(
|
@@ -14,7 +14,7 @@ from snowflake.ml.model import (
|
|
14
14
|
type_hints as model_types,
|
15
15
|
)
|
16
16
|
from snowflake.ml.model._client.model import model_version_impl
|
17
|
-
from snowflake.ml.monitoring import model_monitor
|
17
|
+
from snowflake.ml.monitoring import model_monitor
|
18
18
|
from snowflake.ml.monitoring._manager import model_monitor_manager
|
19
19
|
from snowflake.ml.monitoring.entities import model_monitor_config
|
20
20
|
from snowflake.ml.registry._manager import model_manager
|
@@ -30,6 +30,7 @@ _MODEL_MONITORING_DISABLED_ERROR = (
|
|
30
30
|
|
31
31
|
|
32
32
|
class Registry:
|
33
|
+
@telemetry.send_api_usage_telemetry(project=_TELEMETRY_PROJECT, subproject=_MODEL_TELEMETRY_SUBPROJECT)
|
33
34
|
def __init__(
|
34
35
|
self,
|
35
36
|
session: session.Session,
|
@@ -74,6 +75,22 @@ class Registry:
|
|
74
75
|
else sql_identifier.SqlIdentifier("PUBLIC")
|
75
76
|
)
|
76
77
|
|
78
|
+
database_exists = session.sql(
|
79
|
+
f"""SELECT 1 FROM INFORMATION_SCHEMA.DATABASES WHERE DATABASE_NAME = '{self._database_name.resolved()}';"""
|
80
|
+
).collect()
|
81
|
+
|
82
|
+
if not database_exists:
|
83
|
+
raise ValueError(f"Database {self._database_name} does not exist.")
|
84
|
+
|
85
|
+
schema_exists = session.sql(
|
86
|
+
f"""
|
87
|
+
SELECT 1 FROM {self._database_name.identifier()}.INFORMATION_SCHEMA.SCHEMATA
|
88
|
+
WHERE SCHEMA_NAME = '{self._schema_name.resolved()}';"""
|
89
|
+
).collect()
|
90
|
+
|
91
|
+
if not schema_exists:
|
92
|
+
raise ValueError(f"Schema {self._schema_name} does not exist.")
|
93
|
+
|
77
94
|
self._model_manager = model_manager.ModelManager(
|
78
95
|
session,
|
79
96
|
database_name=self._database_name,
|
@@ -155,7 +172,11 @@ class Registry:
|
|
155
172
|
`snowflake.snowpark.pypi_shared_repository`.
|
156
173
|
resource_constraint: Mapping of resource constraint keys and values, e.g. {"architecture": "x86"}.
|
157
174
|
target_platforms: List of target platforms to run the model. The only acceptable inputs are a combination of
|
158
|
-
|
175
|
+
"WAREHOUSE" and "SNOWPARK_CONTAINER_SERVICES":
|
176
|
+
- ["WAREHOUSE"] (Warehouse only)
|
177
|
+
- ["SNOWPARK_CONTAINER_SERVICES"] (Snowpark Container Services only)
|
178
|
+
- ["WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"] (Both)
|
179
|
+
Defaults to None. When None, the target platforms will be both.
|
159
180
|
python_version: Python version in which the model is run. Defaults to None.
|
160
181
|
signatures: Model data signatures for inputs and outputs for various target methods. If it is None,
|
161
182
|
sample_input_data would be used to infer the signatures for those models that cannot automatically
|
@@ -295,8 +316,11 @@ class Registry:
|
|
295
316
|
`snowflake.snowpark.pypi_shared_repository`.
|
296
317
|
resource_constraint: Mapping of resource constraint keys and values, e.g. {"architecture": "x86"}.
|
297
318
|
target_platforms: List of target platforms to run the model. The only acceptable inputs are a combination of
|
298
|
-
|
299
|
-
|
319
|
+
"WAREHOUSE" and "SNOWPARK_CONTAINER_SERVICES":
|
320
|
+
- ["WAREHOUSE"] (Warehouse only)
|
321
|
+
- ["SNOWPARK_CONTAINER_SERVICES"] (Snowpark Container Services only)
|
322
|
+
- ["WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"] (Both)
|
323
|
+
Defaults to None. When None, the target platforms will be both.
|
300
324
|
python_version: Python version in which the model is run. Defaults to None.
|
301
325
|
signatures: Model data signatures for inputs and outputs for various target methods. If it is None,
|
302
326
|
sample_input_data would be used to infer the signatures for those models that cannot automatically
|
@@ -397,11 +421,11 @@ class Registry:
|
|
397
421
|
if task is not model_types.Task.UNKNOWN:
|
398
422
|
raise ValueError("`task` cannot be specified when calling log_model with a ModelVersion.")
|
399
423
|
|
400
|
-
if pip_requirements:
|
424
|
+
if pip_requirements and not artifact_repository_map:
|
401
425
|
warnings.warn(
|
402
|
-
"Models logged specifying `pip_requirements`
|
403
|
-
"
|
404
|
-
"
|
426
|
+
"Models logged specifying `pip_requirements` cannot be executed in a Snowflake Warehouse "
|
427
|
+
"without specifying `artifact_repository_map`. This model can be run in Snowpark Container "
|
428
|
+
"Services. See https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/container.",
|
405
429
|
category=UserWarning,
|
406
430
|
stacklevel=1,
|
407
431
|
)
|
@@ -500,7 +524,6 @@ class Registry:
|
|
500
524
|
project=telemetry.TelemetryProject.MLOPS.value,
|
501
525
|
subproject=telemetry.TelemetrySubProject.MONITORING.value,
|
502
526
|
)
|
503
|
-
@snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
|
504
527
|
def add_monitor(
|
505
528
|
self,
|
506
529
|
name: str,
|
@@ -525,7 +548,7 @@ class Registry:
|
|
525
548
|
return self._model_monitor_manager.add_monitor(name, source_config, model_monitor_config)
|
526
549
|
|
527
550
|
@overload
|
528
|
-
def get_monitor(self, model_version: model_version_impl.ModelVersion) -> model_monitor.ModelMonitor:
|
551
|
+
def get_monitor(self, *, model_version: model_version_impl.ModelVersion) -> model_monitor.ModelMonitor:
|
529
552
|
"""Get a Model Monitor on a Model Version from the Registry.
|
530
553
|
|
531
554
|
Args:
|
@@ -534,7 +557,7 @@ class Registry:
|
|
534
557
|
...
|
535
558
|
|
536
559
|
@overload
|
537
|
-
def get_monitor(self, name: str) -> model_monitor.ModelMonitor:
|
560
|
+
def get_monitor(self, *, name: str) -> model_monitor.ModelMonitor:
|
538
561
|
"""Get a Model Monitor by name from the Registry.
|
539
562
|
|
540
563
|
Args:
|
@@ -546,7 +569,6 @@ class Registry:
|
|
546
569
|
project=telemetry.TelemetryProject.MLOPS.value,
|
547
570
|
subproject=telemetry.TelemetrySubProject.MONITORING.value,
|
548
571
|
)
|
549
|
-
@snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
|
550
572
|
def get_monitor(
|
551
573
|
self, *, name: Optional[str] = None, model_version: Optional[model_version_impl.ModelVersion] = None
|
552
574
|
) -> model_monitor.ModelMonitor:
|
@@ -575,7 +597,6 @@ class Registry:
|
|
575
597
|
project=telemetry.TelemetryProject.MLOPS.value,
|
576
598
|
subproject=telemetry.TelemetrySubProject.MONITORING.value,
|
577
599
|
)
|
578
|
-
@snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
|
579
600
|
def show_model_monitors(self) -> list[snowpark.Row]:
|
580
601
|
"""Show all model monitors in the registry.
|
581
602
|
|
@@ -593,7 +614,6 @@ class Registry:
|
|
593
614
|
project=telemetry.TelemetryProject.MLOPS.value,
|
594
615
|
subproject=telemetry.TelemetrySubProject.MONITORING.value,
|
595
616
|
)
|
596
|
-
@snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
|
597
617
|
def delete_monitor(self, name: str) -> None:
|
598
618
|
"""Delete a Model Monitor by name from the Registry.
|
599
619
|
|
@@ -113,6 +113,10 @@ def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -
|
|
113
113
|
|
114
114
|
config = configparser.ConfigParser(inline_comment_prefixes="#")
|
115
115
|
|
116
|
+
snowflake_connection_name = os.getenv("SNOWFLAKE_CONNECTION_NAME")
|
117
|
+
if snowflake_connection_name is not None:
|
118
|
+
connection_name = snowflake_connection_name
|
119
|
+
|
116
120
|
if connection_name:
|
117
121
|
if not connection_name.startswith("connections."):
|
118
122
|
connection_name = "connections." + connection_name
|
@@ -132,7 +136,7 @@ def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -
|
|
132
136
|
return conn_params
|
133
137
|
|
134
138
|
|
135
|
-
@snowpark._internal.utils.
|
139
|
+
@snowpark._internal.utils.deprecated(version="1.8.5")
|
136
140
|
def SnowflakeLoginOptions(connection_name: str = "", login_file: Optional[str] = None) -> dict[str, Union[str, bytes]]:
|
137
141
|
"""Returns a dict that can be used directly into snowflake python connector or Snowpark session config.
|
138
142
|
|
@@ -153,9 +157,11 @@ def SnowflakeLoginOptions(connection_name: str = "", login_file: Optional[str] =
|
|
153
157
|
Ideally one should have a snowsql config file. Read more here:
|
154
158
|
https://docs.snowflake.com/en/user-guide/snowsql-start.html#configuring-default-connection-settings
|
155
159
|
|
160
|
+
If snowsql config file does not exist, it tries auth from env variables.
|
161
|
+
|
156
162
|
Args:
|
157
|
-
connection_name: Name of the connection to look for inside the config file. If
|
158
|
-
it
|
163
|
+
connection_name: Name of the connection to look for inside the config file. If environment variable
|
164
|
+
SNOWFLAKE_CONNECTION_NAME is provided, it will override the input connection_name.
|
159
165
|
login_file: If provided, this is used as config file instead of default one (_DEFAULT_CONNECTION_FILE).
|
160
166
|
|
161
167
|
Returns:
|