snowflake-ml-python 1.5.3__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -1
- snowflake/cortex/_classify_text.py +36 -0
- snowflake/cortex/_complete.py +281 -21
- snowflake/cortex/_extract_answer.py +0 -1
- snowflake/cortex/_sentiment.py +0 -1
- snowflake/cortex/_summarize.py +0 -1
- snowflake/cortex/_translate.py +0 -1
- snowflake/cortex/_util.py +12 -85
- snowflake/ml/_internal/container_services/image_registry/http_client.py +10 -3
- snowflake/ml/_internal/container_services/image_registry/imagelib.py +23 -10
- snowflake/ml/_internal/container_services/image_registry/registry_client.py +7 -1
- snowflake/ml/_internal/exceptions/dataset_errors.py +7 -7
- snowflake/ml/_internal/exceptions/fileset_errors.py +3 -3
- snowflake/ml/_internal/exceptions/sql_error_codes.py +6 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +4 -4
- snowflake/ml/_internal/telemetry.py +38 -2
- snowflake/ml/_internal/utils/identifier.py +14 -0
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +15 -4
- snowflake/ml/data/_internal/arrow_ingestor.py +228 -0
- snowflake/ml/data/_internal/ingestor_utils.py +58 -0
- snowflake/ml/data/data_connector.py +133 -0
- snowflake/ml/data/data_ingestor.py +28 -0
- snowflake/ml/data/data_source.py +23 -0
- snowflake/ml/dataset/dataset.py +39 -32
- snowflake/ml/dataset/dataset_reader.py +18 -118
- snowflake/ml/feature_store/access_manager.py +7 -1
- snowflake/ml/feature_store/entity.py +19 -2
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +20 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +31 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +24 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +4 -0
- snowflake/ml/feature_store/examples/example_helper.py +240 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +12 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/dropoff_features.py +39 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +58 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -0
- snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +36 -0
- snowflake/ml/feature_store/examples/source_data/fraud_transactions.yaml +29 -0
- snowflake/ml/feature_store/examples/source_data/nyc_yellow_trips.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/winequality_red.yaml +32 -0
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +14 -0
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +29 -0
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +21 -0
- snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +5 -0
- snowflake/ml/feature_store/feature_store.py +987 -264
- snowflake/ml/feature_store/feature_view.py +228 -13
- snowflake/ml/fileset/embedded_stage_fs.py +25 -21
- snowflake/ml/fileset/fileset.py +2 -2
- snowflake/ml/fileset/snowfs.py +4 -15
- snowflake/ml/fileset/stage_fs.py +24 -18
- snowflake/ml/lineage/__init__.py +3 -0
- snowflake/ml/lineage/lineage_node.py +139 -0
- snowflake/ml/model/_client/model/model_impl.py +47 -14
- snowflake/ml/model/_client/model/model_version_impl.py +82 -2
- snowflake/ml/model/_client/ops/model_ops.py +77 -5
- snowflake/ml/model/_client/sql/model.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +45 -2
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_model_composer/model_composer.py +15 -17
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -17
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
- snowflake/ml/model/_model_composer/model_method/function_generator.py +20 -4
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +3 -32
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +55 -0
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +5 -34
- snowflake/ml/model/_model_composer/model_method/model_method.py +10 -7
- snowflake/ml/model/_packager/model_handlers/_base.py +13 -3
- snowflake/ml/model/_packager/model_handlers/_utils.py +59 -1
- snowflake/ml/model/_packager/model_handlers/catboost.py +44 -2
- snowflake/ml/model/_packager/model_handlers/custom.py +12 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +18 -15
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +70 -2
- snowflake/ml/model/_packager/model_handlers/llm.py +2 -2
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -2
- snowflake/ml/model/_packager/model_handlers/pytorch.py +2 -2
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +2 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -2
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +2 -2
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +2 -2
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +61 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +21 -1
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
- snowflake/ml/model/_packager/model_packager.py +9 -4
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -5
- snowflake/ml/model/custom_model.py +22 -2
- snowflake/ml/model/model_signature.py +4 -4
- snowflake/ml/model/type_hints.py +77 -4
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +3 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +13 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +6 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +1 -0
- snowflake/ml/modeling/cluster/affinity_propagation.py +4 -2
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +4 -2
- snowflake/ml/modeling/cluster/birch.py +4 -2
- snowflake/ml/modeling/cluster/bisecting_k_means.py +4 -2
- snowflake/ml/modeling/cluster/dbscan.py +4 -2
- snowflake/ml/modeling/cluster/feature_agglomeration.py +4 -2
- snowflake/ml/modeling/cluster/k_means.py +4 -2
- snowflake/ml/modeling/cluster/mean_shift.py +4 -2
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +4 -2
- snowflake/ml/modeling/cluster/optics.py +4 -2
- snowflake/ml/modeling/cluster/spectral_biclustering.py +4 -2
- snowflake/ml/modeling/cluster/spectral_clustering.py +4 -2
- snowflake/ml/modeling/cluster/spectral_coclustering.py +4 -2
- snowflake/ml/modeling/compose/column_transformer.py +4 -2
- snowflake/ml/modeling/covariance/elliptic_envelope.py +4 -2
- snowflake/ml/modeling/covariance/empirical_covariance.py +4 -2
- snowflake/ml/modeling/covariance/graphical_lasso.py +4 -2
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +4 -2
- snowflake/ml/modeling/covariance/ledoit_wolf.py +4 -2
- snowflake/ml/modeling/covariance/min_cov_det.py +4 -2
- snowflake/ml/modeling/covariance/oas.py +4 -2
- snowflake/ml/modeling/covariance/shrunk_covariance.py +4 -2
- snowflake/ml/modeling/decomposition/dictionary_learning.py +4 -2
- snowflake/ml/modeling/decomposition/factor_analysis.py +4 -2
- snowflake/ml/modeling/decomposition/fast_ica.py +4 -2
- snowflake/ml/modeling/decomposition/incremental_pca.py +4 -2
- snowflake/ml/modeling/decomposition/kernel_pca.py +4 -2
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +4 -2
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +4 -2
- snowflake/ml/modeling/decomposition/pca.py +4 -2
- snowflake/ml/modeling/decomposition/sparse_pca.py +4 -2
- snowflake/ml/modeling/decomposition/truncated_svd.py +4 -2
- snowflake/ml/modeling/ensemble/isolation_forest.py +4 -2
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +4 -2
- snowflake/ml/modeling/feature_selection/variance_threshold.py +4 -2
- snowflake/ml/modeling/impute/iterative_imputer.py +4 -2
- snowflake/ml/modeling/impute/knn_imputer.py +4 -2
- snowflake/ml/modeling/impute/missing_indicator.py +4 -2
- snowflake/ml/modeling/impute/simple_imputer.py +26 -0
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +4 -2
- snowflake/ml/modeling/kernel_approximation/nystroem.py +4 -2
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +4 -2
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +4 -2
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +4 -2
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +4 -2
- snowflake/ml/modeling/manifold/isomap.py +4 -2
- snowflake/ml/modeling/manifold/mds.py +4 -2
- snowflake/ml/modeling/manifold/spectral_embedding.py +4 -2
- snowflake/ml/modeling/manifold/tsne.py +4 -2
- snowflake/ml/modeling/metrics/ranking.py +3 -0
- snowflake/ml/modeling/metrics/regression.py +3 -0
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +4 -2
- snowflake/ml/modeling/mixture/gaussian_mixture.py +4 -2
- snowflake/ml/modeling/neighbors/kernel_density.py +4 -2
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +4 -2
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +4 -2
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +4 -2
- snowflake/ml/modeling/pipeline/pipeline.py +5 -4
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +43 -9
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +36 -8
- snowflake/ml/modeling/preprocessing/polynomial_features.py +4 -2
- snowflake/ml/registry/_manager/model_manager.py +16 -3
- snowflake/ml/registry/registry.py +100 -13
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/METADATA +81 -7
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/RECORD +165 -139
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/lineage/data_source.py +0 -10
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/top_level.txt +0 -0
@@ -3,6 +3,7 @@ import os
|
|
3
3
|
import warnings
|
4
4
|
from typing import (
|
5
5
|
TYPE_CHECKING,
|
6
|
+
Any,
|
6
7
|
Callable,
|
7
8
|
Dict,
|
8
9
|
List,
|
@@ -250,9 +251,18 @@ class HuggingFacePipelineHandler(
|
|
250
251
|
model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
|
251
252
|
|
252
253
|
@staticmethod
|
253
|
-
def _get_device_config() -> Dict[str, str]:
|
254
|
-
device_config = {}
|
255
|
-
|
254
|
+
def _get_device_config(**kwargs: Unpack[model_types.HuggingFaceLoadOptions]) -> Dict[str, str]:
|
255
|
+
device_config: Dict[str, Any] = {}
|
256
|
+
if (
|
257
|
+
kwargs.get("use_gpu", False)
|
258
|
+
and kwargs.get("device_map", None) is None
|
259
|
+
and kwargs.get("device", None) is None
|
260
|
+
):
|
261
|
+
device_config["device_map"] = "auto"
|
262
|
+
elif kwargs.get("device_map", None) is not None:
|
263
|
+
device_config["device_map"] = kwargs["device_map"]
|
264
|
+
elif kwargs.get("device", None) is not None:
|
265
|
+
device_config["device"] = kwargs["device"]
|
256
266
|
|
257
267
|
return device_config
|
258
268
|
|
@@ -262,7 +272,7 @@ class HuggingFacePipelineHandler(
|
|
262
272
|
name: str,
|
263
273
|
model_meta: model_meta_api.ModelMetadata,
|
264
274
|
model_blobs_dir_path: str,
|
265
|
-
**kwargs: Unpack[model_types.
|
275
|
+
**kwargs: Unpack[model_types.HuggingFaceLoadOptions],
|
266
276
|
) -> Union[huggingface_pipeline.HuggingFacePipelineModel, "transformers.Pipeline"]:
|
267
277
|
if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
|
268
278
|
# We need to redirect the some folders to a writable location in the sandbox.
|
@@ -292,10 +302,7 @@ class HuggingFacePipelineHandler(
|
|
292
302
|
) as f:
|
293
303
|
pipeline_params = cloudpickle.load(f)
|
294
304
|
|
295
|
-
|
296
|
-
device_config = cls._get_device_config()
|
297
|
-
else:
|
298
|
-
device_config = {}
|
305
|
+
device_config = cls._get_device_config(**kwargs)
|
299
306
|
|
300
307
|
m = transformers.pipeline(
|
301
308
|
model_blob_options["task"],
|
@@ -310,12 +317,8 @@ class HuggingFacePipelineHandler(
|
|
310
317
|
with open(model_blob_file_or_dir_path, "rb") as f:
|
311
318
|
m = cloudpickle.load(f)
|
312
319
|
assert isinstance(m, huggingface_pipeline.HuggingFacePipelineModel)
|
313
|
-
if (
|
314
|
-
|
315
|
-
and getattr(m, "device_map", None) is None
|
316
|
-
and kwargs.get("use_gpu", False)
|
317
|
-
):
|
318
|
-
m.__dict__.update(cls._get_device_config())
|
320
|
+
if getattr(m, "device", None) is None and getattr(m, "device_map", None) is None:
|
321
|
+
m.__dict__.update(cls._get_device_config(**kwargs))
|
319
322
|
|
320
323
|
if getattr(m, "torch_dtype", None) is None and kwargs.get("use_gpu", False):
|
321
324
|
m.__dict__.update(torch_dtype="auto")
|
@@ -326,7 +329,7 @@ class HuggingFacePipelineHandler(
|
|
326
329
|
cls,
|
327
330
|
raw_model: Union[huggingface_pipeline.HuggingFacePipelineModel, "transformers.Pipeline"],
|
328
331
|
model_meta: model_meta_api.ModelMetadata,
|
329
|
-
**kwargs: Unpack[model_types.
|
332
|
+
**kwargs: Unpack[model_types.HuggingFaceLoadOptions],
|
330
333
|
) -> custom_model.CustomModel:
|
331
334
|
import transformers
|
332
335
|
|
@@ -43,6 +43,45 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
43
43
|
|
44
44
|
MODELE_BLOB_FILE_OR_DIR = "model.pkl"
|
45
45
|
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
46
|
+
_BINARY_CLASSIFICATION_OBJECTIVES = ["binary"]
|
47
|
+
_MULTI_CLASSIFICATION_OBJECTIVES = ["multiclass", "multiclassova"]
|
48
|
+
_RANKING_OBJECTIVES = ["lambdarank", "rank_xendcg"]
|
49
|
+
_REGRESSION_OBJECTIVES = [
|
50
|
+
"regression",
|
51
|
+
"regression_l1",
|
52
|
+
"huber",
|
53
|
+
"fair",
|
54
|
+
"poisson",
|
55
|
+
"quantile",
|
56
|
+
"tweedie",
|
57
|
+
"mape",
|
58
|
+
"gamma",
|
59
|
+
]
|
60
|
+
|
61
|
+
@classmethod
|
62
|
+
def get_model_objective(cls, model: Union["lightgbm.Booster", "lightgbm.LGBMModel"]) -> _base.ModelObjective:
|
63
|
+
import lightgbm
|
64
|
+
|
65
|
+
# does not account for cross-entropy and custom
|
66
|
+
if isinstance(model, lightgbm.LGBMClassifier):
|
67
|
+
num_classes = handlers_utils.get_num_classes_if_exists(model)
|
68
|
+
if num_classes == 2:
|
69
|
+
return _base.ModelObjective.BINARY_CLASSIFICATION
|
70
|
+
return _base.ModelObjective.MULTI_CLASSIFICATION
|
71
|
+
if isinstance(model, lightgbm.LGBMRanker):
|
72
|
+
return _base.ModelObjective.RANKING
|
73
|
+
if isinstance(model, lightgbm.LGBMRegressor):
|
74
|
+
return _base.ModelObjective.REGRESSION
|
75
|
+
model_objective = model.params["objective"]
|
76
|
+
if model_objective in cls._BINARY_CLASSIFICATION_OBJECTIVES:
|
77
|
+
return _base.ModelObjective.BINARY_CLASSIFICATION
|
78
|
+
if model_objective in cls._MULTI_CLASSIFICATION_OBJECTIVES:
|
79
|
+
return _base.ModelObjective.MULTI_CLASSIFICATION
|
80
|
+
if model_objective in cls._RANKING_OBJECTIVES:
|
81
|
+
return _base.ModelObjective.RANKING
|
82
|
+
if model_objective in cls._REGRESSION_OBJECTIVES:
|
83
|
+
return _base.ModelObjective.REGRESSION
|
84
|
+
return _base.ModelObjective.UNKNOWN
|
46
85
|
|
47
86
|
@classmethod
|
48
87
|
def can_handle(
|
@@ -105,6 +144,19 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
105
144
|
sample_input_data=sample_input_data,
|
106
145
|
get_prediction_fn=get_prediction,
|
107
146
|
)
|
147
|
+
if kwargs.get("enable_explainability", False):
|
148
|
+
output_type = model_signature.DataType.DOUBLE
|
149
|
+
if cls.get_model_objective(model) in [
|
150
|
+
_base.ModelObjective.BINARY_CLASSIFICATION,
|
151
|
+
_base.ModelObjective.MULTI_CLASSIFICATION,
|
152
|
+
]:
|
153
|
+
output_type = model_signature.DataType.STRING
|
154
|
+
model_meta = handlers_utils.add_explain_method_signature(
|
155
|
+
model_meta=model_meta,
|
156
|
+
explain_method="explain",
|
157
|
+
target_method="predict",
|
158
|
+
output_return_type=output_type,
|
159
|
+
)
|
108
160
|
|
109
161
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
110
162
|
os.makedirs(model_blob_path, exist_ok=True)
|
@@ -130,6 +182,11 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
130
182
|
],
|
131
183
|
check_local_version=True,
|
132
184
|
)
|
185
|
+
if kwargs.get("enable_explainability", False):
|
186
|
+
model_meta.env.include_if_absent(
|
187
|
+
[model_env.ModelDependency(requirement="shap", pip_name="shap")],
|
188
|
+
check_local_version=True,
|
189
|
+
)
|
133
190
|
|
134
191
|
return None
|
135
192
|
|
@@ -139,7 +196,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
139
196
|
name: str,
|
140
197
|
model_meta: model_meta_api.ModelMetadata,
|
141
198
|
model_blobs_dir_path: str,
|
142
|
-
**kwargs: Unpack[model_types.
|
199
|
+
**kwargs: Unpack[model_types.LGBMModelLoadOptions],
|
143
200
|
) -> Union["lightgbm.Booster", "lightgbm.LGBMModel"]:
|
144
201
|
import lightgbm
|
145
202
|
|
@@ -169,7 +226,7 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
169
226
|
cls,
|
170
227
|
raw_model: Union["lightgbm.Booster", "lightgbm.XGBModel"],
|
171
228
|
model_meta: model_meta_api.ModelMetadata,
|
172
|
-
**kwargs: Unpack[model_types.
|
229
|
+
**kwargs: Unpack[model_types.LGBMModelLoadOptions],
|
173
230
|
) -> custom_model.CustomModel:
|
174
231
|
import lightgbm
|
175
232
|
|
@@ -198,6 +255,17 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
198
255
|
|
199
256
|
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
200
257
|
|
258
|
+
@custom_model.inference_api
|
259
|
+
def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
260
|
+
import shap
|
261
|
+
|
262
|
+
explainer = shap.TreeExplainer(raw_model)
|
263
|
+
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
|
264
|
+
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
265
|
+
|
266
|
+
if target_method == "explain":
|
267
|
+
return explain_fn
|
268
|
+
|
201
269
|
return fn
|
202
270
|
|
203
271
|
type_method_dict: Dict[str, Any] = {"_raw_model": raw_model}
|
@@ -118,7 +118,7 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
|
|
118
118
|
name: str,
|
119
119
|
model_meta: model_meta_api.ModelMetadata,
|
120
120
|
model_blobs_dir_path: str,
|
121
|
-
**kwargs: Unpack[model_types.
|
121
|
+
**kwargs: Unpack[model_types.LLMLoadOptions],
|
122
122
|
) -> llm.LLM:
|
123
123
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
124
124
|
if not hasattr(model_meta, "models"):
|
@@ -143,7 +143,7 @@ class LLMHandler(_base.BaseModelHandler[llm.LLM]):
|
|
143
143
|
cls,
|
144
144
|
raw_model: llm.LLM,
|
145
145
|
model_meta: model_meta_api.ModelMetadata,
|
146
|
-
**kwargs: Unpack[model_types.
|
146
|
+
**kwargs: Unpack[model_types.LLMLoadOptions],
|
147
147
|
) -> custom_model.CustomModel:
|
148
148
|
import gc
|
149
149
|
import tempfile
|
@@ -160,7 +160,7 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
|
|
160
160
|
name: str,
|
161
161
|
model_meta: model_meta_api.ModelMetadata,
|
162
162
|
model_blobs_dir_path: str,
|
163
|
-
**kwargs: Unpack[model_types.
|
163
|
+
**kwargs: Unpack[model_types.MLFlowLoadOptions],
|
164
164
|
) -> "mlflow.pyfunc.PyFuncModel":
|
165
165
|
import mlflow
|
166
166
|
|
@@ -194,7 +194,7 @@ class MLFlowHandler(_base.BaseModelHandler["mlflow.pyfunc.PyFuncModel"]):
|
|
194
194
|
cls,
|
195
195
|
raw_model: "mlflow.pyfunc.PyFuncModel",
|
196
196
|
model_meta: model_meta_api.ModelMetadata,
|
197
|
-
**kwargs: Unpack[model_types.
|
197
|
+
**kwargs: Unpack[model_types.MLFlowLoadOptions],
|
198
198
|
) -> custom_model.CustomModel:
|
199
199
|
from snowflake.ml.model import custom_model
|
200
200
|
|
@@ -137,7 +137,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
137
137
|
name: str,
|
138
138
|
model_meta: model_meta_api.ModelMetadata,
|
139
139
|
model_blobs_dir_path: str,
|
140
|
-
**kwargs: Unpack[model_types.
|
140
|
+
**kwargs: Unpack[model_types.PyTorchLoadOptions],
|
141
141
|
) -> "torch.nn.Module":
|
142
142
|
import torch
|
143
143
|
|
@@ -156,7 +156,7 @@ class PyTorchHandler(_base.BaseModelHandler["torch.nn.Module"]):
|
|
156
156
|
cls,
|
157
157
|
raw_model: "torch.nn.Module",
|
158
158
|
model_meta: model_meta_api.ModelMetadata,
|
159
|
-
**kwargs: Unpack[model_types.
|
159
|
+
**kwargs: Unpack[model_types.PyTorchLoadOptions],
|
160
160
|
) -> custom_model.CustomModel:
|
161
161
|
import torch
|
162
162
|
|
@@ -126,7 +126,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
126
126
|
name: str,
|
127
127
|
model_meta: model_meta_api.ModelMetadata,
|
128
128
|
model_blobs_dir_path: str,
|
129
|
-
**kwargs: Unpack[model_types.
|
129
|
+
**kwargs: Unpack[model_types.SentenceTransformersLoadOptions], # use_gpu
|
130
130
|
) -> "sentence_transformers.SentenceTransformer":
|
131
131
|
import sentence_transformers
|
132
132
|
|
@@ -154,7 +154,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
154
154
|
cls,
|
155
155
|
raw_model: "sentence_transformers.SentenceTransformer",
|
156
156
|
model_meta: model_meta_api.ModelMetadata,
|
157
|
-
**kwargs: Unpack[model_types.
|
157
|
+
**kwargs: Unpack[model_types.SentenceTransformersLoadOptions],
|
158
158
|
) -> custom_model.CustomModel:
|
159
159
|
import sentence_transformers
|
160
160
|
|
@@ -133,7 +133,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
133
133
|
name: str,
|
134
134
|
model_meta: model_meta_api.ModelMetadata,
|
135
135
|
model_blobs_dir_path: str,
|
136
|
-
**kwargs: Unpack[model_types.
|
136
|
+
**kwargs: Unpack[model_types.SKLModelLoadOptions],
|
137
137
|
) -> Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"]:
|
138
138
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
139
139
|
model_blobs_metadata = model_meta.models
|
@@ -153,7 +153,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
153
153
|
cls,
|
154
154
|
raw_model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
|
155
155
|
model_meta: model_meta_api.ModelMetadata,
|
156
|
-
**kwargs: Unpack[model_types.
|
156
|
+
**kwargs: Unpack[model_types.SKLModelLoadOptions],
|
157
157
|
) -> custom_model.CustomModel:
|
158
158
|
from snowflake.ml.model import custom_model
|
159
159
|
|
@@ -127,7 +127,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
127
127
|
name: str,
|
128
128
|
model_meta: model_meta_api.ModelMetadata,
|
129
129
|
model_blobs_dir_path: str,
|
130
|
-
**kwargs: Unpack[model_types.
|
130
|
+
**kwargs: Unpack[model_types.SNOWModelLoadOptions],
|
131
131
|
) -> "BaseEstimator":
|
132
132
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
133
133
|
model_blobs_metadata = model_meta.models
|
@@ -146,7 +146,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
146
146
|
cls,
|
147
147
|
raw_model: "BaseEstimator",
|
148
148
|
model_meta: model_meta_api.ModelMetadata,
|
149
|
-
**kwargs: Unpack[model_types.
|
149
|
+
**kwargs: Unpack[model_types.SNOWModelLoadOptions],
|
150
150
|
) -> custom_model.CustomModel:
|
151
151
|
from snowflake.ml.model import custom_model
|
152
152
|
|
@@ -138,7 +138,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
138
138
|
name: str,
|
139
139
|
model_meta: model_meta_api.ModelMetadata,
|
140
140
|
model_blobs_dir_path: str,
|
141
|
-
**kwargs: Unpack[model_types.
|
141
|
+
**kwargs: Unpack[model_types.TensorflowLoadOptions],
|
142
142
|
) -> "tensorflow.Module":
|
143
143
|
import tensorflow
|
144
144
|
|
@@ -156,7 +156,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
156
156
|
cls,
|
157
157
|
raw_model: "tensorflow.Module",
|
158
158
|
model_meta: model_meta_api.ModelMetadata,
|
159
|
-
**kwargs: Unpack[model_types.
|
159
|
+
**kwargs: Unpack[model_types.TensorflowLoadOptions],
|
160
160
|
) -> custom_model.CustomModel:
|
161
161
|
import tensorflow
|
162
162
|
|
@@ -128,7 +128,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
128
128
|
name: str,
|
129
129
|
model_meta: model_meta_api.ModelMetadata,
|
130
130
|
model_blobs_dir_path: str,
|
131
|
-
**kwargs: Unpack[model_types.
|
131
|
+
**kwargs: Unpack[model_types.TorchScriptLoadOptions],
|
132
132
|
) -> "torch.jit.ScriptModule": # type:ignore[name-defined]
|
133
133
|
import torch
|
134
134
|
|
@@ -152,7 +152,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
152
152
|
cls,
|
153
153
|
raw_model: "torch.jit.ScriptModule", # type:ignore[name-defined]
|
154
154
|
model_meta: model_meta_api.ModelMetadata,
|
155
|
-
**kwargs: Unpack[model_types.
|
155
|
+
**kwargs: Unpack[model_types.TorchScriptLoadOptions],
|
156
156
|
) -> custom_model.CustomModel:
|
157
157
|
from snowflake.ml.model import custom_model
|
158
158
|
|
@@ -1,4 +1,5 @@
|
|
1
1
|
# mypy: disable-error-code="import"
|
2
|
+
import json
|
2
3
|
import os
|
3
4
|
from typing import (
|
4
5
|
TYPE_CHECKING,
|
@@ -46,6 +47,39 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
46
47
|
|
47
48
|
MODELE_BLOB_FILE_OR_DIR = "model.ubj"
|
48
49
|
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
50
|
+
_BINARY_CLASSIFICATION_OBJECTIVE_PREFIX = ["binary:"]
|
51
|
+
_MULTI_CLASSIFICATION_OBJECTIVE_PREFIX = ["multi:"]
|
52
|
+
_RANKING_OBJECTIVE_PREFIX = ["rank:"]
|
53
|
+
_REGRESSION_OBJECTIVE_PREFIX = ["reg:"]
|
54
|
+
|
55
|
+
@classmethod
|
56
|
+
def get_model_objective(cls, model: Union["xgboost.Booster", "xgboost.XGBModel"]) -> _base.ModelObjective:
|
57
|
+
import xgboost
|
58
|
+
|
59
|
+
if isinstance(model, xgboost.XGBClassifier) or isinstance(model, xgboost.XGBRFClassifier):
|
60
|
+
num_classes = handlers_utils.get_num_classes_if_exists(model)
|
61
|
+
if num_classes == 2:
|
62
|
+
return _base.ModelObjective.BINARY_CLASSIFICATION
|
63
|
+
return _base.ModelObjective.MULTI_CLASSIFICATION
|
64
|
+
if isinstance(model, xgboost.XGBRegressor) or isinstance(model, xgboost.XGBRFRegressor):
|
65
|
+
return _base.ModelObjective.REGRESSION
|
66
|
+
if isinstance(model, xgboost.XGBRanker):
|
67
|
+
return _base.ModelObjective.RANKING
|
68
|
+
model_params = json.loads(model.save_config())
|
69
|
+
model_objective = model_params["learner"]["objective"]
|
70
|
+
for classification_objective in cls._BINARY_CLASSIFICATION_OBJECTIVE_PREFIX:
|
71
|
+
if classification_objective in model_objective:
|
72
|
+
return _base.ModelObjective.BINARY_CLASSIFICATION
|
73
|
+
for classification_objective in cls._MULTI_CLASSIFICATION_OBJECTIVE_PREFIX:
|
74
|
+
if classification_objective in model_objective:
|
75
|
+
return _base.ModelObjective.MULTI_CLASSIFICATION
|
76
|
+
for ranking_objective in cls._RANKING_OBJECTIVE_PREFIX:
|
77
|
+
if ranking_objective in model_objective:
|
78
|
+
return _base.ModelObjective.RANKING
|
79
|
+
for regression_objective in cls._REGRESSION_OBJECTIVE_PREFIX:
|
80
|
+
if regression_objective in model_objective:
|
81
|
+
return _base.ModelObjective.REGRESSION
|
82
|
+
return _base.ModelObjective.UNKNOWN
|
49
83
|
|
50
84
|
@classmethod
|
51
85
|
def can_handle(
|
@@ -112,6 +146,16 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
112
146
|
sample_input_data=sample_input_data,
|
113
147
|
get_prediction_fn=get_prediction,
|
114
148
|
)
|
149
|
+
if kwargs.get("enable_explainability", False):
|
150
|
+
output_type = model_signature.DataType.DOUBLE
|
151
|
+
if cls.get_model_objective(model) == _base.ModelObjective.MULTI_CLASSIFICATION:
|
152
|
+
output_type = model_signature.DataType.STRING
|
153
|
+
model_meta = handlers_utils.add_explain_method_signature(
|
154
|
+
model_meta=model_meta,
|
155
|
+
explain_method="explain",
|
156
|
+
target_method="predict",
|
157
|
+
output_return_type=output_type,
|
158
|
+
)
|
115
159
|
|
116
160
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
117
161
|
os.makedirs(model_blob_path, exist_ok=True)
|
@@ -133,6 +177,11 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
133
177
|
],
|
134
178
|
check_local_version=True,
|
135
179
|
)
|
180
|
+
if kwargs.get("enable_explainability", False):
|
181
|
+
model_meta.env.include_if_absent(
|
182
|
+
[model_env.ModelDependency(requirement="shap", pip_name="shap")],
|
183
|
+
check_local_version=True,
|
184
|
+
)
|
136
185
|
model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
|
137
186
|
|
138
187
|
@classmethod
|
@@ -141,7 +190,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
141
190
|
name: str,
|
142
191
|
model_meta: model_meta_api.ModelMetadata,
|
143
192
|
model_blobs_dir_path: str,
|
144
|
-
**kwargs: Unpack[model_types.
|
193
|
+
**kwargs: Unpack[model_types.XGBModelLoadOptions],
|
145
194
|
) -> Union["xgboost.Booster", "xgboost.XGBModel"]:
|
146
195
|
import xgboost
|
147
196
|
|
@@ -175,7 +224,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
175
224
|
cls,
|
176
225
|
raw_model: Union["xgboost.Booster", "xgboost.XGBModel"],
|
177
226
|
model_meta: model_meta_api.ModelMetadata,
|
178
|
-
**kwargs: Unpack[model_types.
|
227
|
+
**kwargs: Unpack[model_types.XGBModelLoadOptions],
|
179
228
|
) -> custom_model.CustomModel:
|
180
229
|
import xgboost
|
181
230
|
|
@@ -206,6 +255,16 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
206
255
|
|
207
256
|
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
208
257
|
|
258
|
+
@custom_model.inference_api
|
259
|
+
def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
260
|
+
import shap
|
261
|
+
|
262
|
+
explainer = shap.TreeExplainer(raw_model)
|
263
|
+
df = pd.DataFrame(explainer(X).values)
|
264
|
+
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
265
|
+
|
266
|
+
if target_method == "explain":
|
267
|
+
return explain_fn
|
209
268
|
return fn
|
210
269
|
|
211
270
|
type_method_dict: Dict[str, Any] = {"_raw_model": raw_model}
|
@@ -23,6 +23,7 @@ class ModelBlobMeta:
|
|
23
23
|
self.model_type = kwargs["model_type"]
|
24
24
|
self.path = kwargs["path"]
|
25
25
|
self.handler_version = kwargs["handler_version"]
|
26
|
+
self.function_properties = kwargs.get("function_properties", {})
|
26
27
|
|
27
28
|
self.artifacts: Dict[str, str] = {}
|
28
29
|
artifacts = kwargs.get("artifacts", None)
|
@@ -39,6 +40,7 @@ class ModelBlobMeta:
|
|
39
40
|
model_type=self.model_type,
|
40
41
|
path=self.path,
|
41
42
|
handler_version=self.handler_version,
|
43
|
+
function_properties=self.function_properties,
|
42
44
|
artifacts=self.artifacts,
|
43
45
|
options=self.options,
|
44
46
|
)
|
@@ -7,11 +7,12 @@ import zipfile
|
|
7
7
|
from contextlib import contextmanager
|
8
8
|
from datetime import datetime
|
9
9
|
from types import ModuleType
|
10
|
-
from typing import Any, Dict, Generator, List, Optional
|
10
|
+
from typing import Any, Dict, Generator, List, Optional, TypedDict
|
11
11
|
|
12
12
|
import cloudpickle
|
13
13
|
import yaml
|
14
14
|
from packaging import requirements, version
|
15
|
+
from typing_extensions import Required
|
15
16
|
|
16
17
|
from snowflake.ml._internal import env as snowml_env, env_utils, file_utils
|
17
18
|
from snowflake.ml.model import model_signature, type_hints as model_types
|
@@ -47,6 +48,7 @@ def create_model_metadata(
|
|
47
48
|
name: str,
|
48
49
|
model_type: model_types.SupportedModelHandlerType,
|
49
50
|
signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
|
51
|
+
function_properties: Optional[Dict[str, Dict[str, Any]]] = None,
|
50
52
|
metadata: Optional[Dict[str, str]] = None,
|
51
53
|
code_paths: Optional[List[str]] = None,
|
52
54
|
ext_modules: Optional[List[ModuleType]] = None,
|
@@ -64,6 +66,7 @@ def create_model_metadata(
|
|
64
66
|
model_type: Type of the model.
|
65
67
|
signatures: Signatures of the model. If None, it will be inferred after the model meta is created.
|
66
68
|
Defaults to None.
|
69
|
+
function_properties: Dict mapping function names to a dict of properties, mapping property key to value.
|
67
70
|
metadata: User provided key-value metadata of the model. Defaults to None.
|
68
71
|
code_paths: List of paths to additional codes that needs to be packed with. Defaults to None.
|
69
72
|
ext_modules: List of names of modules that need to be pickled with the model. Defaults to None.
|
@@ -127,6 +130,7 @@ def create_model_metadata(
|
|
127
130
|
metadata=metadata,
|
128
131
|
model_type=model_type,
|
129
132
|
signatures=signatures,
|
133
|
+
function_properties=function_properties,
|
130
134
|
)
|
131
135
|
|
132
136
|
code_dir_path = os.path.join(model_dir_path, MODEL_CODE_DIR)
|
@@ -215,6 +219,12 @@ def load_code_path(model_dir_path: str) -> None:
|
|
215
219
|
sys.path.insert(0, code_path)
|
216
220
|
|
217
221
|
|
222
|
+
class ModelMetadataTelemetryDict(TypedDict):
|
223
|
+
model_name: Required[str]
|
224
|
+
framework_type: Required[model_types.SupportedModelHandlerType]
|
225
|
+
number_of_functions: Required[int]
|
226
|
+
|
227
|
+
|
218
228
|
class ModelMetadata:
|
219
229
|
"""Model metadata for Snowflake native model packaged model.
|
220
230
|
|
@@ -224,10 +234,18 @@ class ModelMetadata:
|
|
224
234
|
env: ModelEnv object containing all environment related object
|
225
235
|
models: Dict of model blob metadata
|
226
236
|
signatures: A dict mapping from target function name to input and output signatures.
|
237
|
+
function_properties: A dict mapping function names to dict mapping function property key to value.
|
227
238
|
metadata: User provided key-value metadata of the model. Defaults to None.
|
228
239
|
creation_timestamp: Unix timestamp when the model metadata is created.
|
229
240
|
"""
|
230
241
|
|
242
|
+
def telemetry_metadata(self) -> ModelMetadataTelemetryDict:
|
243
|
+
return ModelMetadataTelemetryDict(
|
244
|
+
model_name=self.name,
|
245
|
+
framework_type=self.model_type,
|
246
|
+
number_of_functions=len(self.signatures.keys()),
|
247
|
+
)
|
248
|
+
|
231
249
|
def __init__(
|
232
250
|
self,
|
233
251
|
*,
|
@@ -236,6 +254,7 @@ class ModelMetadata:
|
|
236
254
|
model_type: model_types.SupportedModelHandlerType,
|
237
255
|
runtimes: Optional[Dict[str, model_runtime.ModelRuntime]] = None,
|
238
256
|
signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
|
257
|
+
function_properties: Optional[Dict[str, Dict[str, Any]]] = None,
|
239
258
|
metadata: Optional[Dict[str, str]] = None,
|
240
259
|
creation_timestamp: Optional[str] = None,
|
241
260
|
min_snowpark_ml_version: Optional[str] = None,
|
@@ -246,6 +265,7 @@ class ModelMetadata:
|
|
246
265
|
self.signatures: Dict[str, model_signature.ModelSignature] = dict()
|
247
266
|
if signatures:
|
248
267
|
self.signatures = signatures
|
268
|
+
self.function_properties = function_properties or {}
|
249
269
|
self.metadata = metadata
|
250
270
|
self.model_type = model_type
|
251
271
|
self.env = env
|
@@ -1,6 +1,6 @@
|
|
1
1
|
# This files contains schema definition of what will be written into model.yml
|
2
2
|
# Changing this file should lead to a change of the schema version.
|
3
|
-
|
3
|
+
from enum import Enum
|
4
4
|
from typing import Any, Dict, List, Optional, TypedDict, Union
|
5
5
|
|
6
6
|
from typing_extensions import NotRequired, Required
|
@@ -11,6 +11,10 @@ MODEL_METADATA_VERSION = "2023-12-01"
|
|
11
11
|
MODEL_METADATA_MIN_SNOWPARK_ML_VERSION = "1.0.12"
|
12
12
|
|
13
13
|
|
14
|
+
class FunctionProperties(Enum):
|
15
|
+
PARTITIONED = "PARTITIONED"
|
16
|
+
|
17
|
+
|
14
18
|
class ModelRuntimeDependenciesDict(TypedDict):
|
15
19
|
conda: Required[str]
|
16
20
|
pip: Required[str]
|
@@ -72,6 +76,7 @@ class ModelBlobMetadataDict(TypedDict):
|
|
72
76
|
model_type: Required[type_hints.SupportedModelHandlerType]
|
73
77
|
path: Required[str]
|
74
78
|
handler_version: Required[str]
|
79
|
+
function_properties: NotRequired[Dict[str, Dict[str, Any]]]
|
75
80
|
artifacts: NotRequired[Dict[str, str]]
|
76
81
|
options: NotRequired[ModelBlobOptions]
|
77
82
|
|
@@ -47,12 +47,12 @@ class ModelPackager:
|
|
47
47
|
ext_modules: Optional[List[ModuleType]] = None,
|
48
48
|
code_paths: Optional[List[str]] = None,
|
49
49
|
options: Optional[model_types.ModelSaveOption] = None,
|
50
|
-
) ->
|
50
|
+
) -> model_meta.ModelMetadata:
|
51
51
|
if (signatures is None) and (sample_input_data is None) and not model_handler.is_auto_signature_model(model):
|
52
52
|
raise snowml_exceptions.SnowflakeMLException(
|
53
53
|
error_code=error_codes.INVALID_ARGUMENT,
|
54
54
|
original_exception=ValueError(
|
55
|
-
"
|
55
|
+
"Either of `signatures` or `sample_input_data` must be provided for this kind of model."
|
56
56
|
),
|
57
57
|
)
|
58
58
|
|
@@ -103,6 +103,7 @@ class ModelPackager:
|
|
103
103
|
|
104
104
|
self.model = model
|
105
105
|
self.meta = meta
|
106
|
+
return meta
|
106
107
|
|
107
108
|
def load(
|
108
109
|
self,
|
@@ -110,7 +111,7 @@ class ModelPackager:
|
|
110
111
|
meta_only: bool = False,
|
111
112
|
as_custom_model: bool = False,
|
112
113
|
options: Optional[model_types.ModelLoadOption] = None,
|
113
|
-
) ->
|
114
|
+
) -> model_meta.ModelMetadata:
|
114
115
|
"""Load the model into memory from directory. Used internal only.
|
115
116
|
|
116
117
|
Args:
|
@@ -120,11 +121,14 @@ class ModelPackager:
|
|
120
121
|
|
121
122
|
Raises:
|
122
123
|
SnowflakeMLException: Raised if model is not native format.
|
124
|
+
|
125
|
+
Returns:
|
126
|
+
Metadata of loaded model.
|
123
127
|
"""
|
124
128
|
|
125
129
|
self.meta = model_meta.ModelMetadata.load(self.local_dir_path)
|
126
130
|
if meta_only:
|
127
|
-
return
|
131
|
+
return self.meta
|
128
132
|
|
129
133
|
model_meta.load_code_path(self.local_dir_path)
|
130
134
|
|
@@ -146,3 +150,4 @@ class ModelPackager:
|
|
146
150
|
assert isinstance(m, custom_model.CustomModel)
|
147
151
|
|
148
152
|
self.model = m
|
153
|
+
return self.meta
|