snowflake-ml-python 1.5.4__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +2 -0
- snowflake/cortex/_classify_text.py +36 -0
- snowflake/cortex/_complete.py +66 -35
- snowflake/cortex/_util.py +4 -4
- snowflake/ml/_internal/env_utils.py +11 -5
- snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
- snowflake/ml/_internal/lineage/lineage_utils.py +4 -4
- snowflake/ml/_internal/telemetry.py +26 -2
- snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
- snowflake/ml/data/_internal/arrow_ingestor.py +284 -0
- snowflake/ml/data/data_connector.py +186 -0
- snowflake/ml/data/data_ingestor.py +45 -0
- snowflake/ml/data/data_source.py +23 -0
- snowflake/ml/data/ingestor_utils.py +62 -0
- snowflake/ml/data/torch_dataset.py +33 -0
- snowflake/ml/dataset/dataset.py +1 -13
- snowflake/ml/dataset/dataset_metadata.py +3 -1
- snowflake/ml/dataset/dataset_reader.py +23 -117
- snowflake/ml/feature_store/access_manager.py +7 -1
- snowflake/ml/feature_store/entity.py +19 -2
- snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
- snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +20 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +37 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +30 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +7 -0
- snowflake/ml/feature_store/examples/example_helper.py +278 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +12 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +44 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +9 -0
- snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +36 -0
- snowflake/ml/feature_store/examples/source_data/fraud_transactions.yaml +29 -0
- snowflake/ml/feature_store/examples/source_data/nyc_yellow_trips.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/winequality_red.yaml +32 -0
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +14 -0
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +36 -0
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +24 -0
- snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +8 -0
- snowflake/ml/feature_store/feature_store.py +637 -76
- snowflake/ml/feature_store/feature_view.py +316 -9
- snowflake/ml/fileset/stage_fs.py +18 -10
- snowflake/ml/lineage/lineage_node.py +1 -1
- snowflake/ml/model/_client/model/model_impl.py +11 -2
- snowflake/ml/model/_client/model/model_version_impl.py +171 -20
- snowflake/ml/model/_client/ops/model_ops.py +105 -27
- snowflake/ml/model/_client/ops/service_ops.py +121 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +95 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +31 -0
- snowflake/ml/model/_client/sql/model_version.py +13 -4
- snowflake/ml/model/_client/sql/service.py +129 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
- snowflake/ml/model/_model_composer/model_composer.py +14 -14
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +33 -17
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -1
- snowflake/ml/model/_model_composer/model_method/function_generator.py +3 -3
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +3 -32
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +3 -27
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +3 -32
- snowflake/ml/model/_model_composer/model_method/model_method.py +5 -2
- snowflake/ml/model/_packager/model_env/model_env.py +7 -2
- snowflake/ml/model/_packager/model_handlers/_base.py +30 -3
- snowflake/ml/model/_packager/model_handlers/_utils.py +58 -1
- snowflake/ml/model/_packager/model_handlers/catboost.py +52 -3
- snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +9 -5
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +80 -3
- snowflake/ml/model/_packager/model_handlers/llm.py +7 -3
- snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sklearn.py +87 -4
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +7 -2
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +8 -3
- snowflake/ml/model/_packager/model_handlers/xgboost.py +71 -3
- snowflake/ml/model/_packager/model_meta/model_meta.py +32 -2
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +19 -0
- snowflake/ml/model/_packager/model_packager.py +2 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +7 -7
- snowflake/ml/model/model_signature.py +4 -4
- snowflake/ml/model/type_hints.py +2 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +1 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +13 -1
- snowflake/ml/modeling/framework/base.py +28 -19
- snowflake/ml/modeling/impute/simple_imputer.py +26 -0
- snowflake/ml/modeling/pipeline/pipeline.py +7 -4
- snowflake/ml/registry/_manager/model_manager.py +16 -2
- snowflake/ml/registry/registry.py +100 -13
- snowflake/ml/utils/sql_client.py +22 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/METADATA +81 -2
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/RECORD +99 -66
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/lineage/data_source.py +0 -10
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/top_level.txt +0 -0
@@ -6,6 +6,7 @@ import numpy as np
|
|
6
6
|
import pandas as pd
|
7
7
|
from typing_extensions import TypeGuard, Unpack
|
8
8
|
|
9
|
+
import snowflake.snowpark.dataframe as sp_df
|
9
10
|
from snowflake.ml._internal import type_utils
|
10
11
|
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
11
12
|
from snowflake.ml.model._packager.model_env import model_env
|
@@ -14,8 +15,13 @@ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
|
14
15
|
from snowflake.ml.model._packager.model_meta import (
|
15
16
|
model_blob_meta,
|
16
17
|
model_meta as model_meta_api,
|
18
|
+
model_meta_schema,
|
19
|
+
)
|
20
|
+
from snowflake.ml.model._signatures import (
|
21
|
+
numpy_handler,
|
22
|
+
snowpark_handler,
|
23
|
+
utils as model_signature_utils,
|
17
24
|
)
|
18
|
-
from snowflake.ml.model._signatures import numpy_handler, utils as model_signature_utils
|
19
25
|
|
20
26
|
if TYPE_CHECKING:
|
21
27
|
import sklearn.base
|
@@ -36,6 +42,27 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
36
42
|
|
37
43
|
DEFAULT_TARGET_METHODS = ["predict", "transform", "predict_proba", "predict_log_proba", "decision_function"]
|
38
44
|
|
45
|
+
@classmethod
|
46
|
+
def get_model_objective(
|
47
|
+
cls, model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"]
|
48
|
+
) -> model_meta_schema.ModelObjective:
|
49
|
+
import sklearn.pipeline
|
50
|
+
from sklearn.base import is_classifier, is_regressor
|
51
|
+
|
52
|
+
if isinstance(model, sklearn.pipeline.Pipeline):
|
53
|
+
return model_meta_schema.ModelObjective.UNKNOWN
|
54
|
+
if is_regressor(model):
|
55
|
+
return model_meta_schema.ModelObjective.REGRESSION
|
56
|
+
if is_classifier(model):
|
57
|
+
classes_list = getattr(model, "classes_", [])
|
58
|
+
num_classes = getattr(model, "n_classes_", None) or len(classes_list)
|
59
|
+
if isinstance(num_classes, int):
|
60
|
+
if num_classes > 2:
|
61
|
+
return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
|
62
|
+
return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
|
63
|
+
return model_meta_schema.ModelObjective.UNKNOWN
|
64
|
+
return model_meta_schema.ModelObjective.UNKNOWN
|
65
|
+
|
39
66
|
@classmethod
|
40
67
|
def can_handle(
|
41
68
|
cls,
|
@@ -79,11 +106,33 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
79
106
|
is_sub_model: Optional[bool] = False,
|
80
107
|
**kwargs: Unpack[model_types.SKLModelSaveOptions],
|
81
108
|
) -> None:
|
109
|
+
enable_explainability = kwargs.get("enable_explainability", False)
|
110
|
+
|
82
111
|
import sklearn.base
|
83
112
|
import sklearn.pipeline
|
84
113
|
|
85
114
|
assert isinstance(model, sklearn.base.BaseEstimator) or isinstance(model, sklearn.pipeline.Pipeline)
|
86
115
|
|
116
|
+
enable_explainability = kwargs.get("enable_explainability", False)
|
117
|
+
if enable_explainability:
|
118
|
+
# TODO: Currently limited to pandas df, need to extend to other types.
|
119
|
+
if sample_input_data is None or not (
|
120
|
+
isinstance(sample_input_data, pd.DataFrame) or isinstance(sample_input_data, sp_df.DataFrame)
|
121
|
+
):
|
122
|
+
raise ValueError(
|
123
|
+
"Sample input data is required to enable explainability. Currently we only support this for "
|
124
|
+
+ "`pandas.DataFrame` and `snowflake.snowpark.dataframe.DataFrame`."
|
125
|
+
)
|
126
|
+
sample_input_data_pandas = (
|
127
|
+
sample_input_data
|
128
|
+
if isinstance(sample_input_data, pd.DataFrame)
|
129
|
+
else snowpark_handler.SnowparkDataFrameHandler.convert_to_df(sample_input_data)
|
130
|
+
)
|
131
|
+
data_blob_path = os.path.join(model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR)
|
132
|
+
os.makedirs(data_blob_path, exist_ok=True)
|
133
|
+
with open(os.path.join(data_blob_path, name + cls.BG_DATA_FILE_SUFFIX), "wb") as f:
|
134
|
+
sample_input_data_pandas.to_parquet(f)
|
135
|
+
|
87
136
|
if not is_sub_model:
|
88
137
|
target_methods = handlers_utils.get_target_methods(
|
89
138
|
model=model,
|
@@ -110,19 +159,36 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
110
159
|
get_prediction_fn=get_prediction,
|
111
160
|
)
|
112
161
|
|
162
|
+
if enable_explainability:
|
163
|
+
output_type = model_signature.DataType.DOUBLE
|
164
|
+
if cls.get_model_objective(model) == model_meta_schema.ModelObjective.MULTI_CLASSIFICATION:
|
165
|
+
output_type = model_signature.DataType.STRING
|
166
|
+
model_meta = handlers_utils.add_explain_method_signature(
|
167
|
+
model_meta=model_meta,
|
168
|
+
explain_method="explain",
|
169
|
+
target_method="predict",
|
170
|
+
output_return_type=output_type,
|
171
|
+
)
|
172
|
+
|
113
173
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
114
174
|
os.makedirs(model_blob_path, exist_ok=True)
|
115
|
-
with open(os.path.join(model_blob_path, cls.
|
175
|
+
with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
|
116
176
|
cloudpickle.dump(model, f)
|
117
177
|
base_meta = model_blob_meta.ModelBlobMeta(
|
118
178
|
name=name,
|
119
179
|
model_type=cls.HANDLER_TYPE,
|
120
180
|
handler_version=cls.HANDLER_VERSION,
|
121
|
-
path=cls.
|
181
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
122
182
|
)
|
123
183
|
model_meta.models[name] = base_meta
|
124
184
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
125
185
|
|
186
|
+
if enable_explainability:
|
187
|
+
model_meta.env.include_if_absent(
|
188
|
+
[model_env.ModelDependency(requirement="shap", pip_name="shap")],
|
189
|
+
check_local_version=True,
|
190
|
+
)
|
191
|
+
|
126
192
|
model_meta.env.include_if_absent(
|
127
193
|
[model_env.ModelDependency(requirement="scikit-learn", pip_name="scikit-learn")], check_local_version=True
|
128
194
|
)
|
@@ -153,6 +219,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
153
219
|
cls,
|
154
220
|
raw_model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
|
155
221
|
model_meta: model_meta_api.ModelMetadata,
|
222
|
+
background_data: Optional[pd.DataFrame] = None,
|
156
223
|
**kwargs: Unpack[model_types.SKLModelLoadOptions],
|
157
224
|
) -> custom_model.CustomModel:
|
158
225
|
from snowflake.ml.model import custom_model
|
@@ -165,6 +232,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
165
232
|
raw_model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
|
166
233
|
signature: model_signature.ModelSignature,
|
167
234
|
target_method: str,
|
235
|
+
background_data: Optional[pd.DataFrame],
|
168
236
|
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
169
237
|
@custom_model.inference_api
|
170
238
|
def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
@@ -179,11 +247,26 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
179
247
|
|
180
248
|
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
181
249
|
|
250
|
+
@custom_model.inference_api
|
251
|
+
def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
252
|
+
import shap
|
253
|
+
|
254
|
+
# TODO: if not resolved by explainer, we need to pass the callable function
|
255
|
+
try:
|
256
|
+
explainer = shap.Explainer(raw_model, background_data)
|
257
|
+
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
|
258
|
+
except TypeError as e:
|
259
|
+
raise ValueError(f"Explanation for this model type not supported yet: {str(e)}")
|
260
|
+
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
261
|
+
|
262
|
+
if target_method == "explain":
|
263
|
+
return explain_fn
|
264
|
+
|
182
265
|
return fn
|
183
266
|
|
184
267
|
type_method_dict = {}
|
185
268
|
for target_method_name, sig in model_meta.signatures.items():
|
186
|
-
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name)
|
269
|
+
type_method_dict[target_method_name] = fn_factory(raw_model, sig, target_method_name, background_data)
|
187
270
|
|
188
271
|
_SKLModel = type(
|
189
272
|
"_SKLModel",
|
@@ -73,6 +73,10 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
73
73
|
is_sub_model: Optional[bool] = False,
|
74
74
|
**kwargs: Unpack[model_types.SNOWModelSaveOptions],
|
75
75
|
) -> None:
|
76
|
+
enable_explainability = kwargs.get("enable_explainability", False)
|
77
|
+
if enable_explainability:
|
78
|
+
raise NotImplementedError("Explainability is not supported for Snowpark ML model.")
|
79
|
+
|
76
80
|
from snowflake.ml.modeling.framework.base import BaseEstimator
|
77
81
|
|
78
82
|
assert isinstance(model, BaseEstimator)
|
@@ -103,13 +107,13 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
103
107
|
|
104
108
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
105
109
|
os.makedirs(model_blob_path, exist_ok=True)
|
106
|
-
with open(os.path.join(model_blob_path, cls.
|
110
|
+
with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
|
107
111
|
cloudpickle.dump(model, f)
|
108
112
|
base_meta = model_blob_meta.ModelBlobMeta(
|
109
113
|
name=name,
|
110
114
|
model_type=cls.HANDLER_TYPE,
|
111
115
|
handler_version=cls.HANDLER_VERSION,
|
112
|
-
path=cls.
|
116
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
113
117
|
)
|
114
118
|
model_meta.models[name] = base_meta
|
115
119
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
@@ -146,6 +150,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
146
150
|
cls,
|
147
151
|
raw_model: "BaseEstimator",
|
148
152
|
model_meta: model_meta_api.ModelMetadata,
|
153
|
+
background_data: Optional[pd.DataFrame] = None,
|
149
154
|
**kwargs: Unpack[model_types.SNOWModelLoadOptions],
|
150
155
|
) -> custom_model.CustomModel:
|
151
156
|
from snowflake.ml.model import custom_model
|
@@ -36,7 +36,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
36
36
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
37
37
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
38
38
|
|
39
|
-
|
39
|
+
MODEL_BLOB_FILE_OR_DIR = "model"
|
40
40
|
DEFAULT_TARGET_METHODS = ["__call__"]
|
41
41
|
|
42
42
|
@classmethod
|
@@ -68,6 +68,10 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
68
68
|
is_sub_model: Optional[bool] = False,
|
69
69
|
**kwargs: Unpack[model_types.TensorflowSaveOptions],
|
70
70
|
) -> None:
|
71
|
+
enable_explainability = kwargs.get("enable_explainability", False)
|
72
|
+
if enable_explainability:
|
73
|
+
raise NotImplementedError("Explainability is not supported for Tensorflow model.")
|
74
|
+
|
71
75
|
import tensorflow
|
72
76
|
|
73
77
|
assert isinstance(model, tensorflow.Module)
|
@@ -114,15 +118,15 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
114
118
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
115
119
|
os.makedirs(model_blob_path, exist_ok=True)
|
116
120
|
if isinstance(model, tensorflow.keras.Model):
|
117
|
-
tensorflow.keras.models.save_model(model, os.path.join(model_blob_path, cls.
|
121
|
+
tensorflow.keras.models.save_model(model, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
|
118
122
|
else:
|
119
|
-
tensorflow.saved_model.save(model, os.path.join(model_blob_path, cls.
|
123
|
+
tensorflow.saved_model.save(model, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
|
120
124
|
|
121
125
|
base_meta = model_blob_meta.ModelBlobMeta(
|
122
126
|
name=name,
|
123
127
|
model_type=cls.HANDLER_TYPE,
|
124
128
|
handler_version=cls.HANDLER_VERSION,
|
125
|
-
path=cls.
|
129
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
126
130
|
)
|
127
131
|
model_meta.models[name] = base_meta
|
128
132
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
@@ -156,6 +160,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
156
160
|
cls,
|
157
161
|
raw_model: "tensorflow.Module",
|
158
162
|
model_meta: model_meta_api.ModelMetadata,
|
163
|
+
background_data: Optional[pd.DataFrame] = None,
|
159
164
|
**kwargs: Unpack[model_types.TensorflowLoadOptions],
|
160
165
|
) -> custom_model.CustomModel:
|
161
166
|
import tensorflow
|
@@ -34,7 +34,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
34
34
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
35
35
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
36
36
|
|
37
|
-
|
37
|
+
MODEL_BLOB_FILE_OR_DIR = "model.pt"
|
38
38
|
DEFAULT_TARGET_METHODS = ["forward"]
|
39
39
|
|
40
40
|
@classmethod
|
@@ -66,6 +66,10 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
66
66
|
is_sub_model: Optional[bool] = False,
|
67
67
|
**kwargs: Unpack[model_types.TorchScriptSaveOptions],
|
68
68
|
) -> None:
|
69
|
+
enable_explainability = kwargs.get("enable_explainability", False)
|
70
|
+
if enable_explainability:
|
71
|
+
raise NotImplementedError("Explainability is not supported for Torch Script model.")
|
72
|
+
|
69
73
|
import torch
|
70
74
|
|
71
75
|
assert isinstance(model, torch.jit.ScriptModule) # type:ignore[attr-defined]
|
@@ -106,13 +110,13 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
106
110
|
|
107
111
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
108
112
|
os.makedirs(model_blob_path, exist_ok=True)
|
109
|
-
with open(os.path.join(model_blob_path, cls.
|
113
|
+
with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
|
110
114
|
torch.jit.save(model, f) # type:ignore[attr-defined]
|
111
115
|
base_meta = model_blob_meta.ModelBlobMeta(
|
112
116
|
name=name,
|
113
117
|
model_type=cls.HANDLER_TYPE,
|
114
118
|
handler_version=cls.HANDLER_VERSION,
|
115
|
-
path=cls.
|
119
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
116
120
|
)
|
117
121
|
model_meta.models[name] = base_meta
|
118
122
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
@@ -152,6 +156,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
152
156
|
cls,
|
153
157
|
raw_model: "torch.jit.ScriptModule", # type:ignore[name-defined]
|
154
158
|
model_meta: model_meta_api.ModelMetadata,
|
159
|
+
background_data: Optional[pd.DataFrame] = None,
|
155
160
|
**kwargs: Unpack[model_types.TorchScriptLoadOptions],
|
156
161
|
) -> custom_model.CustomModel:
|
157
162
|
from snowflake.ml.model import custom_model
|
@@ -1,4 +1,5 @@
|
|
1
1
|
# mypy: disable-error-code="import"
|
2
|
+
import json
|
2
3
|
import os
|
3
4
|
from typing import (
|
4
5
|
TYPE_CHECKING,
|
@@ -44,8 +45,43 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
44
45
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
45
46
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
46
47
|
|
47
|
-
|
48
|
+
MODEL_BLOB_FILE_OR_DIR = "model.ubj"
|
48
49
|
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
50
|
+
_BINARY_CLASSIFICATION_OBJECTIVE_PREFIX = ["binary:"]
|
51
|
+
_MULTI_CLASSIFICATION_OBJECTIVE_PREFIX = ["multi:"]
|
52
|
+
_RANKING_OBJECTIVE_PREFIX = ["rank:"]
|
53
|
+
_REGRESSION_OBJECTIVE_PREFIX = ["reg:"]
|
54
|
+
|
55
|
+
@classmethod
|
56
|
+
def get_model_objective(
|
57
|
+
cls, model: Union["xgboost.Booster", "xgboost.XGBModel"]
|
58
|
+
) -> model_meta_schema.ModelObjective:
|
59
|
+
import xgboost
|
60
|
+
|
61
|
+
if isinstance(model, xgboost.XGBClassifier) or isinstance(model, xgboost.XGBRFClassifier):
|
62
|
+
num_classes = handlers_utils.get_num_classes_if_exists(model)
|
63
|
+
if num_classes == 2:
|
64
|
+
return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
|
65
|
+
return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
|
66
|
+
if isinstance(model, xgboost.XGBRegressor) or isinstance(model, xgboost.XGBRFRegressor):
|
67
|
+
return model_meta_schema.ModelObjective.REGRESSION
|
68
|
+
if isinstance(model, xgboost.XGBRanker):
|
69
|
+
return model_meta_schema.ModelObjective.RANKING
|
70
|
+
model_params = json.loads(model.save_config())
|
71
|
+
model_objective = model_params["learner"]["objective"]
|
72
|
+
for classification_objective in cls._BINARY_CLASSIFICATION_OBJECTIVE_PREFIX:
|
73
|
+
if classification_objective in model_objective:
|
74
|
+
return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
|
75
|
+
for classification_objective in cls._MULTI_CLASSIFICATION_OBJECTIVE_PREFIX:
|
76
|
+
if classification_objective in model_objective:
|
77
|
+
return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
|
78
|
+
for ranking_objective in cls._RANKING_OBJECTIVE_PREFIX:
|
79
|
+
if ranking_objective in model_objective:
|
80
|
+
return model_meta_schema.ModelObjective.RANKING
|
81
|
+
for regression_objective in cls._REGRESSION_OBJECTIVE_PREFIX:
|
82
|
+
if regression_objective in model_objective:
|
83
|
+
return model_meta_schema.ModelObjective.REGRESSION
|
84
|
+
return model_meta_schema.ModelObjective.UNKNOWN
|
49
85
|
|
50
86
|
@classmethod
|
51
87
|
def can_handle(
|
@@ -112,15 +148,30 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
112
148
|
sample_input_data=sample_input_data,
|
113
149
|
get_prediction_fn=get_prediction,
|
114
150
|
)
|
151
|
+
model_objective = cls.get_model_objective(model)
|
152
|
+
model_meta.model_objective = model_objective
|
153
|
+
if kwargs.get("enable_explainability", True):
|
154
|
+
output_type = model_signature.DataType.DOUBLE
|
155
|
+
if model_objective == model_meta_schema.ModelObjective.MULTI_CLASSIFICATION:
|
156
|
+
output_type = model_signature.DataType.STRING
|
157
|
+
model_meta = handlers_utils.add_explain_method_signature(
|
158
|
+
model_meta=model_meta,
|
159
|
+
explain_method="explain",
|
160
|
+
target_method="predict",
|
161
|
+
output_return_type=output_type,
|
162
|
+
)
|
163
|
+
model_meta.function_properties = {
|
164
|
+
"explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False}
|
165
|
+
}
|
115
166
|
|
116
167
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
117
168
|
os.makedirs(model_blob_path, exist_ok=True)
|
118
|
-
model.save_model(os.path.join(model_blob_path, cls.
|
169
|
+
model.save_model(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
|
119
170
|
base_meta = model_blob_meta.ModelBlobMeta(
|
120
171
|
name=name,
|
121
172
|
model_type=cls.HANDLER_TYPE,
|
122
173
|
handler_version=cls.HANDLER_VERSION,
|
123
|
-
path=cls.
|
174
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
124
175
|
options=model_meta_schema.XgboostModelBlobOptions({"xgb_estimator_type": model.__class__.__name__}),
|
125
176
|
)
|
126
177
|
model_meta.models[name] = base_meta
|
@@ -133,6 +184,12 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
133
184
|
],
|
134
185
|
check_local_version=True,
|
135
186
|
)
|
187
|
+
if kwargs.get("enable_explainability", True):
|
188
|
+
model_meta.env.include_if_absent(
|
189
|
+
[model_env.ModelDependency(requirement="shap", pip_name="shap")],
|
190
|
+
check_local_version=True,
|
191
|
+
)
|
192
|
+
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
136
193
|
model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
|
137
194
|
|
138
195
|
@classmethod
|
@@ -175,6 +232,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
175
232
|
cls,
|
176
233
|
raw_model: Union["xgboost.Booster", "xgboost.XGBModel"],
|
177
234
|
model_meta: model_meta_api.ModelMetadata,
|
235
|
+
background_data: Optional[pd.DataFrame] = None,
|
178
236
|
**kwargs: Unpack[model_types.XGBModelLoadOptions],
|
179
237
|
) -> custom_model.CustomModel:
|
180
238
|
import xgboost
|
@@ -206,6 +264,16 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
206
264
|
|
207
265
|
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
208
266
|
|
267
|
+
@custom_model.inference_api
|
268
|
+
def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
269
|
+
import shap
|
270
|
+
|
271
|
+
explainer = shap.TreeExplainer(raw_model)
|
272
|
+
df = pd.DataFrame(explainer(X).values)
|
273
|
+
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
274
|
+
|
275
|
+
if target_method == "explain":
|
276
|
+
return explain_fn
|
209
277
|
return fn
|
210
278
|
|
211
279
|
type_method_dict: Dict[str, Any] = {"_raw_model": raw_model}
|
@@ -237,6 +237,7 @@ class ModelMetadata:
|
|
237
237
|
function_properties: A dict mapping function names to dict mapping function property key to value.
|
238
238
|
metadata: User provided key-value metadata of the model. Defaults to None.
|
239
239
|
creation_timestamp: Unix timestamp when the model metadata is created.
|
240
|
+
model_objective: Model objective like regression, classification etc.
|
240
241
|
"""
|
241
242
|
|
242
243
|
def telemetry_metadata(self) -> ModelMetadataTelemetryDict:
|
@@ -260,6 +261,8 @@ class ModelMetadata:
|
|
260
261
|
min_snowpark_ml_version: Optional[str] = None,
|
261
262
|
models: Optional[Dict[str, model_blob_meta.ModelBlobMeta]] = None,
|
262
263
|
original_metadata_version: Optional[str] = model_meta_schema.MODEL_METADATA_VERSION,
|
264
|
+
model_objective: Optional[model_meta_schema.ModelObjective] = model_meta_schema.ModelObjective.UNKNOWN,
|
265
|
+
explain_algorithm: Optional[model_meta_schema.ModelExplainAlgorithm] = None,
|
263
266
|
) -> None:
|
264
267
|
self.name = name
|
265
268
|
self.signatures: Dict[str, model_signature.ModelSignature] = dict()
|
@@ -284,6 +287,11 @@ class ModelMetadata:
|
|
284
287
|
|
285
288
|
self.original_metadata_version = original_metadata_version
|
286
289
|
|
290
|
+
self.model_objective: model_meta_schema.ModelObjective = (
|
291
|
+
model_objective or model_meta_schema.ModelObjective.UNKNOWN
|
292
|
+
)
|
293
|
+
self.explain_algorithm: Optional[model_meta_schema.ModelExplainAlgorithm] = explain_algorithm
|
294
|
+
|
287
295
|
@property
|
288
296
|
def min_snowpark_ml_version(self) -> str:
|
289
297
|
return self._min_snowpark_ml_version.base_version
|
@@ -321,9 +329,11 @@ class ModelMetadata:
|
|
321
329
|
model_dict = model_meta_schema.ModelMetadataDict(
|
322
330
|
{
|
323
331
|
"creation_timestamp": self.creation_timestamp,
|
324
|
-
"env": self.env.save_as_dict(
|
332
|
+
"env": self.env.save_as_dict(
|
333
|
+
pathlib.Path(model_dir_path), default_channel_override=env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
|
334
|
+
),
|
325
335
|
"runtimes": {
|
326
|
-
runtime_name: runtime.save(pathlib.Path(model_dir_path))
|
336
|
+
runtime_name: runtime.save(pathlib.Path(model_dir_path), default_channel_override="conda-forge")
|
327
337
|
for runtime_name, runtime in self.runtimes.items()
|
328
338
|
},
|
329
339
|
"metadata": self.metadata,
|
@@ -333,6 +343,13 @@ class ModelMetadata:
|
|
333
343
|
"signatures": {func_name: sig.to_dict() for func_name, sig in self.signatures.items()},
|
334
344
|
"version": model_meta_schema.MODEL_METADATA_VERSION,
|
335
345
|
"min_snowpark_ml_version": self.min_snowpark_ml_version,
|
346
|
+
"model_objective": self.model_objective.value,
|
347
|
+
"explainability": (
|
348
|
+
model_meta_schema.ExplainabilityMetadataDict(algorithm=self.explain_algorithm.value)
|
349
|
+
if self.explain_algorithm
|
350
|
+
else None
|
351
|
+
),
|
352
|
+
"function_properties": self.function_properties,
|
336
353
|
}
|
337
354
|
)
|
338
355
|
|
@@ -370,6 +387,9 @@ class ModelMetadata:
|
|
370
387
|
signatures=loaded_meta["signatures"],
|
371
388
|
version=original_loaded_meta_version,
|
372
389
|
min_snowpark_ml_version=loaded_meta_min_snowpark_ml_version,
|
390
|
+
model_objective=loaded_meta.get("model_objective", model_meta_schema.ModelObjective.UNKNOWN.value),
|
391
|
+
explainability=loaded_meta.get("explainability", None),
|
392
|
+
function_properties=loaded_meta.get("function_properties", {}),
|
373
393
|
)
|
374
394
|
|
375
395
|
@classmethod
|
@@ -406,6 +426,11 @@ class ModelMetadata:
|
|
406
426
|
else:
|
407
427
|
runtimes = None
|
408
428
|
|
429
|
+
explanation_algorithm_dict = model_dict.get("explainability", None)
|
430
|
+
explanation_algorithm = None
|
431
|
+
if explanation_algorithm_dict:
|
432
|
+
explanation_algorithm = model_meta_schema.ModelExplainAlgorithm(explanation_algorithm_dict["algorithm"])
|
433
|
+
|
409
434
|
return cls(
|
410
435
|
name=model_dict["name"],
|
411
436
|
model_type=model_dict["model_type"],
|
@@ -417,4 +442,9 @@ class ModelMetadata:
|
|
417
442
|
min_snowpark_ml_version=model_dict["min_snowpark_ml_version"],
|
418
443
|
models=models,
|
419
444
|
original_metadata_version=model_dict["version"],
|
445
|
+
model_objective=model_meta_schema.ModelObjective(
|
446
|
+
model_dict.get("model_objective", model_meta_schema.ModelObjective.UNKNOWN.value)
|
447
|
+
),
|
448
|
+
explain_algorithm=explanation_algorithm,
|
449
|
+
function_properties=model_dict.get("function_properties", {}),
|
420
450
|
)
|
@@ -71,6 +71,10 @@ ModelBlobOptions = Union[
|
|
71
71
|
]
|
72
72
|
|
73
73
|
|
74
|
+
class ExplainabilityMetadataDict(TypedDict):
|
75
|
+
algorithm: Required[str]
|
76
|
+
|
77
|
+
|
74
78
|
class ModelBlobMetadataDict(TypedDict):
|
75
79
|
name: Required[str]
|
76
80
|
model_type: Required[type_hints.SupportedModelHandlerType]
|
@@ -92,3 +96,18 @@ class ModelMetadataDict(TypedDict):
|
|
92
96
|
signatures: Required[Dict[str, Dict[str, Any]]]
|
93
97
|
version: Required[str]
|
94
98
|
min_snowpark_ml_version: Required[str]
|
99
|
+
model_objective: Required[str]
|
100
|
+
explainability: NotRequired[Optional[ExplainabilityMetadataDict]]
|
101
|
+
function_properties: NotRequired[Dict[str, Dict[str, Any]]]
|
102
|
+
|
103
|
+
|
104
|
+
class ModelObjective(Enum):
|
105
|
+
UNKNOWN = "unknown"
|
106
|
+
BINARY_CLASSIFICATION = "binary_classification"
|
107
|
+
MULTI_CLASSIFICATION = "multi_classification"
|
108
|
+
REGRESSION = "regression"
|
109
|
+
RANKING = "ranking"
|
110
|
+
|
111
|
+
|
112
|
+
class ModelExplainAlgorithm(Enum):
|
113
|
+
SHAP = "shap"
|
@@ -146,7 +146,8 @@ class ModelPackager:
|
|
146
146
|
m = handler.load_model(self.meta.name, self.meta, model_blobs_path, **options)
|
147
147
|
|
148
148
|
if as_custom_model:
|
149
|
-
|
149
|
+
background_data = handler.load_background_data(self.meta.name, model_blobs_path)
|
150
|
+
m = handler.convert_as_custom_model(m, self.meta, background_data, **options)
|
150
151
|
assert isinstance(m, custom_model.CustomModel)
|
151
152
|
|
152
153
|
self.model = m
|
@@ -35,7 +35,7 @@ class ModelRuntime:
|
|
35
35
|
self,
|
36
36
|
name: str,
|
37
37
|
env: model_env.ModelEnv,
|
38
|
-
imports: Optional[List[
|
38
|
+
imports: Optional[List[str]] = None,
|
39
39
|
is_gpu: bool = False,
|
40
40
|
loading_from_file: bool = False,
|
41
41
|
) -> None:
|
@@ -67,7 +67,9 @@ class ModelRuntime:
|
|
67
67
|
def runtime_rel_path(self) -> pathlib.PurePosixPath:
|
68
68
|
return pathlib.PurePosixPath(ModelRuntime.RUNTIME_DIR_REL_PATH) / self.name
|
69
69
|
|
70
|
-
def save(
|
70
|
+
def save(
|
71
|
+
self, packager_path: pathlib.Path, default_channel_override: str = env_utils.SNOWFLAKE_CONDA_CHANNEL_URL
|
72
|
+
) -> model_meta_schema.ModelRuntimeDict:
|
71
73
|
runtime_base_path = packager_path / self.runtime_rel_path
|
72
74
|
runtime_base_path.mkdir(parents=True, exist_ok=True)
|
73
75
|
|
@@ -75,12 +77,12 @@ class ModelRuntime:
|
|
75
77
|
snowpark_ml_lib_path = runtime_base_path / "snowflake-ml-python.zip"
|
76
78
|
file_utils.zip_python_package(str(snowpark_ml_lib_path), "snowflake.ml")
|
77
79
|
snowpark_ml_lib_rel_path = pathlib.PurePosixPath(snowpark_ml_lib_path.relative_to(packager_path).as_posix())
|
78
|
-
self.imports.append(snowpark_ml_lib_rel_path)
|
80
|
+
self.imports.append(str(snowpark_ml_lib_rel_path))
|
79
81
|
|
80
82
|
self.runtime_env.conda_env_rel_path = self.runtime_rel_path / self.runtime_env.conda_env_rel_path
|
81
83
|
self.runtime_env.pip_requirements_rel_path = self.runtime_rel_path / self.runtime_env.pip_requirements_rel_path
|
82
84
|
|
83
|
-
env_dict = self.runtime_env.save_as_dict(packager_path)
|
85
|
+
env_dict = self.runtime_env.save_as_dict(packager_path, default_channel_override=default_channel_override)
|
84
86
|
|
85
87
|
return model_meta_schema.ModelRuntimeDict(
|
86
88
|
imports=list(map(str, self.imports)),
|
@@ -108,6 +110,4 @@ class ModelRuntime:
|
|
108
110
|
warnings.simplefilter("ignore")
|
109
111
|
env.load_from_conda_file(packager_path / conda_env_rel_path)
|
110
112
|
env.load_from_pip_file(packager_path / pip_requirements_rel_path)
|
111
|
-
return ModelRuntime(
|
112
|
-
name=name, env=env, imports=list(map(pathlib.PurePosixPath, loaded_dict["imports"])), loading_from_file=True
|
113
|
-
)
|
113
|
+
return ModelRuntime(name=name, env=env, imports=loaded_dict["imports"], loading_from_file=True)
|
@@ -232,7 +232,7 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
|
|
232
232
|
),
|
233
233
|
)
|
234
234
|
else:
|
235
|
-
if isinstance(data_col[0], list):
|
235
|
+
if isinstance(data_col.iloc[0], list):
|
236
236
|
if not ft_shape:
|
237
237
|
raise snowml_exceptions.SnowflakeMLException(
|
238
238
|
error_code=error_codes.INVALID_DATA,
|
@@ -266,7 +266,7 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
|
|
266
266
|
),
|
267
267
|
)
|
268
268
|
|
269
|
-
elif isinstance(data_col[0], np.ndarray):
|
269
|
+
elif isinstance(data_col.iloc[0], np.ndarray):
|
270
270
|
if not ft_shape:
|
271
271
|
raise snowml_exceptions.SnowflakeMLException(
|
272
272
|
error_code=error_codes.INVALID_DATA,
|
@@ -297,7 +297,7 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
|
|
297
297
|
),
|
298
298
|
)
|
299
299
|
|
300
|
-
elif isinstance(data_col[0], str):
|
300
|
+
elif isinstance(data_col.iloc[0], str):
|
301
301
|
if ft_shape is not None:
|
302
302
|
raise snowml_exceptions.SnowflakeMLException(
|
303
303
|
error_code=error_codes.INVALID_DATA,
|
@@ -316,7 +316,7 @@ def _validate_pandas_df(data: pd.DataFrame, features: Sequence[core.BaseFeatureS
|
|
316
316
|
),
|
317
317
|
)
|
318
318
|
|
319
|
-
elif isinstance(data_col[0], bytes):
|
319
|
+
elif isinstance(data_col.iloc[0], bytes):
|
320
320
|
if ft_shape is not None:
|
321
321
|
raise snowml_exceptions.SnowflakeMLException(
|
322
322
|
error_code=error_codes.INVALID_DATA,
|
snowflake/ml/model/type_hints.py
CHANGED
@@ -232,6 +232,8 @@ class BaseModelSaveOption(TypedDict):
|
|
232
232
|
_legacy_save: NotRequired[bool]
|
233
233
|
function_type: NotRequired[Literal["FUNCTION", "TABLE_FUNCTION"]]
|
234
234
|
method_options: NotRequired[Dict[str, ModelMethodSaveOptions]]
|
235
|
+
include_pip_dependencies: NotRequired[bool]
|
236
|
+
enable_explainability: NotRequired[bool]
|
235
237
|
|
236
238
|
|
237
239
|
class CatBoostModelSaveOptions(BaseModelSaveOption):
|
@@ -41,7 +41,7 @@ cp.register_pickle_by_value(inspect.getmodule(snowpark_dataframe_utils.cast_snow
|
|
41
41
|
|
42
42
|
_PROJECT = "ModelDevelopment"
|
43
43
|
DEFAULT_UDTF_NJOBS = 3
|
44
|
-
ENABLE_EFFICIENT_MEMORY_USAGE =
|
44
|
+
ENABLE_EFFICIENT_MEMORY_USAGE = True
|
45
45
|
_UDTF_STAGE_NAME = f"MEMORY_EFFICIENT_UDTF_{str(uuid.uuid4()).replace('-', '_')}"
|
46
46
|
|
47
47
|
|
@@ -83,7 +83,19 @@ def _load_data_into_udf() -> Tuple[
|
|
83
83
|
with open(local_fit_and_score_kwargs_file_path, mode="rb") as local_fit_and_score_kwargs_file_obj:
|
84
84
|
fit_and_score_kwargs = cp.load(local_fit_and_score_kwargs_file_obj)
|
85
85
|
|
86
|
-
#
|
86
|
+
# Convert dataframe to numpy would save memory consumption
|
87
|
+
# Except for Pipeline, we need to keep the dataframe for the column names
|
88
|
+
from sklearn.pipeline import Pipeline
|
89
|
+
if isinstance(base_estimator, Pipeline):
|
90
|
+
return (
|
91
|
+
df[CONSTANTS['input_cols']],
|
92
|
+
df[CONSTANTS['label_cols']].squeeze(),
|
93
|
+
indices,
|
94
|
+
params_to_evaluate,
|
95
|
+
base_estimator,
|
96
|
+
fit_and_score_kwargs,
|
97
|
+
CONSTANTS
|
98
|
+
)
|
87
99
|
return (
|
88
100
|
df[CONSTANTS['input_cols']].to_numpy(),
|
89
101
|
df[CONSTANTS['label_cols']].squeeze().to_numpy(),
|