snowflake-ml-python 1.7.2__py3-none-any.whl → 1.7.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/cortex/__init__.py +16 -8
- snowflake/cortex/_classify_text.py +12 -1
- snowflake/cortex/_complete.py +101 -13
- snowflake/cortex/_embed_text_1024.py +9 -2
- snowflake/cortex/_embed_text_768.py +9 -2
- snowflake/cortex/_extract_answer.py +9 -2
- snowflake/cortex/_sentiment.py +9 -2
- snowflake/cortex/_summarize.py +9 -2
- snowflake/cortex/_translate.py +9 -2
- snowflake/ml/_internal/env_utils.py +7 -52
- snowflake/ml/_internal/platform_capabilities.py +87 -0
- snowflake/ml/_internal/utils/identifier.py +4 -2
- snowflake/ml/data/__init__.py +3 -0
- snowflake/ml/data/_internal/arrow_ingestor.py +4 -4
- snowflake/ml/data/data_connector.py +53 -11
- snowflake/ml/data/data_ingestor.py +2 -1
- snowflake/ml/data/torch_utils.py +18 -5
- snowflake/ml/dataset/dataset.py +0 -1
- snowflake/ml/feature_store/examples/example_helper.py +2 -1
- snowflake/ml/fileset/fileset.py +24 -18
- snowflake/ml/jobs/__init__.py +21 -0
- snowflake/ml/jobs/_utils/constants.py +51 -0
- snowflake/ml/jobs/_utils/payload_utils.py +352 -0
- snowflake/ml/jobs/_utils/spec_utils.py +298 -0
- snowflake/ml/jobs/_utils/types.py +39 -0
- snowflake/ml/jobs/decorators.py +91 -0
- snowflake/ml/jobs/job.py +113 -0
- snowflake/ml/jobs/manager.py +298 -0
- snowflake/ml/model/_client/model/model_version_impl.py +5 -3
- snowflake/ml/model/_client/ops/model_ops.py +13 -8
- snowflake/ml/model/_client/ops/service_ops.py +1 -11
- snowflake/ml/model/_client/sql/model_version.py +11 -0
- snowflake/ml/model/_client/sql/service.py +13 -6
- snowflake/ml/model/_model_composer/model_composer.py +8 -3
- snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +20 -1
- snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +1 -0
- snowflake/ml/model/_model_composer/model_method/constants.py +1 -0
- snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -0
- snowflake/ml/model/_model_composer/model_method/infer_function.py_template +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_partitioned.py_template +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +1 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +9 -1
- snowflake/ml/model/_model_composer/model_user_file/model_user_file.py +27 -0
- snowflake/ml/model/_packager/model_handlers/_utils.py +39 -5
- snowflake/ml/model/_packager/model_handlers/catboost.py +3 -3
- snowflake/ml/model/_packager/model_handlers/custom.py +1 -2
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +6 -1
- snowflake/ml/model/_packager/model_handlers/lightgbm.py +5 -3
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +55 -20
- snowflake/ml/model/_packager/model_handlers/sklearn.py +9 -10
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +66 -28
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +70 -17
- snowflake/ml/model/_packager/model_handlers/xgboost.py +3 -3
- snowflake/ml/model/_packager/model_meta/model_meta.py +3 -0
- snowflake/ml/model/_packager/model_meta/model_meta_schema.py +6 -1
- snowflake/ml/model/_packager/model_runtime/_snowml_inference_alternative_requirements.py +2 -2
- snowflake/ml/model/_packager/model_task/model_task_utils.py +3 -2
- snowflake/ml/model/_signatures/base_handler.py +1 -2
- snowflake/ml/model/_signatures/builtins_handler.py +2 -2
- snowflake/ml/model/_signatures/numpy_handler.py +6 -7
- snowflake/ml/model/_signatures/pandas_handler.py +3 -3
- snowflake/ml/model/_signatures/pytorch_handler.py +2 -5
- snowflake/ml/model/_signatures/snowpark_handler.py +11 -5
- snowflake/ml/model/_signatures/tensorflow_handler.py +2 -7
- snowflake/ml/model/model_signature.py +17 -4
- snowflake/ml/model/type_hints.py +1 -0
- snowflake/ml/modeling/_internal/model_trainer_builder.py +0 -8
- snowflake/ml/modeling/_internal/model_transformer_builder.py +0 -13
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +6 -3
- snowflake/ml/modeling/cluster/affinity_propagation.py +6 -3
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +6 -3
- snowflake/ml/modeling/cluster/birch.py +6 -3
- snowflake/ml/modeling/cluster/bisecting_k_means.py +6 -3
- snowflake/ml/modeling/cluster/dbscan.py +6 -3
- snowflake/ml/modeling/cluster/feature_agglomeration.py +6 -3
- snowflake/ml/modeling/cluster/k_means.py +6 -3
- snowflake/ml/modeling/cluster/mean_shift.py +6 -3
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +6 -3
- snowflake/ml/modeling/cluster/optics.py +6 -3
- snowflake/ml/modeling/cluster/spectral_biclustering.py +6 -3
- snowflake/ml/modeling/cluster/spectral_clustering.py +6 -3
- snowflake/ml/modeling/cluster/spectral_coclustering.py +6 -3
- snowflake/ml/modeling/compose/column_transformer.py +6 -3
- snowflake/ml/modeling/compose/transformed_target_regressor.py +6 -3
- snowflake/ml/modeling/covariance/elliptic_envelope.py +6 -3
- snowflake/ml/modeling/covariance/empirical_covariance.py +6 -3
- snowflake/ml/modeling/covariance/graphical_lasso.py +6 -3
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +6 -3
- snowflake/ml/modeling/covariance/ledoit_wolf.py +6 -3
- snowflake/ml/modeling/covariance/min_cov_det.py +6 -3
- snowflake/ml/modeling/covariance/oas.py +6 -3
- snowflake/ml/modeling/covariance/shrunk_covariance.py +6 -3
- snowflake/ml/modeling/decomposition/dictionary_learning.py +6 -3
- snowflake/ml/modeling/decomposition/factor_analysis.py +6 -3
- snowflake/ml/modeling/decomposition/fast_ica.py +6 -3
- snowflake/ml/modeling/decomposition/incremental_pca.py +6 -3
- snowflake/ml/modeling/decomposition/kernel_pca.py +6 -3
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +6 -3
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +6 -3
- snowflake/ml/modeling/decomposition/pca.py +6 -3
- snowflake/ml/modeling/decomposition/sparse_pca.py +6 -3
- snowflake/ml/modeling/decomposition/truncated_svd.py +6 -3
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +6 -3
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +6 -3
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/bagging_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/bagging_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/isolation_forest.py +6 -3
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/stacking_regressor.py +6 -3
- snowflake/ml/modeling/ensemble/voting_classifier.py +6 -3
- snowflake/ml/modeling/ensemble/voting_regressor.py +6 -3
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fdr.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fpr.py +6 -3
- snowflake/ml/modeling/feature_selection/select_fwe.py +6 -3
- snowflake/ml/modeling/feature_selection/select_k_best.py +6 -3
- snowflake/ml/modeling/feature_selection/select_percentile.py +6 -3
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +6 -3
- snowflake/ml/modeling/feature_selection/variance_threshold.py +6 -3
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +6 -3
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +6 -3
- snowflake/ml/modeling/impute/iterative_imputer.py +6 -3
- snowflake/ml/modeling/impute/knn_imputer.py +6 -3
- snowflake/ml/modeling/impute/missing_indicator.py +6 -3
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +6 -3
- snowflake/ml/modeling/kernel_approximation/nystroem.py +6 -3
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +6 -3
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +6 -3
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +6 -3
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +6 -3
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +6 -3
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ard_regression.py +6 -3
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +6 -3
- snowflake/ml/modeling/linear_model/elastic_net.py +6 -3
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +6 -3
- snowflake/ml/modeling/linear_model/gamma_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/huber_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/lars.py +6 -3
- snowflake/ml/modeling/linear_model/lars_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +6 -3
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +6 -3
- snowflake/ml/modeling/linear_model/linear_regression.py +6 -3
- snowflake/ml/modeling/linear_model/logistic_regression.py +6 -3
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +6 -3
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +6 -3
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +6 -3
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/perceptron.py +6 -3
- snowflake/ml/modeling/linear_model/poisson_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ransac_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/ridge.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +6 -3
- snowflake/ml/modeling/linear_model/ridge_cv.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_classifier.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +6 -3
- snowflake/ml/modeling/linear_model/sgd_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +6 -3
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +6 -3
- snowflake/ml/modeling/manifold/isomap.py +6 -3
- snowflake/ml/modeling/manifold/mds.py +6 -3
- snowflake/ml/modeling/manifold/spectral_embedding.py +6 -3
- snowflake/ml/modeling/manifold/tsne.py +6 -3
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +6 -3
- snowflake/ml/modeling/mixture/gaussian_mixture.py +6 -3
- snowflake/ml/modeling/model_selection/grid_search_cv.py +17 -2
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +17 -2
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +6 -3
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +6 -3
- snowflake/ml/modeling/multiclass/output_code_classifier.py +6 -3
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/complement_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +6 -3
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +6 -3
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +6 -3
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +6 -3
- snowflake/ml/modeling/neighbors/kernel_density.py +6 -3
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +6 -3
- snowflake/ml/modeling/neighbors/nearest_centroid.py +6 -3
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +6 -3
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +6 -3
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +6 -3
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +6 -3
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +6 -3
- snowflake/ml/modeling/neural_network/mlp_classifier.py +6 -3
- snowflake/ml/modeling/neural_network/mlp_regressor.py +6 -3
- snowflake/ml/modeling/pipeline/pipeline.py +16 -178
- snowflake/ml/modeling/preprocessing/polynomial_features.py +6 -3
- snowflake/ml/modeling/semi_supervised/label_propagation.py +6 -3
- snowflake/ml/modeling/semi_supervised/label_spreading.py +6 -3
- snowflake/ml/modeling/svm/linear_svc.py +6 -3
- snowflake/ml/modeling/svm/linear_svr.py +6 -3
- snowflake/ml/modeling/svm/nu_svc.py +6 -3
- snowflake/ml/modeling/svm/nu_svr.py +6 -3
- snowflake/ml/modeling/svm/svc.py +6 -3
- snowflake/ml/modeling/svm/svr.py +6 -3
- snowflake/ml/modeling/tree/decision_tree_classifier.py +6 -3
- snowflake/ml/modeling/tree/decision_tree_regressor.py +6 -3
- snowflake/ml/modeling/tree/extra_tree_classifier.py +6 -3
- snowflake/ml/modeling/tree/extra_tree_regressor.py +6 -3
- snowflake/ml/modeling/xgboost/xgb_classifier.py +167 -91
- snowflake/ml/modeling/xgboost/xgb_regressor.py +166 -88
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +166 -88
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +166 -88
- snowflake/ml/monitoring/_client/model_monitor_sql_client.py +4 -4
- snowflake/ml/registry/_manager/model_manager.py +70 -33
- snowflake/ml/registry/registry.py +41 -22
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/METADATA +63 -19
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/RECORD +231 -226
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/WHEEL +1 -1
- snowflake/ml/_internal/utils/retryable_http.py +0 -39
- snowflake/ml/fileset/parquet_parser.py +0 -170
- snowflake/ml/fileset/tf_dataset.py +0 -88
- snowflake/ml/fileset/torch_datapipe.py +0 -57
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +0 -151
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_trainer.py +0 -66
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.7.2.dist-info → snowflake_ml_python-1.7.4.dist-info}/top_level.txt +0 -0
@@ -88,6 +88,7 @@ class ModelComposer:
|
|
88
88
|
pip_requirements: Optional[List[str]] = None,
|
89
89
|
target_platforms: Optional[List[model_types.TargetPlatform]] = None,
|
90
90
|
python_version: Optional[str] = None,
|
91
|
+
user_files: Optional[Dict[str, List[str]]] = None,
|
91
92
|
ext_modules: Optional[List[ModuleType]] = None,
|
92
93
|
code_paths: Optional[List[str]] = None,
|
93
94
|
task: model_types.Task = model_types.Task.UNKNOWN,
|
@@ -97,9 +98,12 @@ class ModelComposer:
|
|
97
98
|
options = model_types.BaseModelSaveOption()
|
98
99
|
|
99
100
|
if not snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
|
100
|
-
snowml_matched_versions = env_utils.
|
101
|
-
|
102
|
-
|
101
|
+
snowml_matched_versions = env_utils.get_matched_package_versions_in_information_schema(
|
102
|
+
self.session,
|
103
|
+
reqs=[requirements.Requirement(f"{env_utils.SNOWPARK_ML_PKG_NAME}=={snowml_env.VERSION}")],
|
104
|
+
python_version=python_version or snowml_env.PYTHON_VERSION,
|
105
|
+
statement_params=self._statement_params,
|
106
|
+
).get(env_utils.SNOWPARK_ML_PKG_NAME, [])
|
103
107
|
|
104
108
|
if len(snowml_matched_versions) < 1 and options.get("embed_local_ml_library", False) is False:
|
105
109
|
logging.info(
|
@@ -131,6 +135,7 @@ class ModelComposer:
|
|
131
135
|
model_meta=self.packager.meta,
|
132
136
|
model_rel_path=pathlib.PurePosixPath(ModelComposer.MODEL_DIR_REL_PATH),
|
133
137
|
options=options,
|
138
|
+
user_files=user_files,
|
134
139
|
data_sources=self._get_data_sources(model, sample_input_data),
|
135
140
|
target_platforms=target_platforms,
|
136
141
|
)
|
@@ -2,7 +2,7 @@ import collections
|
|
2
2
|
import logging
|
3
3
|
import pathlib
|
4
4
|
import warnings
|
5
|
-
from typing import List, Optional, cast
|
5
|
+
from typing import Dict, List, Optional, cast
|
6
6
|
|
7
7
|
import yaml
|
8
8
|
|
@@ -11,9 +11,11 @@ from snowflake.ml.data import data_source
|
|
11
11
|
from snowflake.ml.model import type_hints
|
12
12
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
13
13
|
from snowflake.ml.model._model_composer.model_method import (
|
14
|
+
constants,
|
14
15
|
function_generator,
|
15
16
|
model_method,
|
16
17
|
)
|
18
|
+
from snowflake.ml.model._model_composer.model_user_file import model_user_file
|
17
19
|
from snowflake.ml.model._packager.model_meta import (
|
18
20
|
model_meta as model_meta_api,
|
19
21
|
model_meta_schema,
|
@@ -30,9 +32,11 @@ class ModelManifest:
|
|
30
32
|
workspace_path: A local path where model related files should be dumped to.
|
31
33
|
runtimes: A list of ModelRuntime objects managing the runtimes and environment in the MODEL object.
|
32
34
|
methods: A list of ModelMethod objects managing the method we registered to the MODEL object.
|
35
|
+
user_files: A list of ModelUserFile objects managing extra files uploaded to the workspace.
|
33
36
|
"""
|
34
37
|
|
35
38
|
MANIFEST_FILE_REL_PATH = "MANIFEST.yml"
|
39
|
+
_ENABLE_USER_FILES = False
|
36
40
|
_DEFAULT_RUNTIME_NAME = "python_runtime"
|
37
41
|
|
38
42
|
def __init__(self, workspace_path: pathlib.Path) -> None:
|
@@ -42,6 +46,7 @@ class ModelManifest:
|
|
42
46
|
self,
|
43
47
|
model_meta: model_meta_api.ModelMetadata,
|
44
48
|
model_rel_path: pathlib.PurePosixPath,
|
49
|
+
user_files: Optional[Dict[str, List[str]]] = None,
|
45
50
|
options: Optional[type_hints.ModelSaveOption] = None,
|
46
51
|
data_sources: Optional[List[data_source.DataSource]] = None,
|
47
52
|
target_platforms: Optional[List[type_hints.TargetPlatform]] = None,
|
@@ -79,6 +84,7 @@ class ModelManifest:
|
|
79
84
|
|
80
85
|
self.function_generator = function_generator.FunctionGenerator(model_dir_rel_path=model_rel_path)
|
81
86
|
self.methods: List[model_method.ModelMethod] = []
|
87
|
+
|
82
88
|
for target_method in model_meta.signatures.keys():
|
83
89
|
method = model_method.ModelMethod(
|
84
90
|
model_meta=model_meta,
|
@@ -88,11 +94,21 @@ class ModelManifest:
|
|
88
94
|
is_partitioned_function=model_meta.function_properties.get(target_method, {}).get(
|
89
95
|
model_meta_schema.FunctionProperties.PARTITIONED.value, False
|
90
96
|
),
|
97
|
+
wide_input=len(model_meta.signatures[target_method].inputs) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT,
|
91
98
|
options=model_method.get_model_method_options_from_options(options, target_method),
|
92
99
|
)
|
93
100
|
|
94
101
|
self.methods.append(method)
|
95
102
|
|
103
|
+
self.user_files: List[model_user_file.ModelUserFile] = []
|
104
|
+
|
105
|
+
if user_files is not None:
|
106
|
+
for subdirectory, paths in user_files.items():
|
107
|
+
for path in paths:
|
108
|
+
self.user_files.append(
|
109
|
+
model_user_file.ModelUserFile(pathlib.PurePosixPath(subdirectory), pathlib.Path(path))
|
110
|
+
)
|
111
|
+
|
96
112
|
method_name_counter = collections.Counter([method.method_name for method in self.methods])
|
97
113
|
dup_method_names = [k for k, v in method_name_counter.items() if v > 1]
|
98
114
|
if dup_method_names:
|
@@ -129,6 +145,9 @@ class ModelManifest:
|
|
129
145
|
],
|
130
146
|
)
|
131
147
|
|
148
|
+
if self._ENABLE_USER_FILES:
|
149
|
+
manifest_dict["user_files"] = [user_file.save(self.workspace_path) for user_file in self.user_files]
|
150
|
+
|
132
151
|
lineage_sources = self._extract_lineage_info(data_sources)
|
133
152
|
if lineage_sources:
|
134
153
|
manifest_dict["lineage_sources"] = lineage_sources
|
@@ -94,5 +94,6 @@ class ModelManifestDict(TypedDict):
|
|
94
94
|
runtimes: Required[Dict[str, ModelRuntimeDict]]
|
95
95
|
methods: Required[List[ModelMethodDict]]
|
96
96
|
user_data: NotRequired[Dict[str, Any]]
|
97
|
+
user_files: NotRequired[List[str]]
|
97
98
|
lineage_sources: NotRequired[List[LineageSourceDict]]
|
98
99
|
target_platforms: NotRequired[List[str]]
|
@@ -0,0 +1 @@
|
|
1
|
+
SNOWPARK_UDF_INPUT_COL_LIMIT = 500
|
@@ -43,6 +43,7 @@ class FunctionGenerator:
|
|
43
43
|
target_method: str,
|
44
44
|
function_type: str,
|
45
45
|
is_partitioned_function: bool = False,
|
46
|
+
wide_input: bool = False,
|
46
47
|
options: Optional[FunctionGenerateOptions] = None,
|
47
48
|
) -> None:
|
48
49
|
import importlib_resources
|
@@ -70,6 +71,7 @@ class FunctionGenerator:
|
|
70
71
|
model_dir_name=self.model_dir_rel_path.name,
|
71
72
|
target_method=target_method,
|
72
73
|
max_batch_size=options.get("max_batch_size", None),
|
74
|
+
wide_input=wide_input,
|
73
75
|
function_name=FunctionGenerator.FUNCTION_NAME,
|
74
76
|
)
|
75
77
|
with open(function_file_path, "w", encoding="utf-8") as f:
|
@@ -43,7 +43,7 @@ dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
|
|
43
43
|
|
44
44
|
|
45
45
|
# Actual function
|
46
|
-
@vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE)
|
46
|
+
@vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE, flatten_object_input={wide_input})
|
47
47
|
def {function_name}(df: pd.DataFrame) -> dict:
|
48
48
|
df.columns = input_cols
|
49
49
|
input_df = df.astype(dtype=dtype_map)
|
@@ -48,7 +48,7 @@ dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
|
|
48
48
|
|
49
49
|
# Actual table function
|
50
50
|
class {function_name}:
|
51
|
-
@vectorized(input=pd.DataFrame)
|
51
|
+
@vectorized(input=pd.DataFrame, flatten_object_input={wide_input})
|
52
52
|
def end_partition(self, df: pd.DataFrame) -> pd.DataFrame:
|
53
53
|
df.columns = input_cols
|
54
54
|
input_df = df.astype(dtype=dtype_map)
|
@@ -43,7 +43,7 @@ dtype_map = {{feature.name: feature.as_dtype() for feature in features}}
|
|
43
43
|
|
44
44
|
# Actual table function
|
45
45
|
class {function_name}:
|
46
|
-
@vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE)
|
46
|
+
@vectorized(input=pd.DataFrame, max_batch_size=MAX_BATCH_SIZE, flatten_object_input={wide_input})
|
47
47
|
def process(self, df: pd.DataFrame) -> pd.DataFrame:
|
48
48
|
df.columns = input_cols
|
49
49
|
input_df = df.astype(dtype=dtype_map)
|
@@ -7,7 +7,10 @@ from typing_extensions import NotRequired
|
|
7
7
|
from snowflake.ml._internal.utils import sql_identifier
|
8
8
|
from snowflake.ml.model import model_signature, type_hints
|
9
9
|
from snowflake.ml.model._model_composer.model_manifest import model_manifest_schema
|
10
|
-
from snowflake.ml.model._model_composer.model_method import
|
10
|
+
from snowflake.ml.model._model_composer.model_method import (
|
11
|
+
constants,
|
12
|
+
function_generator,
|
13
|
+
)
|
11
14
|
from snowflake.ml.model._packager.model_meta import model_meta as model_meta_api
|
12
15
|
from snowflake.snowpark._internal import type_utils
|
13
16
|
|
@@ -64,6 +67,7 @@ class ModelMethod:
|
|
64
67
|
runtime_name: str,
|
65
68
|
function_generator: function_generator.FunctionGenerator,
|
66
69
|
is_partitioned_function: bool = False,
|
70
|
+
wide_input: bool = False,
|
67
71
|
options: Optional[ModelMethodOptions] = None,
|
68
72
|
) -> None:
|
69
73
|
self.model_meta = model_meta
|
@@ -71,6 +75,7 @@ class ModelMethod:
|
|
71
75
|
self.function_generator = function_generator
|
72
76
|
self.is_partitioned_function = is_partitioned_function
|
73
77
|
self.runtime_name = runtime_name
|
78
|
+
self.wide_input = wide_input
|
74
79
|
self.options = options or {}
|
75
80
|
try:
|
76
81
|
self.method_name = sql_identifier.SqlIdentifier(
|
@@ -114,12 +119,15 @@ class ModelMethod:
|
|
114
119
|
self.target_method,
|
115
120
|
self.function_type,
|
116
121
|
self.is_partitioned_function,
|
122
|
+
self.wide_input,
|
117
123
|
options=options,
|
118
124
|
)
|
119
125
|
input_list = [
|
120
126
|
ModelMethod._get_method_arg_from_feature(ft, case_sensitive=self.options.get("case_sensitive", False))
|
121
127
|
for ft in self.model_meta.signatures[self.target_method].inputs
|
122
128
|
]
|
129
|
+
if len(input_list) > constants.SNOWPARK_UDF_INPUT_COL_LIMIT:
|
130
|
+
input_list = [{"name": "INPUT", "type": "OBJECT"}]
|
123
131
|
input_name_counter = collections.Counter([input_info["name"] for input_info in input_list])
|
124
132
|
dup_input_names = [k for k, v in input_name_counter.items() if v > 1]
|
125
133
|
if dup_input_names:
|
@@ -0,0 +1,27 @@
|
|
1
|
+
import os
|
2
|
+
import pathlib
|
3
|
+
|
4
|
+
from snowflake.ml._internal import file_utils
|
5
|
+
|
6
|
+
|
7
|
+
class ModelUserFile:
|
8
|
+
"""Class representing a user provided file.
|
9
|
+
|
10
|
+
Attributes:
|
11
|
+
subdirectory_name: A local path where model related files should be dumped to.
|
12
|
+
local_path: A list of ModelRuntime objects managing the runtimes and environment in the MODEL object.
|
13
|
+
"""
|
14
|
+
|
15
|
+
USER_FILES_DIR_REL_PATH = "user_files"
|
16
|
+
|
17
|
+
def __init__(self, subdirectory_name: pathlib.PurePosixPath, local_path: pathlib.Path) -> None:
|
18
|
+
self.subdirectory_name = subdirectory_name
|
19
|
+
self.local_path = local_path
|
20
|
+
|
21
|
+
def save(self, workspace_path: pathlib.Path) -> str:
|
22
|
+
user_files_path = workspace_path / ModelUserFile.USER_FILES_DIR_REL_PATH / self.subdirectory_name
|
23
|
+
user_files_path.mkdir(parents=True, exist_ok=True)
|
24
|
+
|
25
|
+
# copy the file to the workspace
|
26
|
+
file_utils.copy_file_or_tree(str(self.local_path), str(user_files_path))
|
27
|
+
return os.path.join(self.subdirectory_name, self.local_path.name)
|
@@ -1,7 +1,8 @@
|
|
1
1
|
import json
|
2
2
|
import os
|
3
|
+
import pathlib
|
3
4
|
import warnings
|
4
|
-
from typing import Any, Callable, Iterable, List, Optional, Sequence, cast
|
5
|
+
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, cast
|
5
6
|
|
6
7
|
import numpy as np
|
7
8
|
import numpy.typing as npt
|
@@ -37,8 +38,10 @@ def _is_callable(model: model_types.SupportedModelType, method_name: str) -> boo
|
|
37
38
|
return callable(getattr(model, method_name, None))
|
38
39
|
|
39
40
|
|
40
|
-
def get_truncated_sample_data(
|
41
|
-
|
41
|
+
def get_truncated_sample_data(
|
42
|
+
sample_input_data: model_types.SupportedDataType, length: int = 100
|
43
|
+
) -> model_types.SupportedLocalDataType:
|
44
|
+
trunc_sample_input = model_signature._truncate_data(sample_input_data, length=length)
|
42
45
|
local_sample_input: model_types.SupportedLocalDataType = None
|
43
46
|
if isinstance(sample_input_data, SnowparkDataFrame):
|
44
47
|
# Added because of Any from missing stubs.
|
@@ -77,7 +80,14 @@ def validate_signature(
|
|
77
80
|
local_sample_input = get_truncated_sample_data(sample_input_data)
|
78
81
|
for target_method in target_methods:
|
79
82
|
predictions_df = get_prediction_fn(target_method, local_sample_input)
|
80
|
-
sig = model_signature.infer_signature(
|
83
|
+
sig = model_signature.infer_signature(
|
84
|
+
sample_input_data,
|
85
|
+
predictions_df,
|
86
|
+
input_feature_names=None,
|
87
|
+
output_feature_names=None,
|
88
|
+
input_data_limit=100,
|
89
|
+
output_data_limit=100,
|
90
|
+
)
|
81
91
|
model_meta.signatures[target_method] = sig
|
82
92
|
|
83
93
|
return model_meta
|
@@ -118,7 +128,7 @@ def get_explainability_supported_background(
|
|
118
128
|
meta: model_meta.ModelMetadata,
|
119
129
|
explain_target_method: Optional[str],
|
120
130
|
) -> pd.DataFrame:
|
121
|
-
if sample_input_data is None:
|
131
|
+
if sample_input_data is None or explain_target_method is None:
|
122
132
|
return None
|
123
133
|
|
124
134
|
if isinstance(sample_input_data, pd.DataFrame):
|
@@ -223,3 +233,27 @@ def get_explain_target_method(
|
|
223
233
|
if method in target_methods_list:
|
224
234
|
return method
|
225
235
|
return None
|
236
|
+
|
237
|
+
|
238
|
+
def save_transformers_config_with_auto_map(local_model_path: str) -> None:
|
239
|
+
import huggingface_hub
|
240
|
+
|
241
|
+
for f_path in pathlib.Path(local_model_path).iterdir():
|
242
|
+
if f_path.name in ["config.json", "tokenizer_config.json"]:
|
243
|
+
with open(f_path) as f:
|
244
|
+
config_dict = json.load(f)
|
245
|
+
|
246
|
+
# a. get repository and class_path from configs
|
247
|
+
auto_map_configs = cast(Dict[str, str], config_dict.get("auto_map", {}))
|
248
|
+
for config_name, config_value in auto_map_configs.items():
|
249
|
+
repository, _, class_path = config_value.rpartition("--")
|
250
|
+
|
251
|
+
# b. download required configs from hf hub
|
252
|
+
if repository:
|
253
|
+
huggingface_hub.snapshot_download(repo_id=repository, local_dir=local_model_path)
|
254
|
+
|
255
|
+
# c. update config files
|
256
|
+
config_dict["auto_map"][config_name] = class_path
|
257
|
+
|
258
|
+
with open(f_path, "w") as f:
|
259
|
+
json.dump(config_dict, f)
|
@@ -94,8 +94,8 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
94
94
|
sample_input_data=sample_input_data,
|
95
95
|
get_prediction_fn=get_prediction,
|
96
96
|
)
|
97
|
-
model_task_and_output = model_task_utils.
|
98
|
-
model_meta.task =
|
97
|
+
model_task_and_output = model_task_utils.resolve_model_task_and_output_type(model, model_meta.task)
|
98
|
+
model_meta.task = model_task_and_output.task
|
99
99
|
if enable_explainability:
|
100
100
|
explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
|
101
101
|
model_meta = handlers_utils.add_explain_method_signature(
|
@@ -227,7 +227,7 @@ class CatBoostModelHandler(_base.BaseModelHandler["catboost.CatBoost"]):
|
|
227
227
|
import shap
|
228
228
|
|
229
229
|
explainer = shap.TreeExplainer(raw_model)
|
230
|
-
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer(X)
|
230
|
+
df = handlers_utils.convert_explanations_to_2D_df(raw_model, explainer.shap_values(X))
|
231
231
|
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
232
232
|
|
233
233
|
if target_method == "explain":
|
@@ -66,7 +66,7 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
66
66
|
sample_input_data = model_signature._convert_local_data_to_df(sample_input_data)
|
67
67
|
|
68
68
|
if inspect.iscoroutinefunction(target_method):
|
69
|
-
with anyio.start_blocking_portal() as portal:
|
69
|
+
with anyio.from_thread.start_blocking_portal() as portal:
|
70
70
|
predictions_df = portal.call(target_method, model, sample_input_data)
|
71
71
|
else:
|
72
72
|
predictions_df = target_method(model, sample_input_data)
|
@@ -98,7 +98,6 @@ class CustomModelHandler(_base.BaseModelHandler["custom_model.CustomModel"]):
|
|
98
98
|
if model.context.model_refs:
|
99
99
|
for sub_name, model_ref in model.context.model_refs.items():
|
100
100
|
handler = model_handler.find_handler(model_ref.model)
|
101
|
-
assert handler is not None
|
102
101
|
if handler is None:
|
103
102
|
raise TypeError("Your input type to custom model is not currently supported")
|
104
103
|
sub_model = handler.cast_model(model_ref.model)
|
@@ -195,8 +195,12 @@ class HuggingFacePipelineHandler(
|
|
195
195
|
os.makedirs(model_blob_path, exist_ok=True)
|
196
196
|
|
197
197
|
if type_utils.LazyType("transformers.Pipeline").isinstance(model):
|
198
|
+
save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
|
198
199
|
model.save_pretrained( # type:ignore[attr-defined]
|
199
|
-
|
200
|
+
save_path
|
201
|
+
)
|
202
|
+
handlers_utils.save_transformers_config_with_auto_map(
|
203
|
+
save_path,
|
200
204
|
)
|
201
205
|
pipeline_params = {
|
202
206
|
"_batch_size": model._batch_size, # type:ignore[attr-defined]
|
@@ -319,6 +323,7 @@ class HuggingFacePipelineHandler(
|
|
319
323
|
model_blob_options["task"],
|
320
324
|
model=model_blob_file_or_dir_path,
|
321
325
|
trust_remote_code=True,
|
326
|
+
torch_dtype="auto",
|
322
327
|
**device_config,
|
323
328
|
)
|
324
329
|
|
@@ -110,8 +110,8 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
110
110
|
sample_input_data=sample_input_data,
|
111
111
|
get_prediction_fn=get_prediction,
|
112
112
|
)
|
113
|
-
model_task_and_output = model_task_utils.
|
114
|
-
model_meta.task =
|
113
|
+
model_task_and_output = model_task_utils.resolve_model_task_and_output_type(model, model_meta.task)
|
114
|
+
model_meta.task = model_task_and_output.task
|
115
115
|
if enable_explainability:
|
116
116
|
explain_target_method = handlers_utils.get_explain_target_method(model_meta, cls.EXPLAIN_TARGET_METHODS)
|
117
117
|
model_meta = handlers_utils.add_explain_method_signature(
|
@@ -240,7 +240,9 @@ class LGBMModelHandler(_base.BaseModelHandler[Union["lightgbm.Booster", "lightgb
|
|
240
240
|
import shap
|
241
241
|
|
242
242
|
explainer = shap.TreeExplainer(raw_model)
|
243
|
-
df = handlers_utils.convert_explanations_to_2D_df(
|
243
|
+
df = handlers_utils.convert_explanations_to_2D_df(
|
244
|
+
raw_model, explainer.shap_values(X, from_call=True)
|
245
|
+
)
|
244
246
|
return model_signature_utils.rename_pandas_df(df, signature.outputs)
|
245
247
|
|
246
248
|
if target_method == "explain":
|
@@ -14,8 +14,8 @@ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
|
14
14
|
from snowflake.ml.model._packager.model_meta import (
|
15
15
|
model_blob_meta,
|
16
16
|
model_meta as model_meta_api,
|
17
|
+
model_meta_schema,
|
17
18
|
)
|
18
|
-
from snowflake.ml.model._signatures import utils as model_signature_utils
|
19
19
|
from snowflake.snowpark._internal import utils as snowpark_utils
|
20
20
|
|
21
21
|
if TYPE_CHECKING:
|
@@ -24,6 +24,25 @@ if TYPE_CHECKING:
|
|
24
24
|
logger = logging.getLogger(__name__)
|
25
25
|
|
26
26
|
|
27
|
+
def _validate_sentence_transformers_signatures(sigs: Dict[str, model_signature.ModelSignature]) -> None:
|
28
|
+
if list(sigs.keys()) != ["encode"]:
|
29
|
+
raise ValueError("target_methods can only be ['encode']")
|
30
|
+
|
31
|
+
if len(sigs["encode"].inputs) != 1:
|
32
|
+
raise ValueError("SentenceTransformer can only accept 1 input column")
|
33
|
+
|
34
|
+
if len(sigs["encode"].outputs) != 1:
|
35
|
+
raise ValueError("SentenceTransformer can only return 1 output column")
|
36
|
+
|
37
|
+
assert isinstance(sigs["encode"].inputs[0], model_signature.FeatureSpec)
|
38
|
+
|
39
|
+
if sigs["encode"].inputs[0]._shape is not None:
|
40
|
+
raise ValueError("SentenceTransformer does not support input shape")
|
41
|
+
|
42
|
+
if sigs["encode"].inputs[0]._dtype != model_signature.DataType.STRING:
|
43
|
+
raise ValueError("SentenceTransformer only accepts string input")
|
44
|
+
|
45
|
+
|
27
46
|
@final
|
28
47
|
class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.SentenceTransformer"]):
|
29
48
|
HANDLER_TYPE = "sentence_transformers"
|
@@ -68,6 +87,10 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
68
87
|
if enable_explainability:
|
69
88
|
raise NotImplementedError("Explainability is not supported for Sentence Transformer model.")
|
70
89
|
|
90
|
+
batch_size = kwargs.get("batch_size", 32)
|
91
|
+
if not isinstance(batch_size, int) or batch_size <= 0:
|
92
|
+
raise ValueError("batch_size must be a positive integer")
|
93
|
+
|
71
94
|
# Validate target methods and signature (if possible)
|
72
95
|
if not is_sub_model:
|
73
96
|
target_methods = handlers_utils.get_target_methods(
|
@@ -75,12 +98,23 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
75
98
|
target_methods=kwargs.pop("target_methods", None),
|
76
99
|
default_target_methods=cls.DEFAULT_TARGET_METHODS,
|
77
100
|
)
|
78
|
-
|
101
|
+
if target_methods != ["encode"]:
|
102
|
+
raise ValueError("target_methods can only be ['encode']")
|
79
103
|
|
80
104
|
def get_prediction(
|
81
105
|
target_method_name: str, sample_input_data: model_types.SupportedLocalDataType
|
82
106
|
) -> model_types.SupportedLocalDataType:
|
83
|
-
|
107
|
+
if not isinstance(sample_input_data, pd.DataFrame):
|
108
|
+
sample_input_data = model_signature._convert_local_data_to_df(data=sample_input_data)
|
109
|
+
|
110
|
+
if sample_input_data.shape[1] != 1:
|
111
|
+
raise ValueError(
|
112
|
+
"SentenceTransformer can only accept 1 input column when converted to pd.DataFrame"
|
113
|
+
)
|
114
|
+
X_list = sample_input_data.iloc[:, 0].tolist()
|
115
|
+
|
116
|
+
assert callable(getattr(model, "encode", None))
|
117
|
+
return pd.DataFrame({0: model.encode(X_list, batch_size=batch_size).tolist()})
|
84
118
|
|
85
119
|
if model_meta.signatures:
|
86
120
|
handlers_utils.validate_target_methods(model, list(model_meta.signatures.keys()))
|
@@ -102,10 +136,16 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
102
136
|
get_prediction_fn=get_prediction,
|
103
137
|
)
|
104
138
|
|
139
|
+
_validate_sentence_transformers_signatures(model_meta.signatures)
|
140
|
+
|
105
141
|
# save model
|
106
142
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
107
143
|
os.makedirs(model_blob_path, exist_ok=True)
|
108
|
-
|
144
|
+
save_path = os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR)
|
145
|
+
model.save(save_path)
|
146
|
+
handlers_utils.save_transformers_config_with_auto_map(
|
147
|
+
save_path,
|
148
|
+
)
|
109
149
|
|
110
150
|
# save model metadata
|
111
151
|
base_meta = model_blob_meta.ModelBlobMeta(
|
@@ -113,6 +153,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
113
153
|
model_type=cls.HANDLER_TYPE,
|
114
154
|
handler_version=cls.HANDLER_VERSION,
|
115
155
|
path=cls.MODEL_BLOB_FILE_OR_DIR,
|
156
|
+
options=model_meta_schema.SentenceTransformersModelBlobOptions(batch_size=batch_size),
|
116
157
|
)
|
117
158
|
model_meta.models[name] = base_meta
|
118
159
|
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
@@ -149,6 +190,7 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
149
190
|
if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
|
150
191
|
# We need to redirect the same folders to a writable location in the sandbox.
|
151
192
|
os.environ["TRANSFORMERS_CACHE"] = "/tmp"
|
193
|
+
os.environ["HF_HOME"] = "/tmp"
|
152
194
|
|
153
195
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
154
196
|
model_blobs_metadata = model_meta.models
|
@@ -183,6 +225,10 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
183
225
|
raw_model: "sentence_transformers.SentenceTransformer",
|
184
226
|
model_meta: model_meta_api.ModelMetadata,
|
185
227
|
) -> Type[custom_model.CustomModel]:
|
228
|
+
batch_size = cast(
|
229
|
+
model_meta_schema.SentenceTransformersModelBlobOptions, model_meta.models[model_meta.name].options
|
230
|
+
).get("batch_size", None)
|
231
|
+
|
186
232
|
def get_prediction(
|
187
233
|
raw_model: "sentence_transformers.SentenceTransformer",
|
188
234
|
signature: model_signature.ModelSignature,
|
@@ -190,8 +236,11 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
190
236
|
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
191
237
|
@custom_model.inference_api
|
192
238
|
def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
193
|
-
|
194
|
-
|
239
|
+
X_list = X.iloc[:, 0].tolist()
|
240
|
+
|
241
|
+
return pd.DataFrame(
|
242
|
+
{signature.outputs[0].name: raw_model.encode(X_list, batch_size=batch_size).tolist()}
|
243
|
+
)
|
195
244
|
|
196
245
|
return fn
|
197
246
|
|
@@ -217,17 +266,3 @@ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.S
|
|
217
266
|
predict_method = getattr(sentence_transformers_SentenceTransformer_model, "encode", None)
|
218
267
|
assert callable(predict_method)
|
219
268
|
return sentence_transformers_SentenceTransformer_model
|
220
|
-
|
221
|
-
|
222
|
-
def _sentence_transformer_encode(
|
223
|
-
model: "sentence_transformers.SentenceTransformer", X: model_types.SupportedLocalDataType
|
224
|
-
) -> model_types.SupportedLocalDataType:
|
225
|
-
|
226
|
-
if not isinstance(X, pd.DataFrame):
|
227
|
-
X = model_signature._convert_local_data_to_df(X)
|
228
|
-
|
229
|
-
assert X.shape[1] == 1, "SentenceTransformer can only accept 1 input column when converted to pd.DataFrame"
|
230
|
-
X_list = X.iloc[:, 0].tolist()
|
231
|
-
|
232
|
-
assert callable(getattr(model, "encode", None))
|
233
|
-
return pd.DataFrame({0: model.encode(X_list, batch_size=X.shape[0]).tolist()})
|
@@ -152,8 +152,8 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
152
152
|
sample_input_data, model_meta, explain_target_method
|
153
153
|
)
|
154
154
|
|
155
|
-
model_task_and_output_type = model_task_utils.
|
156
|
-
model_meta.task =
|
155
|
+
model_task_and_output_type = model_task_utils.resolve_model_task_and_output_type(model, model_meta.task)
|
156
|
+
model_meta.task = model_task_and_output_type.task
|
157
157
|
|
158
158
|
# if users did not ask then we enable if we have background data
|
159
159
|
if enable_explainability is None:
|
@@ -164,11 +164,17 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
164
164
|
stacklevel=1,
|
165
165
|
)
|
166
166
|
enable_explainability = False
|
167
|
-
elif model_meta.task == model_types.Task.UNKNOWN:
|
167
|
+
elif model_meta.task == model_types.Task.UNKNOWN or explain_target_method is None:
|
168
168
|
enable_explainability = False
|
169
169
|
else:
|
170
170
|
enable_explainability = True
|
171
171
|
if enable_explainability:
|
172
|
+
model_meta = handlers_utils.add_explain_method_signature(
|
173
|
+
model_meta=model_meta,
|
174
|
+
explain_method="explain",
|
175
|
+
target_method=explain_target_method,
|
176
|
+
output_return_type=model_task_and_output_type.output_type,
|
177
|
+
)
|
172
178
|
handlers_utils.save_background_data(
|
173
179
|
model_blobs_dir_path,
|
174
180
|
cls.EXPLAIN_ARTIFACTS_DIR,
|
@@ -177,13 +183,6 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
177
183
|
background_data,
|
178
184
|
)
|
179
185
|
|
180
|
-
model_meta = handlers_utils.add_explain_method_signature(
|
181
|
-
model_meta=model_meta,
|
182
|
-
explain_method="explain",
|
183
|
-
target_method=explain_target_method,
|
184
|
-
output_return_type=model_task_and_output_type.output_type,
|
185
|
-
)
|
186
|
-
|
187
186
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
188
187
|
os.makedirs(model_blob_path, exist_ok=True)
|
189
188
|
with open(os.path.join(model_blob_path, cls.MODEL_BLOB_FILE_OR_DIR), "wb") as f:
|