snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. snowflake/ml/_internal/telemetry.py +142 -20
  2. snowflake/ml/_internal/utils/identifier.py +48 -11
  3. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  4. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  5. snowflake/ml/_internal/utils/table_manager.py +19 -1
  6. snowflake/ml/_internal/utils/uri.py +2 -2
  7. snowflake/ml/data/data_connector.py +33 -7
  8. snowflake/ml/data/torch_utils.py +68 -0
  9. snowflake/ml/dataset/dataset.py +1 -3
  10. snowflake/ml/feature_store/feature_store.py +41 -17
  11. snowflake/ml/feature_store/feature_view.py +2 -2
  12. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  13. snowflake/ml/fileset/fileset.py +1 -1
  14. snowflake/ml/fileset/sfcfs.py +9 -3
  15. snowflake/ml/model/_client/model/model_version_impl.py +22 -7
  16. snowflake/ml/model/_client/ops/model_ops.py +39 -3
  17. snowflake/ml/model/_client/ops/service_ops.py +198 -7
  18. snowflake/ml/model/_client/service/model_deployment_spec.py +4 -5
  19. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -2
  20. snowflake/ml/model/_client/sql/service.py +85 -18
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  23. snowflake/ml/model/_model_composer/model_composer.py +2 -0
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -8
  25. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  26. snowflake/ml/model/_packager/model_handlers/catboost.py +17 -15
  27. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +23 -15
  28. snowflake/ml/model/_packager/model_handlers/lightgbm.py +15 -57
  29. snowflake/ml/model/_packager/model_handlers/llm.py +4 -2
  30. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  31. snowflake/ml/model/_packager/model_handlers/sklearn.py +36 -24
  32. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +119 -6
  33. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  34. snowflake/ml/model/_packager/model_handlers/xgboost.py +48 -48
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +10 -7
  36. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +0 -8
  37. snowflake/ml/model/_packager/model_packager.py +2 -0
  38. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  39. snowflake/ml/model/_signatures/utils.py +9 -0
  40. snowflake/ml/model/models/llm.py +3 -1
  41. snowflake/ml/model/type_hints.py +9 -1
  42. snowflake/ml/modeling/_internal/constants.py +1 -0
  43. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  44. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  45. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  46. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  48. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  49. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  50. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  51. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  52. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  53. snowflake/ml/modeling/cluster/birch.py +60 -21
  54. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  55. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  56. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  57. snowflake/ml/modeling/cluster/k_means.py +60 -21
  58. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  59. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  60. snowflake/ml/modeling/cluster/optics.py +60 -21
  61. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  62. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  63. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  64. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  65. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  66. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  67. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  68. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  69. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  70. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  71. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  72. snowflake/ml/modeling/covariance/oas.py +60 -21
  73. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  74. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  75. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  76. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  77. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  78. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  79. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  80. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  81. snowflake/ml/modeling/decomposition/pca.py +60 -21
  82. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  83. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  84. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  85. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  86. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  87. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  88. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  89. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  90. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  91. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  92. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  93. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  94. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  95. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  96. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  97. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  98. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  99. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  100. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  101. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  102. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  103. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  104. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  105. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  106. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  107. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  108. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  109. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  111. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  112. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  113. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  114. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  123. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  125. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  129. snowflake/ml/modeling/linear_model/lars.py +60 -21
  130. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  131. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  136. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  146. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  149. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  158. snowflake/ml/modeling/manifold/isomap.py +60 -21
  159. snowflake/ml/modeling/manifold/mds.py +60 -21
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  161. snowflake/ml/modeling/manifold/tsne.py +60 -21
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  174. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  184. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  185. snowflake/ml/modeling/pipeline/pipeline.py +1 -12
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  189. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  190. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  191. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  192. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  193. snowflake/ml/modeling/svm/svc.py +60 -21
  194. snowflake/ml/modeling/svm/svr.py +60 -21
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  203. snowflake/ml/registry/_manager/model_manager.py +4 -0
  204. snowflake/ml/registry/model_registry.py +1 -1
  205. snowflake/ml/registry/registry.py +1 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +23 -4
  208. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +211 -209
  209. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  210. snowflake/ml/data/torch_dataset.py +0 -33
  211. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  212. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -25,12 +23,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
25
23
  from snowflake.ml._internal import telemetry
26
24
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
27
25
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
28
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
26
+ from snowflake.ml._internal.utils import identifier
29
27
  from snowflake.snowpark import DataFrame, Session
30
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
31
29
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
32
30
  from snowflake.ml.modeling._internal.transformer_protocols import (
33
- ModelTransformHandlers,
34
31
  BatchInferenceKwargsTypedDict,
35
32
  ScoreKwargsTypedDict
36
33
  )
@@ -479,12 +476,23 @@ class SelectKBest(BaseTransformer):
479
476
  autogenerated=self._autogenerated,
480
477
  subproject=_SUBPROJECT,
481
478
  )
482
- output_result, fitted_estimator = model_trainer.train_fit_predict(
483
- drop_input_cols=self._drop_input_cols,
484
- expected_output_cols_list=(
485
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
486
- ),
479
+ expected_output_cols = (
480
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
487
481
  )
482
+ if isinstance(dataset, DataFrame):
483
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
484
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
485
+ )
486
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
487
+ drop_input_cols=self._drop_input_cols,
488
+ expected_output_cols_list=expected_output_cols,
489
+ example_output_pd_df=example_output_pd_df,
490
+ )
491
+ else:
492
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
493
+ drop_input_cols=self._drop_input_cols,
494
+ expected_output_cols_list=expected_output_cols,
495
+ )
488
496
  self._sklearn_object = fitted_estimator
489
497
  self._is_fitted = True
490
498
  return output_result
@@ -565,12 +573,41 @@ class SelectKBest(BaseTransformer):
565
573
 
566
574
  return rv
567
575
 
568
- def _align_expected_output_names(
569
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
570
- ) -> List[str]:
576
+ def _align_expected_output(
577
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
578
+ ) -> Tuple[List[str], pd.DataFrame]:
579
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
580
+ and output dataframe with 1 line.
581
+ If the method is fit_predict, run 2 lines of data.
582
+ """
571
583
  # in case the inferred output column names dimension is different
572
584
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
573
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
585
+
586
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
587
+ # so change the minimum of number of rows to 2
588
+ num_examples = 2
589
+ statement_params = telemetry.get_function_usage_statement_params(
590
+ project=_PROJECT,
591
+ subproject=_SUBPROJECT,
592
+ function_name=telemetry.get_statement_params_full_func_name(
593
+ inspect.currentframe(), SelectKBest.__class__.__name__
594
+ ),
595
+ api_calls=[Session.call],
596
+ custom_tags={"autogen": True} if self._autogenerated else None,
597
+ )
598
+ if output_cols_prefix == "fit_predict_":
599
+ if hasattr(self._sklearn_object, "n_clusters"):
600
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
601
+ num_examples = self._sklearn_object.n_clusters
602
+ elif hasattr(self._sklearn_object, "min_samples"):
603
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
604
+ num_examples = self._sklearn_object.min_samples
605
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
606
+ # LocalOutlierFactor expects n_neighbors <= n_samples
607
+ num_examples = self._sklearn_object.n_neighbors
608
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
609
+ else:
610
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
574
611
 
575
612
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
576
613
  # seen during the fit.
@@ -582,12 +619,14 @@ class SelectKBest(BaseTransformer):
582
619
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
583
620
  if self.sample_weight_col:
584
621
  output_df_columns_set -= set(self.sample_weight_col)
622
+
585
623
  # if the dimension of inferred output column names is correct; use it
586
624
  if len(expected_output_cols_list) == len(output_df_columns_set):
587
- return expected_output_cols_list
625
+ return expected_output_cols_list, output_df_pd
588
626
  # otherwise, use the sklearn estimator's output
589
627
  else:
590
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
628
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
629
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
591
630
 
592
631
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
593
632
  @telemetry.send_api_usage_telemetry(
@@ -633,7 +672,7 @@ class SelectKBest(BaseTransformer):
633
672
  drop_input_cols=self._drop_input_cols,
634
673
  expected_output_cols_type="float",
635
674
  )
636
- expected_output_cols = self._align_expected_output_names(
675
+ expected_output_cols, _ = self._align_expected_output(
637
676
  inference_method, dataset, expected_output_cols, output_cols_prefix
638
677
  )
639
678
 
@@ -699,7 +738,7 @@ class SelectKBest(BaseTransformer):
699
738
  drop_input_cols=self._drop_input_cols,
700
739
  expected_output_cols_type="float",
701
740
  )
702
- expected_output_cols = self._align_expected_output_names(
741
+ expected_output_cols, _ = self._align_expected_output(
703
742
  inference_method, dataset, expected_output_cols, output_cols_prefix
704
743
  )
705
744
  elif isinstance(dataset, pd.DataFrame):
@@ -762,7 +801,7 @@ class SelectKBest(BaseTransformer):
762
801
  drop_input_cols=self._drop_input_cols,
763
802
  expected_output_cols_type="float",
764
803
  )
765
- expected_output_cols = self._align_expected_output_names(
804
+ expected_output_cols, _ = self._align_expected_output(
766
805
  inference_method, dataset, expected_output_cols, output_cols_prefix
767
806
  )
768
807
 
@@ -827,7 +866,7 @@ class SelectKBest(BaseTransformer):
827
866
  drop_input_cols = self._drop_input_cols,
828
867
  expected_output_cols_type="float",
829
868
  )
830
- expected_output_cols = self._align_expected_output_names(
869
+ expected_output_cols, _ = self._align_expected_output(
831
870
  inference_method, dataset, expected_output_cols, output_cols_prefix
832
871
  )
833
872
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -25,12 +23,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
25
23
  from snowflake.ml._internal import telemetry
26
24
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
27
25
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
28
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
26
+ from snowflake.ml._internal.utils import identifier
29
27
  from snowflake.snowpark import DataFrame, Session
30
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
31
29
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
32
30
  from snowflake.ml.modeling._internal.transformer_protocols import (
33
- ModelTransformHandlers,
34
31
  BatchInferenceKwargsTypedDict,
35
32
  ScoreKwargsTypedDict
36
33
  )
@@ -478,12 +475,23 @@ class SelectPercentile(BaseTransformer):
478
475
  autogenerated=self._autogenerated,
479
476
  subproject=_SUBPROJECT,
480
477
  )
481
- output_result, fitted_estimator = model_trainer.train_fit_predict(
482
- drop_input_cols=self._drop_input_cols,
483
- expected_output_cols_list=(
484
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
485
- ),
478
+ expected_output_cols = (
479
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
486
480
  )
481
+ if isinstance(dataset, DataFrame):
482
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
483
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
484
+ )
485
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
486
+ drop_input_cols=self._drop_input_cols,
487
+ expected_output_cols_list=expected_output_cols,
488
+ example_output_pd_df=example_output_pd_df,
489
+ )
490
+ else:
491
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
492
+ drop_input_cols=self._drop_input_cols,
493
+ expected_output_cols_list=expected_output_cols,
494
+ )
487
495
  self._sklearn_object = fitted_estimator
488
496
  self._is_fitted = True
489
497
  return output_result
@@ -564,12 +572,41 @@ class SelectPercentile(BaseTransformer):
564
572
 
565
573
  return rv
566
574
 
567
- def _align_expected_output_names(
568
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
569
- ) -> List[str]:
575
+ def _align_expected_output(
576
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
577
+ ) -> Tuple[List[str], pd.DataFrame]:
578
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
579
+ and output dataframe with 1 line.
580
+ If the method is fit_predict, run 2 lines of data.
581
+ """
570
582
  # in case the inferred output column names dimension is different
571
583
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
572
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
584
+
585
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
586
+ # so change the minimum of number of rows to 2
587
+ num_examples = 2
588
+ statement_params = telemetry.get_function_usage_statement_params(
589
+ project=_PROJECT,
590
+ subproject=_SUBPROJECT,
591
+ function_name=telemetry.get_statement_params_full_func_name(
592
+ inspect.currentframe(), SelectPercentile.__class__.__name__
593
+ ),
594
+ api_calls=[Session.call],
595
+ custom_tags={"autogen": True} if self._autogenerated else None,
596
+ )
597
+ if output_cols_prefix == "fit_predict_":
598
+ if hasattr(self._sklearn_object, "n_clusters"):
599
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
600
+ num_examples = self._sklearn_object.n_clusters
601
+ elif hasattr(self._sklearn_object, "min_samples"):
602
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
603
+ num_examples = self._sklearn_object.min_samples
604
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
605
+ # LocalOutlierFactor expects n_neighbors <= n_samples
606
+ num_examples = self._sklearn_object.n_neighbors
607
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
608
+ else:
609
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
573
610
 
574
611
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
575
612
  # seen during the fit.
@@ -581,12 +618,14 @@ class SelectPercentile(BaseTransformer):
581
618
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
582
619
  if self.sample_weight_col:
583
620
  output_df_columns_set -= set(self.sample_weight_col)
621
+
584
622
  # if the dimension of inferred output column names is correct; use it
585
623
  if len(expected_output_cols_list) == len(output_df_columns_set):
586
- return expected_output_cols_list
624
+ return expected_output_cols_list, output_df_pd
587
625
  # otherwise, use the sklearn estimator's output
588
626
  else:
589
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
627
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
628
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
590
629
 
591
630
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
592
631
  @telemetry.send_api_usage_telemetry(
@@ -632,7 +671,7 @@ class SelectPercentile(BaseTransformer):
632
671
  drop_input_cols=self._drop_input_cols,
633
672
  expected_output_cols_type="float",
634
673
  )
635
- expected_output_cols = self._align_expected_output_names(
674
+ expected_output_cols, _ = self._align_expected_output(
636
675
  inference_method, dataset, expected_output_cols, output_cols_prefix
637
676
  )
638
677
 
@@ -698,7 +737,7 @@ class SelectPercentile(BaseTransformer):
698
737
  drop_input_cols=self._drop_input_cols,
699
738
  expected_output_cols_type="float",
700
739
  )
701
- expected_output_cols = self._align_expected_output_names(
740
+ expected_output_cols, _ = self._align_expected_output(
702
741
  inference_method, dataset, expected_output_cols, output_cols_prefix
703
742
  )
704
743
  elif isinstance(dataset, pd.DataFrame):
@@ -761,7 +800,7 @@ class SelectPercentile(BaseTransformer):
761
800
  drop_input_cols=self._drop_input_cols,
762
801
  expected_output_cols_type="float",
763
802
  )
764
- expected_output_cols = self._align_expected_output_names(
803
+ expected_output_cols, _ = self._align_expected_output(
765
804
  inference_method, dataset, expected_output_cols, output_cols_prefix
766
805
  )
767
806
 
@@ -826,7 +865,7 @@ class SelectPercentile(BaseTransformer):
826
865
  drop_input_cols = self._drop_input_cols,
827
866
  expected_output_cols_type="float",
828
867
  )
829
- expected_output_cols = self._align_expected_output_names(
868
+ expected_output_cols, _ = self._align_expected_output(
830
869
  inference_method, dataset, expected_output_cols, output_cols_prefix
831
870
  )
832
871
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -538,12 +535,23 @@ class SequentialFeatureSelector(BaseTransformer):
538
535
  autogenerated=self._autogenerated,
539
536
  subproject=_SUBPROJECT,
540
537
  )
541
- output_result, fitted_estimator = model_trainer.train_fit_predict(
542
- drop_input_cols=self._drop_input_cols,
543
- expected_output_cols_list=(
544
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
545
- ),
538
+ expected_output_cols = (
539
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
546
540
  )
541
+ if isinstance(dataset, DataFrame):
542
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
543
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
544
+ )
545
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
546
+ drop_input_cols=self._drop_input_cols,
547
+ expected_output_cols_list=expected_output_cols,
548
+ example_output_pd_df=example_output_pd_df,
549
+ )
550
+ else:
551
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
552
+ drop_input_cols=self._drop_input_cols,
553
+ expected_output_cols_list=expected_output_cols,
554
+ )
547
555
  self._sklearn_object = fitted_estimator
548
556
  self._is_fitted = True
549
557
  return output_result
@@ -624,12 +632,41 @@ class SequentialFeatureSelector(BaseTransformer):
624
632
 
625
633
  return rv
626
634
 
627
- def _align_expected_output_names(
628
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
629
- ) -> List[str]:
635
+ def _align_expected_output(
636
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
637
+ ) -> Tuple[List[str], pd.DataFrame]:
638
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
639
+ and output dataframe with 1 line.
640
+ If the method is fit_predict, run 2 lines of data.
641
+ """
630
642
  # in case the inferred output column names dimension is different
631
643
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
632
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
644
+
645
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
646
+ # so change the minimum of number of rows to 2
647
+ num_examples = 2
648
+ statement_params = telemetry.get_function_usage_statement_params(
649
+ project=_PROJECT,
650
+ subproject=_SUBPROJECT,
651
+ function_name=telemetry.get_statement_params_full_func_name(
652
+ inspect.currentframe(), SequentialFeatureSelector.__class__.__name__
653
+ ),
654
+ api_calls=[Session.call],
655
+ custom_tags={"autogen": True} if self._autogenerated else None,
656
+ )
657
+ if output_cols_prefix == "fit_predict_":
658
+ if hasattr(self._sklearn_object, "n_clusters"):
659
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
660
+ num_examples = self._sklearn_object.n_clusters
661
+ elif hasattr(self._sklearn_object, "min_samples"):
662
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
663
+ num_examples = self._sklearn_object.min_samples
664
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
665
+ # LocalOutlierFactor expects n_neighbors <= n_samples
666
+ num_examples = self._sklearn_object.n_neighbors
667
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
668
+ else:
669
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
633
670
 
634
671
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
635
672
  # seen during the fit.
@@ -641,12 +678,14 @@ class SequentialFeatureSelector(BaseTransformer):
641
678
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
642
679
  if self.sample_weight_col:
643
680
  output_df_columns_set -= set(self.sample_weight_col)
681
+
644
682
  # if the dimension of inferred output column names is correct; use it
645
683
  if len(expected_output_cols_list) == len(output_df_columns_set):
646
- return expected_output_cols_list
684
+ return expected_output_cols_list, output_df_pd
647
685
  # otherwise, use the sklearn estimator's output
648
686
  else:
649
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
687
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
688
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
650
689
 
651
690
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
652
691
  @telemetry.send_api_usage_telemetry(
@@ -692,7 +731,7 @@ class SequentialFeatureSelector(BaseTransformer):
692
731
  drop_input_cols=self._drop_input_cols,
693
732
  expected_output_cols_type="float",
694
733
  )
695
- expected_output_cols = self._align_expected_output_names(
734
+ expected_output_cols, _ = self._align_expected_output(
696
735
  inference_method, dataset, expected_output_cols, output_cols_prefix
697
736
  )
698
737
 
@@ -758,7 +797,7 @@ class SequentialFeatureSelector(BaseTransformer):
758
797
  drop_input_cols=self._drop_input_cols,
759
798
  expected_output_cols_type="float",
760
799
  )
761
- expected_output_cols = self._align_expected_output_names(
800
+ expected_output_cols, _ = self._align_expected_output(
762
801
  inference_method, dataset, expected_output_cols, output_cols_prefix
763
802
  )
764
803
  elif isinstance(dataset, pd.DataFrame):
@@ -821,7 +860,7 @@ class SequentialFeatureSelector(BaseTransformer):
821
860
  drop_input_cols=self._drop_input_cols,
822
861
  expected_output_cols_type="float",
823
862
  )
824
- expected_output_cols = self._align_expected_output_names(
863
+ expected_output_cols, _ = self._align_expected_output(
825
864
  inference_method, dataset, expected_output_cols, output_cols_prefix
826
865
  )
827
866
 
@@ -886,7 +925,7 @@ class SequentialFeatureSelector(BaseTransformer):
886
925
  drop_input_cols = self._drop_input_cols,
887
926
  expected_output_cols_type="float",
888
927
  )
889
- expected_output_cols = self._align_expected_output_names(
928
+ expected_output_cols, _ = self._align_expected_output(
890
929
  inference_method, dataset, expected_output_cols, output_cols_prefix
891
930
  )
892
931
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -471,12 +468,23 @@ class VarianceThreshold(BaseTransformer):
471
468
  autogenerated=self._autogenerated,
472
469
  subproject=_SUBPROJECT,
473
470
  )
474
- output_result, fitted_estimator = model_trainer.train_fit_predict(
475
- drop_input_cols=self._drop_input_cols,
476
- expected_output_cols_list=(
477
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
478
- ),
471
+ expected_output_cols = (
472
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
479
473
  )
474
+ if isinstance(dataset, DataFrame):
475
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
476
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
477
+ )
478
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
479
+ drop_input_cols=self._drop_input_cols,
480
+ expected_output_cols_list=expected_output_cols,
481
+ example_output_pd_df=example_output_pd_df,
482
+ )
483
+ else:
484
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
485
+ drop_input_cols=self._drop_input_cols,
486
+ expected_output_cols_list=expected_output_cols,
487
+ )
480
488
  self._sklearn_object = fitted_estimator
481
489
  self._is_fitted = True
482
490
  return output_result
@@ -557,12 +565,41 @@ class VarianceThreshold(BaseTransformer):
557
565
 
558
566
  return rv
559
567
 
560
- def _align_expected_output_names(
561
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
562
- ) -> List[str]:
568
+ def _align_expected_output(
569
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
570
+ ) -> Tuple[List[str], pd.DataFrame]:
571
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
572
+ and output dataframe with 1 line.
573
+ If the method is fit_predict, run 2 lines of data.
574
+ """
563
575
  # in case the inferred output column names dimension is different
564
576
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
565
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
577
+
578
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
579
+ # so change the minimum of number of rows to 2
580
+ num_examples = 2
581
+ statement_params = telemetry.get_function_usage_statement_params(
582
+ project=_PROJECT,
583
+ subproject=_SUBPROJECT,
584
+ function_name=telemetry.get_statement_params_full_func_name(
585
+ inspect.currentframe(), VarianceThreshold.__class__.__name__
586
+ ),
587
+ api_calls=[Session.call],
588
+ custom_tags={"autogen": True} if self._autogenerated else None,
589
+ )
590
+ if output_cols_prefix == "fit_predict_":
591
+ if hasattr(self._sklearn_object, "n_clusters"):
592
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
593
+ num_examples = self._sklearn_object.n_clusters
594
+ elif hasattr(self._sklearn_object, "min_samples"):
595
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
596
+ num_examples = self._sklearn_object.min_samples
597
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
598
+ # LocalOutlierFactor expects n_neighbors <= n_samples
599
+ num_examples = self._sklearn_object.n_neighbors
600
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
601
+ else:
602
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
566
603
 
567
604
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
568
605
  # seen during the fit.
@@ -574,12 +611,14 @@ class VarianceThreshold(BaseTransformer):
574
611
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
575
612
  if self.sample_weight_col:
576
613
  output_df_columns_set -= set(self.sample_weight_col)
614
+
577
615
  # if the dimension of inferred output column names is correct; use it
578
616
  if len(expected_output_cols_list) == len(output_df_columns_set):
579
- return expected_output_cols_list
617
+ return expected_output_cols_list, output_df_pd
580
618
  # otherwise, use the sklearn estimator's output
581
619
  else:
582
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
620
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
621
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
583
622
 
584
623
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
585
624
  @telemetry.send_api_usage_telemetry(
@@ -625,7 +664,7 @@ class VarianceThreshold(BaseTransformer):
625
664
  drop_input_cols=self._drop_input_cols,
626
665
  expected_output_cols_type="float",
627
666
  )
628
- expected_output_cols = self._align_expected_output_names(
667
+ expected_output_cols, _ = self._align_expected_output(
629
668
  inference_method, dataset, expected_output_cols, output_cols_prefix
630
669
  )
631
670
 
@@ -691,7 +730,7 @@ class VarianceThreshold(BaseTransformer):
691
730
  drop_input_cols=self._drop_input_cols,
692
731
  expected_output_cols_type="float",
693
732
  )
694
- expected_output_cols = self._align_expected_output_names(
733
+ expected_output_cols, _ = self._align_expected_output(
695
734
  inference_method, dataset, expected_output_cols, output_cols_prefix
696
735
  )
697
736
  elif isinstance(dataset, pd.DataFrame):
@@ -754,7 +793,7 @@ class VarianceThreshold(BaseTransformer):
754
793
  drop_input_cols=self._drop_input_cols,
755
794
  expected_output_cols_type="float",
756
795
  )
757
- expected_output_cols = self._align_expected_output_names(
796
+ expected_output_cols, _ = self._align_expected_output(
758
797
  inference_method, dataset, expected_output_cols, output_cols_prefix
759
798
  )
760
799
 
@@ -819,7 +858,7 @@ class VarianceThreshold(BaseTransformer):
819
858
  drop_input_cols = self._drop_input_cols,
820
859
  expected_output_cols_type="float",
821
860
  )
822
- expected_output_cols = self._align_expected_output_names(
861
+ expected_output_cols, _ = self._align_expected_output(
823
862
  inference_method, dataset, expected_output_cols, output_cols_prefix
824
863
  )
825
864