snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. snowflake/ml/_internal/telemetry.py +142 -20
  2. snowflake/ml/_internal/utils/identifier.py +48 -11
  3. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  4. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  5. snowflake/ml/_internal/utils/table_manager.py +19 -1
  6. snowflake/ml/_internal/utils/uri.py +2 -2
  7. snowflake/ml/data/data_connector.py +33 -7
  8. snowflake/ml/data/torch_utils.py +68 -0
  9. snowflake/ml/dataset/dataset.py +1 -3
  10. snowflake/ml/feature_store/feature_store.py +41 -17
  11. snowflake/ml/feature_store/feature_view.py +2 -2
  12. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  13. snowflake/ml/fileset/fileset.py +1 -1
  14. snowflake/ml/fileset/sfcfs.py +9 -3
  15. snowflake/ml/model/_client/model/model_version_impl.py +22 -7
  16. snowflake/ml/model/_client/ops/model_ops.py +39 -3
  17. snowflake/ml/model/_client/ops/service_ops.py +198 -7
  18. snowflake/ml/model/_client/service/model_deployment_spec.py +4 -5
  19. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -2
  20. snowflake/ml/model/_client/sql/service.py +85 -18
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  23. snowflake/ml/model/_model_composer/model_composer.py +2 -0
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -8
  25. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  26. snowflake/ml/model/_packager/model_handlers/catboost.py +17 -15
  27. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +23 -15
  28. snowflake/ml/model/_packager/model_handlers/lightgbm.py +15 -57
  29. snowflake/ml/model/_packager/model_handlers/llm.py +4 -2
  30. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  31. snowflake/ml/model/_packager/model_handlers/sklearn.py +36 -24
  32. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +119 -6
  33. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  34. snowflake/ml/model/_packager/model_handlers/xgboost.py +48 -48
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +10 -7
  36. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +0 -8
  37. snowflake/ml/model/_packager/model_packager.py +2 -0
  38. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  39. snowflake/ml/model/_signatures/utils.py +9 -0
  40. snowflake/ml/model/models/llm.py +3 -1
  41. snowflake/ml/model/type_hints.py +9 -1
  42. snowflake/ml/modeling/_internal/constants.py +1 -0
  43. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  44. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  45. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  46. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  48. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  49. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  50. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  51. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  52. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  53. snowflake/ml/modeling/cluster/birch.py +60 -21
  54. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  55. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  56. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  57. snowflake/ml/modeling/cluster/k_means.py +60 -21
  58. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  59. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  60. snowflake/ml/modeling/cluster/optics.py +60 -21
  61. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  62. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  63. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  64. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  65. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  66. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  67. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  68. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  69. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  70. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  71. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  72. snowflake/ml/modeling/covariance/oas.py +60 -21
  73. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  74. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  75. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  76. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  77. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  78. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  79. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  80. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  81. snowflake/ml/modeling/decomposition/pca.py +60 -21
  82. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  83. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  84. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  85. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  86. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  87. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  88. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  89. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  90. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  91. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  92. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  93. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  94. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  95. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  96. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  97. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  98. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  99. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  100. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  101. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  102. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  103. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  104. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  105. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  106. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  107. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  108. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  109. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  111. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  112. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  113. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  114. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  123. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  125. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  129. snowflake/ml/modeling/linear_model/lars.py +60 -21
  130. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  131. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  136. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  146. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  149. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  158. snowflake/ml/modeling/manifold/isomap.py +60 -21
  159. snowflake/ml/modeling/manifold/mds.py +60 -21
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  161. snowflake/ml/modeling/manifold/tsne.py +60 -21
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  174. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  184. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  185. snowflake/ml/modeling/pipeline/pipeline.py +1 -12
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  189. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  190. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  191. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  192. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  193. snowflake/ml/modeling/svm/svc.py +60 -21
  194. snowflake/ml/modeling/svm/svr.py +60 -21
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  203. snowflake/ml/registry/_manager/model_manager.py +4 -0
  204. snowflake/ml/registry/model_registry.py +1 -1
  205. snowflake/ml/registry/registry.py +1 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +23 -4
  208. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +211 -209
  209. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  210. snowflake/ml/data/torch_dataset.py +0 -33
  211. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  212. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -652,12 +649,23 @@ class ExtraTreesClassifier(BaseTransformer):
652
649
  autogenerated=self._autogenerated,
653
650
  subproject=_SUBPROJECT,
654
651
  )
655
- output_result, fitted_estimator = model_trainer.train_fit_predict(
656
- drop_input_cols=self._drop_input_cols,
657
- expected_output_cols_list=(
658
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
659
- ),
652
+ expected_output_cols = (
653
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
660
654
  )
655
+ if isinstance(dataset, DataFrame):
656
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
657
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
658
+ )
659
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
660
+ drop_input_cols=self._drop_input_cols,
661
+ expected_output_cols_list=expected_output_cols,
662
+ example_output_pd_df=example_output_pd_df,
663
+ )
664
+ else:
665
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
666
+ drop_input_cols=self._drop_input_cols,
667
+ expected_output_cols_list=expected_output_cols,
668
+ )
661
669
  self._sklearn_object = fitted_estimator
662
670
  self._is_fitted = True
663
671
  return output_result
@@ -736,12 +744,41 @@ class ExtraTreesClassifier(BaseTransformer):
736
744
 
737
745
  return rv
738
746
 
739
- def _align_expected_output_names(
740
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
741
- ) -> List[str]:
747
+ def _align_expected_output(
748
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
749
+ ) -> Tuple[List[str], pd.DataFrame]:
750
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
751
+ and output dataframe with 1 line.
752
+ If the method is fit_predict, run 2 lines of data.
753
+ """
742
754
  # in case the inferred output column names dimension is different
743
755
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
744
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
756
+
757
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
758
+ # so change the minimum of number of rows to 2
759
+ num_examples = 2
760
+ statement_params = telemetry.get_function_usage_statement_params(
761
+ project=_PROJECT,
762
+ subproject=_SUBPROJECT,
763
+ function_name=telemetry.get_statement_params_full_func_name(
764
+ inspect.currentframe(), ExtraTreesClassifier.__class__.__name__
765
+ ),
766
+ api_calls=[Session.call],
767
+ custom_tags={"autogen": True} if self._autogenerated else None,
768
+ )
769
+ if output_cols_prefix == "fit_predict_":
770
+ if hasattr(self._sklearn_object, "n_clusters"):
771
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
772
+ num_examples = self._sklearn_object.n_clusters
773
+ elif hasattr(self._sklearn_object, "min_samples"):
774
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
775
+ num_examples = self._sklearn_object.min_samples
776
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
777
+ # LocalOutlierFactor expects n_neighbors <= n_samples
778
+ num_examples = self._sklearn_object.n_neighbors
779
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
780
+ else:
781
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
745
782
 
746
783
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
747
784
  # seen during the fit.
@@ -753,12 +790,14 @@ class ExtraTreesClassifier(BaseTransformer):
753
790
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
754
791
  if self.sample_weight_col:
755
792
  output_df_columns_set -= set(self.sample_weight_col)
793
+
756
794
  # if the dimension of inferred output column names is correct; use it
757
795
  if len(expected_output_cols_list) == len(output_df_columns_set):
758
- return expected_output_cols_list
796
+ return expected_output_cols_list, output_df_pd
759
797
  # otherwise, use the sklearn estimator's output
760
798
  else:
761
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
799
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
800
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
762
801
 
763
802
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
764
803
  @telemetry.send_api_usage_telemetry(
@@ -806,7 +845,7 @@ class ExtraTreesClassifier(BaseTransformer):
806
845
  drop_input_cols=self._drop_input_cols,
807
846
  expected_output_cols_type="float",
808
847
  )
809
- expected_output_cols = self._align_expected_output_names(
848
+ expected_output_cols, _ = self._align_expected_output(
810
849
  inference_method, dataset, expected_output_cols, output_cols_prefix
811
850
  )
812
851
 
@@ -874,7 +913,7 @@ class ExtraTreesClassifier(BaseTransformer):
874
913
  drop_input_cols=self._drop_input_cols,
875
914
  expected_output_cols_type="float",
876
915
  )
877
- expected_output_cols = self._align_expected_output_names(
916
+ expected_output_cols, _ = self._align_expected_output(
878
917
  inference_method, dataset, expected_output_cols, output_cols_prefix
879
918
  )
880
919
  elif isinstance(dataset, pd.DataFrame):
@@ -937,7 +976,7 @@ class ExtraTreesClassifier(BaseTransformer):
937
976
  drop_input_cols=self._drop_input_cols,
938
977
  expected_output_cols_type="float",
939
978
  )
940
- expected_output_cols = self._align_expected_output_names(
979
+ expected_output_cols, _ = self._align_expected_output(
941
980
  inference_method, dataset, expected_output_cols, output_cols_prefix
942
981
  )
943
982
 
@@ -1002,7 +1041,7 @@ class ExtraTreesClassifier(BaseTransformer):
1002
1041
  drop_input_cols = self._drop_input_cols,
1003
1042
  expected_output_cols_type="float",
1004
1043
  )
1005
- expected_output_cols = self._align_expected_output_names(
1044
+ expected_output_cols, _ = self._align_expected_output(
1006
1045
  inference_method, dataset, expected_output_cols, output_cols_prefix
1007
1046
  )
1008
1047
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -631,12 +628,23 @@ class ExtraTreesRegressor(BaseTransformer):
631
628
  autogenerated=self._autogenerated,
632
629
  subproject=_SUBPROJECT,
633
630
  )
634
- output_result, fitted_estimator = model_trainer.train_fit_predict(
635
- drop_input_cols=self._drop_input_cols,
636
- expected_output_cols_list=(
637
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
638
- ),
631
+ expected_output_cols = (
632
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
639
633
  )
634
+ if isinstance(dataset, DataFrame):
635
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
636
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
637
+ )
638
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
639
+ drop_input_cols=self._drop_input_cols,
640
+ expected_output_cols_list=expected_output_cols,
641
+ example_output_pd_df=example_output_pd_df,
642
+ )
643
+ else:
644
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
645
+ drop_input_cols=self._drop_input_cols,
646
+ expected_output_cols_list=expected_output_cols,
647
+ )
640
648
  self._sklearn_object = fitted_estimator
641
649
  self._is_fitted = True
642
650
  return output_result
@@ -715,12 +723,41 @@ class ExtraTreesRegressor(BaseTransformer):
715
723
 
716
724
  return rv
717
725
 
718
- def _align_expected_output_names(
719
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
720
- ) -> List[str]:
726
+ def _align_expected_output(
727
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
728
+ ) -> Tuple[List[str], pd.DataFrame]:
729
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
730
+ and output dataframe with 1 line.
731
+ If the method is fit_predict, run 2 lines of data.
732
+ """
721
733
  # in case the inferred output column names dimension is different
722
734
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
723
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
735
+
736
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
737
+ # so change the minimum of number of rows to 2
738
+ num_examples = 2
739
+ statement_params = telemetry.get_function_usage_statement_params(
740
+ project=_PROJECT,
741
+ subproject=_SUBPROJECT,
742
+ function_name=telemetry.get_statement_params_full_func_name(
743
+ inspect.currentframe(), ExtraTreesRegressor.__class__.__name__
744
+ ),
745
+ api_calls=[Session.call],
746
+ custom_tags={"autogen": True} if self._autogenerated else None,
747
+ )
748
+ if output_cols_prefix == "fit_predict_":
749
+ if hasattr(self._sklearn_object, "n_clusters"):
750
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
751
+ num_examples = self._sklearn_object.n_clusters
752
+ elif hasattr(self._sklearn_object, "min_samples"):
753
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
754
+ num_examples = self._sklearn_object.min_samples
755
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
756
+ # LocalOutlierFactor expects n_neighbors <= n_samples
757
+ num_examples = self._sklearn_object.n_neighbors
758
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
759
+ else:
760
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
724
761
 
725
762
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
726
763
  # seen during the fit.
@@ -732,12 +769,14 @@ class ExtraTreesRegressor(BaseTransformer):
732
769
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
733
770
  if self.sample_weight_col:
734
771
  output_df_columns_set -= set(self.sample_weight_col)
772
+
735
773
  # if the dimension of inferred output column names is correct; use it
736
774
  if len(expected_output_cols_list) == len(output_df_columns_set):
737
- return expected_output_cols_list
775
+ return expected_output_cols_list, output_df_pd
738
776
  # otherwise, use the sklearn estimator's output
739
777
  else:
740
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
778
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
779
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
741
780
 
742
781
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
743
782
  @telemetry.send_api_usage_telemetry(
@@ -783,7 +822,7 @@ class ExtraTreesRegressor(BaseTransformer):
783
822
  drop_input_cols=self._drop_input_cols,
784
823
  expected_output_cols_type="float",
785
824
  )
786
- expected_output_cols = self._align_expected_output_names(
825
+ expected_output_cols, _ = self._align_expected_output(
787
826
  inference_method, dataset, expected_output_cols, output_cols_prefix
788
827
  )
789
828
 
@@ -849,7 +888,7 @@ class ExtraTreesRegressor(BaseTransformer):
849
888
  drop_input_cols=self._drop_input_cols,
850
889
  expected_output_cols_type="float",
851
890
  )
852
- expected_output_cols = self._align_expected_output_names(
891
+ expected_output_cols, _ = self._align_expected_output(
853
892
  inference_method, dataset, expected_output_cols, output_cols_prefix
854
893
  )
855
894
  elif isinstance(dataset, pd.DataFrame):
@@ -912,7 +951,7 @@ class ExtraTreesRegressor(BaseTransformer):
912
951
  drop_input_cols=self._drop_input_cols,
913
952
  expected_output_cols_type="float",
914
953
  )
915
- expected_output_cols = self._align_expected_output_names(
954
+ expected_output_cols, _ = self._align_expected_output(
916
955
  inference_method, dataset, expected_output_cols, output_cols_prefix
917
956
  )
918
957
 
@@ -977,7 +1016,7 @@ class ExtraTreesRegressor(BaseTransformer):
977
1016
  drop_input_cols = self._drop_input_cols,
978
1017
  expected_output_cols_type="float",
979
1018
  )
980
- expected_output_cols = self._align_expected_output_names(
1019
+ expected_output_cols, _ = self._align_expected_output(
981
1020
  inference_method, dataset, expected_output_cols, output_cols_prefix
982
1021
  )
983
1022
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -664,12 +661,23 @@ class GradientBoostingClassifier(BaseTransformer):
664
661
  autogenerated=self._autogenerated,
665
662
  subproject=_SUBPROJECT,
666
663
  )
667
- output_result, fitted_estimator = model_trainer.train_fit_predict(
668
- drop_input_cols=self._drop_input_cols,
669
- expected_output_cols_list=(
670
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
671
- ),
664
+ expected_output_cols = (
665
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
672
666
  )
667
+ if isinstance(dataset, DataFrame):
668
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
669
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
670
+ )
671
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
672
+ drop_input_cols=self._drop_input_cols,
673
+ expected_output_cols_list=expected_output_cols,
674
+ example_output_pd_df=example_output_pd_df,
675
+ )
676
+ else:
677
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
678
+ drop_input_cols=self._drop_input_cols,
679
+ expected_output_cols_list=expected_output_cols,
680
+ )
673
681
  self._sklearn_object = fitted_estimator
674
682
  self._is_fitted = True
675
683
  return output_result
@@ -748,12 +756,41 @@ class GradientBoostingClassifier(BaseTransformer):
748
756
 
749
757
  return rv
750
758
 
751
- def _align_expected_output_names(
752
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
753
- ) -> List[str]:
759
+ def _align_expected_output(
760
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
761
+ ) -> Tuple[List[str], pd.DataFrame]:
762
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
763
+ and output dataframe with 1 line.
764
+ If the method is fit_predict, run 2 lines of data.
765
+ """
754
766
  # in case the inferred output column names dimension is different
755
767
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
756
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
768
+
769
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
770
+ # so change the minimum of number of rows to 2
771
+ num_examples = 2
772
+ statement_params = telemetry.get_function_usage_statement_params(
773
+ project=_PROJECT,
774
+ subproject=_SUBPROJECT,
775
+ function_name=telemetry.get_statement_params_full_func_name(
776
+ inspect.currentframe(), GradientBoostingClassifier.__class__.__name__
777
+ ),
778
+ api_calls=[Session.call],
779
+ custom_tags={"autogen": True} if self._autogenerated else None,
780
+ )
781
+ if output_cols_prefix == "fit_predict_":
782
+ if hasattr(self._sklearn_object, "n_clusters"):
783
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
784
+ num_examples = self._sklearn_object.n_clusters
785
+ elif hasattr(self._sklearn_object, "min_samples"):
786
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
787
+ num_examples = self._sklearn_object.min_samples
788
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
789
+ # LocalOutlierFactor expects n_neighbors <= n_samples
790
+ num_examples = self._sklearn_object.n_neighbors
791
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
792
+ else:
793
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
757
794
 
758
795
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
759
796
  # seen during the fit.
@@ -765,12 +802,14 @@ class GradientBoostingClassifier(BaseTransformer):
765
802
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
766
803
  if self.sample_weight_col:
767
804
  output_df_columns_set -= set(self.sample_weight_col)
805
+
768
806
  # if the dimension of inferred output column names is correct; use it
769
807
  if len(expected_output_cols_list) == len(output_df_columns_set):
770
- return expected_output_cols_list
808
+ return expected_output_cols_list, output_df_pd
771
809
  # otherwise, use the sklearn estimator's output
772
810
  else:
773
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
811
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
812
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
774
813
 
775
814
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
776
815
  @telemetry.send_api_usage_telemetry(
@@ -818,7 +857,7 @@ class GradientBoostingClassifier(BaseTransformer):
818
857
  drop_input_cols=self._drop_input_cols,
819
858
  expected_output_cols_type="float",
820
859
  )
821
- expected_output_cols = self._align_expected_output_names(
860
+ expected_output_cols, _ = self._align_expected_output(
822
861
  inference_method, dataset, expected_output_cols, output_cols_prefix
823
862
  )
824
863
 
@@ -886,7 +925,7 @@ class GradientBoostingClassifier(BaseTransformer):
886
925
  drop_input_cols=self._drop_input_cols,
887
926
  expected_output_cols_type="float",
888
927
  )
889
- expected_output_cols = self._align_expected_output_names(
928
+ expected_output_cols, _ = self._align_expected_output(
890
929
  inference_method, dataset, expected_output_cols, output_cols_prefix
891
930
  )
892
931
  elif isinstance(dataset, pd.DataFrame):
@@ -951,7 +990,7 @@ class GradientBoostingClassifier(BaseTransformer):
951
990
  drop_input_cols=self._drop_input_cols,
952
991
  expected_output_cols_type="float",
953
992
  )
954
- expected_output_cols = self._align_expected_output_names(
993
+ expected_output_cols, _ = self._align_expected_output(
955
994
  inference_method, dataset, expected_output_cols, output_cols_prefix
956
995
  )
957
996
 
@@ -1016,7 +1055,7 @@ class GradientBoostingClassifier(BaseTransformer):
1016
1055
  drop_input_cols = self._drop_input_cols,
1017
1056
  expected_output_cols_type="float",
1018
1057
  )
1019
- expected_output_cols = self._align_expected_output_names(
1058
+ expected_output_cols, _ = self._align_expected_output(
1020
1059
  inference_method, dataset, expected_output_cols, output_cols_prefix
1021
1060
  )
1022
1061
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -673,12 +670,23 @@ class GradientBoostingRegressor(BaseTransformer):
673
670
  autogenerated=self._autogenerated,
674
671
  subproject=_SUBPROJECT,
675
672
  )
676
- output_result, fitted_estimator = model_trainer.train_fit_predict(
677
- drop_input_cols=self._drop_input_cols,
678
- expected_output_cols_list=(
679
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
680
- ),
673
+ expected_output_cols = (
674
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
681
675
  )
676
+ if isinstance(dataset, DataFrame):
677
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
678
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
679
+ )
680
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
681
+ drop_input_cols=self._drop_input_cols,
682
+ expected_output_cols_list=expected_output_cols,
683
+ example_output_pd_df=example_output_pd_df,
684
+ )
685
+ else:
686
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
687
+ drop_input_cols=self._drop_input_cols,
688
+ expected_output_cols_list=expected_output_cols,
689
+ )
682
690
  self._sklearn_object = fitted_estimator
683
691
  self._is_fitted = True
684
692
  return output_result
@@ -757,12 +765,41 @@ class GradientBoostingRegressor(BaseTransformer):
757
765
 
758
766
  return rv
759
767
 
760
- def _align_expected_output_names(
761
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
762
- ) -> List[str]:
768
+ def _align_expected_output(
769
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
770
+ ) -> Tuple[List[str], pd.DataFrame]:
771
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
772
+ and output dataframe with 1 line.
773
+ If the method is fit_predict, run 2 lines of data.
774
+ """
763
775
  # in case the inferred output column names dimension is different
764
776
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
765
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
777
+
778
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
779
+ # so change the minimum of number of rows to 2
780
+ num_examples = 2
781
+ statement_params = telemetry.get_function_usage_statement_params(
782
+ project=_PROJECT,
783
+ subproject=_SUBPROJECT,
784
+ function_name=telemetry.get_statement_params_full_func_name(
785
+ inspect.currentframe(), GradientBoostingRegressor.__class__.__name__
786
+ ),
787
+ api_calls=[Session.call],
788
+ custom_tags={"autogen": True} if self._autogenerated else None,
789
+ )
790
+ if output_cols_prefix == "fit_predict_":
791
+ if hasattr(self._sklearn_object, "n_clusters"):
792
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
793
+ num_examples = self._sklearn_object.n_clusters
794
+ elif hasattr(self._sklearn_object, "min_samples"):
795
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
796
+ num_examples = self._sklearn_object.min_samples
797
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
798
+ # LocalOutlierFactor expects n_neighbors <= n_samples
799
+ num_examples = self._sklearn_object.n_neighbors
800
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
801
+ else:
802
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
766
803
 
767
804
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
768
805
  # seen during the fit.
@@ -774,12 +811,14 @@ class GradientBoostingRegressor(BaseTransformer):
774
811
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
775
812
  if self.sample_weight_col:
776
813
  output_df_columns_set -= set(self.sample_weight_col)
814
+
777
815
  # if the dimension of inferred output column names is correct; use it
778
816
  if len(expected_output_cols_list) == len(output_df_columns_set):
779
- return expected_output_cols_list
817
+ return expected_output_cols_list, output_df_pd
780
818
  # otherwise, use the sklearn estimator's output
781
819
  else:
782
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
820
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
821
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
783
822
 
784
823
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
785
824
  @telemetry.send_api_usage_telemetry(
@@ -825,7 +864,7 @@ class GradientBoostingRegressor(BaseTransformer):
825
864
  drop_input_cols=self._drop_input_cols,
826
865
  expected_output_cols_type="float",
827
866
  )
828
- expected_output_cols = self._align_expected_output_names(
867
+ expected_output_cols, _ = self._align_expected_output(
829
868
  inference_method, dataset, expected_output_cols, output_cols_prefix
830
869
  )
831
870
 
@@ -891,7 +930,7 @@ class GradientBoostingRegressor(BaseTransformer):
891
930
  drop_input_cols=self._drop_input_cols,
892
931
  expected_output_cols_type="float",
893
932
  )
894
- expected_output_cols = self._align_expected_output_names(
933
+ expected_output_cols, _ = self._align_expected_output(
895
934
  inference_method, dataset, expected_output_cols, output_cols_prefix
896
935
  )
897
936
  elif isinstance(dataset, pd.DataFrame):
@@ -954,7 +993,7 @@ class GradientBoostingRegressor(BaseTransformer):
954
993
  drop_input_cols=self._drop_input_cols,
955
994
  expected_output_cols_type="float",
956
995
  )
957
- expected_output_cols = self._align_expected_output_names(
996
+ expected_output_cols, _ = self._align_expected_output(
958
997
  inference_method, dataset, expected_output_cols, output_cols_prefix
959
998
  )
960
999
 
@@ -1019,7 +1058,7 @@ class GradientBoostingRegressor(BaseTransformer):
1019
1058
  drop_input_cols = self._drop_input_cols,
1020
1059
  expected_output_cols_type="float",
1021
1060
  )
1022
- expected_output_cols = self._align_expected_output_names(
1061
+ expected_output_cols, _ = self._align_expected_output(
1023
1062
  inference_method, dataset, expected_output_cols, output_cols_prefix
1024
1063
  )
1025
1064