snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. snowflake/ml/_internal/telemetry.py +142 -20
  2. snowflake/ml/_internal/utils/identifier.py +48 -11
  3. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  4. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  5. snowflake/ml/_internal/utils/table_manager.py +19 -1
  6. snowflake/ml/_internal/utils/uri.py +2 -2
  7. snowflake/ml/data/data_connector.py +33 -7
  8. snowflake/ml/data/torch_utils.py +68 -0
  9. snowflake/ml/dataset/dataset.py +1 -3
  10. snowflake/ml/feature_store/feature_store.py +41 -17
  11. snowflake/ml/feature_store/feature_view.py +2 -2
  12. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  13. snowflake/ml/fileset/fileset.py +1 -1
  14. snowflake/ml/fileset/sfcfs.py +9 -3
  15. snowflake/ml/model/_client/model/model_version_impl.py +22 -7
  16. snowflake/ml/model/_client/ops/model_ops.py +39 -3
  17. snowflake/ml/model/_client/ops/service_ops.py +198 -7
  18. snowflake/ml/model/_client/service/model_deployment_spec.py +4 -5
  19. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -2
  20. snowflake/ml/model/_client/sql/service.py +85 -18
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  23. snowflake/ml/model/_model_composer/model_composer.py +2 -0
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -8
  25. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  26. snowflake/ml/model/_packager/model_handlers/catboost.py +17 -15
  27. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +23 -15
  28. snowflake/ml/model/_packager/model_handlers/lightgbm.py +15 -57
  29. snowflake/ml/model/_packager/model_handlers/llm.py +4 -2
  30. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  31. snowflake/ml/model/_packager/model_handlers/sklearn.py +36 -24
  32. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +119 -6
  33. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  34. snowflake/ml/model/_packager/model_handlers/xgboost.py +48 -48
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +10 -7
  36. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +0 -8
  37. snowflake/ml/model/_packager/model_packager.py +2 -0
  38. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  39. snowflake/ml/model/_signatures/utils.py +9 -0
  40. snowflake/ml/model/models/llm.py +3 -1
  41. snowflake/ml/model/type_hints.py +9 -1
  42. snowflake/ml/modeling/_internal/constants.py +1 -0
  43. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  44. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  45. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  46. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  48. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  49. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  50. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  51. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  52. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  53. snowflake/ml/modeling/cluster/birch.py +60 -21
  54. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  55. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  56. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  57. snowflake/ml/modeling/cluster/k_means.py +60 -21
  58. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  59. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  60. snowflake/ml/modeling/cluster/optics.py +60 -21
  61. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  62. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  63. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  64. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  65. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  66. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  67. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  68. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  69. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  70. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  71. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  72. snowflake/ml/modeling/covariance/oas.py +60 -21
  73. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  74. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  75. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  76. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  77. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  78. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  79. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  80. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  81. snowflake/ml/modeling/decomposition/pca.py +60 -21
  82. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  83. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  84. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  85. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  86. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  87. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  88. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  89. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  90. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  91. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  92. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  93. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  94. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  95. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  96. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  97. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  98. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  99. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  100. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  101. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  102. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  103. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  104. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  105. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  106. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  107. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  108. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  109. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  111. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  112. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  113. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  114. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  123. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  125. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  129. snowflake/ml/modeling/linear_model/lars.py +60 -21
  130. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  131. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  136. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  146. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  149. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  158. snowflake/ml/modeling/manifold/isomap.py +60 -21
  159. snowflake/ml/modeling/manifold/mds.py +60 -21
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  161. snowflake/ml/modeling/manifold/tsne.py +60 -21
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  174. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  184. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  185. snowflake/ml/modeling/pipeline/pipeline.py +1 -12
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  189. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  190. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  191. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  192. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  193. snowflake/ml/modeling/svm/svc.py +60 -21
  194. snowflake/ml/modeling/svm/svr.py +60 -21
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  203. snowflake/ml/registry/_manager/model_manager.py +4 -0
  204. snowflake/ml/registry/model_registry.py +1 -1
  205. snowflake/ml/registry/registry.py +1 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +23 -4
  208. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +211 -209
  209. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  210. snowflake/ml/data/torch_dataset.py +0 -33
  211. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  212. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -592,12 +589,23 @@ class KernelPCA(BaseTransformer):
592
589
  autogenerated=self._autogenerated,
593
590
  subproject=_SUBPROJECT,
594
591
  )
595
- output_result, fitted_estimator = model_trainer.train_fit_predict(
596
- drop_input_cols=self._drop_input_cols,
597
- expected_output_cols_list=(
598
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
599
- ),
592
+ expected_output_cols = (
593
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
600
594
  )
595
+ if isinstance(dataset, DataFrame):
596
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
597
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
598
+ )
599
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
600
+ drop_input_cols=self._drop_input_cols,
601
+ expected_output_cols_list=expected_output_cols,
602
+ example_output_pd_df=example_output_pd_df,
603
+ )
604
+ else:
605
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
606
+ drop_input_cols=self._drop_input_cols,
607
+ expected_output_cols_list=expected_output_cols,
608
+ )
601
609
  self._sklearn_object = fitted_estimator
602
610
  self._is_fitted = True
603
611
  return output_result
@@ -678,12 +686,41 @@ class KernelPCA(BaseTransformer):
678
686
 
679
687
  return rv
680
688
 
681
- def _align_expected_output_names(
682
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
683
- ) -> List[str]:
689
+ def _align_expected_output(
690
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
691
+ ) -> Tuple[List[str], pd.DataFrame]:
692
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
693
+ and output dataframe with 1 line.
694
+ If the method is fit_predict, run 2 lines of data.
695
+ """
684
696
  # in case the inferred output column names dimension is different
685
697
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
686
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
698
+
699
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
700
+ # so change the minimum of number of rows to 2
701
+ num_examples = 2
702
+ statement_params = telemetry.get_function_usage_statement_params(
703
+ project=_PROJECT,
704
+ subproject=_SUBPROJECT,
705
+ function_name=telemetry.get_statement_params_full_func_name(
706
+ inspect.currentframe(), KernelPCA.__class__.__name__
707
+ ),
708
+ api_calls=[Session.call],
709
+ custom_tags={"autogen": True} if self._autogenerated else None,
710
+ )
711
+ if output_cols_prefix == "fit_predict_":
712
+ if hasattr(self._sklearn_object, "n_clusters"):
713
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
714
+ num_examples = self._sklearn_object.n_clusters
715
+ elif hasattr(self._sklearn_object, "min_samples"):
716
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
717
+ num_examples = self._sklearn_object.min_samples
718
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
719
+ # LocalOutlierFactor expects n_neighbors <= n_samples
720
+ num_examples = self._sklearn_object.n_neighbors
721
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
722
+ else:
723
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
687
724
 
688
725
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
689
726
  # seen during the fit.
@@ -695,12 +732,14 @@ class KernelPCA(BaseTransformer):
695
732
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
696
733
  if self.sample_weight_col:
697
734
  output_df_columns_set -= set(self.sample_weight_col)
735
+
698
736
  # if the dimension of inferred output column names is correct; use it
699
737
  if len(expected_output_cols_list) == len(output_df_columns_set):
700
- return expected_output_cols_list
738
+ return expected_output_cols_list, output_df_pd
701
739
  # otherwise, use the sklearn estimator's output
702
740
  else:
703
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
741
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
742
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
704
743
 
705
744
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
706
745
  @telemetry.send_api_usage_telemetry(
@@ -746,7 +785,7 @@ class KernelPCA(BaseTransformer):
746
785
  drop_input_cols=self._drop_input_cols,
747
786
  expected_output_cols_type="float",
748
787
  )
749
- expected_output_cols = self._align_expected_output_names(
788
+ expected_output_cols, _ = self._align_expected_output(
750
789
  inference_method, dataset, expected_output_cols, output_cols_prefix
751
790
  )
752
791
 
@@ -812,7 +851,7 @@ class KernelPCA(BaseTransformer):
812
851
  drop_input_cols=self._drop_input_cols,
813
852
  expected_output_cols_type="float",
814
853
  )
815
- expected_output_cols = self._align_expected_output_names(
854
+ expected_output_cols, _ = self._align_expected_output(
816
855
  inference_method, dataset, expected_output_cols, output_cols_prefix
817
856
  )
818
857
  elif isinstance(dataset, pd.DataFrame):
@@ -875,7 +914,7 @@ class KernelPCA(BaseTransformer):
875
914
  drop_input_cols=self._drop_input_cols,
876
915
  expected_output_cols_type="float",
877
916
  )
878
- expected_output_cols = self._align_expected_output_names(
917
+ expected_output_cols, _ = self._align_expected_output(
879
918
  inference_method, dataset, expected_output_cols, output_cols_prefix
880
919
  )
881
920
 
@@ -940,7 +979,7 @@ class KernelPCA(BaseTransformer):
940
979
  drop_input_cols = self._drop_input_cols,
941
980
  expected_output_cols_type="float",
942
981
  )
943
- expected_output_cols = self._align_expected_output_names(
982
+ expected_output_cols, _ = self._align_expected_output(
944
983
  inference_method, dataset, expected_output_cols, output_cols_prefix
945
984
  )
946
985
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -614,12 +611,23 @@ class MiniBatchDictionaryLearning(BaseTransformer):
614
611
  autogenerated=self._autogenerated,
615
612
  subproject=_SUBPROJECT,
616
613
  )
617
- output_result, fitted_estimator = model_trainer.train_fit_predict(
618
- drop_input_cols=self._drop_input_cols,
619
- expected_output_cols_list=(
620
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
621
- ),
614
+ expected_output_cols = (
615
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
622
616
  )
617
+ if isinstance(dataset, DataFrame):
618
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
619
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
620
+ )
621
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
622
+ drop_input_cols=self._drop_input_cols,
623
+ expected_output_cols_list=expected_output_cols,
624
+ example_output_pd_df=example_output_pd_df,
625
+ )
626
+ else:
627
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
628
+ drop_input_cols=self._drop_input_cols,
629
+ expected_output_cols_list=expected_output_cols,
630
+ )
623
631
  self._sklearn_object = fitted_estimator
624
632
  self._is_fitted = True
625
633
  return output_result
@@ -700,12 +708,41 @@ class MiniBatchDictionaryLearning(BaseTransformer):
700
708
 
701
709
  return rv
702
710
 
703
- def _align_expected_output_names(
704
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
705
- ) -> List[str]:
711
+ def _align_expected_output(
712
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
713
+ ) -> Tuple[List[str], pd.DataFrame]:
714
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
715
+ and output dataframe with 1 line.
716
+ If the method is fit_predict, run 2 lines of data.
717
+ """
706
718
  # in case the inferred output column names dimension is different
707
719
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
708
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
720
+
721
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
722
+ # so change the minimum of number of rows to 2
723
+ num_examples = 2
724
+ statement_params = telemetry.get_function_usage_statement_params(
725
+ project=_PROJECT,
726
+ subproject=_SUBPROJECT,
727
+ function_name=telemetry.get_statement_params_full_func_name(
728
+ inspect.currentframe(), MiniBatchDictionaryLearning.__class__.__name__
729
+ ),
730
+ api_calls=[Session.call],
731
+ custom_tags={"autogen": True} if self._autogenerated else None,
732
+ )
733
+ if output_cols_prefix == "fit_predict_":
734
+ if hasattr(self._sklearn_object, "n_clusters"):
735
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
736
+ num_examples = self._sklearn_object.n_clusters
737
+ elif hasattr(self._sklearn_object, "min_samples"):
738
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
739
+ num_examples = self._sklearn_object.min_samples
740
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
741
+ # LocalOutlierFactor expects n_neighbors <= n_samples
742
+ num_examples = self._sklearn_object.n_neighbors
743
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
744
+ else:
745
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
709
746
 
710
747
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
711
748
  # seen during the fit.
@@ -717,12 +754,14 @@ class MiniBatchDictionaryLearning(BaseTransformer):
717
754
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
718
755
  if self.sample_weight_col:
719
756
  output_df_columns_set -= set(self.sample_weight_col)
757
+
720
758
  # if the dimension of inferred output column names is correct; use it
721
759
  if len(expected_output_cols_list) == len(output_df_columns_set):
722
- return expected_output_cols_list
760
+ return expected_output_cols_list, output_df_pd
723
761
  # otherwise, use the sklearn estimator's output
724
762
  else:
725
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
763
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
764
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
726
765
 
727
766
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
728
767
  @telemetry.send_api_usage_telemetry(
@@ -768,7 +807,7 @@ class MiniBatchDictionaryLearning(BaseTransformer):
768
807
  drop_input_cols=self._drop_input_cols,
769
808
  expected_output_cols_type="float",
770
809
  )
771
- expected_output_cols = self._align_expected_output_names(
810
+ expected_output_cols, _ = self._align_expected_output(
772
811
  inference_method, dataset, expected_output_cols, output_cols_prefix
773
812
  )
774
813
 
@@ -834,7 +873,7 @@ class MiniBatchDictionaryLearning(BaseTransformer):
834
873
  drop_input_cols=self._drop_input_cols,
835
874
  expected_output_cols_type="float",
836
875
  )
837
- expected_output_cols = self._align_expected_output_names(
876
+ expected_output_cols, _ = self._align_expected_output(
838
877
  inference_method, dataset, expected_output_cols, output_cols_prefix
839
878
  )
840
879
  elif isinstance(dataset, pd.DataFrame):
@@ -897,7 +936,7 @@ class MiniBatchDictionaryLearning(BaseTransformer):
897
936
  drop_input_cols=self._drop_input_cols,
898
937
  expected_output_cols_type="float",
899
938
  )
900
- expected_output_cols = self._align_expected_output_names(
939
+ expected_output_cols, _ = self._align_expected_output(
901
940
  inference_method, dataset, expected_output_cols, output_cols_prefix
902
941
  )
903
942
 
@@ -962,7 +1001,7 @@ class MiniBatchDictionaryLearning(BaseTransformer):
962
1001
  drop_input_cols = self._drop_input_cols,
963
1002
  expected_output_cols_type="float",
964
1003
  )
965
- expected_output_cols = self._align_expected_output_names(
1004
+ expected_output_cols, _ = self._align_expected_output(
966
1005
  inference_method, dataset, expected_output_cols, output_cols_prefix
967
1006
  )
968
1007
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -559,12 +556,23 @@ class MiniBatchSparsePCA(BaseTransformer):
559
556
  autogenerated=self._autogenerated,
560
557
  subproject=_SUBPROJECT,
561
558
  )
562
- output_result, fitted_estimator = model_trainer.train_fit_predict(
563
- drop_input_cols=self._drop_input_cols,
564
- expected_output_cols_list=(
565
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
566
- ),
559
+ expected_output_cols = (
560
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
567
561
  )
562
+ if isinstance(dataset, DataFrame):
563
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
564
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
565
+ )
566
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
567
+ drop_input_cols=self._drop_input_cols,
568
+ expected_output_cols_list=expected_output_cols,
569
+ example_output_pd_df=example_output_pd_df,
570
+ )
571
+ else:
572
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
573
+ drop_input_cols=self._drop_input_cols,
574
+ expected_output_cols_list=expected_output_cols,
575
+ )
568
576
  self._sklearn_object = fitted_estimator
569
577
  self._is_fitted = True
570
578
  return output_result
@@ -645,12 +653,41 @@ class MiniBatchSparsePCA(BaseTransformer):
645
653
 
646
654
  return rv
647
655
 
648
- def _align_expected_output_names(
649
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
650
- ) -> List[str]:
656
+ def _align_expected_output(
657
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
658
+ ) -> Tuple[List[str], pd.DataFrame]:
659
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
660
+ and output dataframe with 1 line.
661
+ If the method is fit_predict, run 2 lines of data.
662
+ """
651
663
  # in case the inferred output column names dimension is different
652
664
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
653
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
665
+
666
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
667
+ # so change the minimum of number of rows to 2
668
+ num_examples = 2
669
+ statement_params = telemetry.get_function_usage_statement_params(
670
+ project=_PROJECT,
671
+ subproject=_SUBPROJECT,
672
+ function_name=telemetry.get_statement_params_full_func_name(
673
+ inspect.currentframe(), MiniBatchSparsePCA.__class__.__name__
674
+ ),
675
+ api_calls=[Session.call],
676
+ custom_tags={"autogen": True} if self._autogenerated else None,
677
+ )
678
+ if output_cols_prefix == "fit_predict_":
679
+ if hasattr(self._sklearn_object, "n_clusters"):
680
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
681
+ num_examples = self._sklearn_object.n_clusters
682
+ elif hasattr(self._sklearn_object, "min_samples"):
683
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
684
+ num_examples = self._sklearn_object.min_samples
685
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
686
+ # LocalOutlierFactor expects n_neighbors <= n_samples
687
+ num_examples = self._sklearn_object.n_neighbors
688
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
689
+ else:
690
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
654
691
 
655
692
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
656
693
  # seen during the fit.
@@ -662,12 +699,14 @@ class MiniBatchSparsePCA(BaseTransformer):
662
699
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
663
700
  if self.sample_weight_col:
664
701
  output_df_columns_set -= set(self.sample_weight_col)
702
+
665
703
  # if the dimension of inferred output column names is correct; use it
666
704
  if len(expected_output_cols_list) == len(output_df_columns_set):
667
- return expected_output_cols_list
705
+ return expected_output_cols_list, output_df_pd
668
706
  # otherwise, use the sklearn estimator's output
669
707
  else:
670
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
708
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
709
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
671
710
 
672
711
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
673
712
  @telemetry.send_api_usage_telemetry(
@@ -713,7 +752,7 @@ class MiniBatchSparsePCA(BaseTransformer):
713
752
  drop_input_cols=self._drop_input_cols,
714
753
  expected_output_cols_type="float",
715
754
  )
716
- expected_output_cols = self._align_expected_output_names(
755
+ expected_output_cols, _ = self._align_expected_output(
717
756
  inference_method, dataset, expected_output_cols, output_cols_prefix
718
757
  )
719
758
 
@@ -779,7 +818,7 @@ class MiniBatchSparsePCA(BaseTransformer):
779
818
  drop_input_cols=self._drop_input_cols,
780
819
  expected_output_cols_type="float",
781
820
  )
782
- expected_output_cols = self._align_expected_output_names(
821
+ expected_output_cols, _ = self._align_expected_output(
783
822
  inference_method, dataset, expected_output_cols, output_cols_prefix
784
823
  )
785
824
  elif isinstance(dataset, pd.DataFrame):
@@ -842,7 +881,7 @@ class MiniBatchSparsePCA(BaseTransformer):
842
881
  drop_input_cols=self._drop_input_cols,
843
882
  expected_output_cols_type="float",
844
883
  )
845
- expected_output_cols = self._align_expected_output_names(
884
+ expected_output_cols, _ = self._align_expected_output(
846
885
  inference_method, dataset, expected_output_cols, output_cols_prefix
847
886
  )
848
887
 
@@ -907,7 +946,7 @@ class MiniBatchSparsePCA(BaseTransformer):
907
946
  drop_input_cols = self._drop_input_cols,
908
947
  expected_output_cols_type="float",
909
948
  )
910
- expected_output_cols = self._align_expected_output_names(
949
+ expected_output_cols, _ = self._align_expected_output(
911
950
  inference_method, dataset, expected_output_cols, output_cols_prefix
912
951
  )
913
952
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -561,12 +558,23 @@ class PCA(BaseTransformer):
561
558
  autogenerated=self._autogenerated,
562
559
  subproject=_SUBPROJECT,
563
560
  )
564
- output_result, fitted_estimator = model_trainer.train_fit_predict(
565
- drop_input_cols=self._drop_input_cols,
566
- expected_output_cols_list=(
567
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
568
- ),
561
+ expected_output_cols = (
562
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
569
563
  )
564
+ if isinstance(dataset, DataFrame):
565
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
566
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
567
+ )
568
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
569
+ drop_input_cols=self._drop_input_cols,
570
+ expected_output_cols_list=expected_output_cols,
571
+ example_output_pd_df=example_output_pd_df,
572
+ )
573
+ else:
574
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
575
+ drop_input_cols=self._drop_input_cols,
576
+ expected_output_cols_list=expected_output_cols,
577
+ )
570
578
  self._sklearn_object = fitted_estimator
571
579
  self._is_fitted = True
572
580
  return output_result
@@ -647,12 +655,41 @@ class PCA(BaseTransformer):
647
655
 
648
656
  return rv
649
657
 
650
- def _align_expected_output_names(
651
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
652
- ) -> List[str]:
658
+ def _align_expected_output(
659
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
660
+ ) -> Tuple[List[str], pd.DataFrame]:
661
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
662
+ and output dataframe with 1 line.
663
+ If the method is fit_predict, run 2 lines of data.
664
+ """
653
665
  # in case the inferred output column names dimension is different
654
666
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
655
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
667
+
668
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
669
+ # so change the minimum of number of rows to 2
670
+ num_examples = 2
671
+ statement_params = telemetry.get_function_usage_statement_params(
672
+ project=_PROJECT,
673
+ subproject=_SUBPROJECT,
674
+ function_name=telemetry.get_statement_params_full_func_name(
675
+ inspect.currentframe(), PCA.__class__.__name__
676
+ ),
677
+ api_calls=[Session.call],
678
+ custom_tags={"autogen": True} if self._autogenerated else None,
679
+ )
680
+ if output_cols_prefix == "fit_predict_":
681
+ if hasattr(self._sklearn_object, "n_clusters"):
682
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
683
+ num_examples = self._sklearn_object.n_clusters
684
+ elif hasattr(self._sklearn_object, "min_samples"):
685
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
686
+ num_examples = self._sklearn_object.min_samples
687
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
688
+ # LocalOutlierFactor expects n_neighbors <= n_samples
689
+ num_examples = self._sklearn_object.n_neighbors
690
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
691
+ else:
692
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
656
693
 
657
694
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
658
695
  # seen during the fit.
@@ -664,12 +701,14 @@ class PCA(BaseTransformer):
664
701
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
665
702
  if self.sample_weight_col:
666
703
  output_df_columns_set -= set(self.sample_weight_col)
704
+
667
705
  # if the dimension of inferred output column names is correct; use it
668
706
  if len(expected_output_cols_list) == len(output_df_columns_set):
669
- return expected_output_cols_list
707
+ return expected_output_cols_list, output_df_pd
670
708
  # otherwise, use the sklearn estimator's output
671
709
  else:
672
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
710
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
711
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
673
712
 
674
713
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
675
714
  @telemetry.send_api_usage_telemetry(
@@ -715,7 +754,7 @@ class PCA(BaseTransformer):
715
754
  drop_input_cols=self._drop_input_cols,
716
755
  expected_output_cols_type="float",
717
756
  )
718
- expected_output_cols = self._align_expected_output_names(
757
+ expected_output_cols, _ = self._align_expected_output(
719
758
  inference_method, dataset, expected_output_cols, output_cols_prefix
720
759
  )
721
760
 
@@ -781,7 +820,7 @@ class PCA(BaseTransformer):
781
820
  drop_input_cols=self._drop_input_cols,
782
821
  expected_output_cols_type="float",
783
822
  )
784
- expected_output_cols = self._align_expected_output_names(
823
+ expected_output_cols, _ = self._align_expected_output(
785
824
  inference_method, dataset, expected_output_cols, output_cols_prefix
786
825
  )
787
826
  elif isinstance(dataset, pd.DataFrame):
@@ -844,7 +883,7 @@ class PCA(BaseTransformer):
844
883
  drop_input_cols=self._drop_input_cols,
845
884
  expected_output_cols_type="float",
846
885
  )
847
- expected_output_cols = self._align_expected_output_names(
886
+ expected_output_cols, _ = self._align_expected_output(
848
887
  inference_method, dataset, expected_output_cols, output_cols_prefix
849
888
  )
850
889
 
@@ -911,7 +950,7 @@ class PCA(BaseTransformer):
911
950
  drop_input_cols = self._drop_input_cols,
912
951
  expected_output_cols_type="float",
913
952
  )
914
- expected_output_cols = self._align_expected_output_names(
953
+ expected_output_cols, _ = self._align_expected_output(
915
954
  inference_method, dataset, expected_output_cols, output_cols_prefix
916
955
  )
917
956