snowflake-ml-python 1.8.6__py3-none-any.whl → 1.9.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. snowflake/ml/_internal/env_utils.py +44 -3
  2. snowflake/ml/_internal/platform_capabilities.py +52 -2
  3. snowflake/ml/_internal/type_utils.py +1 -1
  4. snowflake/ml/_internal/utils/identifier.py +1 -1
  5. snowflake/ml/_internal/utils/mixins.py +71 -0
  6. snowflake/ml/_internal/utils/service_logger.py +4 -2
  7. snowflake/ml/data/_internal/arrow_ingestor.py +11 -1
  8. snowflake/ml/data/data_connector.py +43 -2
  9. snowflake/ml/data/data_ingestor.py +8 -0
  10. snowflake/ml/data/torch_utils.py +1 -1
  11. snowflake/ml/dataset/dataset.py +3 -2
  12. snowflake/ml/dataset/dataset_reader.py +22 -6
  13. snowflake/ml/experiment/_client/experiment_tracking_sql_client.py +98 -0
  14. snowflake/ml/experiment/_entities/__init__.py +4 -0
  15. snowflake/ml/experiment/_entities/experiment.py +10 -0
  16. snowflake/ml/experiment/_entities/run.py +62 -0
  17. snowflake/ml/experiment/_entities/run_metadata.py +68 -0
  18. snowflake/ml/experiment/_experiment_info.py +63 -0
  19. snowflake/ml/experiment/experiment_tracking.py +319 -0
  20. snowflake/ml/jobs/_utils/constants.py +1 -1
  21. snowflake/ml/jobs/_utils/interop_utils.py +63 -4
  22. snowflake/ml/jobs/_utils/payload_utils.py +5 -3
  23. snowflake/ml/jobs/_utils/query_helper.py +20 -0
  24. snowflake/ml/jobs/_utils/scripts/mljob_launcher.py +5 -1
  25. snowflake/ml/jobs/_utils/spec_utils.py +21 -4
  26. snowflake/ml/jobs/decorators.py +18 -25
  27. snowflake/ml/jobs/job.py +137 -37
  28. snowflake/ml/jobs/manager.py +228 -153
  29. snowflake/ml/lineage/lineage_node.py +2 -2
  30. snowflake/ml/model/_client/model/model_version_impl.py +16 -4
  31. snowflake/ml/model/_client/ops/model_ops.py +12 -3
  32. snowflake/ml/model/_client/ops/service_ops.py +324 -138
  33. snowflake/ml/model/_client/service/model_deployment_spec.py +1 -1
  34. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +3 -1
  35. snowflake/ml/model/_model_composer/model_composer.py +6 -1
  36. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +55 -13
  37. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  38. snowflake/ml/model/_packager/model_env/model_env.py +35 -27
  39. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +41 -2
  40. snowflake/ml/model/_packager/model_handlers/pytorch.py +5 -1
  41. snowflake/ml/model/_packager/model_meta/model_meta.py +3 -1
  42. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -1
  43. snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -3
  44. snowflake/ml/model/_signatures/snowpark_handler.py +55 -3
  45. snowflake/ml/model/event_handler.py +117 -0
  46. snowflake/ml/model/model_signature.py +9 -9
  47. snowflake/ml/model/models/huggingface_pipeline.py +170 -1
  48. snowflake/ml/model/target_platform.py +11 -0
  49. snowflake/ml/model/task.py +9 -0
  50. snowflake/ml/model/type_hints.py +5 -13
  51. snowflake/ml/modeling/framework/base.py +1 -1
  52. snowflake/ml/modeling/metrics/classification.py +14 -14
  53. snowflake/ml/modeling/metrics/correlation.py +19 -8
  54. snowflake/ml/modeling/metrics/metrics_utils.py +2 -0
  55. snowflake/ml/modeling/metrics/ranking.py +6 -6
  56. snowflake/ml/modeling/metrics/regression.py +9 -9
  57. snowflake/ml/monitoring/explain_visualize.py +12 -5
  58. snowflake/ml/registry/_manager/model_manager.py +47 -15
  59. snowflake/ml/registry/registry.py +109 -64
  60. snowflake/ml/version.py +1 -1
  61. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/METADATA +118 -18
  62. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/RECORD +65 -53
  63. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/WHEEL +0 -0
  64. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/licenses/LICENSE.txt +0 -0
  65. {snowflake_ml_python-1.8.6.dist-info → snowflake_ml_python-1.9.1.dist-info}/top_level.txt +0 -0
@@ -29,7 +29,7 @@ def d2_absolute_error_score(
29
29
  y_pred_col_names: Union[str, list[str]],
30
30
  sample_weight_col_name: Optional[str] = None,
31
31
  multioutput: Union[str, npt.ArrayLike] = "uniform_average",
32
- ) -> Union[float, npt.NDArray[np.float_]]:
32
+ ) -> Union[float, npt.NDArray[np.float64]]:
33
33
  """
34
34
  :math:`D^2` regression score function, \
35
35
  fraction of absolute error explained.
@@ -111,7 +111,7 @@ def d2_absolute_error_score(
111
111
 
112
112
  kwargs = telemetry.get_sproc_statement_params_kwargs(d2_absolute_error_score_anon_sproc, statement_params)
113
113
  result_object = result.deserialize(session, d2_absolute_error_score_anon_sproc(session, **kwargs))
114
- score: Union[float, npt.NDArray[np.float_]] = result_object
114
+ score: Union[float, npt.NDArray[np.float64]] = result_object
115
115
  return score
116
116
 
117
117
 
@@ -124,7 +124,7 @@ def d2_pinball_score(
124
124
  sample_weight_col_name: Optional[str] = None,
125
125
  alpha: float = 0.5,
126
126
  multioutput: Union[str, npt.ArrayLike] = "uniform_average",
127
- ) -> Union[float, npt.NDArray[np.float_]]:
127
+ ) -> Union[float, npt.NDArray[np.float64]]:
128
128
  """
129
129
  :math:`D^2` regression score function, fraction of pinball loss explained.
130
130
 
@@ -211,7 +211,7 @@ def d2_pinball_score(
211
211
  kwargs = telemetry.get_sproc_statement_params_kwargs(d2_pinball_score_anon_sproc, statement_params)
212
212
  result_object = result.deserialize(session, d2_pinball_score_anon_sproc(session, **kwargs))
213
213
 
214
- score: Union[float, npt.NDArray[np.float_]] = result_object
214
+ score: Union[float, npt.NDArray[np.float64]] = result_object
215
215
  return score
216
216
 
217
217
 
@@ -224,7 +224,7 @@ def explained_variance_score(
224
224
  sample_weight_col_name: Optional[str] = None,
225
225
  multioutput: Union[str, npt.ArrayLike] = "uniform_average",
226
226
  force_finite: bool = True,
227
- ) -> Union[float, npt.NDArray[np.float_]]:
227
+ ) -> Union[float, npt.NDArray[np.float64]]:
228
228
  """
229
229
  Explained variance regression score function.
230
230
 
@@ -326,7 +326,7 @@ def explained_variance_score(
326
326
 
327
327
  kwargs = telemetry.get_sproc_statement_params_kwargs(explained_variance_score_anon_sproc, statement_params)
328
328
  result_object = result.deserialize(session, explained_variance_score_anon_sproc(session, **kwargs))
329
- score: Union[float, npt.NDArray[np.float_]] = result_object
329
+ score: Union[float, npt.NDArray[np.float64]] = result_object
330
330
  return score
331
331
 
332
332
 
@@ -338,7 +338,7 @@ def mean_absolute_error(
338
338
  y_pred_col_names: Union[str, list[str]],
339
339
  sample_weight_col_name: Optional[str] = None,
340
340
  multioutput: Union[str, npt.ArrayLike] = "uniform_average",
341
- ) -> Union[float, npt.NDArray[np.float_]]:
341
+ ) -> Union[float, npt.NDArray[np.float64]]:
342
342
  """
343
343
  Mean absolute error regression loss.
344
344
 
@@ -411,7 +411,7 @@ def mean_absolute_percentage_error(
411
411
  y_pred_col_names: Union[str, list[str]],
412
412
  sample_weight_col_name: Optional[str] = None,
413
413
  multioutput: Union[str, npt.ArrayLike] = "uniform_average",
414
- ) -> Union[float, npt.NDArray[np.float_]]:
414
+ ) -> Union[float, npt.NDArray[np.float64]]:
415
415
  """
416
416
  Mean absolute percentage error (MAPE) regression loss.
417
417
 
@@ -495,7 +495,7 @@ def mean_squared_error(
495
495
  sample_weight_col_name: Optional[str] = None,
496
496
  multioutput: Union[str, npt.ArrayLike] = "uniform_average",
497
497
  squared: bool = True,
498
- ) -> Union[float, npt.NDArray[np.float_]]:
498
+ ) -> Union[float, npt.NDArray[np.float64]]:
499
499
  """
500
500
  Mean squared error regression loss.
501
501
 
@@ -264,6 +264,7 @@ def plot_force(
264
264
  def plot_influence_sensitivity(
265
265
  shap_values: type_hints.SupportedDataType,
266
266
  feature_values: type_hints.SupportedDataType,
267
+ infer_is_categorical: bool = True,
267
268
  figsize: tuple[float, float] = DEFAULT_FIGSIZE,
268
269
  ) -> Any:
269
270
  """
@@ -274,6 +275,8 @@ def plot_influence_sensitivity(
274
275
  Args:
275
276
  shap_values: pandas Series or 2D array containing the SHAP values for a specific feature
276
277
  feature_values: pandas Series or 2D array containing the feature values for the same feature
278
+ infer_is_categorical: If True, the function will infer if the feature is categorical
279
+ based on the number of unique values.
277
280
  figsize: tuple of (width, height) for the plot
278
281
 
279
282
  Returns:
@@ -294,7 +297,7 @@ def plot_influence_sensitivity(
294
297
  elif feature_values_df.shape[0] != shap_values_df.shape[0]:
295
298
  raise ValueError("Feature values and SHAP values must have the same number of rows.")
296
299
 
297
- scatter = _create_scatter_plot(feature_values, shap_values, figsize)
300
+ scatter = _create_scatter_plot(feature_values, shap_values, infer_is_categorical, figsize)
298
301
  return st.altair_chart(scatter) if use_streamlit else scatter
299
302
 
300
303
 
@@ -322,11 +325,13 @@ def _prepare_feature_values_for_streamlit(
322
325
  return feature_values, shap_values, st
323
326
 
324
327
 
325
- def _create_scatter_plot(feature_values: pd.Series, shap_values: pd.Series, figsize: tuple[float, float]) -> alt.Chart:
328
+ def _create_scatter_plot(
329
+ feature_values: pd.Series, shap_values: pd.Series, infer_is_categorical: bool, figsize: tuple[float, float]
330
+ ) -> alt.Chart:
326
331
  unique_vals = np.sort(np.unique(feature_values.values))
327
332
  max_points_per_unique_value = float(np.max(np.bincount(np.searchsorted(unique_vals, feature_values.values))))
328
333
  points_per_value = len(feature_values.values) / len(unique_vals)
329
- is_categorical = float(max(max_points_per_unique_value, points_per_value)) > 10
334
+ is_categorical = float(max(max_points_per_unique_value, points_per_value)) > 10 if infer_is_categorical else False
330
335
 
331
336
  kwargs = (
332
337
  {
@@ -403,9 +408,11 @@ def plot_violin(
403
408
  .transform_density(density="shap_value", groupby=["feature_name"], as_=["shap_value", "density"])
404
409
  .mark_area(orient="vertical")
405
410
  .encode(
406
- y=alt.Y("density:Q", title=None).stack("center").impute(None).axis(labels=False, grid=False, ticks=True),
411
+ y=alt.Y("density:Q", title=None).stack("center").impute(None).axis(labels=False, grid=False, ticks=False),
407
412
  x=alt.X("shap_value:Q", title="SHAP Value"),
408
- row=alt.Row("feature_name:N", sort=column_sort_order).spacing(0),
413
+ row=alt.Row(
414
+ "feature_name:N", sort=column_sort_order, header=alt.Header(labelAngle=0, labelAlign="left")
415
+ ).spacing(0),
409
416
  color=alt.Color("feature_name:N", legend=None),
410
417
  tooltip=["feature_name", "shap_value"],
411
418
  )
@@ -1,5 +1,5 @@
1
1
  from types import ModuleType
2
- from typing import Any, Optional, Union
2
+ from typing import TYPE_CHECKING, Any, Optional, Union
3
3
 
4
4
  import pandas as pd
5
5
  from absl.logging import logging
@@ -8,7 +8,7 @@ from snowflake.ml._internal import env, platform_capabilities, telemetry
8
8
  from snowflake.ml._internal.exceptions import error_codes, exceptions
9
9
  from snowflake.ml._internal.human_readable_id import hrid_generator
10
10
  from snowflake.ml._internal.utils import sql_identifier
11
- from snowflake.ml.model import model_signature, type_hints as model_types
11
+ from snowflake.ml.model import model_signature, target_platform, task, type_hints
12
12
  from snowflake.ml.model._client.model import model_impl, model_version_impl
13
13
  from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
14
14
  from snowflake.ml.model._model_composer import model_composer
@@ -17,6 +17,9 @@ from snowflake.ml.model._packager.model_meta import model_meta
17
17
  from snowflake.snowpark import exceptions as snowpark_exceptions, session
18
18
  from snowflake.snowpark._internal import utils as snowpark_utils
19
19
 
20
+ if TYPE_CHECKING:
21
+ from snowflake.ml.experiment._experiment_info import ExperimentInfo
22
+
20
23
  logger = logging.getLogger(__name__)
21
24
 
22
25
 
@@ -41,7 +44,7 @@ class ModelManager:
41
44
  def log_model(
42
45
  self,
43
46
  *,
44
- model: Union[model_types.SupportedModelType, model_version_impl.ModelVersion],
47
+ model: Union[type_hints.SupportedModelType, model_version_impl.ModelVersion],
45
48
  model_name: str,
46
49
  version_name: Optional[str] = None,
47
50
  comment: Optional[str] = None,
@@ -50,16 +53,18 @@ class ModelManager:
50
53
  pip_requirements: Optional[list[str]] = None,
51
54
  artifact_repository_map: Optional[dict[str, str]] = None,
52
55
  resource_constraint: Optional[dict[str, str]] = None,
53
- target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None,
56
+ target_platforms: Optional[list[type_hints.SupportedTargetPlatformType]] = None,
54
57
  python_version: Optional[str] = None,
55
58
  signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
56
- sample_input_data: Optional[model_types.SupportedDataType] = None,
59
+ sample_input_data: Optional[type_hints.SupportedDataType] = None,
57
60
  user_files: Optional[dict[str, list[str]]] = None,
58
61
  code_paths: Optional[list[str]] = None,
59
62
  ext_modules: Optional[list[ModuleType]] = None,
60
- task: model_types.Task = model_types.Task.UNKNOWN,
61
- options: Optional[model_types.ModelSaveOption] = None,
63
+ task: type_hints.Task = task.Task.UNKNOWN,
64
+ experiment_info: Optional["ExperimentInfo"] = None,
65
+ options: Optional[type_hints.ModelSaveOption] = None,
62
66
  statement_params: Optional[dict[str, Any]] = None,
67
+ progress_status: Optional[Any] = None,
63
68
  ) -> model_version_impl.ModelVersion:
64
69
 
65
70
  database_name_id, schema_name_id, model_name_id = self._parse_fully_qualified_name(model_name)
@@ -141,13 +146,15 @@ class ModelManager:
141
146
  code_paths=code_paths,
142
147
  ext_modules=ext_modules,
143
148
  task=task,
149
+ experiment_info=experiment_info,
144
150
  options=options,
145
151
  statement_params=statement_params,
152
+ progress_status=progress_status,
146
153
  )
147
154
 
148
155
  def _log_model(
149
156
  self,
150
- model: model_types.SupportedModelType,
157
+ model: type_hints.SupportedModelType,
151
158
  *,
152
159
  model_name: str,
153
160
  version_name: str,
@@ -157,16 +164,18 @@ class ModelManager:
157
164
  pip_requirements: Optional[list[str]] = None,
158
165
  artifact_repository_map: Optional[dict[str, str]] = None,
159
166
  resource_constraint: Optional[dict[str, str]] = None,
160
- target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None,
167
+ target_platforms: Optional[list[type_hints.SupportedTargetPlatformType]] = None,
161
168
  python_version: Optional[str] = None,
162
169
  signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
163
- sample_input_data: Optional[model_types.SupportedDataType] = None,
170
+ sample_input_data: Optional[type_hints.SupportedDataType] = None,
164
171
  user_files: Optional[dict[str, list[str]]] = None,
165
172
  code_paths: Optional[list[str]] = None,
166
173
  ext_modules: Optional[list[ModuleType]] = None,
167
- task: model_types.Task = model_types.Task.UNKNOWN,
168
- options: Optional[model_types.ModelSaveOption] = None,
174
+ task: type_hints.Task = task.Task.UNKNOWN,
175
+ experiment_info: Optional["ExperimentInfo"] = None,
176
+ options: Optional[type_hints.ModelSaveOption] = None,
169
177
  statement_params: Optional[dict[str, Any]] = None,
178
+ progress_status: Optional[Any] = None,
170
179
  ) -> model_version_impl.ModelVersion:
171
180
  database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
172
181
  version_name_id = sql_identifier.SqlIdentifier(version_name)
@@ -215,7 +224,7 @@ class ModelManager:
215
224
  # User specified target platforms are defaulted to None and will not show up in the generated manifest.
216
225
  if target_platforms:
217
226
  # Convert any string target platforms to TargetPlatform objects
218
- platforms = [model_types.TargetPlatform(platform) for platform in target_platforms]
227
+ platforms = [type_hints.TargetPlatform(platform) for platform in target_platforms]
219
228
  else:
220
229
  # Default the target platform to warehouse if not specified and any table function exists
221
230
  if options and (
@@ -231,7 +240,7 @@ class ModelManager:
231
240
  "Logging a partitioned model with a table function without specifying `target_platforms`. "
232
241
  'Default to `target_platforms=["WAREHOUSE"]`.'
233
242
  )
234
- platforms = [model_types.TargetPlatform.WAREHOUSE]
243
+ platforms = [target_platform.TargetPlatform.WAREHOUSE]
235
244
 
236
245
  # Default the target platform to SPCS if not specified when running in ML runtime
237
246
  if not platforms and env.IN_ML_RUNTIME:
@@ -239,7 +248,7 @@ class ModelManager:
239
248
  "Logging the model on Container Runtime for ML without specifying `target_platforms`. "
240
249
  'Default to `target_platforms=["SNOWPARK_CONTAINER_SERVICES"]`.'
241
250
  )
242
- platforms = [model_types.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
251
+ platforms = [target_platform.TargetPlatform.SNOWPARK_CONTAINER_SERVICES]
243
252
 
244
253
  if artifact_repository_map:
245
254
  for channel, artifact_repository_name in artifact_repository_map.items():
@@ -254,6 +263,9 @@ class ModelManager:
254
263
  )
255
264
 
256
265
  logger.info("Start packaging and uploading your model. It might take some time based on the size of the model.")
266
+ if progress_status:
267
+ progress_status.update("packaging model...")
268
+ progress_status.increment()
257
269
 
258
270
  # Extract save_location from options if present
259
271
  save_location = None
@@ -267,6 +279,11 @@ class ModelManager:
267
279
  statement_params=statement_params,
268
280
  save_location=save_location,
269
281
  )
282
+
283
+ if progress_status:
284
+ progress_status.update("creating model manifest...")
285
+ progress_status.increment()
286
+
270
287
  model_metadata: model_meta.ModelMetadata = mc.save(
271
288
  name=model_name_id.resolved(),
272
289
  model=model,
@@ -283,7 +300,12 @@ class ModelManager:
283
300
  ext_modules=ext_modules,
284
301
  options=options,
285
302
  task=task,
303
+ experiment_info=experiment_info,
286
304
  )
305
+
306
+ if progress_status:
307
+ progress_status.update("uploading model files...")
308
+ progress_status.increment()
287
309
  statement_params = telemetry.add_statement_params_custom_tags(
288
310
  statement_params, model_metadata.telemetry_metadata()
289
311
  )
@@ -292,6 +314,9 @@ class ModelManager:
292
314
  )
293
315
 
294
316
  logger.info("Start creating MODEL object for you in the Snowflake.")
317
+ if progress_status:
318
+ progress_status.update("creating model object in Snowflake...")
319
+ progress_status.increment()
295
320
 
296
321
  self._model_ops.create_from_stage(
297
322
  composed_model=mc,
@@ -318,6 +343,10 @@ class ModelManager:
318
343
  version_name=version_name_id,
319
344
  )
320
345
 
346
+ if progress_status:
347
+ progress_status.update("setting model metadata...")
348
+ progress_status.increment()
349
+
321
350
  if comment:
322
351
  mv.comment = comment
323
352
 
@@ -331,6 +360,9 @@ class ModelManager:
331
360
  statement_params=statement_params,
332
361
  )
333
362
 
363
+ if progress_status:
364
+ progress_status.update("model logged successfully!")
365
+
334
366
  return mv
335
367
 
336
368
  def get_model(
@@ -6,12 +6,15 @@ import pandas as pd
6
6
 
7
7
  from snowflake import snowpark
8
8
  from snowflake.ml._internal import telemetry
9
- from snowflake.ml._internal.utils import sql_identifier
9
+ from snowflake.ml._internal.utils import query_result_checker, sql_identifier
10
10
  from snowflake.ml.model import (
11
11
  Model,
12
12
  ModelVersion,
13
+ event_handler,
13
14
  model_signature,
14
- type_hints as model_types,
15
+ target_platform,
16
+ task,
17
+ type_hints,
15
18
  )
16
19
  from snowflake.ml.model._client.model import model_version_impl
17
20
  from snowflake.ml.monitoring import model_monitor
@@ -75,20 +78,30 @@ class Registry:
75
78
  else sql_identifier.SqlIdentifier("PUBLIC")
76
79
  )
77
80
 
78
- database_exists = session.sql(
79
- f"""SELECT 1 FROM INFORMATION_SCHEMA.DATABASES WHERE DATABASE_NAME = '{self._database_name.resolved()}';"""
80
- ).collect()
81
+ database_results = (
82
+ query_result_checker.SqlResultValidator(
83
+ session, f"""SHOW DATABASES LIKE '{self._database_name.resolved()}';"""
84
+ )
85
+ .has_column("name", allow_empty=True)
86
+ .validate()
87
+ )
81
88
 
82
- if not database_exists:
89
+ db_names = [row["name"] for row in database_results]
90
+ if not self._database_name.resolved() in db_names:
83
91
  raise ValueError(f"Database {self._database_name} does not exist.")
84
92
 
85
- schema_exists = session.sql(
86
- f"""
87
- SELECT 1 FROM {self._database_name.identifier()}.INFORMATION_SCHEMA.SCHEMATA
88
- WHERE SCHEMA_NAME = '{self._schema_name.resolved()}';"""
89
- ).collect()
93
+ schema_results = (
94
+ query_result_checker.SqlResultValidator(
95
+ session,
96
+ f"""SHOW SCHEMAS LIKE '{self._schema_name.resolved()}'
97
+ IN DATABASE {self._database_name.identifier()};""",
98
+ )
99
+ .has_column("name", allow_empty=True)
100
+ .validate()
101
+ )
90
102
 
91
- if not schema_exists:
103
+ schema_names = [row["name"] for row in schema_results]
104
+ if not self._schema_name.resolved() in schema_names:
92
105
  raise ValueError(f"Schema {self._schema_name} does not exist.")
93
106
 
94
107
  self._model_manager = model_manager.ModelManager(
@@ -119,7 +132,7 @@ class Registry:
119
132
  @overload
120
133
  def log_model(
121
134
  self,
122
- model: model_types.SupportedModelType,
135
+ model: type_hints.SupportedModelType,
123
136
  *,
124
137
  model_name: str,
125
138
  version_name: Optional[str] = None,
@@ -129,15 +142,15 @@ class Registry:
129
142
  pip_requirements: Optional[list[str]] = None,
130
143
  artifact_repository_map: Optional[dict[str, str]] = None,
131
144
  resource_constraint: Optional[dict[str, str]] = None,
132
- target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None,
145
+ target_platforms: Optional[list[Union[target_platform.TargetPlatform, str]]] = None,
133
146
  python_version: Optional[str] = None,
134
147
  signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
135
- sample_input_data: Optional[model_types.SupportedDataType] = None,
148
+ sample_input_data: Optional[type_hints.SupportedDataType] = None,
136
149
  user_files: Optional[dict[str, list[str]]] = None,
137
150
  code_paths: Optional[list[str]] = None,
138
151
  ext_modules: Optional[list[ModuleType]] = None,
139
- task: model_types.Task = model_types.Task.UNKNOWN,
140
- options: Optional[model_types.ModelSaveOption] = None,
152
+ task: task.Task = task.Task.UNKNOWN,
153
+ options: Optional[type_hints.ModelSaveOption] = None,
141
154
  ) -> ModelVersion:
142
155
  """
143
156
  Log a model with various parameters and metadata, or a ModelVersion object.
@@ -159,12 +172,12 @@ class Registry:
159
172
  to specify a dependency. It is a recommended way to specify your dependencies using conda. When channel
160
173
  is not specified, Snowflake Anaconda Channel will be used. Defaults to None.
161
174
  pip_requirements: List of Pip package specifications. Defaults to None.
162
- Models with pip requirements are currently only runnable in Snowpark Container Services.
163
- See https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/container for more.
164
- Models with pip requirements specified will not be executable in Snowflake Warehouse where all
165
- dependencies must be retrieved from Snowflake Anaconda Channel.
175
+ Models running in a Snowflake Warehouse must also specify a pip artifact repository (see
176
+ `artifact_repository_map`). Otherwise, models with pip requirements are runnable only in Snowpark
177
+ Container Services. See
178
+ https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/container for more.
166
179
  artifact_repository_map: Specifies a mapping of package channels or platforms to custom artifact
167
- repositories. Defaults to None. Currently, the mapping applies only to warehouse execution.
180
+ repositories. Defaults to None. Currently, the mapping applies only to Warehouse execution.
168
181
  Note : This feature is currently in Public Preview.
169
182
  Format: {channel_name: artifact_repository_name}, where:
170
183
  - channel_name: Currently must be 'pip'.
@@ -172,10 +185,13 @@ class Registry:
172
185
  `snowflake.snowpark.pypi_shared_repository`.
173
186
  resource_constraint: Mapping of resource constraint keys and values, e.g. {"architecture": "x86"}.
174
187
  target_platforms: List of target platforms to run the model. The only acceptable inputs are a combination of
175
- "WAREHOUSE" and "SNOWPARK_CONTAINER_SERVICES":
176
- - ["WAREHOUSE"] (Warehouse only)
177
- - ["SNOWPARK_CONTAINER_SERVICES"] (Snowpark Container Services only)
178
- - ["WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"] (Both)
188
+ "WAREHOUSE" and "SNOWPARK_CONTAINER_SERVICES", or a target platform constant:
189
+ - ["WAREHOUSE"] or snowflake.ml.model.target_platform.WAREHOUSE_ONLY (Warehouse only)
190
+ - ["SNOWPARK_CONTAINER_SERVICES"] or
191
+ snowflake.ml.model.target_platform.SNOWPARK_CONTAINER_SERVICES_ONLY
192
+ (Snowpark Container Services only)
193
+ - ["WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"] or
194
+ snowflake.ml.model.target_platform.BOTH_WAREHOUSE_AND_SNOWPARK_CONTAINER_SERVICES (Both)
179
195
  Defaults to None. When None, the target platforms will be both.
180
196
  python_version: Python version in which the model is run. Defaults to None.
181
197
  signatures: Model data signatures for inputs and outputs for various target methods. If it is None,
@@ -206,7 +222,8 @@ class Registry:
206
222
  - target_methods: List of target methods to register when logging the model.
207
223
  This option is not used in MLFlow models. Defaults to None, in which case the model handler's
208
224
  default target methods will be used.
209
- - save_location: Location to save the model and metadata.
225
+ - save_location: Local directory to save the the serialized model files first before
226
+ uploading to Snowflake. This is useful when default tmp directory is not writable.
210
227
  - method_options: Per-method saving options. This dictionary has method names as keys and dictionary
211
228
  values with the desired options.
212
229
 
@@ -263,7 +280,7 @@ class Registry:
263
280
  )
264
281
  def log_model(
265
282
  self,
266
- model: Union[model_types.SupportedModelType, ModelVersion],
283
+ model: Union[type_hints.SupportedModelType, ModelVersion],
267
284
  *,
268
285
  model_name: str,
269
286
  version_name: Optional[str] = None,
@@ -273,15 +290,15 @@ class Registry:
273
290
  pip_requirements: Optional[list[str]] = None,
274
291
  artifact_repository_map: Optional[dict[str, str]] = None,
275
292
  resource_constraint: Optional[dict[str, str]] = None,
276
- target_platforms: Optional[list[model_types.SupportedTargetPlatformType]] = None,
293
+ target_platforms: Optional[list[Union[target_platform.TargetPlatform, str]]] = None,
277
294
  python_version: Optional[str] = None,
278
295
  signatures: Optional[dict[str, model_signature.ModelSignature]] = None,
279
- sample_input_data: Optional[model_types.SupportedDataType] = None,
296
+ sample_input_data: Optional[type_hints.SupportedDataType] = None,
280
297
  user_files: Optional[dict[str, list[str]]] = None,
281
298
  code_paths: Optional[list[str]] = None,
282
299
  ext_modules: Optional[list[ModuleType]] = None,
283
- task: model_types.Task = model_types.Task.UNKNOWN,
284
- options: Optional[model_types.ModelSaveOption] = None,
300
+ task: task.Task = task.Task.UNKNOWN,
301
+ options: Optional[type_hints.ModelSaveOption] = None,
285
302
  ) -> ModelVersion:
286
303
  """
287
304
  Log a model with various parameters and metadata, or a ModelVersion object.
@@ -303,12 +320,12 @@ class Registry:
303
320
  to specify a dependency. It is a recommended way to specify your dependencies using conda. When channel
304
321
  is not specified, Snowflake Anaconda Channel will be used. Defaults to None.
305
322
  pip_requirements: List of Pip package specifications. Defaults to None.
306
- Models with pip requirements are currently only runnable in Snowpark Container Services.
307
- See https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/container for more.
308
- Models with pip requirements specified will not be executable in Snowflake Warehouse where all
309
- dependencies must be retrieved from Snowflake Anaconda Channel.
323
+ Models running in a Snowflake Warehouse must also specify a pip artifact repository (see
324
+ `artifact_repository_map`). Otherwise, models with pip requirements are runnable only in Snowpark
325
+ Container Services. See
326
+ https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/container for more.
310
327
  artifact_repository_map: Specifies a mapping of package channels or platforms to custom artifact
311
- repositories. Defaults to None. Currently, the mapping applies only to warehouse execution.
328
+ repositories. Defaults to None. Currently, the mapping applies only to Warehouse execution.
312
329
  Note : This feature is currently in Public Preview.
313
330
  Format: {channel_name: artifact_repository_name}, where:
314
331
  - channel_name: Currently must be 'pip'.
@@ -316,10 +333,13 @@ class Registry:
316
333
  `snowflake.snowpark.pypi_shared_repository`.
317
334
  resource_constraint: Mapping of resource constraint keys and values, e.g. {"architecture": "x86"}.
318
335
  target_platforms: List of target platforms to run the model. The only acceptable inputs are a combination of
319
- "WAREHOUSE" and "SNOWPARK_CONTAINER_SERVICES":
320
- - ["WAREHOUSE"] (Warehouse only)
321
- - ["SNOWPARK_CONTAINER_SERVICES"] (Snowpark Container Services only)
322
- - ["WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"] (Both)
336
+ "WAREHOUSE" and "SNOWPARK_CONTAINER_SERVICES", or a target platform constant:
337
+ - ["WAREHOUSE"] or snowflake.ml.model.target_platform.WAREHOUSE_ONLY (Warehouse only)
338
+ - ["SNOWPARK_CONTAINER_SERVICES"] or
339
+ snowflake.ml.model.target_platform.SNOWPARK_CONTAINER_SERVICES_ONLY
340
+ (Snowpark Container Services only)
341
+ - ["WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"] or
342
+ snowflake.ml.model.target_platform.BOTH_WAREHOUSE_AND_SNOWPARK_CONTAINER_SERVICES (Both)
323
343
  Defaults to None. When None, the target platforms will be both.
324
344
  python_version: Python version in which the model is run. Defaults to None.
325
345
  signatures: Model data signatures for inputs and outputs for various target methods. If it is None,
@@ -366,6 +386,7 @@ class Registry:
366
386
 
367
387
  Raises:
368
388
  ValueError: If extra arguments are specified ModelVersion is provided.
389
+ Exception: If the model logging fails.
369
390
 
370
391
  Returns:
371
392
  ModelVersion: ModelVersion object corresponding to the model just logged.
@@ -418,10 +439,10 @@ class Registry:
418
439
  raise ValueError(
419
440
  "When calling log_model with a ModelVersion, only model_name and version_name may be specified."
420
441
  )
421
- if task is not model_types.Task.UNKNOWN:
442
+ if task is not type_hints.Task.UNKNOWN:
422
443
  raise ValueError("`task` cannot be specified when calling log_model with a ModelVersion.")
423
444
 
424
- if pip_requirements and not artifact_repository_map:
445
+ if pip_requirements and not artifact_repository_map and self._targets_warehouse(target_platforms):
425
446
  warnings.warn(
426
447
  "Models logged specifying `pip_requirements` cannot be executed in a Snowflake Warehouse "
427
448
  "without specifying `artifact_repository_map`. This model can be run in Snowpark Container "
@@ -429,27 +450,42 @@ class Registry:
429
450
  category=UserWarning,
430
451
  stacklevel=1,
431
452
  )
432
- return self._model_manager.log_model(
433
- model=model,
434
- model_name=model_name,
435
- version_name=version_name,
436
- comment=comment,
437
- metrics=metrics,
438
- conda_dependencies=conda_dependencies,
439
- pip_requirements=pip_requirements,
440
- artifact_repository_map=artifact_repository_map,
441
- resource_constraint=resource_constraint,
442
- target_platforms=target_platforms,
443
- python_version=python_version,
444
- signatures=signatures,
445
- sample_input_data=sample_input_data,
446
- user_files=user_files,
447
- code_paths=code_paths,
448
- ext_modules=ext_modules,
449
- task=task,
450
- options=options,
451
- statement_params=statement_params,
452
- )
453
+
454
+ registry_event_handler = event_handler.ModelEventHandler()
455
+ with registry_event_handler.status("Logging model", total=6) as status:
456
+ # Step 1: Validation and setup
457
+ status.update("validating model and dependencies...")
458
+ status.increment()
459
+
460
+ # Perform the actual model logging
461
+ try:
462
+ result = self._model_manager.log_model(
463
+ model=model,
464
+ model_name=model_name,
465
+ version_name=version_name,
466
+ comment=comment,
467
+ metrics=metrics,
468
+ conda_dependencies=conda_dependencies,
469
+ pip_requirements=pip_requirements,
470
+ artifact_repository_map=artifact_repository_map,
471
+ resource_constraint=resource_constraint,
472
+ target_platforms=target_platforms,
473
+ python_version=python_version,
474
+ signatures=signatures,
475
+ sample_input_data=sample_input_data,
476
+ user_files=user_files,
477
+ code_paths=code_paths,
478
+ ext_modules=ext_modules,
479
+ task=task,
480
+ options=options,
481
+ statement_params=statement_params,
482
+ progress_status=status,
483
+ )
484
+ status.update(label="Model logged successfully.", state="complete", expanded=False)
485
+ return result
486
+ except Exception as e:
487
+ status.update(label="Model logging failed.", state="error", expanded=False)
488
+ raise e
453
489
 
454
490
  @telemetry.send_api_usage_telemetry(
455
491
  project=_TELEMETRY_PROJECT,
@@ -626,3 +662,12 @@ class Registry:
626
662
  if not self.enable_monitoring:
627
663
  raise ValueError(_MODEL_MONITORING_DISABLED_ERROR)
628
664
  self._model_monitor_manager.delete_monitor(name)
665
+
666
+ @staticmethod
667
+ def _targets_warehouse(target_platforms: Optional[list[type_hints.SupportedTargetPlatformType]]) -> bool:
668
+ """Returns True if warehouse is a target platform (None defaults to True)."""
669
+ return (
670
+ target_platforms is None
671
+ or type_hints.TargetPlatform.WAREHOUSE in target_platforms
672
+ or "WAREHOUSE" in target_platforms
673
+ )
snowflake/ml/version.py CHANGED
@@ -1,2 +1,2 @@
1
1
  # This is parsed by regex in conda recipe meta file. Make sure not to break it.
2
- VERSION = "1.8.6"
2
+ VERSION = "1.9.1"