snowflake-ml-python 1.5.4__py3-none-any.whl → 1.6.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (100) hide show
  1. snowflake/cortex/__init__.py +2 -0
  2. snowflake/cortex/_classify_text.py +36 -0
  3. snowflake/cortex/_complete.py +66 -35
  4. snowflake/cortex/_util.py +4 -4
  5. snowflake/ml/_internal/env_utils.py +11 -5
  6. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  7. snowflake/ml/_internal/lineage/lineage_utils.py +4 -4
  8. snowflake/ml/_internal/telemetry.py +26 -2
  9. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  10. snowflake/ml/data/_internal/arrow_ingestor.py +284 -0
  11. snowflake/ml/data/data_connector.py +186 -0
  12. snowflake/ml/data/data_ingestor.py +45 -0
  13. snowflake/ml/data/data_source.py +23 -0
  14. snowflake/ml/data/ingestor_utils.py +62 -0
  15. snowflake/ml/data/torch_dataset.py +33 -0
  16. snowflake/ml/dataset/dataset.py +1 -13
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +23 -117
  19. snowflake/ml/feature_store/access_manager.py +7 -1
  20. snowflake/ml/feature_store/entity.py +19 -2
  21. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  22. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  23. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  24. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +20 -0
  26. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +37 -0
  27. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +30 -0
  28. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +7 -0
  29. snowflake/ml/feature_store/examples/example_helper.py +278 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +12 -0
  31. snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +44 -0
  32. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  33. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +9 -0
  34. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  35. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +36 -0
  36. snowflake/ml/feature_store/examples/source_data/fraud_transactions.yaml +29 -0
  37. snowflake/ml/feature_store/examples/source_data/nyc_yellow_trips.yaml +4 -0
  38. snowflake/ml/feature_store/examples/source_data/winequality_red.yaml +32 -0
  39. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +14 -0
  40. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +36 -0
  41. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +24 -0
  42. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +8 -0
  43. snowflake/ml/feature_store/feature_store.py +637 -76
  44. snowflake/ml/feature_store/feature_view.py +316 -9
  45. snowflake/ml/fileset/stage_fs.py +18 -10
  46. snowflake/ml/lineage/lineage_node.py +1 -1
  47. snowflake/ml/model/_client/model/model_impl.py +11 -2
  48. snowflake/ml/model/_client/model/model_version_impl.py +171 -20
  49. snowflake/ml/model/_client/ops/model_ops.py +105 -27
  50. snowflake/ml/model/_client/ops/service_ops.py +121 -0
  51. snowflake/ml/model/_client/service/model_deployment_spec.py +95 -0
  52. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +31 -0
  53. snowflake/ml/model/_client/sql/model_version.py +13 -4
  54. snowflake/ml/model/_client/sql/service.py +129 -0
  55. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
  56. snowflake/ml/model/_model_composer/model_composer.py +14 -14
  57. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +33 -17
  58. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -1
  59. snowflake/ml/model/_model_composer/model_method/function_generator.py +3 -3
  60. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +3 -32
  61. snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +3 -27
  62. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +3 -32
  63. snowflake/ml/model/_model_composer/model_method/model_method.py +5 -2
  64. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  65. snowflake/ml/model/_packager/model_handlers/_base.py +30 -3
  66. snowflake/ml/model/_packager/model_handlers/_utils.py +58 -1
  67. snowflake/ml/model/_packager/model_handlers/catboost.py +52 -3
  68. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  69. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +9 -5
  70. snowflake/ml/model/_packager/model_handlers/lightgbm.py +80 -3
  71. snowflake/ml/model/_packager/model_handlers/llm.py +7 -3
  72. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  73. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  74. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  75. snowflake/ml/model/_packager/model_handlers/sklearn.py +87 -4
  76. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +7 -2
  77. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  78. snowflake/ml/model/_packager/model_handlers/torchscript.py +8 -3
  79. snowflake/ml/model/_packager/model_handlers/xgboost.py +71 -3
  80. snowflake/ml/model/_packager/model_meta/model_meta.py +32 -2
  81. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +19 -0
  82. snowflake/ml/model/_packager/model_packager.py +2 -1
  83. snowflake/ml/model/_packager/model_runtime/model_runtime.py +7 -7
  84. snowflake/ml/model/model_signature.py +4 -4
  85. snowflake/ml/model/type_hints.py +2 -0
  86. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +1 -1
  87. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +13 -1
  88. snowflake/ml/modeling/framework/base.py +28 -19
  89. snowflake/ml/modeling/impute/simple_imputer.py +26 -0
  90. snowflake/ml/modeling/pipeline/pipeline.py +7 -4
  91. snowflake/ml/registry/_manager/model_manager.py +16 -2
  92. snowflake/ml/registry/registry.py +100 -13
  93. snowflake/ml/utils/sql_client.py +22 -0
  94. snowflake/ml/version.py +1 -1
  95. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/METADATA +81 -2
  96. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/RECORD +99 -66
  97. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/WHEEL +1 -1
  98. snowflake/ml/_internal/lineage/data_source.py +0 -10
  99. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/LICENSE.txt +0 -0
  100. {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,6 @@
1
1
  #!/usr/bin/env python3
2
2
  import inspect
3
3
  from abc import abstractmethod
4
- from collections import defaultdict
5
4
  from datetime import datetime
6
5
  from typing import Any, Dict, Iterable, List, Mapping, Optional, Union, overload
7
6
 
@@ -18,6 +17,7 @@ from snowflake.ml._internal.exceptions import (
18
17
  )
19
18
  from snowflake.ml._internal.lineage import lineage_utils
20
19
  from snowflake.ml._internal.utils import identifier, parallelize
20
+ from snowflake.ml.data import data_source
21
21
  from snowflake.ml.modeling.framework import _utils
22
22
  from snowflake.snowpark import functions as F
23
23
 
@@ -246,7 +246,7 @@ class Base:
246
246
 
247
247
  def get_params(self, deep: bool = True) -> Dict[str, Any]:
248
248
  """
249
- Get parameters for this transformer.
249
+ Get the snowflake-ml parameters for this transformer.
250
250
 
251
251
  Args:
252
252
  deep: If True, will return the parameters for this transformer and
@@ -265,13 +265,13 @@ class Base:
265
265
  out[key] = value
266
266
  return out
267
267
 
268
- def set_params(self, **params: Dict[str, Any]) -> None:
268
+ def set_params(self, **params: Any) -> None:
269
269
  """
270
270
  Set the parameters of this transformer.
271
271
 
272
- The method works on simple transformers as well as on nested objects.
273
- The latter have parameters of the form ``<component>__<parameter>``
274
- so that it's possible to update each component of a nested object.
272
+ The method works on simple transformers as well as on sklearn compatible pipelines with nested
273
+ objects, once the transformer has been fit. Nested objects have parameters of the form
274
+ ``<component>__<parameter>`` so that it's possible to update each component of a nested object.
275
275
 
276
276
  Args:
277
277
  **params: Transformer parameter names mapped to their values.
@@ -283,12 +283,28 @@ class Base:
283
283
  # simple optimization to gain speed (inspect is slow)
284
284
  return
285
285
  valid_params = self.get_params(deep=True)
286
+ valid_skl_params = {}
287
+ if hasattr(self, "_sklearn_object") and self._sklearn_object is not None:
288
+ valid_skl_params = self._sklearn_object.get_params()
286
289
 
287
- nested_params: Dict[str, Any] = defaultdict(dict) # grouped by prefix
288
290
  for key, value in params.items():
289
- key, delim, sub_key = key.partition("__")
290
- if key not in valid_params:
291
- local_valid_params = self._get_param_names()
291
+ if valid_params.get("steps"):
292
+ # Recurse through pipeline steps
293
+ key, _, sub_key = key.partition("__")
294
+ for name, nested_object in valid_params["steps"]:
295
+ if name == key:
296
+ nested_object.set_params(**{sub_key: value})
297
+
298
+ elif key in valid_params:
299
+ setattr(self, key, value)
300
+ valid_params[key] = value
301
+ elif key in valid_skl_params:
302
+ # This dictionary would be empty if the following assert were not true, as specified above.
303
+ assert hasattr(self, "_sklearn_object") and self._sklearn_object is not None
304
+ setattr(self._sklearn_object, key, value)
305
+ valid_skl_params[key] = value
306
+ else:
307
+ local_valid_params = self._get_param_names() + list(valid_skl_params.keys())
292
308
  raise exceptions.SnowflakeMLException(
293
309
  error_code=error_codes.INVALID_ARGUMENT,
294
310
  original_exception=ValueError(
@@ -298,15 +314,6 @@ class Base:
298
314
  ),
299
315
  )
300
316
 
301
- if delim:
302
- nested_params[key][sub_key] = value
303
- else:
304
- setattr(self, key, value)
305
- valid_params[key] = value
306
-
307
- for key, sub_params in nested_params.items():
308
- valid_params[key].set_params(**sub_params)
309
-
310
317
  def get_sklearn_args(
311
318
  self,
312
319
  default_sklearn_obj: Optional[object] = None,
@@ -427,6 +434,8 @@ class BaseEstimator(Base):
427
434
  def fit(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> "BaseEstimator":
428
435
  """Runs universal logics for all fit implementations."""
429
436
  data_sources = lineage_utils.get_data_sources(dataset)
437
+ if not data_sources and isinstance(dataset, snowpark.DataFrame):
438
+ data_sources = [data_source.DataFrameInfo(dataset.queries["queries"][-1])]
430
439
  lineage_utils.set_data_sources(self, data_sources)
431
440
  return self._fit(dataset)
432
441
 
@@ -1,5 +1,6 @@
1
1
  #!/usr/bin/env python3
2
2
  import copy
3
+ import warnings
3
4
  from typing import Any, Dict, Iterable, Optional, Type, Union
4
5
 
5
6
  import numpy as np
@@ -10,6 +11,7 @@ from sklearn import impute
10
11
  from snowflake import snowpark
11
12
  from snowflake.ml._internal import telemetry
12
13
  from snowflake.ml._internal.exceptions import error_codes, exceptions
14
+ from snowflake.ml._internal.utils import formatting
13
15
  from snowflake.ml.modeling.framework import _utils, base
14
16
  from snowflake.snowpark import functions as F, types as T
15
17
  from snowflake.snowpark._internal import utils as snowpark_utils
@@ -171,6 +173,14 @@ class SimpleImputer(base.BaseTransformer):
171
173
  self.set_output_cols(output_cols)
172
174
  self.set_passthrough_cols(passthrough_cols)
173
175
 
176
+ def _is_integer_type(self, column_type: T.DataType) -> bool:
177
+ return (
178
+ isinstance(column_type, T.ByteType)
179
+ or isinstance(column_type, T.ShortType)
180
+ or isinstance(column_type, T.IntegerType)
181
+ or isinstance(column_type, T.LongType)
182
+ )
183
+
174
184
  def _reset(self) -> None:
175
185
  """
176
186
  Reset internal data-dependent state of the imputer, if necessary.
@@ -389,6 +399,22 @@ class SimpleImputer(base.BaseTransformer):
389
399
  # Use `fillna` for replacing nans. Check if the column has a string data type, or coerce a float.
390
400
  if not isinstance(input_col_datatypes[input_col], T.StringType):
391
401
  statistic = float(statistic)
402
+
403
+ if self._is_integer_type(input_col_datatypes[input_col]):
404
+ if statistic.is_integer():
405
+ statistic = int(statistic)
406
+ else:
407
+ warnings.warn(
408
+ formatting.unwrap(
409
+ f"""
410
+ Integer column may not be imputed with a non-integer value {statistic}.
411
+ In order to impute a non-integer value, convert the column to FloatType before imputing.
412
+ """
413
+ ),
414
+ category=UserWarning,
415
+ stacklevel=1,
416
+ )
417
+
392
418
  transformed_dataset = transformed_dataset.na.fill({output_col: statistic})
393
419
  else:
394
420
  transformed_dataset = transformed_dataset.na.replace(
@@ -19,6 +19,7 @@ from snowflake.ml._internal import file_utils, telemetry
19
19
  from snowflake.ml._internal.exceptions import error_codes, exceptions
20
20
  from snowflake.ml._internal.lineage import lineage_utils
21
21
  from snowflake.ml._internal.utils import snowpark_dataframe_utils, temp_file_utils
22
+ from snowflake.ml.data import data_source
22
23
  from snowflake.ml.model.model_signature import ModelSignature, _infer_signature
23
24
  from snowflake.ml.modeling._internal.model_transformer_builder import (
24
25
  ModelTransformerBuilder,
@@ -99,10 +100,6 @@ class Pipeline(base.BaseTransformer):
99
100
  must implement `fit` and `transform` methods.
100
101
  The final step can be a transform or estimator, that is, it must implement
101
102
  `fit` and `transform`/`predict` methods.
102
- TODO: SKLearn pipeline expects last step(and only the last step) to be an estimator obj or a dummy
103
- estimator(like None or passthrough). Currently this Pipeline class works with a list of all
104
- transforms or a list of transforms ending with an estimator. Should we change this implementation
105
- to only work with list of steps ending with an estimator or a dummy estimator like SKLearn?
106
103
 
107
104
  Args:
108
105
  steps: List of (name, transform) tuples (implementing `fit`/`transform`) that
@@ -111,6 +108,10 @@ class Pipeline(base.BaseTransformer):
111
108
  """
112
109
  super().__init__()
113
110
  self.steps = steps
111
+ # TODO(snandamuri): SKLearn pipeline expects last step(and only the last step) to be an estimator obj or a dummy
112
+ # estimator(like None or passthrough). Currently this Pipeline class works with a list of all
113
+ # transforms or a list of transforms ending with an estimator. Should we change this implementation
114
+ # to only work with list of steps ending with an estimator or a dummy estimator like SKLearn?
114
115
  self._is_final_step_estimator = Pipeline._is_estimator(steps[-1][1])
115
116
  self._is_fitted = False
116
117
  self._feature_names_in: List[np.ndarray[Any, np.dtype[Any]]] = []
@@ -431,6 +432,8 @@ class Pipeline(base.BaseTransformer):
431
432
 
432
433
  # Extract lineage information here since we're overriding fit() directly
433
434
  data_sources = lineage_utils.get_data_sources(dataset)
435
+ if not data_sources and isinstance(dataset, snowpark.DataFrame):
436
+ data_sources = [data_source.DataFrameInfo(dataset.queries["queries"][-1])]
434
437
  lineage_utils.set_data_sources(self, data_sources)
435
438
 
436
439
  if self._can_be_trained_in_ml_runtime(dataset):
@@ -9,7 +9,7 @@ from snowflake.ml._internal.human_readable_id import hrid_generator
9
9
  from snowflake.ml._internal.utils import sql_identifier
10
10
  from snowflake.ml.model import model_signature, type_hints as model_types
11
11
  from snowflake.ml.model._client.model import model_impl, model_version_impl
12
- from snowflake.ml.model._client.ops import metadata_ops, model_ops
12
+ from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
13
13
  from snowflake.ml.model._model_composer import model_composer
14
14
  from snowflake.ml.model._packager.model_meta import model_meta
15
15
  from snowflake.snowpark import session
@@ -30,6 +30,9 @@ class ModelManager:
30
30
  self._model_ops = model_ops.ModelOperator(
31
31
  session, database_name=self._database_name, schema_name=self._schema_name
32
32
  )
33
+ self._service_ops = service_ops.ServiceOperator(
34
+ session, database_name=self._database_name, schema_name=self._schema_name
35
+ )
33
36
  self._hrid_generator = hrid_generator.HRID16()
34
37
 
35
38
  def log_model(
@@ -173,11 +176,16 @@ class ModelManager:
173
176
  )
174
177
 
175
178
  mv = model_version_impl.ModelVersion._ref(
176
- model_ops.ModelOperator(
179
+ model_ops=model_ops.ModelOperator(
177
180
  self._model_ops._session,
178
181
  database_name=database_name_id or self._database_name,
179
182
  schema_name=schema_name_id or self._schema_name,
180
183
  ),
184
+ service_ops=service_ops.ServiceOperator(
185
+ self._service_ops._session,
186
+ database_name=database_name_id or self._database_name,
187
+ schema_name=schema_name_id or self._schema_name,
188
+ ),
181
189
  model_name=model_name_id,
182
190
  version_name=version_name_id,
183
191
  )
@@ -216,6 +224,11 @@ class ModelManager:
216
224
  database_name=database_name_id or self._database_name,
217
225
  schema_name=schema_name_id or self._schema_name,
218
226
  ),
227
+ service_ops=service_ops.ServiceOperator(
228
+ self._service_ops._session,
229
+ database_name=database_name_id or self._database_name,
230
+ schema_name=schema_name_id or self._schema_name,
231
+ ),
219
232
  model_name=model_name_id,
220
233
  )
221
234
  else:
@@ -234,6 +247,7 @@ class ModelManager:
234
247
  return [
235
248
  model_impl.Model._ref(
236
249
  self._model_ops,
250
+ service_ops=self._service_ops,
237
251
  model_name=model_name,
238
252
  )
239
253
  for model_name in model_names
@@ -1,5 +1,6 @@
1
+ import warnings
1
2
  from types import ModuleType
2
- from typing import Any, Dict, List, Optional
3
+ from typing import Any, Dict, List, Optional, Union, overload
3
4
 
4
5
  import pandas as pd
5
6
 
@@ -68,6 +69,90 @@ class Registry:
68
69
  """Get the location (database.schema) of the registry."""
69
70
  return ".".join([self._database_name.identifier(), self._schema_name.identifier()])
70
71
 
72
+ @overload
73
+ def log_model(
74
+ self,
75
+ model: model_types.SupportedModelType,
76
+ *,
77
+ model_name: str,
78
+ version_name: Optional[str] = None,
79
+ comment: Optional[str] = None,
80
+ metrics: Optional[Dict[str, Any]] = None,
81
+ conda_dependencies: Optional[List[str]] = None,
82
+ pip_requirements: Optional[List[str]] = None,
83
+ python_version: Optional[str] = None,
84
+ signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
85
+ sample_input_data: Optional[model_types.SupportedDataType] = None,
86
+ code_paths: Optional[List[str]] = None,
87
+ ext_modules: Optional[List[ModuleType]] = None,
88
+ options: Optional[model_types.ModelSaveOption] = None,
89
+ ) -> ModelVersion:
90
+ """
91
+ Log a model with various parameters and metadata.
92
+
93
+ Args:
94
+ model: Model object of supported types such as Scikit-learn, XGBoost, LightGBM, Snowpark ML,
95
+ PyTorch, TorchScript, Tensorflow, Tensorflow Keras, MLFlow, HuggingFace Pipeline,
96
+ Sentence Transformers, Peft-finetuned LLM, or Custom Model.
97
+ model_name: Name to identify the model.
98
+ version_name: Version identifier for the model. Combination of model_name and version_name must be unique.
99
+ If not specified, a random name will be generated.
100
+ comment: Comment associated with the model version. Defaults to None.
101
+ metrics: A JSON serializable dictionary containing metrics linked to the model version. Defaults to None.
102
+ signatures: Model data signatures for inputs and outputs for various target methods. If it is None,
103
+ sample_input_data would be used to infer the signatures for those models that cannot automatically
104
+ infer the signature. If not None, sample_input_data should not be specified. Defaults to None.
105
+ sample_input_data: Sample input data to infer model signatures from. Defaults to None.
106
+ conda_dependencies: List of Conda package specifications. Use "[channel::]package [operator version]" syntax
107
+ to specify a dependency. It is a recommended way to specify your dependencies using conda. When channel
108
+ is not specified, Snowflake Anaconda Channel will be used. Defaults to None.
109
+ pip_requirements: List of Pip package specifications. Defaults to None.
110
+ Currently it is not supported since Model can only executed in Snowflake Warehouse where all
111
+ dependencies are required to be retrieved from Snowflake Anaconda Channel.
112
+ python_version: Python version in which the model is run. Defaults to None.
113
+ code_paths: List of directories containing code to import. Defaults to None.
114
+ ext_modules: List of external modules to pickle with the model object.
115
+ Only supported when logging the following types of model:
116
+ Scikit-learn, Snowpark ML, PyTorch, TorchScript and Custom Model. Defaults to None.
117
+ options (Dict[str, Any], optional): Additional model saving options.
118
+ Model Saving Options include:
119
+ - embed_local_ml_library: Embed local Snowpark ML into the code directory or folder.
120
+ Override to True if the local Snowpark ML version is not available in the Snowflake Anaconda
121
+ Channel. Otherwise, defaults to False
122
+ - relax_version: Whether or not relax the version constraints of the dependencies.
123
+ It detects any ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to True.
124
+ - function_type: Set the method function type globally. To set method function types individually see
125
+ function_type in model_options.
126
+ - method_options: Per-method saving options including:
127
+ - case_sensitive: Indicates whether the method and its signature should be case sensitive.
128
+ This means when you refer the method in the SQL, you need to double quote it.
129
+ This will be helpful if you need case to tell apart your methods or features, or you have
130
+ non-alphabetic characters in your method or feature name. Defaults to False.
131
+ - max_batch_size: Maximum batch size that the method could accept in the Snowflake Warehouse.
132
+ Defaults to None, determined automatically by Snowflake.
133
+ - function_type: One of supported model method function types (FUNCTION or TABLE_FUNCTION).
134
+ """
135
+ ...
136
+
137
+ @overload
138
+ def log_model(
139
+ self,
140
+ model: ModelVersion,
141
+ *,
142
+ model_name: str,
143
+ version_name: Optional[str] = None,
144
+ ) -> ModelVersion:
145
+ """
146
+ Log a model with a ModelVersion object.
147
+
148
+ Args:
149
+ model: Source ModelVersion object used to create the new ModelVersion object.
150
+ model_name: Name to identify the model.
151
+ version_name: Version identifier for the model. Combination of model_name and version_name must be unique.
152
+ If not specified, a random name will be generated.
153
+ """
154
+ ...
155
+
71
156
  @telemetry.send_api_usage_telemetry(
72
157
  project=_TELEMETRY_PROJECT,
73
158
  subproject=_MODEL_TELEMETRY_SUBPROJECT,
@@ -84,7 +169,7 @@ class Registry:
84
169
  )
85
170
  def log_model(
86
171
  self,
87
- model: model_types.SupportedModelType,
172
+ model: Union[model_types.SupportedModelType, ModelVersion],
88
173
  *,
89
174
  model_name: str,
90
175
  version_name: Optional[str] = None,
@@ -100,12 +185,14 @@ class Registry:
100
185
  options: Optional[model_types.ModelSaveOption] = None,
101
186
  ) -> ModelVersion:
102
187
  """
103
- Log a model with various parameters and metadata.
188
+ Log a model with various parameters and metadata, or a ModelVersion object.
104
189
 
105
190
  Args:
106
- model: Model object of supported types such as Scikit-learn, XGBoost, LightGBM, Snowpark ML,
107
- PyTorch, TorchScript, Tensorflow, Tensorflow Keras, MLFlow, HuggingFace Pipeline,
108
- Sentence Transformers, Peft-finetuned LLM, or Custom Model.
191
+ model: Supported model or ModelVersion object.
192
+ - Supported model: Model object of supported types such as Scikit-learn, XGBoost, LightGBM, Snowpark ML,
193
+ PyTorch, TorchScript, Tensorflow, Tensorflow Keras, MLFlow, HuggingFace Pipeline, Sentence Transformers,
194
+ Peft-finetuned LLM, or Custom Model.
195
+ - ModelVersion: Source ModelVersion object used to create the new ModelVersion object.
109
196
  model_name: Name to identify the model.
110
197
  version_name: Version identifier for the model. Combination of model_name and version_name must be unique.
111
198
  If not specified, a random name will be generated.
@@ -146,9 +233,6 @@ class Registry:
146
233
  Defaults to None, determined automatically by Snowflake.
147
234
  - function_type: One of supported model method function types (FUNCTION or TABLE_FUNCTION).
148
235
 
149
- Raises:
150
- NotImplementedError: `pip_requirements` is not supported.
151
-
152
236
  Returns:
153
237
  ModelVersion: ModelVersion object corresponding to the model just logged.
154
238
  """
@@ -157,10 +241,13 @@ class Registry:
157
241
  subproject=_MODEL_TELEMETRY_SUBPROJECT,
158
242
  )
159
243
  if pip_requirements:
160
- raise NotImplementedError(
161
- "Currently `pip_requirements` is not supported since Model can only executed "
244
+ warnings.warn(
245
+ "Models logged specifying `pip_requirements` can not be executed "
162
246
  "in Snowflake Warehouse where all dependencies are required to be retrieved "
163
- "from Snowflake Anaconda Channel."
247
+ "from Snowflake Anaconda Channel. Specify model save option `include_pip_dependencies`"
248
+ "to log model with pip dependencies.",
249
+ category=UserWarning,
250
+ stacklevel=1,
164
251
  )
165
252
  return self._model_manager.log_model(
166
253
  model=model,
@@ -169,7 +256,7 @@ class Registry:
169
256
  comment=comment,
170
257
  metrics=metrics,
171
258
  conda_dependencies=conda_dependencies,
172
- pip_requirements=None,
259
+ pip_requirements=pip_requirements,
173
260
  python_version=python_version,
174
261
  signatures=signatures,
175
262
  sample_input_data=sample_input_data,
@@ -0,0 +1,22 @@
1
+ from enum import Enum
2
+ from typing import Dict
3
+
4
+
5
+ class CreationOption(Enum):
6
+ FAIL_IF_NOT_EXIST = 1
7
+ CREATE_IF_NOT_EXIST = 2
8
+ OR_REPLACE = 3
9
+
10
+
11
+ class CreationMode:
12
+ def __init__(self, *, if_not_exists: bool = False, or_replace: bool = False) -> None:
13
+ self.if_not_exists = if_not_exists
14
+ self.or_replace = or_replace
15
+
16
+ def get_ddl_phrases(self) -> Dict[CreationOption, str]:
17
+ if_not_exists_sql = " IF NOT EXISTS" if self.if_not_exists else ""
18
+ or_replace_sql = " OR REPLACE" if self.or_replace else ""
19
+ return {
20
+ CreationOption.CREATE_IF_NOT_EXIST: if_not_exists_sql,
21
+ CreationOption.OR_REPLACE: or_replace_sql,
22
+ }
snowflake/ml/version.py CHANGED
@@ -1 +1 @@
1
- VERSION="1.5.4"
1
+ VERSION="1.6.1"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: snowflake-ml-python
3
- Version: 1.5.4
3
+ Version: 1.6.1
4
4
  Summary: The machine learning client library that is used for interacting with Snowflake to build machine learning solutions.
5
5
  Author-email: "Snowflake, Inc" <support@snowflake.com>
6
6
  License:
@@ -373,7 +373,86 @@ be compatibility issues. Server-side functionality that `snowflake-ml-python` de
373
373
 
374
374
  # Release History
375
375
 
376
- ## 1.5.4
376
+ ## 1.6.1 (TBD)
377
+
378
+ ### Bug Fixes
379
+
380
+ - Feature Store: Support large metadata blob when generating dataset
381
+ - Feature Store: Added a hidden knob in FeatureView as kargs for setting customized
382
+ refresh_mode
383
+ - Registry: Fix an error message in Model Version `run` when `function_name` is not mentioned and model has multiple
384
+ target methods.
385
+ - Cortex inference: snowflake.cortex.Complete now only uses the REST API for streaming and the use_rest_api_experimental
386
+ is no longer needed.
387
+ - Feature Store: Add a new API: FeatureView.list_columns() which list all column information.
388
+ - Data: Fix `DataFrame` ingestion with `ArrowIngestor`.
389
+
390
+ ### New Features
391
+
392
+ - Enable `set_params` to set the parameters of the underlying sklearn estimator, if the snowflake-ml model has been fit.
393
+ - Data: Add top-level exports for `DataConnector` and `DataSource` to `snowflake.ml.data`.
394
+ - Data: Add `snowflake.ml.data.ingestor_utils` module with utility functions helpful for `DataIngestor` implementations.
395
+ - Data: Add new `to_torch_dataset()` connector to `DataConnector` to replace deprecated DataPipe.
396
+ - Registry: Option to `enable_explainability` set to True by default for XGBoost, LightGBM and CatBoost as PuPr feature.
397
+ - Registry: Option to `enable_explainability` when registering SHAP supported sklearn models.
398
+
399
+ ### Behavior Changes
400
+
401
+ ## 1.6.0 (2024-07-29)
402
+
403
+ ### Bug Fixes
404
+
405
+ - Modeling: `SimpleImputer` can impute integer columns with integer values.
406
+ - Registry: Fix an issue when providing a pandas Dataframe whose index is not starting from 0 as the input to
407
+ the `ModelVersion.run`.
408
+
409
+ ### New Features
410
+
411
+ - Feature Store: Add overloads to APIs accept both object and name/version. Impacted APIs include read_feature_view(),
412
+ refresh_feature_view(), get_refresh_history(), resume_feature_view(), suspend_feature_view(), delete_feature_view().
413
+ - Feature Store: Add docstring inline examples for all public APIs.
414
+ - Feature Store: Add new utility class `ExampleHelper` to help with load source data to simplify public notebooks.
415
+ - Registry: Option to `enable_explainability` when registering XGBoost models as a pre-PuPr feature.
416
+ - Feature Store: add new API `update_entity()`.
417
+ - Registry: Option to `enable_explainability` when registering Catboost models as a pre-PuPr feature.
418
+ - Feature Store: Add new argument warehouse to FeatureView constructor to overwrite the default warehouse. Also add
419
+ a new column 'warehouse' to the output of list_feature_views().
420
+ - Registry: Add support for logging model from a model version.
421
+ - Modeling: Distributed Hyperparameter Optimization now announce GA refresh version. The latest memory efficient version
422
+ will not have the 10GB training limitation for dataset any more. To turn off, please run
423
+ `
424
+ from snowflake.ml.modeling._internal.snowpark_implementations import (
425
+ distributed_hpo_trainer,
426
+ )
427
+ distributed_hpo_trainer.ENABLE_EFFICIENT_MEMORY_USAGE = False
428
+ `
429
+ - Registry: Option to `enable_explainability` when registering LightGBM models as a pre-PuPr feature.
430
+ - Data: Add new `snowflake.ml.data` preview module which contains data reading utilities like `DataConnector`
431
+ - `DataConnector` provides efficient connectors from Snowpark `DataFrame`
432
+ and Snowpark ML `Dataset` to external frameworks like PyTorch, TensorFlow, and Pandas. Create `DataConnector`
433
+ instances using the classmethod constructors `DataConnector.from_dataset()` and `DataConnector.from_dataframe()`.
434
+ - Data: Add new `DataConnector.from_sources()` classmethod constructor for constructing from `DataSource` objects.
435
+ - Data: Add new `ingestor_class` arg to `DataConnector` classmethod constructors for easier `DataIngestor` injection.
436
+ - Dataset: `DatasetReader` now subclasses new `DataConnector` class.
437
+ - Add optional `limit` arg to `DatasetReader.to_pandas()`
438
+
439
+ ### Behavior Changes
440
+
441
+ - Feature Store: change some positional parameters to keyword arguments in following APIs:
442
+ - Entity(): desc.
443
+ - FeatureView(): timestamp_col, refresh_freq, desc.
444
+ - FeatureStore(): creation_mode.
445
+ - update_entity(): desc.
446
+ - register_feature_view(): block, overwrite.
447
+ - list_feature_views(): entity_name, feature_view_name.
448
+ - get_refresh_history(): verbose.
449
+ - retrieve_feature_values(): spine_timestamp_col, exclude_columns, include_feature_view_timestamp_col.
450
+ - generate_training_set(): save_as, spine_timestamp_col, spine_label_cols, exclude_columns,
451
+ include_feature_view_timestamp_col.
452
+ - generate_dataset(): version, spine_timestamp_col, spine_label_cols, exclude_columns,
453
+ include_feature_view_timestamp_col, desc, output_type.
454
+
455
+ ## 1.5.4 (2024-07-11)
377
456
 
378
457
  ### Bug Fixes
379
458