snowflake-ml-python 1.5.4__py3-none-any.whl → 1.6.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +2 -0
- snowflake/cortex/_classify_text.py +36 -0
- snowflake/cortex/_complete.py +66 -35
- snowflake/cortex/_util.py +4 -4
- snowflake/ml/_internal/env_utils.py +11 -5
- snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
- snowflake/ml/_internal/lineage/lineage_utils.py +4 -4
- snowflake/ml/_internal/telemetry.py +26 -2
- snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
- snowflake/ml/data/_internal/arrow_ingestor.py +284 -0
- snowflake/ml/data/data_connector.py +186 -0
- snowflake/ml/data/data_ingestor.py +45 -0
- snowflake/ml/data/data_source.py +23 -0
- snowflake/ml/data/ingestor_utils.py +62 -0
- snowflake/ml/data/torch_dataset.py +33 -0
- snowflake/ml/dataset/dataset.py +1 -13
- snowflake/ml/dataset/dataset_metadata.py +3 -1
- snowflake/ml/dataset/dataset_reader.py +23 -117
- snowflake/ml/feature_store/access_manager.py +7 -1
- snowflake/ml/feature_store/entity.py +19 -2
- snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
- snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
- snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
- snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +20 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +37 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +30 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +7 -0
- snowflake/ml/feature_store/examples/example_helper.py +278 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +12 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/location_features.py +44 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +9 -0
- snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +36 -0
- snowflake/ml/feature_store/examples/source_data/fraud_transactions.yaml +29 -0
- snowflake/ml/feature_store/examples/source_data/nyc_yellow_trips.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/winequality_red.yaml +32 -0
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +14 -0
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +36 -0
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +24 -0
- snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +8 -0
- snowflake/ml/feature_store/feature_store.py +637 -76
- snowflake/ml/feature_store/feature_view.py +316 -9
- snowflake/ml/fileset/stage_fs.py +18 -10
- snowflake/ml/lineage/lineage_node.py +1 -1
- snowflake/ml/model/_client/model/model_impl.py +11 -2
- snowflake/ml/model/_client/model/model_version_impl.py +171 -20
- snowflake/ml/model/_client/ops/model_ops.py +105 -27
- snowflake/ml/model/_client/ops/service_ops.py +121 -0
- snowflake/ml/model/_client/service/model_deployment_spec.py +95 -0
- snowflake/ml/model/_client/service/model_deployment_spec_schema.py +31 -0
- snowflake/ml/model/_client/sql/model_version.py +13 -4
- snowflake/ml/model/_client/sql/service.py +129 -0
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +2 -3
- snowflake/ml/model/_model_composer/model_composer.py +14 -14
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +33 -17
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +5 -1
- snowflake/ml/model/_model_composer/model_method/function_generator.py +3 -3
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +3 -32
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +3 -27
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +3 -32
- snowflake/ml/model/_model_composer/model_method/model_method.py +5 -2
- snowflake/ml/model/_packager/model_env/model_env.py +7 -2
- snowflake/ml/model/_packager/model_handlers/_base.py +30 -3
- snowflake/ml/model/_packager/model_handlers/_utils.py +58 -1
- snowflake/ml/model/_packager/model_handlers/catboost.py +52 -3
- snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +9 -5
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +80 -3
- snowflake/ml/model/_packager/model_handlers/llm.py +7 -3
- snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
- snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
- snowflake/ml/model/_packager/model_handlers/sklearn.py +87 -4
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +7 -2
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
- snowflake/ml/model/_packager/model_handlers/torchscript.py +8 -3
- snowflake/ml/model/_packager/model_handlers/xgboost.py +71 -3
- snowflake/ml/model/_packager/model_meta/model_meta.py +32 -2
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +19 -0
- snowflake/ml/model/_packager/model_packager.py +2 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +7 -7
- snowflake/ml/model/model_signature.py +4 -4
- snowflake/ml/model/type_hints.py +2 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +1 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +13 -1
- snowflake/ml/modeling/framework/base.py +28 -19
- snowflake/ml/modeling/impute/simple_imputer.py +26 -0
- snowflake/ml/modeling/pipeline/pipeline.py +7 -4
- snowflake/ml/registry/_manager/model_manager.py +16 -2
- snowflake/ml/registry/registry.py +100 -13
- snowflake/ml/utils/sql_client.py +22 -0
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/METADATA +81 -2
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/RECORD +99 -66
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/lineage/data_source.py +0 -10
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.5.4.dist-info → snowflake_ml_python-1.6.1.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,8 @@
|
|
1
|
+
import os
|
1
2
|
from abc import abstractmethod
|
2
3
|
from typing import Dict, Generic, Optional, Protocol, Type, final
|
3
4
|
|
5
|
+
import pandas as pd
|
4
6
|
from typing_extensions import TypeGuard, Unpack
|
5
7
|
|
6
8
|
from snowflake.ml.model import custom_model, type_hints as model_types
|
@@ -16,7 +18,7 @@ class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
|
|
16
18
|
|
17
19
|
@classmethod
|
18
20
|
@abstractmethod
|
19
|
-
def can_handle(cls, model: model_types.
|
21
|
+
def can_handle(cls, model: model_types.SupportedModelType) -> TypeGuard[model_types._ModelType]:
|
20
22
|
"""Whether this handler could support the type of the `model`.
|
21
23
|
|
22
24
|
Args:
|
@@ -96,6 +98,7 @@ class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
|
|
96
98
|
cls,
|
97
99
|
raw_model: model_types._ModelType,
|
98
100
|
model_meta: model_meta.ModelMetadata,
|
101
|
+
background_data: Optional[pd.DataFrame] = None,
|
99
102
|
**kwargs: Unpack[model_types.BaseModelLoadOption],
|
100
103
|
) -> custom_model.CustomModel:
|
101
104
|
"""Create a custom model class wrap for unified interface when being deployed. The predict method will be
|
@@ -104,6 +107,7 @@ class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
|
|
104
107
|
Args:
|
105
108
|
raw_model: original model object,
|
106
109
|
model_meta: The model metadata.
|
110
|
+
background_data: The background data used for the model explanations.
|
107
111
|
kwargs: Options when converting the model.
|
108
112
|
|
109
113
|
Raises:
|
@@ -121,7 +125,8 @@ class BaseModelHandler(Generic[model_types._ModelType], _BaseModelHandlerProtoco
|
|
121
125
|
_MIN_SNOWPARK_ML_VERSION: The minimal version of Snowpark ML library to use the current handler.
|
122
126
|
_HANDLER_MIGRATOR_PLANS: Dict holding handler migrator plans.
|
123
127
|
|
124
|
-
|
128
|
+
MODEL_BLOB_FILE_OR_DIR: Relative path of the model blob file in the model subdir. Default to "model.pkl".
|
129
|
+
BG_DATA_FILE_SUFFIX: Suffix of the background data file. Default to "_background_data.pqt".
|
125
130
|
MODEL_ARTIFACTS_DIR: Relative path of the model artifacts dir in the model subdir. Default to "artifacts"
|
126
131
|
DEFAULT_TARGET_METHODS: Default target methods to be logged if not specified in this kind of model. Default to
|
127
132
|
["predict"]
|
@@ -129,8 +134,10 @@ class BaseModelHandler(Generic[model_types._ModelType], _BaseModelHandlerProtoco
|
|
129
134
|
inputting sample data or model signature. Default to False.
|
130
135
|
"""
|
131
136
|
|
132
|
-
|
137
|
+
MODEL_BLOB_FILE_OR_DIR = "model.pkl"
|
138
|
+
BG_DATA_FILE_SUFFIX = "_background_data.pqt"
|
133
139
|
MODEL_ARTIFACTS_DIR = "artifacts"
|
140
|
+
EXPLAIN_ARTIFACTS_DIR = "explain_artifacts"
|
134
141
|
DEFAULT_TARGET_METHODS = ["predict"]
|
135
142
|
IS_AUTO_SIGNATURE = False
|
136
143
|
|
@@ -159,3 +166,23 @@ class BaseModelHandler(Generic[model_types._ModelType], _BaseModelHandlerProtoco
|
|
159
166
|
model_meta=model_meta,
|
160
167
|
model_blobs_dir_path=model_blobs_dir_path,
|
161
168
|
)
|
169
|
+
|
170
|
+
@classmethod
|
171
|
+
@final
|
172
|
+
def load_background_data(cls, name: str, model_blobs_dir_path: str) -> Optional[pd.DataFrame]:
|
173
|
+
"""Load the model into memory.
|
174
|
+
|
175
|
+
Args:
|
176
|
+
name: Name of the model.
|
177
|
+
model_blobs_dir_path: Directory path to the whole model.
|
178
|
+
|
179
|
+
Returns:
|
180
|
+
Optional[pd.DataFrame], background data as pandas DataFrame, if exists.
|
181
|
+
"""
|
182
|
+
data_blob_path = os.path.join(model_blobs_dir_path, cls.EXPLAIN_ARTIFACTS_DIR, name + cls.BG_DATA_FILE_SUFFIX)
|
183
|
+
if not os.path.exists(model_blobs_dir_path) or not os.path.isfile(data_blob_path):
|
184
|
+
return None
|
185
|
+
with open(data_blob_path, "rb") as f:
|
186
|
+
background_data = pd.read_parquet(f)
|
187
|
+
|
188
|
+
return background_data
|
@@ -1,4 +1,9 @@
|
|
1
|
-
|
1
|
+
import json
|
2
|
+
from typing import Any, Callable, Iterable, Optional, Sequence, cast
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import numpy.typing as npt
|
6
|
+
import pandas as pd
|
2
7
|
|
3
8
|
from snowflake.ml.model import model_signature, type_hints as model_types
|
4
9
|
from snowflake.ml.model._packager.model_meta import model_meta
|
@@ -40,6 +45,24 @@ def validate_signature(
|
|
40
45
|
return model_meta
|
41
46
|
|
42
47
|
|
48
|
+
def add_explain_method_signature(
|
49
|
+
model_meta: model_meta.ModelMetadata,
|
50
|
+
explain_method: str,
|
51
|
+
target_method: str,
|
52
|
+
output_return_type: model_signature.DataType = model_signature.DataType.DOUBLE,
|
53
|
+
) -> model_meta.ModelMetadata:
|
54
|
+
if target_method not in model_meta.signatures:
|
55
|
+
raise ValueError(f"Signature for target method {target_method} is missing")
|
56
|
+
inputs = model_meta.signatures[target_method].inputs
|
57
|
+
model_meta.signatures[explain_method] = model_signature.ModelSignature(
|
58
|
+
inputs=inputs,
|
59
|
+
outputs=[
|
60
|
+
model_signature.FeatureSpec(dtype=output_return_type, name=f"{spec.name}_explanation") for spec in inputs
|
61
|
+
],
|
62
|
+
)
|
63
|
+
return model_meta
|
64
|
+
|
65
|
+
|
43
66
|
def get_target_methods(
|
44
67
|
model: model_types.SupportedModelType,
|
45
68
|
target_methods: Optional[Sequence[str]],
|
@@ -56,3 +79,37 @@ def validate_target_methods(model: model_types.SupportedModelType, target_method
|
|
56
79
|
for method_name in target_methods:
|
57
80
|
if not _is_callable(model, method_name):
|
58
81
|
raise ValueError(f"Target method {method_name} is not callable or does not exist in the model.")
|
82
|
+
|
83
|
+
|
84
|
+
def get_num_classes_if_exists(model: model_types.SupportedModelType) -> int:
|
85
|
+
num_classes = getattr(model, "classes_", [])
|
86
|
+
return len(num_classes)
|
87
|
+
|
88
|
+
|
89
|
+
def convert_explanations_to_2D_df(
|
90
|
+
model: model_types.SupportedModelType, explanations: npt.NDArray[Any]
|
91
|
+
) -> pd.DataFrame:
|
92
|
+
if explanations.ndim != 3:
|
93
|
+
return pd.DataFrame(explanations)
|
94
|
+
|
95
|
+
if hasattr(model, "classes_"):
|
96
|
+
classes_list = [cl for cl in model.classes_] # type:ignore[union-attr]
|
97
|
+
len_classes = len(classes_list)
|
98
|
+
if explanations.shape[2] != len_classes:
|
99
|
+
raise ValueError(f"Model has {len_classes} classes but explanations have {explanations.shape[2]}")
|
100
|
+
else:
|
101
|
+
classes_list = [i for i in range(explanations.shape[2])]
|
102
|
+
exp_2d = []
|
103
|
+
# TODO (SNOW-1549044): Optimize this
|
104
|
+
for row in explanations:
|
105
|
+
col_list = []
|
106
|
+
for column in row:
|
107
|
+
class_explanations = {}
|
108
|
+
for cl, cl_exp in zip(classes_list, column):
|
109
|
+
if isinstance(cl, (int, np.integer)):
|
110
|
+
cl = int(cl)
|
111
|
+
class_explanations[cl] = cl_exp
|
112
|
+
col_list.append(json.dumps(class_explanations))
|
113
|
+
exp_2d.append(col_list)
|
114
|
+
|
115
|
+
return pd.DataFrame(exp_2d)
|
@@ -30,9 +30,25 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
30
30
|
_MIN_SNOWPARK_ML_VERSION = "1.3.1"
|
31
31
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
32
32
|
|
33
|
-
|
33
|
+
MODEL_BLOB_FILE_OR_DIR = "model.bin"
|
34
34
|
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
35
35
|
|
36
|
+
@classmethod
|
37
|
+
def get_model_objective(cls, model: "catboost.CatBoost") -> model_meta_schema.ModelObjective:
|
38
|
+
import catboost
|
39
|
+
|
40
|
+
if isinstance(model, catboost.CatBoostClassifier):
|
41
|
+
num_classes = handlers_utils.get_num_classes_if_exists(model)
|
42
|
+
if num_classes == 2:
|
43
|
+
return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
|
44
|
+
return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
|
45
|
+
if isinstance(model, catboost.CatBoostRanker):
|
46
|
+
return model_meta_schema.ModelObjective.RANKING
|
47
|
+
if isinstance(model, catboost.CatBoostRegressor):
|
48
|
+
return model_meta_schema.ModelObjective.REGRESSION
|
49
|
+
# TODO: Find out model type from the generic Catboost Model
|
50
|
+
return model_meta_schema.ModelObjective.UNKNOWN
|
51
|
+
|
36
52
|
@classmethod
|
37
53
|
def can_handle(cls, model: model_types.SupportedModelType) -> TypeGuard["catboost.CatBoost"]:
|
38
54
|
return (type_utils.LazyType("catboost.CatBoost").isinstance(model)) and any(
|
@@ -89,10 +105,25 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
89
105
|
sample_input_data=sample_input_data,
|
90
106
|
get_prediction_fn=get_prediction,
|
91
107
|
)
|
108
|
+
model_objective = cls.get_model_objective(model)
|
109
|
+
model_meta.model_objective = model_objective
|
110
|
+
if kwargs.get("enable_explainability", True):
|
111
|
+
output_type = model_signature.DataType.DOUBLE
|
112
|
+
if model_objective == model_meta_schema.ModelObjective.MULTI_CLASSIFICATION:
|
113
|
+
output_type = model_signature.DataType.STRING
|
114
|
+
model_meta = handlers_utils.add_explain_method_signature(
|
115
|
+
model_meta=model_meta,
|
116
|
+
explain_method="explain",
|
117
|
+
target_method="predict",
|
118
|
+
output_return_type=output_type,
|
119
|
+
)
|
120
|
+
model_meta.function_properties = {
|
121
|
+
"explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False}
|
122
|
+
}
|
92
123
|
|
93
124
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
94
125
|
os.makedirs(model_blob_path, exist_ok=True)
|
95
|
-
model_save_path = os.path.join(model_blob_path, cls.
|
126
|
+
model_save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
|
96
127
|
|
97
128
|
model.save_model(model_save_path)
|
98
129
|
|
@@ -100,7 +131,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
100
131
|
name=name,
|
101
132
|
model_type=cls.HANDLER_TYPE,
|
102
133
|
handler_version=cls.HANDLER_VERSION,
|
103
|
-
path=cls.
|
134
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
104
135
|
options=model_meta_schema.CatBoostModelBlobOptions({"catboost_estimator_type": model.__class__.__name__}),
|
105
136
|
)
|
106
137
|
model_meta.models[name] = base_meta
|
@@ -112,6 +143,12 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
112
143
|
],
|
113
144
|
check_local_version=True,
|
114
145
|
)
|
146
|
+
if kwargs.get("enable_explainability", True):
|
147
|
+
model_meta.env.include_if_absent(
|
148
|
+
[model_env.ModelDependency(requirement="shap", pip_name="shap")],
|
149
|
+
check_local_version=True,
|
150
|
+
)
|
151
|
+
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
115
152
|
model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
|
116
153
|
|
117
154
|
return None
|
@@ -157,6 +194,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
157
194
|
cls,
|
158
195
|
raw_model: "catboost.CatBoost",
|
159
196
|
model_meta: model_meta_api.ModelMetadata,
|
197
|
+
background_data: Optional[pd.DataFrame] = None,
|
160
198
|
**kwargs: Unpack[model_types.CatBoostModelLoadOptions],
|
161
199
|
) -> custom_model.CustomModel:
|
162
200
|
import catboost
|
@@ -186,6 +224,17 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
186
224
|
|
187
225
|
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
188
226
|
|
227
|
+
@custom_model.inference_api
|
228
|
+
def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
229
|
+
import shap
|
230
|
+
|
231
|
+
explainer = shap.TreeExplainer(raw_model)
|
232
|
+
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
|
233
|
+
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
234
|
+
|
235
|
+
if target_method == "explain":
|
236
|
+
return explain_fn
|
237
|
+
|
189
238
|
return fn
|
190
239
|
|
191
240
|
type_method_dict: Dict[str, Any] = {"_raw_model": raw_model}
|
@@ -51,6 +51,9 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
51
51
|
**kwargs: Unpack[model_types.CustomModelSaveOption],
|
52
52
|
) -> None:
|
53
53
|
assert isinstance(model, custom_model.CustomModel)
|
54
|
+
enable_explainability = kwargs.get("enable_explainability", False)
|
55
|
+
if enable_explainability:
|
56
|
+
raise NotImplementedError("Explainability is not supported for custom model.")
|
54
57
|
|
55
58
|
def get_prediction(
|
56
59
|
target_method_name: str, sample_input_data: model_types.SupportedLocalDataType
|
@@ -108,13 +111,13 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
108
111
|
# Make sure that the module where the model is defined get pickled by value as well.
|
109
112
|
cloudpickle.register_pickle_by_value(sys.modules[model.__module__])
|
110
113
|
pickled_obj = (model.__class__, model.context)
|
111
|
-
with open(os.path.join(model_blob_path, cls.
|
114
|
+
with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
|
112
115
|
cloudpickle.dump(pickled_obj, f)
|
113
116
|
# model meta will be saved by the context manager
|
114
117
|
model_meta.models[name] = model_blob_meta.ModelBlobMeta(
|
115
118
|
name=name,
|
116
119
|
model_type=cls.HANDLER_TYPE,
|
117
|
-
path=cls.
|
120
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
118
121
|
handler_version=cls.HANDLER_VERSION,
|
119
122
|
function_properties=model_meta.function_properties,
|
120
123
|
artifacts={
|
@@ -183,6 +186,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
183
186
|
cls,
|
184
187
|
raw_model: custom_model.CustomModel,
|
185
188
|
model_meta: model_meta_api.ModelMetadata,
|
189
|
+
background_data: Optional[pd.DataFrame] = None,
|
186
190
|
**kwargs: Unpack[model_types.CustomModelLoadOption],
|
187
191
|
) -> custom_model.CustomModel:
|
188
192
|
return raw_model
|
@@ -89,7 +89,7 @@ class HuggingFacePipelineHandler(
|
|
89
89
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
90
90
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
91
91
|
|
92
|
-
|
92
|
+
MODEL_BLOB_FILE_OR_DIR = "model"
|
93
93
|
ADDITIONAL_CONFIG_FILE = "pipeline_config.pt"
|
94
94
|
DEFAULT_TARGET_METHODS = ["__call__"]
|
95
95
|
IS_AUTO_SIGNATURE = True
|
@@ -133,6 +133,9 @@ class HuggingFacePipelineHandler(
|
|
133
133
|
is_sub_model: Optional[bool] = False,
|
134
134
|
**kwargs: Unpack[model_types.HuggingFaceSaveOptions],
|
135
135
|
) -> None:
|
136
|
+
enable_explainability = kwargs.get("enable_explainability", False)
|
137
|
+
if enable_explainability:
|
138
|
+
raise NotImplementedError("Explainability is not supported for huggingface model.")
|
136
139
|
if type_utils.LazyType("transformers.Pipeline").isinstance(model):
|
137
140
|
task = model.task # type:ignore[attr-defined]
|
138
141
|
framework = model.framework # type:ignore[attr-defined]
|
@@ -193,7 +196,7 @@ class HuggingFacePipelineHandler(
|
|
193
196
|
|
194
197
|
if type_utils.LazyType("transformers.Pipeline").isinstance(model):
|
195
198
|
model.save_pretrained( # type:ignore[attr-defined]
|
196
|
-
os.path.join(model_blob_path, cls.
|
199
|
+
os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
|
197
200
|
)
|
198
201
|
pipeline_params = {
|
199
202
|
"_batch_size": model._batch_size, # type:ignore[attr-defined]
|
@@ -205,7 +208,7 @@ class HuggingFacePipelineHandler(
|
|
205
208
|
with open(
|
206
209
|
os.path.join(
|
207
210
|
model_blob_path,
|
208
|
-
cls.
|
211
|
+
cls.MODEL_BLOB_FILE_OR_DIR,
|
209
212
|
cls.ADDITIONAL_CONFIG_FILE,
|
210
213
|
),
|
211
214
|
"wb",
|
@@ -213,7 +216,7 @@ class HuggingFacePipelineHandler(
|
|
213
216
|
cloudpickle.dump(pipeline_params, f)
|
214
217
|
else:
|
215
218
|
with open(
|
216
|
-
os.path.join(model_blob_path, cls.
|
219
|
+
os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR),
|
217
220
|
"wb",
|
218
221
|
) as f:
|
219
222
|
cloudpickle.dump(model, f)
|
@@ -222,7 +225,7 @@ class HuggingFacePipelineHandler(
|
|
222
225
|
name=name,
|
223
226
|
model_type=cls.HANDLER_TYPE,
|
224
227
|
handler_version=cls.HANDLER_VERSION,
|
225
|
-
path=cls.
|
228
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
226
229
|
options=model_meta_schema.HuggingFacePipelineModelBlobOptions(
|
227
230
|
{
|
228
231
|
"task": task,
|
@@ -329,6 +332,7 @@ class HuggingFacePipelineHandler(
|
|
329
332
|
cls,
|
330
333
|
raw_model: Union[huggingface_pipeline.HuggingFacePipelineModel, "transformers.Pipeline"],
|
331
334
|
model_meta: model_meta_api.ModelMetadata,
|
335
|
+
background_data: Optional[pd.DataFrame] = None,
|
332
336
|
**kwargs: Unpack[model_types.HuggingFaceLoadOptions],
|
333
337
|
) -> custom_model.CustomModel:
|
334
338
|
import transformers
|
@@ -41,8 +41,49 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
41
41
|
_MIN_SNOWPARK_ML_VERSION = "1.3.1"
|
42
42
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
43
43
|
|
44
|
-
|
44
|
+
MODEL_BLOB_FILE_OR_DIR = "model.pkl"
|
45
45
|
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
46
|
+
_BINARY_CLASSIFICATION_OBJECTIVES = ["binary"]
|
47
|
+
_MULTI_CLASSIFICATION_OBJECTIVES = ["multiclass", "multiclassova"]
|
48
|
+
_RANKING_OBJECTIVES = ["lambdarank", "rank_xendcg"]
|
49
|
+
_REGRESSION_OBJECTIVES = [
|
50
|
+
"regression",
|
51
|
+
"regression_l1",
|
52
|
+
"huber",
|
53
|
+
"fair",
|
54
|
+
"poisson",
|
55
|
+
"quantile",
|
56
|
+
"tweedie",
|
57
|
+
"mape",
|
58
|
+
"gamma",
|
59
|
+
]
|
60
|
+
|
61
|
+
@classmethod
|
62
|
+
def get_model_objective(
|
63
|
+
cls, model: Union["lightgbm.Booster", "lightgbm.LGBMModel"]
|
64
|
+
) -> model_meta_schema.ModelObjective:
|
65
|
+
import lightgbm
|
66
|
+
|
67
|
+
# does not account for cross-entropy and custom
|
68
|
+
if isinstance(model, lightgbm.LGBMClassifier):
|
69
|
+
num_classes = handlers_utils.get_num_classes_if_exists(model)
|
70
|
+
if num_classes == 2:
|
71
|
+
return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
|
72
|
+
return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
|
73
|
+
if isinstance(model, lightgbm.LGBMRanker):
|
74
|
+
return model_meta_schema.ModelObjective.RANKING
|
75
|
+
if isinstance(model, lightgbm.LGBMRegressor):
|
76
|
+
return model_meta_schema.ModelObjective.REGRESSION
|
77
|
+
model_objective = model.params["objective"]
|
78
|
+
if model_objective in cls._BINARY_CLASSIFICATION_OBJECTIVES:
|
79
|
+
return model_meta_schema.ModelObjective.BINARY_CLASSIFICATION
|
80
|
+
if model_objective in cls._MULTI_CLASSIFICATION_OBJECTIVES:
|
81
|
+
return model_meta_schema.ModelObjective.MULTI_CLASSIFICATION
|
82
|
+
if model_objective in cls._RANKING_OBJECTIVES:
|
83
|
+
return model_meta_schema.ModelObjective.RANKING
|
84
|
+
if model_objective in cls._REGRESSION_OBJECTIVES:
|
85
|
+
return model_meta_schema.ModelObjective.REGRESSION
|
86
|
+
return model_meta_schema.ModelObjective.UNKNOWN
|
46
87
|
|
47
88
|
@classmethod
|
48
89
|
def can_handle(
|
@@ -105,11 +146,29 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
105
146
|
sample_input_data=sample_input_data,
|
106
147
|
get_prediction_fn=get_prediction,
|
107
148
|
)
|
149
|
+
model_objective = cls.get_model_objective(model)
|
150
|
+
model_meta.model_objective = model_objective
|
151
|
+
if kwargs.get("enable_explainability", True):
|
152
|
+
output_type = model_signature.DataType.DOUBLE
|
153
|
+
if model_objective in [
|
154
|
+
model_meta_schema.ModelObjective.BINARY_CLASSIFICATION,
|
155
|
+
model_meta_schema.ModelObjective.MULTI_CLASSIFICATION,
|
156
|
+
]:
|
157
|
+
output_type = model_signature.DataType.STRING
|
158
|
+
model_meta = handlers_utils.add_explain_method_signature(
|
159
|
+
model_meta=model_meta,
|
160
|
+
explain_method="explain",
|
161
|
+
target_method="predict",
|
162
|
+
output_return_type=output_type,
|
163
|
+
)
|
164
|
+
model_meta.function_properties = {
|
165
|
+
"explain": {model_meta_schema.FunctionProperties.PARTITIONED.value: False}
|
166
|
+
}
|
108
167
|
|
109
168
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
110
169
|
os.makedirs(model_blob_path, exist_ok=True)
|
111
170
|
|
112
|
-
model_save_path = os.path.join(model_blob_path, cls.
|
171
|
+
model_save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
|
113
172
|
with open(model_save_path, "wb") as f:
|
114
173
|
cloudpickle.dump(model, f)
|
115
174
|
|
@@ -117,7 +176,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
117
176
|
name=name,
|
118
177
|
model_type=cls.HANDLER_TYPE,
|
119
178
|
handler_version=cls.HANDLER_VERSION,
|
120
|
-
path=cls.
|
179
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
121
180
|
options=model_meta_schema.LightGBMModelBlobOptions({"lightgbm_estimator_type": model.__class__.__name__}),
|
122
181
|
)
|
123
182
|
model_meta.models[name] = base_meta
|
@@ -130,6 +189,12 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
130
189
|
],
|
131
190
|
check_local_version=True,
|
132
191
|
)
|
192
|
+
if kwargs.get("enable_explainability", True):
|
193
|
+
model_meta.env.include_if_absent(
|
194
|
+
[model_env.ModelDependency(requirement="shap", pip_name="shap")],
|
195
|
+
check_local_version=True,
|
196
|
+
)
|
197
|
+
model_meta.explain_algorithm = model_meta_schema.ModelExplainAlgorithm.SHAP
|
133
198
|
|
134
199
|
return None
|
135
200
|
|
@@ -169,6 +234,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
169
234
|
cls,
|
170
235
|
raw_model: Union["lightgbm.Booster", "lightgbm.XGBModel"],
|
171
236
|
model_meta: model_meta_api.ModelMetadata,
|
237
|
+
background_data: Optional[pd.DataFrame] = None,
|
172
238
|
**kwargs: Unpack[model_types.LGBMModelLoadOptions],
|
173
239
|
) -> custom_model.CustomModel:
|
174
240
|
import lightgbm
|
@@ -198,6 +264,17 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
198
264
|
|
199
265
|
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
200
266
|
|
267
|
+
@custom_model.inference_api
|
268
|
+
def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
269
|
+
import shap
|
270
|
+
|
271
|
+
explainer = shap.TreeExplainer(raw_model)
|
272
|
+
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
|
273
|
+
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
274
|
+
|
275
|
+
if target_method == "explain":
|
276
|
+
return explain_fn
|
277
|
+
|
201
278
|
return fn
|
202
279
|
|
203
280
|
type_method_dict: Dict[str, Any] = {"_raw_model": raw_model}
|
@@ -28,7 +28,7 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
|
|
28
28
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
29
29
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
30
30
|
|
31
|
-
|
31
|
+
MODEL_BLOB_FILE_OR_DIR = "model"
|
32
32
|
LLM_META = "llm_meta"
|
33
33
|
IS_AUTO_SIGNATURE = True
|
34
34
|
|
@@ -59,9 +59,12 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
|
|
59
59
|
**kwargs: Unpack[model_types.LLMSaveOptions],
|
60
60
|
) -> None:
|
61
61
|
assert not is_sub_model, "LLM can not be sub-model."
|
62
|
+
enable_explainability = kwargs.get("enable_explainability", False)
|
63
|
+
if enable_explainability:
|
64
|
+
raise NotImplementedError("Explainability is not supported for llm model.")
|
62
65
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
63
66
|
os.makedirs(model_blob_path, exist_ok=True)
|
64
|
-
model_blob_dir_path = os.path.join(model_blob_path, cls.
|
67
|
+
model_blob_dir_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
|
65
68
|
|
66
69
|
sig = model_signature.ModelSignature(
|
67
70
|
inputs=[
|
@@ -86,7 +89,7 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
|
|
86
89
|
name=name,
|
87
90
|
model_type=cls.HANDLER_TYPE,
|
88
91
|
handler_version=cls.HANDLER_VERSION,
|
89
|
-
path=cls.
|
92
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
90
93
|
options=model_meta_schema.LLMModelBlobOptions(
|
91
94
|
{
|
92
95
|
"batch_size": model.max_batch_size,
|
@@ -143,6 +146,7 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
|
|
143
146
|
cls,
|
144
147
|
raw_model: llm.LLM,
|
145
148
|
model_meta: model_meta_api.ModelMetadata,
|
149
|
+
background_data: Optional[pd.DataFrame] = None,
|
146
150
|
**kwargs: Unpack[model_types.LLMLoadOptions],
|
147
151
|
) -> custom_model.CustomModel:
|
148
152
|
import gc
|
@@ -63,7 +63,7 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
|
|
63
63
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
64
64
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
65
65
|
|
66
|
-
|
66
|
+
MODEL_BLOB_FILE_OR_DIR = "model"
|
67
67
|
_DEFAULT_TARGET_METHOD = "predict"
|
68
68
|
DEFAULT_TARGET_METHODS = [_DEFAULT_TARGET_METHOD]
|
69
69
|
IS_AUTO_SIGNATURE = True
|
@@ -97,6 +97,10 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
|
|
97
97
|
is_sub_model: Optional[bool] = False,
|
98
98
|
**kwargs: Unpack[model_types.MLFlowSaveOptions],
|
99
99
|
) -> None:
|
100
|
+
enable_explainability = kwargs.get("enable_explainability", False)
|
101
|
+
if enable_explainability:
|
102
|
+
raise NotImplementedError("Explainability is not supported for MLFlow model.")
|
103
|
+
|
100
104
|
import mlflow
|
101
105
|
|
102
106
|
assert isinstance(model, mlflow.pyfunc.PyFuncModel)
|
@@ -142,13 +146,13 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
|
|
142
146
|
except (mlflow.MlflowException, OSError):
|
143
147
|
raise ValueError("Cannot load MLFlow model artifacts.")
|
144
148
|
|
145
|
-
file_utils.copy_file_or_tree(local_path, os.path.join(model_blob_path, cls.
|
149
|
+
file_utils.copy_file_or_tree(local_path, os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
|
146
150
|
|
147
151
|
base_meta = model_blob_meta.ModelBlobMeta(
|
148
152
|
name=name,
|
149
153
|
model_type=cls.HANDLER_TYPE,
|
150
154
|
handler_version=cls.HANDLER_VERSION,
|
151
|
-
path=cls.
|
155
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
152
156
|
options=model_meta_schema.MLFlowModelBlobOptions({"artifact_path": model_info.artifact_path}),
|
153
157
|
)
|
154
158
|
model_meta.models[name] = base_meta
|
@@ -194,6 +198,7 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
|
|
194
198
|
cls,
|
195
199
|
raw_model: "mlflow.pyfunc.PyFuncModel",
|
196
200
|
model_meta: model_meta_api.ModelMetadata,
|
201
|
+
background_data: Optional[pd.DataFrame] = None,
|
197
202
|
**kwargs: Unpack[model_types.MLFlowLoadOptions],
|
198
203
|
) -> custom_model.CustomModel:
|
199
204
|
from snowflake.ml.model import custom_model
|
@@ -37,7 +37,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
37
37
|
_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
38
38
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
39
39
|
|
40
|
-
|
40
|
+
MODEL_BLOB_FILE_OR_DIR = "model.pt"
|
41
41
|
DEFAULT_TARGET_METHODS = ["forward"]
|
42
42
|
|
43
43
|
@classmethod
|
@@ -73,6 +73,10 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
73
73
|
is_sub_model: Optional[bool] = False,
|
74
74
|
**kwargs: Unpack[model_types.PyTorchSaveOptions],
|
75
75
|
) -> None:
|
76
|
+
enable_explainability = kwargs.get("enable_explainability", False)
|
77
|
+
if enable_explainability:
|
78
|
+
raise NotImplementedError("Explainability is not supported for PyTorch model.")
|
79
|
+
|
76
80
|
import torch
|
77
81
|
|
78
82
|
assert isinstance(model, torch.nn.Module)
|
@@ -115,13 +119,13 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
115
119
|
cloudpickle.register_pickle_by_value(sys.modules[model.__module__])
|
116
120
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
117
121
|
os.makedirs(model_blob_path, exist_ok=True)
|
118
|
-
with open(os.path.join(model_blob_path, cls.
|
122
|
+
with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
|
119
123
|
torch.save(model, f, pickle_module=cloudpickle)
|
120
124
|
base_meta = model_blob_meta.ModelBlobMeta(
|
121
125
|
name=name,
|
122
126
|
model_type=cls.HANDLER_TYPE,
|
123
127
|
handler_version=cls.HANDLER_VERSION,
|
124
|
-
path=cls.
|
128
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
125
129
|
)
|
126
130
|
model_meta.models[name] = base_meta
|
127
131
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
@@ -156,6 +160,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
156
160
|
cls,
|
157
161
|
raw_model: "torch.nn.Module",
|
158
162
|
model_meta: model_meta_api.ModelMetadata,
|
163
|
+
background_data: Optional[pd.DataFrame] = None,
|
159
164
|
**kwargs: Unpack[model_types.PyTorchLoadOptions],
|
160
165
|
) -> custom_model.CustomModel:
|
161
166
|
import torch
|
@@ -31,7 +31,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
31
31
|
_MIN_SNOWPARK_ML_VERSION = "1.3.1"
|
32
32
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
33
33
|
|
34
|
-
|
34
|
+
MODEL_BLOB_FILE_OR_DIR = "model"
|
35
35
|
DEFAULT_TARGET_METHODS = ["encode"]
|
36
36
|
|
37
37
|
@classmethod
|
@@ -64,6 +64,10 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
64
64
|
is_sub_model: Optional[bool] = False,
|
65
65
|
**kwargs: Unpack[model_types.SentenceTransformersSaveOptions], # registry.log_model(options={...})
|
66
66
|
) -> None:
|
67
|
+
enable_explainability = kwargs.get("enable_explainability", False)
|
68
|
+
if enable_explainability:
|
69
|
+
raise NotImplementedError("Explainability is not supported for Sentence Transformer model.")
|
70
|
+
|
67
71
|
# Validate target methods and signature (if possible)
|
68
72
|
if not is_sub_model:
|
69
73
|
target_methods = handlers_utils.get_target_methods(
|
@@ -101,14 +105,14 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
101
105
|
# save model
|
102
106
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
103
107
|
os.makedirs(model_blob_path, exist_ok=True)
|
104
|
-
model.save(os.path.join(model_blob_path, cls.
|
108
|
+
model.save(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR))
|
105
109
|
|
106
110
|
# save model metadata
|
107
111
|
base_meta = model_blob_meta.ModelBlobMeta(
|
108
112
|
name=name,
|
109
113
|
model_type=cls.HANDLER_TYPE,
|
110
114
|
handler_version=cls.HANDLER_VERSION,
|
111
|
-
path=cls.
|
115
|
+
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
112
116
|
)
|
113
117
|
model_meta.models[name] = base_meta
|
114
118
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
@@ -154,6 +158,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
154
158
|
cls,
|
155
159
|
raw_model: "sentence_transformers.SentenceTransformer",
|
156
160
|
model_meta: model_meta_api.ModelMetadata,
|
161
|
+
background_data: Optional[pd.DataFrame] = None,
|
157
162
|
**kwargs: Unpack[model_types.SentenceTransformersLoadOptions],
|
158
163
|
) -> custom_model.CustomModel:
|
159
164
|
import sentence_transformers
|