snowflake-ml-python 1.3.0__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (211) hide show
  1. snowflake/ml/_internal/file_utils.py +3 -3
  2. snowflake/ml/_internal/human_readable_id/adjectives.txt +128 -0
  3. snowflake/ml/_internal/human_readable_id/animals.txt +128 -0
  4. snowflake/ml/_internal/human_readable_id/hrid_generator.py +40 -0
  5. snowflake/ml/_internal/human_readable_id/hrid_generator_base.py +135 -0
  6. snowflake/ml/_internal/telemetry.py +11 -2
  7. snowflake/ml/_internal/utils/formatting.py +1 -1
  8. snowflake/ml/feature_store/feature_store.py +15 -106
  9. snowflake/ml/fileset/sfcfs.py +4 -3
  10. snowflake/ml/fileset/stage_fs.py +18 -0
  11. snowflake/ml/model/_api.py +9 -9
  12. snowflake/ml/model/_client/model/model_version_impl.py +20 -15
  13. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +3 -9
  14. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -5
  15. snowflake/ml/model/_deploy_client/snowservice/deploy.py +7 -6
  16. snowflake/ml/model/_model_composer/model_composer.py +10 -8
  17. snowflake/ml/model/_model_composer/model_method/function_generator.py +1 -1
  18. snowflake/ml/model/_model_composer/model_method/infer_table_function.py_template +2 -1
  19. snowflake/ml/model/_model_composer/model_method/model_method.py +2 -2
  20. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +1 -1
  21. snowflake/ml/model/_packager/model_handlers/_base.py +2 -2
  22. snowflake/ml/model/_packager/model_handlers/_utils.py +5 -5
  23. snowflake/ml/model/_packager/model_handlers/custom.py +7 -7
  24. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +2 -2
  25. snowflake/ml/model/_packager/model_handlers/llm.py +1 -1
  26. snowflake/ml/model/_packager/model_handlers/mlflow.py +1 -1
  27. snowflake/ml/model/_packager/model_handlers/pytorch.py +13 -10
  28. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +214 -0
  29. snowflake/ml/model/_packager/model_handlers/sklearn.py +6 -6
  30. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +15 -3
  31. snowflake/ml/model/_packager/model_handlers/tensorflow.py +8 -8
  32. snowflake/ml/model/_packager/model_handlers/torchscript.py +7 -7
  33. snowflake/ml/model/_packager/model_handlers/xgboost.py +8 -8
  34. snowflake/ml/model/_packager/model_meta/_core_requirements.py +1 -1
  35. snowflake/ml/model/_packager/model_packager.py +8 -6
  36. snowflake/ml/model/custom_model.py +3 -1
  37. snowflake/ml/model/type_hints.py +13 -0
  38. snowflake/ml/modeling/_internal/estimator_utils.py +61 -1
  39. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +4 -43
  40. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +4 -4
  41. snowflake/ml/modeling/_internal/ml_runtime_implementations/ml_runtime_handlers.py +21 -17
  42. snowflake/ml/modeling/_internal/model_specifications.py +3 -1
  43. snowflake/ml/modeling/_internal/model_trainer.py +2 -2
  44. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +547 -1
  45. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +67 -114
  46. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +9 -9
  47. snowflake/ml/modeling/_internal/transformer_protocols.py +2 -3
  48. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +33 -61
  49. snowflake/ml/modeling/cluster/affinity_propagation.py +33 -61
  50. snowflake/ml/modeling/cluster/agglomerative_clustering.py +33 -61
  51. snowflake/ml/modeling/cluster/birch.py +33 -61
  52. snowflake/ml/modeling/cluster/bisecting_k_means.py +33 -61
  53. snowflake/ml/modeling/cluster/dbscan.py +33 -61
  54. snowflake/ml/modeling/cluster/feature_agglomeration.py +33 -61
  55. snowflake/ml/modeling/cluster/k_means.py +33 -61
  56. snowflake/ml/modeling/cluster/mean_shift.py +33 -61
  57. snowflake/ml/modeling/cluster/mini_batch_k_means.py +33 -61
  58. snowflake/ml/modeling/cluster/optics.py +33 -61
  59. snowflake/ml/modeling/cluster/spectral_biclustering.py +33 -61
  60. snowflake/ml/modeling/cluster/spectral_clustering.py +33 -61
  61. snowflake/ml/modeling/cluster/spectral_coclustering.py +33 -61
  62. snowflake/ml/modeling/compose/column_transformer.py +33 -61
  63. snowflake/ml/modeling/compose/transformed_target_regressor.py +33 -61
  64. snowflake/ml/modeling/covariance/elliptic_envelope.py +33 -61
  65. snowflake/ml/modeling/covariance/empirical_covariance.py +33 -61
  66. snowflake/ml/modeling/covariance/graphical_lasso.py +33 -61
  67. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +33 -61
  68. snowflake/ml/modeling/covariance/ledoit_wolf.py +33 -61
  69. snowflake/ml/modeling/covariance/min_cov_det.py +33 -61
  70. snowflake/ml/modeling/covariance/oas.py +33 -61
  71. snowflake/ml/modeling/covariance/shrunk_covariance.py +33 -61
  72. snowflake/ml/modeling/decomposition/dictionary_learning.py +33 -61
  73. snowflake/ml/modeling/decomposition/factor_analysis.py +33 -61
  74. snowflake/ml/modeling/decomposition/fast_ica.py +33 -61
  75. snowflake/ml/modeling/decomposition/incremental_pca.py +33 -61
  76. snowflake/ml/modeling/decomposition/kernel_pca.py +33 -61
  77. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +33 -61
  78. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +33 -61
  79. snowflake/ml/modeling/decomposition/pca.py +33 -61
  80. snowflake/ml/modeling/decomposition/sparse_pca.py +33 -61
  81. snowflake/ml/modeling/decomposition/truncated_svd.py +33 -61
  82. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +33 -61
  83. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +33 -61
  84. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +33 -61
  85. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +33 -61
  86. snowflake/ml/modeling/ensemble/bagging_classifier.py +33 -61
  87. snowflake/ml/modeling/ensemble/bagging_regressor.py +33 -61
  88. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +33 -61
  89. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +33 -61
  90. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +33 -61
  91. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +33 -61
  92. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +33 -61
  93. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +33 -61
  94. snowflake/ml/modeling/ensemble/isolation_forest.py +33 -61
  95. snowflake/ml/modeling/ensemble/random_forest_classifier.py +33 -61
  96. snowflake/ml/modeling/ensemble/random_forest_regressor.py +33 -61
  97. snowflake/ml/modeling/ensemble/stacking_regressor.py +33 -61
  98. snowflake/ml/modeling/ensemble/voting_classifier.py +33 -61
  99. snowflake/ml/modeling/ensemble/voting_regressor.py +33 -61
  100. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +33 -61
  101. snowflake/ml/modeling/feature_selection/select_fdr.py +33 -61
  102. snowflake/ml/modeling/feature_selection/select_fpr.py +33 -61
  103. snowflake/ml/modeling/feature_selection/select_fwe.py +33 -61
  104. snowflake/ml/modeling/feature_selection/select_k_best.py +33 -61
  105. snowflake/ml/modeling/feature_selection/select_percentile.py +33 -61
  106. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +33 -61
  107. snowflake/ml/modeling/feature_selection/variance_threshold.py +33 -61
  108. snowflake/ml/modeling/framework/base.py +55 -5
  109. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +33 -61
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +33 -61
  111. snowflake/ml/modeling/impute/iterative_imputer.py +33 -61
  112. snowflake/ml/modeling/impute/knn_imputer.py +33 -61
  113. snowflake/ml/modeling/impute/missing_indicator.py +33 -61
  114. snowflake/ml/modeling/impute/simple_imputer.py +4 -15
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +33 -61
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +33 -61
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +33 -61
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +33 -61
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +33 -61
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +33 -61
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +36 -63
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +36 -63
  123. snowflake/ml/modeling/linear_model/ard_regression.py +33 -61
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +33 -61
  125. snowflake/ml/modeling/linear_model/elastic_net.py +33 -61
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +33 -61
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +33 -61
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +33 -61
  129. snowflake/ml/modeling/linear_model/lars.py +33 -61
  130. snowflake/ml/modeling/linear_model/lars_cv.py +33 -61
  131. snowflake/ml/modeling/linear_model/lasso.py +33 -61
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +33 -61
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +33 -61
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +33 -61
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +33 -61
  136. snowflake/ml/modeling/linear_model/linear_regression.py +33 -61
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +33 -61
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +33 -61
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +33 -61
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +33 -61
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +33 -61
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +33 -61
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +33 -61
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +33 -61
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +33 -61
  146. snowflake/ml/modeling/linear_model/perceptron.py +33 -61
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +33 -61
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +33 -61
  149. snowflake/ml/modeling/linear_model/ridge.py +33 -61
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +33 -61
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +33 -61
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +33 -61
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +33 -61
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +33 -61
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +33 -61
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +33 -61
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +33 -61
  158. snowflake/ml/modeling/manifold/isomap.py +33 -61
  159. snowflake/ml/modeling/manifold/mds.py +33 -61
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +33 -61
  161. snowflake/ml/modeling/manifold/tsne.py +33 -61
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +33 -61
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +33 -61
  164. snowflake/ml/modeling/model_selection/grid_search_cv.py +39 -57
  165. snowflake/ml/modeling/model_selection/randomized_search_cv.py +26 -57
  166. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +33 -61
  167. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +33 -61
  168. snowflake/ml/modeling/multiclass/output_code_classifier.py +33 -61
  169. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +33 -61
  170. snowflake/ml/modeling/naive_bayes/categorical_nb.py +33 -61
  171. snowflake/ml/modeling/naive_bayes/complement_nb.py +33 -61
  172. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +33 -61
  173. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +33 -61
  174. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +33 -61
  175. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +33 -61
  176. snowflake/ml/modeling/neighbors/kernel_density.py +33 -61
  177. snowflake/ml/modeling/neighbors/local_outlier_factor.py +33 -61
  178. snowflake/ml/modeling/neighbors/nearest_centroid.py +33 -61
  179. snowflake/ml/modeling/neighbors/nearest_neighbors.py +33 -61
  180. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +33 -61
  181. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +33 -61
  182. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +33 -61
  183. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +33 -61
  184. snowflake/ml/modeling/neural_network/mlp_classifier.py +33 -61
  185. snowflake/ml/modeling/neural_network/mlp_regressor.py +33 -61
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +33 -61
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +33 -61
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +33 -61
  189. snowflake/ml/modeling/svm/linear_svc.py +33 -61
  190. snowflake/ml/modeling/svm/linear_svr.py +33 -61
  191. snowflake/ml/modeling/svm/nu_svc.py +33 -61
  192. snowflake/ml/modeling/svm/nu_svr.py +33 -61
  193. snowflake/ml/modeling/svm/svc.py +33 -61
  194. snowflake/ml/modeling/svm/svr.py +33 -61
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +33 -61
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +33 -61
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +33 -61
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +33 -61
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +33 -61
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +33 -61
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +33 -61
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +33 -61
  203. snowflake/ml/registry/_manager/model_manager.py +6 -2
  204. snowflake/ml/registry/model_registry.py +100 -27
  205. snowflake/ml/registry/registry.py +6 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/METADATA +43 -7
  208. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/RECORD +211 -206
  209. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/LICENSE.txt +0 -0
  210. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/WHEEL +0 -0
  211. {snowflake_ml_python-1.3.0.dist-info → snowflake_ml_python-1.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,214 @@
1
+ import logging
2
+ import os
3
+ from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, cast, final
4
+
5
+ import cloudpickle
6
+ import pandas as pd
7
+ from typing_extensions import TypeGuard, Unpack
8
+
9
+ from snowflake.ml._internal import type_utils
10
+ from snowflake.ml.model import custom_model, model_signature, type_hints as model_types
11
+ from snowflake.ml.model._packager.model_env import model_env
12
+ from snowflake.ml.model._packager.model_handlers import _base, _utils as handlers_utils
13
+ from snowflake.ml.model._packager.model_handlers_migrator import base_migrator
14
+ from snowflake.ml.model._packager.model_meta import (
15
+ model_blob_meta,
16
+ model_meta as model_meta_api,
17
+ )
18
+ from snowflake.ml.model._signatures import utils as model_signature_utils
19
+ from snowflake.snowpark._internal import utils as snowpark_utils
20
+
21
+ if TYPE_CHECKING:
22
+ import sentence_transformers
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ @final
28
+ class SentenceTransformerHandler(_base.BaseModelHandler["sentence_transformers.SentenceTransformer"]):
29
+ HANDLER_TYPE = "sentence_transformers"
30
+ HANDLER_VERSION = "2024-03-15"
31
+ _MIN_SNOWPARK_ML_VERSION = "1.3.1"
32
+ _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
33
+
34
+ MODELE_BLOB_FILE_OR_DIR = "model"
35
+ DEFAULT_TARGET_METHODS = ["encode"]
36
+
37
+ @classmethod
38
+ def can_handle(
39
+ cls,
40
+ model: model_types.SupportedModelType,
41
+ ) -> TypeGuard["sentence_transformers.SentenceTransformer"]:
42
+ if type_utils.LazyType("sentence_transformers.SentenceTransformer").isinstance(model):
43
+ return True
44
+ return False
45
+
46
+ @classmethod
47
+ def cast_model(
48
+ cls,
49
+ model: model_types.SupportedModelType,
50
+ ) -> "sentence_transformers.SentenceTransformer":
51
+ import sentence_transformers
52
+
53
+ assert isinstance(model, sentence_transformers.SentenceTransformer)
54
+ return cast(sentence_transformers.SentenceTransformer, model)
55
+
56
+ @classmethod
57
+ def save_model(
58
+ cls,
59
+ name: str,
60
+ model: "sentence_transformers.SentenceTransformer",
61
+ model_meta: model_meta_api.ModelMetadata,
62
+ model_blobs_dir_path: str,
63
+ sample_input_data: Optional[model_types.SupportedDataType] = None,
64
+ is_sub_model: Optional[bool] = False,
65
+ **kwargs: Unpack[model_types.SentenceTransformersSaveOptions], # registry.log_model(options={...})
66
+ ) -> None:
67
+ # Validate target methods and signature (if possible)
68
+ if not is_sub_model:
69
+ target_methods = handlers_utils.get_target_methods(
70
+ model=model,
71
+ target_methods=kwargs.pop("target_methods", None),
72
+ default_target_methods=cls.DEFAULT_TARGET_METHODS,
73
+ )
74
+ assert target_methods == ["encode"], "target_methods can only be ['encode']"
75
+
76
+ def get_prediction(
77
+ target_method_name: str, sample_input_data: model_types.SupportedLocalDataType
78
+ ) -> model_types.SupportedLocalDataType:
79
+ return _sentence_transformer_encode(model, sample_input_data)
80
+
81
+ if model_meta.signatures:
82
+ handlers_utils.validate_target_methods(model, list(model_meta.signatures.keys()))
83
+ model_meta = handlers_utils.validate_signature(
84
+ model=model,
85
+ model_meta=model_meta,
86
+ target_methods=target_methods,
87
+ sample_input_data=sample_input_data,
88
+ get_prediction_fn=get_prediction,
89
+ )
90
+ else:
91
+ handlers_utils.validate_target_methods(model, target_methods) # DEFAULT_TARGET_METHODS only
92
+ if sample_input_data is not None:
93
+ model_meta = handlers_utils.validate_signature(
94
+ model=model,
95
+ model_meta=model_meta,
96
+ target_methods=target_methods,
97
+ sample_input_data=sample_input_data,
98
+ get_prediction_fn=get_prediction,
99
+ )
100
+
101
+ # save model
102
+ model_blob_path = os.path.join(model_blobs_dir_path, name)
103
+ os.makedirs(model_blob_path, exist_ok=True)
104
+ model.save(os.path.join(model_blob_path, cls.MODELE_BLOB_FILE_OR_DIR))
105
+
106
+ # save model metadata
107
+ base_meta = model_blob_meta.ModelBlobMeta(
108
+ name=name,
109
+ model_type=cls.HANDLER_TYPE,
110
+ handler_version=cls.HANDLER_VERSION,
111
+ path=cls.MODELE_BLOB_FILE_OR_DIR,
112
+ )
113
+ model_meta.models[name] = base_meta
114
+ model_meta.min_snowpark_ml_version = cls._MIN_SNOWPARK_ML_VERSION
115
+
116
+ model_meta.env.include_if_absent(
117
+ [
118
+ model_env.ModelDependency(requirement="sentence-transformers", pip_name="sentence-transformers"),
119
+ ],
120
+ check_local_version=True,
121
+ )
122
+
123
+ @classmethod
124
+ def load_model(
125
+ cls,
126
+ name: str,
127
+ model_meta: model_meta_api.ModelMetadata,
128
+ model_blobs_dir_path: str,
129
+ **kwargs: Unpack[model_types.ModelLoadOption], # use_gpu
130
+ ) -> "sentence_transformers.SentenceTransformer":
131
+ import sentence_transformers
132
+
133
+ if snowpark_utils.is_in_stored_procedure(): # type: ignore[no-untyped-call]
134
+ # We need to redirect the same folders to a writable location in the sandbox.
135
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp"
136
+
137
+ model_blob_path = os.path.join(model_blobs_dir_path, name)
138
+ model_blobs_metadata = model_meta.models
139
+ model_blob_metadata = model_blobs_metadata[name]
140
+ model_blob_filename = model_blob_metadata.path
141
+ model_blob_file_or_dir_path = os.path.join(model_blob_path, model_blob_filename)
142
+
143
+ if os.path.isdir(model_blob_file_or_dir_path): # if the saved model is a directory
144
+ model = sentence_transformers.SentenceTransformer(model_blob_file_or_dir_path)
145
+ else:
146
+ assert os.path.isfile(model_blob_file_or_dir_path) # if the saved model is a file
147
+ with open(model_blob_file_or_dir_path, "rb") as f:
148
+ model = cloudpickle.load(f)
149
+ assert isinstance(model, sentence_transformers.SentenceTransformer)
150
+ return model
151
+
152
+ @classmethod
153
+ def convert_as_custom_model(
154
+ cls,
155
+ raw_model: "sentence_transformers.SentenceTransformer",
156
+ model_meta: model_meta_api.ModelMetadata,
157
+ **kwargs: Unpack[model_types.ModelLoadOption],
158
+ ) -> custom_model.CustomModel:
159
+ import sentence_transformers
160
+
161
+ from snowflake.ml.model import custom_model
162
+
163
+ def _create_custom_model(
164
+ raw_model: "sentence_transformers.SentenceTransformer",
165
+ model_meta: model_meta_api.ModelMetadata,
166
+ ) -> Type[custom_model.CustomModel]:
167
+ def get_prediction(
168
+ raw_model: "sentence_transformers.SentenceTransformer",
169
+ signature: model_signature.ModelSignature,
170
+ target_method: str,
171
+ ) -> Callable[[custom_model.CustomModel, pd.DataFrame], pd.DataFrame]:
172
+ @custom_model.inference_api
173
+ def fn(self: custom_model.CustomModel, X: pd.DataFrame) -> pd.DataFrame:
174
+ predictions_df = _sentence_transformer_encode(raw_model, X)
175
+ return model_signature_utils.rename_pandas_df(predictions_df, signature.outputs)
176
+
177
+ return fn
178
+
179
+ type_method_dict = {}
180
+ for target_method_name, sig in model_meta.signatures.items():
181
+ if target_method_name == "encode":
182
+ type_method_dict[target_method_name] = get_prediction(raw_model, sig, target_method_name)
183
+ else:
184
+ ValueError(f"{target_method_name} is currently not supported.")
185
+
186
+ _SentenceTransformer = type(
187
+ "_SentenceTransformer",
188
+ (custom_model.CustomModel,),
189
+ type_method_dict,
190
+ )
191
+ return _SentenceTransformer
192
+
193
+ assert isinstance(raw_model, sentence_transformers.SentenceTransformer)
194
+ model = raw_model
195
+
196
+ _SentenceTransformer = _create_custom_model(model, model_meta)
197
+ sentence_transformers_SentenceTransformer_model = _SentenceTransformer(custom_model.ModelContext())
198
+ predict_method = getattr(sentence_transformers_SentenceTransformer_model, "encode", None)
199
+ assert callable(predict_method)
200
+ return sentence_transformers_SentenceTransformer_model
201
+
202
+
203
+ def _sentence_transformer_encode(
204
+ model: "sentence_transformers.SentenceTransformer", X: model_types.SupportedLocalDataType
205
+ ) -> model_types.SupportedLocalDataType:
206
+
207
+ if not isinstance(X, pd.DataFrame):
208
+ X = model_signature._convert_local_data_to_df(X)
209
+
210
+ assert X.shape[1] == 1, "SentenceTransformer can only accept 1 input column when converted to pd.DataFrame"
211
+ X_list = X.iloc[:, 0].tolist()
212
+
213
+ assert callable(getattr(model, "encode", None))
214
+ return pd.DataFrame({0: model.encode(X_list, batch_size=X.shape[0]).tolist()})
@@ -72,7 +72,7 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
72
72
  model: Union["sklearn.base.BaseEstimator", "sklearn.pipeline.Pipeline"],
73
73
  model_meta: model_meta_api.ModelMetadata,
74
74
  model_blobs_dir_path: str,
75
- sample_input: Optional[model_types.SupportedDataType] = None,
75
+ sample_input_data: Optional[model_types.SupportedDataType] = None,
76
76
  is_sub_model: Optional[bool] = False,
77
77
  **kwargs: Unpack[model_types.SKLModelSaveOptions],
78
78
  ) -> None:
@@ -89,21 +89,21 @@ class SKLModelHandler(_base.BaseModelHandler[Union["sklearn.base.BaseEstimator",
89
89
  )
90
90
 
91
91
  def get_prediction(
92
- target_method_name: str, sample_input: model_types.SupportedLocalDataType
92
+ target_method_name: str, sample_input_data: model_types.SupportedLocalDataType
93
93
  ) -> model_types.SupportedLocalDataType:
94
- if not isinstance(sample_input, (pd.DataFrame, np.ndarray)):
95
- sample_input = model_signature._convert_local_data_to_df(sample_input)
94
+ if not isinstance(sample_input_data, (pd.DataFrame, np.ndarray)):
95
+ sample_input_data = model_signature._convert_local_data_to_df(sample_input_data)
96
96
 
97
97
  target_method = getattr(model, target_method_name, None)
98
98
  assert callable(target_method)
99
- predictions_df = target_method(sample_input)
99
+ predictions_df = target_method(sample_input_data)
100
100
  return predictions_df
101
101
 
102
102
  model_meta = handlers_utils.validate_signature(
103
103
  model=model,
104
104
  model_meta=model_meta,
105
105
  target_methods=target_methods,
106
- sample_input=sample_input,
106
+ sample_input_data=sample_input_data,
107
107
  get_prediction_fn=get_prediction,
108
108
  )
109
109
 
@@ -69,7 +69,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
69
69
  model: "BaseEstimator",
70
70
  model_meta: model_meta_api.ModelMetadata,
71
71
  model_blobs_dir_path: str,
72
- sample_input: Optional[model_types.SupportedDataType] = None,
72
+ sample_input_data: Optional[model_types.SupportedDataType] = None,
73
73
  is_sub_model: Optional[bool] = False,
74
74
  **kwargs: Unpack[model_types.SNOWModelSaveOptions],
75
75
  ) -> None:
@@ -79,7 +79,7 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
79
79
  # Pipeline is inherited from BaseEstimator, so no need to add one more check
80
80
 
81
81
  if not is_sub_model:
82
- if sample_input is not None or model_meta.signatures:
82
+ if sample_input_data is not None or model_meta.signatures:
83
83
  warnings.warn(
84
84
  "Inferring model signature from sample input or providing model signature for Snowpark ML "
85
85
  + "Modeling model is not required. Model signature will automatically be inferred during fitting. ",
@@ -87,7 +87,19 @@ class SnowMLModelHandler(_base.BaseModelHandler["BaseEstimator"]):
87
87
  stacklevel=2,
88
88
  )
89
89
  assert hasattr(model, "model_signatures"), "Model does not have model signatures as expected."
90
- model_meta.signatures = getattr(model, "model_signatures", {})
90
+ model_signature_dict = getattr(model, "model_signatures", {})
91
+ target_methods = kwargs.pop("target_methods", None)
92
+ if not target_methods:
93
+ model_meta.signatures = model_signature_dict
94
+ else:
95
+ temp_model_signature_dict = {}
96
+ for method_name in target_methods:
97
+ method_model_signature = model_signature_dict.get(method_name, None)
98
+ if method_model_signature is not None:
99
+ temp_model_signature_dict[method_name] = method_model_signature
100
+ else:
101
+ raise ValueError(f"Target method {method_name} does not exist in the model.")
102
+ model_meta.signatures = temp_model_signature_dict
91
103
 
92
104
  model_blob_path = os.path.join(model_blobs_dir_path, name)
93
105
  os.makedirs(model_blob_path, exist_ok=True)
@@ -64,7 +64,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
64
64
  model: "tensorflow.Module",
65
65
  model_meta: model_meta_api.ModelMetadata,
66
66
  model_blobs_dir_path: str,
67
- sample_input: Optional[model_types.SupportedDataType] = None,
67
+ sample_input_data: Optional[model_types.SupportedDataType] = None,
68
68
  is_sub_model: Optional[bool] = False,
69
69
  **kwargs: Unpack[model_types.TensorflowSaveOptions],
70
70
  ) -> None:
@@ -85,18 +85,18 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
85
85
  )
86
86
 
87
87
  def get_prediction(
88
- target_method_name: str, sample_input: "model_types.SupportedLocalDataType"
88
+ target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
89
89
  ) -> model_types.SupportedLocalDataType:
90
- if not tensorflow_handler.SeqOfTensorflowTensorHandler.can_handle(sample_input):
91
- sample_input = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(
92
- model_signature._convert_local_data_to_df(sample_input)
90
+ if not tensorflow_handler.SeqOfTensorflowTensorHandler.can_handle(sample_input_data):
91
+ sample_input_data = tensorflow_handler.SeqOfTensorflowTensorHandler.convert_from_df(
92
+ model_signature._convert_local_data_to_df(sample_input_data)
93
93
  )
94
94
 
95
95
  target_method = getattr(model, target_method_name, None)
96
96
  assert callable(target_method)
97
- for tensor in sample_input:
97
+ for tensor in sample_input_data:
98
98
  tensorflow.stop_gradient(tensor)
99
- predictions_df = target_method(*sample_input)
99
+ predictions_df = target_method(*sample_input_data)
100
100
 
101
101
  if isinstance(predictions_df, (tensorflow.Tensor, tensorflow.Variable, np.ndarray)):
102
102
  predictions_df = [predictions_df]
@@ -107,7 +107,7 @@ class TensorFlowHandler(_base.BaseModelHandler["tensorflow.Module"]):
107
107
  model=model,
108
108
  model_meta=model_meta,
109
109
  target_methods=target_methods,
110
- sample_input=sample_input,
110
+ sample_input_data=sample_input_data,
111
111
  get_prediction_fn=get_prediction,
112
112
  )
113
113
 
@@ -62,7 +62,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
62
62
  model: "torch.jit.ScriptModule", # type:ignore[name-defined]
63
63
  model_meta: model_meta_api.ModelMetadata,
64
64
  model_blobs_dir_path: str,
65
- sample_input: Optional[model_types.SupportedDataType] = None,
65
+ sample_input_data: Optional[model_types.SupportedDataType] = None,
66
66
  is_sub_model: Optional[bool] = False,
67
67
  **kwargs: Unpack[model_types.TorchScriptSaveOptions],
68
68
  ) -> None:
@@ -78,18 +78,18 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
78
78
  )
79
79
 
80
80
  def get_prediction(
81
- target_method_name: str, sample_input: "model_types.SupportedLocalDataType"
81
+ target_method_name: str, sample_input_data: "model_types.SupportedLocalDataType"
82
82
  ) -> model_types.SupportedLocalDataType:
83
- if not pytorch_handler.SeqOfPyTorchTensorHandler.can_handle(sample_input):
84
- sample_input = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(
85
- model_signature._convert_local_data_to_df(sample_input)
83
+ if not pytorch_handler.SeqOfPyTorchTensorHandler.can_handle(sample_input_data):
84
+ sample_input_data = pytorch_handler.SeqOfPyTorchTensorHandler.convert_from_df(
85
+ model_signature._convert_local_data_to_df(sample_input_data)
86
86
  )
87
87
 
88
88
  model.eval()
89
89
  target_method = getattr(model, target_method_name, None)
90
90
  assert callable(target_method)
91
91
  with torch.no_grad():
92
- predictions_df = target_method(*sample_input)
92
+ predictions_df = target_method(*sample_input_data)
93
93
 
94
94
  if isinstance(predictions_df, torch.Tensor):
95
95
  predictions_df = [predictions_df]
@@ -100,7 +100,7 @@ class TorchScriptHandler(_base.BaseModelHandler["torch.jit.ScriptModule"]): # t
100
100
  model=model,
101
101
  model_meta=model_meta,
102
102
  target_methods=target_methods,
103
- sample_input=sample_input,
103
+ sample_input_data=sample_input_data,
104
104
  get_prediction_fn=get_prediction,
105
105
  )
106
106
 
@@ -45,7 +45,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
45
45
  _HANDLER_MIGRATOR_PLANS: Dict[str, Type[base_migrator.BaseModelHandlerMigrator]] = {}
46
46
 
47
47
  MODELE_BLOB_FILE_OR_DIR = "model.ubj"
48
- DEFAULT_TARGET_METHODS = ["apply", "predict", "predict_proba"]
48
+ DEFAULT_TARGET_METHODS = ["predict", "predict_proba"]
49
49
 
50
50
  @classmethod
51
51
  def can_handle(
@@ -76,7 +76,7 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
76
76
  model: Union["xgboost.Booster", "xgboost.XGBModel"],
77
77
  model_meta: model_meta_api.ModelMetadata,
78
78
  model_blobs_dir_path: str,
79
- sample_input: Optional[model_types.SupportedDataType] = None,
79
+ sample_input_data: Optional[model_types.SupportedDataType] = None,
80
80
  is_sub_model: Optional[bool] = False,
81
81
  **kwargs: Unpack[model_types.XGBModelSaveOptions],
82
82
  ) -> None:
@@ -92,24 +92,24 @@ class XGBModelHandler(_base.BaseModelHandler[Union["xgboost.Booster", "xgboost.X
92
92
  )
93
93
 
94
94
  def get_prediction(
95
- target_method_name: str, sample_input: model_types.SupportedLocalDataType
95
+ target_method_name: str, sample_input_data: model_types.SupportedLocalDataType
96
96
  ) -> model_types.SupportedLocalDataType:
97
- if not isinstance(sample_input, (pd.DataFrame, np.ndarray)):
98
- sample_input = model_signature._convert_local_data_to_df(sample_input)
97
+ if not isinstance(sample_input_data, (pd.DataFrame, np.ndarray)):
98
+ sample_input_data = model_signature._convert_local_data_to_df(sample_input_data)
99
99
 
100
100
  if isinstance(model, xgboost.Booster):
101
- sample_input = xgboost.DMatrix(sample_input)
101
+ sample_input_data = xgboost.DMatrix(sample_input_data)
102
102
 
103
103
  target_method = getattr(model, target_method_name, None)
104
104
  assert callable(target_method)
105
- predictions_df = target_method(sample_input)
105
+ predictions_df = target_method(sample_input_data)
106
106
  return predictions_df
107
107
 
108
108
  model_meta = handlers_utils.validate_signature(
109
109
  model=model,
110
110
  model_meta=model_meta,
111
111
  target_methods=target_methods,
112
- sample_input=sample_input,
112
+ sample_input_data=sample_input_data,
113
113
  get_prediction_fn=get_prediction,
114
114
  )
115
115
 
@@ -6,6 +6,6 @@ REQUIREMENTS = [
6
6
  "packaging>=20.9,<24",
7
7
  "pandas>=1.0.0,<2",
8
8
  "pyyaml>=6.0,<7",
9
- "snowflake-snowpark-python>=1.8.0,<2,!=1.12.0",
9
+ "snowflake-snowpark-python>=1.11.1,<2,!=1.12.0",
10
10
  "typing-extensions>=4.1.0,<5"
11
11
  ]
@@ -40,7 +40,7 @@ class ModelPackager:
40
40
  name: str,
41
41
  model: model_types.SupportedModelType,
42
42
  signatures: Optional[Dict[str, model_signature.ModelSignature]] = None,
43
- sample_input: Optional[model_types.SupportedDataType] = None,
43
+ sample_input_data: Optional[model_types.SupportedDataType] = None,
44
44
  metadata: Optional[Dict[str, str]] = None,
45
45
  conda_dependencies: Optional[List[str]] = None,
46
46
  pip_requirements: Optional[List[str]] = None,
@@ -49,18 +49,20 @@ class ModelPackager:
49
49
  code_paths: Optional[List[str]] = None,
50
50
  options: Optional[model_types.ModelSaveOption] = None,
51
51
  ) -> None:
52
- if (signatures is None) and (sample_input is None) and not model_handler.is_auto_signature_model(model):
52
+ if (signatures is None) and (sample_input_data is None) and not model_handler.is_auto_signature_model(model):
53
53
  raise snowml_exceptions.SnowflakeMLException(
54
54
  error_code=error_codes.INVALID_ARGUMENT,
55
55
  original_exception=ValueError(
56
- "Signatures and sample_input both cannot be None at the same time for this kind of model."
56
+ "Signatures and sample_input_data both cannot be None at the same time for this kind of model."
57
57
  ),
58
58
  )
59
59
 
60
- if (signatures is not None) and (sample_input is not None):
60
+ if (signatures is not None) and (sample_input_data is not None):
61
61
  raise snowml_exceptions.SnowflakeMLException(
62
62
  error_code=error_codes.INVALID_ARGUMENT,
63
- original_exception=ValueError("Signatures and sample_input both cannot be specified at the same time."),
63
+ original_exception=ValueError(
64
+ "Signatures and sample_input_data both cannot be specified at the same time."
65
+ ),
64
66
  )
65
67
 
66
68
  if not options:
@@ -93,7 +95,7 @@ class ModelPackager:
93
95
  model=model,
94
96
  model_meta=meta,
95
97
  model_blobs_dir_path=model_blobs_path,
96
- sample_input=sample_input,
98
+ sample_input_data=sample_input_data,
97
99
  is_sub_model=False,
98
100
  **options,
99
101
  )
@@ -149,7 +149,9 @@ class CustomModel:
149
149
  context: A ModelContext object showing sub-models and artifacts related to this model.
150
150
  """
151
151
 
152
- def __init__(self, context: ModelContext) -> None:
152
+ def __init__(self, context: Optional[ModelContext] = None) -> None:
153
+ if context is None:
154
+ context = ModelContext()
153
155
  self.context = context
154
156
  for method in self._get_infer_methods():
155
157
  _validate_predict_function(method)
@@ -22,6 +22,7 @@ if TYPE_CHECKING:
22
22
  import mlflow
23
23
  import numpy as np
24
24
  import pandas as pd
25
+ import sentence_transformers
25
26
  import sklearn.base
26
27
  import sklearn.pipeline
27
28
  import tensorflow
@@ -32,6 +33,7 @@ if TYPE_CHECKING:
32
33
  import snowflake.ml.model.custom_model
33
34
  import snowflake.ml.model.models.huggingface_pipeline
34
35
  import snowflake.ml.model.models.llm
36
+ import snowflake.ml.model.models.sentence_transformers
35
37
  import snowflake.snowpark
36
38
  from snowflake.ml.modeling.framework import base # noqa: F401
37
39
 
@@ -81,7 +83,9 @@ SupportedNoSignatureRequirementsModelType = Union[
81
83
  "base.BaseEstimator",
82
84
  "mlflow.pyfunc.PyFuncModel",
83
85
  "transformers.Pipeline",
86
+ "sentence_transformers.SentenceTransformer",
84
87
  "snowflake.ml.model.models.huggingface_pipeline.HuggingFacePipelineModel",
88
+ "snowflake.ml.model.models.sentence_transformers.SentenceTransformer",
85
89
  "snowflake.ml.model.models.llm.LLM",
86
90
  ]
87
91
 
@@ -106,6 +110,7 @@ Here is all acceptable types of Snowflake native model packaging and its handler
106
110
  | mlflow.pyfunc.PyFuncModel | mlflow.py | _MLFlowHandler |
107
111
  | transformers.Pipeline | huggingface_pipeline.py | _HuggingFacePipelineHandler |
108
112
  | huggingface_pipeline.HuggingFacePipelineModel | huggingface_pipeline.py | _HuggingFacePipelineHandler |
113
+ | sentence_transformers.SentenceTransformer | sentence_transformers.py | _SentenceTransformerHandler |
109
114
  """
110
115
 
111
116
  SupportedModelHandlerType = Literal[
@@ -113,6 +118,7 @@ SupportedModelHandlerType = Literal[
113
118
  "huggingface_pipeline",
114
119
  "mlflow",
115
120
  "pytorch",
121
+ "sentence_transformers",
116
122
  "sklearn",
117
123
  "snowml",
118
124
  "tensorflow",
@@ -215,6 +221,7 @@ class BaseModelSaveOption(TypedDict):
215
221
  embed_local_ml_library: NotRequired[bool]
216
222
  relax_version: NotRequired[bool]
217
223
  _legacy_save: NotRequired[bool]
224
+ function_type: NotRequired[Literal["FUNCTION", "TABLE_FUNCTION"]]
218
225
  method_options: NotRequired[Dict[str, ModelMethodSaveOptions]]
219
226
 
220
227
 
@@ -261,6 +268,11 @@ class HuggingFaceSaveOptions(BaseModelSaveOption):
261
268
  cuda_version: NotRequired[str]
262
269
 
263
270
 
271
+ class SentenceTransformersSaveOptions(BaseModelSaveOption):
272
+ target_methods: NotRequired[Sequence[str]]
273
+ cuda_version: NotRequired[str]
274
+
275
+
264
276
  class LLMSaveOptions(BaseModelSaveOption):
265
277
  cuda_version: NotRequired[str]
266
278
 
@@ -276,6 +288,7 @@ ModelSaveOption = Union[
276
288
  TensorflowSaveOptions,
277
289
  MLFlowSaveOptions,
278
290
  HuggingFaceSaveOptions,
291
+ SentenceTransformersSaveOptions,
279
292
  LLMSaveOptions,
280
293
  ]
281
294
 
@@ -1,7 +1,9 @@
1
1
  import inspect
2
- from typing import Any, Callable, Dict, Set, Tuple
2
+ import numbers
3
+ from typing import Any, Callable, Dict, List, Set, Tuple
3
4
 
4
5
  import numpy as np
6
+ from numpy import typing as npt
5
7
  from typing_extensions import TypeGuard
6
8
 
7
9
  from snowflake.ml._internal.exceptions import error_codes, exceptions
@@ -153,3 +155,61 @@ def get_module_name(model: object) -> str:
153
155
  original_exception=ValueError(f"Unable to infer the source module of the given object {model}."),
154
156
  )
155
157
  return module.__name__
158
+
159
+
160
+ def handle_inference_result(
161
+ inference_res: Any, output_cols: List[str], inference_method: str, within_udf: bool = False
162
+ ) -> Tuple[npt.NDArray[Any], List[str]]:
163
+ if isinstance(inference_res, list) and len(inference_res) > 0 and isinstance(inference_res[0], np.ndarray):
164
+ # In case of multioutput estimators, predict_proba, decision_function etc., functions return a list of
165
+ # ndarrays. We need to concatenate them.
166
+
167
+ # First compute output column names
168
+ if len(output_cols) == len(inference_res):
169
+ actual_output_cols = []
170
+ for idx, np_arr in enumerate(inference_res):
171
+ for i in range(1 if len(np_arr.shape) <= 1 else np_arr.shape[1]):
172
+ actual_output_cols.append(f"{output_cols[idx]}_{i}")
173
+ output_cols = actual_output_cols
174
+
175
+ # Concatenate np arrays
176
+ transformed_numpy_array = np.concatenate(inference_res, axis=1)
177
+ elif isinstance(inference_res, tuple) and len(inference_res) > 0 and isinstance(inference_res[0], np.ndarray):
178
+ # In case of kneighbors, functions return a tuple of ndarrays.
179
+ transformed_numpy_array = np.stack(inference_res, axis=1)
180
+ elif isinstance(inference_res, numbers.Number):
181
+ # In case of BernoulliRBM, functions return a float
182
+ transformed_numpy_array = np.array([inference_res])
183
+ else:
184
+ transformed_numpy_array = inference_res
185
+
186
+ if (len(transformed_numpy_array.shape) == 3) and inference_method != "kneighbors":
187
+ # VotingClassifier will return results of shape (n_classifiers, n_samples, n_classes)
188
+ # when voting = "soft" and flatten_transform = False. We can't handle unflatten transforms,
189
+ # so we ignore flatten_transform flag and flatten the results.
190
+ transformed_numpy_array = np.hstack(transformed_numpy_array) # type: ignore[call-overload]
191
+
192
+ if len(transformed_numpy_array.shape) == 1:
193
+ transformed_numpy_array = np.reshape(transformed_numpy_array, (-1, 1))
194
+
195
+ shape = transformed_numpy_array.shape
196
+ if len(shape) > 1:
197
+ if shape[1] != len(output_cols):
198
+ # HeterogeneousEnsemble's transform method produce results with variying shapes
199
+ # from (n_samples, n_estimators) to (n_samples, n_estimators * n_classes).
200
+ # It is hard to predict the response shape without using fragile introspection logic.
201
+ # So, to avoid that we are packing the results into a dataframe of shape (n_samples, 1) with
202
+ # each element being a list.
203
+ if len(output_cols) != 1:
204
+ raise TypeError(
205
+ "expected_output_cols must be same length as transformed array or should be of length 1."
206
+ f"Currently expected_output_cols shape is {len(output_cols)}, "
207
+ f"transformed array shape is {shape}. "
208
+ )
209
+ if not within_udf:
210
+ actual_output_cols = []
211
+ for i in range(shape[1]):
212
+ actual_output_cols.append(f"{output_cols[0]}_{i}")
213
+ output_cols = actual_output_cols
214
+
215
+ return transformed_numpy_array, output_cols