snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/_complete.py +7 -33
- snowflake/ml/_internal/env_utils.py +11 -5
- snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
- snowflake/ml/_internal/telemetry.py +14 -0
- snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
- snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
- snowflake/ml/data/data_connector.py +59 -6
- snowflake/ml/data/data_ingestor.py +18 -1
- snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
- snowflake/ml/data/torch_dataset.py +33 -0
- snowflake/ml/dataset/dataset_metadata.py +3 -1
- snowflake/ml/dataset/dataset_reader.py +9 -3
- snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
- snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
- snowflake/ml/feature_store/examples/example_helper.py +69 -31
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
- snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
- snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
- snowflake/ml/feature_store/feature_store.py +59 -24
- snowflake/ml/feature_store/feature_view.py +148 -4
- snowflake/ml/model/_client/model/model_impl.py +11 -2
- snowflake/ml/model/_client/model/model_version_impl.py +171 -20
- snowflake/ml/model/_client/ops/model_ops.py +105 -27
- snowflake/ml/model/_client/ops/service_ops.py +121 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +95 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +31 -0
- snowflake/ml/model/_client/sql/model_version.py +13 -4
- snowflake/ml/model/_client/sql/service.py +129 -0
- snowflake/ml/model/_model_composer/model_composer.py +3 -0
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +10 -2
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
- snowflake/ml/model/_packager/model_env/model_env.py +7 -2
- snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
- snowflake/ml/model/_packager/model_handlers/catboost.py +19 -12
- snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +9 -5
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +27 -18
- snowflake/ml/model/_packager/model_handlers/llm.py +7 -3
- snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sklearn.py +87 -4
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +7 -2
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +8 -3
- snowflake/ml/model/_packager/model_handlers/xgboost.py +25 -16
- snowflake/ml/model/_packager/model_meta/model_meta.py +32 -2
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +19 -0
- snowflake/ml/model/_packager/model_packager.py +2 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
- snowflake/ml/model/type_hints.py +1 -3
- snowflake/ml/modeling/framework/base.py +28 -19
- snowflake/ml/modeling/pipeline/pipeline.py +3 -0
- snowflake/ml/registry/_manager/model_manager.py +16 -2
- snowflake/ml/utils/sql_client.py +22 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/METADATA +35 -2
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/RECORD +73 -62
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.1.dist-info}/top_level.txt +0 -0
@@ -45,7 +45,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
45
45
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
46
46
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
47
47
|
|
48
|
-
|
48
|
+
MODEL_BLOB_FILE_OR_DIR = "model.ubj"
|
49
49
|
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
50
50
|
_BINARY_CLASSIFICATION_OBJECTIVE_PREFIX = ["binary:"]
|
51
51
|
_MULTI_CLASSIFICATION_OBJECTIVE_PREFIX = ["multi:"]
|
@@ -53,33 +53,35 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
53
53
|
_REGRESSION_OBJECTIVE_PREFIX = ["reg:"]
|
54
54
|
|
55
55
|
@classmethod
|
56
|
-
def get_model_objective(
|
56
|
+
def get_model_objective(
|
57
|
+
cls, model: Union["xgboost.Booster", "xgboost.XGBModel"]
|
58
|
+
) -> model_meta_schema.ModelObjective:
|
57
59
|
import xgboost
|
58
60
|
|
59
61
|
if isinstance(model, xgboost.XGBClassifier) or isinstance(model, xgboost.XGBRFClassifier):
|
60
62
|
num_classes = handlers_utils.get_num_classes_if_exists(model)
|
61
63
|
if num_classes == 2:
|
62
|
-
return
|
63
|
-
return
|
64
|
+
return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
|
65
|
+
return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
|
64
66
|
if isinstance(model, xgboost.XGBRegressor) or isinstance(model, xgboost.XGBRFRegressor):
|
65
|
-
return
|
67
|
+
return model_meta_schema.ModelObjective.REGRESSION
|
66
68
|
if isinstance(model, xgboost.XGBRanker):
|
67
|
-
return
|
69
|
+
return model_meta_schema.ModelObjective.RANKING
|
68
70
|
model_params = json.loads(model.save_config())
|
69
71
|
model_objective = model_params["learner"]["objective"]
|
70
72
|
for classification_objective in cls._BINARY_CLASSIFICATION_OBJECTIVE_PREFIX:
|
71
73
|
if classification_objective in model_objective:
|
72
|
-
return
|
74
|
+
return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
|
73
75
|
for classification_objective in cls._MULTI_CLASSIFICATION_OBJECTIVE_PREFIX:
|
74
76
|
if classification_objective in model_objective:
|
75
|
-
return
|
77
|
+
return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
|
76
78
|
for ranking_objective in cls._RANKING_OBJECTIVE_PREFIX:
|
77
79
|
if ranking_objective in model_objective:
|
78
|
-
return
|
80
|
+
return model_meta_schema.ModelObjective.RANKING
|
79
81
|
for regression_objective in cls._REGRESSION_OBJECTIVE_PREFIX:
|
80
82
|
if regression_objective in model_objective:
|
81
|
-
return
|
82
|
-
return
|
83
|
+
return model_meta_schema.ModelObjective.REGRESSION
|
84
|
+
return model_meta_schema.ModelObjective.UNKNOWN
|
83
85
|
|
84
86
|
@classmethod
|
85
87
|
def can_handle(
|
@@ -146,9 +148,11 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
146
148
|
sample_input_data=sample_input_data,
|
147
149
|
get_prediction_fn=get_prediction,
|
148
150
|
)
|
149
|
-
|
151
|
+
model_objective = cls.get_model_objective(model)
|
152
|
+
model_meta.model_objective = model_objective
|
153
|
+
if kwargs.get("enable_explainability", True):
|
150
154
|
output_type = model_signature.DataType.DOUBLE
|
151
|
-
if
|
155
|
+
if model_objective == model_meta_schema.ModelObjective.MULTI_CLASSIFICATION:
|
152
156
|
output_type = model_signature.DataType.STRING
|
153
157
|
model_meta = handlers_utils.add_explain_method_signature(
|
154
158
|
model_meta=model_meta,
|
@@ -156,15 +160,18 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
156
160
|
target_method="predict",
|
157
161
|
output_return_type=output_type,
|
158
162
|
)
|
163
|
+
model_meta.function_properties = {
|
164
|
+
"explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False}
|
165
|
+
}
|
159
166
|
|
160
167
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
161
168
|
os.makedirs(model_blob_path, exist_ok=True)
|
162
|
-
model.save_model(os.path.join(model_blob_path, cls.
|
169
|
+
model.save_model(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
|
163
170
|
base_meta = model_blob_meta.ModelBlobMeta(
|
164
171
|
name=name,
|
165
172
|
model_type=cls.HANDLER_TYPE,
|
166
173
|
handler_version=cls.HANDLER_VERSION,
|
167
|
-
path=cls.
|
174
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
168
175
|
options=model_meta_schema.XgboostModelBlobOptions({"xgb_estimator_type": model.__class__.__name__}),
|
169
176
|
)
|
170
177
|
model_meta.models[name] = base_meta
|
@@ -177,11 +184,12 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
177
184
|
],
|
178
185
|
check_local_version=True,
|
179
186
|
)
|
180
|
-
if kwargs.get("enable_explainability",
|
187
|
+
if kwargs.get("enable_explainability", True):
|
181
188
|
model_meta.env.include_if_absent(
|
182
189
|
[model_env.ModelDependency(requirement="shap", pip_name="shap")],
|
183
190
|
check_local_version=True,
|
184
191
|
)
|
192
|
+
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
185
193
|
model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
|
186
194
|
|
187
195
|
@classmethod
|
@@ -224,6 +232,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
224
232
|
cls,
|
225
233
|
raw_model: Union["xgboost.Booster", "xgboost.XGBModel"],
|
226
234
|
model_meta: model_meta_api.ModelMetadata,
|
235
|
+
background_data: Optional[pd.DataFrame] = None,
|
227
236
|
**kwargs: Unpack[model_types.XGBModelLoadOptions],
|
228
237
|
) -> custom_model.CustomModel:
|
229
238
|
import xgboost
|
@@ -237,6 +237,7 @@ class ModelMetadata:
|
|
237
237
|
function_properties: A dict mapping function names to dict mapping function property key to value.
|
238
238
|
metadata: User provided key-value metadata of the model. Defaults to None.
|
239
239
|
creation_timestamp: Unix timestamp when the model metadata is created.
|
240
|
+
model_objective: Model objective like regression, classification etc.
|
240
241
|
"""
|
241
242
|
|
242
243
|
def telemetry_metadata(self) -> ModelMetadataTelemetryDict:
|
@@ -260,6 +261,8 @@ class ModelMetadata:
|
|
260
261
|
min_snowpark_ml_version: Optional[str] = None,
|
261
262
|
models: Optional[Dict[str, model_blob_meta.ModelBlobMeta]] = None,
|
262
263
|
original_metadata_version: Optional[str] = model_meta_schema.MODEL_METADATA_VERSION,
|
264
|
+
model_objective: Optional[model_meta_schema.ModelObjective] = model_meta_schema.ModelObjective.UNKNOWN,
|
265
|
+
explain_algorithm: Optional[model_meta_schema.ModelExplainAlgorithm] = None,
|
263
266
|
) -> None:
|
264
267
|
self.name = name
|
265
268
|
self.signatures: Dict[str, model_signature.ModelSignature] = dict()
|
@@ -284,6 +287,11 @@ class ModelMetadata:
|
|
284
287
|
|
285
288
|
self.original_metadata_version = original_metadata_version
|
286
289
|
|
290
|
+
self.model_objective: model_meta_schema.ModelObjective = (
|
291
|
+
model_objective or model_meta_schema.ModelObjective.UNKNOWN
|
292
|
+
)
|
293
|
+
self.explain_algorithm: Optional[model_meta_schema.ModelExplainAlgorithm] = explain_algorithm
|
294
|
+
|
287
295
|
@property
|
288
296
|
def min_snowpark_ml_version(self) -> str:
|
289
297
|
return self._min_snowpark_ml_version.base_version
|
@@ -321,9 +329,11 @@ class ModelMetadata:
|
|
321
329
|
model_dict = model_meta_schema.ModelMetadataDict(
|
322
330
|
{
|
323
331
|
"creation_timestamp": self.creation_timestamp,
|
324
|
-
"env": self.env.save_as_dict(
|
332
|
+
"env": self.env.save_as_dict(
|
333
|
+
pathlib.Path(model_dir_path), default_channel_override=env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
|
334
|
+
),
|
325
335
|
"runtimes": {
|
326
|
-
runtime_name: runtime.save(pathlib.Path(model_dir_path))
|
336
|
+
runtime_name: runtime.save(pathlib.Path(model_dir_path), default_channel_override="conda-forge")
|
327
337
|
for runtime_name, runtime in self.runtimes.items()
|
328
338
|
},
|
329
339
|
"metadata": self.metadata,
|
@@ -333,6 +343,13 @@ class ModelMetadata:
|
|
333
343
|
"signatures": {func_name: sig.to_dict() for func_name, sig in self.signatures.items()},
|
334
344
|
"version": model_meta_schema.MODEL_METADATA_VERSION,
|
335
345
|
"min_snowpark_ml_version": self.min_snowpark_ml_version,
|
346
|
+
"model_objective": self.model_objective.value,
|
347
|
+
"explainability": (
|
348
|
+
model_meta_schema.ExplainabilityMetadataDict(algorithm=self.explain_algorithm.value)
|
349
|
+
if self.explain_algorithm
|
350
|
+
else None
|
351
|
+
),
|
352
|
+
"function_properties": self.function_properties,
|
336
353
|
}
|
337
354
|
)
|
338
355
|
|
@@ -370,6 +387,9 @@ class ModelMetadata:
|
|
370
387
|
signatures=loaded_meta["signatures"],
|
371
388
|
version=original_loaded_meta_version,
|
372
389
|
min_snowpark_ml_version=loaded_meta_min_snowpark_ml_version,
|
390
|
+
model_objective=loaded_meta.get("model_objective", model_meta_schema.ModelObjective.UNKNOWN.value),
|
391
|
+
explainability=loaded_meta.get("explainability", None),
|
392
|
+
function_properties=loaded_meta.get("function_properties", {}),
|
373
393
|
)
|
374
394
|
|
375
395
|
@classmethod
|
@@ -406,6 +426,11 @@ class ModelMetadata:
|
|
406
426
|
else:
|
407
427
|
runtimes = None
|
408
428
|
|
429
|
+
explanation_algorithm_dict = model_dict.get("explainability", None)
|
430
|
+
explanation_algorithm = None
|
431
|
+
if explanation_algorithm_dict:
|
432
|
+
explanation_algorithm = model_meta_schema.ModelExplainAlgorithm(explanation_algorithm_dict["algorithm"])
|
433
|
+
|
409
434
|
return cls(
|
410
435
|
name=model_dict["name"],
|
411
436
|
model_type=model_dict["model_type"],
|
@@ -417,4 +442,9 @@ class ModelMetadata:
|
|
417
442
|
min_snowpark_ml_version=model_dict["min_snowpark_ml_version"],
|
418
443
|
models=models,
|
419
444
|
original_metadata_version=model_dict["version"],
|
445
|
+
model_objective=model_meta_schema.ModelObjective(
|
446
|
+
model_dict.get("model_objective", model_meta_schema.ModelObjective.UNKNOWN.value)
|
447
|
+
),
|
448
|
+
explain_algorithm=explanation_algorithm,
|
449
|
+
function_properties=model_dict.get("function_properties", {}),
|
420
450
|
)
|
@@ -71,6 +71,10 @@ ModelBlobOptions = Union[
|
|
71
71
|
]
|
72
72
|
|
73
73
|
|
74
|
+
class ExplainabilityMetadataDict(TypedDict):
|
75
|
+
algorithm: Required[str]
|
76
|
+
|
77
|
+
|
74
78
|
class ModelBlobMetadataDict(TypedDict):
|
75
79
|
name: Required[str]
|
76
80
|
model_type: Required[type_hints.SupportedModelHandlerType]
|
@@ -92,3 +96,18 @@ class ModelMetadataDict(TypedDict):
|
|
92
96
|
signatures: Required[Dict[str, Dict[str, Any]]]
|
93
97
|
version: Required[str]
|
94
98
|
min_snowpark_ml_version: Required[str]
|
99
|
+
model_objective: Required[str]
|
100
|
+
explainability: NotRequired[Optional[ExplainabilityMetadataDict]]
|
101
|
+
function_properties: NotRequired[Dict[str, Dict[str, Any]]]
|
102
|
+
|
103
|
+
|
104
|
+
class ModelObjective(Enum):
|
105
|
+
UNKNOWN = "unknown"
|
106
|
+
BINARY_CLASSIFICATION = "binary_classification"
|
107
|
+
MULTI_CLASSIFICATION = "multi_classification"
|
108
|
+
REGRESSION = "regression"
|
109
|
+
RANKING = "ranking"
|
110
|
+
|
111
|
+
|
112
|
+
class ModelExplainAlgorithm(Enum):
|
113
|
+
SHAP = "shap"
|
@@ -146,7 +146,8 @@ class ModelPackager:
|
|
146
146
|
m = handler.load_model(self.meta.name, self.meta, model_blobs_path, **options)
|
147
147
|
|
148
148
|
if as_custom_model:
|
149
|
-
|
149
|
+
background_data = handler.load_background_data(self.meta.name, model_blobs_path)
|
150
|
+
m = handler.convert_as_custom_model(m, self.meta, background_data, **options)
|
150
151
|
assert isinstance(m, custom_model.CustomModel)
|
151
152
|
|
152
153
|
self.model = m
|
@@ -67,7 +67,9 @@ class ModelRuntime:
|
|
67
67
|
def runtime_rel_path(self) -> pathlib.PurePosixPath:
|
68
68
|
return pathlib.PurePosixPath(ModelRuntime.RUNTIME_DIR_REL_PATH) / self.name
|
69
69
|
|
70
|
-
def save(
|
70
|
+
def save(
|
71
|
+
self, packager_path: pathlib.Path, default_channel_override: str = env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
|
72
|
+
) -> model_meta_schema.ModelRuntimeDict:
|
71
73
|
runtime_base_path = packager_path / self.runtime_rel_path
|
72
74
|
runtime_base_path.mkdir(parents=True, exist_ok=True)
|
73
75
|
|
@@ -80,7 +82,7 @@ class ModelRuntime:
|
|
80
82
|
self.runtime_env.conda_env_rel_path = self.runtime_rel_path / self.runtime_env.conda_env_rel_path
|
81
83
|
self.runtime_env.pip_requirements_rel_path = self.runtime_rel_path / self.runtime_env.pip_requirements_rel_path
|
82
84
|
|
83
|
-
env_dict = self.runtime_env.save_as_dict(packager_path)
|
85
|
+
env_dict = self.runtime_env.save_as_dict(packager_path, default_channel_override=default_channel_override)
|
84
86
|
|
85
87
|
return model_meta_schema.ModelRuntimeDict(
|
86
88
|
imports=list(map(str, self.imports)),
|
snowflake/ml/model/type_hints.py
CHANGED
@@ -233,12 +233,12 @@ class BaseModelSaveOption(TypedDict):
|
|
233
233
|
function_type: NotRequired[Literal["FUNCTION", "TABLE_FUNCTION"]]
|
234
234
|
method_options: NotRequired[Dict[str, ModelMethodSaveOptions]]
|
235
235
|
include_pip_dependencies: NotRequired[bool]
|
236
|
+
enable_explainability: NotRequired[bool]
|
236
237
|
|
237
238
|
|
238
239
|
class CatBoostModelSaveOptions(BaseModelSaveOption):
|
239
240
|
target_methods: NotRequired[Sequence[str]]
|
240
241
|
cuda_version: NotRequired[str]
|
241
|
-
enable_explainability: NotRequired[bool]
|
242
242
|
|
243
243
|
|
244
244
|
class CustomModelSaveOption(BaseModelSaveOption):
|
@@ -252,12 +252,10 @@ class SKLModelSaveOptions(BaseModelSaveOption):
|
|
252
252
|
class XGBModelSaveOptions(BaseModelSaveOption):
|
253
253
|
target_methods: NotRequired[Sequence[str]]
|
254
254
|
cuda_version: NotRequired[str]
|
255
|
-
enable_explainability: NotRequired[bool]
|
256
255
|
|
257
256
|
|
258
257
|
class LGBMModelSaveOptions(BaseModelSaveOption):
|
259
258
|
target_methods: NotRequired[Sequence[str]]
|
260
|
-
enable_explainability: NotRequired[bool]
|
261
259
|
|
262
260
|
|
263
261
|
class SNOWModelSaveOptions(BaseModelSaveOption):
|
@@ -1,7 +1,6 @@
|
|
1
1
|
#!/usr/bin/env python3
|
2
2
|
import inspect
|
3
3
|
from abc import abstractmethod
|
4
|
-
from collections import defaultdict
|
5
4
|
from datetime import datetime
|
6
5
|
from typing import Any, Dict, Iterable, List, Mapping, Optional, Union, overload
|
7
6
|
|
@@ -18,6 +17,7 @@ from snowflake.ml._internal.exceptions import (
|
|
18
17
|
)
|
19
18
|
from snowflake.ml._internal.lineage import lineage_utils
|
20
19
|
from snowflake.ml._internal.utils import identifier, parallelize
|
20
|
+
from snowflake.ml.data import data_source
|
21
21
|
from snowflake.ml.modeling.framework import _utils
|
22
22
|
from snowflake.snowpark import functions as F
|
23
23
|
|
@@ -246,7 +246,7 @@ class Base:
|
|
246
246
|
|
247
247
|
def get_params(self, deep: bool = True) -> Dict[str, Any]:
|
248
248
|
"""
|
249
|
-
Get parameters for this transformer.
|
249
|
+
Get the snowflake-ml parameters for this transformer.
|
250
250
|
|
251
251
|
Args:
|
252
252
|
deep: If True, will return the parameters for this transformer and
|
@@ -265,13 +265,13 @@ class Base:
|
|
265
265
|
out[key] = value
|
266
266
|
return out
|
267
267
|
|
268
|
-
def set_params(self, **params:
|
268
|
+
def set_params(self, **params: Any) -> None:
|
269
269
|
"""
|
270
270
|
Set the parameters of this transformer.
|
271
271
|
|
272
|
-
The method works on simple transformers as well as on nested
|
273
|
-
|
274
|
-
so that it's possible to update each component of a nested object.
|
272
|
+
The method works on simple transformers as well as on sklearn compatible pipelines with nested
|
273
|
+
objects, once the transformer has been fit. Nested objects have parameters of the form
|
274
|
+
``<component>__<parameter>`` so that it's possible to update each component of a nested object.
|
275
275
|
|
276
276
|
Args:
|
277
277
|
**params: Transformer parameter names mapped to their values.
|
@@ -283,12 +283,28 @@ class Base:
|
|
283
283
|
# simple optimization to gain speed (inspect is slow)
|
284
284
|
return
|
285
285
|
valid_params = self.get_params(deep=True)
|
286
|
+
valid_skl_params = {}
|
287
|
+
if hasattr(self, "_sklearn_object") and self._sklearn_object is not None:
|
288
|
+
valid_skl_params = self._sklearn_object.get_params()
|
286
289
|
|
287
|
-
nested_params: Dict[str, Any] = defaultdict(dict) # grouped by prefix
|
288
290
|
for key, value in params.items():
|
289
|
-
|
290
|
-
|
291
|
-
|
291
|
+
if valid_params.get("steps"):
|
292
|
+
# Recurse through pipeline steps
|
293
|
+
key, _, sub_key = key.partition("__")
|
294
|
+
for name, nested_object in valid_params["steps"]:
|
295
|
+
if name == key:
|
296
|
+
nested_object.set_params(**{sub_key: value})
|
297
|
+
|
298
|
+
elif key in valid_params:
|
299
|
+
setattr(self, key, value)
|
300
|
+
valid_params[key] = value
|
301
|
+
elif key in valid_skl_params:
|
302
|
+
# This dictionary would be empty if the following assert were not true, as specified above.
|
303
|
+
assert hasattr(self, "_sklearn_object") and self._sklearn_object is not None
|
304
|
+
setattr(self._sklearn_object, key, value)
|
305
|
+
valid_skl_params[key] = value
|
306
|
+
else:
|
307
|
+
local_valid_params = self._get_param_names() + list(valid_skl_params.keys())
|
292
308
|
raise exceptions.SnowflakeMLException(
|
293
309
|
error_code=error_codes.INVALID_ARGUMENT,
|
294
310
|
original_exception=ValueError(
|
@@ -298,15 +314,6 @@ class Base:
|
|
298
314
|
),
|
299
315
|
)
|
300
316
|
|
301
|
-
if delim:
|
302
|
-
nested_params[key][sub_key] = value
|
303
|
-
else:
|
304
|
-
setattr(self, key, value)
|
305
|
-
valid_params[key] = value
|
306
|
-
|
307
|
-
for key, sub_params in nested_params.items():
|
308
|
-
valid_params[key].set_params(**sub_params)
|
309
|
-
|
310
317
|
def get_sklearn_args(
|
311
318
|
self,
|
312
319
|
default_sklearn_obj: Optional[object] = None,
|
@@ -427,6 +434,8 @@ class BaseEstimator(Base):
|
|
427
434
|
def fit(self, dataset: Union[snowpark.DataFrame, pd.DataFrame]) -> "BaseEstimator":
|
428
435
|
"""Runs universal logics for all fit implementations."""
|
429
436
|
data_sources = lineage_utils.get_data_sources(dataset)
|
437
|
+
if not data_sources and isinstance(dataset, snowpark.DataFrame):
|
438
|
+
data_sources = [data_source.DataFrameInfo(dataset.queries["queries"][-1])]
|
430
439
|
lineage_utils.set_data_sources(self, data_sources)
|
431
440
|
return self._fit(dataset)
|
432
441
|
|
@@ -19,6 +19,7 @@ from snowflake.ml._internal import file_utils, telemetry
|
|
19
19
|
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
20
20
|
from snowflake.ml._internal.lineage import lineage_utils
|
21
21
|
from snowflake.ml._internal.utils import snowpark_dataframe_utils, temp_file_utils
|
22
|
+
from snowflake.ml.data import data_source
|
22
23
|
from snowflake.ml.model.model_signature import ModelSignature, _infer_signature
|
23
24
|
from snowflake.ml.modeling._internal.model_transformer_builder import (
|
24
25
|
ModelTransformerBuilder,
|
@@ -431,6 +432,8 @@ class Pipeline(base.BaseTransformer):
|
|
431
432
|
|
432
433
|
# Extract lineage information here since we're overriding fit() directly
|
433
434
|
data_sources = lineage_utils.get_data_sources(dataset)
|
435
|
+
if not data_sources and isinstance(dataset, snowpark.DataFrame):
|
436
|
+
data_sources = [data_source.DataFrameInfo(dataset.queries["queries"][-1])]
|
434
437
|
lineage_utils.set_data_sources(self, data_sources)
|
435
438
|
|
436
439
|
if self._can_be_trained_in_ml_runtime(dataset):
|
@@ -9,7 +9,7 @@ from snowflake.ml._internal.human_readable_id import hrid_generator
|
|
9
9
|
from snowflake.ml._internal.utils import sql_identifier
|
10
10
|
from snowflake.ml.model import model_signature, type_hints as model_types
|
11
11
|
from snowflake.ml.model._client.model import model_impl, model_version_impl
|
12
|
-
from snowflake.ml.model._client.ops import metadata_ops, model_ops
|
12
|
+
from snowflake.ml.model._client.ops import metadata_ops, model_ops, service_ops
|
13
13
|
from snowflake.ml.model._model_composer import model_composer
|
14
14
|
from snowflake.ml.model._packager.model_meta import model_meta
|
15
15
|
from snowflake.snowpark import session
|
@@ -30,6 +30,9 @@ class ModelManager:
|
|
30
30
|
self._model_ops = model_ops.ModelOperator(
|
31
31
|
session, database_name=self._database_name, schema_name=self._schema_name
|
32
32
|
)
|
33
|
+
self._service_ops = service_ops.ServiceOperator(
|
34
|
+
session, database_name=self._database_name, schema_name=self._schema_name
|
35
|
+
)
|
33
36
|
self._hrid_generator = hrid_generator.HRID16()
|
34
37
|
|
35
38
|
def log_model(
|
@@ -173,11 +176,16 @@ class ModelManager:
|
|
173
176
|
)
|
174
177
|
|
175
178
|
mv = model_version_impl.ModelVersion._ref(
|
176
|
-
model_ops.ModelOperator(
|
179
|
+
model_ops=model_ops.ModelOperator(
|
177
180
|
self._model_ops._session,
|
178
181
|
database_name=database_name_id or self._database_name,
|
179
182
|
schema_name=schema_name_id or self._schema_name,
|
180
183
|
),
|
184
|
+
service_ops=service_ops.ServiceOperator(
|
185
|
+
self._service_ops._session,
|
186
|
+
database_name=database_name_id or self._database_name,
|
187
|
+
schema_name=schema_name_id or self._schema_name,
|
188
|
+
),
|
181
189
|
model_name=model_name_id,
|
182
190
|
version_name=version_name_id,
|
183
191
|
)
|
@@ -216,6 +224,11 @@ class ModelManager:
|
|
216
224
|
database_name=database_name_id or self._database_name,
|
217
225
|
schema_name=schema_name_id or self._schema_name,
|
218
226
|
),
|
227
|
+
service_ops=service_ops.ServiceOperator(
|
228
|
+
self._service_ops._session,
|
229
|
+
database_name=database_name_id or self._database_name,
|
230
|
+
schema_name=schema_name_id or self._schema_name,
|
231
|
+
),
|
219
232
|
model_name=model_name_id,
|
220
233
|
)
|
221
234
|
else:
|
@@ -234,6 +247,7 @@ class ModelManager:
|
|
234
247
|
return [
|
235
248
|
model_impl.Model._ref(
|
236
249
|
self._model_ops,
|
250
|
+
service_ops=self._service_ops,
|
237
251
|
model_name=model_name,
|
238
252
|
)
|
239
253
|
for model_name in model_names
|
@@ -0,0 +1,22 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
from typing import Dict
|
3
|
+
|
4
|
+
|
5
|
+
class CreationOption(Enum):
|
6
|
+
FAIL_IF_NOT_EXIST = 1
|
7
|
+
CREATE_IF_NOT_EXIST = 2
|
8
|
+
OR_REPLACE = 3
|
9
|
+
|
10
|
+
|
11
|
+
class CreationMode:
|
12
|
+
def __init__(self, *, if_not_exists: bool = False, or_replace: bool = False) -> None:
|
13
|
+
self.if_not_exists = if_not_exists
|
14
|
+
self.or_replace = or_replace
|
15
|
+
|
16
|
+
def get_ddl_phrases(self) -> Dict[CreationOption, str]:
|
17
|
+
if_not_exists_sql = " IF NOT EXISTS" if self.if_not_exists else ""
|
18
|
+
or_replace_sql = " OR REPLACE" if self.or_replace else ""
|
19
|
+
return {
|
20
|
+
CreationOption.CREATE_IF_NOT_EXIST: if_not_exists_sql,
|
21
|
+
CreationOption.OR_REPLACE: or_replace_sql,
|
22
|
+
}
|
snowflake/ml/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
VERSION="1.6.
|
1
|
+
VERSION="1.6.1"
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: snowflake-ml-python
|
3
|
-
Version: 1.6.
|
3
|
+
Version: 1.6.1
|
4
4
|
Summary: The machine learning client library that is used for interacting with Snowflake to build machine learning solutions.
|
5
5
|
Author-email: "Snowflake, Inc" <support@snowflake.com>
|
6
6
|
License:
|
@@ -373,7 +373,32 @@ be compatibility issues. Server-side functionality that `snowflake-ml-python` de
|
|
373
373
|
|
374
374
|
# Release History
|
375
375
|
|
376
|
-
## 1.6.
|
376
|
+
## 1.6.1 (TBD)
|
377
|
+
|
378
|
+
### Bug Fixes
|
379
|
+
|
380
|
+
- Feature Store: Support large metadata blob when generating dataset
|
381
|
+
- Feature Store: Added a hidden knob in FeatureView as kargs for setting customized
|
382
|
+
refresh_mode
|
383
|
+
- Registry: Fix an error message in Model Version `run` when `function_name` is not mentioned and model has multiple
|
384
|
+
target methods.
|
385
|
+
- Cortex inference: snowflake.cortex.Complete now only uses the REST API for streaming and the use_rest_api_experimental
|
386
|
+
is no longer needed.
|
387
|
+
- Feature Store: Add a new API: FeatureView.list_columns() which list all column information.
|
388
|
+
- Data: Fix `DataFrame` ingestion with `ArrowIngestor`.
|
389
|
+
|
390
|
+
### New Features
|
391
|
+
|
392
|
+
- Enable `set_params` to set the parameters of the underlying sklearn estimator, if the snowflake-ml model has been fit.
|
393
|
+
- Data: Add top-level exports for `DataConnector` and `DataSource` to `snowflake.ml.data`.
|
394
|
+
- Data: Add `snowflake.ml.data.ingestor_utils` module with utility functions helpful for `DataIngestor` implementations.
|
395
|
+
- Data: Add new `to_torch_dataset()` connector to `DataConnector` to replace deprecated DataPipe.
|
396
|
+
- Registry: Option to `enable_explainability` set to True by default for XGBoost, LightGBM and CatBoost as PuPr feature.
|
397
|
+
- Registry: Option to `enable_explainability` when registering SHAP supported sklearn models.
|
398
|
+
|
399
|
+
### Behavior Changes
|
400
|
+
|
401
|
+
## 1.6.0 (2024-07-29)
|
377
402
|
|
378
403
|
### Bug Fixes
|
379
404
|
|
@@ -402,6 +427,14 @@ be compatibility issues. Server-side functionality that `snowflake-ml-python` de
|
|
402
427
|
distributed_hpo_trainer.ENABLE_EFFICIENT_MEMORY_USAGE = False
|
403
428
|
`
|
404
429
|
- Registry: Option to `enable_explainability` when registering LightGBM models as a pre-PuPr feature.
|
430
|
+
- Data: Add new `snowflake.ml.data` preview module which contains data reading utilities like `DataConnector`
|
431
|
+
- `DataConnector` provides efficient connectors from Snowpark `DataFrame`
|
432
|
+
and Snowpark ML `Dataset` to external frameworks like PyTorch, TensorFlow, and Pandas. Create `DataConnector`
|
433
|
+
instances using the classmethod constructors `DataConnector.from_dataset()` and `DataConnector.from_dataframe()`.
|
434
|
+
- Data: Add new `DataConnector.from_sources()` classmethod constructor for constructing from `DataSource` objects.
|
435
|
+
- Data: Add new `ingestor_class` arg to `DataConnector` classmethod constructors for easier `DataIngestor` injection.
|
436
|
+
- Dataset: `DatasetReader` now subclasses new `DataConnector` class.
|
437
|
+
- Add optional `limit` arg to `DatasetReader.to_pandas()`
|
405
438
|
|
406
439
|
### Behavior Changes
|
407
440
|
|