snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. snowflake/ml/_internal/telemetry.py +142 -20
  2. snowflake/ml/_internal/utils/identifier.py +48 -11
  3. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  4. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  5. snowflake/ml/_internal/utils/table_manager.py +19 -1
  6. snowflake/ml/_internal/utils/uri.py +2 -2
  7. snowflake/ml/data/data_connector.py +33 -7
  8. snowflake/ml/data/torch_utils.py +68 -0
  9. snowflake/ml/dataset/dataset.py +1 -3
  10. snowflake/ml/feature_store/feature_store.py +41 -17
  11. snowflake/ml/feature_store/feature_view.py +2 -2
  12. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  13. snowflake/ml/fileset/fileset.py +1 -1
  14. snowflake/ml/fileset/sfcfs.py +9 -3
  15. snowflake/ml/model/_client/model/model_version_impl.py +22 -7
  16. snowflake/ml/model/_client/ops/model_ops.py +39 -3
  17. snowflake/ml/model/_client/ops/service_ops.py +198 -7
  18. snowflake/ml/model/_client/service/model_deployment_spec.py +4 -5
  19. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -2
  20. snowflake/ml/model/_client/sql/service.py +85 -18
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  23. snowflake/ml/model/_model_composer/model_composer.py +2 -0
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -8
  25. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  26. snowflake/ml/model/_packager/model_handlers/catboost.py +17 -15
  27. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +23 -15
  28. snowflake/ml/model/_packager/model_handlers/lightgbm.py +15 -57
  29. snowflake/ml/model/_packager/model_handlers/llm.py +4 -2
  30. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  31. snowflake/ml/model/_packager/model_handlers/sklearn.py +36 -24
  32. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +119 -6
  33. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  34. snowflake/ml/model/_packager/model_handlers/xgboost.py +48 -48
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +10 -7
  36. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +0 -8
  37. snowflake/ml/model/_packager/model_packager.py +2 -0
  38. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  39. snowflake/ml/model/_signatures/utils.py +9 -0
  40. snowflake/ml/model/models/llm.py +3 -1
  41. snowflake/ml/model/type_hints.py +9 -1
  42. snowflake/ml/modeling/_internal/constants.py +1 -0
  43. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  44. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  45. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  46. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  48. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  49. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  50. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  51. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  52. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  53. snowflake/ml/modeling/cluster/birch.py +60 -21
  54. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  55. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  56. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  57. snowflake/ml/modeling/cluster/k_means.py +60 -21
  58. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  59. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  60. snowflake/ml/modeling/cluster/optics.py +60 -21
  61. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  62. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  63. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  64. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  65. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  66. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  67. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  68. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  69. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  70. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  71. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  72. snowflake/ml/modeling/covariance/oas.py +60 -21
  73. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  74. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  75. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  76. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  77. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  78. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  79. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  80. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  81. snowflake/ml/modeling/decomposition/pca.py +60 -21
  82. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  83. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  84. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  85. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  86. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  87. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  88. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  89. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  90. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  91. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  92. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  93. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  94. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  95. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  96. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  97. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  98. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  99. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  100. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  101. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  102. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  103. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  104. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  105. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  106. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  107. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  108. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  109. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  111. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  112. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  113. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  114. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  123. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  125. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  129. snowflake/ml/modeling/linear_model/lars.py +60 -21
  130. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  131. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  136. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  146. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  149. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  158. snowflake/ml/modeling/manifold/isomap.py +60 -21
  159. snowflake/ml/modeling/manifold/mds.py +60 -21
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  161. snowflake/ml/modeling/manifold/tsne.py +60 -21
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  174. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  184. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  185. snowflake/ml/modeling/pipeline/pipeline.py +1 -12
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  189. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  190. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  191. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  192. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  193. snowflake/ml/modeling/svm/svc.py +60 -21
  194. snowflake/ml/modeling/svm/svr.py +60 -21
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  203. snowflake/ml/registry/_manager/model_manager.py +4 -0
  204. snowflake/ml/registry/model_registry.py +1 -1
  205. snowflake/ml/registry/registry.py +1 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +23 -4
  208. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +211 -209
  209. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  210. snowflake/ml/data/torch_dataset.py +0 -33
  211. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  212. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -540,12 +537,23 @@ class CalibratedClassifierCV(BaseTransformer):
540
537
  autogenerated=self._autogenerated,
541
538
  subproject=_SUBPROJECT,
542
539
  )
543
- output_result, fitted_estimator = model_trainer.train_fit_predict(
544
- drop_input_cols=self._drop_input_cols,
545
- expected_output_cols_list=(
546
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
547
- ),
540
+ expected_output_cols = (
541
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
548
542
  )
543
+ if isinstance(dataset, DataFrame):
544
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
545
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
546
+ )
547
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
548
+ drop_input_cols=self._drop_input_cols,
549
+ expected_output_cols_list=expected_output_cols,
550
+ example_output_pd_df=example_output_pd_df,
551
+ )
552
+ else:
553
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
554
+ drop_input_cols=self._drop_input_cols,
555
+ expected_output_cols_list=expected_output_cols,
556
+ )
549
557
  self._sklearn_object = fitted_estimator
550
558
  self._is_fitted = True
551
559
  return output_result
@@ -624,12 +632,41 @@ class CalibratedClassifierCV(BaseTransformer):
624
632
 
625
633
  return rv
626
634
 
627
- def _align_expected_output_names(
628
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
629
- ) -> List[str]:
635
+ def _align_expected_output(
636
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
637
+ ) -> Tuple[List[str], pd.DataFrame]:
638
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
639
+ and output dataframe with 1 line.
640
+ If the method is fit_predict, run 2 lines of data.
641
+ """
630
642
  # in case the inferred output column names dimension is different
631
643
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
632
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
644
+
645
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
646
+ # so change the minimum of number of rows to 2
647
+ num_examples = 2
648
+ statement_params = telemetry.get_function_usage_statement_params(
649
+ project=_PROJECT,
650
+ subproject=_SUBPROJECT,
651
+ function_name=telemetry.get_statement_params_full_func_name(
652
+ inspect.currentframe(), CalibratedClassifierCV.__class__.__name__
653
+ ),
654
+ api_calls=[Session.call],
655
+ custom_tags={"autogen": True} if self._autogenerated else None,
656
+ )
657
+ if output_cols_prefix == "fit_predict_":
658
+ if hasattr(self._sklearn_object, "n_clusters"):
659
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
660
+ num_examples = self._sklearn_object.n_clusters
661
+ elif hasattr(self._sklearn_object, "min_samples"):
662
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
663
+ num_examples = self._sklearn_object.min_samples
664
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
665
+ # LocalOutlierFactor expects n_neighbors <= n_samples
666
+ num_examples = self._sklearn_object.n_neighbors
667
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
668
+ else:
669
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
633
670
 
634
671
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
635
672
  # seen during the fit.
@@ -641,12 +678,14 @@ class CalibratedClassifierCV(BaseTransformer):
641
678
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
642
679
  if self.sample_weight_col:
643
680
  output_df_columns_set -= set(self.sample_weight_col)
681
+
644
682
  # if the dimension of inferred output column names is correct; use it
645
683
  if len(expected_output_cols_list) == len(output_df_columns_set):
646
- return expected_output_cols_list
684
+ return expected_output_cols_list, output_df_pd
647
685
  # otherwise, use the sklearn estimator's output
648
686
  else:
649
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
687
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
688
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
650
689
 
651
690
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
652
691
  @telemetry.send_api_usage_telemetry(
@@ -694,7 +733,7 @@ class CalibratedClassifierCV(BaseTransformer):
694
733
  drop_input_cols=self._drop_input_cols,
695
734
  expected_output_cols_type="float",
696
735
  )
697
- expected_output_cols = self._align_expected_output_names(
736
+ expected_output_cols, _ = self._align_expected_output(
698
737
  inference_method, dataset, expected_output_cols, output_cols_prefix
699
738
  )
700
739
 
@@ -762,7 +801,7 @@ class CalibratedClassifierCV(BaseTransformer):
762
801
  drop_input_cols=self._drop_input_cols,
763
802
  expected_output_cols_type="float",
764
803
  )
765
- expected_output_cols = self._align_expected_output_names(
804
+ expected_output_cols, _ = self._align_expected_output(
766
805
  inference_method, dataset, expected_output_cols, output_cols_prefix
767
806
  )
768
807
  elif isinstance(dataset, pd.DataFrame):
@@ -825,7 +864,7 @@ class CalibratedClassifierCV(BaseTransformer):
825
864
  drop_input_cols=self._drop_input_cols,
826
865
  expected_output_cols_type="float",
827
866
  )
828
- expected_output_cols = self._align_expected_output_names(
867
+ expected_output_cols, _ = self._align_expected_output(
829
868
  inference_method, dataset, expected_output_cols, output_cols_prefix
830
869
  )
831
870
 
@@ -890,7 +929,7 @@ class CalibratedClassifierCV(BaseTransformer):
890
929
  drop_input_cols = self._drop_input_cols,
891
930
  expected_output_cols_type="float",
892
931
  )
893
- expected_output_cols = self._align_expected_output_names(
932
+ expected_output_cols, _ = self._align_expected_output(
894
933
  inference_method, dataset, expected_output_cols, output_cols_prefix
895
934
  )
896
935
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -519,12 +516,23 @@ class AffinityPropagation(BaseTransformer):
519
516
  autogenerated=self._autogenerated,
520
517
  subproject=_SUBPROJECT,
521
518
  )
522
- output_result, fitted_estimator = model_trainer.train_fit_predict(
523
- drop_input_cols=self._drop_input_cols,
524
- expected_output_cols_list=(
525
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
526
- ),
519
+ expected_output_cols = (
520
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
527
521
  )
522
+ if isinstance(dataset, DataFrame):
523
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
524
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
525
+ )
526
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
527
+ drop_input_cols=self._drop_input_cols,
528
+ expected_output_cols_list=expected_output_cols,
529
+ example_output_pd_df=example_output_pd_df,
530
+ )
531
+ else:
532
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
533
+ drop_input_cols=self._drop_input_cols,
534
+ expected_output_cols_list=expected_output_cols,
535
+ )
528
536
  self._sklearn_object = fitted_estimator
529
537
  self._is_fitted = True
530
538
  return output_result
@@ -603,12 +611,41 @@ class AffinityPropagation(BaseTransformer):
603
611
 
604
612
  return rv
605
613
 
606
- def _align_expected_output_names(
607
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
608
- ) -> List[str]:
614
+ def _align_expected_output(
615
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
616
+ ) -> Tuple[List[str], pd.DataFrame]:
617
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
618
+ and output dataframe with 1 line.
619
+ If the method is fit_predict, run 2 lines of data.
620
+ """
609
621
  # in case the inferred output column names dimension is different
610
622
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
611
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
623
+
624
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
625
+ # so change the minimum of number of rows to 2
626
+ num_examples = 2
627
+ statement_params = telemetry.get_function_usage_statement_params(
628
+ project=_PROJECT,
629
+ subproject=_SUBPROJECT,
630
+ function_name=telemetry.get_statement_params_full_func_name(
631
+ inspect.currentframe(), AffinityPropagation.__class__.__name__
632
+ ),
633
+ api_calls=[Session.call],
634
+ custom_tags={"autogen": True} if self._autogenerated else None,
635
+ )
636
+ if output_cols_prefix == "fit_predict_":
637
+ if hasattr(self._sklearn_object, "n_clusters"):
638
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
639
+ num_examples = self._sklearn_object.n_clusters
640
+ elif hasattr(self._sklearn_object, "min_samples"):
641
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
642
+ num_examples = self._sklearn_object.min_samples
643
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
644
+ # LocalOutlierFactor expects n_neighbors <= n_samples
645
+ num_examples = self._sklearn_object.n_neighbors
646
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
647
+ else:
648
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
612
649
 
613
650
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
614
651
  # seen during the fit.
@@ -620,12 +657,14 @@ class AffinityPropagation(BaseTransformer):
620
657
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
621
658
  if self.sample_weight_col:
622
659
  output_df_columns_set -= set(self.sample_weight_col)
660
+
623
661
  # if the dimension of inferred output column names is correct; use it
624
662
  if len(expected_output_cols_list) == len(output_df_columns_set):
625
- return expected_output_cols_list
663
+ return expected_output_cols_list, output_df_pd
626
664
  # otherwise, use the sklearn estimator's output
627
665
  else:
628
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
666
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
667
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
629
668
 
630
669
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
631
670
  @telemetry.send_api_usage_telemetry(
@@ -671,7 +710,7 @@ class AffinityPropagation(BaseTransformer):
671
710
  drop_input_cols=self._drop_input_cols,
672
711
  expected_output_cols_type="float",
673
712
  )
674
- expected_output_cols = self._align_expected_output_names(
713
+ expected_output_cols, _ = self._align_expected_output(
675
714
  inference_method, dataset, expected_output_cols, output_cols_prefix
676
715
  )
677
716
 
@@ -737,7 +776,7 @@ class AffinityPropagation(BaseTransformer):
737
776
  drop_input_cols=self._drop_input_cols,
738
777
  expected_output_cols_type="float",
739
778
  )
740
- expected_output_cols = self._align_expected_output_names(
779
+ expected_output_cols, _ = self._align_expected_output(
741
780
  inference_method, dataset, expected_output_cols, output_cols_prefix
742
781
  )
743
782
  elif isinstance(dataset, pd.DataFrame):
@@ -800,7 +839,7 @@ class AffinityPropagation(BaseTransformer):
800
839
  drop_input_cols=self._drop_input_cols,
801
840
  expected_output_cols_type="float",
802
841
  )
803
- expected_output_cols = self._align_expected_output_names(
842
+ expected_output_cols, _ = self._align_expected_output(
804
843
  inference_method, dataset, expected_output_cols, output_cols_prefix
805
844
  )
806
845
 
@@ -865,7 +904,7 @@ class AffinityPropagation(BaseTransformer):
865
904
  drop_input_cols = self._drop_input_cols,
866
905
  expected_output_cols_type="float",
867
906
  )
868
- expected_output_cols = self._align_expected_output_names(
907
+ expected_output_cols, _ = self._align_expected_output(
869
908
  inference_method, dataset, expected_output_cols, output_cols_prefix
870
909
  )
871
910
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -550,12 +547,23 @@ class AgglomerativeClustering(BaseTransformer):
550
547
  autogenerated=self._autogenerated,
551
548
  subproject=_SUBPROJECT,
552
549
  )
553
- output_result, fitted_estimator = model_trainer.train_fit_predict(
554
- drop_input_cols=self._drop_input_cols,
555
- expected_output_cols_list=(
556
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
557
- ),
550
+ expected_output_cols = (
551
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
558
552
  )
553
+ if isinstance(dataset, DataFrame):
554
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
555
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
556
+ )
557
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
558
+ drop_input_cols=self._drop_input_cols,
559
+ expected_output_cols_list=expected_output_cols,
560
+ example_output_pd_df=example_output_pd_df,
561
+ )
562
+ else:
563
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
564
+ drop_input_cols=self._drop_input_cols,
565
+ expected_output_cols_list=expected_output_cols,
566
+ )
559
567
  self._sklearn_object = fitted_estimator
560
568
  self._is_fitted = True
561
569
  return output_result
@@ -634,12 +642,41 @@ class AgglomerativeClustering(BaseTransformer):
634
642
 
635
643
  return rv
636
644
 
637
- def _align_expected_output_names(
638
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
639
- ) -> List[str]:
645
+ def _align_expected_output(
646
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
647
+ ) -> Tuple[List[str], pd.DataFrame]:
648
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
649
+ and output dataframe with 1 line.
650
+ If the method is fit_predict, run 2 lines of data.
651
+ """
640
652
  # in case the inferred output column names dimension is different
641
653
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
642
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
654
+
655
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
656
+ # so change the minimum of number of rows to 2
657
+ num_examples = 2
658
+ statement_params = telemetry.get_function_usage_statement_params(
659
+ project=_PROJECT,
660
+ subproject=_SUBPROJECT,
661
+ function_name=telemetry.get_statement_params_full_func_name(
662
+ inspect.currentframe(), AgglomerativeClustering.__class__.__name__
663
+ ),
664
+ api_calls=[Session.call],
665
+ custom_tags={"autogen": True} if self._autogenerated else None,
666
+ )
667
+ if output_cols_prefix == "fit_predict_":
668
+ if hasattr(self._sklearn_object, "n_clusters"):
669
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
670
+ num_examples = self._sklearn_object.n_clusters
671
+ elif hasattr(self._sklearn_object, "min_samples"):
672
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
673
+ num_examples = self._sklearn_object.min_samples
674
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
675
+ # LocalOutlierFactor expects n_neighbors <= n_samples
676
+ num_examples = self._sklearn_object.n_neighbors
677
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
678
+ else:
679
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
643
680
 
644
681
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
645
682
  # seen during the fit.
@@ -651,12 +688,14 @@ class AgglomerativeClustering(BaseTransformer):
651
688
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
652
689
  if self.sample_weight_col:
653
690
  output_df_columns_set -= set(self.sample_weight_col)
691
+
654
692
  # if the dimension of inferred output column names is correct; use it
655
693
  if len(expected_output_cols_list) == len(output_df_columns_set):
656
- return expected_output_cols_list
694
+ return expected_output_cols_list, output_df_pd
657
695
  # otherwise, use the sklearn estimator's output
658
696
  else:
659
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
697
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
698
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
660
699
 
661
700
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
662
701
  @telemetry.send_api_usage_telemetry(
@@ -702,7 +741,7 @@ class AgglomerativeClustering(BaseTransformer):
702
741
  drop_input_cols=self._drop_input_cols,
703
742
  expected_output_cols_type="float",
704
743
  )
705
- expected_output_cols = self._align_expected_output_names(
744
+ expected_output_cols, _ = self._align_expected_output(
706
745
  inference_method, dataset, expected_output_cols, output_cols_prefix
707
746
  )
708
747
 
@@ -768,7 +807,7 @@ class AgglomerativeClustering(BaseTransformer):
768
807
  drop_input_cols=self._drop_input_cols,
769
808
  expected_output_cols_type="float",
770
809
  )
771
- expected_output_cols = self._align_expected_output_names(
810
+ expected_output_cols, _ = self._align_expected_output(
772
811
  inference_method, dataset, expected_output_cols, output_cols_prefix
773
812
  )
774
813
  elif isinstance(dataset, pd.DataFrame):
@@ -831,7 +870,7 @@ class AgglomerativeClustering(BaseTransformer):
831
870
  drop_input_cols=self._drop_input_cols,
832
871
  expected_output_cols_type="float",
833
872
  )
834
- expected_output_cols = self._align_expected_output_names(
873
+ expected_output_cols, _ = self._align_expected_output(
835
874
  inference_method, dataset, expected_output_cols, output_cols_prefix
836
875
  )
837
876
 
@@ -896,7 +935,7 @@ class AgglomerativeClustering(BaseTransformer):
896
935
  drop_input_cols = self._drop_input_cols,
897
936
  expected_output_cols_type="float",
898
937
  )
899
- expected_output_cols = self._align_expected_output_names(
938
+ expected_output_cols, _ = self._align_expected_output(
900
939
  inference_method, dataset, expected_output_cols, output_cols_prefix
901
940
  )
902
941
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -512,12 +509,23 @@ class Birch(BaseTransformer):
512
509
  autogenerated=self._autogenerated,
513
510
  subproject=_SUBPROJECT,
514
511
  )
515
- output_result, fitted_estimator = model_trainer.train_fit_predict(
516
- drop_input_cols=self._drop_input_cols,
517
- expected_output_cols_list=(
518
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
519
- ),
512
+ expected_output_cols = (
513
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
520
514
  )
515
+ if isinstance(dataset, DataFrame):
516
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
517
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
518
+ )
519
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
520
+ drop_input_cols=self._drop_input_cols,
521
+ expected_output_cols_list=expected_output_cols,
522
+ example_output_pd_df=example_output_pd_df,
523
+ )
524
+ else:
525
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
526
+ drop_input_cols=self._drop_input_cols,
527
+ expected_output_cols_list=expected_output_cols,
528
+ )
521
529
  self._sklearn_object = fitted_estimator
522
530
  self._is_fitted = True
523
531
  return output_result
@@ -598,12 +606,41 @@ class Birch(BaseTransformer):
598
606
 
599
607
  return rv
600
608
 
601
- def _align_expected_output_names(
602
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
603
- ) -> List[str]:
609
+ def _align_expected_output(
610
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
611
+ ) -> Tuple[List[str], pd.DataFrame]:
612
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
613
+ and output dataframe with 1 line.
614
+ If the method is fit_predict, run 2 lines of data.
615
+ """
604
616
  # in case the inferred output column names dimension is different
605
617
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
606
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
618
+
619
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
620
+ # so change the minimum of number of rows to 2
621
+ num_examples = 2
622
+ statement_params = telemetry.get_function_usage_statement_params(
623
+ project=_PROJECT,
624
+ subproject=_SUBPROJECT,
625
+ function_name=telemetry.get_statement_params_full_func_name(
626
+ inspect.currentframe(), Birch.__class__.__name__
627
+ ),
628
+ api_calls=[Session.call],
629
+ custom_tags={"autogen": True} if self._autogenerated else None,
630
+ )
631
+ if output_cols_prefix == "fit_predict_":
632
+ if hasattr(self._sklearn_object, "n_clusters"):
633
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
634
+ num_examples = self._sklearn_object.n_clusters
635
+ elif hasattr(self._sklearn_object, "min_samples"):
636
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
637
+ num_examples = self._sklearn_object.min_samples
638
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
639
+ # LocalOutlierFactor expects n_neighbors <= n_samples
640
+ num_examples = self._sklearn_object.n_neighbors
641
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
642
+ else:
643
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
607
644
 
608
645
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
609
646
  # seen during the fit.
@@ -615,12 +652,14 @@ class Birch(BaseTransformer):
615
652
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
616
653
  if self.sample_weight_col:
617
654
  output_df_columns_set -= set(self.sample_weight_col)
655
+
618
656
  # if the dimension of inferred output column names is correct; use it
619
657
  if len(expected_output_cols_list) == len(output_df_columns_set):
620
- return expected_output_cols_list
658
+ return expected_output_cols_list, output_df_pd
621
659
  # otherwise, use the sklearn estimator's output
622
660
  else:
623
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
661
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
662
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
624
663
 
625
664
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
626
665
  @telemetry.send_api_usage_telemetry(
@@ -666,7 +705,7 @@ class Birch(BaseTransformer):
666
705
  drop_input_cols=self._drop_input_cols,
667
706
  expected_output_cols_type="float",
668
707
  )
669
- expected_output_cols = self._align_expected_output_names(
708
+ expected_output_cols, _ = self._align_expected_output(
670
709
  inference_method, dataset, expected_output_cols, output_cols_prefix
671
710
  )
672
711
 
@@ -732,7 +771,7 @@ class Birch(BaseTransformer):
732
771
  drop_input_cols=self._drop_input_cols,
733
772
  expected_output_cols_type="float",
734
773
  )
735
- expected_output_cols = self._align_expected_output_names(
774
+ expected_output_cols, _ = self._align_expected_output(
736
775
  inference_method, dataset, expected_output_cols, output_cols_prefix
737
776
  )
738
777
  elif isinstance(dataset, pd.DataFrame):
@@ -795,7 +834,7 @@ class Birch(BaseTransformer):
795
834
  drop_input_cols=self._drop_input_cols,
796
835
  expected_output_cols_type="float",
797
836
  )
798
- expected_output_cols = self._align_expected_output_names(
837
+ expected_output_cols, _ = self._align_expected_output(
799
838
  inference_method, dataset, expected_output_cols, output_cols_prefix
800
839
  )
801
840
 
@@ -860,7 +899,7 @@ class Birch(BaseTransformer):
860
899
  drop_input_cols = self._drop_input_cols,
861
900
  expected_output_cols_type="float",
862
901
  )
863
- expected_output_cols = self._align_expected_output_names(
902
+ expected_output_cols, _ = self._align_expected_output(
864
903
  inference_method, dataset, expected_output_cols, output_cols_prefix
865
904
  )
866
905