snowflake-ml-python 1.5.3__py3-none-any.whl → 1.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +4 -1
- snowflake/cortex/_classify_text.py +36 -0
- snowflake/cortex/_complete.py +281 -21
- snowflake/cortex/_extract_answer.py +0 -1
- snowflake/cortex/_sentiment.py +0 -1
- snowflake/cortex/_summarize.py +0 -1
- snowflake/cortex/_translate.py +0 -1
- snowflake/cortex/_util.py +12 -85
- snowflake/ml/_internal/container_services/image_registry/http_client.py +10 -3
- snowflake/ml/_internal/container_services/image_registry/imagelib.py +23 -10
- snowflake/ml/_internal/container_services/image_registry/registry_client.py +7 -1
- snowflake/ml/_internal/exceptions/dataset_errors.py +7 -7
- snowflake/ml/_internal/exceptions/fileset_errors.py +3 -3
- snowflake/ml/_internal/exceptions/sql_error_codes.py +6 -0
- snowflake/ml/_internal/lineage/lineage_utils.py +4 -4
- snowflake/ml/_internal/telemetry.py +38 -2
- snowflake/ml/_internal/utils/identifier.py +14 -0
- snowflake/ml/_internal/utils/snowpark_dataframe_utils.py +15 -4
- snowflake/ml/data/_internal/arrow_ingestor.py +228 -0
- snowflake/ml/data/_internal/ingestor_utils.py +58 -0
- snowflake/ml/data/data_connector.py +133 -0
- snowflake/ml/data/data_ingestor.py +28 -0
- snowflake/ml/data/data_source.py +23 -0
- snowflake/ml/dataset/dataset.py +39 -32
- snowflake/ml/dataset/dataset_reader.py +18 -118
- snowflake/ml/feature_store/access_manager.py +7 -1
- snowflake/ml/feature_store/entity.py +19 -2
- snowflake/ml/feature_store/examples/citibike_trip_features/entities.py +20 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +31 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +24 -0
- snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +4 -0
- snowflake/ml/feature_store/examples/example_helper.py +240 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +12 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/dropoff_features.py +39 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +58 -0
- snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -0
- snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +36 -0
- snowflake/ml/feature_store/examples/source_data/fraud_transactions.yaml +29 -0
- snowflake/ml/feature_store/examples/source_data/nyc_yellow_trips.yaml +4 -0
- snowflake/ml/feature_store/examples/source_data/winequality_red.yaml +32 -0
- snowflake/ml/feature_store/examples/wine_quality_features/entities.py +14 -0
- snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +29 -0
- snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +21 -0
- snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +5 -0
- snowflake/ml/feature_store/feature_store.py +987 -264
- snowflake/ml/feature_store/feature_view.py +228 -13
- snowflake/ml/fileset/embedded_stage_fs.py +25 -21
- snowflake/ml/fileset/fileset.py +2 -2
- snowflake/ml/fileset/snowfs.py +4 -15
- snowflake/ml/fileset/stage_fs.py +24 -18
- snowflake/ml/lineage/__init__.py +3 -0
- snowflake/ml/lineage/lineage_node.py +139 -0
- snowflake/ml/model/_client/model/model_impl.py +47 -14
- snowflake/ml/model/_client/model/model_version_impl.py +82 -2
- snowflake/ml/model/_client/ops/model_ops.py +77 -5
- snowflake/ml/model/_client/sql/model.py +1 -0
- snowflake/ml/model/_client/sql/model_version.py +45 -2
- snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
- snowflake/ml/model/_model_composer/model_composer.py +15 -17
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -17
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -1
- snowflake/ml/model/_model_composer/model_method/function_generator.py +20 -4
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +3 -32
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +55 -0
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +5 -34
- snowflake/ml/model/_model_composer/model_method/model_method.py +10 -7
- snowflake/ml/model/_packager/model_handlers/_base.py +13 -3
- snowflake/ml/model/_packager/model_handlers/_utils.py +59 -1
- snowflake/ml/model/_packager/model_handlers/catboost.py +44 -2
- snowflake/ml/model/_packager/model_handlers/custom.py +12 -4
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +18 -15
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +70 -2
- snowflake/ml/model/_packager/model_handlers/llm.py +2 -2
- snowflake/ml/model/_packager/model_handlers/mlflow.py +2 -2
- snowflake/ml/model/_packager/model_handlers/pytorch.py +2 -2
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +2 -2
- snowflake/ml/model/_packager/model_handlers/sklearn.py +2 -2
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +2 -2
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +2 -2
- snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
- snowflake/ml/model/_packager/model_handlers/xgboost.py +61 -2
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_meta/model_blob_meta.py +2 -0
- snowflake/ml/model/_packager/model_meta/model_meta.py +21 -1
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
- snowflake/ml/model/_packager/model_packager.py +9 -4
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +1 -1
- snowflake/ml/model/_packager/model_runtime/model_runtime.py +3 -5
- snowflake/ml/model/custom_model.py +22 -2
- snowflake/ml/model/model_signature.py +4 -4
- snowflake/ml/model/type_hints.py +77 -4
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +3 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_search_udf_file.py +13 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +1 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +6 -0
- snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +1 -0
- snowflake/ml/modeling/cluster/affinity_propagation.py +4 -2
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +4 -2
- snowflake/ml/modeling/cluster/birch.py +4 -2
- snowflake/ml/modeling/cluster/bisecting_k_means.py +4 -2
- snowflake/ml/modeling/cluster/dbscan.py +4 -2
- snowflake/ml/modeling/cluster/feature_agglomeration.py +4 -2
- snowflake/ml/modeling/cluster/k_means.py +4 -2
- snowflake/ml/modeling/cluster/mean_shift.py +4 -2
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +4 -2
- snowflake/ml/modeling/cluster/optics.py +4 -2
- snowflake/ml/modeling/cluster/spectral_biclustering.py +4 -2
- snowflake/ml/modeling/cluster/spectral_clustering.py +4 -2
- snowflake/ml/modeling/cluster/spectral_coclustering.py +4 -2
- snowflake/ml/modeling/compose/column_transformer.py +4 -2
- snowflake/ml/modeling/covariance/elliptic_envelope.py +4 -2
- snowflake/ml/modeling/covariance/empirical_covariance.py +4 -2
- snowflake/ml/modeling/covariance/graphical_lasso.py +4 -2
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +4 -2
- snowflake/ml/modeling/covariance/ledoit_wolf.py +4 -2
- snowflake/ml/modeling/covariance/min_cov_det.py +4 -2
- snowflake/ml/modeling/covariance/oas.py +4 -2
- snowflake/ml/modeling/covariance/shrunk_covariance.py +4 -2
- snowflake/ml/modeling/decomposition/dictionary_learning.py +4 -2
- snowflake/ml/modeling/decomposition/factor_analysis.py +4 -2
- snowflake/ml/modeling/decomposition/fast_ica.py +4 -2
- snowflake/ml/modeling/decomposition/incremental_pca.py +4 -2
- snowflake/ml/modeling/decomposition/kernel_pca.py +4 -2
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +4 -2
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +4 -2
- snowflake/ml/modeling/decomposition/pca.py +4 -2
- snowflake/ml/modeling/decomposition/sparse_pca.py +4 -2
- snowflake/ml/modeling/decomposition/truncated_svd.py +4 -2
- snowflake/ml/modeling/ensemble/isolation_forest.py +4 -2
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +4 -2
- snowflake/ml/modeling/feature_selection/variance_threshold.py +4 -2
- snowflake/ml/modeling/impute/iterative_imputer.py +4 -2
- snowflake/ml/modeling/impute/knn_imputer.py +4 -2
- snowflake/ml/modeling/impute/missing_indicator.py +4 -2
- snowflake/ml/modeling/impute/simple_imputer.py +26 -0
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +4 -2
- snowflake/ml/modeling/kernel_approximation/nystroem.py +4 -2
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +4 -2
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +4 -2
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +4 -2
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +4 -2
- snowflake/ml/modeling/manifold/isomap.py +4 -2
- snowflake/ml/modeling/manifold/mds.py +4 -2
- snowflake/ml/modeling/manifold/spectral_embedding.py +4 -2
- snowflake/ml/modeling/manifold/tsne.py +4 -2
- snowflake/ml/modeling/metrics/ranking.py +3 -0
- snowflake/ml/modeling/metrics/regression.py +3 -0
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +4 -2
- snowflake/ml/modeling/mixture/gaussian_mixture.py +4 -2
- snowflake/ml/modeling/neighbors/kernel_density.py +4 -2
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +4 -2
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +4 -2
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +4 -2
- snowflake/ml/modeling/pipeline/pipeline.py +5 -4
- snowflake/ml/modeling/preprocessing/one_hot_encoder.py +43 -9
- snowflake/ml/modeling/preprocessing/ordinal_encoder.py +36 -8
- snowflake/ml/modeling/preprocessing/polynomial_features.py +4 -2
- snowflake/ml/registry/_manager/model_manager.py +16 -3
- snowflake/ml/registry/registry.py +100 -13
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/METADATA +81 -7
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/RECORD +165 -139
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/lineage/data_source.py +0 -10
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.5.3.dist-info → snowflake_ml_python-1.6.0.dist-info}/top_level.txt +0 -0
@@ -1,19 +1,22 @@
|
|
1
1
|
import collections
|
2
2
|
import copy
|
3
3
|
import pathlib
|
4
|
+
import warnings
|
4
5
|
from typing import List, Optional, cast
|
5
6
|
|
6
7
|
import yaml
|
7
8
|
|
8
|
-
from snowflake.ml.
|
9
|
+
from snowflake.ml.data import data_source
|
9
10
|
from snowflake.ml.model import type_hints
|
10
11
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
11
12
|
from snowflake.ml.model._model_composer.model_method import (
|
12
13
|
function_generator,
|
13
14
|
model_method,
|
14
15
|
)
|
15
|
-
from snowflake.ml.model._packager.model_meta import
|
16
|
-
|
16
|
+
from snowflake.ml.model._packager.model_meta import (
|
17
|
+
model_meta as model_meta_api,
|
18
|
+
model_meta_schema,
|
19
|
+
)
|
17
20
|
|
18
21
|
|
19
22
|
class ModelManifest:
|
@@ -33,9 +36,8 @@ class ModelManifest:
|
|
33
36
|
|
34
37
|
def save(
|
35
38
|
self,
|
36
|
-
session: Session,
|
37
39
|
model_meta: model_meta_api.ModelMetadata,
|
38
|
-
|
40
|
+
model_rel_path: pathlib.PurePosixPath,
|
39
41
|
options: Optional[type_hints.ModelSaveOption] = None,
|
40
42
|
data_sources: Optional[List[data_source.DataSource]] = None,
|
41
43
|
) -> None:
|
@@ -44,10 +46,10 @@ class ModelManifest:
|
|
44
46
|
|
45
47
|
runtime_to_use = copy.deepcopy(model_meta.runtimes["cpu"])
|
46
48
|
runtime_to_use.name = self._DEFAULT_RUNTIME_NAME
|
47
|
-
runtime_to_use.imports.append(
|
49
|
+
runtime_to_use.imports.append(str(model_rel_path) + "/")
|
48
50
|
runtime_dict = runtime_to_use.save(self.workspace_path)
|
49
51
|
|
50
|
-
self.function_generator = function_generator.FunctionGenerator(
|
52
|
+
self.function_generator = function_generator.FunctionGenerator(model_dir_rel_path=model_rel_path)
|
51
53
|
self.methods: List[model_method.ModelMethod] = []
|
52
54
|
for target_method in model_meta.signatures.keys():
|
53
55
|
method = model_method.ModelMethod(
|
@@ -55,6 +57,9 @@ class ModelManifest:
|
|
55
57
|
target_method=target_method,
|
56
58
|
runtime_name=self._DEFAULT_RUNTIME_NAME,
|
57
59
|
function_generator=self.function_generator,
|
60
|
+
is_partitioned_function=model_meta.function_properties.get(target_method, {}).get(
|
61
|
+
model_meta_schema.FunctionProperties.PARTITIONED.value, False
|
62
|
+
),
|
58
63
|
options=model_method.get_model_method_options_from_options(options, target_method),
|
59
64
|
)
|
60
65
|
|
@@ -69,6 +74,16 @@ class ModelManifest:
|
|
69
74
|
"In this case, set case_sensitive as True for those methods to distinguish them."
|
70
75
|
)
|
71
76
|
|
77
|
+
dependencies = model_manifest_schema.ModelRuntimeDependenciesDict(conda=runtime_dict["dependencies"]["conda"])
|
78
|
+
if options.get("include_pip_dependencies"):
|
79
|
+
warnings.warn(
|
80
|
+
"`include_pip_dependencies` specified as True: pip dependencies will be included and may not"
|
81
|
+
"be warehouse-compabible. The model may need to be run in SPCS.",
|
82
|
+
category=UserWarning,
|
83
|
+
stacklevel=1,
|
84
|
+
)
|
85
|
+
dependencies["pip"] = runtime_dict["dependencies"]["pip"]
|
86
|
+
|
72
87
|
manifest_dict = model_manifest_schema.ModelManifestDict(
|
73
88
|
manifest_version=model_manifest_schema.MODEL_MANIFEST_VERSION,
|
74
89
|
runtimes={
|
@@ -76,9 +91,7 @@ class ModelManifest:
|
|
76
91
|
language="PYTHON",
|
77
92
|
version=runtime_to_use.runtime_env.python_version,
|
78
93
|
imports=runtime_dict["imports"],
|
79
|
-
dependencies=
|
80
|
-
conda=runtime_dict["dependencies"]["conda"]
|
81
|
-
),
|
94
|
+
dependencies=dependencies,
|
82
95
|
)
|
83
96
|
},
|
84
97
|
methods=[
|
@@ -121,12 +134,13 @@ class ModelManifest:
|
|
121
134
|
result = []
|
122
135
|
if data_sources:
|
123
136
|
for source in data_sources:
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
137
|
+
if isinstance(source, data_source.DatasetInfo):
|
138
|
+
result.append(
|
139
|
+
model_manifest_schema.LineageSourceDict(
|
140
|
+
# Currently, we only support lineage from Dataset.
|
141
|
+
type=model_manifest_schema.LineageSourceTypes.DATASET.value,
|
142
|
+
entity=source.fully_qualified_name,
|
143
|
+
version=source.version,
|
144
|
+
)
|
130
145
|
)
|
131
|
-
)
|
132
146
|
return result
|
@@ -3,7 +3,14 @@ from typing import Optional, TypedDict
|
|
3
3
|
|
4
4
|
from typing_extensions import NotRequired
|
5
5
|
|
6
|
+
from snowflake.ml._internal.exceptions import (
|
7
|
+
error_codes,
|
8
|
+
exceptions as snowml_exceptions,
|
9
|
+
)
|
6
10
|
from snowflake.ml.model import type_hints
|
11
|
+
from snowflake.ml.model._model_composer.model_manifest.model_manifest_schema import (
|
12
|
+
ModelMethodFunctionTypes,
|
13
|
+
)
|
7
14
|
|
8
15
|
|
9
16
|
class FunctionGenerateOptions(TypedDict):
|
@@ -26,15 +33,16 @@ class FunctionGenerator:
|
|
26
33
|
|
27
34
|
def __init__(
|
28
35
|
self,
|
29
|
-
|
36
|
+
model_dir_rel_path: pathlib.PurePosixPath,
|
30
37
|
) -> None:
|
31
|
-
self.
|
38
|
+
self.model_dir_rel_path = model_dir_rel_path
|
32
39
|
|
33
40
|
def generate(
|
34
41
|
self,
|
35
42
|
function_file_path: pathlib.Path,
|
36
43
|
target_method: str,
|
37
44
|
function_type: str,
|
45
|
+
is_partitioned_function: bool = False,
|
38
46
|
options: Optional[FunctionGenerateOptions] = None,
|
39
47
|
) -> None:
|
40
48
|
import importlib_resources
|
@@ -42,7 +50,15 @@ class FunctionGenerator:
|
|
42
50
|
if options is None:
|
43
51
|
options = {}
|
44
52
|
|
45
|
-
|
53
|
+
if is_partitioned_function:
|
54
|
+
if function_type != ModelMethodFunctionTypes.TABLE_FUNCTION.value:
|
55
|
+
raise snowml_exceptions.SnowflakeMLException(
|
56
|
+
error_code=error_codes.INVALID_DATA,
|
57
|
+
original_exception=ValueError("Partitioned inference api functions must have type TABLE_FUNCTION."),
|
58
|
+
)
|
59
|
+
template_filename = "infer_partitioned.py_template"
|
60
|
+
else:
|
61
|
+
template_filename = f"infer_{function_type.lower()}.py_template"
|
46
62
|
|
47
63
|
function_template = (
|
48
64
|
importlib_resources.files("snowflake.ml.model._model_composer.model_method")
|
@@ -51,7 +67,7 @@ class FunctionGenerator:
|
|
51
67
|
)
|
52
68
|
|
53
69
|
udf_code = function_template.format(
|
54
|
-
|
70
|
+
model_dir_name=self.model_dir_rel_path.name,
|
55
71
|
target_method=target_method,
|
56
72
|
max_batch_size=options.get("max_batch_size", None),
|
57
73
|
function_name=FunctionGenerator.FUNCTION_NAME,
|
@@ -1,12 +1,7 @@
|
|
1
|
-
import fcntl
|
2
1
|
import functools
|
3
2
|
import inspect
|
4
3
|
import os
|
5
4
|
import sys
|
6
|
-
import threading
|
7
|
-
import zipfile
|
8
|
-
from types import TracebackType
|
9
|
-
from typing import Optional, Type
|
10
5
|
|
11
6
|
import anyio
|
12
7
|
import pandas as pd
|
@@ -15,42 +10,18 @@ from _snowflake import vectorized
|
|
15
10
|
from snowflake.ml.model._packager import model_packager
|
16
11
|
|
17
12
|
|
18
|
-
class FileLock:
|
19
|
-
def __enter__(self) -> None:
|
20
|
-
self._lock = threading.Lock()
|
21
|
-
self._lock.acquire()
|
22
|
-
self._fd = open("/tmp/lockfile.LOCK", "w+")
|
23
|
-
fcntl.lockf(self._fd, fcntl.LOCK_EX)
|
24
|
-
|
25
|
-
def __exit__(
|
26
|
-
self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType]
|
27
|
-
) -> None:
|
28
|
-
self._fd.close()
|
29
|
-
self._lock.release()
|
30
|
-
|
31
|
-
|
32
13
|
# User-defined parameters
|
33
|
-
|
14
|
+
MODEL_DIR_REL_PATH = "{model_dir_name}"
|
34
15
|
TARGET_METHOD = "{target_method}"
|
35
16
|
MAX_BATCH_SIZE = {max_batch_size}
|
36
17
|
|
37
|
-
|
38
18
|
# Retrieve the model
|
39
19
|
IMPORT_DIRECTORY_NAME = "snowflake_import_directory"
|
40
20
|
import_dir = sys._xoptions[IMPORT_DIRECTORY_NAME]
|
41
|
-
|
42
|
-
model_dir_name = os.path.splitext(MODEL_FILE_NAME)[0]
|
43
|
-
zip_model_path = os.path.join(import_dir, MODEL_FILE_NAME)
|
44
|
-
extracted = "/tmp/models"
|
45
|
-
extracted_model_dir_path = os.path.join(extracted, model_dir_name)
|
46
|
-
|
47
|
-
with FileLock():
|
48
|
-
if not os.path.isdir(extracted_model_dir_path):
|
49
|
-
with zipfile.ZipFile(zip_model_path, "r") as myzip:
|
50
|
-
myzip.extractall(extracted_model_dir_path)
|
21
|
+
model_dir_path = os.path.join(import_dir, MODEL_DIR_REL_PATH)
|
51
22
|
|
52
23
|
# Load the model
|
53
|
-
pk = model_packager.ModelPackager(
|
24
|
+
pk = model_packager.ModelPackager(model_dir_path)
|
54
25
|
pk.load(as_custom_model=True)
|
55
26
|
assert pk.model, "model is not loaded"
|
56
27
|
assert pk.meta, "model metadata is not loaded"
|
@@ -0,0 +1,55 @@
|
|
1
|
+
import fcntl
|
2
|
+
import functools
|
3
|
+
import inspect
|
4
|
+
import os
|
5
|
+
import sys
|
6
|
+
import threading
|
7
|
+
import zipfile
|
8
|
+
from types import TracebackType
|
9
|
+
from typing import Optional, Type
|
10
|
+
|
11
|
+
import anyio
|
12
|
+
import pandas as pd
|
13
|
+
from _snowflake import vectorized
|
14
|
+
|
15
|
+
from snowflake.ml.model._packager import model_packager
|
16
|
+
|
17
|
+
|
18
|
+
# User-defined parameters
|
19
|
+
MODEL_DIR_REL_PATH = "{model_dir_name}"
|
20
|
+
TARGET_METHOD = "{target_method}"
|
21
|
+
MAX_BATCH_SIZE = {max_batch_size}
|
22
|
+
|
23
|
+
# Retrieve the model
|
24
|
+
IMPORT_DIRECTORY_NAME = "snowflake_import_directory"
|
25
|
+
import_dir = sys._xoptions[IMPORT_DIRECTORY_NAME]
|
26
|
+
model_dir_path = os.path.join(import_dir, MODEL_DIR_REL_PATH)
|
27
|
+
|
28
|
+
# Load the model
|
29
|
+
pk = model_packager.ModelPackager(model_dir_path)
|
30
|
+
pk.load(as_custom_model=True)
|
31
|
+
assert pk.model, "model is not loaded"
|
32
|
+
assert pk.meta, "model metadata is not loaded"
|
33
|
+
|
34
|
+
# Determine the actual runner
|
35
|
+
model = pk.model
|
36
|
+
meta = pk.meta
|
37
|
+
func = getattr(model, TARGET_METHOD)
|
38
|
+
if inspect.iscoroutinefunction(func):
|
39
|
+
runner = functools.partial(anyio.run, func)
|
40
|
+
else:
|
41
|
+
runner = functools.partial(func)
|
42
|
+
|
43
|
+
# Determine preprocess parameters
|
44
|
+
features = meta.signatures[TARGET_METHOD].inputs
|
45
|
+
input_cols = [feature.name for feature in features]
|
46
|
+
dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
|
47
|
+
|
48
|
+
|
49
|
+
# Actual table function
|
50
|
+
class {function_name}:
|
51
|
+
@vectorized(input=pd.DataFrame)
|
52
|
+
def end_partition(self, df: pd.DataFrame) -> pd.DataFrame:
|
53
|
+
df.columns = input_cols
|
54
|
+
input_df = df.astype(dtype=dtype_map)
|
55
|
+
return runner(input_df[input_cols])
|
@@ -1,12 +1,7 @@
|
|
1
|
-
import fcntl
|
2
1
|
import functools
|
3
2
|
import inspect
|
4
3
|
import os
|
5
4
|
import sys
|
6
|
-
import threading
|
7
|
-
import zipfile
|
8
|
-
from types import TracebackType
|
9
|
-
from typing import Optional, Type
|
10
5
|
|
11
6
|
import anyio
|
12
7
|
import pandas as pd
|
@@ -15,42 +10,18 @@ from _snowflake import vectorized
|
|
15
10
|
from snowflake.ml.model._packager import model_packager
|
16
11
|
|
17
12
|
|
18
|
-
class FileLock:
|
19
|
-
def __enter__(self) -> None:
|
20
|
-
self._lock = threading.Lock()
|
21
|
-
self._lock.acquire()
|
22
|
-
self._fd = open("/tmp/lockfile.LOCK", "w+")
|
23
|
-
fcntl.lockf(self._fd, fcntl.LOCK_EX)
|
24
|
-
|
25
|
-
def __exit__(
|
26
|
-
self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType]
|
27
|
-
) -> None:
|
28
|
-
self._fd.close()
|
29
|
-
self._lock.release()
|
30
|
-
|
31
|
-
|
32
13
|
# User-defined parameters
|
33
|
-
|
14
|
+
MODEL_DIR_REL_PATH = "{model_dir_name}"
|
34
15
|
TARGET_METHOD = "{target_method}"
|
35
16
|
MAX_BATCH_SIZE = {max_batch_size}
|
36
17
|
|
37
|
-
|
38
18
|
# Retrieve the model
|
39
19
|
IMPORT_DIRECTORY_NAME = "snowflake_import_directory"
|
40
20
|
import_dir = sys._xoptions[IMPORT_DIRECTORY_NAME]
|
41
|
-
|
42
|
-
model_dir_name = os.path.splitext(MODEL_FILE_NAME)[0]
|
43
|
-
zip_model_path = os.path.join(import_dir, MODEL_FILE_NAME)
|
44
|
-
extracted = "/tmp/models"
|
45
|
-
extracted_model_dir_path = os.path.join(extracted, model_dir_name)
|
46
|
-
|
47
|
-
with FileLock():
|
48
|
-
if not os.path.isdir(extracted_model_dir_path):
|
49
|
-
with zipfile.ZipFile(zip_model_path, "r") as myzip:
|
50
|
-
myzip.extractall(extracted_model_dir_path)
|
21
|
+
model_dir_path = os.path.join(import_dir, MODEL_DIR_REL_PATH)
|
51
22
|
|
52
23
|
# Load the model
|
53
|
-
pk = model_packager.ModelPackager(
|
24
|
+
pk = model_packager.ModelPackager(model_dir_path)
|
54
25
|
pk.load(as_custom_model=True)
|
55
26
|
assert pk.model, "model is not loaded"
|
56
27
|
assert pk.meta, "model metadata is not loaded"
|
@@ -72,8 +43,8 @@ dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
|
|
72
43
|
|
73
44
|
# Actual table function
|
74
45
|
class {function_name}:
|
75
|
-
@vectorized(input=pd.DataFrame)
|
76
|
-
def
|
46
|
+
@vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE)
|
47
|
+
def process(self, df: pd.DataFrame) -> pd.DataFrame:
|
77
48
|
df.columns = input_cols
|
78
49
|
input_df = df.astype(dtype=dtype_map)
|
79
50
|
return runner(input_df[input_cols])
|
@@ -26,13 +26,14 @@ class ModelMethodOptions(TypedDict):
|
|
26
26
|
def get_model_method_options_from_options(
|
27
27
|
options: type_hints.ModelSaveOption, target_method: str
|
28
28
|
) -> ModelMethodOptions:
|
29
|
+
default_function_type = model_manifest_schema.ModelMethodFunctionTypes.FUNCTION.value
|
30
|
+
if options.get("enable_explainability", False) and target_method.startswith("explain"):
|
31
|
+
default_function_type = model_manifest_schema.ModelMethodFunctionTypes.TABLE_FUNCTION.value
|
29
32
|
method_option = options.get("method_options", {}).get(target_method, {})
|
30
|
-
global_function_type = options.get("function_type",
|
33
|
+
global_function_type = options.get("function_type", default_function_type)
|
31
34
|
function_type = method_option.get("function_type", global_function_type)
|
32
35
|
if function_type not in [function_type.value for function_type in model_manifest_schema.ModelMethodFunctionTypes]:
|
33
|
-
raise NotImplementedError
|
34
|
-
|
35
|
-
# TODO(TH): enforce minimum snowflake version
|
36
|
+
raise NotImplementedError(f"Function type {function_type} is not supported.")
|
36
37
|
|
37
38
|
return ModelMethodOptions(
|
38
39
|
case_sensitive=method_option.get("case_sensitive", False),
|
@@ -47,10 +48,9 @@ class ModelMethod:
|
|
47
48
|
Attributes:
|
48
49
|
model_meta: Model Metadata.
|
49
50
|
target_method: Original target method name to call with the model.
|
50
|
-
method_name: The actual method name registered in manifest and used in SQL.
|
51
|
-
|
52
|
-
function_generator: Function file generator.
|
53
51
|
runtime_name: Name of the Model Runtime to run the method.
|
52
|
+
function_generator: Function file generator.
|
53
|
+
is_partitioned_function: Whether the model method function is partitioned.
|
54
54
|
|
55
55
|
options: Model Method Options.
|
56
56
|
"""
|
@@ -63,11 +63,13 @@ class ModelMethod:
|
|
63
63
|
target_method: str,
|
64
64
|
runtime_name: str,
|
65
65
|
function_generator: function_generator.FunctionGenerator,
|
66
|
+
is_partitioned_function: bool = False,
|
66
67
|
options: Optional[ModelMethodOptions] = None,
|
67
68
|
) -> None:
|
68
69
|
self.model_meta = model_meta
|
69
70
|
self.target_method = target_method
|
70
71
|
self.function_generator = function_generator
|
72
|
+
self.is_partitioned_function = is_partitioned_function
|
71
73
|
self.runtime_name = runtime_name
|
72
74
|
self.options = options or {}
|
73
75
|
try:
|
@@ -111,6 +113,7 @@ class ModelMethod:
|
|
111
113
|
workspace_path / ModelMethod.FUNCTIONS_DIR_REL_PATH / f"{self.target_method}.py",
|
112
114
|
self.target_method,
|
113
115
|
self.function_type,
|
116
|
+
self.is_partitioned_function,
|
114
117
|
options=options,
|
115
118
|
)
|
116
119
|
input_list = [
|
@@ -1,4 +1,5 @@
|
|
1
1
|
from abc import abstractmethod
|
2
|
+
from enum import Enum
|
2
3
|
from typing import Dict, Generic, Optional, Protocol, Type, final
|
3
4
|
|
4
5
|
from typing_extensions import TypeGuard, Unpack
|
@@ -8,6 +9,15 @@ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
|
8
9
|
from snowflake.ml.model._packager.model_meta import model_meta
|
9
10
|
|
10
11
|
|
12
|
+
class ModelObjective(Enum):
|
13
|
+
# This is not getting stored anywhere as metadata yet so it should be fine to slowly extend it for better coverage
|
14
|
+
UNKNOWN = "unknown"
|
15
|
+
BINARY_CLASSIFICATION = "binary_classification"
|
16
|
+
MULTI_CLASSIFICATION = "multi_classification"
|
17
|
+
REGRESSION = "regression"
|
18
|
+
RANKING = "ranking"
|
19
|
+
|
20
|
+
|
11
21
|
class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
|
12
22
|
HANDLER_TYPE: model_types.SupportedModelHandlerType
|
13
23
|
HANDLER_VERSION: str
|
@@ -16,7 +26,7 @@ class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
|
|
16
26
|
|
17
27
|
@classmethod
|
18
28
|
@abstractmethod
|
19
|
-
def can_handle(cls, model: model_types.
|
29
|
+
def can_handle(cls, model: model_types.SupportedModelType) -> TypeGuard[model_types._ModelType]:
|
20
30
|
"""Whether this handler could support the type of the `model`.
|
21
31
|
|
22
32
|
Args:
|
@@ -75,7 +85,7 @@ class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
|
|
75
85
|
name: str,
|
76
86
|
model_meta: model_meta.ModelMetadata,
|
77
87
|
model_blobs_dir_path: str,
|
78
|
-
**kwargs: Unpack[model_types.
|
88
|
+
**kwargs: Unpack[model_types.BaseModelLoadOption],
|
79
89
|
) -> model_types._ModelType:
|
80
90
|
"""Load the model into memory.
|
81
91
|
|
@@ -96,7 +106,7 @@ class _BaseModelHandlerProtocol(Protocol[model_types._ModelType]):
|
|
96
106
|
cls,
|
97
107
|
raw_model: model_types._ModelType,
|
98
108
|
model_meta: model_meta.ModelMetadata,
|
99
|
-
**kwargs: Unpack[model_types.
|
109
|
+
**kwargs: Unpack[model_types.BaseModelLoadOption],
|
100
110
|
) -> custom_model.CustomModel:
|
101
111
|
"""Create a custom model class wrap for unified interface when being deployed. The predict method will be
|
102
112
|
re-targeted based on target_method metadata.
|
@@ -1,4 +1,9 @@
|
|
1
|
-
|
1
|
+
import json
|
2
|
+
from typing import Any, Callable, Iterable, Optional, Sequence, cast
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import numpy.typing as npt
|
6
|
+
import pandas as pd
|
2
7
|
|
3
8
|
from snowflake.ml.model import model_signature, type_hints as model_types
|
4
9
|
from snowflake.ml.model._packager.model_meta import model_meta
|
@@ -36,6 +41,25 @@ def validate_signature(
|
|
36
41
|
predictions_df = get_prediction_fn(target_method, local_sample_input)
|
37
42
|
sig = model_signature.infer_signature(local_sample_input, predictions_df)
|
38
43
|
model_meta.signatures[target_method] = sig
|
44
|
+
|
45
|
+
return model_meta
|
46
|
+
|
47
|
+
|
48
|
+
def add_explain_method_signature(
|
49
|
+
model_meta: model_meta.ModelMetadata,
|
50
|
+
explain_method: str,
|
51
|
+
target_method: str,
|
52
|
+
output_return_type: model_signature.DataType = model_signature.DataType.DOUBLE,
|
53
|
+
) -> model_meta.ModelMetadata:
|
54
|
+
if target_method not in model_meta.signatures:
|
55
|
+
raise ValueError(f"Signature for target method {target_method} is missing")
|
56
|
+
inputs = model_meta.signatures[target_method].inputs
|
57
|
+
model_meta.signatures[explain_method] = model_signature.ModelSignature(
|
58
|
+
inputs=inputs,
|
59
|
+
outputs=[
|
60
|
+
model_signature.FeatureSpec(dtype=output_return_type, name=f"{spec.name}_explanation") for spec in inputs
|
61
|
+
],
|
62
|
+
)
|
39
63
|
return model_meta
|
40
64
|
|
41
65
|
|
@@ -55,3 +79,37 @@ def validate_target_methods(model: model_types.SupportedModelType, target_method
|
|
55
79
|
for method_name in target_methods:
|
56
80
|
if not _is_callable(model, method_name):
|
57
81
|
raise ValueError(f"Target method {method_name} is not callable or does not exist in the model.")
|
82
|
+
|
83
|
+
|
84
|
+
def get_num_classes_if_exists(model: model_types.SupportedModelType) -> int:
|
85
|
+
num_classes = getattr(model, "classes_", [])
|
86
|
+
return len(num_classes)
|
87
|
+
|
88
|
+
|
89
|
+
def convert_explanations_to_2D_df(
|
90
|
+
model: model_types.SupportedModelType, explanations: npt.NDArray[Any]
|
91
|
+
) -> pd.DataFrame:
|
92
|
+
if explanations.ndim != 3:
|
93
|
+
return pd.DataFrame(explanations)
|
94
|
+
|
95
|
+
if hasattr(model, "classes_"):
|
96
|
+
classes_list = [cl for cl in model.classes_] # type:ignore[union-attr]
|
97
|
+
len_classes = len(classes_list)
|
98
|
+
if explanations.shape[2] != len_classes:
|
99
|
+
raise ValueError(f"Model has {len_classes} classes but explanations have {explanations.shape[2]}")
|
100
|
+
else:
|
101
|
+
classes_list = [i for i in range(explanations.shape[2])]
|
102
|
+
exp_2d = []
|
103
|
+
# TODO (SNOW-1549044): Optimize this
|
104
|
+
for row in explanations:
|
105
|
+
col_list = []
|
106
|
+
for column in row:
|
107
|
+
class_explanations = {}
|
108
|
+
for cl, cl_exp in zip(classes_list, column):
|
109
|
+
if isinstance(cl, (int, np.integer)):
|
110
|
+
cl = int(cl)
|
111
|
+
class_explanations[cl] = cl_exp
|
112
|
+
col_list.append(json.dumps(class_explanations))
|
113
|
+
exp_2d.append(col_list)
|
114
|
+
|
115
|
+
return pd.DataFrame(exp_2d)
|
@@ -33,6 +33,22 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
33
33
|
MODELE_BLOB_FILE_OR_DIR = "model.bin"
|
34
34
|
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
35
35
|
|
36
|
+
@classmethod
|
37
|
+
def get_model_objective(cls, model: "catboost.CatBoost") -> _base.ModelObjective:
|
38
|
+
import catboost
|
39
|
+
|
40
|
+
if isinstance(model, catboost.CatBoostClassifier):
|
41
|
+
num_classes = handlers_utils.get_num_classes_if_exists(model)
|
42
|
+
if num_classes == 2:
|
43
|
+
return _base.ModelObjective.BINARY_CLASSIFICATION
|
44
|
+
return _base.ModelObjective.MULTI_CLASSIFICATION
|
45
|
+
if isinstance(model, catboost.CatBoostRanker):
|
46
|
+
return _base.ModelObjective.RANKING
|
47
|
+
if isinstance(model, catboost.CatBoostRegressor):
|
48
|
+
return _base.ModelObjective.REGRESSION
|
49
|
+
# TODO: Find out model type from the generic Catboost Model
|
50
|
+
return _base.ModelObjective.UNKNOWN
|
51
|
+
|
36
52
|
@classmethod
|
37
53
|
def can_handle(cls, model: model_types.SupportedModelType) -> TypeGuard["catboost.CatBoost"]:
|
38
54
|
return (type_utils.LazyType("catboost.CatBoost").isinstance(model)) and any(
|
@@ -89,6 +105,16 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
89
105
|
sample_input_data=sample_input_data,
|
90
106
|
get_prediction_fn=get_prediction,
|
91
107
|
)
|
108
|
+
if kwargs.get("enable_explainability", False):
|
109
|
+
output_type = model_signature.DataType.DOUBLE
|
110
|
+
if cls.get_model_objective(model) == _base.ModelObjective.MULTI_CLASSIFICATION:
|
111
|
+
output_type = model_signature.DataType.STRING
|
112
|
+
model_meta = handlers_utils.add_explain_method_signature(
|
113
|
+
model_meta=model_meta,
|
114
|
+
explain_method="explain",
|
115
|
+
target_method="predict",
|
116
|
+
output_return_type=output_type,
|
117
|
+
)
|
92
118
|
|
93
119
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
94
120
|
os.makedirs(model_blob_path, exist_ok=True)
|
@@ -112,6 +138,11 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
112
138
|
],
|
113
139
|
check_local_version=True,
|
114
140
|
)
|
141
|
+
if kwargs.get("enable_explainability", False):
|
142
|
+
model_meta.env.include_if_absent(
|
143
|
+
[model_env.ModelDependency(requirement="shap", pip_name="shap")],
|
144
|
+
check_local_version=True,
|
145
|
+
)
|
115
146
|
model_meta.env.cuda_version = kwargs.get("cuda_version", model_env.DEFAULT_CUDA_VERSION)
|
116
147
|
|
117
148
|
return None
|
@@ -122,7 +153,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
122
153
|
name: str,
|
123
154
|
model_meta: model_meta_api.ModelMetadata,
|
124
155
|
model_blobs_dir_path: str,
|
125
|
-
**kwargs: Unpack[model_types.
|
156
|
+
**kwargs: Unpack[model_types.CatBoostModelLoadOptions],
|
126
157
|
) -> "catboost.CatBoost":
|
127
158
|
import catboost
|
128
159
|
|
@@ -157,7 +188,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
157
188
|
cls,
|
158
189
|
raw_model: "catboost.CatBoost",
|
159
190
|
model_meta: model_meta_api.ModelMetadata,
|
160
|
-
**kwargs: Unpack[model_types.
|
191
|
+
**kwargs: Unpack[model_types.CatBoostModelLoadOptions],
|
161
192
|
) -> custom_model.CustomModel:
|
162
193
|
import catboost
|
163
194
|
|
@@ -186,6 +217,17 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
186
217
|
|
187
218
|
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
188
219
|
|
220
|
+
@custom_model.inference_api
|
221
|
+
def explain_fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
222
|
+
import shap
|
223
|
+
|
224
|
+
explainer = shap.TreeExplainer(raw_model)
|
225
|
+
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X).values)
|
226
|
+
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
227
|
+
|
228
|
+
if target_method == "explain":
|
229
|
+
return explain_fn
|
230
|
+
|
189
231
|
return fn
|
190
232
|
|
191
233
|
type_method_dict: Dict[str, Any] = {"_raw_model": raw_model}
|
@@ -17,6 +17,7 @@ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
|
17
17
|
from snowflake.ml.model._packager.model_meta import (
|
18
18
|
model_blob_meta,
|
19
19
|
model_meta as model_meta_api,
|
20
|
+
model_meta_schema,
|
20
21
|
)
|
21
22
|
|
22
23
|
|
@@ -68,6 +69,11 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
68
69
|
predictions_df = target_method(model, sample_input_data)
|
69
70
|
return predictions_df
|
70
71
|
|
72
|
+
for func_name in model._get_partitioned_infer_methods():
|
73
|
+
function_properties = model_meta.function_properties.get(func_name, {})
|
74
|
+
function_properties[model_meta_schema.FunctionProperties.PARTITIONED.value] = True
|
75
|
+
model_meta.function_properties[func_name] = function_properties
|
76
|
+
|
71
77
|
if not is_sub_model:
|
72
78
|
model_meta = handlers_utils.validate_signature(
|
73
79
|
model=model,
|
@@ -101,14 +107,16 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
101
107
|
|
102
108
|
# Make sure that the module where the model is defined get pickled by value as well.
|
103
109
|
cloudpickle.register_pickle_by_value(sys.modules[model.__module__])
|
104
|
-
|
110
|
+
pickled_obj = (model.__class__, model.context)
|
105
111
|
with open(os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR), "wb") as f:
|
106
|
-
cloudpickle.dump(
|
112
|
+
cloudpickle.dump(pickled_obj, f)
|
113
|
+
# model meta will be saved by the context manager
|
107
114
|
model_meta.models[name] = model_blob_meta.ModelBlobMeta(
|
108
115
|
name=name,
|
109
116
|
model_type=cls.HANDLER_TYPE,
|
110
117
|
path=cls.MODELE_BLOB_FILE_OR_DIR,
|
111
118
|
handler_version=cls.HANDLER_VERSION,
|
119
|
+
function_properties=model_meta.function_properties,
|
112
120
|
artifacts={
|
113
121
|
name: pathlib.Path(
|
114
122
|
os.path.join(cls.MODEL_ARTIFACTS_DIR, os.path.basename(os.path.normpath(path=uri)))
|
@@ -128,7 +136,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
128
136
|
name: str,
|
129
137
|
model_meta: model_meta_api.ModelMetadata,
|
130
138
|
model_blobs_dir_path: str,
|
131
|
-
**kwargs: Unpack[model_types.
|
139
|
+
**kwargs: Unpack[model_types.CustomModelLoadOption],
|
132
140
|
) -> "custom_model.CustomModel":
|
133
141
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
134
142
|
|
@@ -175,6 +183,6 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
175
183
|
cls,
|
176
184
|
raw_model: custom_model.CustomModel,
|
177
185
|
model_meta: model_meta_api.ModelMetadata,
|
178
|
-
**kwargs: Unpack[model_types.
|
186
|
+
**kwargs: Unpack[model_types.CustomModelLoadOption],
|
179
187
|
) -> custom_model.CustomModel:
|
180
188
|
return raw_model
|