snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. snowflake/ml/_internal/telemetry.py +142 -20
  2. snowflake/ml/_internal/utils/identifier.py +48 -11
  3. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  4. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  5. snowflake/ml/_internal/utils/table_manager.py +19 -1
  6. snowflake/ml/_internal/utils/uri.py +2 -2
  7. snowflake/ml/data/data_connector.py +33 -7
  8. snowflake/ml/data/torch_utils.py +68 -0
  9. snowflake/ml/dataset/dataset.py +1 -3
  10. snowflake/ml/feature_store/feature_store.py +41 -17
  11. snowflake/ml/feature_store/feature_view.py +2 -2
  12. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  13. snowflake/ml/fileset/fileset.py +1 -1
  14. snowflake/ml/fileset/sfcfs.py +9 -3
  15. snowflake/ml/model/_client/model/model_version_impl.py +22 -7
  16. snowflake/ml/model/_client/ops/model_ops.py +39 -3
  17. snowflake/ml/model/_client/ops/service_ops.py +198 -7
  18. snowflake/ml/model/_client/service/model_deployment_spec.py +4 -5
  19. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -2
  20. snowflake/ml/model/_client/sql/service.py +85 -18
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  23. snowflake/ml/model/_model_composer/model_composer.py +2 -0
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -8
  25. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  26. snowflake/ml/model/_packager/model_handlers/catboost.py +17 -15
  27. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +23 -15
  28. snowflake/ml/model/_packager/model_handlers/lightgbm.py +15 -57
  29. snowflake/ml/model/_packager/model_handlers/llm.py +4 -2
  30. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  31. snowflake/ml/model/_packager/model_handlers/sklearn.py +36 -24
  32. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +119 -6
  33. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  34. snowflake/ml/model/_packager/model_handlers/xgboost.py +48 -48
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +10 -7
  36. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +0 -8
  37. snowflake/ml/model/_packager/model_packager.py +2 -0
  38. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  39. snowflake/ml/model/_signatures/utils.py +9 -0
  40. snowflake/ml/model/models/llm.py +3 -1
  41. snowflake/ml/model/type_hints.py +9 -1
  42. snowflake/ml/modeling/_internal/constants.py +1 -0
  43. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  44. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  45. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  46. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  48. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  49. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  50. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  51. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  52. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  53. snowflake/ml/modeling/cluster/birch.py +60 -21
  54. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  55. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  56. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  57. snowflake/ml/modeling/cluster/k_means.py +60 -21
  58. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  59. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  60. snowflake/ml/modeling/cluster/optics.py +60 -21
  61. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  62. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  63. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  64. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  65. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  66. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  67. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  68. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  69. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  70. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  71. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  72. snowflake/ml/modeling/covariance/oas.py +60 -21
  73. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  74. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  75. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  76. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  77. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  78. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  79. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  80. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  81. snowflake/ml/modeling/decomposition/pca.py +60 -21
  82. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  83. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  84. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  85. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  86. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  87. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  88. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  89. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  90. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  91. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  92. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  93. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  94. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  95. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  96. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  97. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  98. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  99. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  100. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  101. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  102. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  103. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  104. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  105. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  106. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  107. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  108. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  109. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  111. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  112. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  113. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  114. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  123. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  125. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  129. snowflake/ml/modeling/linear_model/lars.py +60 -21
  130. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  131. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  136. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  146. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  149. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  158. snowflake/ml/modeling/manifold/isomap.py +60 -21
  159. snowflake/ml/modeling/manifold/mds.py +60 -21
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  161. snowflake/ml/modeling/manifold/tsne.py +60 -21
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  174. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  184. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  185. snowflake/ml/modeling/pipeline/pipeline.py +1 -12
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  189. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  190. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  191. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  192. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  193. snowflake/ml/modeling/svm/svc.py +60 -21
  194. snowflake/ml/modeling/svm/svr.py +60 -21
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  203. snowflake/ml/registry/_manager/model_manager.py +4 -0
  204. snowflake/ml/registry/model_registry.py +1 -1
  205. snowflake/ml/registry/registry.py +1 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +23 -4
  208. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +211 -209
  209. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  210. snowflake/ml/data/torch_dataset.py +0 -33
  211. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  212. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -570,12 +567,23 @@ class RidgeClassifier(BaseTransformer):
570
567
  autogenerated=self._autogenerated,
571
568
  subproject=_SUBPROJECT,
572
569
  )
573
- output_result, fitted_estimator = model_trainer.train_fit_predict(
574
- drop_input_cols=self._drop_input_cols,
575
- expected_output_cols_list=(
576
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
577
- ),
570
+ expected_output_cols = (
571
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
578
572
  )
573
+ if isinstance(dataset, DataFrame):
574
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
575
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
576
+ )
577
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
578
+ drop_input_cols=self._drop_input_cols,
579
+ expected_output_cols_list=expected_output_cols,
580
+ example_output_pd_df=example_output_pd_df,
581
+ )
582
+ else:
583
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
584
+ drop_input_cols=self._drop_input_cols,
585
+ expected_output_cols_list=expected_output_cols,
586
+ )
579
587
  self._sklearn_object = fitted_estimator
580
588
  self._is_fitted = True
581
589
  return output_result
@@ -654,12 +662,41 @@ class RidgeClassifier(BaseTransformer):
654
662
 
655
663
  return rv
656
664
 
657
- def _align_expected_output_names(
658
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
659
- ) -> List[str]:
665
+ def _align_expected_output(
666
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
667
+ ) -> Tuple[List[str], pd.DataFrame]:
668
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
669
+ and output dataframe with 1 line.
670
+ If the method is fit_predict, run 2 lines of data.
671
+ """
660
672
  # in case the inferred output column names dimension is different
661
673
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
662
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
674
+
675
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
676
+ # so change the minimum of number of rows to 2
677
+ num_examples = 2
678
+ statement_params = telemetry.get_function_usage_statement_params(
679
+ project=_PROJECT,
680
+ subproject=_SUBPROJECT,
681
+ function_name=telemetry.get_statement_params_full_func_name(
682
+ inspect.currentframe(), RidgeClassifier.__class__.__name__
683
+ ),
684
+ api_calls=[Session.call],
685
+ custom_tags={"autogen": True} if self._autogenerated else None,
686
+ )
687
+ if output_cols_prefix == "fit_predict_":
688
+ if hasattr(self._sklearn_object, "n_clusters"):
689
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
690
+ num_examples = self._sklearn_object.n_clusters
691
+ elif hasattr(self._sklearn_object, "min_samples"):
692
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
693
+ num_examples = self._sklearn_object.min_samples
694
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
695
+ # LocalOutlierFactor expects n_neighbors <= n_samples
696
+ num_examples = self._sklearn_object.n_neighbors
697
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
698
+ else:
699
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
663
700
 
664
701
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
665
702
  # seen during the fit.
@@ -671,12 +708,14 @@ class RidgeClassifier(BaseTransformer):
671
708
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
672
709
  if self.sample_weight_col:
673
710
  output_df_columns_set -= set(self.sample_weight_col)
711
+
674
712
  # if the dimension of inferred output column names is correct; use it
675
713
  if len(expected_output_cols_list) == len(output_df_columns_set):
676
- return expected_output_cols_list
714
+ return expected_output_cols_list, output_df_pd
677
715
  # otherwise, use the sklearn estimator's output
678
716
  else:
679
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
717
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
718
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
680
719
 
681
720
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
682
721
  @telemetry.send_api_usage_telemetry(
@@ -722,7 +761,7 @@ class RidgeClassifier(BaseTransformer):
722
761
  drop_input_cols=self._drop_input_cols,
723
762
  expected_output_cols_type="float",
724
763
  )
725
- expected_output_cols = self._align_expected_output_names(
764
+ expected_output_cols, _ = self._align_expected_output(
726
765
  inference_method, dataset, expected_output_cols, output_cols_prefix
727
766
  )
728
767
 
@@ -788,7 +827,7 @@ class RidgeClassifier(BaseTransformer):
788
827
  drop_input_cols=self._drop_input_cols,
789
828
  expected_output_cols_type="float",
790
829
  )
791
- expected_output_cols = self._align_expected_output_names(
830
+ expected_output_cols, _ = self._align_expected_output(
792
831
  inference_method, dataset, expected_output_cols, output_cols_prefix
793
832
  )
794
833
  elif isinstance(dataset, pd.DataFrame):
@@ -853,7 +892,7 @@ class RidgeClassifier(BaseTransformer):
853
892
  drop_input_cols=self._drop_input_cols,
854
893
  expected_output_cols_type="float",
855
894
  )
856
- expected_output_cols = self._align_expected_output_names(
895
+ expected_output_cols, _ = self._align_expected_output(
857
896
  inference_method, dataset, expected_output_cols, output_cols_prefix
858
897
  )
859
898
 
@@ -918,7 +957,7 @@ class RidgeClassifier(BaseTransformer):
918
957
  drop_input_cols = self._drop_input_cols,
919
958
  expected_output_cols_type="float",
920
959
  )
921
- expected_output_cols = self._align_expected_output_names(
960
+ expected_output_cols, _ = self._align_expected_output(
922
961
  inference_method, dataset, expected_output_cols, output_cols_prefix
923
962
  )
924
963
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -521,12 +518,23 @@ class RidgeClassifierCV(BaseTransformer):
521
518
  autogenerated=self._autogenerated,
522
519
  subproject=_SUBPROJECT,
523
520
  )
524
- output_result, fitted_estimator = model_trainer.train_fit_predict(
525
- drop_input_cols=self._drop_input_cols,
526
- expected_output_cols_list=(
527
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
528
- ),
521
+ expected_output_cols = (
522
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
529
523
  )
524
+ if isinstance(dataset, DataFrame):
525
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
526
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
527
+ )
528
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
529
+ drop_input_cols=self._drop_input_cols,
530
+ expected_output_cols_list=expected_output_cols,
531
+ example_output_pd_df=example_output_pd_df,
532
+ )
533
+ else:
534
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
535
+ drop_input_cols=self._drop_input_cols,
536
+ expected_output_cols_list=expected_output_cols,
537
+ )
530
538
  self._sklearn_object = fitted_estimator
531
539
  self._is_fitted = True
532
540
  return output_result
@@ -605,12 +613,41 @@ class RidgeClassifierCV(BaseTransformer):
605
613
 
606
614
  return rv
607
615
 
608
- def _align_expected_output_names(
609
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
610
- ) -> List[str]:
616
+ def _align_expected_output(
617
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
618
+ ) -> Tuple[List[str], pd.DataFrame]:
619
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
620
+ and output dataframe with 1 line.
621
+ If the method is fit_predict, run 2 lines of data.
622
+ """
611
623
  # in case the inferred output column names dimension is different
612
624
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
613
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
625
+
626
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
627
+ # so change the minimum of number of rows to 2
628
+ num_examples = 2
629
+ statement_params = telemetry.get_function_usage_statement_params(
630
+ project=_PROJECT,
631
+ subproject=_SUBPROJECT,
632
+ function_name=telemetry.get_statement_params_full_func_name(
633
+ inspect.currentframe(), RidgeClassifierCV.__class__.__name__
634
+ ),
635
+ api_calls=[Session.call],
636
+ custom_tags={"autogen": True} if self._autogenerated else None,
637
+ )
638
+ if output_cols_prefix == "fit_predict_":
639
+ if hasattr(self._sklearn_object, "n_clusters"):
640
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
641
+ num_examples = self._sklearn_object.n_clusters
642
+ elif hasattr(self._sklearn_object, "min_samples"):
643
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
644
+ num_examples = self._sklearn_object.min_samples
645
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
646
+ # LocalOutlierFactor expects n_neighbors <= n_samples
647
+ num_examples = self._sklearn_object.n_neighbors
648
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
649
+ else:
650
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
614
651
 
615
652
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
616
653
  # seen during the fit.
@@ -622,12 +659,14 @@ class RidgeClassifierCV(BaseTransformer):
622
659
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
623
660
  if self.sample_weight_col:
624
661
  output_df_columns_set -= set(self.sample_weight_col)
662
+
625
663
  # if the dimension of inferred output column names is correct; use it
626
664
  if len(expected_output_cols_list) == len(output_df_columns_set):
627
- return expected_output_cols_list
665
+ return expected_output_cols_list, output_df_pd
628
666
  # otherwise, use the sklearn estimator's output
629
667
  else:
630
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
668
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
669
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
631
670
 
632
671
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
633
672
  @telemetry.send_api_usage_telemetry(
@@ -673,7 +712,7 @@ class RidgeClassifierCV(BaseTransformer):
673
712
  drop_input_cols=self._drop_input_cols,
674
713
  expected_output_cols_type="float",
675
714
  )
676
- expected_output_cols = self._align_expected_output_names(
715
+ expected_output_cols, _ = self._align_expected_output(
677
716
  inference_method, dataset, expected_output_cols, output_cols_prefix
678
717
  )
679
718
 
@@ -739,7 +778,7 @@ class RidgeClassifierCV(BaseTransformer):
739
778
  drop_input_cols=self._drop_input_cols,
740
779
  expected_output_cols_type="float",
741
780
  )
742
- expected_output_cols = self._align_expected_output_names(
781
+ expected_output_cols, _ = self._align_expected_output(
743
782
  inference_method, dataset, expected_output_cols, output_cols_prefix
744
783
  )
745
784
  elif isinstance(dataset, pd.DataFrame):
@@ -804,7 +843,7 @@ class RidgeClassifierCV(BaseTransformer):
804
843
  drop_input_cols=self._drop_input_cols,
805
844
  expected_output_cols_type="float",
806
845
  )
807
- expected_output_cols = self._align_expected_output_names(
846
+ expected_output_cols, _ = self._align_expected_output(
808
847
  inference_method, dataset, expected_output_cols, output_cols_prefix
809
848
  )
810
849
 
@@ -869,7 +908,7 @@ class RidgeClassifierCV(BaseTransformer):
869
908
  drop_input_cols = self._drop_input_cols,
870
909
  expected_output_cols_type="float",
871
910
  )
872
- expected_output_cols = self._align_expected_output_names(
911
+ expected_output_cols, _ = self._align_expected_output(
873
912
  inference_method, dataset, expected_output_cols, output_cols_prefix
874
913
  )
875
914
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -542,12 +539,23 @@ class RidgeCV(BaseTransformer):
542
539
  autogenerated=self._autogenerated,
543
540
  subproject=_SUBPROJECT,
544
541
  )
545
- output_result, fitted_estimator = model_trainer.train_fit_predict(
546
- drop_input_cols=self._drop_input_cols,
547
- expected_output_cols_list=(
548
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
549
- ),
542
+ expected_output_cols = (
543
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
550
544
  )
545
+ if isinstance(dataset, DataFrame):
546
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
547
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
548
+ )
549
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
550
+ drop_input_cols=self._drop_input_cols,
551
+ expected_output_cols_list=expected_output_cols,
552
+ example_output_pd_df=example_output_pd_df,
553
+ )
554
+ else:
555
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
556
+ drop_input_cols=self._drop_input_cols,
557
+ expected_output_cols_list=expected_output_cols,
558
+ )
551
559
  self._sklearn_object = fitted_estimator
552
560
  self._is_fitted = True
553
561
  return output_result
@@ -626,12 +634,41 @@ class RidgeCV(BaseTransformer):
626
634
 
627
635
  return rv
628
636
 
629
- def _align_expected_output_names(
630
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
631
- ) -> List[str]:
637
+ def _align_expected_output(
638
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
639
+ ) -> Tuple[List[str], pd.DataFrame]:
640
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
641
+ and output dataframe with 1 line.
642
+ If the method is fit_predict, run 2 lines of data.
643
+ """
632
644
  # in case the inferred output column names dimension is different
633
645
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
634
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
646
+
647
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
648
+ # so change the minimum of number of rows to 2
649
+ num_examples = 2
650
+ statement_params = telemetry.get_function_usage_statement_params(
651
+ project=_PROJECT,
652
+ subproject=_SUBPROJECT,
653
+ function_name=telemetry.get_statement_params_full_func_name(
654
+ inspect.currentframe(), RidgeCV.__class__.__name__
655
+ ),
656
+ api_calls=[Session.call],
657
+ custom_tags={"autogen": True} if self._autogenerated else None,
658
+ )
659
+ if output_cols_prefix == "fit_predict_":
660
+ if hasattr(self._sklearn_object, "n_clusters"):
661
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
662
+ num_examples = self._sklearn_object.n_clusters
663
+ elif hasattr(self._sklearn_object, "min_samples"):
664
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
665
+ num_examples = self._sklearn_object.min_samples
666
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
667
+ # LocalOutlierFactor expects n_neighbors <= n_samples
668
+ num_examples = self._sklearn_object.n_neighbors
669
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
670
+ else:
671
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
635
672
 
636
673
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
637
674
  # seen during the fit.
@@ -643,12 +680,14 @@ class RidgeCV(BaseTransformer):
643
680
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
644
681
  if self.sample_weight_col:
645
682
  output_df_columns_set -= set(self.sample_weight_col)
683
+
646
684
  # if the dimension of inferred output column names is correct; use it
647
685
  if len(expected_output_cols_list) == len(output_df_columns_set):
648
- return expected_output_cols_list
686
+ return expected_output_cols_list, output_df_pd
649
687
  # otherwise, use the sklearn estimator's output
650
688
  else:
651
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
689
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
690
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
652
691
 
653
692
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
654
693
  @telemetry.send_api_usage_telemetry(
@@ -694,7 +733,7 @@ class RidgeCV(BaseTransformer):
694
733
  drop_input_cols=self._drop_input_cols,
695
734
  expected_output_cols_type="float",
696
735
  )
697
- expected_output_cols = self._align_expected_output_names(
736
+ expected_output_cols, _ = self._align_expected_output(
698
737
  inference_method, dataset, expected_output_cols, output_cols_prefix
699
738
  )
700
739
 
@@ -760,7 +799,7 @@ class RidgeCV(BaseTransformer):
760
799
  drop_input_cols=self._drop_input_cols,
761
800
  expected_output_cols_type="float",
762
801
  )
763
- expected_output_cols = self._align_expected_output_names(
802
+ expected_output_cols, _ = self._align_expected_output(
764
803
  inference_method, dataset, expected_output_cols, output_cols_prefix
765
804
  )
766
805
  elif isinstance(dataset, pd.DataFrame):
@@ -823,7 +862,7 @@ class RidgeCV(BaseTransformer):
823
862
  drop_input_cols=self._drop_input_cols,
824
863
  expected_output_cols_type="float",
825
864
  )
826
- expected_output_cols = self._align_expected_output_names(
865
+ expected_output_cols, _ = self._align_expected_output(
827
866
  inference_method, dataset, expected_output_cols, output_cols_prefix
828
867
  )
829
868
 
@@ -888,7 +927,7 @@ class RidgeCV(BaseTransformer):
888
927
  drop_input_cols = self._drop_input_cols,
889
928
  expected_output_cols_type="float",
890
929
  )
891
- expected_output_cols = self._align_expected_output_names(
930
+ expected_output_cols, _ = self._align_expected_output(
892
931
  inference_method, dataset, expected_output_cols, output_cols_prefix
893
932
  )
894
933
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -661,12 +658,23 @@ class SGDClassifier(BaseTransformer):
661
658
  autogenerated=self._autogenerated,
662
659
  subproject=_SUBPROJECT,
663
660
  )
664
- output_result, fitted_estimator = model_trainer.train_fit_predict(
665
- drop_input_cols=self._drop_input_cols,
666
- expected_output_cols_list=(
667
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
668
- ),
661
+ expected_output_cols = (
662
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
669
663
  )
664
+ if isinstance(dataset, DataFrame):
665
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
666
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
667
+ )
668
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
669
+ drop_input_cols=self._drop_input_cols,
670
+ expected_output_cols_list=expected_output_cols,
671
+ example_output_pd_df=example_output_pd_df,
672
+ )
673
+ else:
674
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
675
+ drop_input_cols=self._drop_input_cols,
676
+ expected_output_cols_list=expected_output_cols,
677
+ )
670
678
  self._sklearn_object = fitted_estimator
671
679
  self._is_fitted = True
672
680
  return output_result
@@ -745,12 +753,41 @@ class SGDClassifier(BaseTransformer):
745
753
 
746
754
  return rv
747
755
 
748
- def _align_expected_output_names(
749
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
750
- ) -> List[str]:
756
+ def _align_expected_output(
757
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
758
+ ) -> Tuple[List[str], pd.DataFrame]:
759
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
760
+ and output dataframe with 1 line.
761
+ If the method is fit_predict, run 2 lines of data.
762
+ """
751
763
  # in case the inferred output column names dimension is different
752
764
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
753
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
765
+
766
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
767
+ # so change the minimum of number of rows to 2
768
+ num_examples = 2
769
+ statement_params = telemetry.get_function_usage_statement_params(
770
+ project=_PROJECT,
771
+ subproject=_SUBPROJECT,
772
+ function_name=telemetry.get_statement_params_full_func_name(
773
+ inspect.currentframe(), SGDClassifier.__class__.__name__
774
+ ),
775
+ api_calls=[Session.call],
776
+ custom_tags={"autogen": True} if self._autogenerated else None,
777
+ )
778
+ if output_cols_prefix == "fit_predict_":
779
+ if hasattr(self._sklearn_object, "n_clusters"):
780
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
781
+ num_examples = self._sklearn_object.n_clusters
782
+ elif hasattr(self._sklearn_object, "min_samples"):
783
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
784
+ num_examples = self._sklearn_object.min_samples
785
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
786
+ # LocalOutlierFactor expects n_neighbors <= n_samples
787
+ num_examples = self._sklearn_object.n_neighbors
788
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
789
+ else:
790
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
754
791
 
755
792
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
756
793
  # seen during the fit.
@@ -762,12 +799,14 @@ class SGDClassifier(BaseTransformer):
762
799
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
763
800
  if self.sample_weight_col:
764
801
  output_df_columns_set -= set(self.sample_weight_col)
802
+
765
803
  # if the dimension of inferred output column names is correct; use it
766
804
  if len(expected_output_cols_list) == len(output_df_columns_set):
767
- return expected_output_cols_list
805
+ return expected_output_cols_list, output_df_pd
768
806
  # otherwise, use the sklearn estimator's output
769
807
  else:
770
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
808
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
809
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
771
810
 
772
811
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
773
812
  @telemetry.send_api_usage_telemetry(
@@ -815,7 +854,7 @@ class SGDClassifier(BaseTransformer):
815
854
  drop_input_cols=self._drop_input_cols,
816
855
  expected_output_cols_type="float",
817
856
  )
818
- expected_output_cols = self._align_expected_output_names(
857
+ expected_output_cols, _ = self._align_expected_output(
819
858
  inference_method, dataset, expected_output_cols, output_cols_prefix
820
859
  )
821
860
 
@@ -883,7 +922,7 @@ class SGDClassifier(BaseTransformer):
883
922
  drop_input_cols=self._drop_input_cols,
884
923
  expected_output_cols_type="float",
885
924
  )
886
- expected_output_cols = self._align_expected_output_names(
925
+ expected_output_cols, _ = self._align_expected_output(
887
926
  inference_method, dataset, expected_output_cols, output_cols_prefix
888
927
  )
889
928
  elif isinstance(dataset, pd.DataFrame):
@@ -948,7 +987,7 @@ class SGDClassifier(BaseTransformer):
948
987
  drop_input_cols=self._drop_input_cols,
949
988
  expected_output_cols_type="float",
950
989
  )
951
- expected_output_cols = self._align_expected_output_names(
990
+ expected_output_cols, _ = self._align_expected_output(
952
991
  inference_method, dataset, expected_output_cols, output_cols_prefix
953
992
  )
954
993
 
@@ -1013,7 +1052,7 @@ class SGDClassifier(BaseTransformer):
1013
1052
  drop_input_cols = self._drop_input_cols,
1014
1053
  expected_output_cols_type="float",
1015
1054
  )
1016
- expected_output_cols = self._align_expected_output_names(
1055
+ expected_output_cols, _ = self._align_expected_output(
1017
1056
  inference_method, dataset, expected_output_cols, output_cols_prefix
1018
1057
  )
1019
1058