snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. snowflake/ml/_internal/telemetry.py +142 -20
  2. snowflake/ml/_internal/utils/identifier.py +48 -11
  3. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  4. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  5. snowflake/ml/_internal/utils/table_manager.py +19 -1
  6. snowflake/ml/_internal/utils/uri.py +2 -2
  7. snowflake/ml/data/data_connector.py +33 -7
  8. snowflake/ml/data/torch_utils.py +68 -0
  9. snowflake/ml/dataset/dataset.py +1 -3
  10. snowflake/ml/feature_store/feature_store.py +41 -17
  11. snowflake/ml/feature_store/feature_view.py +2 -2
  12. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  13. snowflake/ml/fileset/fileset.py +1 -1
  14. snowflake/ml/fileset/sfcfs.py +9 -3
  15. snowflake/ml/model/_client/model/model_version_impl.py +22 -7
  16. snowflake/ml/model/_client/ops/model_ops.py +39 -3
  17. snowflake/ml/model/_client/ops/service_ops.py +198 -7
  18. snowflake/ml/model/_client/service/model_deployment_spec.py +4 -5
  19. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -2
  20. snowflake/ml/model/_client/sql/service.py +85 -18
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  23. snowflake/ml/model/_model_composer/model_composer.py +2 -0
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -8
  25. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  26. snowflake/ml/model/_packager/model_handlers/catboost.py +17 -15
  27. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +23 -15
  28. snowflake/ml/model/_packager/model_handlers/lightgbm.py +15 -57
  29. snowflake/ml/model/_packager/model_handlers/llm.py +4 -2
  30. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  31. snowflake/ml/model/_packager/model_handlers/sklearn.py +36 -24
  32. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +119 -6
  33. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  34. snowflake/ml/model/_packager/model_handlers/xgboost.py +48 -48
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +10 -7
  36. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +0 -8
  37. snowflake/ml/model/_packager/model_packager.py +2 -0
  38. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  39. snowflake/ml/model/_signatures/utils.py +9 -0
  40. snowflake/ml/model/models/llm.py +3 -1
  41. snowflake/ml/model/type_hints.py +9 -1
  42. snowflake/ml/modeling/_internal/constants.py +1 -0
  43. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  44. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  45. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  46. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  48. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  49. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  50. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  51. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  52. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  53. snowflake/ml/modeling/cluster/birch.py +60 -21
  54. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  55. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  56. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  57. snowflake/ml/modeling/cluster/k_means.py +60 -21
  58. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  59. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  60. snowflake/ml/modeling/cluster/optics.py +60 -21
  61. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  62. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  63. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  64. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  65. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  66. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  67. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  68. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  69. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  70. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  71. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  72. snowflake/ml/modeling/covariance/oas.py +60 -21
  73. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  74. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  75. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  76. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  77. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  78. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  79. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  80. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  81. snowflake/ml/modeling/decomposition/pca.py +60 -21
  82. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  83. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  84. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  85. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  86. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  87. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  88. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  89. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  90. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  91. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  92. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  93. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  94. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  95. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  96. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  97. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  98. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  99. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  100. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  101. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  102. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  103. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  104. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  105. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  106. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  107. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  108. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  109. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  111. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  112. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  113. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  114. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  123. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  125. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  129. snowflake/ml/modeling/linear_model/lars.py +60 -21
  130. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  131. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  136. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  146. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  149. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  158. snowflake/ml/modeling/manifold/isomap.py +60 -21
  159. snowflake/ml/modeling/manifold/mds.py +60 -21
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  161. snowflake/ml/modeling/manifold/tsne.py +60 -21
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  174. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  184. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  185. snowflake/ml/modeling/pipeline/pipeline.py +1 -12
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  189. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  190. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  191. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  192. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  193. snowflake/ml/modeling/svm/svc.py +60 -21
  194. snowflake/ml/modeling/svm/svr.py +60 -21
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  203. snowflake/ml/registry/_manager/model_manager.py +4 -0
  204. snowflake/ml/registry/model_registry.py +1 -1
  205. snowflake/ml/registry/registry.py +1 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +23 -4
  208. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +211 -209
  209. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  210. snowflake/ml/data/torch_dataset.py +0 -33
  211. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  212. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -561,12 +558,23 @@ class BisectingKMeans(BaseTransformer):
561
558
  autogenerated=self._autogenerated,
562
559
  subproject=_SUBPROJECT,
563
560
  )
564
- output_result, fitted_estimator = model_trainer.train_fit_predict(
565
- drop_input_cols=self._drop_input_cols,
566
- expected_output_cols_list=(
567
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
568
- ),
561
+ expected_output_cols = (
562
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
569
563
  )
564
+ if isinstance(dataset, DataFrame):
565
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
566
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
567
+ )
568
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
569
+ drop_input_cols=self._drop_input_cols,
570
+ expected_output_cols_list=expected_output_cols,
571
+ example_output_pd_df=example_output_pd_df,
572
+ )
573
+ else:
574
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
575
+ drop_input_cols=self._drop_input_cols,
576
+ expected_output_cols_list=expected_output_cols,
577
+ )
570
578
  self._sklearn_object = fitted_estimator
571
579
  self._is_fitted = True
572
580
  return output_result
@@ -647,12 +655,41 @@ class BisectingKMeans(BaseTransformer):
647
655
 
648
656
  return rv
649
657
 
650
- def _align_expected_output_names(
651
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
652
- ) -> List[str]:
658
+ def _align_expected_output(
659
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
660
+ ) -> Tuple[List[str], pd.DataFrame]:
661
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
662
+ and output dataframe with 1 line.
663
+ If the method is fit_predict, run 2 lines of data.
664
+ """
653
665
  # in case the inferred output column names dimension is different
654
666
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
655
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
667
+
668
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
669
+ # so change the minimum of number of rows to 2
670
+ num_examples = 2
671
+ statement_params = telemetry.get_function_usage_statement_params(
672
+ project=_PROJECT,
673
+ subproject=_SUBPROJECT,
674
+ function_name=telemetry.get_statement_params_full_func_name(
675
+ inspect.currentframe(), BisectingKMeans.__class__.__name__
676
+ ),
677
+ api_calls=[Session.call],
678
+ custom_tags={"autogen": True} if self._autogenerated else None,
679
+ )
680
+ if output_cols_prefix == "fit_predict_":
681
+ if hasattr(self._sklearn_object, "n_clusters"):
682
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
683
+ num_examples = self._sklearn_object.n_clusters
684
+ elif hasattr(self._sklearn_object, "min_samples"):
685
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
686
+ num_examples = self._sklearn_object.min_samples
687
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
688
+ # LocalOutlierFactor expects n_neighbors <= n_samples
689
+ num_examples = self._sklearn_object.n_neighbors
690
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
691
+ else:
692
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
656
693
 
657
694
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
658
695
  # seen during the fit.
@@ -664,12 +701,14 @@ class BisectingKMeans(BaseTransformer):
664
701
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
665
702
  if self.sample_weight_col:
666
703
  output_df_columns_set -= set(self.sample_weight_col)
704
+
667
705
  # if the dimension of inferred output column names is correct; use it
668
706
  if len(expected_output_cols_list) == len(output_df_columns_set):
669
- return expected_output_cols_list
707
+ return expected_output_cols_list, output_df_pd
670
708
  # otherwise, use the sklearn estimator's output
671
709
  else:
672
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
710
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
711
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
673
712
 
674
713
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
675
714
  @telemetry.send_api_usage_telemetry(
@@ -715,7 +754,7 @@ class BisectingKMeans(BaseTransformer):
715
754
  drop_input_cols=self._drop_input_cols,
716
755
  expected_output_cols_type="float",
717
756
  )
718
- expected_output_cols = self._align_expected_output_names(
757
+ expected_output_cols, _ = self._align_expected_output(
719
758
  inference_method, dataset, expected_output_cols, output_cols_prefix
720
759
  )
721
760
 
@@ -781,7 +820,7 @@ class BisectingKMeans(BaseTransformer):
781
820
  drop_input_cols=self._drop_input_cols,
782
821
  expected_output_cols_type="float",
783
822
  )
784
- expected_output_cols = self._align_expected_output_names(
823
+ expected_output_cols, _ = self._align_expected_output(
785
824
  inference_method, dataset, expected_output_cols, output_cols_prefix
786
825
  )
787
826
  elif isinstance(dataset, pd.DataFrame):
@@ -844,7 +883,7 @@ class BisectingKMeans(BaseTransformer):
844
883
  drop_input_cols=self._drop_input_cols,
845
884
  expected_output_cols_type="float",
846
885
  )
847
- expected_output_cols = self._align_expected_output_names(
886
+ expected_output_cols, _ = self._align_expected_output(
848
887
  inference_method, dataset, expected_output_cols, output_cols_prefix
849
888
  )
850
889
 
@@ -909,7 +948,7 @@ class BisectingKMeans(BaseTransformer):
909
948
  drop_input_cols = self._drop_input_cols,
910
949
  expected_output_cols_type="float",
911
950
  )
912
- expected_output_cols = self._align_expected_output_names(
951
+ expected_output_cols, _ = self._align_expected_output(
913
952
  inference_method, dataset, expected_output_cols, output_cols_prefix
914
953
  )
915
954
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -525,12 +522,23 @@ class DBSCAN(BaseTransformer):
525
522
  autogenerated=self._autogenerated,
526
523
  subproject=_SUBPROJECT,
527
524
  )
528
- output_result, fitted_estimator = model_trainer.train_fit_predict(
529
- drop_input_cols=self._drop_input_cols,
530
- expected_output_cols_list=(
531
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
532
- ),
525
+ expected_output_cols = (
526
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
533
527
  )
528
+ if isinstance(dataset, DataFrame):
529
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
530
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
531
+ )
532
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
533
+ drop_input_cols=self._drop_input_cols,
534
+ expected_output_cols_list=expected_output_cols,
535
+ example_output_pd_df=example_output_pd_df,
536
+ )
537
+ else:
538
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
539
+ drop_input_cols=self._drop_input_cols,
540
+ expected_output_cols_list=expected_output_cols,
541
+ )
534
542
  self._sklearn_object = fitted_estimator
535
543
  self._is_fitted = True
536
544
  return output_result
@@ -609,12 +617,41 @@ class DBSCAN(BaseTransformer):
609
617
 
610
618
  return rv
611
619
 
612
- def _align_expected_output_names(
613
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
614
- ) -> List[str]:
620
+ def _align_expected_output(
621
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
622
+ ) -> Tuple[List[str], pd.DataFrame]:
623
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
624
+ and output dataframe with 1 line.
625
+ If the method is fit_predict, run 2 lines of data.
626
+ """
615
627
  # in case the inferred output column names dimension is different
616
628
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
617
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
629
+
630
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
631
+ # so change the minimum of number of rows to 2
632
+ num_examples = 2
633
+ statement_params = telemetry.get_function_usage_statement_params(
634
+ project=_PROJECT,
635
+ subproject=_SUBPROJECT,
636
+ function_name=telemetry.get_statement_params_full_func_name(
637
+ inspect.currentframe(), DBSCAN.__class__.__name__
638
+ ),
639
+ api_calls=[Session.call],
640
+ custom_tags={"autogen": True} if self._autogenerated else None,
641
+ )
642
+ if output_cols_prefix == "fit_predict_":
643
+ if hasattr(self._sklearn_object, "n_clusters"):
644
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
645
+ num_examples = self._sklearn_object.n_clusters
646
+ elif hasattr(self._sklearn_object, "min_samples"):
647
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
648
+ num_examples = self._sklearn_object.min_samples
649
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
650
+ # LocalOutlierFactor expects n_neighbors <= n_samples
651
+ num_examples = self._sklearn_object.n_neighbors
652
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
653
+ else:
654
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
618
655
 
619
656
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
620
657
  # seen during the fit.
@@ -626,12 +663,14 @@ class DBSCAN(BaseTransformer):
626
663
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
627
664
  if self.sample_weight_col:
628
665
  output_df_columns_set -= set(self.sample_weight_col)
666
+
629
667
  # if the dimension of inferred output column names is correct; use it
630
668
  if len(expected_output_cols_list) == len(output_df_columns_set):
631
- return expected_output_cols_list
669
+ return expected_output_cols_list, output_df_pd
632
670
  # otherwise, use the sklearn estimator's output
633
671
  else:
634
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
672
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
673
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
635
674
 
636
675
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
637
676
  @telemetry.send_api_usage_telemetry(
@@ -677,7 +716,7 @@ class DBSCAN(BaseTransformer):
677
716
  drop_input_cols=self._drop_input_cols,
678
717
  expected_output_cols_type="float",
679
718
  )
680
- expected_output_cols = self._align_expected_output_names(
719
+ expected_output_cols, _ = self._align_expected_output(
681
720
  inference_method, dataset, expected_output_cols, output_cols_prefix
682
721
  )
683
722
 
@@ -743,7 +782,7 @@ class DBSCAN(BaseTransformer):
743
782
  drop_input_cols=self._drop_input_cols,
744
783
  expected_output_cols_type="float",
745
784
  )
746
- expected_output_cols = self._align_expected_output_names(
785
+ expected_output_cols, _ = self._align_expected_output(
747
786
  inference_method, dataset, expected_output_cols, output_cols_prefix
748
787
  )
749
788
  elif isinstance(dataset, pd.DataFrame):
@@ -806,7 +845,7 @@ class DBSCAN(BaseTransformer):
806
845
  drop_input_cols=self._drop_input_cols,
807
846
  expected_output_cols_type="float",
808
847
  )
809
- expected_output_cols = self._align_expected_output_names(
848
+ expected_output_cols, _ = self._align_expected_output(
810
849
  inference_method, dataset, expected_output_cols, output_cols_prefix
811
850
  )
812
851
 
@@ -871,7 +910,7 @@ class DBSCAN(BaseTransformer):
871
910
  drop_input_cols = self._drop_input_cols,
872
911
  expected_output_cols_type="float",
873
912
  )
874
- expected_output_cols = self._align_expected_output_names(
913
+ expected_output_cols, _ = self._align_expected_output(
875
914
  inference_method, dataset, expected_output_cols, output_cols_prefix
876
915
  )
877
916
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -559,12 +556,23 @@ class FeatureAgglomeration(BaseTransformer):
559
556
  autogenerated=self._autogenerated,
560
557
  subproject=_SUBPROJECT,
561
558
  )
562
- output_result, fitted_estimator = model_trainer.train_fit_predict(
563
- drop_input_cols=self._drop_input_cols,
564
- expected_output_cols_list=(
565
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
566
- ),
559
+ expected_output_cols = (
560
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
567
561
  )
562
+ if isinstance(dataset, DataFrame):
563
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
564
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
565
+ )
566
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
567
+ drop_input_cols=self._drop_input_cols,
568
+ expected_output_cols_list=expected_output_cols,
569
+ example_output_pd_df=example_output_pd_df,
570
+ )
571
+ else:
572
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
573
+ drop_input_cols=self._drop_input_cols,
574
+ expected_output_cols_list=expected_output_cols,
575
+ )
568
576
  self._sklearn_object = fitted_estimator
569
577
  self._is_fitted = True
570
578
  return output_result
@@ -645,12 +653,41 @@ class FeatureAgglomeration(BaseTransformer):
645
653
 
646
654
  return rv
647
655
 
648
- def _align_expected_output_names(
649
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
650
- ) -> List[str]:
656
+ def _align_expected_output(
657
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
658
+ ) -> Tuple[List[str], pd.DataFrame]:
659
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
660
+ and output dataframe with 1 line.
661
+ If the method is fit_predict, run 2 lines of data.
662
+ """
651
663
  # in case the inferred output column names dimension is different
652
664
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
653
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
665
+
666
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
667
+ # so change the minimum of number of rows to 2
668
+ num_examples = 2
669
+ statement_params = telemetry.get_function_usage_statement_params(
670
+ project=_PROJECT,
671
+ subproject=_SUBPROJECT,
672
+ function_name=telemetry.get_statement_params_full_func_name(
673
+ inspect.currentframe(), FeatureAgglomeration.__class__.__name__
674
+ ),
675
+ api_calls=[Session.call],
676
+ custom_tags={"autogen": True} if self._autogenerated else None,
677
+ )
678
+ if output_cols_prefix == "fit_predict_":
679
+ if hasattr(self._sklearn_object, "n_clusters"):
680
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
681
+ num_examples = self._sklearn_object.n_clusters
682
+ elif hasattr(self._sklearn_object, "min_samples"):
683
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
684
+ num_examples = self._sklearn_object.min_samples
685
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
686
+ # LocalOutlierFactor expects n_neighbors <= n_samples
687
+ num_examples = self._sklearn_object.n_neighbors
688
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
689
+ else:
690
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
654
691
 
655
692
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
656
693
  # seen during the fit.
@@ -662,12 +699,14 @@ class FeatureAgglomeration(BaseTransformer):
662
699
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
663
700
  if self.sample_weight_col:
664
701
  output_df_columns_set -= set(self.sample_weight_col)
702
+
665
703
  # if the dimension of inferred output column names is correct; use it
666
704
  if len(expected_output_cols_list) == len(output_df_columns_set):
667
- return expected_output_cols_list
705
+ return expected_output_cols_list, output_df_pd
668
706
  # otherwise, use the sklearn estimator's output
669
707
  else:
670
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
708
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
709
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
671
710
 
672
711
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
673
712
  @telemetry.send_api_usage_telemetry(
@@ -713,7 +752,7 @@ class FeatureAgglomeration(BaseTransformer):
713
752
  drop_input_cols=self._drop_input_cols,
714
753
  expected_output_cols_type="float",
715
754
  )
716
- expected_output_cols = self._align_expected_output_names(
755
+ expected_output_cols, _ = self._align_expected_output(
717
756
  inference_method, dataset, expected_output_cols, output_cols_prefix
718
757
  )
719
758
 
@@ -779,7 +818,7 @@ class FeatureAgglomeration(BaseTransformer):
779
818
  drop_input_cols=self._drop_input_cols,
780
819
  expected_output_cols_type="float",
781
820
  )
782
- expected_output_cols = self._align_expected_output_names(
821
+ expected_output_cols, _ = self._align_expected_output(
783
822
  inference_method, dataset, expected_output_cols, output_cols_prefix
784
823
  )
785
824
  elif isinstance(dataset, pd.DataFrame):
@@ -842,7 +881,7 @@ class FeatureAgglomeration(BaseTransformer):
842
881
  drop_input_cols=self._drop_input_cols,
843
882
  expected_output_cols_type="float",
844
883
  )
845
- expected_output_cols = self._align_expected_output_names(
884
+ expected_output_cols, _ = self._align_expected_output(
846
885
  inference_method, dataset, expected_output_cols, output_cols_prefix
847
886
  )
848
887
 
@@ -907,7 +946,7 @@ class FeatureAgglomeration(BaseTransformer):
907
946
  drop_input_cols = self._drop_input_cols,
908
947
  expected_output_cols_type="float",
909
948
  )
910
- expected_output_cols = self._align_expected_output_names(
949
+ expected_output_cols, _ = self._align_expected_output(
911
950
  inference_method, dataset, expected_output_cols, output_cols_prefix
912
951
  )
913
952
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -556,12 +553,23 @@ class KMeans(BaseTransformer):
556
553
  autogenerated=self._autogenerated,
557
554
  subproject=_SUBPROJECT,
558
555
  )
559
- output_result, fitted_estimator = model_trainer.train_fit_predict(
560
- drop_input_cols=self._drop_input_cols,
561
- expected_output_cols_list=(
562
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
563
- ),
556
+ expected_output_cols = (
557
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
564
558
  )
559
+ if isinstance(dataset, DataFrame):
560
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
561
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
562
+ )
563
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
564
+ drop_input_cols=self._drop_input_cols,
565
+ expected_output_cols_list=expected_output_cols,
566
+ example_output_pd_df=example_output_pd_df,
567
+ )
568
+ else:
569
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
570
+ drop_input_cols=self._drop_input_cols,
571
+ expected_output_cols_list=expected_output_cols,
572
+ )
565
573
  self._sklearn_object = fitted_estimator
566
574
  self._is_fitted = True
567
575
  return output_result
@@ -642,12 +650,41 @@ class KMeans(BaseTransformer):
642
650
 
643
651
  return rv
644
652
 
645
- def _align_expected_output_names(
646
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
647
- ) -> List[str]:
653
+ def _align_expected_output(
654
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
655
+ ) -> Tuple[List[str], pd.DataFrame]:
656
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
657
+ and output dataframe with 1 line.
658
+ If the method is fit_predict, run 2 lines of data.
659
+ """
648
660
  # in case the inferred output column names dimension is different
649
661
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
650
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
662
+
663
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
664
+ # so change the minimum of number of rows to 2
665
+ num_examples = 2
666
+ statement_params = telemetry.get_function_usage_statement_params(
667
+ project=_PROJECT,
668
+ subproject=_SUBPROJECT,
669
+ function_name=telemetry.get_statement_params_full_func_name(
670
+ inspect.currentframe(), KMeans.__class__.__name__
671
+ ),
672
+ api_calls=[Session.call],
673
+ custom_tags={"autogen": True} if self._autogenerated else None,
674
+ )
675
+ if output_cols_prefix == "fit_predict_":
676
+ if hasattr(self._sklearn_object, "n_clusters"):
677
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
678
+ num_examples = self._sklearn_object.n_clusters
679
+ elif hasattr(self._sklearn_object, "min_samples"):
680
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
681
+ num_examples = self._sklearn_object.min_samples
682
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
683
+ # LocalOutlierFactor expects n_neighbors <= n_samples
684
+ num_examples = self._sklearn_object.n_neighbors
685
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
686
+ else:
687
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
651
688
 
652
689
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
653
690
  # seen during the fit.
@@ -659,12 +696,14 @@ class KMeans(BaseTransformer):
659
696
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
660
697
  if self.sample_weight_col:
661
698
  output_df_columns_set -= set(self.sample_weight_col)
699
+
662
700
  # if the dimension of inferred output column names is correct; use it
663
701
  if len(expected_output_cols_list) == len(output_df_columns_set):
664
- return expected_output_cols_list
702
+ return expected_output_cols_list, output_df_pd
665
703
  # otherwise, use the sklearn estimator's output
666
704
  else:
667
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
705
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
706
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
668
707
 
669
708
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
670
709
  @telemetry.send_api_usage_telemetry(
@@ -710,7 +749,7 @@ class KMeans(BaseTransformer):
710
749
  drop_input_cols=self._drop_input_cols,
711
750
  expected_output_cols_type="float",
712
751
  )
713
- expected_output_cols = self._align_expected_output_names(
752
+ expected_output_cols, _ = self._align_expected_output(
714
753
  inference_method, dataset, expected_output_cols, output_cols_prefix
715
754
  )
716
755
 
@@ -776,7 +815,7 @@ class KMeans(BaseTransformer):
776
815
  drop_input_cols=self._drop_input_cols,
777
816
  expected_output_cols_type="float",
778
817
  )
779
- expected_output_cols = self._align_expected_output_names(
818
+ expected_output_cols, _ = self._align_expected_output(
780
819
  inference_method, dataset, expected_output_cols, output_cols_prefix
781
820
  )
782
821
  elif isinstance(dataset, pd.DataFrame):
@@ -839,7 +878,7 @@ class KMeans(BaseTransformer):
839
878
  drop_input_cols=self._drop_input_cols,
840
879
  expected_output_cols_type="float",
841
880
  )
842
- expected_output_cols = self._align_expected_output_names(
881
+ expected_output_cols, _ = self._align_expected_output(
843
882
  inference_method, dataset, expected_output_cols, output_cols_prefix
844
883
  )
845
884
 
@@ -904,7 +943,7 @@ class KMeans(BaseTransformer):
904
943
  drop_input_cols = self._drop_input_cols,
905
944
  expected_output_cols_type="float",
906
945
  )
907
- expected_output_cols = self._align_expected_output_names(
946
+ expected_output_cols, _ = self._align_expected_output(
908
947
  inference_method, dataset, expected_output_cols, output_cols_prefix
909
948
  )
910
949