snowflake-ml-python 1.5.4__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +2 -0
- snowflake/cortex/_classify_text.py +36 -0
- snowflake/cortex/_complete.py +66 -35
- snowflake/cortex/_util.py +4 -4
- snowflake/ml/_internal/env_utils.py +11 -5
- snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
- snowflake/ml/_internal/lineage/lineage_utils.py +4 -4
- snowflake/ml/_internal/telemetry.py +26 -2
- snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
- snowflake/ml/data/_internal/arrow_ingestor.py +284 -0
- snowflake/ml/data/data_connector.py +186 -0
- snowflake/ml/data/data_ingestor.py +45 -0
- snowflake/ml/data/data_source.py +23 -0
- snowflake/ml/data/ingestor_utils.py +62 -0
- snowflake/ml/data/torch_dataset.py +33 -0
- snowflake/ml/dataset/dataset.py +1 -13
- snowflake/ml/dataset/dataset_metadata.py +3 -1
- snowflake/ml/dataset/dataset_reader.py +23 -117
- snowflake/ml/feature_store/access_manager.py +7 -1
- snowflake/ml/feature_store/entity.py +19 -2
- snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
- snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +20 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +37 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +30 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +7 -0
- snowflake/ml/feature_store/examples/example_helper.py +278 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +12 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +44 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +9 -0
- snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +36 -0
- snowflake/ml/feature_store/examples/source_data/fraud_transactions.yaml +29 -0
- snowflake/ml/feature_store/examples/source_data/nyc_yellow_trips.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/winequality_red.yaml +32 -0
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +14 -0
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +36 -0
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +24 -0
- snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +8 -0
- snowflake/ml/feature_store/feature_store.py +637 -76
- snowflake/ml/feature_store/feature_view.py +316 -9
- snowflake/ml/fileset/stage_fs.py +18 -10
- snowflake/ml/lineage/lineage_node.py +1 -1
- snowflake/ml/model/_client/model/model_impl.py +11 -2
- snowflake/ml/model/_client/model/model_version_impl.py +171 -20
- snowflake/ml/model/_client/ops/model_ops.py +105 -27
- snowflake/ml/model/_client/ops/service_ops.py +121 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +95 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +31 -0
- snowflake/ml/model/_client/sql/model_version.py +13 -4
- snowflake/ml/model/_client/sql/service.py +129 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
- snowflake/ml/model/_model_composer/model_composer.py +14 -14
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +33 -17
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -1
- snowflake/ml/model/_model_composer/model_method/function_generator.py +3 -3
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +3 -32
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +3 -27
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +3 -32
- snowflake/ml/model/_model_composer/model_method/model_method.py +5 -2
- snowflake/ml/model/_packager/model_env/model_env.py +7 -2
- snowflake/ml/model/_packager/model_handlers/_base.py +30 -3
- snowflake/ml/model/_packager/model_handlers/_utils.py +58 -1
- snowflake/ml/model/_packager/model_handlers/catboost.py +52 -3
- snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +9 -5
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +80 -3
- snowflake/ml/model/_packager/model_handlers/llm.py +7 -3
- snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sklearn.py +87 -4
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +7 -2
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +8 -3
- snowflake/ml/model/_packager/model_handlers/xgboost.py +71 -3
- snowflake/ml/model/_packager/model_meta/model_meta.py +32 -2
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +19 -0
- snowflake/ml/model/_packager/model_packager.py +2 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +7 -7
- snowflake/ml/model/model_signature.py +4 -4
- snowflake/ml/model/type_hints.py +2 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +1 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +13 -1
- snowflake/ml/modeling/framework/base.py +28 -19
- snowflake/ml/modeling/impute/simple_imputer.py +26 -0
- snowflake/ml/modeling/pipeline/pipeline.py +7 -4
- snowflake/ml/registry/_manager/model_manager.py +16 -2
- snowflake/ml/registry/registry.py +100 -13
- snowflake/ml/utils/sql_client.py +22 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/METADATA +81 -2
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/RECORD +99 -66
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/lineage/data_source.py +0 -10
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,6 @@
|
|
1
1
|
#!/usr/bin/env python3
|
2
2
|
import inspect
|
3
3
|
from abc import abstractmethod
|
4
|
-
from collections import defaultdict
|
5
4
|
from datetime import datetime
|
6
5
|
from typing import Any, Dict, Iterable, List, Mapping, Optional, Union, overload
|
7
6
|
|
@@ -18,6 +17,7 @@ from snowflake.ml._internal.exceptions import (
|
|
18
17
|
)
|
19
18
|
from snowflake.ml._internal.lineage import lineage_utils
|
20
19
|
from snowflake.ml._internal.utils import identifier, parallelize
|
20
|
+
from snowflake.ml.data import data_source
|
21
21
|
from snowflake.ml.modeling.framework import _utils
|
22
22
|
from snowflake.snowpark import functions as F
|
23
23
|
|
@@ -246,7 +246,7 @@ class Base:
|
|
246
246
|
|
247
247
|
def get_params(self, deep: bool = True) -> Dict[str, Any]:
|
248
248
|
"""
|
249
|
-
Get parameters for this transformer.
|
249
|
+
Get the snowflake-ml parameters for this transformer.
|
250
250
|
|
251
251
|
Args:
|
252
252
|
deep: If True, will return the parameters for this transformer and
|
@@ -265,13 +265,13 @@ class Base:
|
|
265
265
|
out[key] = value
|
266
266
|
return out
|
267
267
|
|
268
|
-
def set_params(self, **params:
|
268
|
+
def set_params(self, **params: Any) -> None:
|
269
269
|
"""
|
270
270
|
Set the parameters of this transformer.
|
271
271
|
|
272
|
-
The method works on simple transformers as well as on nested
|
273
|
-
|
274
|
-
so that it's possible to update each component of a nested object.
|
272
|
+
The method works on simple transformers as well as on sklearn compatible pipelines with nested
|
273
|
+
objects, once the transformer has been fit. Nested objects have parameters of the form
|
274
|
+
``<component>__<parameter>`` so that it's possible to update each component of a nested object.
|
275
275
|
|
276
276
|
Args:
|
277
277
|
**params: Transformer parameter names mapped to their values.
|
@@ -283,12 +283,28 @@ class Base:
|
|
283
283
|
# simple optimization to gain speed (inspect is slow)
|
284
284
|
return
|
285
285
|
valid_params = self.get_params(deep=True)
|
286
|
+
valid_skl_params = {}
|
287
|
+
if hasattr(self, "_sklearn_object") and self._sklearn_object is not None:
|
288
|
+
valid_skl_params = self._sklearn_object.get_params()
|
286
289
|
|
287
|
-
nested_params: Dict[str, Any] = defaultdict(dict) # grouped by prefix
|
288
290
|
for key, value in params.items():
|
289
|
-
|
290
|
-
|
291
|
-
|
291
|
+
if valid_params.get("steps"):
|
292
|
+
# Recurse through pipeline steps
|
293
|
+
key, _, sub_key = key.partition("__")
|
294
|
+
for name, nested_object in valid_params["steps"]:
|
295
|
+
if name == key:
|
296
|
+
nested_object.set_params(**{sub_key: value})
|
297
|
+
|
298
|
+
elif key in valid_params:
|
299
|
+
setattr(self, key, value)
|
300
|
+
valid_params[key] = value
|
301
|
+
elif key in valid_skl_params:
|
302
|
+
# This dictionary would be empty if the following assert were not true, as specified above.
|
303
|
+
assert hasattr(self, "_sklearn_object") and self._sklearn_object is not None
|
304
|
+
setattr(self._sklearn_object, key, value)
|
305
|
+
valid_skl_params[key] = value
|
306
|
+
else:
|
307
|
+
local_valid_params = self._get_param_names() + list(valid_skl_params.keys())
|
292
308
|
raise exceptions.SnowflakeMLException(
|
293
309
|
error_code=error_codes.INVALID_ARGUMENT,
|
294
310
|
original_exception=ValueError(
|
@@ -298,15 +314,6 @@ class Base:
|
|
298
314
|
),
|
299
315
|
)
|
300
316
|
|
301
|
-
if delim:
|
302
|
-
nested_params[key][sub_key] = value
|
303
|
-
else:
|
304
|
-
setattr(self, key, value)
|
305
|
-
valid_params[key] = value
|
306
|
-
|
307
|
-
for key, sub_params in nested_params.items():
|
308
|
-
valid_params[key].set_params(**sub_params)
|
309
|
-
|
310
317
|
def get_sklearn_args(
|
311
318
|
self,
|
312
319
|
default_sklearn_obj: Optional[object] = None,
|
@@ -427,6 +434,8 @@ class BaseEstimator(Base):
|
|
427
434
|
def fit(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> "BaseEstimator":
|
428
435
|
"""Runs universal logics for all fit implementations."""
|
429
436
|
data_sources = lineage_utils.get_data_sources(dataset)
|
437
|
+
if not data_sources and isinstance(dataset, snowpark.DataFrame):
|
438
|
+
data_sources = [data_source.DataFrameInfo(dataset.queries["queries"][-1])]
|
430
439
|
lineage_utils.set_data_sources(self, data_sources)
|
431
440
|
return self._fit(dataset)
|
432
441
|
|
@@ -1,5 +1,6 @@
|
|
1
1
|
#!/usr/bin/env python3
|
2
2
|
import copy
|
3
|
+
import warnings
|
3
4
|
from typing import Any, Dict, Iterable, Optional, Type, Union
|
4
5
|
|
5
6
|
import numpy as np
|
@@ -10,6 +11,7 @@ from sklearn import impute
|
|
10
11
|
from snowflake import snowpark
|
11
12
|
from snowflake.ml._internal import telemetry
|
12
13
|
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
14
|
+
from snowflake.ml._internal.utils import formatting
|
13
15
|
from snowflake.ml.modeling.framework import _utils, base
|
14
16
|
from snowflake.snowpark import functions as F, types as T
|
15
17
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
@@ -171,6 +173,14 @@ class SimpleImputer(base.BaseTransformer):
|
|
171
173
|
self.set_output_cols(output_cols)
|
172
174
|
self.set_passthrough_cols(passthrough_cols)
|
173
175
|
|
176
|
+
def _is_integer_type(self, column_type: T.DataType) -> bool:
|
177
|
+
return (
|
178
|
+
isinstance(column_type, T.ByteType)
|
179
|
+
or isinstance(column_type, T.ShortType)
|
180
|
+
or isinstance(column_type, T.IntegerType)
|
181
|
+
or isinstance(column_type, T.LongType)
|
182
|
+
)
|
183
|
+
|
174
184
|
def _reset(self) -> None:
|
175
185
|
"""
|
176
186
|
Reset internal data-dependent state of the imputer, if necessary.
|
@@ -389,6 +399,22 @@ class SimpleImputer(base.BaseTransformer):
|
|
389
399
|
# Use `fillna` for replacing nans. Check if the column has a string data type, or coerce a float.
|
390
400
|
if not isinstance(input_col_datatypes[input_col], T.StringType):
|
391
401
|
statistic = float(statistic)
|
402
|
+
|
403
|
+
if self._is_integer_type(input_col_datatypes[input_col]):
|
404
|
+
if statistic.is_integer():
|
405
|
+
statistic = int(statistic)
|
406
|
+
else:
|
407
|
+
warnings.warn(
|
408
|
+
formatting.unwrap(
|
409
|
+
f"""
|
410
|
+
Integer column may not be imputed with a non-integer value {statistic}.
|
411
|
+
In order to impute a non-integer value, convert the column to FloatType before imputing.
|
412
|
+
"""
|
413
|
+
),
|
414
|
+
category=UserWarning,
|
415
|
+
stacklevel=1,
|
416
|
+
)
|
417
|
+
|
392
418
|
transformed_dataset = transformed_dataset.na.fill({output_col: statistic})
|
393
419
|
else:
|
394
420
|
transformed_dataset = transformed_dataset.na.replace(
|
@@ -19,6 +19,7 @@ from snowflake.ml._internal import file_utils, telemetry
|
|
19
19
|
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
20
20
|
from snowflake.ml._internal.lineage import lineage_utils
|
21
21
|
from snowflake.ml._internal.utils import snowpark_dataframe_utils, temp_file_utils
|
22
|
+
from snowflake.ml.data import data_source
|
22
23
|
from snowflake.ml.model.model_signature import ModelSignature, _infer_signature
|
23
24
|
from snowflake.ml.modeling._internal.model_transformer_builder import (
|
24
25
|
ModelTransformerBuilder,
|
@@ -99,10 +100,6 @@ class Pipeline(base.BaseTransformer):
|
|
99
100
|
must implement `fit` and `transform` methods.
|
100
101
|
The final step can be a transform or estimator, that is, it must implement
|
101
102
|
`fit` and `transform`/`predict` methods.
|
102
|
-
TODO: SKLearn pipeline expects last step(and only the last step) to be an estimator obj or a dummy
|
103
|
-
estimator(like None or passthrough). Currently this Pipeline class works with a list of all
|
104
|
-
transforms or a list of transforms ending with an estimator. Should we change this implementation
|
105
|
-
to only work with list of steps ending with an estimator or a dummy estimator like SKLearn?
|
106
103
|
|
107
104
|
Args:
|
108
105
|
steps: List of (name, transform) tuples (implementing `fit`/`transform`) that
|
@@ -111,6 +108,10 @@ class Pipeline(base.BaseTransformer):
|
|
111
108
|
"""
|
112
109
|
super().__init__()
|
113
110
|
self.steps = steps
|
111
|
+
# TODO(snandamuri): SKLearn pipeline expects last step(and only the last step) to be an estimator obj or a dummy
|
112
|
+
# estimator(like None or passthrough). Currently this Pipeline class works with a list of all
|
113
|
+
# transforms or a list of transforms ending with an estimator. Should we change this implementation
|
114
|
+
# to only work with list of steps ending with an estimator or a dummy estimator like SKLearn?
|
114
115
|
self._is_final_step_estimator = Pipeline._is_estimator(steps[-1][1])
|
115
116
|
self._is_fitted = False
|
116
117
|
self._feature_names_in: List[np.ndarray[Any, np.dtype[Any]]] = []
|
@@ -431,6 +432,8 @@ class Pipeline(base.BaseTransformer):
|
|
431
432
|
|
432
433
|
# Extract lineage information here since we're overriding fit() directly
|
433
434
|
data_sources = lineage_utils.get_data_sources(dataset)
|
435
|
+
if not data_sources and isinstance(dataset, snowpark.DataFrame):
|
436
|
+
data_sources = [data_source.DataFrameInfo(dataset.queries["queries"][-1])]
|
434
437
|
lineage_utils.set_data_sources(self, data_sources)
|
435
438
|
|
436
439
|
if self._can_be_trained_in_ml_runtime(dataset):
|
@@ -9,7 +9,7 @@ from snowflake.ml._internal.human_readable_id import hrid_generator
|
|
9
9
|
from snowflake.ml._internal.utils import sql_identifier
|
10
10
|
from snowflake.ml.model import model_signature, type_hints as model_types
|
11
11
|
from snowflake.ml.model._client.model import model_impl, model_version_impl
|
12
|
-
from snowflake.ml.model._client.ops import metadata_ops, model_ops
|
12
|
+
from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
|
13
13
|
from snowflake.ml.model._model_composer import model_composer
|
14
14
|
from snowflake.ml.model._packager.model_meta import model_meta
|
15
15
|
from snowflake.snowpark import session
|
@@ -30,6 +30,9 @@ class ModelManager:
|
|
30
30
|
self._model_ops = model_ops.ModelOperator(
|
31
31
|
session, database_name=self._database_name, schema_name=self._schema_name
|
32
32
|
)
|
33
|
+
self._service_ops = service_ops.ServiceOperator(
|
34
|
+
session, database_name=self._database_name, schema_name=self._schema_name
|
35
|
+
)
|
33
36
|
self._hrid_generator = hrid_generator.HRID16()
|
34
37
|
|
35
38
|
def log_model(
|
@@ -173,11 +176,16 @@ class ModelManager:
|
|
173
176
|
)
|
174
177
|
|
175
178
|
mv = model_version_impl.ModelVersion._ref(
|
176
|
-
model_ops.ModelOperator(
|
179
|
+
model_ops=model_ops.ModelOperator(
|
177
180
|
self._model_ops._session,
|
178
181
|
database_name=database_name_id or self._database_name,
|
179
182
|
schema_name=schema_name_id or self._schema_name,
|
180
183
|
),
|
184
|
+
service_ops=service_ops.ServiceOperator(
|
185
|
+
self._service_ops._session,
|
186
|
+
database_name=database_name_id or self._database_name,
|
187
|
+
schema_name=schema_name_id or self._schema_name,
|
188
|
+
),
|
181
189
|
model_name=model_name_id,
|
182
190
|
version_name=version_name_id,
|
183
191
|
)
|
@@ -216,6 +224,11 @@ class ModelManager:
|
|
216
224
|
database_name=database_name_id or self._database_name,
|
217
225
|
schema_name=schema_name_id or self._schema_name,
|
218
226
|
),
|
227
|
+
service_ops=service_ops.ServiceOperator(
|
228
|
+
self._service_ops._session,
|
229
|
+
database_name=database_name_id or self._database_name,
|
230
|
+
schema_name=schema_name_id or self._schema_name,
|
231
|
+
),
|
219
232
|
model_name=model_name_id,
|
220
233
|
)
|
221
234
|
else:
|
@@ -234,6 +247,7 @@ class ModelManager:
|
|
234
247
|
return [
|
235
248
|
model_impl.Model._ref(
|
236
249
|
self._model_ops,
|
250
|
+
service_ops=self._service_ops,
|
237
251
|
model_name=model_name,
|
238
252
|
)
|
239
253
|
for model_name in model_names
|
@@ -1,5 +1,6 @@
|
|
1
|
+
import warnings
|
1
2
|
from types import ModuleType
|
2
|
-
from typing import Any, Dict, List, Optional
|
3
|
+
from typing import Any, Dict, List, Optional, Union, overload
|
3
4
|
|
4
5
|
import pandas as pd
|
5
6
|
|
@@ -68,6 +69,90 @@ class Registry:
|
|
68
69
|
"""Get the location (database.schema) of the registry."""
|
69
70
|
return ".".join([self._database_name.identifier(), self._schema_name.identifier()])
|
70
71
|
|
72
|
+
@overload
|
73
|
+
def log_model(
|
74
|
+
self,
|
75
|
+
model: model_types.SupportedModelType,
|
76
|
+
*,
|
77
|
+
model_name: str,
|
78
|
+
version_name: Optional[str] = None,
|
79
|
+
comment: Optional[str] = None,
|
80
|
+
metrics: Optional[Dict[str, Any]] = None,
|
81
|
+
conda_dependencies: Optional[List[str]] = None,
|
82
|
+
pip_requirements: Optional[List[str]] = None,
|
83
|
+
python_version: Optional[str] = None,
|
84
|
+
signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
|
85
|
+
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
86
|
+
code_paths: Optional[List[str]] = None,
|
87
|
+
ext_modules: Optional[List[ModuleType]] = None,
|
88
|
+
options: Optional[model_types.ModelSaveOption] = None,
|
89
|
+
) -> ModelVersion:
|
90
|
+
"""
|
91
|
+
Log a model with various parameters and metadata.
|
92
|
+
|
93
|
+
Args:
|
94
|
+
model: Model object of supported types such as Scikit-learn, XGBoost, LightGBM, Snowpark ML,
|
95
|
+
PyTorch, TorchScript, Tensorflow, Tensorflow Keras, MLFlow, HuggingFace Pipeline,
|
96
|
+
Sentence Transformers, Peft-finetuned LLM, or Custom Model.
|
97
|
+
model_name: Name to identify the model.
|
98
|
+
version_name: Version identifier for the model. Combination of model_name and version_name must be unique.
|
99
|
+
If not specified, a random name will be generated.
|
100
|
+
comment: Comment associated with the model version. Defaults to None.
|
101
|
+
metrics: A JSON serializable dictionary containing metrics linked to the model version. Defaults to None.
|
102
|
+
signatures: Model data signatures for inputs and outputs for various target methods. If it is None,
|
103
|
+
sample_input_data would be used to infer the signatures for those models that cannot automatically
|
104
|
+
infer the signature. If not None, sample_input_data should not be specified. Defaults to None.
|
105
|
+
sample_input_data: Sample input data to infer model signatures from. Defaults to None.
|
106
|
+
conda_dependencies: List of Conda package specifications. Use "[channel::]package [operator version]" syntax
|
107
|
+
to specify a dependency. It is a recommended way to specify your dependencies using conda. When channel
|
108
|
+
is not specified, Snowflake Anaconda Channel will be used. Defaults to None.
|
109
|
+
pip_requirements: List of Pip package specifications. Defaults to None.
|
110
|
+
Currently it is not supported since Model can only executed in Snowflake Warehouse where all
|
111
|
+
dependencies are required to be retrieved from Snowflake Anaconda Channel.
|
112
|
+
python_version: Python version in which the model is run. Defaults to None.
|
113
|
+
code_paths: List of directories containing code to import. Defaults to None.
|
114
|
+
ext_modules: List of external modules to pickle with the model object.
|
115
|
+
Only supported when logging the following types of model:
|
116
|
+
Scikit-learn, Snowpark ML, PyTorch, TorchScript and Custom Model. Defaults to None.
|
117
|
+
options (Dict[str, Any], optional): Additional model saving options.
|
118
|
+
Model Saving Options include:
|
119
|
+
- embed_local_ml_library: Embed local Snowpark ML into the code directory or folder.
|
120
|
+
Override to True if the local Snowpark ML version is not available in the Snowflake Anaconda
|
121
|
+
Channel. Otherwise, defaults to False
|
122
|
+
- relax_version: Whether or not relax the version constraints of the dependencies.
|
123
|
+
It detects any ==x.y.z in specifiers and replaced with >=x.y, <(x+1). Defaults to True.
|
124
|
+
- function_type: Set the method function type globally. To set method function types individually see
|
125
|
+
function_type in model_options.
|
126
|
+
- method_options: Per-method saving options including:
|
127
|
+
- case_sensitive: Indicates whether the method and its signature should be case sensitive.
|
128
|
+
This means when you refer the method in the SQL, you need to double quote it.
|
129
|
+
This will be helpful if you need case to tell apart your methods or features, or you have
|
130
|
+
non-alphabetic characters in your method or feature name. Defaults to False.
|
131
|
+
- max_batch_size: Maximum batch size that the method could accept in the Snowflake Warehouse.
|
132
|
+
Defaults to None, determined automatically by Snowflake.
|
133
|
+
- function_type: One of supported model method function types (FUNCTION or TABLE_FUNCTION).
|
134
|
+
"""
|
135
|
+
...
|
136
|
+
|
137
|
+
@overload
|
138
|
+
def log_model(
|
139
|
+
self,
|
140
|
+
model: ModelVersion,
|
141
|
+
*,
|
142
|
+
model_name: str,
|
143
|
+
version_name: Optional[str] = None,
|
144
|
+
) -> ModelVersion:
|
145
|
+
"""
|
146
|
+
Log a model with a ModelVersion object.
|
147
|
+
|
148
|
+
Args:
|
149
|
+
model: Source ModelVersion object used to create the new ModelVersion object.
|
150
|
+
model_name: Name to identify the model.
|
151
|
+
version_name: Version identifier for the model. Combination of model_name and version_name must be unique.
|
152
|
+
If not specified, a random name will be generated.
|
153
|
+
"""
|
154
|
+
...
|
155
|
+
|
71
156
|
@telemetry.send_api_usage_telemetry(
|
72
157
|
project=_TELEMETRY_PROJECT,
|
73
158
|
subproject=_MODEL_TELEMETRY_SUBPROJECT,
|
@@ -84,7 +169,7 @@ class Registry:
|
|
84
169
|
)
|
85
170
|
def log_model(
|
86
171
|
self,
|
87
|
-
model: model_types.SupportedModelType,
|
172
|
+
model: Union[model_types.SupportedModelType, ModelVersion],
|
88
173
|
*,
|
89
174
|
model_name: str,
|
90
175
|
version_name: Optional[str] = None,
|
@@ -100,12 +185,14 @@ class Registry:
|
|
100
185
|
options: Optional[model_types.ModelSaveOption] = None,
|
101
186
|
) -> ModelVersion:
|
102
187
|
"""
|
103
|
-
Log a model with various parameters and metadata.
|
188
|
+
Log a model with various parameters and metadata, or a ModelVersion object.
|
104
189
|
|
105
190
|
Args:
|
106
|
-
model:
|
107
|
-
|
108
|
-
|
191
|
+
model: Supported model or ModelVersion object.
|
192
|
+
- Supported model: Model object of supported types such as Scikit-learn, XGBoost, LightGBM, Snowpark ML,
|
193
|
+
PyTorch, TorchScript, Tensorflow, Tensorflow Keras, MLFlow, HuggingFace Pipeline, Sentence Transformers,
|
194
|
+
Peft-finetuned LLM, or Custom Model.
|
195
|
+
- ModelVersion: Source ModelVersion object used to create the new ModelVersion object.
|
109
196
|
model_name: Name to identify the model.
|
110
197
|
version_name: Version identifier for the model. Combination of model_name and version_name must be unique.
|
111
198
|
If not specified, a random name will be generated.
|
@@ -146,9 +233,6 @@ class Registry:
|
|
146
233
|
Defaults to None, determined automatically by Snowflake.
|
147
234
|
- function_type: One of supported model method function types (FUNCTION or TABLE_FUNCTION).
|
148
235
|
|
149
|
-
Raises:
|
150
|
-
NotImplementedError: `pip_requirements` is not supported.
|
151
|
-
|
152
236
|
Returns:
|
153
237
|
ModelVersion: ModelVersion object corresponding to the model just logged.
|
154
238
|
"""
|
@@ -157,10 +241,13 @@ class Registry:
|
|
157
241
|
subproject=_MODEL_TELEMETRY_SUBPROJECT,
|
158
242
|
)
|
159
243
|
if pip_requirements:
|
160
|
-
|
161
|
-
"
|
244
|
+
warnings.warn(
|
245
|
+
"Models logged specifying `pip_requirements` can not be executed "
|
162
246
|
"in Snowflake Warehouse where all dependencies are required to be retrieved "
|
163
|
-
"from Snowflake Anaconda Channel."
|
247
|
+
"from Snowflake Anaconda Channel. Specify model save option `include_pip_dependencies`"
|
248
|
+
"to log model with pip dependencies.",
|
249
|
+
category=UserWarning,
|
250
|
+
stacklevel=1,
|
164
251
|
)
|
165
252
|
return self._model_manager.log_model(
|
166
253
|
model=model,
|
@@ -169,7 +256,7 @@ class Registry:
|
|
169
256
|
comment=comment,
|
170
257
|
metrics=metrics,
|
171
258
|
conda_dependencies=conda_dependencies,
|
172
|
-
pip_requirements=
|
259
|
+
pip_requirements=pip_requirements,
|
173
260
|
python_version=python_version,
|
174
261
|
signatures=signatures,
|
175
262
|
sample_input_data=sample_input_data,
|
@@ -0,0 +1,22 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
from typing import Dict
|
3
|
+
|
4
|
+
|
5
|
+
class CreationOption(Enum):
|
6
|
+
FAIL_IF_NOT_EXIST = 1
|
7
|
+
CREATE_IF_NOT_EXIST = 2
|
8
|
+
OR_REPLACE = 3
|
9
|
+
|
10
|
+
|
11
|
+
class CreationMode:
|
12
|
+
def __init__(self, *, if_not_exists: bool = False, or_replace: bool = False) -> None:
|
13
|
+
self.if_not_exists = if_not_exists
|
14
|
+
self.or_replace = or_replace
|
15
|
+
|
16
|
+
def get_ddl_phrases(self) -> Dict[CreationOption, str]:
|
17
|
+
if_not_exists_sql = " IF NOT EXISTS" if self.if_not_exists else ""
|
18
|
+
or_replace_sql = " OR REPLACE" if self.or_replace else ""
|
19
|
+
return {
|
20
|
+
CreationOption.CREATE_IF_NOT_EXIST: if_not_exists_sql,
|
21
|
+
CreationOption.OR_REPLACE: or_replace_sql,
|
22
|
+
}
|
snowflake/ml/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
VERSION="1.
|
1
|
+
VERSION="1.6.1"
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: snowflake-ml-python
|
3
|
-
Version: 1.
|
3
|
+
Version: 1.6.1
|
4
4
|
Summary: The machine learning client library that is used for interacting with Snowflake to build machine learning solutions.
|
5
5
|
Author-email: "Snowflake, Inc" <support@snowflake.com>
|
6
6
|
License:
|
@@ -373,7 +373,86 @@ be compatibility issues. Server-side functionality that `snowflake-ml-python` de
|
|
373
373
|
|
374
374
|
# Release History
|
375
375
|
|
376
|
-
## 1.
|
376
|
+
## 1.6.1 (TBD)
|
377
|
+
|
378
|
+
### Bug Fixes
|
379
|
+
|
380
|
+
- Feature Store: Support large metadata blob when generating dataset
|
381
|
+
- Feature Store: Added a hidden knob in FeatureView as kargs for setting customized
|
382
|
+
refresh_mode
|
383
|
+
- Registry: Fix an error message in Model Version `run` when `function_name` is not mentioned and model has multiple
|
384
|
+
target methods.
|
385
|
+
- Cortex inference: snowflake.cortex.Complete now only uses the REST API for streaming and the use_rest_api_experimental
|
386
|
+
is no longer needed.
|
387
|
+
- Feature Store: Add a new API: FeatureView.list_columns() which list all column information.
|
388
|
+
- Data: Fix `DataFrame` ingestion with `ArrowIngestor`.
|
389
|
+
|
390
|
+
### New Features
|
391
|
+
|
392
|
+
- Enable `set_params` to set the parameters of the underlying sklearn estimator, if the snowflake-ml model has been fit.
|
393
|
+
- Data: Add top-level exports for `DataConnector` and `DataSource` to `snowflake.ml.data`.
|
394
|
+
- Data: Add `snowflake.ml.data.ingestor_utils` module with utility functions helpful for `DataIngestor` implementations.
|
395
|
+
- Data: Add new `to_torch_dataset()` connector to `DataConnector` to replace deprecated DataPipe.
|
396
|
+
- Registry: Option to `enable_explainability` set to True by default for XGBoost, LightGBM and CatBoost as PuPr feature.
|
397
|
+
- Registry: Option to `enable_explainability` when registering SHAP supported sklearn models.
|
398
|
+
|
399
|
+
### Behavior Changes
|
400
|
+
|
401
|
+
## 1.6.0 (2024-07-29)
|
402
|
+
|
403
|
+
### Bug Fixes
|
404
|
+
|
405
|
+
- Modeling: `SimpleImputer` can impute integer columns with integer values.
|
406
|
+
- Registry: Fix an issue when providing a pandas Dataframe whose index is not starting from 0 as the input to
|
407
|
+
the `ModelVersion.run`.
|
408
|
+
|
409
|
+
### New Features
|
410
|
+
|
411
|
+
- Feature Store: Add overloads to APIs accept both object and name/version. Impacted APIs include read_feature_view(),
|
412
|
+
refresh_feature_view(), get_refresh_history(), resume_feature_view(), suspend_feature_view(), delete_feature_view().
|
413
|
+
- Feature Store: Add docstring inline examples for all public APIs.
|
414
|
+
- Feature Store: Add new utility class `ExampleHelper` to help with load source data to simplify public notebooks.
|
415
|
+
- Registry: Option to `enable_explainability` when registering XGBoost models as a pre-PuPr feature.
|
416
|
+
- Feature Store: add new API `update_entity()`.
|
417
|
+
- Registry: Option to `enable_explainability` when registering Catboost models as a pre-PuPr feature.
|
418
|
+
- Feature Store: Add new argument warehouse to FeatureView constructor to overwrite the default warehouse. Also add
|
419
|
+
a new column 'warehouse' to the output of list_feature_views().
|
420
|
+
- Registry: Add support for logging model from a model version.
|
421
|
+
- Modeling: Distributed Hyperparameter Optimization now announce GA refresh version. The latest memory efficient version
|
422
|
+
will not have the 10GB training limitation for dataset any more. To turn off, please run
|
423
|
+
`
|
424
|
+
from snowflake.ml.modeling._internal.snowpark_implementations import (
|
425
|
+
distributed_hpo_trainer,
|
426
|
+
)
|
427
|
+
distributed_hpo_trainer.ENABLE_EFFICIENT_MEMORY_USAGE = False
|
428
|
+
`
|
429
|
+
- Registry: Option to `enable_explainability` when registering LightGBM models as a pre-PuPr feature.
|
430
|
+
- Data: Add new `snowflake.ml.data` preview module which contains data reading utilities like `DataConnector`
|
431
|
+
- `DataConnector` provides efficient connectors from Snowpark `DataFrame`
|
432
|
+
and Snowpark ML `Dataset` to external frameworks like PyTorch, TensorFlow, and Pandas. Create `DataConnector`
|
433
|
+
instances using the classmethod constructors `DataConnector.from_dataset()` and `DataConnector.from_dataframe()`.
|
434
|
+
- Data: Add new `DataConnector.from_sources()` classmethod constructor for constructing from `DataSource` objects.
|
435
|
+
- Data: Add new `ingestor_class` arg to `DataConnector` classmethod constructors for easier `DataIngestor` injection.
|
436
|
+
- Dataset: `DatasetReader` now subclasses new `DataConnector` class.
|
437
|
+
- Add optional `limit` arg to `DatasetReader.to_pandas()`
|
438
|
+
|
439
|
+
### Behavior Changes
|
440
|
+
|
441
|
+
- Feature Store: change some positional parameters to keyword arguments in following APIs:
|
442
|
+
- Entity(): desc.
|
443
|
+
- FeatureView(): timestamp_col, refresh_freq, desc.
|
444
|
+
- FeatureStore(): creation_mode.
|
445
|
+
- update_entity(): desc.
|
446
|
+
- register_feature_view(): block, overwrite.
|
447
|
+
- list_feature_views(): entity_name, feature_view_name.
|
448
|
+
- get_refresh_history(): verbose.
|
449
|
+
- retrieve_feature_values(): spine_timestamp_col, exclude_columns, include_feature_view_timestamp_col.
|
450
|
+
- generate_training_set(): save_as, spine_timestamp_col, spine_label_cols, exclude_columns,
|
451
|
+
include_feature_view_timestamp_col.
|
452
|
+
- generate_dataset(): version, spine_timestamp_col, spine_label_cols, exclude_columns,
|
453
|
+
include_feature_view_timestamp_col, desc, output_type.
|
454
|
+
|
455
|
+
## 1.5.4 (2024-07-11)
|
377
456
|
|
378
457
|
### Bug Fixes
|
379
458
|
|