snowflake-ml-python 1.7.1__py3-none-any.whl → 1.7.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. snowflake/cortex/__init__.py +16 -8
  2. snowflake/cortex/_classify_text.py +12 -1
  3. snowflake/cortex/_complete.py +82 -13
  4. snowflake/cortex/_embed_text_1024.py +9 -2
  5. snowflake/cortex/_embed_text_768.py +9 -2
  6. snowflake/cortex/_extract_answer.py +9 -2
  7. snowflake/cortex/_sentiment.py +9 -2
  8. snowflake/cortex/_summarize.py +9 -2
  9. snowflake/cortex/_translate.py +9 -2
  10. snowflake/ml/_internal/env_utils.py +7 -52
  11. snowflake/ml/_internal/utils/identifier.py +4 -2
  12. snowflake/ml/_internal/utils/jwt_generator.py +141 -0
  13. snowflake/ml/data/__init__.py +3 -0
  14. snowflake/ml/data/_internal/arrow_ingestor.py +4 -4
  15. snowflake/ml/data/data_connector.py +53 -11
  16. snowflake/ml/data/data_ingestor.py +2 -1
  17. snowflake/ml/data/torch_utils.py +18 -5
  18. snowflake/ml/feature_store/examples/example_helper.py +2 -1
  19. snowflake/ml/fileset/fileset.py +18 -18
  20. snowflake/ml/model/_client/model/model_version_impl.py +24 -8
  21. snowflake/ml/model/_client/ops/model_ops.py +2 -6
  22. snowflake/ml/model/_client/ops/service_ops.py +12 -7
  23. snowflake/ml/model/_client/sql/model_version.py +11 -0
  24. snowflake/ml/model/_client/sql/stage.py +1 -1
  25. snowflake/ml/model/_model_composer/model_composer.py +8 -3
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -1
  27. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
  28. snowflake/ml/model/_model_composer/model_method/constants.py +1 -0
  29. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -0
  30. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +1 -1
  31. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +1 -1
  32. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +1 -1
  33. snowflake/ml/model/_model_composer/model_method/model_method.py +9 -1
  34. snowflake/ml/model/_model_composer/model_user_file/model_user_file.py +27 -0
  35. snowflake/ml/model/_packager/model_handlers/_utils.py +27 -2
  36. snowflake/ml/model/_packager/model_handlers/catboost.py +3 -3
  37. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +5 -1
  38. snowflake/ml/model/_packager/model_handlers/lightgbm.py +5 -3
  39. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +55 -20
  40. snowflake/ml/model/_packager/model_handlers/sklearn.py +10 -9
  41. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +66 -28
  42. snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -17
  43. snowflake/ml/model/_packager/model_handlers/xgboost.py +3 -3
  44. snowflake/ml/model/_packager/model_meta/model_meta.py +3 -0
  45. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
  46. snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
  47. snowflake/ml/model/_packager/model_task/model_task_utils.py +3 -2
  48. snowflake/ml/model/_signatures/pandas_handler.py +1 -1
  49. snowflake/ml/model/_signatures/snowpark_handler.py +8 -2
  50. snowflake/ml/model/_signatures/utils.py +0 -1
  51. snowflake/ml/model/type_hints.py +1 -0
  52. snowflake/ml/modeling/_internal/model_trainer_builder.py +0 -8
  53. snowflake/ml/modeling/_internal/model_transformer_builder.py +0 -13
  54. snowflake/ml/modeling/pipeline/pipeline.py +6 -176
  55. snowflake/ml/modeling/xgboost/xgb_classifier.py +161 -88
  56. snowflake/ml/modeling/xgboost/xgb_regressor.py +160 -85
  57. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +160 -85
  58. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +160 -85
  59. snowflake/ml/monitoring/_client/model_monitor_sql_client.py +5 -170
  60. snowflake/ml/monitoring/_manager/model_monitor_manager.py +9 -9
  61. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -2
  62. snowflake/ml/monitoring/model_monitor.py +26 -11
  63. snowflake/ml/registry/_manager/model_manager.py +70 -33
  64. snowflake/ml/registry/registry.py +53 -34
  65. snowflake/ml/utils/authentication.py +75 -0
  66. snowflake/ml/version.py +1 -1
  67. {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/METADATA +120 -53
  68. {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/RECORD +71 -74
  69. {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/WHEEL +1 -1
  70. snowflake/ml/_internal/utils/retryable_http.py +0 -39
  71. snowflake/ml/fileset/parquet_parser.py +0 -170
  72. snowflake/ml/fileset/tf_dataset.py +0 -88
  73. snowflake/ml/fileset/torch_datapipe.py +0 -57
  74. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +0 -151
  75. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_trainer.py +0 -66
  76. snowflake/ml/monitoring/entities/output_score_type.py +0 -90
  77. {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/LICENSE.txt +0 -0
  78. {snowflake_ml_python-1.7.1.dist-info → snowflake_ml_python-1.7.3.dist-info}/top_level.txt +0 -0
@@ -6,23 +6,49 @@ from snowflake.ml.model._client.model import model_version_impl
6
6
 
7
7
  @dataclass
8
8
  class ModelMonitorSourceConfig:
9
+ """Configuration for the source of data to be monitored."""
10
+
9
11
  source: str
12
+ """Name of table or view containing monitoring data."""
13
+
10
14
  timestamp_column: str
15
+ """Name of column in the source containing timestamp."""
16
+
11
17
  id_columns: List[str]
18
+ """List of columns in the source containing unique identifiers."""
19
+
12
20
  prediction_score_columns: Optional[List[str]] = None
21
+ """List of columns in the source containing prediction scores.
22
+ Can be regression scores for regression models and probability scores for classification models."""
23
+
13
24
  prediction_class_columns: Optional[List[str]] = None
25
+ """List of columns in the source containing prediction classes for classification models."""
26
+
14
27
  actual_score_columns: Optional[List[str]] = None
28
+ """List of columns in the source containing actual scores."""
29
+
15
30
  actual_class_columns: Optional[List[str]] = None
31
+ """List of columns in the source containing actual classes for classification models."""
32
+
16
33
  baseline: Optional[str] = None
34
+ """Name of table containing the baseline data."""
17
35
 
18
36
 
19
37
  @dataclass
20
38
  class ModelMonitorConfig:
39
+ """Configuration for the Model Monitor."""
40
+
21
41
  model_version: model_version_impl.ModelVersion
42
+ """Model version to monitor."""
22
43
 
23
- # Python model function name
24
44
  model_function_name: str
45
+ """Function name in the model to monitor."""
46
+
25
47
  background_compute_warehouse_name: str
26
- # TODO: Add support for pythonic notion of time.
48
+ """Name of the warehouse to use for background compute."""
49
+
27
50
  refresh_interval: str = "1 hour"
51
+ """Interval at which to refresh the monitoring data."""
52
+
28
53
  aggregation_window: str = "1 day"
54
+ """Window for aggregating monitoring data."""
@@ -1,5 +1,7 @@
1
+ from snowflake import snowpark
1
2
  from snowflake.ml._internal import telemetry
2
3
  from snowflake.ml._internal.utils import sql_identifier
4
+ from snowflake.ml.monitoring import model_monitor_version
3
5
  from snowflake.ml.monitoring._client import model_monitor_sql_client
4
6
 
5
7
 
@@ -9,13 +11,8 @@ class ModelMonitor:
9
11
  name: sql_identifier.SqlIdentifier
10
12
  _model_monitor_client: model_monitor_sql_client.ModelMonitorSQLClient
11
13
 
12
- statement_params = telemetry.get_statement_params(
13
- telemetry.TelemetryProject.MLOPS.value,
14
- telemetry.TelemetrySubProject.MONITORING.value,
15
- )
16
-
17
14
  def __init__(self) -> None:
18
- raise RuntimeError("ModelMonitor's initializer is not meant to be used.")
15
+ raise RuntimeError("Model Monitor's initializer is not meant to be used.")
19
16
 
20
17
  @classmethod
21
18
  def _ref(
@@ -28,10 +25,28 @@ class ModelMonitor:
28
25
  self._model_monitor_client = model_monitor_client
29
26
  return self
30
27
 
28
+ @telemetry.send_api_usage_telemetry(
29
+ project=telemetry.TelemetryProject.MLOPS.value,
30
+ subproject=telemetry.TelemetrySubProject.MONITORING.value,
31
+ )
32
+ @snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
31
33
  def suspend(self) -> None:
32
- """Suspend pipeline for ModelMonitor"""
33
- self._model_monitor_client.suspend_monitor(self.name, statement_params=self.statement_params)
34
-
34
+ """Suspend the Model Monitor"""
35
+ statement_params = telemetry.get_statement_params(
36
+ telemetry.TelemetryProject.MLOPS.value,
37
+ telemetry.TelemetrySubProject.MONITORING.value,
38
+ )
39
+ self._model_monitor_client.suspend_monitor(self.name, statement_params=statement_params)
40
+
41
+ @telemetry.send_api_usage_telemetry(
42
+ project=telemetry.TelemetryProject.MLOPS.value,
43
+ subproject=telemetry.TelemetrySubProject.MONITORING.value,
44
+ )
45
+ @snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
35
46
  def resume(self) -> None:
36
- """Resume pipeline for ModelMonitor"""
37
- self._model_monitor_client.resume_monitor(self.name, statement_params=self.statement_params)
47
+ """Resume the Model Monitor"""
48
+ statement_params = telemetry.get_statement_params(
49
+ telemetry.TelemetryProject.MLOPS.value,
50
+ telemetry.TelemetrySubProject.MONITORING.value,
51
+ )
52
+ self._model_monitor_client.resume_monitor(self.name, statement_params=statement_params)
@@ -1,13 +1,13 @@
1
1
  from types import ModuleType
2
- from typing import Any, Dict, List, Optional, Union
2
+ from typing import Any, Dict, List, Optional, Tuple, Union
3
3
 
4
4
  import pandas as pd
5
5
  from absl.logging import logging
6
- from packaging import version
7
6
 
8
7
  from snowflake.ml._internal import telemetry
8
+ from snowflake.ml._internal.exceptions import error_codes, exceptions
9
9
  from snowflake.ml._internal.human_readable_id import hrid_generator
10
- from snowflake.ml._internal.utils import snowflake_env, sql_identifier
10
+ from snowflake.ml._internal.utils import sql_identifier
11
11
  from snowflake.ml.model import model_signature, type_hints as model_types
12
12
  from snowflake.ml.model._client.model import model_impl, model_version_impl
13
13
  from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
@@ -50,14 +50,40 @@ class ModelManager:
50
50
  python_version: Optional[str] = None,
51
51
  signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
52
52
  sample_input_data: Optional[model_types.SupportedDataType] = None,
53
+ user_files: Optional[Dict[str, List[str]]] = None,
53
54
  code_paths: Optional[List[str]] = None,
54
55
  ext_modules: Optional[List[ModuleType]] = None,
55
56
  task: model_types.Task = model_types.Task.UNKNOWN,
56
57
  options: Optional[model_types.ModelSaveOption] = None,
57
58
  statement_params: Optional[Dict[str, Any]] = None,
58
59
  ) -> model_version_impl.ModelVersion:
59
- if not version_name:
60
- version_name = self._hrid_generator.generate()[1]
60
+
61
+ database_name_id, schema_name_id, model_name_id = self._parse_fully_qualified_name(model_name)
62
+
63
+ model_exists = self._model_ops.validate_existence(
64
+ database_name=database_name_id,
65
+ schema_name=schema_name_id,
66
+ model_name=model_name_id,
67
+ statement_params=statement_params,
68
+ )
69
+
70
+ if version_name is None:
71
+ if model_exists:
72
+ versions = self._model_ops.list_models_or_versions(
73
+ database_name=database_name_id,
74
+ schema_name=schema_name_id,
75
+ model_name=model_name_id,
76
+ statement_params=statement_params,
77
+ )
78
+ for _ in range(1000):
79
+ hrid = self._hrid_generator.generate()[1]
80
+ if sql_identifier.SqlIdentifier(hrid) not in versions:
81
+ version_name = hrid
82
+ break
83
+ if version_name is None:
84
+ raise RuntimeError("Random version name generation failed.")
85
+ else:
86
+ version_name = self._hrid_generator.generate()[1]
61
87
 
62
88
  if isinstance(model, model_version_impl.ModelVersion):
63
89
  (
@@ -75,10 +101,24 @@ class ModelManager:
75
101
  schema_name=None,
76
102
  model_name=sql_identifier.SqlIdentifier(model_name),
77
103
  version_name=sql_identifier.SqlIdentifier(version_name),
104
+ model_exists=model_exists,
78
105
  statement_params=statement_params,
79
106
  )
80
107
  return self.get_model(model_name=model_name, statement_params=statement_params).version(version_name)
81
108
 
109
+ version_name_id = sql_identifier.SqlIdentifier(version_name)
110
+ if model_exists and self._model_ops.validate_existence(
111
+ database_name=database_name_id,
112
+ schema_name=schema_name_id,
113
+ model_name=model_name_id,
114
+ version_name=version_name_id,
115
+ statement_params=statement_params,
116
+ ):
117
+ raise ValueError(
118
+ f"Model {model_name} version {version_name} already existed. "
119
+ + "To auto-generate `version_name`, skip that argument."
120
+ )
121
+
82
122
  return self._log_model(
83
123
  model=model,
84
124
  model_name=model_name,
@@ -91,6 +131,7 @@ class ModelManager:
91
131
  python_version=python_version,
92
132
  signatures=signatures,
93
133
  sample_input_data=sample_input_data,
134
+ user_files=user_files,
94
135
  code_paths=code_paths,
95
136
  ext_modules=ext_modules,
96
137
  task=task,
@@ -103,7 +144,7 @@ class ModelManager:
103
144
  model: model_types.SupportedModelType,
104
145
  *,
105
146
  model_name: str,
106
- version_name: Optional[str] = None,
147
+ version_name: str,
107
148
  comment: Optional[str] = None,
108
149
  metrics: Optional[Dict[str, Any]] = None,
109
150
  conda_dependencies: Optional[List[str]] = None,
@@ -112,6 +153,7 @@ class ModelManager:
112
153
  python_version: Optional[str] = None,
113
154
  signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
114
155
  sample_input_data: Optional[model_types.SupportedDataType] = None,
156
+ user_files: Optional[Dict[str, List[str]]] = None,
115
157
  code_paths: Optional[List[str]] = None,
116
158
  ext_modules: Optional[List[ModuleType]] = None,
117
159
  task: model_types.Task = model_types.Task.UNKNOWN,
@@ -119,28 +161,8 @@ class ModelManager:
119
161
  statement_params: Optional[Dict[str, Any]] = None,
120
162
  ) -> model_version_impl.ModelVersion:
121
163
  database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
122
-
123
- if not version_name:
124
- version_name = self._hrid_generator.generate()[1]
125
164
  version_name_id = sql_identifier.SqlIdentifier(version_name)
126
165
 
127
- if self._model_ops.validate_existence(
128
- database_name=database_name_id,
129
- schema_name=schema_name_id,
130
- model_name=model_name_id,
131
- statement_params=statement_params,
132
- ) and self._model_ops.validate_existence(
133
- database_name=database_name_id,
134
- schema_name=schema_name_id,
135
- model_name=model_name_id,
136
- version_name=version_name_id,
137
- statement_params=statement_params,
138
- ):
139
- raise ValueError(
140
- f"Model {model_name} version {version_name} already existed. "
141
- + "To auto-generate `version_name`, skip that argument."
142
- )
143
-
144
166
  stage_path = self._model_ops.prepare_model_stage_path(
145
167
  database_name=database_name_id,
146
168
  schema_name=schema_name_id,
@@ -148,13 +170,10 @@ class ModelManager:
148
170
  )
149
171
 
150
172
  platforms = None
151
- # TODO(jbahk): Remove the version check after Snowflake 8.40.0 release
152
173
  # User specified target platforms are defaulted to None and will not show up in the generated manifest.
153
- # In the backend, we attempt to create a model for all platforms (WH, SPCS) regardless by default.
154
- if snowflake_env.get_current_snowflake_version(self._model_ops._session) >= version.parse("8.40.0"):
174
+ if target_platforms:
155
175
  # Convert any string target platforms to TargetPlatform objects
156
- if target_platforms:
157
- platforms = [model_types.TargetPlatform(platform) for platform in target_platforms]
176
+ platforms = [model_types.TargetPlatform(platform) for platform in target_platforms]
158
177
 
159
178
  logger.info("Start packaging and uploading your model. It might take some time based on the size of the model.")
160
179
 
@@ -170,6 +189,7 @@ class ModelManager:
170
189
  pip_requirements=pip_requirements,
171
190
  target_platforms=platforms,
172
191
  python_version=python_version,
192
+ user_files=user_files,
173
193
  code_paths=code_paths,
174
194
  ext_modules=ext_modules,
175
195
  options=options,
@@ -229,7 +249,7 @@ class ModelManager:
229
249
  *,
230
250
  statement_params: Optional[Dict[str, Any]] = None,
231
251
  ) -> model_impl.Model:
232
- database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
252
+ database_name_id, schema_name_id, model_name_id = self._parse_fully_qualified_name(model_name)
233
253
  if self._model_ops.validate_existence(
234
254
  database_name=database_name_id,
235
255
  schema_name=schema_name_id,
@@ -289,7 +309,7 @@ class ModelManager:
289
309
  *,
290
310
  statement_params: Optional[Dict[str, Any]] = None,
291
311
  ) -> None:
292
- database_name_id, schema_name_id, model_name_id = sql_identifier.parse_fully_qualified_name(model_name)
312
+ database_name_id, schema_name_id, model_name_id = self._parse_fully_qualified_name(model_name)
293
313
 
294
314
  self._model_ops.delete_model_or_version(
295
315
  database_name=database_name_id,
@@ -297,3 +317,20 @@ class ModelManager:
297
317
  model_name=model_name_id,
298
318
  statement_params=statement_params,
299
319
  )
320
+
321
+ def _parse_fully_qualified_name(
322
+ self, model_name: str
323
+ ) -> Tuple[
324
+ Optional[sql_identifier.SqlIdentifier], Optional[sql_identifier.SqlIdentifier], sql_identifier.SqlIdentifier
325
+ ]:
326
+ try:
327
+ return sql_identifier.parse_fully_qualified_name(model_name)
328
+ except ValueError:
329
+ raise exceptions.SnowflakeMLException(
330
+ error_code=error_codes.INVALID_ARGUMENT,
331
+ original_exception=ValueError(
332
+ f"The model_name `{model_name}` cannot be parsed as a SQL identifier. Alphanumeric characters and "
333
+ "underscores are permitted. See https://docs.snowflake.com/en/sql-reference/identifiers-syntax for "
334
+ "more information."
335
+ ),
336
+ )
@@ -117,41 +117,49 @@ class Registry:
117
117
  options: Optional[model_types.ModelSaveOption] = None,
118
118
  ) -> ModelVersion:
119
119
  """
120
- Log a model with various parameters and metadata.
120
+ Log a model with various parameters and metadata, or a ModelVersion object.
121
121
 
122
122
  Args:
123
- model: Model object of supported types such as Scikit-learn, XGBoost, LightGBM, Snowpark ML,
124
- PyTorch, TorchScript, Tensorflow, Tensorflow Keras, MLFlow, HuggingFace Pipeline,
125
- Sentence Transformers, or Custom Model.
126
- model_name: Name to identify the model.
123
+ model: Supported model or ModelVersion object.
124
+ - Supported model: Model object of supported types such as Scikit-learn, XGBoost, LightGBM, Snowpark ML,
125
+ PyTorch, TorchScript, Tensorflow, Tensorflow Keras, MLFlow, HuggingFace Pipeline, Sentence Transformers,
126
+ or Custom Model.
127
+ - ModelVersion: Source ModelVersion object used to create the new ModelVersion object.
128
+ model_name: Name to identify the model. This must be a valid Snowflake SQL Identifier. Alphanumeric
129
+ characters and underscores are permitted.
130
+ See https://docs.snowflake.com/en/sql-reference/identifiers-syntax for more.
127
131
  version_name: Version identifier for the model. Combination of model_name and version_name must be unique.
128
132
  If not specified, a random name will be generated.
129
133
  comment: Comment associated with the model version. Defaults to None.
130
134
  metrics: A JSON serializable dictionary containing metrics linked to the model version. Defaults to None.
131
- signatures: Model data signatures for inputs and outputs for various target methods. If it is None,
132
- sample_input_data would be used to infer the signatures for those models that cannot automatically
133
- infer the signature. Defaults to None.
134
- sample_input_data: Sample input data to infer model signatures from.
135
- It would also be used as background data in explanation and to capture data lineage. Defaults to None.
136
135
  conda_dependencies: List of Conda package specifications. Use "[channel::]package [operator version]" syntax
137
136
  to specify a dependency. It is a recommended way to specify your dependencies using conda. When channel
138
137
  is not specified, Snowflake Anaconda Channel will be used. Defaults to None.
139
138
  pip_requirements: List of Pip package specifications. Defaults to None.
140
- Currently it is not supported since Model can only executed in Snowflake Warehouse where all
141
- dependencies are required to be retrieved from Snowflake Anaconda Channel.
139
+ Models with pip requirements are currently only runnable in Snowpark Container Services.
140
+ See https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/container for more.
141
+ Models with pip requirements specified will not be executable in Snowflake Warehouse where all
142
+ dependencies must be retrieved from Snowflake Anaconda Channel.
142
143
  target_platforms: List of target platforms to run the model. The only acceptable inputs are a combination of
143
144
  {"WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"}. Defaults to None.
144
145
  python_version: Python version in which the model is run. Defaults to None.
146
+ signatures: Model data signatures for inputs and outputs for various target methods. If it is None,
147
+ sample_input_data would be used to infer the signatures for those models that cannot automatically
148
+ infer the signature. If not None, sample_input_data should not be specified. Defaults to None.
149
+ sample_input_data: Sample input data to infer model signatures from.
150
+ It would also be used as background data in explanation and to capture data lineage. Defaults to None.
145
151
  code_paths: List of directories containing code to import. Defaults to None.
146
152
  ext_modules: List of external modules to pickle with the model object.
147
153
  Only supported when logging the following types of model:
148
154
  Scikit-learn, Snowpark ML, PyTorch, TorchScript and Custom Model. Defaults to None.
149
155
  options (Dict[str, Any], optional): Additional model saving options.
156
+
150
157
  Model Saving Options include:
158
+
151
159
  - embed_local_ml_library: Embed local Snowpark ML into the code directory or folder.
152
160
  Override to True if the local Snowpark ML version is not available in the Snowflake Anaconda
153
161
  Channel. Otherwise, defaults to False
154
- - relax_version: Whether or not relax the version constraints of the dependencies when running in the
162
+ - relax_version: Whether to relax the version constraints of the dependencies when running in the
155
163
  Warehouse. It detects any ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to True.
156
164
  - function_type: Set the method function type globally. To set method function types individually see
157
165
  function_type in model_options.
@@ -163,7 +171,10 @@ class Registry:
163
171
  - max_batch_size: Maximum batch size that the method could accept in the Snowflake Warehouse.
164
172
  Defaults to None, determined automatically by Snowflake.
165
173
  - function_type: One of supported model method function types (FUNCTION or TABLE_FUNCTION).
174
+ Returns:
175
+ ModelVersion: ModelVersion object corresponding to the model just logged.
166
176
  """
177
+
167
178
  ...
168
179
 
169
180
  @overload
@@ -214,6 +225,7 @@ class Registry:
214
225
  python_version: Optional[str] = None,
215
226
  signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
216
227
  sample_input_data: Optional[model_types.SupportedDataType] = None,
228
+ user_files: Optional[Dict[str, List[str]]] = None,
217
229
  code_paths: Optional[List[str]] = None,
218
230
  ext_modules: Optional[List[ModuleType]] = None,
219
231
  task: model_types.Task = model_types.Task.UNKNOWN,
@@ -228,25 +240,31 @@ class Registry:
228
240
  PyTorch, TorchScript, Tensorflow, Tensorflow Keras, MLFlow, HuggingFace Pipeline, Sentence Transformers,
229
241
  or Custom Model.
230
242
  - ModelVersion: Source ModelVersion object used to create the new ModelVersion object.
231
- model_name: Name to identify the model.
243
+ model_name: Name to identify the model. This must be a valid Snowflake SQL Identifier. Alphanumeric
244
+ characters and underscores are permitted.
245
+ See https://docs.snowflake.com/en/sql-reference/identifiers-syntax for more.
232
246
  version_name: Version identifier for the model. Combination of model_name and version_name must be unique.
233
247
  If not specified, a random name will be generated.
234
248
  comment: Comment associated with the model version. Defaults to None.
235
249
  metrics: A JSON serializable dictionary containing metrics linked to the model version. Defaults to None.
236
- signatures: Model data signatures for inputs and outputs for various target methods. If it is None,
237
- sample_input_data would be used to infer the signatures for those models that cannot automatically
238
- infer the signature. If not None, sample_input_data should not be specified. Defaults to None.
239
- sample_input_data: Sample input data to infer model signatures from.
240
- It would also be used as background data in explanation and to capture data lineage. Defaults to None.
241
250
  conda_dependencies: List of Conda package specifications. Use "[channel::]package [operator version]" syntax
242
251
  to specify a dependency. It is a recommended way to specify your dependencies using conda. When channel
243
252
  is not specified, Snowflake Anaconda Channel will be used. Defaults to None.
244
253
  pip_requirements: List of Pip package specifications. Defaults to None.
245
- Currently it is not supported since Model can only executed in Snowflake Warehouse where all
246
- dependencies are required to be retrieved from Snowflake Anaconda Channel.
254
+ Models with pip requirements are currently only runnable in Snowpark Container Services.
255
+ See https://docs.snowflake.com/en/developer-guide/snowflake-ml/model-registry/container for more.
256
+ Models with pip requirements specified will not be executable in Snowflake Warehouse where all
257
+ dependencies must be retrieved from Snowflake Anaconda Channel.
247
258
  target_platforms: List of target platforms to run the model. The only acceptable inputs are a combination of
248
259
  {"WAREHOUSE", "SNOWPARK_CONTAINER_SERVICES"}. Defaults to None.
249
260
  python_version: Python version in which the model is run. Defaults to None.
261
+ signatures: Model data signatures for inputs and outputs for various target methods. If it is None,
262
+ sample_input_data would be used to infer the signatures for those models that cannot automatically
263
+ infer the signature. If not None, sample_input_data should not be specified. Defaults to None.
264
+ sample_input_data: Sample input data to infer model signatures from.
265
+ It would also be used as background data in explanation and to capture data lineage. Defaults to None.
266
+ user_files: Dictionary where the keys are subdirectories, and values are lists of local file name
267
+ strings. The local file name strings can include wildcards (? or *) for matching multiple files.
250
268
  code_paths: List of directories containing code to import. Defaults to None.
251
269
  ext_modules: List of external modules to pickle with the model object.
252
270
  Only supported when logging the following types of model:
@@ -261,7 +279,7 @@ class Registry:
261
279
  - embed_local_ml_library: Embed local Snowpark ML into the code directory or folder.
262
280
  Override to True if the local Snowpark ML version is not available in the Snowflake Anaconda
263
281
  Channel. Otherwise, defaults to False
264
- - relax_version: Whether or not relax the version constraints of the dependencies when running in the
282
+ - relax_version: Whether to relax the version constraints of the dependencies when running in the
265
283
  Warehouse. It detects any ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to True.
266
284
  - function_type: Set the method function type globally. To set method function types individually see
267
285
  function_type in model_options.
@@ -301,6 +319,7 @@ class Registry:
301
319
  python_version=python_version,
302
320
  signatures=signatures,
303
321
  sample_input_data=sample_input_data,
322
+ user_files=user_files,
304
323
  code_paths=code_paths,
305
324
  ext_modules=ext_modules,
306
325
  task=task,
@@ -388,15 +407,15 @@ class Registry:
388
407
  source_config: model_monitor_config.ModelMonitorSourceConfig,
389
408
  model_monitor_config: model_monitor_config.ModelMonitorConfig,
390
409
  ) -> model_monitor.ModelMonitor:
391
- """Add a Model Monitor to the Registry
410
+ """Add a Model Monitor to the Registry.
392
411
 
393
412
  Args:
394
- name: Name of Model Monitor to create
395
- source_config: Configuration options of table for ModelMonitor.
396
- model_monitor_config: Configuration options of ModelMonitor.
413
+ name: Name of Model Monitor to create.
414
+ source_config: Configuration options of table for Model Monitor.
415
+ model_monitor_config: Configuration options of Model Monitor.
397
416
 
398
417
  Returns:
399
- The newly added ModelMonitor object.
418
+ The newly added Model Monitor object.
400
419
 
401
420
  Raises:
402
421
  ValueError: If monitoring is not enabled in the Registry.
@@ -407,16 +426,16 @@ class Registry:
407
426
 
408
427
  @overload
409
428
  def get_monitor(self, model_version: model_version_impl.ModelVersion) -> model_monitor.ModelMonitor:
410
- """Get a Model Monitor on a ModelVersion from the Registry
429
+ """Get a Model Monitor on a Model Version from the Registry.
411
430
 
412
431
  Args:
413
- model_version: ModelVersion for which to retrieve the ModelMonitor.
432
+ model_version: Model Version for which to retrieve the Model Monitor.
414
433
  """
415
434
  ...
416
435
 
417
436
  @overload
418
437
  def get_monitor(self, name: str) -> model_monitor.ModelMonitor:
419
- """Get a Model Monitor from the Registry
438
+ """Get a Model Monitor by name from the Registry.
420
439
 
421
440
  Args:
422
441
  name: Name of Model Monitor to retrieve.
@@ -431,14 +450,14 @@ class Registry:
431
450
  def get_monitor(
432
451
  self, *, name: Optional[str] = None, model_version: Optional[model_version_impl.ModelVersion] = None
433
452
  ) -> model_monitor.ModelMonitor:
434
- """Get a Model Monitor from the Registry
453
+ """Get a Model Monitor from the Registry.
435
454
 
436
455
  Args:
437
456
  name: Name of Model Monitor to retrieve.
438
- model_version: ModelVersion for which to retrieve the ModelMonitor.
457
+ model_version: Model Version for which to retrieve the Model Monitor.
439
458
 
440
459
  Returns:
441
- The fetched ModelMonitor.
460
+ The fetched Model Monitor.
442
461
 
443
462
  Raises:
444
463
  ValueError: If monitoring is not enabled in the Registry.
@@ -476,7 +495,7 @@ class Registry:
476
495
  )
477
496
  @snowpark._internal.utils.private_preview(version=model_monitor_version.SNOWFLAKE_ML_MONITORING_MIN_VERSION)
478
497
  def delete_monitor(self, name: str) -> None:
479
- """Delete a Model Monitor from the Registry
498
+ """Delete a Model Monitor by name from the Registry.
480
499
 
481
500
  Args:
482
501
  name: Name of the Model Monitor to delete.
@@ -0,0 +1,75 @@
1
+ import http
2
+ import logging
3
+ from datetime import timedelta
4
+ from typing import Dict, Optional
5
+
6
+ import requests
7
+ from cryptography.hazmat.primitives.asymmetric import types
8
+ from requests import auth
9
+
10
+ from snowflake.ml._internal.utils import jwt_generator
11
+
12
+ logger = logging.getLogger(__name__)
13
+ _JWT_TOKEN_CACHE: Dict[str, Dict[int, str]] = {}
14
+
15
+
16
+ def get_jwt_token_generator(
17
+ account: str,
18
+ user: str,
19
+ private_key: types.PRIVATE_KEY_TYPES,
20
+ lifetime: Optional[timedelta] = None,
21
+ renewal_delay: Optional[timedelta] = None,
22
+ ) -> jwt_generator.JWTGenerator:
23
+ return jwt_generator.JWTGenerator(account, user, private_key, lifetime=lifetime, renewal_delay=renewal_delay)
24
+
25
+
26
+ def _get_snowflake_token_by_jwt(
27
+ jwt_token_generator: jwt_generator.JWTGenerator,
28
+ account: Optional[str] = None,
29
+ role: Optional[str] = None,
30
+ endpoint: Optional[str] = None,
31
+ snowflake_account_url: Optional[str] = None,
32
+ ) -> str:
33
+ scope_role = f"session:role:{role}" if role is not None else None
34
+ scope = " ".join(filter(None, [scope_role, endpoint]))
35
+ data = {
36
+ "grant_type": "urn:ietf:params:oauth:grant-type:jwt-bearer",
37
+ "scope": scope or None,
38
+ "assertion": jwt_token_generator.get_token(),
39
+ }
40
+ account = account or jwt_token_generator.account
41
+ url = f"https://{account}.snowflakecomputing.com/oauth/token"
42
+ if snowflake_account_url:
43
+ url = f"{snowflake_account_url}/oauth/token"
44
+
45
+ cache_key = hash(frozenset(data.items()))
46
+ if url in _JWT_TOKEN_CACHE:
47
+ if cache_key in _JWT_TOKEN_CACHE[url]:
48
+ return _JWT_TOKEN_CACHE[url][cache_key]
49
+ else:
50
+ _JWT_TOKEN_CACHE[url] = {}
51
+
52
+ response = requests.post(url, data=data)
53
+ if response.status_code != http.HTTPStatus.OK:
54
+ raise RuntimeError(f"Failed to get snowflake token: {response.status_code} {response.content!r}")
55
+ auth_token = response.text
56
+ _JWT_TOKEN_CACHE[url][cache_key] = auth_token
57
+ return auth_token
58
+
59
+
60
+ class SnowflakeJWTTokenAuth(auth.AuthBase):
61
+ def __init__(
62
+ self,
63
+ jwt_token_generator: jwt_generator.JWTGenerator,
64
+ account: Optional[str] = None,
65
+ role: Optional[str] = None,
66
+ endpoint: Optional[str] = None,
67
+ snowflake_account_url: Optional[str] = None,
68
+ ) -> None:
69
+ self.snowflake_token = _get_snowflake_token_by_jwt(
70
+ jwt_token_generator, account, role, endpoint, snowflake_account_url
71
+ )
72
+
73
+ def __call__(self, r: requests.PreparedRequest) -> requests.PreparedRequest:
74
+ r.headers["Authorization"] = f'Snowflake Token="{self.snowflake_token}"'
75
+ return r
snowflake/ml/version.py CHANGED
@@ -1 +1 @@
1
- VERSION="1.7.1"
1
+ VERSION="1.7.3"