snowflake-ml-python 1.8.4__py3-none-any.whl → 1.8.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. snowflake/ml/_internal/telemetry.py +42 -16
  2. snowflake/ml/_internal/utils/connection_params.py +196 -0
  3. snowflake/ml/data/data_connector.py +1 -1
  4. snowflake/ml/jobs/__init__.py +2 -0
  5. snowflake/ml/jobs/_utils/constants.py +12 -2
  6. snowflake/ml/jobs/_utils/function_payload_utils.py +43 -0
  7. snowflake/ml/jobs/_utils/interop_utils.py +1 -1
  8. snowflake/ml/jobs/_utils/payload_utils.py +95 -39
  9. snowflake/ml/jobs/_utils/scripts/constants.py +22 -0
  10. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +67 -2
  11. snowflake/ml/jobs/_utils/spec_utils.py +30 -6
  12. snowflake/ml/jobs/_utils/stage_utils.py +119 -0
  13. snowflake/ml/jobs/_utils/types.py +5 -1
  14. snowflake/ml/jobs/decorators.py +10 -7
  15. snowflake/ml/jobs/job.py +176 -28
  16. snowflake/ml/jobs/manager.py +119 -26
  17. snowflake/ml/model/_client/model/model_impl.py +58 -0
  18. snowflake/ml/model/_client/model/model_version_impl.py +90 -0
  19. snowflake/ml/model/_client/ops/model_ops.py +6 -3
  20. snowflake/ml/model/_client/ops/service_ops.py +24 -7
  21. snowflake/ml/model/_client/service/model_deployment_spec.py +11 -0
  22. snowflake/ml/model/_client/sql/model_version.py +1 -1
  23. snowflake/ml/model/_client/sql/service.py +73 -28
  24. snowflake/ml/model/_client/sql/stage.py +5 -2
  25. snowflake/ml/model/_model_composer/model_composer.py +3 -1
  26. snowflake/ml/model/_packager/model_handlers/sklearn.py +1 -1
  27. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +103 -73
  28. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +3 -2
  29. snowflake/ml/model/_signatures/core.py +24 -0
  30. snowflake/ml/monitoring/explain_visualize.py +160 -22
  31. snowflake/ml/monitoring/model_monitor.py +0 -4
  32. snowflake/ml/registry/registry.py +34 -14
  33. snowflake/ml/utils/connection_params.py +9 -3
  34. snowflake/ml/utils/html_utils.py +263 -0
  35. snowflake/ml/version.py +1 -1
  36. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/METADATA +40 -13
  37. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/RECORD +40 -37
  38. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/WHEEL +1 -1
  39. snowflake/ml/monitoring/model_monitor_version.py +0 -1
  40. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/licenses/LICENSE.txt +0 -0
  41. {snowflake_ml_python-1.8.4.dist-info → snowflake_ml_python-1.8.6.dist-info}/top_level.txt +0 -0
@@ -12,7 +12,7 @@ REQUIREMENTS = [
12
12
  "importlib_resources>=6.1.1, <7",
13
13
  "numpy>=1.23,<2",
14
14
  "packaging>=20.9,<25",
15
- "pandas>=1.0.0,<3",
15
+ "pandas>=2.1.4,<3",
16
16
  "pyarrow",
17
17
  "pydantic>=2.8.2, <3",
18
18
  "pyjwt>=2.0.0, <3",
@@ -24,9 +24,10 @@ REQUIREMENTS = [
24
24
  "scikit-learn<1.6",
25
25
  "scipy>=1.9,<2",
26
26
  "shap>=0.46.0,<1",
27
- "snowflake-connector-python>=3.14.0,<4",
27
+ "snowflake-connector-python>=3.15.0,<4",
28
28
  "snowflake-snowpark-python>=1.17.0,<2,!=1.26.0",
29
29
  "snowflake.core>=1.0.2,<2",
30
30
  "sqlparse>=0.4,<1",
31
31
  "typing-extensions>=4.1.0,<5",
32
+ "xgboost>=1.7.3,<3",
32
33
  ]
@@ -559,6 +559,30 @@ class ModelSignature:
559
559
  )"""
560
560
  )
561
561
 
562
+ def _repr_html_(self) -> str:
563
+ """Generate an HTML representation of the model signature.
564
+
565
+ Returns:
566
+ str: HTML string containing formatted signature details.
567
+ """
568
+ from snowflake.ml.utils import html_utils
569
+
570
+ # Create collapsible sections for inputs and outputs
571
+ inputs_content = html_utils.create_features_html(self.inputs, "Input")
572
+ outputs_content = html_utils.create_features_html(self.outputs, "Output")
573
+
574
+ inputs_section = html_utils.create_collapsible_section("Inputs", inputs_content, open_by_default=True)
575
+ outputs_section = html_utils.create_collapsible_section("Outputs", outputs_content, open_by_default=True)
576
+
577
+ content = f"""
578
+ <div style="margin-top: 10px;">
579
+ {inputs_section}
580
+ {outputs_section}
581
+ </div>
582
+ """
583
+
584
+ return html_utils.create_base_container("Model Signature", content)
585
+
562
586
  @classmethod
563
587
  def from_mlflow_sig(cls, mlflow_sig: "mlflow.models.ModelSignature") -> "ModelSignature":
564
588
  return ModelSignature(
@@ -1,4 +1,4 @@
1
- from typing import Union, cast, overload
1
+ from typing import Any, Union, cast, overload
2
2
 
3
3
  import altair as alt
4
4
  import numpy as np
@@ -6,16 +6,22 @@ import pandas as pd
6
6
 
7
7
  import snowflake.snowpark.dataframe as sp_df
8
8
  from snowflake import snowpark
9
+ from snowflake.ml._internal.exceptions import error_codes, exceptions
9
10
  from snowflake.ml.model import model_signature, type_hints
10
11
  from snowflake.ml.model._signatures import snowpark_handler
11
12
 
13
+ DEFAULT_FIGSIZE = (1400, 500)
14
+ DEFAULT_VIOLIN_FIGSIZE = (1400, 100)
15
+ MAX_ANNOTATION_LENGTH = 20
16
+ MIN_DISTANCE = 10 # Increase minimum distance between labels for more spreading in plot_force
17
+
12
18
 
13
19
  @overload
14
20
  def plot_force(
15
21
  shap_row: snowpark.Row,
16
22
  features_row: snowpark.Row,
17
23
  base_value: float = 0.0,
18
- figsize: tuple[float, float] = (600, 200),
24
+ figsize: tuple[float, float] = DEFAULT_FIGSIZE,
19
25
  contribution_threshold: float = 0.05,
20
26
  ) -> alt.LayerChart:
21
27
  ...
@@ -26,7 +32,7 @@ def plot_force(
26
32
  shap_row: pd.Series,
27
33
  features_row: pd.Series,
28
34
  base_value: float = 0.0,
29
- figsize: tuple[float, float] = (600, 200),
35
+ figsize: tuple[float, float] = DEFAULT_FIGSIZE,
30
36
  contribution_threshold: float = 0.05,
31
37
  ) -> alt.LayerChart:
32
38
  ...
@@ -36,7 +42,7 @@ def plot_force(
36
42
  shap_row: Union[pd.Series, snowpark.Row],
37
43
  features_row: Union[pd.Series, snowpark.Row],
38
44
  base_value: float = 0.0,
39
- figsize: tuple[float, float] = (600, 200),
45
+ figsize: tuple[float, float] = DEFAULT_FIGSIZE,
40
46
  contribution_threshold: float = 0.05,
41
47
  ) -> alt.LayerChart:
42
48
  """
@@ -53,7 +59,17 @@ def plot_force(
53
59
 
54
60
  Returns:
55
61
  Altair chart object
62
+
63
+ Raises:
64
+ SnowflakeMLException: If the contribution threshold is not between 0 and 1,
65
+ or if no features with significant contributions are found.
56
66
  """
67
+ if not (0 < contribution_threshold and contribution_threshold < 1):
68
+ raise exceptions.SnowflakeMLException(
69
+ error_code=error_codes.INVALID_ARGUMENT,
70
+ original_exception=ValueError("contribution_threshold must be between 0 and 1."),
71
+ )
72
+
57
73
  if isinstance(shap_row, snowpark.Row):
58
74
  shap_row = pd.Series(shap_row.as_dict())
59
75
  if isinstance(features_row, snowpark.Row):
@@ -67,7 +83,7 @@ def plot_force(
67
83
  {
68
84
  "feature": feature,
69
85
  "feature_value": features_row.iloc[index],
70
- "feature_annotated": f"{feature}: {features_row.iloc[index]}",
86
+ "feature_annotated": f"{feature}: {features_row.iloc[index]}"[:MAX_ANNOTATION_LENGTH],
71
87
  "influence_value": shap_row.iloc[index],
72
88
  "bar_direction": positive_label if shap_row.iloc[index] >= 0 else negative_label,
73
89
  }
@@ -95,11 +111,11 @@ def plot_force(
95
111
 
96
112
  if row_influence_value >= 0:
97
113
  start = current_position_pos - spacing
98
- end = current_position_pos - row_influence_value
114
+ end = current_position_pos - row_influence_value - spacing
99
115
  current_position_pos = end
100
116
  else:
101
117
  start = current_position_neg + spacing
102
- end = current_position_neg + abs(row_influence_value)
118
+ end = current_position_neg + abs(row_influence_value) + spacing
103
119
  current_position_neg = end
104
120
 
105
121
  positions.append(
@@ -108,13 +124,23 @@ def plot_force(
108
124
  "end": end,
109
125
  "avg": (start + end) / 2,
110
126
  "influence_value": row_influence_value,
111
- "influence_annotated": f"Influence: {row_influence_value}",
112
127
  "feature_value": row["feature_value"],
113
128
  "feature_annotated": row["feature_annotated"],
114
129
  "bar_direction": row["bar_direction"],
130
+ "bar_y": 0,
131
+ "feature": row["feature"],
115
132
  }
116
133
  )
117
134
 
135
+ if len(positions) == 0:
136
+ raise exceptions.SnowflakeMLException(
137
+ error_code=error_codes.INVALID_ARGUMENT,
138
+ original_exception=ValueError(
139
+ "No features with significant contributions found. Try lowering the contribution_threshold,"
140
+ "and verify the input is non-empty."
141
+ ),
142
+ )
143
+
118
144
  position_df = pd.DataFrame(positions)
119
145
 
120
146
  # Create force plot using Altair
@@ -127,12 +153,13 @@ def plot_force(
127
153
  .encode(
128
154
  x=alt.X("start:Q", title="Feature Impact"),
129
155
  x2=alt.X2("end:Q"),
156
+ y=alt.Y("bar_y:Q", axis=None),
130
157
  color=alt.Color(
131
158
  "bar_direction:N",
132
159
  scale=alt.Scale(domain=[positive_label, negative_label], range=[red_color, blue_color]),
133
160
  legend=alt.Legend(title="Influence Direction"),
134
161
  ),
135
- tooltip=["influence_value", "feature_value"],
162
+ tooltip=["feature", "influence_value", "feature_value"],
136
163
  )
137
164
  .properties(title="Feature Influence (SHAP values)", width=width, height=height)
138
165
  ).interactive()
@@ -142,6 +169,7 @@ def plot_force(
142
169
  .mark_point(shape="triangle", filled=True, fillOpacity=1)
143
170
  .encode(
144
171
  x=alt.X("start:Q"),
172
+ y=alt.Y("bar_y:Q", axis=None),
145
173
  angle=alt.Angle("bar_direction:N", scale=alt.Scale(domain=["Positive", "Negative"], range=[90, -90])),
146
174
  color=alt.Color(
147
175
  "bar_direction:N", scale=alt.Scale(domain=["Positive", "Negative"], range=["#1f77b4", "#d62728"])
@@ -154,37 +182,147 @@ def plot_force(
154
182
  # Add a vertical line at the base value
155
183
  zero_line: alt.Chart = alt.Chart(pd.DataFrame({"x": [base_value]})).mark_rule(strokeDash=[3, 3]).encode(x="x:Q")
156
184
 
157
- # Add text labels on each bar
185
+ # Calculate label positions to avoid overlap and ensure labels are spread apart horizontally
186
+
187
+ # Sort by bar center (avg) for label placement
188
+ sorted_positions = sorted(positions, key=lambda x: x["avg"])
189
+
190
+ # Improved label spreading algorithm:
191
+ # Calculate the minimum and maximum x positions (avg) for the bars
192
+ min_x = min(pos["avg"] for pos in sorted_positions)
193
+ max_x = max(pos["avg"] for pos in sorted_positions)
194
+ n_labels = len(sorted_positions)
195
+ # Calculate the minimum required distance between labels
196
+ spread_width = max_x - min_x
197
+ if n_labels > 1:
198
+ space_per_label = spread_width / (n_labels - 1)
199
+ # If space_per_label is less than min_distance, use min_distance instead
200
+ effective_distance = max(space_per_label, MIN_DISTANCE)
201
+ else:
202
+ effective_distance = 0
203
+
204
+ # Start from min_x - offset, and assign label_x for each label from left to right
205
+ offset = -effective_distance # Start a bit to the left
206
+ label_positions = []
207
+ label_lines = []
208
+ placed_label_xs: list[float] = []
209
+ for i, pos in enumerate(sorted_positions):
210
+ if i == 0:
211
+ label_x = min_x + offset
212
+ else:
213
+ label_x = placed_label_xs[-1] + effective_distance
214
+ placed_label_xs.append(label_x)
215
+ label_positions.append(
216
+ {
217
+ "label_x": label_x,
218
+ "label_y": 1, # Place labels below the bars
219
+ "feature_annotated": pos["feature_annotated"],
220
+ "feature_value": pos["feature_value"],
221
+ }
222
+ )
223
+ # Draw a diagonal line from the bar to the label
224
+ label_lines.append(
225
+ {
226
+ "x": pos["avg"],
227
+ "x2": label_x,
228
+ "y": 0,
229
+ "y2": 1,
230
+ }
231
+ )
232
+
233
+ label_positions_df = pd.DataFrame(label_positions)
234
+ label_lines_df = pd.DataFrame(label_lines)
235
+
236
+ # Draw diagonal lines from bar to label
237
+ label_connectors = (
238
+ alt.Chart(label_lines_df)
239
+ .mark_rule(strokeDash=[2, 2], color="grey")
240
+ .encode(
241
+ x="x:Q",
242
+ x2="x2:Q",
243
+ y=alt.Y("y:Q", axis=None),
244
+ y2="y2:Q",
245
+ )
246
+ )
247
+
248
+ # Place labels at adjusted positions
158
249
  feature_labels = (
159
- alt.Chart(position_df)
160
- .mark_text(align="center", baseline="line-bottom", dy=30, fontSize=11)
250
+ alt.Chart(label_positions_df)
251
+ .mark_text(align="center", baseline="line-bottom", dy=0, fontSize=11)
161
252
  .encode(
162
- x=alt.X("avg:Q"),
163
- text=alt.Text("feature_annotated:N"), # Display with 2 decimal places
164
- color=alt.value("grey"), # Label color for positive values
253
+ x=alt.X("label_x:Q"),
254
+ y=alt.Y("label_y:Q", axis=None),
255
+ text=alt.Text("feature_annotated:N"),
256
+ color=alt.value("grey"),
165
257
  tooltip=["feature_value"],
166
258
  )
167
259
  )
168
260
 
169
- return cast(alt.LayerChart, bars + feature_labels + zero_line + arrow)
261
+ return cast(alt.LayerChart, bars + feature_labels + zero_line + arrow + label_connectors)
170
262
 
171
263
 
172
264
  def plot_influence_sensitivity(
173
- feature_values: pd.Series, shap_values: pd.Series, figsize: tuple[float, float] = (600, 400)
174
- ) -> alt.Chart:
265
+ shap_values: type_hints.SupportedDataType,
266
+ feature_values: type_hints.SupportedDataType,
267
+ figsize: tuple[float, float] = DEFAULT_FIGSIZE,
268
+ ) -> Any:
175
269
  """
176
- Create a SHAP dependence scatter plot for a specific feature.
270
+ Create a SHAP dependence scatter plot for a specific feature. If a DataFrame is provided, a select box
271
+ will be displayed to select the feature. This is only supported in Snowflake notebooks.
272
+ If Streamlit is not available and a DataFrame is passed in, an ImportError will be raised.
177
273
 
178
274
  Args:
179
- feature_values: pandas Series containing the feature values for a specific feature
180
- shap_values: pandas Series containing the SHAP values for the same feature
275
+ shap_values: pandas Series or 2D array containing the SHAP values for a specific feature
276
+ feature_values: pandas Series or 2D array containing the feature values for the same feature
181
277
  figsize: tuple of (width, height) for the plot
182
278
 
183
279
  Returns:
184
280
  Altair chart object
185
281
 
282
+ Raises:
283
+ ValueError: If the types of feature_values and shap_values are not the same
284
+
186
285
  """
187
286
 
287
+ use_streamlit = False
288
+ feature_values_df = _convert_to_pandas_df(feature_values)
289
+ shap_values_df = _convert_to_pandas_df(shap_values)
290
+
291
+ if len(shap_values_df.shape) > 1:
292
+ feature_values, shap_values, st = _prepare_feature_values_for_streamlit(feature_values_df, shap_values_df)
293
+ use_streamlit = True
294
+ elif feature_values_df.shape[0] != shap_values_df.shape[0]:
295
+ raise ValueError("Feature values and SHAP values must have the same number of rows.")
296
+
297
+ scatter = _create_scatter_plot(feature_values, shap_values, figsize)
298
+ return st.altair_chart(scatter) if use_streamlit else scatter
299
+
300
+
301
+ def _prepare_feature_values_for_streamlit(
302
+ feature_values_df: pd.DataFrame, shap_values: pd.DataFrame
303
+ ) -> tuple[pd.Series, pd.Series, Any]:
304
+ try:
305
+ from IPython import get_ipython
306
+ from snowbook.executor.python_transformer import IPythonProxy
307
+
308
+ assert isinstance(
309
+ get_ipython(), IPythonProxy
310
+ ), "Influence sensitivity plots for a DataFrame are not supported outside of Snowflake notebooks."
311
+ except ImportError:
312
+ raise RuntimeError(
313
+ "Influence sensitivity plots for a DataFrame are not supported outside of Snowflake notebooks."
314
+ )
315
+
316
+ import streamlit as st
317
+
318
+ feature_columns = feature_values_df.columns
319
+ chosen_ft: str = st.selectbox("Feature:", feature_columns)
320
+ feature_values = feature_values_df[chosen_ft]
321
+ shap_values = shap_values.iloc[:, feature_columns.get_loc(chosen_ft)]
322
+ return feature_values, shap_values, st
323
+
324
+
325
+ def _create_scatter_plot(feature_values: pd.Series, shap_values: pd.Series, figsize: tuple[float, float]) -> alt.Chart:
188
326
  unique_vals = np.sort(np.unique(feature_values.values))
189
327
  max_points_per_unique_value = float(np.max(np.bincount(np.searchsorted(unique_vals, feature_values.values))))
190
328
  points_per_value = len(feature_values.values) / len(unique_vals)
@@ -224,7 +362,7 @@ def plot_influence_sensitivity(
224
362
  def plot_violin(
225
363
  shap_df: type_hints.SupportedDataType,
226
364
  feature_df: type_hints.SupportedDataType,
227
- figsize: tuple[float, float] = (600, 200),
365
+ figsize: tuple[float, float] = DEFAULT_VIOLIN_FIGSIZE,
228
366
  ) -> alt.Chart:
229
367
  """
230
368
  Create a violin plot per feature showing the distribution of SHAP values.
@@ -1,7 +1,5 @@
1
- from snowflake import snowpark
2
1
  from snowflake.ml._internal import telemetry
3
2
  from snowflake.ml._internal.utils import sql_identifier
4
- from snowflake.ml.monitoring import model_monitor_version
5
3
  from snowflake.ml.monitoring._client import model_monitor_sql_client
6
4
 
7
5
 
@@ -29,7 +27,6 @@ class ModelMonitor:
29
27
  project=telemetry.TelemetryProject.MLOPS.value,
30
28
  subproject=telemetry.TelemetrySubProject.MONITORING.value,
31
29
  )
32
- @snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
33
30
  def suspend(self) -> None:
34
31
  """Suspend the Model Monitor"""
35
32
  statement_params = telemetry.get_statement_params(
@@ -42,7 +39,6 @@ class ModelMonitor:
42
39
  project=telemetry.TelemetryProject.MLOPS.value,
43
40
  subproject=telemetry.TelemetrySubProject.MONITORING.value,
44
41
  )
45
- @snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
46
42
  def resume(self) -> None:
47
43
  """Resume the Model Monitor"""
48
44
  statement_params = telemetry.get_statement_params(
@@ -14,7 +14,7 @@ from snowflake.ml.model import (
14
14
  type_hints as model_types,
15
15
  )
16
16
  from snowflake.ml.model._client.model import model_version_impl
17
- from snowflake.ml.monitoring import model_monitor, model_monitor_version
17
+ from snowflake.ml.monitoring import model_monitor
18
18
  from snowflake.ml.monitoring._manager import model_monitor_manager
19
19
  from snowflake.ml.monitoring.entities import model_monitor_config
20
20
  from snowflake.ml.registry._manager import model_manager
@@ -30,6 +30,7 @@ _MODEL_MONITORING_DISABLED_ERROR = (
30
30
 
31
31
 
32
32
  class Registry:
33
+ @telemetry.send_api_usage_telemetry(project=_TELEMETRY_PROJECT, subproject=_MODEL_TELEMETRY_SUBPROJECT)
33
34
  def __init__(
34
35
  self,
35
36
  session: session.Session,
@@ -74,6 +75,22 @@ class Registry:
74
75
  else sql_identifier.SqlIdentifier("PUBLIC")
75
76
  )
76
77
 
78
+ database_exists = session.sql(
79
+ f"""SELECT 1 FROM INFORMATION_SCHEMA.DATABASES WHERE DATABASE_NAME = '{self._database_name.resolved()}';"""
80
+ ).collect()
81
+
82
+ if not database_exists:
83
+ raise ValueError(f"Database {self._database_name} does not exist.")
84
+
85
+ schema_exists = session.sql(
86
+ f"""
87
+ SELECT 1 FROM {self._database_name.identifier()}.INFORMATION_SCHEMA.SCHEMATA
88
+ WHERE SCHEMA_NAME = '{self._schema_name.resolved()}';"""
89
+ ).collect()
90
+
91
+ if not schema_exists:
92
+ raise ValueError(f"Schema {self._schema_name} does not exist.")
93
+
77
94
  self._model_manager = model_manager.ModelManager(
78
95
  session,
79
96
  database_name=self._database_name,
@@ -155,7 +172,11 @@ class Registry:
155
172
  `snowflake.snowpark.pypi_shared_repository`.
156
173
  resource_constraint: Mapping of resource constraint keys and values, e.g. {"architecture": "x86"}.
157
174
  target_platforms: List of target platforms to run the model. The only acceptable inputs are a combination of
158
- {"WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"}. Defaults to None.
175
+ "WAREHOUSE" and "SNOWPARK_CONTAINER_SERVICES":
176
+ - ["WAREHOUSE"] (Warehouse only)
177
+ - ["SNOWPARK_CONTAINER_SERVICES"] (Snowpark Container Services only)
178
+ - ["WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"] (Both)
179
+ Defaults to None. When None, the target platforms will be both.
159
180
  python_version: Python version in which the model is run. Defaults to None.
160
181
  signatures: Model data signatures for inputs and outputs for various target methods. If it is None,
161
182
  sample_input_data would be used to infer the signatures for those models that cannot automatically
@@ -295,8 +316,11 @@ class Registry:
295
316
  `snowflake.snowpark.pypi_shared_repository`.
296
317
  resource_constraint: Mapping of resource constraint keys and values, e.g. {"architecture": "x86"}.
297
318
  target_platforms: List of target platforms to run the model. The only acceptable inputs are a combination of
298
- ["WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"]. Defaults to None. When None, the target platforms will be
299
- both.
319
+ "WAREHOUSE" and "SNOWPARK_CONTAINER_SERVICES":
320
+ - ["WAREHOUSE"] (Warehouse only)
321
+ - ["SNOWPARK_CONTAINER_SERVICES"] (Snowpark Container Services only)
322
+ - ["WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"] (Both)
323
+ Defaults to None. When None, the target platforms will be both.
300
324
  python_version: Python version in which the model is run. Defaults to None.
301
325
  signatures: Model data signatures for inputs and outputs for various target methods. If it is None,
302
326
  sample_input_data would be used to infer the signatures for those models that cannot automatically
@@ -397,11 +421,11 @@ class Registry:
397
421
  if task is not model_types.Task.UNKNOWN:
398
422
  raise ValueError("`task` cannot be specified when calling log_model with a ModelVersion.")
399
423
 
400
- if pip_requirements:
424
+ if pip_requirements and not artifact_repository_map:
401
425
  warnings.warn(
402
- "Models logged specifying `pip_requirements` can not be executed "
403
- "in Snowflake Warehouse where all dependencies are required to be retrieved "
404
- "from Snowflake Anaconda Channel.",
426
+ "Models logged specifying `pip_requirements` cannot be executed in a Snowflake Warehouse "
427
+ "without specifying `artifact_repository_map`. This model can be run in Snowpark Container "
428
+ "Services. See https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/container.",
405
429
  category=UserWarning,
406
430
  stacklevel=1,
407
431
  )
@@ -500,7 +524,6 @@ class Registry:
500
524
  project=telemetry.TelemetryProject.MLOPS.value,
501
525
  subproject=telemetry.TelemetrySubProject.MONITORING.value,
502
526
  )
503
- @snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
504
527
  def add_monitor(
505
528
  self,
506
529
  name: str,
@@ -525,7 +548,7 @@ class Registry:
525
548
  return self._model_monitor_manager.add_monitor(name, source_config, model_monitor_config)
526
549
 
527
550
  @overload
528
- def get_monitor(self, model_version: model_version_impl.ModelVersion) -> model_monitor.ModelMonitor:
551
+ def get_monitor(self, *, model_version: model_version_impl.ModelVersion) -> model_monitor.ModelMonitor:
529
552
  """Get a Model Monitor on a Model Version from the Registry.
530
553
 
531
554
  Args:
@@ -534,7 +557,7 @@ class Registry:
534
557
  ...
535
558
 
536
559
  @overload
537
- def get_monitor(self, name: str) -> model_monitor.ModelMonitor:
560
+ def get_monitor(self, *, name: str) -> model_monitor.ModelMonitor:
538
561
  """Get a Model Monitor by name from the Registry.
539
562
 
540
563
  Args:
@@ -546,7 +569,6 @@ class Registry:
546
569
  project=telemetry.TelemetryProject.MLOPS.value,
547
570
  subproject=telemetry.TelemetrySubProject.MONITORING.value,
548
571
  )
549
- @snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
550
572
  def get_monitor(
551
573
  self, *, name: Optional[str] = None, model_version: Optional[model_version_impl.ModelVersion] = None
552
574
  ) -> model_monitor.ModelMonitor:
@@ -575,7 +597,6 @@ class Registry:
575
597
  project=telemetry.TelemetryProject.MLOPS.value,
576
598
  subproject=telemetry.TelemetrySubProject.MONITORING.value,
577
599
  )
578
- @snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
579
600
  def show_model_monitors(self) -> list[snowpark.Row]:
580
601
  """Show all model monitors in the registry.
581
602
 
@@ -593,7 +614,6 @@ class Registry:
593
614
  project=telemetry.TelemetryProject.MLOPS.value,
594
615
  subproject=telemetry.TelemetrySubProject.MONITORING.value,
595
616
  )
596
- @snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
597
617
  def delete_monitor(self, name: str) -> None:
598
618
  """Delete a Model Monitor by name from the Registry.
599
619
 
@@ -113,6 +113,10 @@ def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -
113
113
 
114
114
  config = configparser.ConfigParser(inline_comment_prefixes="#")
115
115
 
116
+ snowflake_connection_name = os.getenv("SNOWFLAKE_CONNECTION_NAME")
117
+ if snowflake_connection_name is not None:
118
+ connection_name = snowflake_connection_name
119
+
116
120
  if connection_name:
117
121
  if not connection_name.startswith("connections."):
118
122
  connection_name = "connections." + connection_name
@@ -132,7 +136,7 @@ def _load_from_snowsql_config_file(connection_name: str, login_file: str = "") -
132
136
  return conn_params
133
137
 
134
138
 
135
- @snowpark._internal.utils.private_preview(version="0.2.0")
139
+ @snowpark._internal.utils.deprecated(version="1.8.5")
136
140
  def SnowflakeLoginOptions(connection_name: str = "", login_file: Optional[str] = None) -> dict[str, Union[str, bytes]]:
137
141
  """Returns a dict that can be used directly into snowflake python connector or Snowpark session config.
138
142
 
@@ -153,9 +157,11 @@ def SnowflakeLoginOptions(connection_name: str = "", login_file: Optional[str] =
153
157
  Ideally one should have a snowsql config file. Read more here:
154
158
  https://docs.snowflake.com/en/user-guide/snowsql-start.html#configuring-default-connection-settings
155
159
 
160
+ If snowsql config file does not exist, it tries auth from env variables.
161
+
156
162
  Args:
157
- connection_name: Name of the connection to look for inside the config file. If `connection_name` is NOT given,
158
- it tries auth from env variables.
163
+ connection_name: Name of the connection to look for inside the config file. If environment variable
164
+ SNOWFLAKE_CONNECTION_NAME is provided, it will override the input connection_name.
159
165
  login_file: If provided, this is used as config file instead of default one (_DEFAULT_CONNECTION_FILE).
160
166
 
161
167
  Returns: