snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. snowflake/ml/_internal/telemetry.py +142 -20
  2. snowflake/ml/_internal/utils/identifier.py +48 -11
  3. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  4. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  5. snowflake/ml/_internal/utils/table_manager.py +19 -1
  6. snowflake/ml/_internal/utils/uri.py +2 -2
  7. snowflake/ml/data/data_connector.py +33 -7
  8. snowflake/ml/data/torch_utils.py +68 -0
  9. snowflake/ml/dataset/dataset.py +1 -3
  10. snowflake/ml/feature_store/feature_store.py +41 -17
  11. snowflake/ml/feature_store/feature_view.py +2 -2
  12. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  13. snowflake/ml/fileset/fileset.py +1 -1
  14. snowflake/ml/fileset/sfcfs.py +9 -3
  15. snowflake/ml/model/_client/model/model_version_impl.py +22 -7
  16. snowflake/ml/model/_client/ops/model_ops.py +39 -3
  17. snowflake/ml/model/_client/ops/service_ops.py +198 -7
  18. snowflake/ml/model/_client/service/model_deployment_spec.py +4 -5
  19. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -2
  20. snowflake/ml/model/_client/sql/service.py +85 -18
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  23. snowflake/ml/model/_model_composer/model_composer.py +2 -0
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -8
  25. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  26. snowflake/ml/model/_packager/model_handlers/catboost.py +17 -15
  27. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +23 -15
  28. snowflake/ml/model/_packager/model_handlers/lightgbm.py +15 -57
  29. snowflake/ml/model/_packager/model_handlers/llm.py +4 -2
  30. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  31. snowflake/ml/model/_packager/model_handlers/sklearn.py +36 -24
  32. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +119 -6
  33. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  34. snowflake/ml/model/_packager/model_handlers/xgboost.py +48 -48
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +10 -7
  36. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +0 -8
  37. snowflake/ml/model/_packager/model_packager.py +2 -0
  38. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  39. snowflake/ml/model/_signatures/utils.py +9 -0
  40. snowflake/ml/model/models/llm.py +3 -1
  41. snowflake/ml/model/type_hints.py +9 -1
  42. snowflake/ml/modeling/_internal/constants.py +1 -0
  43. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  44. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  45. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  46. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  48. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  49. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  50. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  51. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  52. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  53. snowflake/ml/modeling/cluster/birch.py +60 -21
  54. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  55. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  56. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  57. snowflake/ml/modeling/cluster/k_means.py +60 -21
  58. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  59. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  60. snowflake/ml/modeling/cluster/optics.py +60 -21
  61. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  62. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  63. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  64. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  65. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  66. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  67. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  68. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  69. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  70. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  71. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  72. snowflake/ml/modeling/covariance/oas.py +60 -21
  73. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  74. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  75. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  76. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  77. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  78. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  79. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  80. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  81. snowflake/ml/modeling/decomposition/pca.py +60 -21
  82. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  83. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  84. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  85. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  86. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  87. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  88. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  89. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  90. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  91. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  92. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  93. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  94. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  95. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  96. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  97. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  98. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  99. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  100. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  101. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  102. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  103. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  104. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  105. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  106. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  107. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  108. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  109. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  111. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  112. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  113. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  114. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  123. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  125. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  129. snowflake/ml/modeling/linear_model/lars.py +60 -21
  130. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  131. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  136. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  146. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  149. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  158. snowflake/ml/modeling/manifold/isomap.py +60 -21
  159. snowflake/ml/modeling/manifold/mds.py +60 -21
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  161. snowflake/ml/modeling/manifold/tsne.py +60 -21
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  174. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  184. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  185. snowflake/ml/modeling/pipeline/pipeline.py +1 -12
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  189. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  190. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  191. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  192. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  193. snowflake/ml/modeling/svm/svc.py +60 -21
  194. snowflake/ml/modeling/svm/svr.py +60 -21
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  203. snowflake/ml/registry/_manager/model_manager.py +4 -0
  204. snowflake/ml/registry/model_registry.py +1 -1
  205. snowflake/ml/registry/registry.py +1 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +23 -4
  208. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +211 -209
  209. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  210. snowflake/ml/data/torch_dataset.py +0 -33
  211. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  212. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -534,12 +531,23 @@ class SparsePCA(BaseTransformer):
534
531
  autogenerated=self._autogenerated,
535
532
  subproject=_SUBPROJECT,
536
533
  )
537
- output_result, fitted_estimator = model_trainer.train_fit_predict(
538
- drop_input_cols=self._drop_input_cols,
539
- expected_output_cols_list=(
540
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
541
- ),
534
+ expected_output_cols = (
535
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
542
536
  )
537
+ if isinstance(dataset, DataFrame):
538
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
539
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
540
+ )
541
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
542
+ drop_input_cols=self._drop_input_cols,
543
+ expected_output_cols_list=expected_output_cols,
544
+ example_output_pd_df=example_output_pd_df,
545
+ )
546
+ else:
547
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
548
+ drop_input_cols=self._drop_input_cols,
549
+ expected_output_cols_list=expected_output_cols,
550
+ )
543
551
  self._sklearn_object = fitted_estimator
544
552
  self._is_fitted = True
545
553
  return output_result
@@ -620,12 +628,41 @@ class SparsePCA(BaseTransformer):
620
628
 
621
629
  return rv
622
630
 
623
- def _align_expected_output_names(
624
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
625
- ) -> List[str]:
631
+ def _align_expected_output(
632
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
633
+ ) -> Tuple[List[str], pd.DataFrame]:
634
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
635
+ and output dataframe with 1 line.
636
+ If the method is fit_predict, run 2 lines of data.
637
+ """
626
638
  # in case the inferred output column names dimension is different
627
639
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
628
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
640
+
641
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
642
+ # so change the minimum of number of rows to 2
643
+ num_examples = 2
644
+ statement_params = telemetry.get_function_usage_statement_params(
645
+ project=_PROJECT,
646
+ subproject=_SUBPROJECT,
647
+ function_name=telemetry.get_statement_params_full_func_name(
648
+ inspect.currentframe(), SparsePCA.__class__.__name__
649
+ ),
650
+ api_calls=[Session.call],
651
+ custom_tags={"autogen": True} if self._autogenerated else None,
652
+ )
653
+ if output_cols_prefix == "fit_predict_":
654
+ if hasattr(self._sklearn_object, "n_clusters"):
655
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
656
+ num_examples = self._sklearn_object.n_clusters
657
+ elif hasattr(self._sklearn_object, "min_samples"):
658
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
659
+ num_examples = self._sklearn_object.min_samples
660
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
661
+ # LocalOutlierFactor expects n_neighbors <= n_samples
662
+ num_examples = self._sklearn_object.n_neighbors
663
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
664
+ else:
665
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
629
666
 
630
667
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
631
668
  # seen during the fit.
@@ -637,12 +674,14 @@ class SparsePCA(BaseTransformer):
637
674
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
638
675
  if self.sample_weight_col:
639
676
  output_df_columns_set -= set(self.sample_weight_col)
677
+
640
678
  # if the dimension of inferred output column names is correct; use it
641
679
  if len(expected_output_cols_list) == len(output_df_columns_set):
642
- return expected_output_cols_list
680
+ return expected_output_cols_list, output_df_pd
643
681
  # otherwise, use the sklearn estimator's output
644
682
  else:
645
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
683
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
684
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
646
685
 
647
686
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
648
687
  @telemetry.send_api_usage_telemetry(
@@ -688,7 +727,7 @@ class SparsePCA(BaseTransformer):
688
727
  drop_input_cols=self._drop_input_cols,
689
728
  expected_output_cols_type="float",
690
729
  )
691
- expected_output_cols = self._align_expected_output_names(
730
+ expected_output_cols, _ = self._align_expected_output(
692
731
  inference_method, dataset, expected_output_cols, output_cols_prefix
693
732
  )
694
733
 
@@ -754,7 +793,7 @@ class SparsePCA(BaseTransformer):
754
793
  drop_input_cols=self._drop_input_cols,
755
794
  expected_output_cols_type="float",
756
795
  )
757
- expected_output_cols = self._align_expected_output_names(
796
+ expected_output_cols, _ = self._align_expected_output(
758
797
  inference_method, dataset, expected_output_cols, output_cols_prefix
759
798
  )
760
799
  elif isinstance(dataset, pd.DataFrame):
@@ -817,7 +856,7 @@ class SparsePCA(BaseTransformer):
817
856
  drop_input_cols=self._drop_input_cols,
818
857
  expected_output_cols_type="float",
819
858
  )
820
- expected_output_cols = self._align_expected_output_names(
859
+ expected_output_cols, _ = self._align_expected_output(
821
860
  inference_method, dataset, expected_output_cols, output_cols_prefix
822
861
  )
823
862
 
@@ -882,7 +921,7 @@ class SparsePCA(BaseTransformer):
882
921
  drop_input_cols = self._drop_input_cols,
883
922
  expected_output_cols_type="float",
884
923
  )
885
- expected_output_cols = self._align_expected_output_names(
924
+ expected_output_cols, _ = self._align_expected_output(
886
925
  inference_method, dataset, expected_output_cols, output_cols_prefix
887
926
  )
888
927
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -515,12 +512,23 @@ class TruncatedSVD(BaseTransformer):
515
512
  autogenerated=self._autogenerated,
516
513
  subproject=_SUBPROJECT,
517
514
  )
518
- output_result, fitted_estimator = model_trainer.train_fit_predict(
519
- drop_input_cols=self._drop_input_cols,
520
- expected_output_cols_list=(
521
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
522
- ),
515
+ expected_output_cols = (
516
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
523
517
  )
518
+ if isinstance(dataset, DataFrame):
519
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
520
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
521
+ )
522
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
523
+ drop_input_cols=self._drop_input_cols,
524
+ expected_output_cols_list=expected_output_cols,
525
+ example_output_pd_df=example_output_pd_df,
526
+ )
527
+ else:
528
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
529
+ drop_input_cols=self._drop_input_cols,
530
+ expected_output_cols_list=expected_output_cols,
531
+ )
524
532
  self._sklearn_object = fitted_estimator
525
533
  self._is_fitted = True
526
534
  return output_result
@@ -601,12 +609,41 @@ class TruncatedSVD(BaseTransformer):
601
609
 
602
610
  return rv
603
611
 
604
- def _align_expected_output_names(
605
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
606
- ) -> List[str]:
612
+ def _align_expected_output(
613
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
614
+ ) -> Tuple[List[str], pd.DataFrame]:
615
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
616
+ and output dataframe with 1 line.
617
+ If the method is fit_predict, run 2 lines of data.
618
+ """
607
619
  # in case the inferred output column names dimension is different
608
620
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
609
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
621
+
622
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
623
+ # so change the minimum of number of rows to 2
624
+ num_examples = 2
625
+ statement_params = telemetry.get_function_usage_statement_params(
626
+ project=_PROJECT,
627
+ subproject=_SUBPROJECT,
628
+ function_name=telemetry.get_statement_params_full_func_name(
629
+ inspect.currentframe(), TruncatedSVD.__class__.__name__
630
+ ),
631
+ api_calls=[Session.call],
632
+ custom_tags={"autogen": True} if self._autogenerated else None,
633
+ )
634
+ if output_cols_prefix == "fit_predict_":
635
+ if hasattr(self._sklearn_object, "n_clusters"):
636
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
637
+ num_examples = self._sklearn_object.n_clusters
638
+ elif hasattr(self._sklearn_object, "min_samples"):
639
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
640
+ num_examples = self._sklearn_object.min_samples
641
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
642
+ # LocalOutlierFactor expects n_neighbors <= n_samples
643
+ num_examples = self._sklearn_object.n_neighbors
644
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
645
+ else:
646
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
610
647
 
611
648
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
612
649
  # seen during the fit.
@@ -618,12 +655,14 @@ class TruncatedSVD(BaseTransformer):
618
655
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
619
656
  if self.sample_weight_col:
620
657
  output_df_columns_set -= set(self.sample_weight_col)
658
+
621
659
  # if the dimension of inferred output column names is correct; use it
622
660
  if len(expected_output_cols_list) == len(output_df_columns_set):
623
- return expected_output_cols_list
661
+ return expected_output_cols_list, output_df_pd
624
662
  # otherwise, use the sklearn estimator's output
625
663
  else:
626
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
664
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
665
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
627
666
 
628
667
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
629
668
  @telemetry.send_api_usage_telemetry(
@@ -669,7 +708,7 @@ class TruncatedSVD(BaseTransformer):
669
708
  drop_input_cols=self._drop_input_cols,
670
709
  expected_output_cols_type="float",
671
710
  )
672
- expected_output_cols = self._align_expected_output_names(
711
+ expected_output_cols, _ = self._align_expected_output(
673
712
  inference_method, dataset, expected_output_cols, output_cols_prefix
674
713
  )
675
714
 
@@ -735,7 +774,7 @@ class TruncatedSVD(BaseTransformer):
735
774
  drop_input_cols=self._drop_input_cols,
736
775
  expected_output_cols_type="float",
737
776
  )
738
- expected_output_cols = self._align_expected_output_names(
777
+ expected_output_cols, _ = self._align_expected_output(
739
778
  inference_method, dataset, expected_output_cols, output_cols_prefix
740
779
  )
741
780
  elif isinstance(dataset, pd.DataFrame):
@@ -798,7 +837,7 @@ class TruncatedSVD(BaseTransformer):
798
837
  drop_input_cols=self._drop_input_cols,
799
838
  expected_output_cols_type="float",
800
839
  )
801
- expected_output_cols = self._align_expected_output_names(
840
+ expected_output_cols, _ = self._align_expected_output(
802
841
  inference_method, dataset, expected_output_cols, output_cols_prefix
803
842
  )
804
843
 
@@ -863,7 +902,7 @@ class TruncatedSVD(BaseTransformer):
863
902
  drop_input_cols = self._drop_input_cols,
864
903
  expected_output_cols_type="float",
865
904
  )
866
- expected_output_cols = self._align_expected_output_names(
905
+ expected_output_cols, _ = self._align_expected_output(
867
906
  inference_method, dataset, expected_output_cols, output_cols_prefix
868
907
  )
869
908
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -532,12 +529,23 @@ class LinearDiscriminantAnalysis(BaseTransformer):
532
529
  autogenerated=self._autogenerated,
533
530
  subproject=_SUBPROJECT,
534
531
  )
535
- output_result, fitted_estimator = model_trainer.train_fit_predict(
536
- drop_input_cols=self._drop_input_cols,
537
- expected_output_cols_list=(
538
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
539
- ),
532
+ expected_output_cols = (
533
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
540
534
  )
535
+ if isinstance(dataset, DataFrame):
536
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
537
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
538
+ )
539
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
540
+ drop_input_cols=self._drop_input_cols,
541
+ expected_output_cols_list=expected_output_cols,
542
+ example_output_pd_df=example_output_pd_df,
543
+ )
544
+ else:
545
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
546
+ drop_input_cols=self._drop_input_cols,
547
+ expected_output_cols_list=expected_output_cols,
548
+ )
541
549
  self._sklearn_object = fitted_estimator
542
550
  self._is_fitted = True
543
551
  return output_result
@@ -618,12 +626,41 @@ class LinearDiscriminantAnalysis(BaseTransformer):
618
626
 
619
627
  return rv
620
628
 
621
- def _align_expected_output_names(
622
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
623
- ) -> List[str]:
629
+ def _align_expected_output(
630
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
631
+ ) -> Tuple[List[str], pd.DataFrame]:
632
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
633
+ and output dataframe with 1 line.
634
+ If the method is fit_predict, run 2 lines of data.
635
+ """
624
636
  # in case the inferred output column names dimension is different
625
637
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
626
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
638
+
639
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
640
+ # so change the minimum of number of rows to 2
641
+ num_examples = 2
642
+ statement_params = telemetry.get_function_usage_statement_params(
643
+ project=_PROJECT,
644
+ subproject=_SUBPROJECT,
645
+ function_name=telemetry.get_statement_params_full_func_name(
646
+ inspect.currentframe(), LinearDiscriminantAnalysis.__class__.__name__
647
+ ),
648
+ api_calls=[Session.call],
649
+ custom_tags={"autogen": True} if self._autogenerated else None,
650
+ )
651
+ if output_cols_prefix == "fit_predict_":
652
+ if hasattr(self._sklearn_object, "n_clusters"):
653
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
654
+ num_examples = self._sklearn_object.n_clusters
655
+ elif hasattr(self._sklearn_object, "min_samples"):
656
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
657
+ num_examples = self._sklearn_object.min_samples
658
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
659
+ # LocalOutlierFactor expects n_neighbors <= n_samples
660
+ num_examples = self._sklearn_object.n_neighbors
661
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
662
+ else:
663
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
627
664
 
628
665
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
629
666
  # seen during the fit.
@@ -635,12 +672,14 @@ class LinearDiscriminantAnalysis(BaseTransformer):
635
672
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
636
673
  if self.sample_weight_col:
637
674
  output_df_columns_set -= set(self.sample_weight_col)
675
+
638
676
  # if the dimension of inferred output column names is correct; use it
639
677
  if len(expected_output_cols_list) == len(output_df_columns_set):
640
- return expected_output_cols_list
678
+ return expected_output_cols_list, output_df_pd
641
679
  # otherwise, use the sklearn estimator's output
642
680
  else:
643
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
681
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
682
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
644
683
 
645
684
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
646
685
  @telemetry.send_api_usage_telemetry(
@@ -688,7 +727,7 @@ class LinearDiscriminantAnalysis(BaseTransformer):
688
727
  drop_input_cols=self._drop_input_cols,
689
728
  expected_output_cols_type="float",
690
729
  )
691
- expected_output_cols = self._align_expected_output_names(
730
+ expected_output_cols, _ = self._align_expected_output(
692
731
  inference_method, dataset, expected_output_cols, output_cols_prefix
693
732
  )
694
733
 
@@ -756,7 +795,7 @@ class LinearDiscriminantAnalysis(BaseTransformer):
756
795
  drop_input_cols=self._drop_input_cols,
757
796
  expected_output_cols_type="float",
758
797
  )
759
- expected_output_cols = self._align_expected_output_names(
798
+ expected_output_cols, _ = self._align_expected_output(
760
799
  inference_method, dataset, expected_output_cols, output_cols_prefix
761
800
  )
762
801
  elif isinstance(dataset, pd.DataFrame):
@@ -821,7 +860,7 @@ class LinearDiscriminantAnalysis(BaseTransformer):
821
860
  drop_input_cols=self._drop_input_cols,
822
861
  expected_output_cols_type="float",
823
862
  )
824
- expected_output_cols = self._align_expected_output_names(
863
+ expected_output_cols, _ = self._align_expected_output(
825
864
  inference_method, dataset, expected_output_cols, output_cols_prefix
826
865
  )
827
866
 
@@ -886,7 +925,7 @@ class LinearDiscriminantAnalysis(BaseTransformer):
886
925
  drop_input_cols = self._drop_input_cols,
887
926
  expected_output_cols_type="float",
888
927
  )
889
- expected_output_cols = self._align_expected_output_names(
928
+ expected_output_cols, _ = self._align_expected_output(
890
929
  inference_method, dataset, expected_output_cols, output_cols_prefix
891
930
  )
892
931
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -492,12 +489,23 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
492
489
  autogenerated=self._autogenerated,
493
490
  subproject=_SUBPROJECT,
494
491
  )
495
- output_result, fitted_estimator = model_trainer.train_fit_predict(
496
- drop_input_cols=self._drop_input_cols,
497
- expected_output_cols_list=(
498
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
499
- ),
492
+ expected_output_cols = (
493
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
500
494
  )
495
+ if isinstance(dataset, DataFrame):
496
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
497
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
498
+ )
499
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
500
+ drop_input_cols=self._drop_input_cols,
501
+ expected_output_cols_list=expected_output_cols,
502
+ example_output_pd_df=example_output_pd_df,
503
+ )
504
+ else:
505
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
506
+ drop_input_cols=self._drop_input_cols,
507
+ expected_output_cols_list=expected_output_cols,
508
+ )
501
509
  self._sklearn_object = fitted_estimator
502
510
  self._is_fitted = True
503
511
  return output_result
@@ -576,12 +584,41 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
576
584
 
577
585
  return rv
578
586
 
579
- def _align_expected_output_names(
580
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
581
- ) -> List[str]:
587
+ def _align_expected_output(
588
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
589
+ ) -> Tuple[List[str], pd.DataFrame]:
590
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
591
+ and output dataframe with 1 line.
592
+ If the method is fit_predict, run 2 lines of data.
593
+ """
582
594
  # in case the inferred output column names dimension is different
583
595
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
584
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
596
+
597
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
598
+ # so change the minimum of number of rows to 2
599
+ num_examples = 2
600
+ statement_params = telemetry.get_function_usage_statement_params(
601
+ project=_PROJECT,
602
+ subproject=_SUBPROJECT,
603
+ function_name=telemetry.get_statement_params_full_func_name(
604
+ inspect.currentframe(), QuadraticDiscriminantAnalysis.__class__.__name__
605
+ ),
606
+ api_calls=[Session.call],
607
+ custom_tags={"autogen": True} if self._autogenerated else None,
608
+ )
609
+ if output_cols_prefix == "fit_predict_":
610
+ if hasattr(self._sklearn_object, "n_clusters"):
611
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
612
+ num_examples = self._sklearn_object.n_clusters
613
+ elif hasattr(self._sklearn_object, "min_samples"):
614
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
615
+ num_examples = self._sklearn_object.min_samples
616
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
617
+ # LocalOutlierFactor expects n_neighbors <= n_samples
618
+ num_examples = self._sklearn_object.n_neighbors
619
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
620
+ else:
621
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
585
622
 
586
623
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
587
624
  # seen during the fit.
@@ -593,12 +630,14 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
593
630
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
594
631
  if self.sample_weight_col:
595
632
  output_df_columns_set -= set(self.sample_weight_col)
633
+
596
634
  # if the dimension of inferred output column names is correct; use it
597
635
  if len(expected_output_cols_list) == len(output_df_columns_set):
598
- return expected_output_cols_list
636
+ return expected_output_cols_list, output_df_pd
599
637
  # otherwise, use the sklearn estimator's output
600
638
  else:
601
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
639
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
640
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
602
641
 
603
642
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
604
643
  @telemetry.send_api_usage_telemetry(
@@ -646,7 +685,7 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
646
685
  drop_input_cols=self._drop_input_cols,
647
686
  expected_output_cols_type="float",
648
687
  )
649
- expected_output_cols = self._align_expected_output_names(
688
+ expected_output_cols, _ = self._align_expected_output(
650
689
  inference_method, dataset, expected_output_cols, output_cols_prefix
651
690
  )
652
691
 
@@ -714,7 +753,7 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
714
753
  drop_input_cols=self._drop_input_cols,
715
754
  expected_output_cols_type="float",
716
755
  )
717
- expected_output_cols = self._align_expected_output_names(
756
+ expected_output_cols, _ = self._align_expected_output(
718
757
  inference_method, dataset, expected_output_cols, output_cols_prefix
719
758
  )
720
759
  elif isinstance(dataset, pd.DataFrame):
@@ -779,7 +818,7 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
779
818
  drop_input_cols=self._drop_input_cols,
780
819
  expected_output_cols_type="float",
781
820
  )
782
- expected_output_cols = self._align_expected_output_names(
821
+ expected_output_cols, _ = self._align_expected_output(
783
822
  inference_method, dataset, expected_output_cols, output_cols_prefix
784
823
  )
785
824
 
@@ -844,7 +883,7 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
844
883
  drop_input_cols = self._drop_input_cols,
845
884
  expected_output_cols_type="float",
846
885
  )
847
- expected_output_cols = self._align_expected_output_names(
886
+ expected_output_cols, _ = self._align_expected_output(
848
887
  inference_method, dataset, expected_output_cols, output_cols_prefix
849
888
  )
850
889