snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. snowflake/ml/_internal/telemetry.py +142 -20
  2. snowflake/ml/_internal/utils/identifier.py +48 -11
  3. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  4. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  5. snowflake/ml/_internal/utils/table_manager.py +19 -1
  6. snowflake/ml/_internal/utils/uri.py +2 -2
  7. snowflake/ml/data/data_connector.py +33 -7
  8. snowflake/ml/data/torch_utils.py +68 -0
  9. snowflake/ml/dataset/dataset.py +1 -3
  10. snowflake/ml/feature_store/feature_store.py +41 -17
  11. snowflake/ml/feature_store/feature_view.py +2 -2
  12. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  13. snowflake/ml/fileset/fileset.py +1 -1
  14. snowflake/ml/fileset/sfcfs.py +9 -3
  15. snowflake/ml/model/_client/model/model_version_impl.py +22 -7
  16. snowflake/ml/model/_client/ops/model_ops.py +39 -3
  17. snowflake/ml/model/_client/ops/service_ops.py +198 -7
  18. snowflake/ml/model/_client/service/model_deployment_spec.py +4 -5
  19. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -2
  20. snowflake/ml/model/_client/sql/service.py +85 -18
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  23. snowflake/ml/model/_model_composer/model_composer.py +2 -0
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -8
  25. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  26. snowflake/ml/model/_packager/model_handlers/catboost.py +17 -15
  27. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +23 -15
  28. snowflake/ml/model/_packager/model_handlers/lightgbm.py +15 -57
  29. snowflake/ml/model/_packager/model_handlers/llm.py +4 -2
  30. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  31. snowflake/ml/model/_packager/model_handlers/sklearn.py +36 -24
  32. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +119 -6
  33. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  34. snowflake/ml/model/_packager/model_handlers/xgboost.py +48 -48
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +10 -7
  36. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +0 -8
  37. snowflake/ml/model/_packager/model_packager.py +2 -0
  38. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  39. snowflake/ml/model/_signatures/utils.py +9 -0
  40. snowflake/ml/model/models/llm.py +3 -1
  41. snowflake/ml/model/type_hints.py +9 -1
  42. snowflake/ml/modeling/_internal/constants.py +1 -0
  43. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  44. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  45. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  46. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  48. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  49. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  50. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  51. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  52. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  53. snowflake/ml/modeling/cluster/birch.py +60 -21
  54. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  55. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  56. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  57. snowflake/ml/modeling/cluster/k_means.py +60 -21
  58. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  59. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  60. snowflake/ml/modeling/cluster/optics.py +60 -21
  61. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  62. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  63. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  64. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  65. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  66. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  67. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  68. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  69. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  70. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  71. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  72. snowflake/ml/modeling/covariance/oas.py +60 -21
  73. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  74. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  75. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  76. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  77. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  78. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  79. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  80. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  81. snowflake/ml/modeling/decomposition/pca.py +60 -21
  82. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  83. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  84. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  85. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  86. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  87. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  88. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  89. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  90. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  91. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  92. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  93. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  94. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  95. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  96. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  97. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  98. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  99. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  100. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  101. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  102. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  103. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  104. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  105. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  106. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  107. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  108. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  109. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  111. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  112. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  113. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  114. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  123. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  125. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  129. snowflake/ml/modeling/linear_model/lars.py +60 -21
  130. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  131. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  136. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  146. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  149. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  158. snowflake/ml/modeling/manifold/isomap.py +60 -21
  159. snowflake/ml/modeling/manifold/mds.py +60 -21
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  161. snowflake/ml/modeling/manifold/tsne.py +60 -21
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  174. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  184. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  185. snowflake/ml/modeling/pipeline/pipeline.py +1 -12
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  189. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  190. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  191. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  192. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  193. snowflake/ml/modeling/svm/svc.py +60 -21
  194. snowflake/ml/modeling/svm/svr.py +60 -21
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  203. snowflake/ml/registry/_manager/model_manager.py +4 -0
  204. snowflake/ml/registry/model_registry.py +1 -1
  205. snowflake/ml/registry/registry.py +1 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +23 -4
  208. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +211 -209
  209. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  210. snowflake/ml/data/torch_dataset.py +0 -33
  211. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  212. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -564,12 +561,23 @@ class GaussianProcessClassifier(BaseTransformer):
564
561
  autogenerated=self._autogenerated,
565
562
  subproject=_SUBPROJECT,
566
563
  )
567
- output_result, fitted_estimator = model_trainer.train_fit_predict(
568
- drop_input_cols=self._drop_input_cols,
569
- expected_output_cols_list=(
570
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
571
- ),
564
+ expected_output_cols = (
565
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
572
566
  )
567
+ if isinstance(dataset, DataFrame):
568
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
569
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
570
+ )
571
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
572
+ drop_input_cols=self._drop_input_cols,
573
+ expected_output_cols_list=expected_output_cols,
574
+ example_output_pd_df=example_output_pd_df,
575
+ )
576
+ else:
577
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
578
+ drop_input_cols=self._drop_input_cols,
579
+ expected_output_cols_list=expected_output_cols,
580
+ )
573
581
  self._sklearn_object = fitted_estimator
574
582
  self._is_fitted = True
575
583
  return output_result
@@ -648,12 +656,41 @@ class GaussianProcessClassifier(BaseTransformer):
648
656
 
649
657
  return rv
650
658
 
651
- def _align_expected_output_names(
652
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
653
- ) -> List[str]:
659
+ def _align_expected_output(
660
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
661
+ ) -> Tuple[List[str], pd.DataFrame]:
662
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
663
+ and output dataframe with 1 line.
664
+ If the method is fit_predict, run 2 lines of data.
665
+ """
654
666
  # in case the inferred output column names dimension is different
655
667
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
656
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
668
+
669
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
670
+ # so change the minimum of number of rows to 2
671
+ num_examples = 2
672
+ statement_params = telemetry.get_function_usage_statement_params(
673
+ project=_PROJECT,
674
+ subproject=_SUBPROJECT,
675
+ function_name=telemetry.get_statement_params_full_func_name(
676
+ inspect.currentframe(), GaussianProcessClassifier.__class__.__name__
677
+ ),
678
+ api_calls=[Session.call],
679
+ custom_tags={"autogen": True} if self._autogenerated else None,
680
+ )
681
+ if output_cols_prefix == "fit_predict_":
682
+ if hasattr(self._sklearn_object, "n_clusters"):
683
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
684
+ num_examples = self._sklearn_object.n_clusters
685
+ elif hasattr(self._sklearn_object, "min_samples"):
686
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
687
+ num_examples = self._sklearn_object.min_samples
688
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
689
+ # LocalOutlierFactor expects n_neighbors <= n_samples
690
+ num_examples = self._sklearn_object.n_neighbors
691
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
692
+ else:
693
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
657
694
 
658
695
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
659
696
  # seen during the fit.
@@ -665,12 +702,14 @@ class GaussianProcessClassifier(BaseTransformer):
665
702
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
666
703
  if self.sample_weight_col:
667
704
  output_df_columns_set -= set(self.sample_weight_col)
705
+
668
706
  # if the dimension of inferred output column names is correct; use it
669
707
  if len(expected_output_cols_list) == len(output_df_columns_set):
670
- return expected_output_cols_list
708
+ return expected_output_cols_list, output_df_pd
671
709
  # otherwise, use the sklearn estimator's output
672
710
  else:
673
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
711
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
712
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
674
713
 
675
714
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
676
715
  @telemetry.send_api_usage_telemetry(
@@ -718,7 +757,7 @@ class GaussianProcessClassifier(BaseTransformer):
718
757
  drop_input_cols=self._drop_input_cols,
719
758
  expected_output_cols_type="float",
720
759
  )
721
- expected_output_cols = self._align_expected_output_names(
760
+ expected_output_cols, _ = self._align_expected_output(
722
761
  inference_method, dataset, expected_output_cols, output_cols_prefix
723
762
  )
724
763
 
@@ -786,7 +825,7 @@ class GaussianProcessClassifier(BaseTransformer):
786
825
  drop_input_cols=self._drop_input_cols,
787
826
  expected_output_cols_type="float",
788
827
  )
789
- expected_output_cols = self._align_expected_output_names(
828
+ expected_output_cols, _ = self._align_expected_output(
790
829
  inference_method, dataset, expected_output_cols, output_cols_prefix
791
830
  )
792
831
  elif isinstance(dataset, pd.DataFrame):
@@ -849,7 +888,7 @@ class GaussianProcessClassifier(BaseTransformer):
849
888
  drop_input_cols=self._drop_input_cols,
850
889
  expected_output_cols_type="float",
851
890
  )
852
- expected_output_cols = self._align_expected_output_names(
891
+ expected_output_cols, _ = self._align_expected_output(
853
892
  inference_method, dataset, expected_output_cols, output_cols_prefix
854
893
  )
855
894
 
@@ -914,7 +953,7 @@ class GaussianProcessClassifier(BaseTransformer):
914
953
  drop_input_cols = self._drop_input_cols,
915
954
  expected_output_cols_type="float",
916
955
  )
917
- expected_output_cols = self._align_expected_output_names(
956
+ expected_output_cols, _ = self._align_expected_output(
918
957
  inference_method, dataset, expected_output_cols, output_cols_prefix
919
958
  )
920
959
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -555,12 +552,23 @@ class GaussianProcessRegressor(BaseTransformer):
555
552
  autogenerated=self._autogenerated,
556
553
  subproject=_SUBPROJECT,
557
554
  )
558
- output_result, fitted_estimator = model_trainer.train_fit_predict(
559
- drop_input_cols=self._drop_input_cols,
560
- expected_output_cols_list=(
561
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
562
- ),
555
+ expected_output_cols = (
556
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
563
557
  )
558
+ if isinstance(dataset, DataFrame):
559
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
560
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
561
+ )
562
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
563
+ drop_input_cols=self._drop_input_cols,
564
+ expected_output_cols_list=expected_output_cols,
565
+ example_output_pd_df=example_output_pd_df,
566
+ )
567
+ else:
568
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
569
+ drop_input_cols=self._drop_input_cols,
570
+ expected_output_cols_list=expected_output_cols,
571
+ )
564
572
  self._sklearn_object = fitted_estimator
565
573
  self._is_fitted = True
566
574
  return output_result
@@ -639,12 +647,41 @@ class GaussianProcessRegressor(BaseTransformer):
639
647
 
640
648
  return rv
641
649
 
642
- def _align_expected_output_names(
643
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
644
- ) -> List[str]:
650
+ def _align_expected_output(
651
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
652
+ ) -> Tuple[List[str], pd.DataFrame]:
653
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
654
+ and output dataframe with 1 line.
655
+ If the method is fit_predict, run 2 lines of data.
656
+ """
645
657
  # in case the inferred output column names dimension is different
646
658
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
647
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
659
+
660
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
661
+ # so change the minimum of number of rows to 2
662
+ num_examples = 2
663
+ statement_params = telemetry.get_function_usage_statement_params(
664
+ project=_PROJECT,
665
+ subproject=_SUBPROJECT,
666
+ function_name=telemetry.get_statement_params_full_func_name(
667
+ inspect.currentframe(), GaussianProcessRegressor.__class__.__name__
668
+ ),
669
+ api_calls=[Session.call],
670
+ custom_tags={"autogen": True} if self._autogenerated else None,
671
+ )
672
+ if output_cols_prefix == "fit_predict_":
673
+ if hasattr(self._sklearn_object, "n_clusters"):
674
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
675
+ num_examples = self._sklearn_object.n_clusters
676
+ elif hasattr(self._sklearn_object, "min_samples"):
677
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
678
+ num_examples = self._sklearn_object.min_samples
679
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
680
+ # LocalOutlierFactor expects n_neighbors <= n_samples
681
+ num_examples = self._sklearn_object.n_neighbors
682
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
683
+ else:
684
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
648
685
 
649
686
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
650
687
  # seen during the fit.
@@ -656,12 +693,14 @@ class GaussianProcessRegressor(BaseTransformer):
656
693
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
657
694
  if self.sample_weight_col:
658
695
  output_df_columns_set -= set(self.sample_weight_col)
696
+
659
697
  # if the dimension of inferred output column names is correct; use it
660
698
  if len(expected_output_cols_list) == len(output_df_columns_set):
661
- return expected_output_cols_list
699
+ return expected_output_cols_list, output_df_pd
662
700
  # otherwise, use the sklearn estimator's output
663
701
  else:
664
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
702
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
703
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
665
704
 
666
705
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
667
706
  @telemetry.send_api_usage_telemetry(
@@ -707,7 +746,7 @@ class GaussianProcessRegressor(BaseTransformer):
707
746
  drop_input_cols=self._drop_input_cols,
708
747
  expected_output_cols_type="float",
709
748
  )
710
- expected_output_cols = self._align_expected_output_names(
749
+ expected_output_cols, _ = self._align_expected_output(
711
750
  inference_method, dataset, expected_output_cols, output_cols_prefix
712
751
  )
713
752
 
@@ -773,7 +812,7 @@ class GaussianProcessRegressor(BaseTransformer):
773
812
  drop_input_cols=self._drop_input_cols,
774
813
  expected_output_cols_type="float",
775
814
  )
776
- expected_output_cols = self._align_expected_output_names(
815
+ expected_output_cols, _ = self._align_expected_output(
777
816
  inference_method, dataset, expected_output_cols, output_cols_prefix
778
817
  )
779
818
  elif isinstance(dataset, pd.DataFrame):
@@ -836,7 +875,7 @@ class GaussianProcessRegressor(BaseTransformer):
836
875
  drop_input_cols=self._drop_input_cols,
837
876
  expected_output_cols_type="float",
838
877
  )
839
- expected_output_cols = self._align_expected_output_names(
878
+ expected_output_cols, _ = self._align_expected_output(
840
879
  inference_method, dataset, expected_output_cols, output_cols_prefix
841
880
  )
842
881
 
@@ -901,7 +940,7 @@ class GaussianProcessRegressor(BaseTransformer):
901
940
  drop_input_cols = self._drop_input_cols,
902
941
  expected_output_cols_type="float",
903
942
  )
904
- expected_output_cols = self._align_expected_output_names(
943
+ expected_output_cols, _ = self._align_expected_output(
905
944
  inference_method, dataset, expected_output_cols, output_cols_prefix
906
945
  )
907
946
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -25,12 +23,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
25
23
  from snowflake.ml._internal import telemetry
26
24
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
27
25
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
28
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
26
+ from snowflake.ml._internal.utils import identifier
29
27
  from snowflake.snowpark import DataFrame, Session
30
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
31
29
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
32
30
  from snowflake.ml.modeling._internal.transformer_protocols import (
33
- ModelTransformHandlers,
34
31
  BatchInferenceKwargsTypedDict,
35
32
  ScoreKwargsTypedDict
36
33
  )
@@ -599,12 +596,23 @@ class IterativeImputer(BaseTransformer):
599
596
  autogenerated=self._autogenerated,
600
597
  subproject=_SUBPROJECT,
601
598
  )
602
- output_result, fitted_estimator = model_trainer.train_fit_predict(
603
- drop_input_cols=self._drop_input_cols,
604
- expected_output_cols_list=(
605
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
606
- ),
599
+ expected_output_cols = (
600
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
607
601
  )
602
+ if isinstance(dataset, DataFrame):
603
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
604
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
605
+ )
606
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
607
+ drop_input_cols=self._drop_input_cols,
608
+ expected_output_cols_list=expected_output_cols,
609
+ example_output_pd_df=example_output_pd_df,
610
+ )
611
+ else:
612
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
613
+ drop_input_cols=self._drop_input_cols,
614
+ expected_output_cols_list=expected_output_cols,
615
+ )
608
616
  self._sklearn_object = fitted_estimator
609
617
  self._is_fitted = True
610
618
  return output_result
@@ -685,12 +693,41 @@ class IterativeImputer(BaseTransformer):
685
693
 
686
694
  return rv
687
695
 
688
- def _align_expected_output_names(
689
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
690
- ) -> List[str]:
696
+ def _align_expected_output(
697
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
698
+ ) -> Tuple[List[str], pd.DataFrame]:
699
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
700
+ and output dataframe with 1 line.
701
+ If the method is fit_predict, run 2 lines of data.
702
+ """
691
703
  # in case the inferred output column names dimension is different
692
704
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
693
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
705
+
706
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
707
+ # so change the minimum of number of rows to 2
708
+ num_examples = 2
709
+ statement_params = telemetry.get_function_usage_statement_params(
710
+ project=_PROJECT,
711
+ subproject=_SUBPROJECT,
712
+ function_name=telemetry.get_statement_params_full_func_name(
713
+ inspect.currentframe(), IterativeImputer.__class__.__name__
714
+ ),
715
+ api_calls=[Session.call],
716
+ custom_tags={"autogen": True} if self._autogenerated else None,
717
+ )
718
+ if output_cols_prefix == "fit_predict_":
719
+ if hasattr(self._sklearn_object, "n_clusters"):
720
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
721
+ num_examples = self._sklearn_object.n_clusters
722
+ elif hasattr(self._sklearn_object, "min_samples"):
723
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
724
+ num_examples = self._sklearn_object.min_samples
725
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
726
+ # LocalOutlierFactor expects n_neighbors <= n_samples
727
+ num_examples = self._sklearn_object.n_neighbors
728
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
729
+ else:
730
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
694
731
 
695
732
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
696
733
  # seen during the fit.
@@ -702,12 +739,14 @@ class IterativeImputer(BaseTransformer):
702
739
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
703
740
  if self.sample_weight_col:
704
741
  output_df_columns_set -= set(self.sample_weight_col)
742
+
705
743
  # if the dimension of inferred output column names is correct; use it
706
744
  if len(expected_output_cols_list) == len(output_df_columns_set):
707
- return expected_output_cols_list
745
+ return expected_output_cols_list, output_df_pd
708
746
  # otherwise, use the sklearn estimator's output
709
747
  else:
710
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
748
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
749
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
711
750
 
712
751
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
713
752
  @telemetry.send_api_usage_telemetry(
@@ -753,7 +792,7 @@ class IterativeImputer(BaseTransformer):
753
792
  drop_input_cols=self._drop_input_cols,
754
793
  expected_output_cols_type="float",
755
794
  )
756
- expected_output_cols = self._align_expected_output_names(
795
+ expected_output_cols, _ = self._align_expected_output(
757
796
  inference_method, dataset, expected_output_cols, output_cols_prefix
758
797
  )
759
798
 
@@ -819,7 +858,7 @@ class IterativeImputer(BaseTransformer):
819
858
  drop_input_cols=self._drop_input_cols,
820
859
  expected_output_cols_type="float",
821
860
  )
822
- expected_output_cols = self._align_expected_output_names(
861
+ expected_output_cols, _ = self._align_expected_output(
823
862
  inference_method, dataset, expected_output_cols, output_cols_prefix
824
863
  )
825
864
  elif isinstance(dataset, pd.DataFrame):
@@ -882,7 +921,7 @@ class IterativeImputer(BaseTransformer):
882
921
  drop_input_cols=self._drop_input_cols,
883
922
  expected_output_cols_type="float",
884
923
  )
885
- expected_output_cols = self._align_expected_output_names(
924
+ expected_output_cols, _ = self._align_expected_output(
886
925
  inference_method, dataset, expected_output_cols, output_cols_prefix
887
926
  )
888
927
 
@@ -947,7 +986,7 @@ class IterativeImputer(BaseTransformer):
947
986
  drop_input_cols = self._drop_input_cols,
948
987
  expected_output_cols_type="float",
949
988
  )
950
- expected_output_cols = self._align_expected_output_names(
989
+ expected_output_cols, _ = self._align_expected_output(
951
990
  inference_method, dataset, expected_output_cols, output_cols_prefix
952
991
  )
953
992
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -525,12 +522,23 @@ class KNNImputer(BaseTransformer):
525
522
  autogenerated=self._autogenerated,
526
523
  subproject=_SUBPROJECT,
527
524
  )
528
- output_result, fitted_estimator = model_trainer.train_fit_predict(
529
- drop_input_cols=self._drop_input_cols,
530
- expected_output_cols_list=(
531
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
532
- ),
525
+ expected_output_cols = (
526
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
533
527
  )
528
+ if isinstance(dataset, DataFrame):
529
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
530
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
531
+ )
532
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
533
+ drop_input_cols=self._drop_input_cols,
534
+ expected_output_cols_list=expected_output_cols,
535
+ example_output_pd_df=example_output_pd_df,
536
+ )
537
+ else:
538
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
539
+ drop_input_cols=self._drop_input_cols,
540
+ expected_output_cols_list=expected_output_cols,
541
+ )
534
542
  self._sklearn_object = fitted_estimator
535
543
  self._is_fitted = True
536
544
  return output_result
@@ -611,12 +619,41 @@ class KNNImputer(BaseTransformer):
611
619
 
612
620
  return rv
613
621
 
614
- def _align_expected_output_names(
615
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
616
- ) -> List[str]:
622
+ def _align_expected_output(
623
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
624
+ ) -> Tuple[List[str], pd.DataFrame]:
625
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
626
+ and output dataframe with 1 line.
627
+ If the method is fit_predict, run 2 lines of data.
628
+ """
617
629
  # in case the inferred output column names dimension is different
618
630
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
619
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
631
+
632
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
633
+ # so change the minimum of number of rows to 2
634
+ num_examples = 2
635
+ statement_params = telemetry.get_function_usage_statement_params(
636
+ project=_PROJECT,
637
+ subproject=_SUBPROJECT,
638
+ function_name=telemetry.get_statement_params_full_func_name(
639
+ inspect.currentframe(), KNNImputer.__class__.__name__
640
+ ),
641
+ api_calls=[Session.call],
642
+ custom_tags={"autogen": True} if self._autogenerated else None,
643
+ )
644
+ if output_cols_prefix == "fit_predict_":
645
+ if hasattr(self._sklearn_object, "n_clusters"):
646
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
647
+ num_examples = self._sklearn_object.n_clusters
648
+ elif hasattr(self._sklearn_object, "min_samples"):
649
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
650
+ num_examples = self._sklearn_object.min_samples
651
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
652
+ # LocalOutlierFactor expects n_neighbors <= n_samples
653
+ num_examples = self._sklearn_object.n_neighbors
654
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
655
+ else:
656
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
620
657
 
621
658
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
622
659
  # seen during the fit.
@@ -628,12 +665,14 @@ class KNNImputer(BaseTransformer):
628
665
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
629
666
  if self.sample_weight_col:
630
667
  output_df_columns_set -= set(self.sample_weight_col)
668
+
631
669
  # if the dimension of inferred output column names is correct; use it
632
670
  if len(expected_output_cols_list) == len(output_df_columns_set):
633
- return expected_output_cols_list
671
+ return expected_output_cols_list, output_df_pd
634
672
  # otherwise, use the sklearn estimator's output
635
673
  else:
636
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
674
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
675
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
637
676
 
638
677
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
639
678
  @telemetry.send_api_usage_telemetry(
@@ -679,7 +718,7 @@ class KNNImputer(BaseTransformer):
679
718
  drop_input_cols=self._drop_input_cols,
680
719
  expected_output_cols_type="float",
681
720
  )
682
- expected_output_cols = self._align_expected_output_names(
721
+ expected_output_cols, _ = self._align_expected_output(
683
722
  inference_method, dataset, expected_output_cols, output_cols_prefix
684
723
  )
685
724
 
@@ -745,7 +784,7 @@ class KNNImputer(BaseTransformer):
745
784
  drop_input_cols=self._drop_input_cols,
746
785
  expected_output_cols_type="float",
747
786
  )
748
- expected_output_cols = self._align_expected_output_names(
787
+ expected_output_cols, _ = self._align_expected_output(
749
788
  inference_method, dataset, expected_output_cols, output_cols_prefix
750
789
  )
751
790
  elif isinstance(dataset, pd.DataFrame):
@@ -808,7 +847,7 @@ class KNNImputer(BaseTransformer):
808
847
  drop_input_cols=self._drop_input_cols,
809
848
  expected_output_cols_type="float",
810
849
  )
811
- expected_output_cols = self._align_expected_output_names(
850
+ expected_output_cols, _ = self._align_expected_output(
812
851
  inference_method, dataset, expected_output_cols, output_cols_prefix
813
852
  )
814
853
 
@@ -873,7 +912,7 @@ class KNNImputer(BaseTransformer):
873
912
  drop_input_cols = self._drop_input_cols,
874
913
  expected_output_cols_type="float",
875
914
  )
876
- expected_output_cols = self._align_expected_output_names(
915
+ expected_output_cols, _ = self._align_expected_output(
877
916
  inference_method, dataset, expected_output_cols, output_cols_prefix
878
917
  )
879
918