snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +156 -20
  5. snowflake/ml/_internal/utils/identifier.py +48 -11
  6. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  7. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  8. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  9. snowflake/ml/_internal/utils/table_manager.py +19 -1
  10. snowflake/ml/_internal/utils/uri.py +2 -2
  11. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  12. snowflake/ml/data/data_connector.py +88 -9
  13. snowflake/ml/data/data_ingestor.py +18 -1
  14. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  15. snowflake/ml/data/torch_utils.py +68 -0
  16. snowflake/ml/dataset/dataset.py +1 -3
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +9 -3
  19. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  20. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  21. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  22. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  23. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  24. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  26. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  27. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  29. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  31. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  32. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  33. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  34. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  35. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  36. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  37. snowflake/ml/feature_store/feature_store.py +100 -41
  38. snowflake/ml/feature_store/feature_view.py +149 -5
  39. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  40. snowflake/ml/fileset/fileset.py +1 -1
  41. snowflake/ml/fileset/sfcfs.py +9 -3
  42. snowflake/ml/model/_client/model/model_impl.py +11 -2
  43. snowflake/ml/model/_client/model/model_version_impl.py +186 -20
  44. snowflake/ml/model/_client/ops/model_ops.py +144 -30
  45. snowflake/ml/model/_client/ops/service_ops.py +312 -0
  46. snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
  47. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
  48. snowflake/ml/model/_client/sql/model_version.py +13 -4
  49. snowflake/ml/model/_client/sql/service.py +196 -0
  50. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  51. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  52. snowflake/ml/model/_model_composer/model_composer.py +5 -0
  53. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
  54. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  55. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  57. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  58. snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
  59. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  60. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
  61. snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
  62. snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
  63. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  64. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  65. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  66. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  67. snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
  68. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
  69. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  70. snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
  71. snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
  72. snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
  73. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
  74. snowflake/ml/model/_packager/model_packager.py +4 -1
  75. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  76. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  77. snowflake/ml/model/_signatures/utils.py +9 -0
  78. snowflake/ml/model/models/llm.py +3 -1
  79. snowflake/ml/model/type_hints.py +10 -4
  80. snowflake/ml/modeling/_internal/constants.py +1 -0
  81. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  82. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  83. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  84. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  85. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  86. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  87. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  88. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  89. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  90. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  91. snowflake/ml/modeling/cluster/birch.py +60 -21
  92. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  93. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  94. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  95. snowflake/ml/modeling/cluster/k_means.py +60 -21
  96. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  97. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  98. snowflake/ml/modeling/cluster/optics.py +60 -21
  99. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  100. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  101. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  102. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  103. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  104. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  105. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  106. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  107. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  108. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  109. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  110. snowflake/ml/modeling/covariance/oas.py +60 -21
  111. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  112. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  113. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  114. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  115. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  116. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  117. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  118. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  119. snowflake/ml/modeling/decomposition/pca.py +60 -21
  120. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  121. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  122. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  123. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  124. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  125. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  126. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  127. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  128. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  129. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  130. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  131. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  132. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  133. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  134. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  135. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  136. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  137. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  138. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  139. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  140. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  141. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  142. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  143. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  144. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  145. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  146. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  147. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  148. snowflake/ml/modeling/framework/base.py +28 -19
  149. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  150. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  151. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  152. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  153. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  154. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  155. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  156. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  157. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  158. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  159. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  160. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  161. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  162. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  163. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  164. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  165. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  166. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  167. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  168. snowflake/ml/modeling/linear_model/lars.py +60 -21
  169. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  170. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  171. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  172. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  173. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  174. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  175. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  176. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  177. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  178. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  179. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  180. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  181. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  182. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  183. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  184. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  185. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  186. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  187. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  188. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  189. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  190. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  191. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  192. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  193. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  194. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  195. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  196. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  197. snowflake/ml/modeling/manifold/isomap.py +60 -21
  198. snowflake/ml/modeling/manifold/mds.py +60 -21
  199. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  200. snowflake/ml/modeling/manifold/tsne.py +60 -21
  201. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  202. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  203. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  204. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  205. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  206. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  207. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  208. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  209. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  210. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  211. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  212. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  213. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  214. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  215. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  216. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  217. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  218. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  219. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  220. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  221. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  222. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  223. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  224. snowflake/ml/modeling/pipeline/pipeline.py +4 -12
  225. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  226. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  227. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  228. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  229. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  230. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  231. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  232. snowflake/ml/modeling/svm/svc.py +60 -21
  233. snowflake/ml/modeling/svm/svr.py +60 -21
  234. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  235. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  236. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  237. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  238. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  239. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  240. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  241. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  242. snowflake/ml/registry/_manager/model_manager.py +20 -2
  243. snowflake/ml/registry/model_registry.py +1 -1
  244. snowflake/ml/registry/registry.py +1 -2
  245. snowflake/ml/utils/sql_client.py +22 -0
  246. snowflake/ml/version.py +1 -1
  247. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
  248. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
  249. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  250. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  251. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  252. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -627,12 +624,23 @@ class RandomForestRegressor(BaseTransformer):
627
624
  autogenerated=self._autogenerated,
628
625
  subproject=_SUBPROJECT,
629
626
  )
630
- output_result, fitted_estimator = model_trainer.train_fit_predict(
631
- drop_input_cols=self._drop_input_cols,
632
- expected_output_cols_list=(
633
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
634
- ),
627
+ expected_output_cols = (
628
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
635
629
  )
630
+ if isinstance(dataset, DataFrame):
631
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
632
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
633
+ )
634
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
635
+ drop_input_cols=self._drop_input_cols,
636
+ expected_output_cols_list=expected_output_cols,
637
+ example_output_pd_df=example_output_pd_df,
638
+ )
639
+ else:
640
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
641
+ drop_input_cols=self._drop_input_cols,
642
+ expected_output_cols_list=expected_output_cols,
643
+ )
636
644
  self._sklearn_object = fitted_estimator
637
645
  self._is_fitted = True
638
646
  return output_result
@@ -711,12 +719,41 @@ class RandomForestRegressor(BaseTransformer):
711
719
 
712
720
  return rv
713
721
 
714
- def _align_expected_output_names(
715
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
716
- ) -> List[str]:
722
+ def _align_expected_output(
723
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
724
+ ) -> Tuple[List[str], pd.DataFrame]:
725
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
726
+ and output dataframe with 1 line.
727
+ If the method is fit_predict, run 2 lines of data.
728
+ """
717
729
  # in case the inferred output column names dimension is different
718
730
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
719
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
731
+
732
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
733
+ # so change the minimum of number of rows to 2
734
+ num_examples = 2
735
+ statement_params = telemetry.get_function_usage_statement_params(
736
+ project=_PROJECT,
737
+ subproject=_SUBPROJECT,
738
+ function_name=telemetry.get_statement_params_full_func_name(
739
+ inspect.currentframe(), RandomForestRegressor.__class__.__name__
740
+ ),
741
+ api_calls=[Session.call],
742
+ custom_tags={"autogen": True} if self._autogenerated else None,
743
+ )
744
+ if output_cols_prefix == "fit_predict_":
745
+ if hasattr(self._sklearn_object, "n_clusters"):
746
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
747
+ num_examples = self._sklearn_object.n_clusters
748
+ elif hasattr(self._sklearn_object, "min_samples"):
749
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
750
+ num_examples = self._sklearn_object.min_samples
751
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
752
+ # LocalOutlierFactor expects n_neighbors <= n_samples
753
+ num_examples = self._sklearn_object.n_neighbors
754
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
755
+ else:
756
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
720
757
 
721
758
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
722
759
  # seen during the fit.
@@ -728,12 +765,14 @@ class RandomForestRegressor(BaseTransformer):
728
765
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
729
766
  if self.sample_weight_col:
730
767
  output_df_columns_set -= set(self.sample_weight_col)
768
+
731
769
  # if the dimension of inferred output column names is correct; use it
732
770
  if len(expected_output_cols_list) == len(output_df_columns_set):
733
- return expected_output_cols_list
771
+ return expected_output_cols_list, output_df_pd
734
772
  # otherwise, use the sklearn estimator's output
735
773
  else:
736
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
774
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
775
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
737
776
 
738
777
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
739
778
  @telemetry.send_api_usage_telemetry(
@@ -779,7 +818,7 @@ class RandomForestRegressor(BaseTransformer):
779
818
  drop_input_cols=self._drop_input_cols,
780
819
  expected_output_cols_type="float",
781
820
  )
782
- expected_output_cols = self._align_expected_output_names(
821
+ expected_output_cols, _ = self._align_expected_output(
783
822
  inference_method, dataset, expected_output_cols, output_cols_prefix
784
823
  )
785
824
 
@@ -845,7 +884,7 @@ class RandomForestRegressor(BaseTransformer):
845
884
  drop_input_cols=self._drop_input_cols,
846
885
  expected_output_cols_type="float",
847
886
  )
848
- expected_output_cols = self._align_expected_output_names(
887
+ expected_output_cols, _ = self._align_expected_output(
849
888
  inference_method, dataset, expected_output_cols, output_cols_prefix
850
889
  )
851
890
  elif isinstance(dataset, pd.DataFrame):
@@ -908,7 +947,7 @@ class RandomForestRegressor(BaseTransformer):
908
947
  drop_input_cols=self._drop_input_cols,
909
948
  expected_output_cols_type="float",
910
949
  )
911
- expected_output_cols = self._align_expected_output_names(
950
+ expected_output_cols, _ = self._align_expected_output(
912
951
  inference_method, dataset, expected_output_cols, output_cols_prefix
913
952
  )
914
953
 
@@ -973,7 +1012,7 @@ class RandomForestRegressor(BaseTransformer):
973
1012
  drop_input_cols = self._drop_input_cols,
974
1013
  expected_output_cols_type="float",
975
1014
  )
976
- expected_output_cols = self._align_expected_output_names(
1015
+ expected_output_cols, _ = self._align_expected_output(
977
1016
  inference_method, dataset, expected_output_cols, output_cols_prefix
978
1017
  )
979
1018
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -530,12 +527,23 @@ class StackingRegressor(BaseTransformer):
530
527
  autogenerated=self._autogenerated,
531
528
  subproject=_SUBPROJECT,
532
529
  )
533
- output_result, fitted_estimator = model_trainer.train_fit_predict(
534
- drop_input_cols=self._drop_input_cols,
535
- expected_output_cols_list=(
536
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
537
- ),
530
+ expected_output_cols = (
531
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
538
532
  )
533
+ if isinstance(dataset, DataFrame):
534
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
535
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
536
+ )
537
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
538
+ drop_input_cols=self._drop_input_cols,
539
+ expected_output_cols_list=expected_output_cols,
540
+ example_output_pd_df=example_output_pd_df,
541
+ )
542
+ else:
543
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
544
+ drop_input_cols=self._drop_input_cols,
545
+ expected_output_cols_list=expected_output_cols,
546
+ )
539
547
  self._sklearn_object = fitted_estimator
540
548
  self._is_fitted = True
541
549
  return output_result
@@ -616,12 +624,41 @@ class StackingRegressor(BaseTransformer):
616
624
 
617
625
  return rv
618
626
 
619
- def _align_expected_output_names(
620
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
621
- ) -> List[str]:
627
+ def _align_expected_output(
628
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
629
+ ) -> Tuple[List[str], pd.DataFrame]:
630
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
631
+ and output dataframe with 1 line.
632
+ If the method is fit_predict, run 2 lines of data.
633
+ """
622
634
  # in case the inferred output column names dimension is different
623
635
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
624
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
636
+
637
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
638
+ # so change the minimum of number of rows to 2
639
+ num_examples = 2
640
+ statement_params = telemetry.get_function_usage_statement_params(
641
+ project=_PROJECT,
642
+ subproject=_SUBPROJECT,
643
+ function_name=telemetry.get_statement_params_full_func_name(
644
+ inspect.currentframe(), StackingRegressor.__class__.__name__
645
+ ),
646
+ api_calls=[Session.call],
647
+ custom_tags={"autogen": True} if self._autogenerated else None,
648
+ )
649
+ if output_cols_prefix == "fit_predict_":
650
+ if hasattr(self._sklearn_object, "n_clusters"):
651
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
652
+ num_examples = self._sklearn_object.n_clusters
653
+ elif hasattr(self._sklearn_object, "min_samples"):
654
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
655
+ num_examples = self._sklearn_object.min_samples
656
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
657
+ # LocalOutlierFactor expects n_neighbors <= n_samples
658
+ num_examples = self._sklearn_object.n_neighbors
659
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
660
+ else:
661
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
625
662
 
626
663
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
627
664
  # seen during the fit.
@@ -633,12 +670,14 @@ class StackingRegressor(BaseTransformer):
633
670
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
634
671
  if self.sample_weight_col:
635
672
  output_df_columns_set -= set(self.sample_weight_col)
673
+
636
674
  # if the dimension of inferred output column names is correct; use it
637
675
  if len(expected_output_cols_list) == len(output_df_columns_set):
638
- return expected_output_cols_list
676
+ return expected_output_cols_list, output_df_pd
639
677
  # otherwise, use the sklearn estimator's output
640
678
  else:
641
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
679
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
680
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
642
681
 
643
682
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
644
683
  @telemetry.send_api_usage_telemetry(
@@ -684,7 +723,7 @@ class StackingRegressor(BaseTransformer):
684
723
  drop_input_cols=self._drop_input_cols,
685
724
  expected_output_cols_type="float",
686
725
  )
687
- expected_output_cols = self._align_expected_output_names(
726
+ expected_output_cols, _ = self._align_expected_output(
688
727
  inference_method, dataset, expected_output_cols, output_cols_prefix
689
728
  )
690
729
 
@@ -750,7 +789,7 @@ class StackingRegressor(BaseTransformer):
750
789
  drop_input_cols=self._drop_input_cols,
751
790
  expected_output_cols_type="float",
752
791
  )
753
- expected_output_cols = self._align_expected_output_names(
792
+ expected_output_cols, _ = self._align_expected_output(
754
793
  inference_method, dataset, expected_output_cols, output_cols_prefix
755
794
  )
756
795
  elif isinstance(dataset, pd.DataFrame):
@@ -813,7 +852,7 @@ class StackingRegressor(BaseTransformer):
813
852
  drop_input_cols=self._drop_input_cols,
814
853
  expected_output_cols_type="float",
815
854
  )
816
- expected_output_cols = self._align_expected_output_names(
855
+ expected_output_cols, _ = self._align_expected_output(
817
856
  inference_method, dataset, expected_output_cols, output_cols_prefix
818
857
  )
819
858
 
@@ -878,7 +917,7 @@ class StackingRegressor(BaseTransformer):
878
917
  drop_input_cols = self._drop_input_cols,
879
918
  expected_output_cols_type="float",
880
919
  )
881
- expected_output_cols = self._align_expected_output_names(
920
+ expected_output_cols, _ = self._align_expected_output(
882
921
  inference_method, dataset, expected_output_cols, output_cols_prefix
883
922
  )
884
923
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -512,12 +509,23 @@ class VotingClassifier(BaseTransformer):
512
509
  autogenerated=self._autogenerated,
513
510
  subproject=_SUBPROJECT,
514
511
  )
515
- output_result, fitted_estimator = model_trainer.train_fit_predict(
516
- drop_input_cols=self._drop_input_cols,
517
- expected_output_cols_list=(
518
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
519
- ),
512
+ expected_output_cols = (
513
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
520
514
  )
515
+ if isinstance(dataset, DataFrame):
516
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
517
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
518
+ )
519
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
520
+ drop_input_cols=self._drop_input_cols,
521
+ expected_output_cols_list=expected_output_cols,
522
+ example_output_pd_df=example_output_pd_df,
523
+ )
524
+ else:
525
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
526
+ drop_input_cols=self._drop_input_cols,
527
+ expected_output_cols_list=expected_output_cols,
528
+ )
521
529
  self._sklearn_object = fitted_estimator
522
530
  self._is_fitted = True
523
531
  return output_result
@@ -598,12 +606,41 @@ class VotingClassifier(BaseTransformer):
598
606
 
599
607
  return rv
600
608
 
601
- def _align_expected_output_names(
602
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
603
- ) -> List[str]:
609
+ def _align_expected_output(
610
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
611
+ ) -> Tuple[List[str], pd.DataFrame]:
612
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
613
+ and output dataframe with 1 line.
614
+ If the method is fit_predict, run 2 lines of data.
615
+ """
604
616
  # in case the inferred output column names dimension is different
605
617
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
606
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
618
+
619
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
620
+ # so change the minimum of number of rows to 2
621
+ num_examples = 2
622
+ statement_params = telemetry.get_function_usage_statement_params(
623
+ project=_PROJECT,
624
+ subproject=_SUBPROJECT,
625
+ function_name=telemetry.get_statement_params_full_func_name(
626
+ inspect.currentframe(), VotingClassifier.__class__.__name__
627
+ ),
628
+ api_calls=[Session.call],
629
+ custom_tags={"autogen": True} if self._autogenerated else None,
630
+ )
631
+ if output_cols_prefix == "fit_predict_":
632
+ if hasattr(self._sklearn_object, "n_clusters"):
633
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
634
+ num_examples = self._sklearn_object.n_clusters
635
+ elif hasattr(self._sklearn_object, "min_samples"):
636
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
637
+ num_examples = self._sklearn_object.min_samples
638
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
639
+ # LocalOutlierFactor expects n_neighbors <= n_samples
640
+ num_examples = self._sklearn_object.n_neighbors
641
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
642
+ else:
643
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
607
644
 
608
645
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
609
646
  # seen during the fit.
@@ -615,12 +652,14 @@ class VotingClassifier(BaseTransformer):
615
652
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
616
653
  if self.sample_weight_col:
617
654
  output_df_columns_set -= set(self.sample_weight_col)
655
+
618
656
  # if the dimension of inferred output column names is correct; use it
619
657
  if len(expected_output_cols_list) == len(output_df_columns_set):
620
- return expected_output_cols_list
658
+ return expected_output_cols_list, output_df_pd
621
659
  # otherwise, use the sklearn estimator's output
622
660
  else:
623
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
661
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
662
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
624
663
 
625
664
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
626
665
  @telemetry.send_api_usage_telemetry(
@@ -668,7 +707,7 @@ class VotingClassifier(BaseTransformer):
668
707
  drop_input_cols=self._drop_input_cols,
669
708
  expected_output_cols_type="float",
670
709
  )
671
- expected_output_cols = self._align_expected_output_names(
710
+ expected_output_cols, _ = self._align_expected_output(
672
711
  inference_method, dataset, expected_output_cols, output_cols_prefix
673
712
  )
674
713
 
@@ -736,7 +775,7 @@ class VotingClassifier(BaseTransformer):
736
775
  drop_input_cols=self._drop_input_cols,
737
776
  expected_output_cols_type="float",
738
777
  )
739
- expected_output_cols = self._align_expected_output_names(
778
+ expected_output_cols, _ = self._align_expected_output(
740
779
  inference_method, dataset, expected_output_cols, output_cols_prefix
741
780
  )
742
781
  elif isinstance(dataset, pd.DataFrame):
@@ -799,7 +838,7 @@ class VotingClassifier(BaseTransformer):
799
838
  drop_input_cols=self._drop_input_cols,
800
839
  expected_output_cols_type="float",
801
840
  )
802
- expected_output_cols = self._align_expected_output_names(
841
+ expected_output_cols, _ = self._align_expected_output(
803
842
  inference_method, dataset, expected_output_cols, output_cols_prefix
804
843
  )
805
844
 
@@ -864,7 +903,7 @@ class VotingClassifier(BaseTransformer):
864
903
  drop_input_cols = self._drop_input_cols,
865
904
  expected_output_cols_type="float",
866
905
  )
867
- expected_output_cols = self._align_expected_output_names(
906
+ expected_output_cols, _ = self._align_expected_output(
868
907
  inference_method, dataset, expected_output_cols, output_cols_prefix
869
908
  )
870
909
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -494,12 +491,23 @@ class VotingRegressor(BaseTransformer):
494
491
  autogenerated=self._autogenerated,
495
492
  subproject=_SUBPROJECT,
496
493
  )
497
- output_result, fitted_estimator = model_trainer.train_fit_predict(
498
- drop_input_cols=self._drop_input_cols,
499
- expected_output_cols_list=(
500
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
501
- ),
494
+ expected_output_cols = (
495
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
502
496
  )
497
+ if isinstance(dataset, DataFrame):
498
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
499
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
500
+ )
501
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
502
+ drop_input_cols=self._drop_input_cols,
503
+ expected_output_cols_list=expected_output_cols,
504
+ example_output_pd_df=example_output_pd_df,
505
+ )
506
+ else:
507
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
508
+ drop_input_cols=self._drop_input_cols,
509
+ expected_output_cols_list=expected_output_cols,
510
+ )
503
511
  self._sklearn_object = fitted_estimator
504
512
  self._is_fitted = True
505
513
  return output_result
@@ -580,12 +588,41 @@ class VotingRegressor(BaseTransformer):
580
588
 
581
589
  return rv
582
590
 
583
- def _align_expected_output_names(
584
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
585
- ) -> List[str]:
591
+ def _align_expected_output(
592
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
593
+ ) -> Tuple[List[str], pd.DataFrame]:
594
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
595
+ and output dataframe with 1 line.
596
+ If the method is fit_predict, run 2 lines of data.
597
+ """
586
598
  # in case the inferred output column names dimension is different
587
599
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
588
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
600
+
601
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
602
+ # so change the minimum of number of rows to 2
603
+ num_examples = 2
604
+ statement_params = telemetry.get_function_usage_statement_params(
605
+ project=_PROJECT,
606
+ subproject=_SUBPROJECT,
607
+ function_name=telemetry.get_statement_params_full_func_name(
608
+ inspect.currentframe(), VotingRegressor.__class__.__name__
609
+ ),
610
+ api_calls=[Session.call],
611
+ custom_tags={"autogen": True} if self._autogenerated else None,
612
+ )
613
+ if output_cols_prefix == "fit_predict_":
614
+ if hasattr(self._sklearn_object, "n_clusters"):
615
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
616
+ num_examples = self._sklearn_object.n_clusters
617
+ elif hasattr(self._sklearn_object, "min_samples"):
618
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
619
+ num_examples = self._sklearn_object.min_samples
620
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
621
+ # LocalOutlierFactor expects n_neighbors <= n_samples
622
+ num_examples = self._sklearn_object.n_neighbors
623
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
624
+ else:
625
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
589
626
 
590
627
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
591
628
  # seen during the fit.
@@ -597,12 +634,14 @@ class VotingRegressor(BaseTransformer):
597
634
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
598
635
  if self.sample_weight_col:
599
636
  output_df_columns_set -= set(self.sample_weight_col)
637
+
600
638
  # if the dimension of inferred output column names is correct; use it
601
639
  if len(expected_output_cols_list) == len(output_df_columns_set):
602
- return expected_output_cols_list
640
+ return expected_output_cols_list, output_df_pd
603
641
  # otherwise, use the sklearn estimator's output
604
642
  else:
605
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
643
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
644
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
606
645
 
607
646
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
608
647
  @telemetry.send_api_usage_telemetry(
@@ -648,7 +687,7 @@ class VotingRegressor(BaseTransformer):
648
687
  drop_input_cols=self._drop_input_cols,
649
688
  expected_output_cols_type="float",
650
689
  )
651
- expected_output_cols = self._align_expected_output_names(
690
+ expected_output_cols, _ = self._align_expected_output(
652
691
  inference_method, dataset, expected_output_cols, output_cols_prefix
653
692
  )
654
693
 
@@ -714,7 +753,7 @@ class VotingRegressor(BaseTransformer):
714
753
  drop_input_cols=self._drop_input_cols,
715
754
  expected_output_cols_type="float",
716
755
  )
717
- expected_output_cols = self._align_expected_output_names(
756
+ expected_output_cols, _ = self._align_expected_output(
718
757
  inference_method, dataset, expected_output_cols, output_cols_prefix
719
758
  )
720
759
  elif isinstance(dataset, pd.DataFrame):
@@ -777,7 +816,7 @@ class VotingRegressor(BaseTransformer):
777
816
  drop_input_cols=self._drop_input_cols,
778
817
  expected_output_cols_type="float",
779
818
  )
780
- expected_output_cols = self._align_expected_output_names(
819
+ expected_output_cols, _ = self._align_expected_output(
781
820
  inference_method, dataset, expected_output_cols, output_cols_prefix
782
821
  )
783
822
 
@@ -842,7 +881,7 @@ class VotingRegressor(BaseTransformer):
842
881
  drop_input_cols = self._drop_input_cols,
843
882
  expected_output_cols_type="float",
844
883
  )
845
- expected_output_cols = self._align_expected_output_names(
884
+ expected_output_cols, _ = self._align_expected_output(
846
885
  inference_method, dataset, expected_output_cols, output_cols_prefix
847
886
  )
848
887