snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. snowflake/ml/_internal/telemetry.py +142 -20
  2. snowflake/ml/_internal/utils/identifier.py +48 -11
  3. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  4. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  5. snowflake/ml/_internal/utils/table_manager.py +19 -1
  6. snowflake/ml/_internal/utils/uri.py +2 -2
  7. snowflake/ml/data/data_connector.py +33 -7
  8. snowflake/ml/data/torch_utils.py +68 -0
  9. snowflake/ml/dataset/dataset.py +1 -3
  10. snowflake/ml/feature_store/feature_store.py +41 -17
  11. snowflake/ml/feature_store/feature_view.py +2 -2
  12. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  13. snowflake/ml/fileset/fileset.py +1 -1
  14. snowflake/ml/fileset/sfcfs.py +9 -3
  15. snowflake/ml/model/_client/model/model_version_impl.py +22 -7
  16. snowflake/ml/model/_client/ops/model_ops.py +39 -3
  17. snowflake/ml/model/_client/ops/service_ops.py +198 -7
  18. snowflake/ml/model/_client/service/model_deployment_spec.py +4 -5
  19. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -2
  20. snowflake/ml/model/_client/sql/service.py +85 -18
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  23. snowflake/ml/model/_model_composer/model_composer.py +2 -0
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -8
  25. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  26. snowflake/ml/model/_packager/model_handlers/catboost.py +17 -15
  27. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +23 -15
  28. snowflake/ml/model/_packager/model_handlers/lightgbm.py +15 -57
  29. snowflake/ml/model/_packager/model_handlers/llm.py +4 -2
  30. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  31. snowflake/ml/model/_packager/model_handlers/sklearn.py +36 -24
  32. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +119 -6
  33. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  34. snowflake/ml/model/_packager/model_handlers/xgboost.py +48 -48
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +10 -7
  36. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +0 -8
  37. snowflake/ml/model/_packager/model_packager.py +2 -0
  38. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  39. snowflake/ml/model/_signatures/utils.py +9 -0
  40. snowflake/ml/model/models/llm.py +3 -1
  41. snowflake/ml/model/type_hints.py +9 -1
  42. snowflake/ml/modeling/_internal/constants.py +1 -0
  43. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  44. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  45. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  46. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  48. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  49. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  50. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  51. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  52. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  53. snowflake/ml/modeling/cluster/birch.py +60 -21
  54. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  55. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  56. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  57. snowflake/ml/modeling/cluster/k_means.py +60 -21
  58. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  59. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  60. snowflake/ml/modeling/cluster/optics.py +60 -21
  61. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  62. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  63. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  64. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  65. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  66. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  67. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  68. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  69. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  70. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  71. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  72. snowflake/ml/modeling/covariance/oas.py +60 -21
  73. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  74. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  75. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  76. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  77. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  78. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  79. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  80. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  81. snowflake/ml/modeling/decomposition/pca.py +60 -21
  82. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  83. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  84. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  85. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  86. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  87. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  88. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  89. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  90. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  91. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  92. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  93. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  94. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  95. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  96. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  97. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  98. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  99. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  100. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  101. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  102. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  103. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  104. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  105. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  106. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  107. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  108. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  109. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  111. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  112. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  113. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  114. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  123. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  125. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  129. snowflake/ml/modeling/linear_model/lars.py +60 -21
  130. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  131. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  136. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  146. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  149. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  158. snowflake/ml/modeling/manifold/isomap.py +60 -21
  159. snowflake/ml/modeling/manifold/mds.py +60 -21
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  161. snowflake/ml/modeling/manifold/tsne.py +60 -21
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  174. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  184. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  185. snowflake/ml/modeling/pipeline/pipeline.py +1 -12
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  189. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  190. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  191. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  192. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  193. snowflake/ml/modeling/svm/svc.py +60 -21
  194. snowflake/ml/modeling/svm/svr.py +60 -21
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  203. snowflake/ml/registry/_manager/model_manager.py +4 -0
  204. snowflake/ml/registry/model_registry.py +1 -1
  205. snowflake/ml/registry/registry.py +1 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +23 -4
  208. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +211 -209
  209. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  210. snowflake/ml/data/torch_dataset.py +0 -33
  211. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  212. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -525,12 +522,23 @@ class KernelDensity(BaseTransformer):
525
522
  autogenerated=self._autogenerated,
526
523
  subproject=_SUBPROJECT,
527
524
  )
528
- output_result, fitted_estimator = model_trainer.train_fit_predict(
529
- drop_input_cols=self._drop_input_cols,
530
- expected_output_cols_list=(
531
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
532
- ),
525
+ expected_output_cols = (
526
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
533
527
  )
528
+ if isinstance(dataset, DataFrame):
529
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
530
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
531
+ )
532
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
533
+ drop_input_cols=self._drop_input_cols,
534
+ expected_output_cols_list=expected_output_cols,
535
+ example_output_pd_df=example_output_pd_df,
536
+ )
537
+ else:
538
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
539
+ drop_input_cols=self._drop_input_cols,
540
+ expected_output_cols_list=expected_output_cols,
541
+ )
534
542
  self._sklearn_object = fitted_estimator
535
543
  self._is_fitted = True
536
544
  return output_result
@@ -609,12 +617,41 @@ class KernelDensity(BaseTransformer):
609
617
 
610
618
  return rv
611
619
 
612
- def _align_expected_output_names(
613
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
614
- ) -> List[str]:
620
+ def _align_expected_output(
621
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
622
+ ) -> Tuple[List[str], pd.DataFrame]:
623
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
624
+ and output dataframe with 1 line.
625
+ If the method is fit_predict, run 2 lines of data.
626
+ """
615
627
  # in case the inferred output column names dimension is different
616
628
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
617
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
629
+
630
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
631
+ # so change the minimum of number of rows to 2
632
+ num_examples = 2
633
+ statement_params = telemetry.get_function_usage_statement_params(
634
+ project=_PROJECT,
635
+ subproject=_SUBPROJECT,
636
+ function_name=telemetry.get_statement_params_full_func_name(
637
+ inspect.currentframe(), KernelDensity.__class__.__name__
638
+ ),
639
+ api_calls=[Session.call],
640
+ custom_tags={"autogen": True} if self._autogenerated else None,
641
+ )
642
+ if output_cols_prefix == "fit_predict_":
643
+ if hasattr(self._sklearn_object, "n_clusters"):
644
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
645
+ num_examples = self._sklearn_object.n_clusters
646
+ elif hasattr(self._sklearn_object, "min_samples"):
647
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
648
+ num_examples = self._sklearn_object.min_samples
649
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
650
+ # LocalOutlierFactor expects n_neighbors <= n_samples
651
+ num_examples = self._sklearn_object.n_neighbors
652
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
653
+ else:
654
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
618
655
 
619
656
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
620
657
  # seen during the fit.
@@ -626,12 +663,14 @@ class KernelDensity(BaseTransformer):
626
663
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
627
664
  if self.sample_weight_col:
628
665
  output_df_columns_set -= set(self.sample_weight_col)
666
+
629
667
  # if the dimension of inferred output column names is correct; use it
630
668
  if len(expected_output_cols_list) == len(output_df_columns_set):
631
- return expected_output_cols_list
669
+ return expected_output_cols_list, output_df_pd
632
670
  # otherwise, use the sklearn estimator's output
633
671
  else:
634
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
672
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
673
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
635
674
 
636
675
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
637
676
  @telemetry.send_api_usage_telemetry(
@@ -677,7 +716,7 @@ class KernelDensity(BaseTransformer):
677
716
  drop_input_cols=self._drop_input_cols,
678
717
  expected_output_cols_type="float",
679
718
  )
680
- expected_output_cols = self._align_expected_output_names(
719
+ expected_output_cols, _ = self._align_expected_output(
681
720
  inference_method, dataset, expected_output_cols, output_cols_prefix
682
721
  )
683
722
 
@@ -743,7 +782,7 @@ class KernelDensity(BaseTransformer):
743
782
  drop_input_cols=self._drop_input_cols,
744
783
  expected_output_cols_type="float",
745
784
  )
746
- expected_output_cols = self._align_expected_output_names(
785
+ expected_output_cols, _ = self._align_expected_output(
747
786
  inference_method, dataset, expected_output_cols, output_cols_prefix
748
787
  )
749
788
  elif isinstance(dataset, pd.DataFrame):
@@ -806,7 +845,7 @@ class KernelDensity(BaseTransformer):
806
845
  drop_input_cols=self._drop_input_cols,
807
846
  expected_output_cols_type="float",
808
847
  )
809
- expected_output_cols = self._align_expected_output_names(
848
+ expected_output_cols, _ = self._align_expected_output(
810
849
  inference_method, dataset, expected_output_cols, output_cols_prefix
811
850
  )
812
851
 
@@ -873,7 +912,7 @@ class KernelDensity(BaseTransformer):
873
912
  drop_input_cols = self._drop_input_cols,
874
913
  expected_output_cols_type="float",
875
914
  )
876
- expected_output_cols = self._align_expected_output_names(
915
+ expected_output_cols, _ = self._align_expected_output(
877
916
  inference_method, dataset, expected_output_cols, output_cols_prefix
878
917
  )
879
918
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -557,12 +554,23 @@ class LocalOutlierFactor(BaseTransformer):
557
554
  autogenerated=self._autogenerated,
558
555
  subproject=_SUBPROJECT,
559
556
  )
560
- output_result, fitted_estimator = model_trainer.train_fit_predict(
561
- drop_input_cols=self._drop_input_cols,
562
- expected_output_cols_list=(
563
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
564
- ),
557
+ expected_output_cols = (
558
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
565
559
  )
560
+ if isinstance(dataset, DataFrame):
561
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
562
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
563
+ )
564
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
565
+ drop_input_cols=self._drop_input_cols,
566
+ expected_output_cols_list=expected_output_cols,
567
+ example_output_pd_df=example_output_pd_df,
568
+ )
569
+ else:
570
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
571
+ drop_input_cols=self._drop_input_cols,
572
+ expected_output_cols_list=expected_output_cols,
573
+ )
566
574
  self._sklearn_object = fitted_estimator
567
575
  self._is_fitted = True
568
576
  return output_result
@@ -641,12 +649,41 @@ class LocalOutlierFactor(BaseTransformer):
641
649
 
642
650
  return rv
643
651
 
644
- def _align_expected_output_names(
645
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
646
- ) -> List[str]:
652
+ def _align_expected_output(
653
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
654
+ ) -> Tuple[List[str], pd.DataFrame]:
655
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
656
+ and output dataframe with 1 line.
657
+ If the method is fit_predict, run 2 lines of data.
658
+ """
647
659
  # in case the inferred output column names dimension is different
648
660
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
649
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
661
+
662
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
663
+ # so change the minimum of number of rows to 2
664
+ num_examples = 2
665
+ statement_params = telemetry.get_function_usage_statement_params(
666
+ project=_PROJECT,
667
+ subproject=_SUBPROJECT,
668
+ function_name=telemetry.get_statement_params_full_func_name(
669
+ inspect.currentframe(), LocalOutlierFactor.__class__.__name__
670
+ ),
671
+ api_calls=[Session.call],
672
+ custom_tags={"autogen": True} if self._autogenerated else None,
673
+ )
674
+ if output_cols_prefix == "fit_predict_":
675
+ if hasattr(self._sklearn_object, "n_clusters"):
676
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
677
+ num_examples = self._sklearn_object.n_clusters
678
+ elif hasattr(self._sklearn_object, "min_samples"):
679
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
680
+ num_examples = self._sklearn_object.min_samples
681
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
682
+ # LocalOutlierFactor expects n_neighbors <= n_samples
683
+ num_examples = self._sklearn_object.n_neighbors
684
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
685
+ else:
686
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
650
687
 
651
688
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
652
689
  # seen during the fit.
@@ -658,12 +695,14 @@ class LocalOutlierFactor(BaseTransformer):
658
695
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
659
696
  if self.sample_weight_col:
660
697
  output_df_columns_set -= set(self.sample_weight_col)
698
+
661
699
  # if the dimension of inferred output column names is correct; use it
662
700
  if len(expected_output_cols_list) == len(output_df_columns_set):
663
- return expected_output_cols_list
701
+ return expected_output_cols_list, output_df_pd
664
702
  # otherwise, use the sklearn estimator's output
665
703
  else:
666
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
704
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
705
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
667
706
 
668
707
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
669
708
  @telemetry.send_api_usage_telemetry(
@@ -709,7 +748,7 @@ class LocalOutlierFactor(BaseTransformer):
709
748
  drop_input_cols=self._drop_input_cols,
710
749
  expected_output_cols_type="float",
711
750
  )
712
- expected_output_cols = self._align_expected_output_names(
751
+ expected_output_cols, _ = self._align_expected_output(
713
752
  inference_method, dataset, expected_output_cols, output_cols_prefix
714
753
  )
715
754
 
@@ -775,7 +814,7 @@ class LocalOutlierFactor(BaseTransformer):
775
814
  drop_input_cols=self._drop_input_cols,
776
815
  expected_output_cols_type="float",
777
816
  )
778
- expected_output_cols = self._align_expected_output_names(
817
+ expected_output_cols, _ = self._align_expected_output(
779
818
  inference_method, dataset, expected_output_cols, output_cols_prefix
780
819
  )
781
820
  elif isinstance(dataset, pd.DataFrame):
@@ -840,7 +879,7 @@ class LocalOutlierFactor(BaseTransformer):
840
879
  drop_input_cols=self._drop_input_cols,
841
880
  expected_output_cols_type="float",
842
881
  )
843
- expected_output_cols = self._align_expected_output_names(
882
+ expected_output_cols, _ = self._align_expected_output(
844
883
  inference_method, dataset, expected_output_cols, output_cols_prefix
845
884
  )
846
885
 
@@ -907,7 +946,7 @@ class LocalOutlierFactor(BaseTransformer):
907
946
  drop_input_cols = self._drop_input_cols,
908
947
  expected_output_cols_type="float",
909
948
  )
910
- expected_output_cols = self._align_expected_output_names(
949
+ expected_output_cols, _ = self._align_expected_output(
911
950
  inference_method, dataset, expected_output_cols, output_cols_prefix
912
951
  )
913
952
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -486,12 +483,23 @@ class NearestCentroid(BaseTransformer):
486
483
  autogenerated=self._autogenerated,
487
484
  subproject=_SUBPROJECT,
488
485
  )
489
- output_result, fitted_estimator = model_trainer.train_fit_predict(
490
- drop_input_cols=self._drop_input_cols,
491
- expected_output_cols_list=(
492
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
493
- ),
486
+ expected_output_cols = (
487
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
494
488
  )
489
+ if isinstance(dataset, DataFrame):
490
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
491
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
492
+ )
493
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
494
+ drop_input_cols=self._drop_input_cols,
495
+ expected_output_cols_list=expected_output_cols,
496
+ example_output_pd_df=example_output_pd_df,
497
+ )
498
+ else:
499
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
500
+ drop_input_cols=self._drop_input_cols,
501
+ expected_output_cols_list=expected_output_cols,
502
+ )
495
503
  self._sklearn_object = fitted_estimator
496
504
  self._is_fitted = True
497
505
  return output_result
@@ -570,12 +578,41 @@ class NearestCentroid(BaseTransformer):
570
578
 
571
579
  return rv
572
580
 
573
- def _align_expected_output_names(
574
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
575
- ) -> List[str]:
581
+ def _align_expected_output(
582
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
583
+ ) -> Tuple[List[str], pd.DataFrame]:
584
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
585
+ and output dataframe with 1 line.
586
+ If the method is fit_predict, run 2 lines of data.
587
+ """
576
588
  # in case the inferred output column names dimension is different
577
589
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
578
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
590
+
591
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
592
+ # so change the minimum of number of rows to 2
593
+ num_examples = 2
594
+ statement_params = telemetry.get_function_usage_statement_params(
595
+ project=_PROJECT,
596
+ subproject=_SUBPROJECT,
597
+ function_name=telemetry.get_statement_params_full_func_name(
598
+ inspect.currentframe(), NearestCentroid.__class__.__name__
599
+ ),
600
+ api_calls=[Session.call],
601
+ custom_tags={"autogen": True} if self._autogenerated else None,
602
+ )
603
+ if output_cols_prefix == "fit_predict_":
604
+ if hasattr(self._sklearn_object, "n_clusters"):
605
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
606
+ num_examples = self._sklearn_object.n_clusters
607
+ elif hasattr(self._sklearn_object, "min_samples"):
608
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
609
+ num_examples = self._sklearn_object.min_samples
610
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
611
+ # LocalOutlierFactor expects n_neighbors <= n_samples
612
+ num_examples = self._sklearn_object.n_neighbors
613
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
614
+ else:
615
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
579
616
 
580
617
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
581
618
  # seen during the fit.
@@ -587,12 +624,14 @@ class NearestCentroid(BaseTransformer):
587
624
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
588
625
  if self.sample_weight_col:
589
626
  output_df_columns_set -= set(self.sample_weight_col)
627
+
590
628
  # if the dimension of inferred output column names is correct; use it
591
629
  if len(expected_output_cols_list) == len(output_df_columns_set):
592
- return expected_output_cols_list
630
+ return expected_output_cols_list, output_df_pd
593
631
  # otherwise, use the sklearn estimator's output
594
632
  else:
595
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
633
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
634
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
596
635
 
597
636
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
598
637
  @telemetry.send_api_usage_telemetry(
@@ -638,7 +677,7 @@ class NearestCentroid(BaseTransformer):
638
677
  drop_input_cols=self._drop_input_cols,
639
678
  expected_output_cols_type="float",
640
679
  )
641
- expected_output_cols = self._align_expected_output_names(
680
+ expected_output_cols, _ = self._align_expected_output(
642
681
  inference_method, dataset, expected_output_cols, output_cols_prefix
643
682
  )
644
683
 
@@ -704,7 +743,7 @@ class NearestCentroid(BaseTransformer):
704
743
  drop_input_cols=self._drop_input_cols,
705
744
  expected_output_cols_type="float",
706
745
  )
707
- expected_output_cols = self._align_expected_output_names(
746
+ expected_output_cols, _ = self._align_expected_output(
708
747
  inference_method, dataset, expected_output_cols, output_cols_prefix
709
748
  )
710
749
  elif isinstance(dataset, pd.DataFrame):
@@ -767,7 +806,7 @@ class NearestCentroid(BaseTransformer):
767
806
  drop_input_cols=self._drop_input_cols,
768
807
  expected_output_cols_type="float",
769
808
  )
770
- expected_output_cols = self._align_expected_output_names(
809
+ expected_output_cols, _ = self._align_expected_output(
771
810
  inference_method, dataset, expected_output_cols, output_cols_prefix
772
811
  )
773
812
 
@@ -832,7 +871,7 @@ class NearestCentroid(BaseTransformer):
832
871
  drop_input_cols = self._drop_input_cols,
833
872
  expected_output_cols_type="float",
834
873
  )
835
- expected_output_cols = self._align_expected_output_names(
874
+ expected_output_cols, _ = self._align_expected_output(
836
875
  inference_method, dataset, expected_output_cols, output_cols_prefix
837
876
  )
838
877
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -536,12 +533,23 @@ class NearestNeighbors(BaseTransformer):
536
533
  autogenerated=self._autogenerated,
537
534
  subproject=_SUBPROJECT,
538
535
  )
539
- output_result, fitted_estimator = model_trainer.train_fit_predict(
540
- drop_input_cols=self._drop_input_cols,
541
- expected_output_cols_list=(
542
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
543
- ),
536
+ expected_output_cols = (
537
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
544
538
  )
539
+ if isinstance(dataset, DataFrame):
540
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
541
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
542
+ )
543
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
544
+ drop_input_cols=self._drop_input_cols,
545
+ expected_output_cols_list=expected_output_cols,
546
+ example_output_pd_df=example_output_pd_df,
547
+ )
548
+ else:
549
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
550
+ drop_input_cols=self._drop_input_cols,
551
+ expected_output_cols_list=expected_output_cols,
552
+ )
545
553
  self._sklearn_object = fitted_estimator
546
554
  self._is_fitted = True
547
555
  return output_result
@@ -620,12 +628,41 @@ class NearestNeighbors(BaseTransformer):
620
628
 
621
629
  return rv
622
630
 
623
- def _align_expected_output_names(
624
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
625
- ) -> List[str]:
631
+ def _align_expected_output(
632
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
633
+ ) -> Tuple[List[str], pd.DataFrame]:
634
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
635
+ and output dataframe with 1 line.
636
+ If the method is fit_predict, run 2 lines of data.
637
+ """
626
638
  # in case the inferred output column names dimension is different
627
639
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
628
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
640
+
641
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
642
+ # so change the minimum of number of rows to 2
643
+ num_examples = 2
644
+ statement_params = telemetry.get_function_usage_statement_params(
645
+ project=_PROJECT,
646
+ subproject=_SUBPROJECT,
647
+ function_name=telemetry.get_statement_params_full_func_name(
648
+ inspect.currentframe(), NearestNeighbors.__class__.__name__
649
+ ),
650
+ api_calls=[Session.call],
651
+ custom_tags={"autogen": True} if self._autogenerated else None,
652
+ )
653
+ if output_cols_prefix == "fit_predict_":
654
+ if hasattr(self._sklearn_object, "n_clusters"):
655
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
656
+ num_examples = self._sklearn_object.n_clusters
657
+ elif hasattr(self._sklearn_object, "min_samples"):
658
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
659
+ num_examples = self._sklearn_object.min_samples
660
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
661
+ # LocalOutlierFactor expects n_neighbors <= n_samples
662
+ num_examples = self._sklearn_object.n_neighbors
663
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
664
+ else:
665
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
629
666
 
630
667
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
631
668
  # seen during the fit.
@@ -637,12 +674,14 @@ class NearestNeighbors(BaseTransformer):
637
674
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
638
675
  if self.sample_weight_col:
639
676
  output_df_columns_set -= set(self.sample_weight_col)
677
+
640
678
  # if the dimension of inferred output column names is correct; use it
641
679
  if len(expected_output_cols_list) == len(output_df_columns_set):
642
- return expected_output_cols_list
680
+ return expected_output_cols_list, output_df_pd
643
681
  # otherwise, use the sklearn estimator's output
644
682
  else:
645
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
683
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
684
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
646
685
 
647
686
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
648
687
  @telemetry.send_api_usage_telemetry(
@@ -688,7 +727,7 @@ class NearestNeighbors(BaseTransformer):
688
727
  drop_input_cols=self._drop_input_cols,
689
728
  expected_output_cols_type="float",
690
729
  )
691
- expected_output_cols = self._align_expected_output_names(
730
+ expected_output_cols, _ = self._align_expected_output(
692
731
  inference_method, dataset, expected_output_cols, output_cols_prefix
693
732
  )
694
733
 
@@ -754,7 +793,7 @@ class NearestNeighbors(BaseTransformer):
754
793
  drop_input_cols=self._drop_input_cols,
755
794
  expected_output_cols_type="float",
756
795
  )
757
- expected_output_cols = self._align_expected_output_names(
796
+ expected_output_cols, _ = self._align_expected_output(
758
797
  inference_method, dataset, expected_output_cols, output_cols_prefix
759
798
  )
760
799
  elif isinstance(dataset, pd.DataFrame):
@@ -817,7 +856,7 @@ class NearestNeighbors(BaseTransformer):
817
856
  drop_input_cols=self._drop_input_cols,
818
857
  expected_output_cols_type="float",
819
858
  )
820
- expected_output_cols = self._align_expected_output_names(
859
+ expected_output_cols, _ = self._align_expected_output(
821
860
  inference_method, dataset, expected_output_cols, output_cols_prefix
822
861
  )
823
862
 
@@ -882,7 +921,7 @@ class NearestNeighbors(BaseTransformer):
882
921
  drop_input_cols = self._drop_input_cols,
883
922
  expected_output_cols_type="float",
884
923
  )
885
- expected_output_cols = self._align_expected_output_names(
924
+ expected_output_cols, _ = self._align_expected_output(
886
925
  inference_method, dataset, expected_output_cols, output_cols_prefix
887
926
  )
888
927