snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. snowflake/ml/_internal/telemetry.py +142 -20
  2. snowflake/ml/_internal/utils/identifier.py +48 -11
  3. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  4. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  5. snowflake/ml/_internal/utils/table_manager.py +19 -1
  6. snowflake/ml/_internal/utils/uri.py +2 -2
  7. snowflake/ml/data/data_connector.py +33 -7
  8. snowflake/ml/data/torch_utils.py +68 -0
  9. snowflake/ml/dataset/dataset.py +1 -3
  10. snowflake/ml/feature_store/feature_store.py +41 -17
  11. snowflake/ml/feature_store/feature_view.py +2 -2
  12. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  13. snowflake/ml/fileset/fileset.py +1 -1
  14. snowflake/ml/fileset/sfcfs.py +9 -3
  15. snowflake/ml/model/_client/model/model_version_impl.py +22 -7
  16. snowflake/ml/model/_client/ops/model_ops.py +39 -3
  17. snowflake/ml/model/_client/ops/service_ops.py +198 -7
  18. snowflake/ml/model/_client/service/model_deployment_spec.py +4 -5
  19. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -2
  20. snowflake/ml/model/_client/sql/service.py +85 -18
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  23. snowflake/ml/model/_model_composer/model_composer.py +2 -0
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -8
  25. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  26. snowflake/ml/model/_packager/model_handlers/catboost.py +17 -15
  27. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +23 -15
  28. snowflake/ml/model/_packager/model_handlers/lightgbm.py +15 -57
  29. snowflake/ml/model/_packager/model_handlers/llm.py +4 -2
  30. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  31. snowflake/ml/model/_packager/model_handlers/sklearn.py +36 -24
  32. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +119 -6
  33. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  34. snowflake/ml/model/_packager/model_handlers/xgboost.py +48 -48
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +10 -7
  36. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +0 -8
  37. snowflake/ml/model/_packager/model_packager.py +2 -0
  38. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  39. snowflake/ml/model/_signatures/utils.py +9 -0
  40. snowflake/ml/model/models/llm.py +3 -1
  41. snowflake/ml/model/type_hints.py +9 -1
  42. snowflake/ml/modeling/_internal/constants.py +1 -0
  43. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  44. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  45. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  46. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  48. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  49. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  50. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  51. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  52. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  53. snowflake/ml/modeling/cluster/birch.py +60 -21
  54. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  55. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  56. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  57. snowflake/ml/modeling/cluster/k_means.py +60 -21
  58. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  59. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  60. snowflake/ml/modeling/cluster/optics.py +60 -21
  61. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  62. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  63. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  64. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  65. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  66. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  67. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  68. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  69. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  70. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  71. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  72. snowflake/ml/modeling/covariance/oas.py +60 -21
  73. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  74. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  75. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  76. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  77. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  78. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  79. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  80. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  81. snowflake/ml/modeling/decomposition/pca.py +60 -21
  82. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  83. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  84. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  85. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  86. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  87. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  88. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  89. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  90. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  91. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  92. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  93. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  94. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  95. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  96. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  97. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  98. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  99. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  100. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  101. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  102. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  103. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  104. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  105. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  106. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  107. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  108. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  109. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  111. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  112. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  113. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  114. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  123. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  125. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  129. snowflake/ml/modeling/linear_model/lars.py +60 -21
  130. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  131. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  136. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  146. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  149. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  158. snowflake/ml/modeling/manifold/isomap.py +60 -21
  159. snowflake/ml/modeling/manifold/mds.py +60 -21
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  161. snowflake/ml/modeling/manifold/tsne.py +60 -21
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  174. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  184. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  185. snowflake/ml/modeling/pipeline/pipeline.py +1 -12
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  189. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  190. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  191. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  192. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  193. snowflake/ml/modeling/svm/svc.py +60 -21
  194. snowflake/ml/modeling/svm/svr.py +60 -21
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  203. snowflake/ml/registry/_manager/model_manager.py +4 -0
  204. snowflake/ml/registry/model_registry.py +1 -1
  205. snowflake/ml/registry/registry.py +1 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +23 -4
  208. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +211 -209
  209. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  210. snowflake/ml/data/torch_dataset.py +0 -33
  211. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  212. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -645,12 +642,23 @@ class HistGradientBoostingClassifier(BaseTransformer):
645
642
  autogenerated=self._autogenerated,
646
643
  subproject=_SUBPROJECT,
647
644
  )
648
- output_result, fitted_estimator = model_trainer.train_fit_predict(
649
- drop_input_cols=self._drop_input_cols,
650
- expected_output_cols_list=(
651
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
652
- ),
645
+ expected_output_cols = (
646
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
653
647
  )
648
+ if isinstance(dataset, DataFrame):
649
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
650
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
651
+ )
652
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
653
+ drop_input_cols=self._drop_input_cols,
654
+ expected_output_cols_list=expected_output_cols,
655
+ example_output_pd_df=example_output_pd_df,
656
+ )
657
+ else:
658
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
659
+ drop_input_cols=self._drop_input_cols,
660
+ expected_output_cols_list=expected_output_cols,
661
+ )
654
662
  self._sklearn_object = fitted_estimator
655
663
  self._is_fitted = True
656
664
  return output_result
@@ -729,12 +737,41 @@ class HistGradientBoostingClassifier(BaseTransformer):
729
737
 
730
738
  return rv
731
739
 
732
- def _align_expected_output_names(
733
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
734
- ) -> List[str]:
740
+ def _align_expected_output(
741
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
742
+ ) -> Tuple[List[str], pd.DataFrame]:
743
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
744
+ and output dataframe with 1 line.
745
+ If the method is fit_predict, run 2 lines of data.
746
+ """
735
747
  # in case the inferred output column names dimension is different
736
748
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
737
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
749
+
750
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
751
+ # so change the minimum of number of rows to 2
752
+ num_examples = 2
753
+ statement_params = telemetry.get_function_usage_statement_params(
754
+ project=_PROJECT,
755
+ subproject=_SUBPROJECT,
756
+ function_name=telemetry.get_statement_params_full_func_name(
757
+ inspect.currentframe(), HistGradientBoostingClassifier.__class__.__name__
758
+ ),
759
+ api_calls=[Session.call],
760
+ custom_tags={"autogen": True} if self._autogenerated else None,
761
+ )
762
+ if output_cols_prefix == "fit_predict_":
763
+ if hasattr(self._sklearn_object, "n_clusters"):
764
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
765
+ num_examples = self._sklearn_object.n_clusters
766
+ elif hasattr(self._sklearn_object, "min_samples"):
767
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
768
+ num_examples = self._sklearn_object.min_samples
769
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
770
+ # LocalOutlierFactor expects n_neighbors <= n_samples
771
+ num_examples = self._sklearn_object.n_neighbors
772
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
773
+ else:
774
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
738
775
 
739
776
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
740
777
  # seen during the fit.
@@ -746,12 +783,14 @@ class HistGradientBoostingClassifier(BaseTransformer):
746
783
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
747
784
  if self.sample_weight_col:
748
785
  output_df_columns_set -= set(self.sample_weight_col)
786
+
749
787
  # if the dimension of inferred output column names is correct; use it
750
788
  if len(expected_output_cols_list) == len(output_df_columns_set):
751
- return expected_output_cols_list
789
+ return expected_output_cols_list, output_df_pd
752
790
  # otherwise, use the sklearn estimator's output
753
791
  else:
754
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
792
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
793
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
755
794
 
756
795
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
757
796
  @telemetry.send_api_usage_telemetry(
@@ -799,7 +838,7 @@ class HistGradientBoostingClassifier(BaseTransformer):
799
838
  drop_input_cols=self._drop_input_cols,
800
839
  expected_output_cols_type="float",
801
840
  )
802
- expected_output_cols = self._align_expected_output_names(
841
+ expected_output_cols, _ = self._align_expected_output(
803
842
  inference_method, dataset, expected_output_cols, output_cols_prefix
804
843
  )
805
844
 
@@ -867,7 +906,7 @@ class HistGradientBoostingClassifier(BaseTransformer):
867
906
  drop_input_cols=self._drop_input_cols,
868
907
  expected_output_cols_type="float",
869
908
  )
870
- expected_output_cols = self._align_expected_output_names(
909
+ expected_output_cols, _ = self._align_expected_output(
871
910
  inference_method, dataset, expected_output_cols, output_cols_prefix
872
911
  )
873
912
  elif isinstance(dataset, pd.DataFrame):
@@ -932,7 +971,7 @@ class HistGradientBoostingClassifier(BaseTransformer):
932
971
  drop_input_cols=self._drop_input_cols,
933
972
  expected_output_cols_type="float",
934
973
  )
935
- expected_output_cols = self._align_expected_output_names(
974
+ expected_output_cols, _ = self._align_expected_output(
936
975
  inference_method, dataset, expected_output_cols, output_cols_prefix
937
976
  )
938
977
 
@@ -997,7 +1036,7 @@ class HistGradientBoostingClassifier(BaseTransformer):
997
1036
  drop_input_cols = self._drop_input_cols,
998
1037
  expected_output_cols_type="float",
999
1038
  )
1000
- expected_output_cols = self._align_expected_output_names(
1039
+ expected_output_cols, _ = self._align_expected_output(
1001
1040
  inference_method, dataset, expected_output_cols, output_cols_prefix
1002
1041
  )
1003
1042
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -636,12 +633,23 @@ class HistGradientBoostingRegressor(BaseTransformer):
636
633
  autogenerated=self._autogenerated,
637
634
  subproject=_SUBPROJECT,
638
635
  )
639
- output_result, fitted_estimator = model_trainer.train_fit_predict(
640
- drop_input_cols=self._drop_input_cols,
641
- expected_output_cols_list=(
642
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
643
- ),
636
+ expected_output_cols = (
637
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
644
638
  )
639
+ if isinstance(dataset, DataFrame):
640
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
641
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
642
+ )
643
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
644
+ drop_input_cols=self._drop_input_cols,
645
+ expected_output_cols_list=expected_output_cols,
646
+ example_output_pd_df=example_output_pd_df,
647
+ )
648
+ else:
649
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
650
+ drop_input_cols=self._drop_input_cols,
651
+ expected_output_cols_list=expected_output_cols,
652
+ )
645
653
  self._sklearn_object = fitted_estimator
646
654
  self._is_fitted = True
647
655
  return output_result
@@ -720,12 +728,41 @@ class HistGradientBoostingRegressor(BaseTransformer):
720
728
 
721
729
  return rv
722
730
 
723
- def _align_expected_output_names(
724
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
725
- ) -> List[str]:
731
+ def _align_expected_output(
732
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
733
+ ) -> Tuple[List[str], pd.DataFrame]:
734
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
735
+ and output dataframe with 1 line.
736
+ If the method is fit_predict, run 2 lines of data.
737
+ """
726
738
  # in case the inferred output column names dimension is different
727
739
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
728
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
740
+
741
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
742
+ # so change the minimum of number of rows to 2
743
+ num_examples = 2
744
+ statement_params = telemetry.get_function_usage_statement_params(
745
+ project=_PROJECT,
746
+ subproject=_SUBPROJECT,
747
+ function_name=telemetry.get_statement_params_full_func_name(
748
+ inspect.currentframe(), HistGradientBoostingRegressor.__class__.__name__
749
+ ),
750
+ api_calls=[Session.call],
751
+ custom_tags={"autogen": True} if self._autogenerated else None,
752
+ )
753
+ if output_cols_prefix == "fit_predict_":
754
+ if hasattr(self._sklearn_object, "n_clusters"):
755
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
756
+ num_examples = self._sklearn_object.n_clusters
757
+ elif hasattr(self._sklearn_object, "min_samples"):
758
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
759
+ num_examples = self._sklearn_object.min_samples
760
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
761
+ # LocalOutlierFactor expects n_neighbors <= n_samples
762
+ num_examples = self._sklearn_object.n_neighbors
763
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
764
+ else:
765
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
729
766
 
730
767
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
731
768
  # seen during the fit.
@@ -737,12 +774,14 @@ class HistGradientBoostingRegressor(BaseTransformer):
737
774
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
738
775
  if self.sample_weight_col:
739
776
  output_df_columns_set -= set(self.sample_weight_col)
777
+
740
778
  # if the dimension of inferred output column names is correct; use it
741
779
  if len(expected_output_cols_list) == len(output_df_columns_set):
742
- return expected_output_cols_list
780
+ return expected_output_cols_list, output_df_pd
743
781
  # otherwise, use the sklearn estimator's output
744
782
  else:
745
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
783
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
784
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
746
785
 
747
786
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
748
787
  @telemetry.send_api_usage_telemetry(
@@ -788,7 +827,7 @@ class HistGradientBoostingRegressor(BaseTransformer):
788
827
  drop_input_cols=self._drop_input_cols,
789
828
  expected_output_cols_type="float",
790
829
  )
791
- expected_output_cols = self._align_expected_output_names(
830
+ expected_output_cols, _ = self._align_expected_output(
792
831
  inference_method, dataset, expected_output_cols, output_cols_prefix
793
832
  )
794
833
 
@@ -854,7 +893,7 @@ class HistGradientBoostingRegressor(BaseTransformer):
854
893
  drop_input_cols=self._drop_input_cols,
855
894
  expected_output_cols_type="float",
856
895
  )
857
- expected_output_cols = self._align_expected_output_names(
896
+ expected_output_cols, _ = self._align_expected_output(
858
897
  inference_method, dataset, expected_output_cols, output_cols_prefix
859
898
  )
860
899
  elif isinstance(dataset, pd.DataFrame):
@@ -917,7 +956,7 @@ class HistGradientBoostingRegressor(BaseTransformer):
917
956
  drop_input_cols=self._drop_input_cols,
918
957
  expected_output_cols_type="float",
919
958
  )
920
- expected_output_cols = self._align_expected_output_names(
959
+ expected_output_cols, _ = self._align_expected_output(
921
960
  inference_method, dataset, expected_output_cols, output_cols_prefix
922
961
  )
923
962
 
@@ -982,7 +1021,7 @@ class HistGradientBoostingRegressor(BaseTransformer):
982
1021
  drop_input_cols = self._drop_input_cols,
983
1022
  expected_output_cols_type="float",
984
1023
  )
985
- expected_output_cols = self._align_expected_output_names(
1024
+ expected_output_cols, _ = self._align_expected_output(
986
1025
  inference_method, dataset, expected_output_cols, output_cols_prefix
987
1026
  )
988
1027
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -540,12 +537,23 @@ class IsolationForest(BaseTransformer):
540
537
  autogenerated=self._autogenerated,
541
538
  subproject=_SUBPROJECT,
542
539
  )
543
- output_result, fitted_estimator = model_trainer.train_fit_predict(
544
- drop_input_cols=self._drop_input_cols,
545
- expected_output_cols_list=(
546
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
547
- ),
540
+ expected_output_cols = (
541
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
548
542
  )
543
+ if isinstance(dataset, DataFrame):
544
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
545
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
546
+ )
547
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
548
+ drop_input_cols=self._drop_input_cols,
549
+ expected_output_cols_list=expected_output_cols,
550
+ example_output_pd_df=example_output_pd_df,
551
+ )
552
+ else:
553
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
554
+ drop_input_cols=self._drop_input_cols,
555
+ expected_output_cols_list=expected_output_cols,
556
+ )
549
557
  self._sklearn_object = fitted_estimator
550
558
  self._is_fitted = True
551
559
  return output_result
@@ -624,12 +632,41 @@ class IsolationForest(BaseTransformer):
624
632
 
625
633
  return rv
626
634
 
627
- def _align_expected_output_names(
628
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
629
- ) -> List[str]:
635
+ def _align_expected_output(
636
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
637
+ ) -> Tuple[List[str], pd.DataFrame]:
638
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
639
+ and output dataframe with 1 line.
640
+ If the method is fit_predict, run 2 lines of data.
641
+ """
630
642
  # in case the inferred output column names dimension is different
631
643
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
632
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
644
+
645
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
646
+ # so change the minimum of number of rows to 2
647
+ num_examples = 2
648
+ statement_params = telemetry.get_function_usage_statement_params(
649
+ project=_PROJECT,
650
+ subproject=_SUBPROJECT,
651
+ function_name=telemetry.get_statement_params_full_func_name(
652
+ inspect.currentframe(), IsolationForest.__class__.__name__
653
+ ),
654
+ api_calls=[Session.call],
655
+ custom_tags={"autogen": True} if self._autogenerated else None,
656
+ )
657
+ if output_cols_prefix == "fit_predict_":
658
+ if hasattr(self._sklearn_object, "n_clusters"):
659
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
660
+ num_examples = self._sklearn_object.n_clusters
661
+ elif hasattr(self._sklearn_object, "min_samples"):
662
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
663
+ num_examples = self._sklearn_object.min_samples
664
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
665
+ # LocalOutlierFactor expects n_neighbors <= n_samples
666
+ num_examples = self._sklearn_object.n_neighbors
667
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
668
+ else:
669
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
633
670
 
634
671
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
635
672
  # seen during the fit.
@@ -641,12 +678,14 @@ class IsolationForest(BaseTransformer):
641
678
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
642
679
  if self.sample_weight_col:
643
680
  output_df_columns_set -= set(self.sample_weight_col)
681
+
644
682
  # if the dimension of inferred output column names is correct; use it
645
683
  if len(expected_output_cols_list) == len(output_df_columns_set):
646
- return expected_output_cols_list
684
+ return expected_output_cols_list, output_df_pd
647
685
  # otherwise, use the sklearn estimator's output
648
686
  else:
649
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
687
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
688
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
650
689
 
651
690
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
652
691
  @telemetry.send_api_usage_telemetry(
@@ -692,7 +731,7 @@ class IsolationForest(BaseTransformer):
692
731
  drop_input_cols=self._drop_input_cols,
693
732
  expected_output_cols_type="float",
694
733
  )
695
- expected_output_cols = self._align_expected_output_names(
734
+ expected_output_cols, _ = self._align_expected_output(
696
735
  inference_method, dataset, expected_output_cols, output_cols_prefix
697
736
  )
698
737
 
@@ -758,7 +797,7 @@ class IsolationForest(BaseTransformer):
758
797
  drop_input_cols=self._drop_input_cols,
759
798
  expected_output_cols_type="float",
760
799
  )
761
- expected_output_cols = self._align_expected_output_names(
800
+ expected_output_cols, _ = self._align_expected_output(
762
801
  inference_method, dataset, expected_output_cols, output_cols_prefix
763
802
  )
764
803
  elif isinstance(dataset, pd.DataFrame):
@@ -823,7 +862,7 @@ class IsolationForest(BaseTransformer):
823
862
  drop_input_cols=self._drop_input_cols,
824
863
  expected_output_cols_type="float",
825
864
  )
826
- expected_output_cols = self._align_expected_output_names(
865
+ expected_output_cols, _ = self._align_expected_output(
827
866
  inference_method, dataset, expected_output_cols, output_cols_prefix
828
867
  )
829
868
 
@@ -890,7 +929,7 @@ class IsolationForest(BaseTransformer):
890
929
  drop_input_cols = self._drop_input_cols,
891
930
  expected_output_cols_type="float",
892
931
  )
893
- expected_output_cols = self._align_expected_output_names(
932
+ expected_output_cols, _ = self._align_expected_output(
894
933
  inference_method, dataset, expected_output_cols, output_cols_prefix
895
934
  )
896
935
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -648,12 +645,23 @@ class RandomForestClassifier(BaseTransformer):
648
645
  autogenerated=self._autogenerated,
649
646
  subproject=_SUBPROJECT,
650
647
  )
651
- output_result, fitted_estimator = model_trainer.train_fit_predict(
652
- drop_input_cols=self._drop_input_cols,
653
- expected_output_cols_list=(
654
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
655
- ),
648
+ expected_output_cols = (
649
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
656
650
  )
651
+ if isinstance(dataset, DataFrame):
652
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
653
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
654
+ )
655
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
656
+ drop_input_cols=self._drop_input_cols,
657
+ expected_output_cols_list=expected_output_cols,
658
+ example_output_pd_df=example_output_pd_df,
659
+ )
660
+ else:
661
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
662
+ drop_input_cols=self._drop_input_cols,
663
+ expected_output_cols_list=expected_output_cols,
664
+ )
657
665
  self._sklearn_object = fitted_estimator
658
666
  self._is_fitted = True
659
667
  return output_result
@@ -732,12 +740,41 @@ class RandomForestClassifier(BaseTransformer):
732
740
 
733
741
  return rv
734
742
 
735
- def _align_expected_output_names(
736
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
737
- ) -> List[str]:
743
+ def _align_expected_output(
744
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
745
+ ) -> Tuple[List[str], pd.DataFrame]:
746
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
747
+ and output dataframe with 1 line.
748
+ If the method is fit_predict, run 2 lines of data.
749
+ """
738
750
  # in case the inferred output column names dimension is different
739
751
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
740
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
752
+
753
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
754
+ # so change the minimum of number of rows to 2
755
+ num_examples = 2
756
+ statement_params = telemetry.get_function_usage_statement_params(
757
+ project=_PROJECT,
758
+ subproject=_SUBPROJECT,
759
+ function_name=telemetry.get_statement_params_full_func_name(
760
+ inspect.currentframe(), RandomForestClassifier.__class__.__name__
761
+ ),
762
+ api_calls=[Session.call],
763
+ custom_tags={"autogen": True} if self._autogenerated else None,
764
+ )
765
+ if output_cols_prefix == "fit_predict_":
766
+ if hasattr(self._sklearn_object, "n_clusters"):
767
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
768
+ num_examples = self._sklearn_object.n_clusters
769
+ elif hasattr(self._sklearn_object, "min_samples"):
770
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
771
+ num_examples = self._sklearn_object.min_samples
772
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
773
+ # LocalOutlierFactor expects n_neighbors <= n_samples
774
+ num_examples = self._sklearn_object.n_neighbors
775
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
776
+ else:
777
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
741
778
 
742
779
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
743
780
  # seen during the fit.
@@ -749,12 +786,14 @@ class RandomForestClassifier(BaseTransformer):
749
786
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
750
787
  if self.sample_weight_col:
751
788
  output_df_columns_set -= set(self.sample_weight_col)
789
+
752
790
  # if the dimension of inferred output column names is correct; use it
753
791
  if len(expected_output_cols_list) == len(output_df_columns_set):
754
- return expected_output_cols_list
792
+ return expected_output_cols_list, output_df_pd
755
793
  # otherwise, use the sklearn estimator's output
756
794
  else:
757
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
795
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
796
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
758
797
 
759
798
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
760
799
  @telemetry.send_api_usage_telemetry(
@@ -802,7 +841,7 @@ class RandomForestClassifier(BaseTransformer):
802
841
  drop_input_cols=self._drop_input_cols,
803
842
  expected_output_cols_type="float",
804
843
  )
805
- expected_output_cols = self._align_expected_output_names(
844
+ expected_output_cols, _ = self._align_expected_output(
806
845
  inference_method, dataset, expected_output_cols, output_cols_prefix
807
846
  )
808
847
 
@@ -870,7 +909,7 @@ class RandomForestClassifier(BaseTransformer):
870
909
  drop_input_cols=self._drop_input_cols,
871
910
  expected_output_cols_type="float",
872
911
  )
873
- expected_output_cols = self._align_expected_output_names(
912
+ expected_output_cols, _ = self._align_expected_output(
874
913
  inference_method, dataset, expected_output_cols, output_cols_prefix
875
914
  )
876
915
  elif isinstance(dataset, pd.DataFrame):
@@ -933,7 +972,7 @@ class RandomForestClassifier(BaseTransformer):
933
972
  drop_input_cols=self._drop_input_cols,
934
973
  expected_output_cols_type="float",
935
974
  )
936
- expected_output_cols = self._align_expected_output_names(
975
+ expected_output_cols, _ = self._align_expected_output(
937
976
  inference_method, dataset, expected_output_cols, output_cols_prefix
938
977
  )
939
978
 
@@ -998,7 +1037,7 @@ class RandomForestClassifier(BaseTransformer):
998
1037
  drop_input_cols = self._drop_input_cols,
999
1038
  expected_output_cols_type="float",
1000
1039
  )
1001
- expected_output_cols = self._align_expected_output_names(
1040
+ expected_output_cols, _ = self._align_expected_output(
1002
1041
  inference_method, dataset, expected_output_cols, output_cols_prefix
1003
1042
  )
1004
1043