snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. snowflake/ml/_internal/telemetry.py +142 -20
  2. snowflake/ml/_internal/utils/identifier.py +48 -11
  3. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  4. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  5. snowflake/ml/_internal/utils/table_manager.py +19 -1
  6. snowflake/ml/_internal/utils/uri.py +2 -2
  7. snowflake/ml/data/data_connector.py +33 -7
  8. snowflake/ml/data/torch_utils.py +68 -0
  9. snowflake/ml/dataset/dataset.py +1 -3
  10. snowflake/ml/feature_store/feature_store.py +41 -17
  11. snowflake/ml/feature_store/feature_view.py +2 -2
  12. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  13. snowflake/ml/fileset/fileset.py +1 -1
  14. snowflake/ml/fileset/sfcfs.py +9 -3
  15. snowflake/ml/model/_client/model/model_version_impl.py +22 -7
  16. snowflake/ml/model/_client/ops/model_ops.py +39 -3
  17. snowflake/ml/model/_client/ops/service_ops.py +198 -7
  18. snowflake/ml/model/_client/service/model_deployment_spec.py +4 -5
  19. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -2
  20. snowflake/ml/model/_client/sql/service.py +85 -18
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  23. snowflake/ml/model/_model_composer/model_composer.py +2 -0
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -8
  25. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  26. snowflake/ml/model/_packager/model_handlers/catboost.py +17 -15
  27. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +23 -15
  28. snowflake/ml/model/_packager/model_handlers/lightgbm.py +15 -57
  29. snowflake/ml/model/_packager/model_handlers/llm.py +4 -2
  30. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  31. snowflake/ml/model/_packager/model_handlers/sklearn.py +36 -24
  32. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +119 -6
  33. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  34. snowflake/ml/model/_packager/model_handlers/xgboost.py +48 -48
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +10 -7
  36. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +0 -8
  37. snowflake/ml/model/_packager/model_packager.py +2 -0
  38. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  39. snowflake/ml/model/_signatures/utils.py +9 -0
  40. snowflake/ml/model/models/llm.py +3 -1
  41. snowflake/ml/model/type_hints.py +9 -1
  42. snowflake/ml/modeling/_internal/constants.py +1 -0
  43. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  44. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  45. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  46. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  48. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  49. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  50. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  51. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  52. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  53. snowflake/ml/modeling/cluster/birch.py +60 -21
  54. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  55. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  56. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  57. snowflake/ml/modeling/cluster/k_means.py +60 -21
  58. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  59. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  60. snowflake/ml/modeling/cluster/optics.py +60 -21
  61. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  62. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  63. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  64. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  65. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  66. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  67. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  68. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  69. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  70. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  71. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  72. snowflake/ml/modeling/covariance/oas.py +60 -21
  73. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  74. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  75. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  76. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  77. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  78. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  79. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  80. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  81. snowflake/ml/modeling/decomposition/pca.py +60 -21
  82. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  83. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  84. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  85. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  86. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  87. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  88. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  89. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  90. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  91. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  92. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  93. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  94. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  95. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  96. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  97. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  98. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  99. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  100. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  101. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  102. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  103. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  104. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  105. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  106. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  107. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  108. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  109. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  111. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  112. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  113. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  114. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  123. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  125. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  129. snowflake/ml/modeling/linear_model/lars.py +60 -21
  130. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  131. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  136. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  146. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  149. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  158. snowflake/ml/modeling/manifold/isomap.py +60 -21
  159. snowflake/ml/modeling/manifold/mds.py +60 -21
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  161. snowflake/ml/modeling/manifold/tsne.py +60 -21
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  174. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  184. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  185. snowflake/ml/modeling/pipeline/pipeline.py +1 -12
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  189. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  190. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  191. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  192. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  193. snowflake/ml/modeling/svm/svc.py +60 -21
  194. snowflake/ml/modeling/svm/svr.py +60 -21
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  203. snowflake/ml/registry/_manager/model_manager.py +4 -0
  204. snowflake/ml/registry/model_registry.py +1 -1
  205. snowflake/ml/registry/registry.py +1 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +23 -4
  208. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +211 -209
  209. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  210. snowflake/ml/data/torch_dataset.py +0 -33
  211. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  212. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -510,12 +507,23 @@ class LabelSpreading(BaseTransformer):
510
507
  autogenerated=self._autogenerated,
511
508
  subproject=_SUBPROJECT,
512
509
  )
513
- output_result, fitted_estimator = model_trainer.train_fit_predict(
514
- drop_input_cols=self._drop_input_cols,
515
- expected_output_cols_list=(
516
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
517
- ),
510
+ expected_output_cols = (
511
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
518
512
  )
513
+ if isinstance(dataset, DataFrame):
514
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
515
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
516
+ )
517
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
518
+ drop_input_cols=self._drop_input_cols,
519
+ expected_output_cols_list=expected_output_cols,
520
+ example_output_pd_df=example_output_pd_df,
521
+ )
522
+ else:
523
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
524
+ drop_input_cols=self._drop_input_cols,
525
+ expected_output_cols_list=expected_output_cols,
526
+ )
519
527
  self._sklearn_object = fitted_estimator
520
528
  self._is_fitted = True
521
529
  return output_result
@@ -594,12 +602,41 @@ class LabelSpreading(BaseTransformer):
594
602
 
595
603
  return rv
596
604
 
597
- def _align_expected_output_names(
598
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
599
- ) -> List[str]:
605
+ def _align_expected_output(
606
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
607
+ ) -> Tuple[List[str], pd.DataFrame]:
608
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
609
+ and output dataframe with 1 line.
610
+ If the method is fit_predict, run 2 lines of data.
611
+ """
600
612
  # in case the inferred output column names dimension is different
601
613
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
602
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
614
+
615
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
616
+ # so change the minimum of number of rows to 2
617
+ num_examples = 2
618
+ statement_params = telemetry.get_function_usage_statement_params(
619
+ project=_PROJECT,
620
+ subproject=_SUBPROJECT,
621
+ function_name=telemetry.get_statement_params_full_func_name(
622
+ inspect.currentframe(), LabelSpreading.__class__.__name__
623
+ ),
624
+ api_calls=[Session.call],
625
+ custom_tags={"autogen": True} if self._autogenerated else None,
626
+ )
627
+ if output_cols_prefix == "fit_predict_":
628
+ if hasattr(self._sklearn_object, "n_clusters"):
629
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
630
+ num_examples = self._sklearn_object.n_clusters
631
+ elif hasattr(self._sklearn_object, "min_samples"):
632
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
633
+ num_examples = self._sklearn_object.min_samples
634
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
635
+ # LocalOutlierFactor expects n_neighbors <= n_samples
636
+ num_examples = self._sklearn_object.n_neighbors
637
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
638
+ else:
639
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
603
640
 
604
641
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
605
642
  # seen during the fit.
@@ -611,12 +648,14 @@ class LabelSpreading(BaseTransformer):
611
648
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
612
649
  if self.sample_weight_col:
613
650
  output_df_columns_set -= set(self.sample_weight_col)
651
+
614
652
  # if the dimension of inferred output column names is correct; use it
615
653
  if len(expected_output_cols_list) == len(output_df_columns_set):
616
- return expected_output_cols_list
654
+ return expected_output_cols_list, output_df_pd
617
655
  # otherwise, use the sklearn estimator's output
618
656
  else:
619
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
657
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
658
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
620
659
 
621
660
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
622
661
  @telemetry.send_api_usage_telemetry(
@@ -664,7 +703,7 @@ class LabelSpreading(BaseTransformer):
664
703
  drop_input_cols=self._drop_input_cols,
665
704
  expected_output_cols_type="float",
666
705
  )
667
- expected_output_cols = self._align_expected_output_names(
706
+ expected_output_cols, _ = self._align_expected_output(
668
707
  inference_method, dataset, expected_output_cols, output_cols_prefix
669
708
  )
670
709
 
@@ -732,7 +771,7 @@ class LabelSpreading(BaseTransformer):
732
771
  drop_input_cols=self._drop_input_cols,
733
772
  expected_output_cols_type="float",
734
773
  )
735
- expected_output_cols = self._align_expected_output_names(
774
+ expected_output_cols, _ = self._align_expected_output(
736
775
  inference_method, dataset, expected_output_cols, output_cols_prefix
737
776
  )
738
777
  elif isinstance(dataset, pd.DataFrame):
@@ -795,7 +834,7 @@ class LabelSpreading(BaseTransformer):
795
834
  drop_input_cols=self._drop_input_cols,
796
835
  expected_output_cols_type="float",
797
836
  )
798
- expected_output_cols = self._align_expected_output_names(
837
+ expected_output_cols, _ = self._align_expected_output(
799
838
  inference_method, dataset, expected_output_cols, output_cols_prefix
800
839
  )
801
840
 
@@ -860,7 +899,7 @@ class LabelSpreading(BaseTransformer):
860
899
  drop_input_cols = self._drop_input_cols,
861
900
  expected_output_cols_type="float",
862
901
  )
863
- expected_output_cols = self._align_expected_output_names(
902
+ expected_output_cols, _ = self._align_expected_output(
864
903
  inference_method, dataset, expected_output_cols, output_cols_prefix
865
904
  )
866
905
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -566,12 +563,23 @@ class LinearSVC(BaseTransformer):
566
563
  autogenerated=self._autogenerated,
567
564
  subproject=_SUBPROJECT,
568
565
  )
569
- output_result, fitted_estimator = model_trainer.train_fit_predict(
570
- drop_input_cols=self._drop_input_cols,
571
- expected_output_cols_list=(
572
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
573
- ),
566
+ expected_output_cols = (
567
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
574
568
  )
569
+ if isinstance(dataset, DataFrame):
570
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
571
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
572
+ )
573
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
574
+ drop_input_cols=self._drop_input_cols,
575
+ expected_output_cols_list=expected_output_cols,
576
+ example_output_pd_df=example_output_pd_df,
577
+ )
578
+ else:
579
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
580
+ drop_input_cols=self._drop_input_cols,
581
+ expected_output_cols_list=expected_output_cols,
582
+ )
575
583
  self._sklearn_object = fitted_estimator
576
584
  self._is_fitted = True
577
585
  return output_result
@@ -650,12 +658,41 @@ class LinearSVC(BaseTransformer):
650
658
 
651
659
  return rv
652
660
 
653
- def _align_expected_output_names(
654
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
655
- ) -> List[str]:
661
+ def _align_expected_output(
662
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
663
+ ) -> Tuple[List[str], pd.DataFrame]:
664
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
665
+ and output dataframe with 1 line.
666
+ If the method is fit_predict, run 2 lines of data.
667
+ """
656
668
  # in case the inferred output column names dimension is different
657
669
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
658
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
670
+
671
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
672
+ # so change the minimum of number of rows to 2
673
+ num_examples = 2
674
+ statement_params = telemetry.get_function_usage_statement_params(
675
+ project=_PROJECT,
676
+ subproject=_SUBPROJECT,
677
+ function_name=telemetry.get_statement_params_full_func_name(
678
+ inspect.currentframe(), LinearSVC.__class__.__name__
679
+ ),
680
+ api_calls=[Session.call],
681
+ custom_tags={"autogen": True} if self._autogenerated else None,
682
+ )
683
+ if output_cols_prefix == "fit_predict_":
684
+ if hasattr(self._sklearn_object, "n_clusters"):
685
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
686
+ num_examples = self._sklearn_object.n_clusters
687
+ elif hasattr(self._sklearn_object, "min_samples"):
688
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
689
+ num_examples = self._sklearn_object.min_samples
690
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
691
+ # LocalOutlierFactor expects n_neighbors <= n_samples
692
+ num_examples = self._sklearn_object.n_neighbors
693
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
694
+ else:
695
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
659
696
 
660
697
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
661
698
  # seen during the fit.
@@ -667,12 +704,14 @@ class LinearSVC(BaseTransformer):
667
704
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
668
705
  if self.sample_weight_col:
669
706
  output_df_columns_set -= set(self.sample_weight_col)
707
+
670
708
  # if the dimension of inferred output column names is correct; use it
671
709
  if len(expected_output_cols_list) == len(output_df_columns_set):
672
- return expected_output_cols_list
710
+ return expected_output_cols_list, output_df_pd
673
711
  # otherwise, use the sklearn estimator's output
674
712
  else:
675
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
713
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
714
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
676
715
 
677
716
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
678
717
  @telemetry.send_api_usage_telemetry(
@@ -718,7 +757,7 @@ class LinearSVC(BaseTransformer):
718
757
  drop_input_cols=self._drop_input_cols,
719
758
  expected_output_cols_type="float",
720
759
  )
721
- expected_output_cols = self._align_expected_output_names(
760
+ expected_output_cols, _ = self._align_expected_output(
722
761
  inference_method, dataset, expected_output_cols, output_cols_prefix
723
762
  )
724
763
 
@@ -784,7 +823,7 @@ class LinearSVC(BaseTransformer):
784
823
  drop_input_cols=self._drop_input_cols,
785
824
  expected_output_cols_type="float",
786
825
  )
787
- expected_output_cols = self._align_expected_output_names(
826
+ expected_output_cols, _ = self._align_expected_output(
788
827
  inference_method, dataset, expected_output_cols, output_cols_prefix
789
828
  )
790
829
  elif isinstance(dataset, pd.DataFrame):
@@ -849,7 +888,7 @@ class LinearSVC(BaseTransformer):
849
888
  drop_input_cols=self._drop_input_cols,
850
889
  expected_output_cols_type="float",
851
890
  )
852
- expected_output_cols = self._align_expected_output_names(
891
+ expected_output_cols, _ = self._align_expected_output(
853
892
  inference_method, dataset, expected_output_cols, output_cols_prefix
854
893
  )
855
894
 
@@ -914,7 +953,7 @@ class LinearSVC(BaseTransformer):
914
953
  drop_input_cols = self._drop_input_cols,
915
954
  expected_output_cols_type="float",
916
955
  )
917
- expected_output_cols = self._align_expected_output_names(
956
+ expected_output_cols, _ = self._align_expected_output(
918
957
  inference_method, dataset, expected_output_cols, output_cols_prefix
919
958
  )
920
959
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -538,12 +535,23 @@ class LinearSVR(BaseTransformer):
538
535
  autogenerated=self._autogenerated,
539
536
  subproject=_SUBPROJECT,
540
537
  )
541
- output_result, fitted_estimator = model_trainer.train_fit_predict(
542
- drop_input_cols=self._drop_input_cols,
543
- expected_output_cols_list=(
544
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
545
- ),
538
+ expected_output_cols = (
539
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
546
540
  )
541
+ if isinstance(dataset, DataFrame):
542
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
543
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
544
+ )
545
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
546
+ drop_input_cols=self._drop_input_cols,
547
+ expected_output_cols_list=expected_output_cols,
548
+ example_output_pd_df=example_output_pd_df,
549
+ )
550
+ else:
551
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
552
+ drop_input_cols=self._drop_input_cols,
553
+ expected_output_cols_list=expected_output_cols,
554
+ )
547
555
  self._sklearn_object = fitted_estimator
548
556
  self._is_fitted = True
549
557
  return output_result
@@ -622,12 +630,41 @@ class LinearSVR(BaseTransformer):
622
630
 
623
631
  return rv
624
632
 
625
- def _align_expected_output_names(
626
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
627
- ) -> List[str]:
633
+ def _align_expected_output(
634
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
635
+ ) -> Tuple[List[str], pd.DataFrame]:
636
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
637
+ and output dataframe with 1 line.
638
+ If the method is fit_predict, run 2 lines of data.
639
+ """
628
640
  # in case the inferred output column names dimension is different
629
641
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
630
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
642
+
643
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
644
+ # so change the minimum of number of rows to 2
645
+ num_examples = 2
646
+ statement_params = telemetry.get_function_usage_statement_params(
647
+ project=_PROJECT,
648
+ subproject=_SUBPROJECT,
649
+ function_name=telemetry.get_statement_params_full_func_name(
650
+ inspect.currentframe(), LinearSVR.__class__.__name__
651
+ ),
652
+ api_calls=[Session.call],
653
+ custom_tags={"autogen": True} if self._autogenerated else None,
654
+ )
655
+ if output_cols_prefix == "fit_predict_":
656
+ if hasattr(self._sklearn_object, "n_clusters"):
657
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
658
+ num_examples = self._sklearn_object.n_clusters
659
+ elif hasattr(self._sklearn_object, "min_samples"):
660
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
661
+ num_examples = self._sklearn_object.min_samples
662
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
663
+ # LocalOutlierFactor expects n_neighbors <= n_samples
664
+ num_examples = self._sklearn_object.n_neighbors
665
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
666
+ else:
667
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
631
668
 
632
669
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
633
670
  # seen during the fit.
@@ -639,12 +676,14 @@ class LinearSVR(BaseTransformer):
639
676
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
640
677
  if self.sample_weight_col:
641
678
  output_df_columns_set -= set(self.sample_weight_col)
679
+
642
680
  # if the dimension of inferred output column names is correct; use it
643
681
  if len(expected_output_cols_list) == len(output_df_columns_set):
644
- return expected_output_cols_list
682
+ return expected_output_cols_list, output_df_pd
645
683
  # otherwise, use the sklearn estimator's output
646
684
  else:
647
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
685
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
686
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
648
687
 
649
688
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
650
689
  @telemetry.send_api_usage_telemetry(
@@ -690,7 +729,7 @@ class LinearSVR(BaseTransformer):
690
729
  drop_input_cols=self._drop_input_cols,
691
730
  expected_output_cols_type="float",
692
731
  )
693
- expected_output_cols = self._align_expected_output_names(
732
+ expected_output_cols, _ = self._align_expected_output(
694
733
  inference_method, dataset, expected_output_cols, output_cols_prefix
695
734
  )
696
735
 
@@ -756,7 +795,7 @@ class LinearSVR(BaseTransformer):
756
795
  drop_input_cols=self._drop_input_cols,
757
796
  expected_output_cols_type="float",
758
797
  )
759
- expected_output_cols = self._align_expected_output_names(
798
+ expected_output_cols, _ = self._align_expected_output(
760
799
  inference_method, dataset, expected_output_cols, output_cols_prefix
761
800
  )
762
801
  elif isinstance(dataset, pd.DataFrame):
@@ -819,7 +858,7 @@ class LinearSVR(BaseTransformer):
819
858
  drop_input_cols=self._drop_input_cols,
820
859
  expected_output_cols_type="float",
821
860
  )
822
- expected_output_cols = self._align_expected_output_names(
861
+ expected_output_cols, _ = self._align_expected_output(
823
862
  inference_method, dataset, expected_output_cols, output_cols_prefix
824
863
  )
825
864
 
@@ -884,7 +923,7 @@ class LinearSVR(BaseTransformer):
884
923
  drop_input_cols = self._drop_input_cols,
885
924
  expected_output_cols_type="float",
886
925
  )
887
- expected_output_cols = self._align_expected_output_names(
926
+ expected_output_cols, _ = self._align_expected_output(
888
927
  inference_method, dataset, expected_output_cols, output_cols_prefix
889
928
  )
890
929
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -572,12 +569,23 @@ class NuSVC(BaseTransformer):
572
569
  autogenerated=self._autogenerated,
573
570
  subproject=_SUBPROJECT,
574
571
  )
575
- output_result, fitted_estimator = model_trainer.train_fit_predict(
576
- drop_input_cols=self._drop_input_cols,
577
- expected_output_cols_list=(
578
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
579
- ),
572
+ expected_output_cols = (
573
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
580
574
  )
575
+ if isinstance(dataset, DataFrame):
576
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
577
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
578
+ )
579
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
580
+ drop_input_cols=self._drop_input_cols,
581
+ expected_output_cols_list=expected_output_cols,
582
+ example_output_pd_df=example_output_pd_df,
583
+ )
584
+ else:
585
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
586
+ drop_input_cols=self._drop_input_cols,
587
+ expected_output_cols_list=expected_output_cols,
588
+ )
581
589
  self._sklearn_object = fitted_estimator
582
590
  self._is_fitted = True
583
591
  return output_result
@@ -656,12 +664,41 @@ class NuSVC(BaseTransformer):
656
664
 
657
665
  return rv
658
666
 
659
- def _align_expected_output_names(
660
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
661
- ) -> List[str]:
667
+ def _align_expected_output(
668
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
669
+ ) -> Tuple[List[str], pd.DataFrame]:
670
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
671
+ and output dataframe with 1 line.
672
+ If the method is fit_predict, run 2 lines of data.
673
+ """
662
674
  # in case the inferred output column names dimension is different
663
675
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
664
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
676
+
677
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
678
+ # so change the minimum of number of rows to 2
679
+ num_examples = 2
680
+ statement_params = telemetry.get_function_usage_statement_params(
681
+ project=_PROJECT,
682
+ subproject=_SUBPROJECT,
683
+ function_name=telemetry.get_statement_params_full_func_name(
684
+ inspect.currentframe(), NuSVC.__class__.__name__
685
+ ),
686
+ api_calls=[Session.call],
687
+ custom_tags={"autogen": True} if self._autogenerated else None,
688
+ )
689
+ if output_cols_prefix == "fit_predict_":
690
+ if hasattr(self._sklearn_object, "n_clusters"):
691
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
692
+ num_examples = self._sklearn_object.n_clusters
693
+ elif hasattr(self._sklearn_object, "min_samples"):
694
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
695
+ num_examples = self._sklearn_object.min_samples
696
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
697
+ # LocalOutlierFactor expects n_neighbors <= n_samples
698
+ num_examples = self._sklearn_object.n_neighbors
699
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
700
+ else:
701
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
665
702
 
666
703
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
667
704
  # seen during the fit.
@@ -673,12 +710,14 @@ class NuSVC(BaseTransformer):
673
710
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
674
711
  if self.sample_weight_col:
675
712
  output_df_columns_set -= set(self.sample_weight_col)
713
+
676
714
  # if the dimension of inferred output column names is correct; use it
677
715
  if len(expected_output_cols_list) == len(output_df_columns_set):
678
- return expected_output_cols_list
716
+ return expected_output_cols_list, output_df_pd
679
717
  # otherwise, use the sklearn estimator's output
680
718
  else:
681
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
719
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
720
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
682
721
 
683
722
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
684
723
  @telemetry.send_api_usage_telemetry(
@@ -726,7 +765,7 @@ class NuSVC(BaseTransformer):
726
765
  drop_input_cols=self._drop_input_cols,
727
766
  expected_output_cols_type="float",
728
767
  )
729
- expected_output_cols = self._align_expected_output_names(
768
+ expected_output_cols, _ = self._align_expected_output(
730
769
  inference_method, dataset, expected_output_cols, output_cols_prefix
731
770
  )
732
771
 
@@ -794,7 +833,7 @@ class NuSVC(BaseTransformer):
794
833
  drop_input_cols=self._drop_input_cols,
795
834
  expected_output_cols_type="float",
796
835
  )
797
- expected_output_cols = self._align_expected_output_names(
836
+ expected_output_cols, _ = self._align_expected_output(
798
837
  inference_method, dataset, expected_output_cols, output_cols_prefix
799
838
  )
800
839
  elif isinstance(dataset, pd.DataFrame):
@@ -859,7 +898,7 @@ class NuSVC(BaseTransformer):
859
898
  drop_input_cols=self._drop_input_cols,
860
899
  expected_output_cols_type="float",
861
900
  )
862
- expected_output_cols = self._align_expected_output_names(
901
+ expected_output_cols, _ = self._align_expected_output(
863
902
  inference_method, dataset, expected_output_cols, output_cols_prefix
864
903
  )
865
904
 
@@ -924,7 +963,7 @@ class NuSVC(BaseTransformer):
924
963
  drop_input_cols = self._drop_input_cols,
925
964
  expected_output_cols_type="float",
926
965
  )
927
- expected_output_cols = self._align_expected_output_names(
966
+ expected_output_cols, _ = self._align_expected_output(
928
967
  inference_method, dataset, expected_output_cols, output_cols_prefix
929
968
  )
930
969