snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. snowflake/ml/_internal/telemetry.py +142 -20
  2. snowflake/ml/_internal/utils/identifier.py +48 -11
  3. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  4. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  5. snowflake/ml/_internal/utils/table_manager.py +19 -1
  6. snowflake/ml/_internal/utils/uri.py +2 -2
  7. snowflake/ml/data/data_connector.py +33 -7
  8. snowflake/ml/data/torch_utils.py +68 -0
  9. snowflake/ml/dataset/dataset.py +1 -3
  10. snowflake/ml/feature_store/feature_store.py +41 -17
  11. snowflake/ml/feature_store/feature_view.py +2 -2
  12. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  13. snowflake/ml/fileset/fileset.py +1 -1
  14. snowflake/ml/fileset/sfcfs.py +9 -3
  15. snowflake/ml/model/_client/model/model_version_impl.py +22 -7
  16. snowflake/ml/model/_client/ops/model_ops.py +39 -3
  17. snowflake/ml/model/_client/ops/service_ops.py +198 -7
  18. snowflake/ml/model/_client/service/model_deployment_spec.py +4 -5
  19. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -2
  20. snowflake/ml/model/_client/sql/service.py +85 -18
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  23. snowflake/ml/model/_model_composer/model_composer.py +2 -0
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -8
  25. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  26. snowflake/ml/model/_packager/model_handlers/catboost.py +17 -15
  27. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +23 -15
  28. snowflake/ml/model/_packager/model_handlers/lightgbm.py +15 -57
  29. snowflake/ml/model/_packager/model_handlers/llm.py +4 -2
  30. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  31. snowflake/ml/model/_packager/model_handlers/sklearn.py +36 -24
  32. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +119 -6
  33. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  34. snowflake/ml/model/_packager/model_handlers/xgboost.py +48 -48
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +10 -7
  36. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +0 -8
  37. snowflake/ml/model/_packager/model_packager.py +2 -0
  38. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  39. snowflake/ml/model/_signatures/utils.py +9 -0
  40. snowflake/ml/model/models/llm.py +3 -1
  41. snowflake/ml/model/type_hints.py +9 -1
  42. snowflake/ml/modeling/_internal/constants.py +1 -0
  43. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  44. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  45. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  46. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  48. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  49. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  50. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  51. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  52. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  53. snowflake/ml/modeling/cluster/birch.py +60 -21
  54. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  55. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  56. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  57. snowflake/ml/modeling/cluster/k_means.py +60 -21
  58. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  59. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  60. snowflake/ml/modeling/cluster/optics.py +60 -21
  61. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  62. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  63. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  64. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  65. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  66. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  67. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  68. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  69. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  70. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  71. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  72. snowflake/ml/modeling/covariance/oas.py +60 -21
  73. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  74. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  75. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  76. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  77. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  78. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  79. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  80. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  81. snowflake/ml/modeling/decomposition/pca.py +60 -21
  82. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  83. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  84. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  85. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  86. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  87. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  88. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  89. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  90. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  91. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  92. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  93. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  94. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  95. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  96. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  97. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  98. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  99. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  100. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  101. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  102. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  103. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  104. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  105. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  106. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  107. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  108. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  109. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  111. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  112. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  113. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  114. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  123. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  125. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  129. snowflake/ml/modeling/linear_model/lars.py +60 -21
  130. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  131. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  136. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  146. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  149. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  158. snowflake/ml/modeling/manifold/isomap.py +60 -21
  159. snowflake/ml/modeling/manifold/mds.py +60 -21
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  161. snowflake/ml/modeling/manifold/tsne.py +60 -21
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  174. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  184. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  185. snowflake/ml/modeling/pipeline/pipeline.py +1 -12
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  189. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  190. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  191. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  192. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  193. snowflake/ml/modeling/svm/svc.py +60 -21
  194. snowflake/ml/modeling/svm/svr.py +60 -21
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  203. snowflake/ml/registry/_manager/model_manager.py +4 -0
  204. snowflake/ml/registry/model_registry.py +1 -1
  205. snowflake/ml/registry/registry.py +1 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +23 -4
  208. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +211 -209
  209. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  210. snowflake/ml/data/torch_dataset.py +0 -33
  211. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  212. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -533,12 +530,23 @@ class NuSVR(BaseTransformer):
533
530
  autogenerated=self._autogenerated,
534
531
  subproject=_SUBPROJECT,
535
532
  )
536
- output_result, fitted_estimator = model_trainer.train_fit_predict(
537
- drop_input_cols=self._drop_input_cols,
538
- expected_output_cols_list=(
539
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
540
- ),
533
+ expected_output_cols = (
534
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
541
535
  )
536
+ if isinstance(dataset, DataFrame):
537
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
538
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
539
+ )
540
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
541
+ drop_input_cols=self._drop_input_cols,
542
+ expected_output_cols_list=expected_output_cols,
543
+ example_output_pd_df=example_output_pd_df,
544
+ )
545
+ else:
546
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
547
+ drop_input_cols=self._drop_input_cols,
548
+ expected_output_cols_list=expected_output_cols,
549
+ )
542
550
  self._sklearn_object = fitted_estimator
543
551
  self._is_fitted = True
544
552
  return output_result
@@ -617,12 +625,41 @@ class NuSVR(BaseTransformer):
617
625
 
618
626
  return rv
619
627
 
620
- def _align_expected_output_names(
621
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
622
- ) -> List[str]:
628
+ def _align_expected_output(
629
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
630
+ ) -> Tuple[List[str], pd.DataFrame]:
631
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
632
+ and output dataframe with 1 line.
633
+ If the method is fit_predict, run 2 lines of data.
634
+ """
623
635
  # in case the inferred output column names dimension is different
624
636
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
625
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
637
+
638
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
639
+ # so change the minimum of number of rows to 2
640
+ num_examples = 2
641
+ statement_params = telemetry.get_function_usage_statement_params(
642
+ project=_PROJECT,
643
+ subproject=_SUBPROJECT,
644
+ function_name=telemetry.get_statement_params_full_func_name(
645
+ inspect.currentframe(), NuSVR.__class__.__name__
646
+ ),
647
+ api_calls=[Session.call],
648
+ custom_tags={"autogen": True} if self._autogenerated else None,
649
+ )
650
+ if output_cols_prefix == "fit_predict_":
651
+ if hasattr(self._sklearn_object, "n_clusters"):
652
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
653
+ num_examples = self._sklearn_object.n_clusters
654
+ elif hasattr(self._sklearn_object, "min_samples"):
655
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
656
+ num_examples = self._sklearn_object.min_samples
657
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
658
+ # LocalOutlierFactor expects n_neighbors <= n_samples
659
+ num_examples = self._sklearn_object.n_neighbors
660
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
661
+ else:
662
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
626
663
 
627
664
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
628
665
  # seen during the fit.
@@ -634,12 +671,14 @@ class NuSVR(BaseTransformer):
634
671
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
635
672
  if self.sample_weight_col:
636
673
  output_df_columns_set -= set(self.sample_weight_col)
674
+
637
675
  # if the dimension of inferred output column names is correct; use it
638
676
  if len(expected_output_cols_list) == len(output_df_columns_set):
639
- return expected_output_cols_list
677
+ return expected_output_cols_list, output_df_pd
640
678
  # otherwise, use the sklearn estimator's output
641
679
  else:
642
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
680
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
681
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
643
682
 
644
683
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
645
684
  @telemetry.send_api_usage_telemetry(
@@ -685,7 +724,7 @@ class NuSVR(BaseTransformer):
685
724
  drop_input_cols=self._drop_input_cols,
686
725
  expected_output_cols_type="float",
687
726
  )
688
- expected_output_cols = self._align_expected_output_names(
727
+ expected_output_cols, _ = self._align_expected_output(
689
728
  inference_method, dataset, expected_output_cols, output_cols_prefix
690
729
  )
691
730
 
@@ -751,7 +790,7 @@ class NuSVR(BaseTransformer):
751
790
  drop_input_cols=self._drop_input_cols,
752
791
  expected_output_cols_type="float",
753
792
  )
754
- expected_output_cols = self._align_expected_output_names(
793
+ expected_output_cols, _ = self._align_expected_output(
755
794
  inference_method, dataset, expected_output_cols, output_cols_prefix
756
795
  )
757
796
  elif isinstance(dataset, pd.DataFrame):
@@ -814,7 +853,7 @@ class NuSVR(BaseTransformer):
814
853
  drop_input_cols=self._drop_input_cols,
815
854
  expected_output_cols_type="float",
816
855
  )
817
- expected_output_cols = self._align_expected_output_names(
856
+ expected_output_cols, _ = self._align_expected_output(
818
857
  inference_method, dataset, expected_output_cols, output_cols_prefix
819
858
  )
820
859
 
@@ -879,7 +918,7 @@ class NuSVR(BaseTransformer):
879
918
  drop_input_cols = self._drop_input_cols,
880
919
  expected_output_cols_type="float",
881
920
  )
882
- expected_output_cols = self._align_expected_output_names(
921
+ expected_output_cols, _ = self._align_expected_output(
883
922
  inference_method, dataset, expected_output_cols, output_cols_prefix
884
923
  )
885
924
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -575,12 +572,23 @@ class SVC(BaseTransformer):
575
572
  autogenerated=self._autogenerated,
576
573
  subproject=_SUBPROJECT,
577
574
  )
578
- output_result, fitted_estimator = model_trainer.train_fit_predict(
579
- drop_input_cols=self._drop_input_cols,
580
- expected_output_cols_list=(
581
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
582
- ),
575
+ expected_output_cols = (
576
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
583
577
  )
578
+ if isinstance(dataset, DataFrame):
579
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
580
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
581
+ )
582
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
583
+ drop_input_cols=self._drop_input_cols,
584
+ expected_output_cols_list=expected_output_cols,
585
+ example_output_pd_df=example_output_pd_df,
586
+ )
587
+ else:
588
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
589
+ drop_input_cols=self._drop_input_cols,
590
+ expected_output_cols_list=expected_output_cols,
591
+ )
584
592
  self._sklearn_object = fitted_estimator
585
593
  self._is_fitted = True
586
594
  return output_result
@@ -659,12 +667,41 @@ class SVC(BaseTransformer):
659
667
 
660
668
  return rv
661
669
 
662
- def _align_expected_output_names(
663
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
664
- ) -> List[str]:
670
+ def _align_expected_output(
671
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
672
+ ) -> Tuple[List[str], pd.DataFrame]:
673
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
674
+ and output dataframe with 1 line.
675
+ If the method is fit_predict, run 2 lines of data.
676
+ """
665
677
  # in case the inferred output column names dimension is different
666
678
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
667
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
679
+
680
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
681
+ # so change the minimum of number of rows to 2
682
+ num_examples = 2
683
+ statement_params = telemetry.get_function_usage_statement_params(
684
+ project=_PROJECT,
685
+ subproject=_SUBPROJECT,
686
+ function_name=telemetry.get_statement_params_full_func_name(
687
+ inspect.currentframe(), SVC.__class__.__name__
688
+ ),
689
+ api_calls=[Session.call],
690
+ custom_tags={"autogen": True} if self._autogenerated else None,
691
+ )
692
+ if output_cols_prefix == "fit_predict_":
693
+ if hasattr(self._sklearn_object, "n_clusters"):
694
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
695
+ num_examples = self._sklearn_object.n_clusters
696
+ elif hasattr(self._sklearn_object, "min_samples"):
697
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
698
+ num_examples = self._sklearn_object.min_samples
699
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
700
+ # LocalOutlierFactor expects n_neighbors <= n_samples
701
+ num_examples = self._sklearn_object.n_neighbors
702
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
703
+ else:
704
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
668
705
 
669
706
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
670
707
  # seen during the fit.
@@ -676,12 +713,14 @@ class SVC(BaseTransformer):
676
713
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
677
714
  if self.sample_weight_col:
678
715
  output_df_columns_set -= set(self.sample_weight_col)
716
+
679
717
  # if the dimension of inferred output column names is correct; use it
680
718
  if len(expected_output_cols_list) == len(output_df_columns_set):
681
- return expected_output_cols_list
719
+ return expected_output_cols_list, output_df_pd
682
720
  # otherwise, use the sklearn estimator's output
683
721
  else:
684
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
722
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
723
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
685
724
 
686
725
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
687
726
  @telemetry.send_api_usage_telemetry(
@@ -729,7 +768,7 @@ class SVC(BaseTransformer):
729
768
  drop_input_cols=self._drop_input_cols,
730
769
  expected_output_cols_type="float",
731
770
  )
732
- expected_output_cols = self._align_expected_output_names(
771
+ expected_output_cols, _ = self._align_expected_output(
733
772
  inference_method, dataset, expected_output_cols, output_cols_prefix
734
773
  )
735
774
 
@@ -797,7 +836,7 @@ class SVC(BaseTransformer):
797
836
  drop_input_cols=self._drop_input_cols,
798
837
  expected_output_cols_type="float",
799
838
  )
800
- expected_output_cols = self._align_expected_output_names(
839
+ expected_output_cols, _ = self._align_expected_output(
801
840
  inference_method, dataset, expected_output_cols, output_cols_prefix
802
841
  )
803
842
  elif isinstance(dataset, pd.DataFrame):
@@ -862,7 +901,7 @@ class SVC(BaseTransformer):
862
901
  drop_input_cols=self._drop_input_cols,
863
902
  expected_output_cols_type="float",
864
903
  )
865
- expected_output_cols = self._align_expected_output_names(
904
+ expected_output_cols, _ = self._align_expected_output(
866
905
  inference_method, dataset, expected_output_cols, output_cols_prefix
867
906
  )
868
907
 
@@ -927,7 +966,7 @@ class SVC(BaseTransformer):
927
966
  drop_input_cols = self._drop_input_cols,
928
967
  expected_output_cols_type="float",
929
968
  )
930
- expected_output_cols = self._align_expected_output_names(
969
+ expected_output_cols, _ = self._align_expected_output(
931
970
  inference_method, dataset, expected_output_cols, output_cols_prefix
932
971
  )
933
972
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -536,12 +533,23 @@ class SVR(BaseTransformer):
536
533
  autogenerated=self._autogenerated,
537
534
  subproject=_SUBPROJECT,
538
535
  )
539
- output_result, fitted_estimator = model_trainer.train_fit_predict(
540
- drop_input_cols=self._drop_input_cols,
541
- expected_output_cols_list=(
542
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
543
- ),
536
+ expected_output_cols = (
537
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
544
538
  )
539
+ if isinstance(dataset, DataFrame):
540
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
541
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
542
+ )
543
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
544
+ drop_input_cols=self._drop_input_cols,
545
+ expected_output_cols_list=expected_output_cols,
546
+ example_output_pd_df=example_output_pd_df,
547
+ )
548
+ else:
549
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
550
+ drop_input_cols=self._drop_input_cols,
551
+ expected_output_cols_list=expected_output_cols,
552
+ )
545
553
  self._sklearn_object = fitted_estimator
546
554
  self._is_fitted = True
547
555
  return output_result
@@ -620,12 +628,41 @@ class SVR(BaseTransformer):
620
628
 
621
629
  return rv
622
630
 
623
- def _align_expected_output_names(
624
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
625
- ) -> List[str]:
631
+ def _align_expected_output(
632
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
633
+ ) -> Tuple[List[str], pd.DataFrame]:
634
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
635
+ and output dataframe with 1 line.
636
+ If the method is fit_predict, run 2 lines of data.
637
+ """
626
638
  # in case the inferred output column names dimension is different
627
639
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
628
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
640
+
641
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
642
+ # so change the minimum of number of rows to 2
643
+ num_examples = 2
644
+ statement_params = telemetry.get_function_usage_statement_params(
645
+ project=_PROJECT,
646
+ subproject=_SUBPROJECT,
647
+ function_name=telemetry.get_statement_params_full_func_name(
648
+ inspect.currentframe(), SVR.__class__.__name__
649
+ ),
650
+ api_calls=[Session.call],
651
+ custom_tags={"autogen": True} if self._autogenerated else None,
652
+ )
653
+ if output_cols_prefix == "fit_predict_":
654
+ if hasattr(self._sklearn_object, "n_clusters"):
655
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
656
+ num_examples = self._sklearn_object.n_clusters
657
+ elif hasattr(self._sklearn_object, "min_samples"):
658
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
659
+ num_examples = self._sklearn_object.min_samples
660
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
661
+ # LocalOutlierFactor expects n_neighbors <= n_samples
662
+ num_examples = self._sklearn_object.n_neighbors
663
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
664
+ else:
665
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
629
666
 
630
667
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
631
668
  # seen during the fit.
@@ -637,12 +674,14 @@ class SVR(BaseTransformer):
637
674
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
638
675
  if self.sample_weight_col:
639
676
  output_df_columns_set -= set(self.sample_weight_col)
677
+
640
678
  # if the dimension of inferred output column names is correct; use it
641
679
  if len(expected_output_cols_list) == len(output_df_columns_set):
642
- return expected_output_cols_list
680
+ return expected_output_cols_list, output_df_pd
643
681
  # otherwise, use the sklearn estimator's output
644
682
  else:
645
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
683
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
684
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
646
685
 
647
686
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
648
687
  @telemetry.send_api_usage_telemetry(
@@ -688,7 +727,7 @@ class SVR(BaseTransformer):
688
727
  drop_input_cols=self._drop_input_cols,
689
728
  expected_output_cols_type="float",
690
729
  )
691
- expected_output_cols = self._align_expected_output_names(
730
+ expected_output_cols, _ = self._align_expected_output(
692
731
  inference_method, dataset, expected_output_cols, output_cols_prefix
693
732
  )
694
733
 
@@ -754,7 +793,7 @@ class SVR(BaseTransformer):
754
793
  drop_input_cols=self._drop_input_cols,
755
794
  expected_output_cols_type="float",
756
795
  )
757
- expected_output_cols = self._align_expected_output_names(
796
+ expected_output_cols, _ = self._align_expected_output(
758
797
  inference_method, dataset, expected_output_cols, output_cols_prefix
759
798
  )
760
799
  elif isinstance(dataset, pd.DataFrame):
@@ -817,7 +856,7 @@ class SVR(BaseTransformer):
817
856
  drop_input_cols=self._drop_input_cols,
818
857
  expected_output_cols_type="float",
819
858
  )
820
- expected_output_cols = self._align_expected_output_names(
859
+ expected_output_cols, _ = self._align_expected_output(
821
860
  inference_method, dataset, expected_output_cols, output_cols_prefix
822
861
  )
823
862
 
@@ -882,7 +921,7 @@ class SVR(BaseTransformer):
882
921
  drop_input_cols = self._drop_input_cols,
883
922
  expected_output_cols_type="float",
884
923
  )
885
- expected_output_cols = self._align_expected_output_names(
924
+ expected_output_cols, _ = self._align_expected_output(
886
925
  inference_method, dataset, expected_output_cols, output_cols_prefix
887
926
  )
888
927
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -603,12 +600,23 @@ class DecisionTreeClassifier(BaseTransformer):
603
600
  autogenerated=self._autogenerated,
604
601
  subproject=_SUBPROJECT,
605
602
  )
606
- output_result, fitted_estimator = model_trainer.train_fit_predict(
607
- drop_input_cols=self._drop_input_cols,
608
- expected_output_cols_list=(
609
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
610
- ),
603
+ expected_output_cols = (
604
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
611
605
  )
606
+ if isinstance(dataset, DataFrame):
607
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
608
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
609
+ )
610
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
611
+ drop_input_cols=self._drop_input_cols,
612
+ expected_output_cols_list=expected_output_cols,
613
+ example_output_pd_df=example_output_pd_df,
614
+ )
615
+ else:
616
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
617
+ drop_input_cols=self._drop_input_cols,
618
+ expected_output_cols_list=expected_output_cols,
619
+ )
612
620
  self._sklearn_object = fitted_estimator
613
621
  self._is_fitted = True
614
622
  return output_result
@@ -687,12 +695,41 @@ class DecisionTreeClassifier(BaseTransformer):
687
695
 
688
696
  return rv
689
697
 
690
- def _align_expected_output_names(
691
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
692
- ) -> List[str]:
698
+ def _align_expected_output(
699
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
700
+ ) -> Tuple[List[str], pd.DataFrame]:
701
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
702
+ and output dataframe with 1 line.
703
+ If the method is fit_predict, run 2 lines of data.
704
+ """
693
705
  # in case the inferred output column names dimension is different
694
706
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
695
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
707
+
708
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
709
+ # so change the minimum of number of rows to 2
710
+ num_examples = 2
711
+ statement_params = telemetry.get_function_usage_statement_params(
712
+ project=_PROJECT,
713
+ subproject=_SUBPROJECT,
714
+ function_name=telemetry.get_statement_params_full_func_name(
715
+ inspect.currentframe(), DecisionTreeClassifier.__class__.__name__
716
+ ),
717
+ api_calls=[Session.call],
718
+ custom_tags={"autogen": True} if self._autogenerated else None,
719
+ )
720
+ if output_cols_prefix == "fit_predict_":
721
+ if hasattr(self._sklearn_object, "n_clusters"):
722
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
723
+ num_examples = self._sklearn_object.n_clusters
724
+ elif hasattr(self._sklearn_object, "min_samples"):
725
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
726
+ num_examples = self._sklearn_object.min_samples
727
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
728
+ # LocalOutlierFactor expects n_neighbors <= n_samples
729
+ num_examples = self._sklearn_object.n_neighbors
730
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
731
+ else:
732
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
696
733
 
697
734
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
698
735
  # seen during the fit.
@@ -704,12 +741,14 @@ class DecisionTreeClassifier(BaseTransformer):
704
741
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
705
742
  if self.sample_weight_col:
706
743
  output_df_columns_set -= set(self.sample_weight_col)
744
+
707
745
  # if the dimension of inferred output column names is correct; use it
708
746
  if len(expected_output_cols_list) == len(output_df_columns_set):
709
- return expected_output_cols_list
747
+ return expected_output_cols_list, output_df_pd
710
748
  # otherwise, use the sklearn estimator's output
711
749
  else:
712
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
750
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
751
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
713
752
 
714
753
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
715
754
  @telemetry.send_api_usage_telemetry(
@@ -757,7 +796,7 @@ class DecisionTreeClassifier(BaseTransformer):
757
796
  drop_input_cols=self._drop_input_cols,
758
797
  expected_output_cols_type="float",
759
798
  )
760
- expected_output_cols = self._align_expected_output_names(
799
+ expected_output_cols, _ = self._align_expected_output(
761
800
  inference_method, dataset, expected_output_cols, output_cols_prefix
762
801
  )
763
802
 
@@ -825,7 +864,7 @@ class DecisionTreeClassifier(BaseTransformer):
825
864
  drop_input_cols=self._drop_input_cols,
826
865
  expected_output_cols_type="float",
827
866
  )
828
- expected_output_cols = self._align_expected_output_names(
867
+ expected_output_cols, _ = self._align_expected_output(
829
868
  inference_method, dataset, expected_output_cols, output_cols_prefix
830
869
  )
831
870
  elif isinstance(dataset, pd.DataFrame):
@@ -888,7 +927,7 @@ class DecisionTreeClassifier(BaseTransformer):
888
927
  drop_input_cols=self._drop_input_cols,
889
928
  expected_output_cols_type="float",
890
929
  )
891
- expected_output_cols = self._align_expected_output_names(
930
+ expected_output_cols, _ = self._align_expected_output(
892
931
  inference_method, dataset, expected_output_cols, output_cols_prefix
893
932
  )
894
933
 
@@ -953,7 +992,7 @@ class DecisionTreeClassifier(BaseTransformer):
953
992
  drop_input_cols = self._drop_input_cols,
954
993
  expected_output_cols_type="float",
955
994
  )
956
- expected_output_cols = self._align_expected_output_names(
995
+ expected_output_cols, _ = self._align_expected_output(
957
996
  inference_method, dataset, expected_output_cols, output_cols_prefix
958
997
  )
959
998