snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. snowflake/ml/_internal/telemetry.py +142 -20
  2. snowflake/ml/_internal/utils/identifier.py +48 -11
  3. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  4. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  5. snowflake/ml/_internal/utils/table_manager.py +19 -1
  6. snowflake/ml/_internal/utils/uri.py +2 -2
  7. snowflake/ml/data/data_connector.py +33 -7
  8. snowflake/ml/data/torch_utils.py +68 -0
  9. snowflake/ml/dataset/dataset.py +1 -3
  10. snowflake/ml/feature_store/feature_store.py +41 -17
  11. snowflake/ml/feature_store/feature_view.py +2 -2
  12. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  13. snowflake/ml/fileset/fileset.py +1 -1
  14. snowflake/ml/fileset/sfcfs.py +9 -3
  15. snowflake/ml/model/_client/model/model_version_impl.py +22 -7
  16. snowflake/ml/model/_client/ops/model_ops.py +39 -3
  17. snowflake/ml/model/_client/ops/service_ops.py +198 -7
  18. snowflake/ml/model/_client/service/model_deployment_spec.py +4 -5
  19. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -2
  20. snowflake/ml/model/_client/sql/service.py +85 -18
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  23. snowflake/ml/model/_model_composer/model_composer.py +2 -0
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -8
  25. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  26. snowflake/ml/model/_packager/model_handlers/catboost.py +17 -15
  27. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +23 -15
  28. snowflake/ml/model/_packager/model_handlers/lightgbm.py +15 -57
  29. snowflake/ml/model/_packager/model_handlers/llm.py +4 -2
  30. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  31. snowflake/ml/model/_packager/model_handlers/sklearn.py +36 -24
  32. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +119 -6
  33. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  34. snowflake/ml/model/_packager/model_handlers/xgboost.py +48 -48
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +10 -7
  36. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +0 -8
  37. snowflake/ml/model/_packager/model_packager.py +2 -0
  38. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  39. snowflake/ml/model/_signatures/utils.py +9 -0
  40. snowflake/ml/model/models/llm.py +3 -1
  41. snowflake/ml/model/type_hints.py +9 -1
  42. snowflake/ml/modeling/_internal/constants.py +1 -0
  43. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  44. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  45. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  46. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  48. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  49. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  50. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  51. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  52. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  53. snowflake/ml/modeling/cluster/birch.py +60 -21
  54. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  55. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  56. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  57. snowflake/ml/modeling/cluster/k_means.py +60 -21
  58. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  59. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  60. snowflake/ml/modeling/cluster/optics.py +60 -21
  61. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  62. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  63. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  64. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  65. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  66. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  67. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  68. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  69. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  70. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  71. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  72. snowflake/ml/modeling/covariance/oas.py +60 -21
  73. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  74. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  75. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  76. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  77. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  78. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  79. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  80. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  81. snowflake/ml/modeling/decomposition/pca.py +60 -21
  82. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  83. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  84. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  85. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  86. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  87. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  88. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  89. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  90. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  91. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  92. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  93. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  94. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  95. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  96. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  97. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  98. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  99. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  100. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  101. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  102. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  103. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  104. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  105. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  106. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  107. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  108. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  109. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  111. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  112. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  113. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  114. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  123. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  125. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  129. snowflake/ml/modeling/linear_model/lars.py +60 -21
  130. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  131. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  136. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  146. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  149. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  158. snowflake/ml/modeling/manifold/isomap.py +60 -21
  159. snowflake/ml/modeling/manifold/mds.py +60 -21
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  161. snowflake/ml/modeling/manifold/tsne.py +60 -21
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  174. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  184. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  185. snowflake/ml/modeling/pipeline/pipeline.py +1 -12
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  189. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  190. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  191. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  192. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  193. snowflake/ml/modeling/svm/svc.py +60 -21
  194. snowflake/ml/modeling/svm/svr.py +60 -21
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  203. snowflake/ml/registry/_manager/model_manager.py +4 -0
  204. snowflake/ml/registry/model_registry.py +1 -1
  205. snowflake/ml/registry/registry.py +1 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +23 -4
  208. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +211 -209
  209. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  210. snowflake/ml/data/torch_dataset.py +0 -33
  211. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  212. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -594,12 +591,23 @@ class SpectralClustering(BaseTransformer):
594
591
  autogenerated=self._autogenerated,
595
592
  subproject=_SUBPROJECT,
596
593
  )
597
- output_result, fitted_estimator = model_trainer.train_fit_predict(
598
- drop_input_cols=self._drop_input_cols,
599
- expected_output_cols_list=(
600
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
601
- ),
594
+ expected_output_cols = (
595
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
602
596
  )
597
+ if isinstance(dataset, DataFrame):
598
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
599
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
600
+ )
601
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
602
+ drop_input_cols=self._drop_input_cols,
603
+ expected_output_cols_list=expected_output_cols,
604
+ example_output_pd_df=example_output_pd_df,
605
+ )
606
+ else:
607
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
608
+ drop_input_cols=self._drop_input_cols,
609
+ expected_output_cols_list=expected_output_cols,
610
+ )
603
611
  self._sklearn_object = fitted_estimator
604
612
  self._is_fitted = True
605
613
  return output_result
@@ -678,12 +686,41 @@ class SpectralClustering(BaseTransformer):
678
686
 
679
687
  return rv
680
688
 
681
- def _align_expected_output_names(
682
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
683
- ) -> List[str]:
689
+ def _align_expected_output(
690
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
691
+ ) -> Tuple[List[str], pd.DataFrame]:
692
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
693
+ and output dataframe with 1 line.
694
+ If the method is fit_predict, run 2 lines of data.
695
+ """
684
696
  # in case the inferred output column names dimension is different
685
697
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
686
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
698
+
699
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
700
+ # so change the minimum of number of rows to 2
701
+ num_examples = 2
702
+ statement_params = telemetry.get_function_usage_statement_params(
703
+ project=_PROJECT,
704
+ subproject=_SUBPROJECT,
705
+ function_name=telemetry.get_statement_params_full_func_name(
706
+ inspect.currentframe(), SpectralClustering.__class__.__name__
707
+ ),
708
+ api_calls=[Session.call],
709
+ custom_tags={"autogen": True} if self._autogenerated else None,
710
+ )
711
+ if output_cols_prefix == "fit_predict_":
712
+ if hasattr(self._sklearn_object, "n_clusters"):
713
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
714
+ num_examples = self._sklearn_object.n_clusters
715
+ elif hasattr(self._sklearn_object, "min_samples"):
716
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
717
+ num_examples = self._sklearn_object.min_samples
718
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
719
+ # LocalOutlierFactor expects n_neighbors <= n_samples
720
+ num_examples = self._sklearn_object.n_neighbors
721
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
722
+ else:
723
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
687
724
 
688
725
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
689
726
  # seen during the fit.
@@ -695,12 +732,14 @@ class SpectralClustering(BaseTransformer):
695
732
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
696
733
  if self.sample_weight_col:
697
734
  output_df_columns_set -= set(self.sample_weight_col)
735
+
698
736
  # if the dimension of inferred output column names is correct; use it
699
737
  if len(expected_output_cols_list) == len(output_df_columns_set):
700
- return expected_output_cols_list
738
+ return expected_output_cols_list, output_df_pd
701
739
  # otherwise, use the sklearn estimator's output
702
740
  else:
703
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
741
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
742
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
704
743
 
705
744
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
706
745
  @telemetry.send_api_usage_telemetry(
@@ -746,7 +785,7 @@ class SpectralClustering(BaseTransformer):
746
785
  drop_input_cols=self._drop_input_cols,
747
786
  expected_output_cols_type="float",
748
787
  )
749
- expected_output_cols = self._align_expected_output_names(
788
+ expected_output_cols, _ = self._align_expected_output(
750
789
  inference_method, dataset, expected_output_cols, output_cols_prefix
751
790
  )
752
791
 
@@ -812,7 +851,7 @@ class SpectralClustering(BaseTransformer):
812
851
  drop_input_cols=self._drop_input_cols,
813
852
  expected_output_cols_type="float",
814
853
  )
815
- expected_output_cols = self._align_expected_output_names(
854
+ expected_output_cols, _ = self._align_expected_output(
816
855
  inference_method, dataset, expected_output_cols, output_cols_prefix
817
856
  )
818
857
  elif isinstance(dataset, pd.DataFrame):
@@ -875,7 +914,7 @@ class SpectralClustering(BaseTransformer):
875
914
  drop_input_cols=self._drop_input_cols,
876
915
  expected_output_cols_type="float",
877
916
  )
878
- expected_output_cols = self._align_expected_output_names(
917
+ expected_output_cols, _ = self._align_expected_output(
879
918
  inference_method, dataset, expected_output_cols, output_cols_prefix
880
919
  )
881
920
 
@@ -940,7 +979,7 @@ class SpectralClustering(BaseTransformer):
940
979
  drop_input_cols = self._drop_input_cols,
941
980
  expected_output_cols_type="float",
942
981
  )
943
- expected_output_cols = self._align_expected_output_names(
982
+ expected_output_cols, _ = self._align_expected_output(
944
983
  inference_method, dataset, expected_output_cols, output_cols_prefix
945
984
  )
946
985
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -513,12 +510,23 @@ class SpectralCoclustering(BaseTransformer):
513
510
  autogenerated=self._autogenerated,
514
511
  subproject=_SUBPROJECT,
515
512
  )
516
- output_result, fitted_estimator = model_trainer.train_fit_predict(
517
- drop_input_cols=self._drop_input_cols,
518
- expected_output_cols_list=(
519
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
520
- ),
513
+ expected_output_cols = (
514
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
521
515
  )
516
+ if isinstance(dataset, DataFrame):
517
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
518
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
519
+ )
520
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
521
+ drop_input_cols=self._drop_input_cols,
522
+ expected_output_cols_list=expected_output_cols,
523
+ example_output_pd_df=example_output_pd_df,
524
+ )
525
+ else:
526
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
527
+ drop_input_cols=self._drop_input_cols,
528
+ expected_output_cols_list=expected_output_cols,
529
+ )
522
530
  self._sklearn_object = fitted_estimator
523
531
  self._is_fitted = True
524
532
  return output_result
@@ -597,12 +605,41 @@ class SpectralCoclustering(BaseTransformer):
597
605
 
598
606
  return rv
599
607
 
600
- def _align_expected_output_names(
601
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
602
- ) -> List[str]:
608
+ def _align_expected_output(
609
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
610
+ ) -> Tuple[List[str], pd.DataFrame]:
611
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
612
+ and output dataframe with 1 line.
613
+ If the method is fit_predict, run 2 lines of data.
614
+ """
603
615
  # in case the inferred output column names dimension is different
604
616
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
605
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
617
+
618
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
619
+ # so change the minimum of number of rows to 2
620
+ num_examples = 2
621
+ statement_params = telemetry.get_function_usage_statement_params(
622
+ project=_PROJECT,
623
+ subproject=_SUBPROJECT,
624
+ function_name=telemetry.get_statement_params_full_func_name(
625
+ inspect.currentframe(), SpectralCoclustering.__class__.__name__
626
+ ),
627
+ api_calls=[Session.call],
628
+ custom_tags={"autogen": True} if self._autogenerated else None,
629
+ )
630
+ if output_cols_prefix == "fit_predict_":
631
+ if hasattr(self._sklearn_object, "n_clusters"):
632
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
633
+ num_examples = self._sklearn_object.n_clusters
634
+ elif hasattr(self._sklearn_object, "min_samples"):
635
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
636
+ num_examples = self._sklearn_object.min_samples
637
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
638
+ # LocalOutlierFactor expects n_neighbors <= n_samples
639
+ num_examples = self._sklearn_object.n_neighbors
640
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
641
+ else:
642
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
606
643
 
607
644
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
608
645
  # seen during the fit.
@@ -614,12 +651,14 @@ class SpectralCoclustering(BaseTransformer):
614
651
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
615
652
  if self.sample_weight_col:
616
653
  output_df_columns_set -= set(self.sample_weight_col)
654
+
617
655
  # if the dimension of inferred output column names is correct; use it
618
656
  if len(expected_output_cols_list) == len(output_df_columns_set):
619
- return expected_output_cols_list
657
+ return expected_output_cols_list, output_df_pd
620
658
  # otherwise, use the sklearn estimator's output
621
659
  else:
622
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
660
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
661
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
623
662
 
624
663
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
625
664
  @telemetry.send_api_usage_telemetry(
@@ -665,7 +704,7 @@ class SpectralCoclustering(BaseTransformer):
665
704
  drop_input_cols=self._drop_input_cols,
666
705
  expected_output_cols_type="float",
667
706
  )
668
- expected_output_cols = self._align_expected_output_names(
707
+ expected_output_cols, _ = self._align_expected_output(
669
708
  inference_method, dataset, expected_output_cols, output_cols_prefix
670
709
  )
671
710
 
@@ -731,7 +770,7 @@ class SpectralCoclustering(BaseTransformer):
731
770
  drop_input_cols=self._drop_input_cols,
732
771
  expected_output_cols_type="float",
733
772
  )
734
- expected_output_cols = self._align_expected_output_names(
773
+ expected_output_cols, _ = self._align_expected_output(
735
774
  inference_method, dataset, expected_output_cols, output_cols_prefix
736
775
  )
737
776
  elif isinstance(dataset, pd.DataFrame):
@@ -794,7 +833,7 @@ class SpectralCoclustering(BaseTransformer):
794
833
  drop_input_cols=self._drop_input_cols,
795
834
  expected_output_cols_type="float",
796
835
  )
797
- expected_output_cols = self._align_expected_output_names(
836
+ expected_output_cols, _ = self._align_expected_output(
798
837
  inference_method, dataset, expected_output_cols, output_cols_prefix
799
838
  )
800
839
 
@@ -859,7 +898,7 @@ class SpectralCoclustering(BaseTransformer):
859
898
  drop_input_cols = self._drop_input_cols,
860
899
  expected_output_cols_type="float",
861
900
  )
862
- expected_output_cols = self._align_expected_output_names(
901
+ expected_output_cols, _ = self._align_expected_output(
863
902
  inference_method, dataset, expected_output_cols, output_cols_prefix
864
903
  )
865
904
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -545,12 +542,23 @@ class ColumnTransformer(BaseTransformer):
545
542
  autogenerated=self._autogenerated,
546
543
  subproject=_SUBPROJECT,
547
544
  )
548
- output_result, fitted_estimator = model_trainer.train_fit_predict(
549
- drop_input_cols=self._drop_input_cols,
550
- expected_output_cols_list=(
551
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
552
- ),
545
+ expected_output_cols = (
546
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
553
547
  )
548
+ if isinstance(dataset, DataFrame):
549
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
550
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
551
+ )
552
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
553
+ drop_input_cols=self._drop_input_cols,
554
+ expected_output_cols_list=expected_output_cols,
555
+ example_output_pd_df=example_output_pd_df,
556
+ )
557
+ else:
558
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
559
+ drop_input_cols=self._drop_input_cols,
560
+ expected_output_cols_list=expected_output_cols,
561
+ )
554
562
  self._sklearn_object = fitted_estimator
555
563
  self._is_fitted = True
556
564
  return output_result
@@ -631,12 +639,41 @@ class ColumnTransformer(BaseTransformer):
631
639
 
632
640
  return rv
633
641
 
634
- def _align_expected_output_names(
635
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
636
- ) -> List[str]:
642
+ def _align_expected_output(
643
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
644
+ ) -> Tuple[List[str], pd.DataFrame]:
645
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
646
+ and output dataframe with 1 line.
647
+ If the method is fit_predict, run 2 lines of data.
648
+ """
637
649
  # in case the inferred output column names dimension is different
638
650
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
639
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
651
+
652
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
653
+ # so change the minimum of number of rows to 2
654
+ num_examples = 2
655
+ statement_params = telemetry.get_function_usage_statement_params(
656
+ project=_PROJECT,
657
+ subproject=_SUBPROJECT,
658
+ function_name=telemetry.get_statement_params_full_func_name(
659
+ inspect.currentframe(), ColumnTransformer.__class__.__name__
660
+ ),
661
+ api_calls=[Session.call],
662
+ custom_tags={"autogen": True} if self._autogenerated else None,
663
+ )
664
+ if output_cols_prefix == "fit_predict_":
665
+ if hasattr(self._sklearn_object, "n_clusters"):
666
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
667
+ num_examples = self._sklearn_object.n_clusters
668
+ elif hasattr(self._sklearn_object, "min_samples"):
669
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
670
+ num_examples = self._sklearn_object.min_samples
671
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
672
+ # LocalOutlierFactor expects n_neighbors <= n_samples
673
+ num_examples = self._sklearn_object.n_neighbors
674
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
675
+ else:
676
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
640
677
 
641
678
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
642
679
  # seen during the fit.
@@ -648,12 +685,14 @@ class ColumnTransformer(BaseTransformer):
648
685
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
649
686
  if self.sample_weight_col:
650
687
  output_df_columns_set -= set(self.sample_weight_col)
688
+
651
689
  # if the dimension of inferred output column names is correct; use it
652
690
  if len(expected_output_cols_list) == len(output_df_columns_set):
653
- return expected_output_cols_list
691
+ return expected_output_cols_list, output_df_pd
654
692
  # otherwise, use the sklearn estimator's output
655
693
  else:
656
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
694
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
695
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
657
696
 
658
697
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
659
698
  @telemetry.send_api_usage_telemetry(
@@ -699,7 +738,7 @@ class ColumnTransformer(BaseTransformer):
699
738
  drop_input_cols=self._drop_input_cols,
700
739
  expected_output_cols_type="float",
701
740
  )
702
- expected_output_cols = self._align_expected_output_names(
741
+ expected_output_cols, _ = self._align_expected_output(
703
742
  inference_method, dataset, expected_output_cols, output_cols_prefix
704
743
  )
705
744
 
@@ -765,7 +804,7 @@ class ColumnTransformer(BaseTransformer):
765
804
  drop_input_cols=self._drop_input_cols,
766
805
  expected_output_cols_type="float",
767
806
  )
768
- expected_output_cols = self._align_expected_output_names(
807
+ expected_output_cols, _ = self._align_expected_output(
769
808
  inference_method, dataset, expected_output_cols, output_cols_prefix
770
809
  )
771
810
  elif isinstance(dataset, pd.DataFrame):
@@ -828,7 +867,7 @@ class ColumnTransformer(BaseTransformer):
828
867
  drop_input_cols=self._drop_input_cols,
829
868
  expected_output_cols_type="float",
830
869
  )
831
- expected_output_cols = self._align_expected_output_names(
870
+ expected_output_cols, _ = self._align_expected_output(
832
871
  inference_method, dataset, expected_output_cols, output_cols_prefix
833
872
  )
834
873
 
@@ -893,7 +932,7 @@ class ColumnTransformer(BaseTransformer):
893
932
  drop_input_cols = self._drop_input_cols,
894
933
  expected_output_cols_type="float",
895
934
  )
896
- expected_output_cols = self._align_expected_output_names(
935
+ expected_output_cols, _ = self._align_expected_output(
897
936
  inference_method, dataset, expected_output_cols, output_cols_prefix
898
937
  )
899
938
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -504,12 +501,23 @@ class TransformedTargetRegressor(BaseTransformer):
504
501
  autogenerated=self._autogenerated,
505
502
  subproject=_SUBPROJECT,
506
503
  )
507
- output_result, fitted_estimator = model_trainer.train_fit_predict(
508
- drop_input_cols=self._drop_input_cols,
509
- expected_output_cols_list=(
510
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
511
- ),
504
+ expected_output_cols = (
505
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
512
506
  )
507
+ if isinstance(dataset, DataFrame):
508
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
509
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
510
+ )
511
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
512
+ drop_input_cols=self._drop_input_cols,
513
+ expected_output_cols_list=expected_output_cols,
514
+ example_output_pd_df=example_output_pd_df,
515
+ )
516
+ else:
517
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
518
+ drop_input_cols=self._drop_input_cols,
519
+ expected_output_cols_list=expected_output_cols,
520
+ )
513
521
  self._sklearn_object = fitted_estimator
514
522
  self._is_fitted = True
515
523
  return output_result
@@ -588,12 +596,41 @@ class TransformedTargetRegressor(BaseTransformer):
588
596
 
589
597
  return rv
590
598
 
591
- def _align_expected_output_names(
592
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
593
- ) -> List[str]:
599
+ def _align_expected_output(
600
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
601
+ ) -> Tuple[List[str], pd.DataFrame]:
602
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
603
+ and output dataframe with 1 line.
604
+ If the method is fit_predict, run 2 lines of data.
605
+ """
594
606
  # in case the inferred output column names dimension is different
595
607
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
596
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
608
+
609
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
610
+ # so change the minimum of number of rows to 2
611
+ num_examples = 2
612
+ statement_params = telemetry.get_function_usage_statement_params(
613
+ project=_PROJECT,
614
+ subproject=_SUBPROJECT,
615
+ function_name=telemetry.get_statement_params_full_func_name(
616
+ inspect.currentframe(), TransformedTargetRegressor.__class__.__name__
617
+ ),
618
+ api_calls=[Session.call],
619
+ custom_tags={"autogen": True} if self._autogenerated else None,
620
+ )
621
+ if output_cols_prefix == "fit_predict_":
622
+ if hasattr(self._sklearn_object, "n_clusters"):
623
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
624
+ num_examples = self._sklearn_object.n_clusters
625
+ elif hasattr(self._sklearn_object, "min_samples"):
626
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
627
+ num_examples = self._sklearn_object.min_samples
628
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
629
+ # LocalOutlierFactor expects n_neighbors <= n_samples
630
+ num_examples = self._sklearn_object.n_neighbors
631
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
632
+ else:
633
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
597
634
 
598
635
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
599
636
  # seen during the fit.
@@ -605,12 +642,14 @@ class TransformedTargetRegressor(BaseTransformer):
605
642
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
606
643
  if self.sample_weight_col:
607
644
  output_df_columns_set -= set(self.sample_weight_col)
645
+
608
646
  # if the dimension of inferred output column names is correct; use it
609
647
  if len(expected_output_cols_list) == len(output_df_columns_set):
610
- return expected_output_cols_list
648
+ return expected_output_cols_list, output_df_pd
611
649
  # otherwise, use the sklearn estimator's output
612
650
  else:
613
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
651
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
652
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
614
653
 
615
654
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
616
655
  @telemetry.send_api_usage_telemetry(
@@ -656,7 +695,7 @@ class TransformedTargetRegressor(BaseTransformer):
656
695
  drop_input_cols=self._drop_input_cols,
657
696
  expected_output_cols_type="float",
658
697
  )
659
- expected_output_cols = self._align_expected_output_names(
698
+ expected_output_cols, _ = self._align_expected_output(
660
699
  inference_method, dataset, expected_output_cols, output_cols_prefix
661
700
  )
662
701
 
@@ -722,7 +761,7 @@ class TransformedTargetRegressor(BaseTransformer):
722
761
  drop_input_cols=self._drop_input_cols,
723
762
  expected_output_cols_type="float",
724
763
  )
725
- expected_output_cols = self._align_expected_output_names(
764
+ expected_output_cols, _ = self._align_expected_output(
726
765
  inference_method, dataset, expected_output_cols, output_cols_prefix
727
766
  )
728
767
  elif isinstance(dataset, pd.DataFrame):
@@ -785,7 +824,7 @@ class TransformedTargetRegressor(BaseTransformer):
785
824
  drop_input_cols=self._drop_input_cols,
786
825
  expected_output_cols_type="float",
787
826
  )
788
- expected_output_cols = self._align_expected_output_names(
827
+ expected_output_cols, _ = self._align_expected_output(
789
828
  inference_method, dataset, expected_output_cols, output_cols_prefix
790
829
  )
791
830
 
@@ -850,7 +889,7 @@ class TransformedTargetRegressor(BaseTransformer):
850
889
  drop_input_cols = self._drop_input_cols,
851
890
  expected_output_cols_type="float",
852
891
  )
853
- expected_output_cols = self._align_expected_output_names(
892
+ expected_output_cols, _ = self._align_expected_output(
854
893
  inference_method, dataset, expected_output_cols, output_cols_prefix
855
894
  )
856
895