snowflake-ml-python 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- snowflake/ml/_internal/file_utils.py +3 -3
- snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
- snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
- snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
- snowflake/ml/_internal/telemetry.py +11 -2
- snowflake/ml/_internal/utils/formatting.py +1 -1
- snowflake/ml/feature_store/feature_store.py +15 -106
- snowflake/ml/fileset/sfcfs.py +4 -3
- snowflake/ml/fileset/stage_fs.py +18 -0
- snowflake/ml/model/_api.py +9 -9
- snowflake/ml/model/_client/model/model_version_impl.py +20 -15
- snowflake/ml/model/_deploy_client/image_builds/docker_context.py +3 -9
- snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -5
- snowflake/ml/model/_deploy_client/snowservice/deploy.py +7 -6
- snowflake/ml/model/_model_composer/model_composer.py +10 -8
- snowflake/ml/model/_model_composer/model_method/function_generator.py +1 -1
- snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -1
- snowflake/ml/model/_model_composer/model_method/model_method.py +2 -2
- snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +1 -1
- snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
- snowflake/ml/model/_packager/model_handlers/_utils.py +5 -5
- snowflake/ml/model/_packager/model_handlers/custom.py +7 -7
- snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +2 -2
- snowflake/ml/model/_packager/model_handlers/llm.py +1 -1
- snowflake/ml/model/_packager/model_handlers/mlflow.py +1 -1
- snowflake/ml/model/_packager/model_handlers/pytorch.py +13 -10
- snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +214 -0
- snowflake/ml/model/_packager/model_handlers/sklearn.py +6 -6
- snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +15 -3
- snowflake/ml/model/_packager/model_handlers/tensorflow.py +8 -8
- snowflake/ml/model/_packager/model_handlers/torchscript.py +7 -7
- snowflake/ml/model/_packager/model_handlers/xgboost.py +8 -8
- snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
- snowflake/ml/model/_packager/model_packager.py +8 -6
- snowflake/ml/model/custom_model.py +3 -1
- snowflake/ml/model/type_hints.py +13 -0
- snowflake/ml/modeling/_internal/estimator_utils.py +61 -1
- snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -43
- snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +4 -4
- snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +21 -17
- snowflake/ml/modeling/_internal/model_specifications.py +3 -1
- snowflake/ml/modeling/_internal/model_trainer.py +2 -2
- snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +547 -1
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +67 -114
- snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -9
- snowflake/ml/modeling/_internal/transformer_protocols.py +2 -3
- snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +33 -61
- snowflake/ml/modeling/cluster/affinity_propagation.py +33 -61
- snowflake/ml/modeling/cluster/agglomerative_clustering.py +33 -61
- snowflake/ml/modeling/cluster/birch.py +33 -61
- snowflake/ml/modeling/cluster/bisecting_k_means.py +33 -61
- snowflake/ml/modeling/cluster/dbscan.py +33 -61
- snowflake/ml/modeling/cluster/feature_agglomeration.py +33 -61
- snowflake/ml/modeling/cluster/k_means.py +33 -61
- snowflake/ml/modeling/cluster/mean_shift.py +33 -61
- snowflake/ml/modeling/cluster/mini_batch_k_means.py +33 -61
- snowflake/ml/modeling/cluster/optics.py +33 -61
- snowflake/ml/modeling/cluster/spectral_biclustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_clustering.py +33 -61
- snowflake/ml/modeling/cluster/spectral_coclustering.py +33 -61
- snowflake/ml/modeling/compose/column_transformer.py +33 -61
- snowflake/ml/modeling/compose/transformed_target_regressor.py +33 -61
- snowflake/ml/modeling/covariance/elliptic_envelope.py +33 -61
- snowflake/ml/modeling/covariance/empirical_covariance.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso.py +33 -61
- snowflake/ml/modeling/covariance/graphical_lasso_cv.py +33 -61
- snowflake/ml/modeling/covariance/ledoit_wolf.py +33 -61
- snowflake/ml/modeling/covariance/min_cov_det.py +33 -61
- snowflake/ml/modeling/covariance/oas.py +33 -61
- snowflake/ml/modeling/covariance/shrunk_covariance.py +33 -61
- snowflake/ml/modeling/decomposition/dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/factor_analysis.py +33 -61
- snowflake/ml/modeling/decomposition/fast_ica.py +33 -61
- snowflake/ml/modeling/decomposition/incremental_pca.py +33 -61
- snowflake/ml/modeling/decomposition/kernel_pca.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +33 -61
- snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/pca.py +33 -61
- snowflake/ml/modeling/decomposition/sparse_pca.py +33 -61
- snowflake/ml/modeling/decomposition/truncated_svd.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/ada_boost_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/bagging_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/extra_trees_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/isolation_forest.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/random_forest_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/stacking_regressor.py +33 -61
- snowflake/ml/modeling/ensemble/voting_classifier.py +33 -61
- snowflake/ml/modeling/ensemble/voting_regressor.py +33 -61
- snowflake/ml/modeling/feature_selection/generic_univariate_select.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fdr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fpr.py +33 -61
- snowflake/ml/modeling/feature_selection/select_fwe.py +33 -61
- snowflake/ml/modeling/feature_selection/select_k_best.py +33 -61
- snowflake/ml/modeling/feature_selection/select_percentile.py +33 -61
- snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +33 -61
- snowflake/ml/modeling/feature_selection/variance_threshold.py +33 -61
- snowflake/ml/modeling/framework/base.py +55 -5
- snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +33 -61
- snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +33 -61
- snowflake/ml/modeling/impute/iterative_imputer.py +33 -61
- snowflake/ml/modeling/impute/knn_imputer.py +33 -61
- snowflake/ml/modeling/impute/missing_indicator.py +33 -61
- snowflake/ml/modeling/impute/simple_imputer.py +4 -15
- snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/nystroem.py +33 -61
- snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +33 -61
- snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +33 -61
- snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +33 -61
- snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +33 -61
- snowflake/ml/modeling/lightgbm/lgbm_classifier.py +36 -63
- snowflake/ml/modeling/lightgbm/lgbm_regressor.py +36 -63
- snowflake/ml/modeling/linear_model/ard_regression.py +33 -61
- snowflake/ml/modeling/linear_model/bayesian_ridge.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/gamma_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/huber_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/lars.py +33 -61
- snowflake/ml/modeling/linear_model/lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_cv.py +33 -61
- snowflake/ml/modeling/linear_model/lasso_lars_ic.py +33 -61
- snowflake/ml/modeling/linear_model/linear_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression.py +33 -61
- snowflake/ml/modeling/linear_model/logistic_regression_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso.py +33 -61
- snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +33 -61
- snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/perceptron.py +33 -61
- snowflake/ml/modeling/linear_model/poisson_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ransac_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/ridge.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +33 -61
- snowflake/ml/modeling/linear_model/ridge_cv.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_classifier.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +33 -61
- snowflake/ml/modeling/linear_model/sgd_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/theil_sen_regressor.py +33 -61
- snowflake/ml/modeling/linear_model/tweedie_regressor.py +33 -61
- snowflake/ml/modeling/manifold/isomap.py +33 -61
- snowflake/ml/modeling/manifold/mds.py +33 -61
- snowflake/ml/modeling/manifold/spectral_embedding.py +33 -61
- snowflake/ml/modeling/manifold/tsne.py +33 -61
- snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +33 -61
- snowflake/ml/modeling/mixture/gaussian_mixture.py +33 -61
- snowflake/ml/modeling/model_selection/grid_search_cv.py +39 -57
- snowflake/ml/modeling/model_selection/randomized_search_cv.py +26 -57
- snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +33 -61
- snowflake/ml/modeling/multiclass/output_code_classifier.py +33 -61
- snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/categorical_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/complement_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/gaussian_nb.py +33 -61
- snowflake/ml/modeling/naive_bayes/multinomial_nb.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neighbors/kernel_density.py +33 -61
- snowflake/ml/modeling/neighbors/local_outlier_factor.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_centroid.py +33 -61
- snowflake/ml/modeling/neighbors/nearest_neighbors.py +33 -61
- snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +33 -61
- snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +33 -61
- snowflake/ml/modeling/neural_network/bernoulli_rbm.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_classifier.py +33 -61
- snowflake/ml/modeling/neural_network/mlp_regressor.py +33 -61
- snowflake/ml/modeling/preprocessing/polynomial_features.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_propagation.py +33 -61
- snowflake/ml/modeling/semi_supervised/label_spreading.py +33 -61
- snowflake/ml/modeling/svm/linear_svc.py +33 -61
- snowflake/ml/modeling/svm/linear_svr.py +33 -61
- snowflake/ml/modeling/svm/nu_svc.py +33 -61
- snowflake/ml/modeling/svm/nu_svr.py +33 -61
- snowflake/ml/modeling/svm/svc.py +33 -61
- snowflake/ml/modeling/svm/svr.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/decision_tree_regressor.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_classifier.py +33 -61
- snowflake/ml/modeling/tree/extra_tree_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgb_regressor.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_classifier.py +33 -61
- snowflake/ml/modeling/xgboost/xgbrf_regressor.py +33 -61
- snowflake/ml/registry/_manager/model_manager.py +6 -2
- snowflake/ml/registry/model_registry.py +100 -27
- snowflake/ml/registry/registry.py +6 -2
- snowflake/ml/version.py +1 -1
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/METADATA +43 -7
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/RECORD +211 -206
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/LICENSE.txt +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/WHEEL +0 -0
- {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,214 @@
|
|
1
|
+
import logging
|
2
|
+
import os
|
3
|
+
from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
|
4
|
+
|
5
|
+
import cloudpickle
|
6
|
+
import pandas as pd
|
7
|
+
from typing_extensions import TypeGuard, Unpack
|
8
|
+
|
9
|
+
from snowflake.ml._internal import type_utils
|
10
|
+
from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
|
11
|
+
from snowflake.ml.model._packager.model_env import model_env
|
12
|
+
from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
|
13
|
+
from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
|
14
|
+
from snowflake.ml.model._packager.model_meta import (
|
15
|
+
model_blob_meta,
|
16
|
+
model_meta as model_meta_api,
|
17
|
+
)
|
18
|
+
from snowflake.ml.model._signatures import utils as model_signature_utils
|
19
|
+
from snowflake.snowpark._internal import utils as snowpark_utils
|
20
|
+
|
21
|
+
if TYPE_CHECKING:
|
22
|
+
import sentence_transformers
|
23
|
+
|
24
|
+
logger = logging.getLogger(__name__)
|
25
|
+
|
26
|
+
|
27
|
+
@final
|
28
|
+
class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.SentenceTransformer"]):
|
29
|
+
HANDLER_TYPE = "sentence_transformers"
|
30
|
+
HANDLER_VERSION = "2024-03-15"
|
31
|
+
_MIN_SNOWPARK_ML_VERSION = "1.3.1"
|
32
|
+
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
33
|
+
|
34
|
+
MODELE_BLOB_FILE_OR_DIR = "model"
|
35
|
+
DEFAULT_TARGET_METHODS = ["encode"]
|
36
|
+
|
37
|
+
@classmethod
|
38
|
+
def can_handle(
|
39
|
+
cls,
|
40
|
+
model: model_types.SupportedModelType,
|
41
|
+
) -> TypeGuard["sentence_transformers.SentenceTransformer"]:
|
42
|
+
if type_utils.LazyType("sentence_transformers.SentenceTransformer").isinstance(model):
|
43
|
+
return True
|
44
|
+
return False
|
45
|
+
|
46
|
+
@classmethod
|
47
|
+
def cast_model(
|
48
|
+
cls,
|
49
|
+
model: model_types.SupportedModelType,
|
50
|
+
) -> "sentence_transformers.SentenceTransformer":
|
51
|
+
import sentence_transformers
|
52
|
+
|
53
|
+
assert isinstance(model, sentence_transformers.SentenceTransformer)
|
54
|
+
return cast(sentence_transformers.SentenceTransformer, model)
|
55
|
+
|
56
|
+
@classmethod
|
57
|
+
def save_model(
|
58
|
+
cls,
|
59
|
+
name: str,
|
60
|
+
model: "sentence_transformers.SentenceTransformer",
|
61
|
+
model_meta: model_meta_api.ModelMetadata,
|
62
|
+
model_blobs_dir_path: str,
|
63
|
+
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
64
|
+
is_sub_model: Optional[bool] = False,
|
65
|
+
**kwargs: Unpack[model_types.SentenceTransformersSaveOptions], # registry.log_model(options={...})
|
66
|
+
) -> None:
|
67
|
+
# Validate target methods and signature (if possible)
|
68
|
+
if not is_sub_model:
|
69
|
+
target_methods = handlers_utils.get_target_methods(
|
70
|
+
model=model,
|
71
|
+
target_methods=kwargs.pop("target_methods", None),
|
72
|
+
default_target_methods=cls.DEFAULT_TARGET_METHODS,
|
73
|
+
)
|
74
|
+
assert target_methods == ["encode"], "target_methods can only be ['encode']"
|
75
|
+
|
76
|
+
def get_prediction(
|
77
|
+
target_method_name: str, sample_input_data: model_types.SupportedLocalDataType
|
78
|
+
) -> model_types.SupportedLocalDataType:
|
79
|
+
return _sentence_transformer_encode(model, sample_input_data)
|
80
|
+
|
81
|
+
if model_meta.signatures:
|
82
|
+
handlers_utils.validate_target_methods(model, list(model_meta.signatures.keys()))
|
83
|
+
model_meta = handlers_utils.validate_signature(
|
84
|
+
model=model,
|
85
|
+
model_meta=model_meta,
|
86
|
+
target_methods=target_methods,
|
87
|
+
sample_input_data=sample_input_data,
|
88
|
+
get_prediction_fn=get_prediction,
|
89
|
+
)
|
90
|
+
else:
|
91
|
+
handlers_utils.validate_target_methods(model, target_methods) # DEFAULT_TARGET_METHODS only
|
92
|
+
if sample_input_data is not None:
|
93
|
+
model_meta = handlers_utils.validate_signature(
|
94
|
+
model=model,
|
95
|
+
model_meta=model_meta,
|
96
|
+
target_methods=target_methods,
|
97
|
+
sample_input_data=sample_input_data,
|
98
|
+
get_prediction_fn=get_prediction,
|
99
|
+
)
|
100
|
+
|
101
|
+
# save model
|
102
|
+
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
103
|
+
os.makedirs(model_blob_path, exist_ok=True)
|
104
|
+
model.save(os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR))
|
105
|
+
|
106
|
+
# save model metadata
|
107
|
+
base_meta = model_blob_meta.ModelBlobMeta(
|
108
|
+
name=name,
|
109
|
+
model_type=cls.HANDLER_TYPE,
|
110
|
+
handler_version=cls.HANDLER_VERSION,
|
111
|
+
path=cls.MODELE_BLOB_FILE_OR_DIR,
|
112
|
+
)
|
113
|
+
model_meta.models[name] = base_meta
|
114
|
+
model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
|
115
|
+
|
116
|
+
model_meta.env.include_if_absent(
|
117
|
+
[
|
118
|
+
model_env.ModelDependency(requirement="sentence-transformers", pip_name="sentence-transformers"),
|
119
|
+
],
|
120
|
+
check_local_version=True,
|
121
|
+
)
|
122
|
+
|
123
|
+
@classmethod
|
124
|
+
def load_model(
|
125
|
+
cls,
|
126
|
+
name: str,
|
127
|
+
model_meta: model_meta_api.ModelMetadata,
|
128
|
+
model_blobs_dir_path: str,
|
129
|
+
**kwargs: Unpack[model_types.ModelLoadOption], # use_gpu
|
130
|
+
) -> "sentence_transformers.SentenceTransformer":
|
131
|
+
import sentence_transformers
|
132
|
+
|
133
|
+
if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
|
134
|
+
# We need to redirect the same folders to a writable location in the sandbox.
|
135
|
+
os.environ["TRANSFORMERS_CACHE"] = "/tmp"
|
136
|
+
|
137
|
+
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
138
|
+
model_blobs_metadata = model_meta.models
|
139
|
+
model_blob_metadata = model_blobs_metadata[name]
|
140
|
+
model_blob_filename = model_blob_metadata.path
|
141
|
+
model_blob_file_or_dir_path = os.path.join(model_blob_path, model_blob_filename)
|
142
|
+
|
143
|
+
if os.path.isdir(model_blob_file_or_dir_path): # if the saved model is a directory
|
144
|
+
model = sentence_transformers.SentenceTransformer(model_blob_file_or_dir_path)
|
145
|
+
else:
|
146
|
+
assert os.path.isfile(model_blob_file_or_dir_path) # if the saved model is a file
|
147
|
+
with open(model_blob_file_or_dir_path, "rb") as f:
|
148
|
+
model = cloudpickle.load(f)
|
149
|
+
assert isinstance(model, sentence_transformers.SentenceTransformer)
|
150
|
+
return model
|
151
|
+
|
152
|
+
@classmethod
|
153
|
+
def convert_as_custom_model(
|
154
|
+
cls,
|
155
|
+
raw_model: "sentence_transformers.SentenceTransformer",
|
156
|
+
model_meta: model_meta_api.ModelMetadata,
|
157
|
+
**kwargs: Unpack[model_types.ModelLoadOption],
|
158
|
+
) -> custom_model.CustomModel:
|
159
|
+
import sentence_transformers
|
160
|
+
|
161
|
+
from snowflake.ml.model import custom_model
|
162
|
+
|
163
|
+
def _create_custom_model(
|
164
|
+
raw_model: "sentence_transformers.SentenceTransformer",
|
165
|
+
model_meta: model_meta_api.ModelMetadata,
|
166
|
+
) -> Type[custom_model.CustomModel]:
|
167
|
+
def get_prediction(
|
168
|
+
raw_model: "sentence_transformers.SentenceTransformer",
|
169
|
+
signature: model_signature.ModelSignature,
|
170
|
+
target_method: str,
|
171
|
+
) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
|
172
|
+
@custom_model.inference_api
|
173
|
+
def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
|
174
|
+
predictions_df = _sentence_transformer_encode(raw_model, X)
|
175
|
+
return model_signature_utils.rename_pandas_df(predictions_df, signature.outputs)
|
176
|
+
|
177
|
+
return fn
|
178
|
+
|
179
|
+
type_method_dict = {}
|
180
|
+
for target_method_name, sig in model_meta.signatures.items():
|
181
|
+
if target_method_name == "encode":
|
182
|
+
type_method_dict[target_method_name] = get_prediction(raw_model, sig, target_method_name)
|
183
|
+
else:
|
184
|
+
ValueError(f"{target_method_name} is currently not supported.")
|
185
|
+
|
186
|
+
_SentenceTransformer = type(
|
187
|
+
"_SentenceTransformer",
|
188
|
+
(custom_model.CustomModel,),
|
189
|
+
type_method_dict,
|
190
|
+
)
|
191
|
+
return _SentenceTransformer
|
192
|
+
|
193
|
+
assert isinstance(raw_model, sentence_transformers.SentenceTransformer)
|
194
|
+
model = raw_model
|
195
|
+
|
196
|
+
_SentenceTransformer = _create_custom_model(model, model_meta)
|
197
|
+
sentence_transformers_SentenceTransformer_model = _SentenceTransformer(custom_model.ModelContext())
|
198
|
+
predict_method = getattr(sentence_transformers_SentenceTransformer_model, "encode", None)
|
199
|
+
assert callable(predict_method)
|
200
|
+
return sentence_transformers_SentenceTransformer_model
|
201
|
+
|
202
|
+
|
203
|
+
def _sentence_transformer_encode(
|
204
|
+
model: "sentence_transformers.SentenceTransformer", X: model_types.SupportedLocalDataType
|
205
|
+
) -> model_types.SupportedLocalDataType:
|
206
|
+
|
207
|
+
if not isinstance(X, pd.DataFrame):
|
208
|
+
X = model_signature._convert_local_data_to_df(X)
|
209
|
+
|
210
|
+
assert X.shape[1] == 1, "SentenceTransformer can only accept 1 input column when converted to pd.DataFrame"
|
211
|
+
X_list = X.iloc[:, 0].tolist()
|
212
|
+
|
213
|
+
assert callable(getattr(model, "encode", None))
|
214
|
+
return pd.DataFrame({0: model.encode(X_list, batch_size=X.shape[0]).tolist()})
|
@@ -72,7 +72,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
72
72
|
model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
|
73
73
|
model_meta: model_meta_api.ModelMetadata,
|
74
74
|
model_blobs_dir_path: str,
|
75
|
-
|
75
|
+
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
76
76
|
is_sub_model: Optional[bool] = False,
|
77
77
|
**kwargs: Unpack[model_types.SKLModelSaveOptions],
|
78
78
|
) -> None:
|
@@ -89,21 +89,21 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
|
|
89
89
|
)
|
90
90
|
|
91
91
|
def get_prediction(
|
92
|
-
target_method_name: str,
|
92
|
+
target_method_name: str, sample_input_data: model_types.SupportedLocalDataType
|
93
93
|
) -> model_types.SupportedLocalDataType:
|
94
|
-
if not isinstance(
|
95
|
-
|
94
|
+
if not isinstance(sample_input_data, (pd.DataFrame, np.ndarray)):
|
95
|
+
sample_input_data = model_signature._convert_local_data_to_df(sample_input_data)
|
96
96
|
|
97
97
|
target_method = getattr(model, target_method_name, None)
|
98
98
|
assert callable(target_method)
|
99
|
-
predictions_df = target_method(
|
99
|
+
predictions_df = target_method(sample_input_data)
|
100
100
|
return predictions_df
|
101
101
|
|
102
102
|
model_meta = handlers_utils.validate_signature(
|
103
103
|
model=model,
|
104
104
|
model_meta=model_meta,
|
105
105
|
target_methods=target_methods,
|
106
|
-
|
106
|
+
sample_input_data=sample_input_data,
|
107
107
|
get_prediction_fn=get_prediction,
|
108
108
|
)
|
109
109
|
|
@@ -69,7 +69,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
69
69
|
model: "BaseEstimator",
|
70
70
|
model_meta: model_meta_api.ModelMetadata,
|
71
71
|
model_blobs_dir_path: str,
|
72
|
-
|
72
|
+
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
73
73
|
is_sub_model: Optional[bool] = False,
|
74
74
|
**kwargs: Unpack[model_types.SNOWModelSaveOptions],
|
75
75
|
) -> None:
|
@@ -79,7 +79,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
79
79
|
# Pipeline is inherited from BaseEstimator, so no need to add one more check
|
80
80
|
|
81
81
|
if not is_sub_model:
|
82
|
-
if
|
82
|
+
if sample_input_data is not None or model_meta.signatures:
|
83
83
|
warnings.warn(
|
84
84
|
"Inferring model signature from sample input or providing model signature for Snowpark ML "
|
85
85
|
+ "Modeling model is not required. Model signature will automatically be inferred during fitting. ",
|
@@ -87,7 +87,19 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
|
|
87
87
|
stacklevel=2,
|
88
88
|
)
|
89
89
|
assert hasattr(model, "model_signatures"), "Model does not have model signatures as expected."
|
90
|
-
|
90
|
+
model_signature_dict = getattr(model, "model_signatures", {})
|
91
|
+
target_methods = kwargs.pop("target_methods", None)
|
92
|
+
if not target_methods:
|
93
|
+
model_meta.signatures = model_signature_dict
|
94
|
+
else:
|
95
|
+
temp_model_signature_dict = {}
|
96
|
+
for method_name in target_methods:
|
97
|
+
method_model_signature = model_signature_dict.get(method_name, None)
|
98
|
+
if method_model_signature is not None:
|
99
|
+
temp_model_signature_dict[method_name] = method_model_signature
|
100
|
+
else:
|
101
|
+
raise ValueError(f"Target method {method_name} does not exist in the model.")
|
102
|
+
model_meta.signatures = temp_model_signature_dict
|
91
103
|
|
92
104
|
model_blob_path = os.path.join(model_blobs_dir_path, name)
|
93
105
|
os.makedirs(model_blob_path, exist_ok=True)
|
@@ -64,7 +64,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
64
64
|
model: "tensorflow.Module",
|
65
65
|
model_meta: model_meta_api.ModelMetadata,
|
66
66
|
model_blobs_dir_path: str,
|
67
|
-
|
67
|
+
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
68
68
|
is_sub_model: Optional[bool] = False,
|
69
69
|
**kwargs: Unpack[model_types.TensorflowSaveOptions],
|
70
70
|
) -> None:
|
@@ -85,18 +85,18 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
85
85
|
)
|
86
86
|
|
87
87
|
def get_prediction(
|
88
|
-
target_method_name: str,
|
88
|
+
target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
|
89
89
|
) -> model_types.SupportedLocalDataType:
|
90
|
-
if not tensorflow_handler.SeqOfTensorflowTensorHandler.can_handle(
|
91
|
-
|
92
|
-
model_signature._convert_local_data_to_df(
|
90
|
+
if not tensorflow_handler.SeqOfTensorflowTensorHandler.can_handle(sample_input_data):
|
91
|
+
sample_input_data = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(
|
92
|
+
model_signature._convert_local_data_to_df(sample_input_data)
|
93
93
|
)
|
94
94
|
|
95
95
|
target_method = getattr(model, target_method_name, None)
|
96
96
|
assert callable(target_method)
|
97
|
-
for tensor in
|
97
|
+
for tensor in sample_input_data:
|
98
98
|
tensorflow.stop_gradient(tensor)
|
99
|
-
predictions_df = target_method(*
|
99
|
+
predictions_df = target_method(*sample_input_data)
|
100
100
|
|
101
101
|
if isinstance(predictions_df, (tensorflow.Tensor, tensorflow.Variable, np.ndarray)):
|
102
102
|
predictions_df = [predictions_df]
|
@@ -107,7 +107,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
|
|
107
107
|
model=model,
|
108
108
|
model_meta=model_meta,
|
109
109
|
target_methods=target_methods,
|
110
|
-
|
110
|
+
sample_input_data=sample_input_data,
|
111
111
|
get_prediction_fn=get_prediction,
|
112
112
|
)
|
113
113
|
|
@@ -62,7 +62,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
62
62
|
model: "torch.jit.ScriptModule", # type:ignore[name-defined]
|
63
63
|
model_meta: model_meta_api.ModelMetadata,
|
64
64
|
model_blobs_dir_path: str,
|
65
|
-
|
65
|
+
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
66
66
|
is_sub_model: Optional[bool] = False,
|
67
67
|
**kwargs: Unpack[model_types.TorchScriptSaveOptions],
|
68
68
|
) -> None:
|
@@ -78,18 +78,18 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
78
78
|
)
|
79
79
|
|
80
80
|
def get_prediction(
|
81
|
-
target_method_name: str,
|
81
|
+
target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
|
82
82
|
) -> model_types.SupportedLocalDataType:
|
83
|
-
if not pytorch_handler.SeqOfPyTorchTensorHandler.can_handle(
|
84
|
-
|
85
|
-
model_signature._convert_local_data_to_df(
|
83
|
+
if not pytorch_handler.SeqOfPyTorchTensorHandler.can_handle(sample_input_data):
|
84
|
+
sample_input_data = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(
|
85
|
+
model_signature._convert_local_data_to_df(sample_input_data)
|
86
86
|
)
|
87
87
|
|
88
88
|
model.eval()
|
89
89
|
target_method = getattr(model, target_method_name, None)
|
90
90
|
assert callable(target_method)
|
91
91
|
with torch.no_grad():
|
92
|
-
predictions_df = target_method(*
|
92
|
+
predictions_df = target_method(*sample_input_data)
|
93
93
|
|
94
94
|
if isinstance(predictions_df, torch.Tensor):
|
95
95
|
predictions_df = [predictions_df]
|
@@ -100,7 +100,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
|
|
100
100
|
model=model,
|
101
101
|
model_meta=model_meta,
|
102
102
|
target_methods=target_methods,
|
103
|
-
|
103
|
+
sample_input_data=sample_input_data,
|
104
104
|
get_prediction_fn=get_prediction,
|
105
105
|
)
|
106
106
|
|
@@ -45,7 +45,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
45
45
|
_HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
|
46
46
|
|
47
47
|
MODELE_BLOB_FILE_OR_DIR = "model.ubj"
|
48
|
-
DEFAULT_TARGET_METHODS = ["
|
48
|
+
DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
|
49
49
|
|
50
50
|
@classmethod
|
51
51
|
def can_handle(
|
@@ -76,7 +76,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
76
76
|
model: Union["xgboost.Booster", "xgboost.XGBModel"],
|
77
77
|
model_meta: model_meta_api.ModelMetadata,
|
78
78
|
model_blobs_dir_path: str,
|
79
|
-
|
79
|
+
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
80
80
|
is_sub_model: Optional[bool] = False,
|
81
81
|
**kwargs: Unpack[model_types.XGBModelSaveOptions],
|
82
82
|
) -> None:
|
@@ -92,24 +92,24 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
|
|
92
92
|
)
|
93
93
|
|
94
94
|
def get_prediction(
|
95
|
-
target_method_name: str,
|
95
|
+
target_method_name: str, sample_input_data: model_types.SupportedLocalDataType
|
96
96
|
) -> model_types.SupportedLocalDataType:
|
97
|
-
if not isinstance(
|
98
|
-
|
97
|
+
if not isinstance(sample_input_data, (pd.DataFrame, np.ndarray)):
|
98
|
+
sample_input_data = model_signature._convert_local_data_to_df(sample_input_data)
|
99
99
|
|
100
100
|
if isinstance(model, xgboost.Booster):
|
101
|
-
|
101
|
+
sample_input_data = xgboost.DMatrix(sample_input_data)
|
102
102
|
|
103
103
|
target_method = getattr(model, target_method_name, None)
|
104
104
|
assert callable(target_method)
|
105
|
-
predictions_df = target_method(
|
105
|
+
predictions_df = target_method(sample_input_data)
|
106
106
|
return predictions_df
|
107
107
|
|
108
108
|
model_meta = handlers_utils.validate_signature(
|
109
109
|
model=model,
|
110
110
|
model_meta=model_meta,
|
111
111
|
target_methods=target_methods,
|
112
|
-
|
112
|
+
sample_input_data=sample_input_data,
|
113
113
|
get_prediction_fn=get_prediction,
|
114
114
|
)
|
115
115
|
|
@@ -40,7 +40,7 @@ class ModelPackager:
|
|
40
40
|
name: str,
|
41
41
|
model: model_types.SupportedModelType,
|
42
42
|
signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
|
43
|
-
|
43
|
+
sample_input_data: Optional[model_types.SupportedDataType] = None,
|
44
44
|
metadata: Optional[Dict[str, str]] = None,
|
45
45
|
conda_dependencies: Optional[List[str]] = None,
|
46
46
|
pip_requirements: Optional[List[str]] = None,
|
@@ -49,18 +49,20 @@ class ModelPackager:
|
|
49
49
|
code_paths: Optional[List[str]] = None,
|
50
50
|
options: Optional[model_types.ModelSaveOption] = None,
|
51
51
|
) -> None:
|
52
|
-
if (signatures is None) and (
|
52
|
+
if (signatures is None) and (sample_input_data is None) and not model_handler.is_auto_signature_model(model):
|
53
53
|
raise snowml_exceptions.SnowflakeMLException(
|
54
54
|
error_code=error_codes.INVALID_ARGUMENT,
|
55
55
|
original_exception=ValueError(
|
56
|
-
"Signatures and
|
56
|
+
"Signatures and sample_input_data both cannot be None at the same time for this kind of model."
|
57
57
|
),
|
58
58
|
)
|
59
59
|
|
60
|
-
if (signatures is not None) and (
|
60
|
+
if (signatures is not None) and (sample_input_data is not None):
|
61
61
|
raise snowml_exceptions.SnowflakeMLException(
|
62
62
|
error_code=error_codes.INVALID_ARGUMENT,
|
63
|
-
original_exception=ValueError(
|
63
|
+
original_exception=ValueError(
|
64
|
+
"Signatures and sample_input_data both cannot be specified at the same time."
|
65
|
+
),
|
64
66
|
)
|
65
67
|
|
66
68
|
if not options:
|
@@ -93,7 +95,7 @@ class ModelPackager:
|
|
93
95
|
model=model,
|
94
96
|
model_meta=meta,
|
95
97
|
model_blobs_dir_path=model_blobs_path,
|
96
|
-
|
98
|
+
sample_input_data=sample_input_data,
|
97
99
|
is_sub_model=False,
|
98
100
|
**options,
|
99
101
|
)
|
@@ -149,7 +149,9 @@ class CustomModel:
|
|
149
149
|
context: A ModelContext object showing sub-models and artifacts related to this model.
|
150
150
|
"""
|
151
151
|
|
152
|
-
def __init__(self, context: ModelContext) -> None:
|
152
|
+
def __init__(self, context: Optional[ModelContext] = None) -> None:
|
153
|
+
if context is None:
|
154
|
+
context = ModelContext()
|
153
155
|
self.context = context
|
154
156
|
for method in self._get_infer_methods():
|
155
157
|
_validate_predict_function(method)
|
snowflake/ml/model/type_hints.py
CHANGED
@@ -22,6 +22,7 @@ if TYPE_CHECKING:
|
|
22
22
|
import mlflow
|
23
23
|
import numpy as np
|
24
24
|
import pandas as pd
|
25
|
+
import sentence_transformers
|
25
26
|
import sklearn.base
|
26
27
|
import sklearn.pipeline
|
27
28
|
import tensorflow
|
@@ -32,6 +33,7 @@ if TYPE_CHECKING:
|
|
32
33
|
import snowflake.ml.model.custom_model
|
33
34
|
import snowflake.ml.model.models.huggingface_pipeline
|
34
35
|
import snowflake.ml.model.models.llm
|
36
|
+
import snowflake.ml.model.models.sentence_transformers
|
35
37
|
import snowflake.snowpark
|
36
38
|
from snowflake.ml.modeling.framework import base # noqa: F401
|
37
39
|
|
@@ -81,7 +83,9 @@ SupportedNoSignatureRequirementsModelType = Union[
|
|
81
83
|
"base.BaseEstimator",
|
82
84
|
"mlflow.pyfunc.PyFuncModel",
|
83
85
|
"transformers.Pipeline",
|
86
|
+
"sentence_transformers.SentenceTransformer",
|
84
87
|
"snowflake.ml.model.models.huggingface_pipeline.HuggingFacePipelineModel",
|
88
|
+
"snowflake.ml.model.models.sentence_transformers.SentenceTransformer",
|
85
89
|
"snowflake.ml.model.models.llm.LLM",
|
86
90
|
]
|
87
91
|
|
@@ -106,6 +110,7 @@ Here is all acceptable types of Snowflake native model packaging and its handler
|
|
106
110
|
| mlflow.pyfunc.PyFuncModel | mlflow.py | _MLFlowHandler |
|
107
111
|
| transformers.Pipeline | huggingface_pipeline.py | _HuggingFacePipelineHandler |
|
108
112
|
| huggingface_pipeline.HuggingFacePipelineModel | huggingface_pipeline.py | _HuggingFacePipelineHandler |
|
113
|
+
| sentence_transformers.SentenceTransformer | sentence_transformers.py | _SentenceTransformerHandler |
|
109
114
|
"""
|
110
115
|
|
111
116
|
SupportedModelHandlerType = Literal[
|
@@ -113,6 +118,7 @@ SupportedModelHandlerType = Literal[
|
|
113
118
|
"huggingface_pipeline",
|
114
119
|
"mlflow",
|
115
120
|
"pytorch",
|
121
|
+
"sentence_transformers",
|
116
122
|
"sklearn",
|
117
123
|
"snowml",
|
118
124
|
"tensorflow",
|
@@ -215,6 +221,7 @@ class BaseModelSaveOption(TypedDict):
|
|
215
221
|
embed_local_ml_library: NotRequired[bool]
|
216
222
|
relax_version: NotRequired[bool]
|
217
223
|
_legacy_save: NotRequired[bool]
|
224
|
+
function_type: NotRequired[Literal["FUNCTION", "TABLE_FUNCTION"]]
|
218
225
|
method_options: NotRequired[Dict[str, ModelMethodSaveOptions]]
|
219
226
|
|
220
227
|
|
@@ -261,6 +268,11 @@ class HuggingFaceSaveOptions(BaseModelSaveOption):
|
|
261
268
|
cuda_version: NotRequired[str]
|
262
269
|
|
263
270
|
|
271
|
+
class SentenceTransformersSaveOptions(BaseModelSaveOption):
|
272
|
+
target_methods: NotRequired[Sequence[str]]
|
273
|
+
cuda_version: NotRequired[str]
|
274
|
+
|
275
|
+
|
264
276
|
class LLMSaveOptions(BaseModelSaveOption):
|
265
277
|
cuda_version: NotRequired[str]
|
266
278
|
|
@@ -276,6 +288,7 @@ ModelSaveOption = Union[
|
|
276
288
|
TensorflowSaveOptions,
|
277
289
|
MLFlowSaveOptions,
|
278
290
|
HuggingFaceSaveOptions,
|
291
|
+
SentenceTransformersSaveOptions,
|
279
292
|
LLMSaveOptions,
|
280
293
|
]
|
281
294
|
|
@@ -1,7 +1,9 @@
|
|
1
1
|
import inspect
|
2
|
-
|
2
|
+
import numbers
|
3
|
+
from typing import Any, Callable, Dict, List, Set, Tuple
|
3
4
|
|
4
5
|
import numpy as np
|
6
|
+
from numpy import typing as npt
|
5
7
|
from typing_extensions import TypeGuard
|
6
8
|
|
7
9
|
from snowflake.ml._internal.exceptions import error_codes, exceptions
|
@@ -153,3 +155,61 @@ def get_module_name(model: object) -> str:
|
|
153
155
|
original_exception=ValueError(f"Unable to infer the source module of the given object {model}."),
|
154
156
|
)
|
155
157
|
return module.__name__
|
158
|
+
|
159
|
+
|
160
|
+
def handle_inference_result(
|
161
|
+
inference_res: Any, output_cols: List[str], inference_method: str, within_udf: bool = False
|
162
|
+
) -> Tuple[npt.NDArray[Any], List[str]]:
|
163
|
+
if isinstance(inference_res, list) and len(inference_res) > 0 and isinstance(inference_res[0], np.ndarray):
|
164
|
+
# In case of multioutput estimators, predict_proba, decision_function etc., functions return a list of
|
165
|
+
# ndarrays. We need to concatenate them.
|
166
|
+
|
167
|
+
# First compute output column names
|
168
|
+
if len(output_cols) == len(inference_res):
|
169
|
+
actual_output_cols = []
|
170
|
+
for idx, np_arr in enumerate(inference_res):
|
171
|
+
for i in range(1 if len(np_arr.shape) <= 1 else np_arr.shape[1]):
|
172
|
+
actual_output_cols.append(f"{output_cols[idx]}_{i}")
|
173
|
+
output_cols = actual_output_cols
|
174
|
+
|
175
|
+
# Concatenate np arrays
|
176
|
+
transformed_numpy_array = np.concatenate(inference_res, axis=1)
|
177
|
+
elif isinstance(inference_res, tuple) and len(inference_res) > 0 and isinstance(inference_res[0], np.ndarray):
|
178
|
+
# In case of kneighbors, functions return a tuple of ndarrays.
|
179
|
+
transformed_numpy_array = np.stack(inference_res, axis=1)
|
180
|
+
elif isinstance(inference_res, numbers.Number):
|
181
|
+
# In case of BernoulliRBM, functions return a float
|
182
|
+
transformed_numpy_array = np.array([inference_res])
|
183
|
+
else:
|
184
|
+
transformed_numpy_array = inference_res
|
185
|
+
|
186
|
+
if (len(transformed_numpy_array.shape) == 3) and inference_method != "kneighbors":
|
187
|
+
# VotingClassifier will return results of shape (n_classifiers, n_samples, n_classes)
|
188
|
+
# when voting = "soft" and flatten_transform = False. We can't handle unflatten transforms,
|
189
|
+
# so we ignore flatten_transform flag and flatten the results.
|
190
|
+
transformed_numpy_array = np.hstack(transformed_numpy_array) # type: ignore[call-overload]
|
191
|
+
|
192
|
+
if len(transformed_numpy_array.shape) == 1:
|
193
|
+
transformed_numpy_array = np.reshape(transformed_numpy_array, (-1, 1))
|
194
|
+
|
195
|
+
shape = transformed_numpy_array.shape
|
196
|
+
if len(shape) > 1:
|
197
|
+
if shape[1] != len(output_cols):
|
198
|
+
# HeterogeneousEnsemble's transform method produce results with variying shapes
|
199
|
+
# from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes).
|
200
|
+
# It is hard to predict the response shape without using fragile introspection logic.
|
201
|
+
# So, to avoid that we are packing the results into a dataframe of shape (n_samples, 1) with
|
202
|
+
# each element being a list.
|
203
|
+
if len(output_cols) != 1:
|
204
|
+
raise TypeError(
|
205
|
+
"expected_output_cols must be same length as transformed array or should be of length 1."
|
206
|
+
f"Currently expected_output_cols shape is {len(output_cols)}, "
|
207
|
+
f"transformed array shape is {shape}. "
|
208
|
+
)
|
209
|
+
if not within_udf:
|
210
|
+
actual_output_cols = []
|
211
|
+
for i in range(shape[1]):
|
212
|
+
actual_output_cols.append(f"{output_cols[0]}_{i}")
|
213
|
+
output_cols = actual_output_cols
|
214
|
+
|
215
|
+
return transformed_numpy_array, output_cols
|