snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. snowflake/ml/_internal/telemetry.py +142 -20
  2. snowflake/ml/_internal/utils/identifier.py +48 -11
  3. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  4. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  5. snowflake/ml/_internal/utils/table_manager.py +19 -1
  6. snowflake/ml/_internal/utils/uri.py +2 -2
  7. snowflake/ml/data/data_connector.py +33 -7
  8. snowflake/ml/data/torch_utils.py +68 -0
  9. snowflake/ml/dataset/dataset.py +1 -3
  10. snowflake/ml/feature_store/feature_store.py +41 -17
  11. snowflake/ml/feature_store/feature_view.py +2 -2
  12. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  13. snowflake/ml/fileset/fileset.py +1 -1
  14. snowflake/ml/fileset/sfcfs.py +9 -3
  15. snowflake/ml/model/_client/model/model_version_impl.py +22 -7
  16. snowflake/ml/model/_client/ops/model_ops.py +39 -3
  17. snowflake/ml/model/_client/ops/service_ops.py +198 -7
  18. snowflake/ml/model/_client/service/model_deployment_spec.py +4 -5
  19. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -2
  20. snowflake/ml/model/_client/sql/service.py +85 -18
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  23. snowflake/ml/model/_model_composer/model_composer.py +2 -0
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -8
  25. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  26. snowflake/ml/model/_packager/model_handlers/catboost.py +17 -15
  27. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +23 -15
  28. snowflake/ml/model/_packager/model_handlers/lightgbm.py +15 -57
  29. snowflake/ml/model/_packager/model_handlers/llm.py +4 -2
  30. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  31. snowflake/ml/model/_packager/model_handlers/sklearn.py +36 -24
  32. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +119 -6
  33. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  34. snowflake/ml/model/_packager/model_handlers/xgboost.py +48 -48
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +10 -7
  36. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +0 -8
  37. snowflake/ml/model/_packager/model_packager.py +2 -0
  38. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  39. snowflake/ml/model/_signatures/utils.py +9 -0
  40. snowflake/ml/model/models/llm.py +3 -1
  41. snowflake/ml/model/type_hints.py +9 -1
  42. snowflake/ml/modeling/_internal/constants.py +1 -0
  43. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  44. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  45. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  46. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  48. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  49. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  50. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  51. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  52. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  53. snowflake/ml/modeling/cluster/birch.py +60 -21
  54. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  55. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  56. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  57. snowflake/ml/modeling/cluster/k_means.py +60 -21
  58. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  59. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  60. snowflake/ml/modeling/cluster/optics.py +60 -21
  61. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  62. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  63. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  64. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  65. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  66. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  67. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  68. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  69. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  70. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  71. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  72. snowflake/ml/modeling/covariance/oas.py +60 -21
  73. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  74. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  75. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  76. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  77. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  78. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  79. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  80. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  81. snowflake/ml/modeling/decomposition/pca.py +60 -21
  82. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  83. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  84. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  85. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  86. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  87. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  88. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  89. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  90. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  91. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  92. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  93. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  94. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  95. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  96. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  97. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  98. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  99. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  100. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  101. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  102. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  103. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  104. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  105. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  106. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  107. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  108. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  109. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  111. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  112. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  113. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  114. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  123. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  125. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  129. snowflake/ml/modeling/linear_model/lars.py +60 -21
  130. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  131. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  136. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  146. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  149. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  158. snowflake/ml/modeling/manifold/isomap.py +60 -21
  159. snowflake/ml/modeling/manifold/mds.py +60 -21
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  161. snowflake/ml/modeling/manifold/tsne.py +60 -21
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  174. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  184. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  185. snowflake/ml/modeling/pipeline/pipeline.py +1 -12
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  189. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  190. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  191. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  192. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  193. snowflake/ml/modeling/svm/svc.py +60 -21
  194. snowflake/ml/modeling/svm/svr.py +60 -21
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  203. snowflake/ml/registry/_manager/model_manager.py +4 -0
  204. snowflake/ml/registry/model_registry.py +1 -1
  205. snowflake/ml/registry/registry.py +1 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +23 -4
  208. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +211 -209
  209. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  210. snowflake/ml/data/torch_dataset.py +0 -33
  211. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  212. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -552,12 +549,23 @@ class MultiTaskLassoCV(BaseTransformer):
552
549
  autogenerated=self._autogenerated,
553
550
  subproject=_SUBPROJECT,
554
551
  )
555
- output_result, fitted_estimator = model_trainer.train_fit_predict(
556
- drop_input_cols=self._drop_input_cols,
557
- expected_output_cols_list=(
558
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
559
- ),
552
+ expected_output_cols = (
553
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
560
554
  )
555
+ if isinstance(dataset, DataFrame):
556
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
557
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
558
+ )
559
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
560
+ drop_input_cols=self._drop_input_cols,
561
+ expected_output_cols_list=expected_output_cols,
562
+ example_output_pd_df=example_output_pd_df,
563
+ )
564
+ else:
565
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
566
+ drop_input_cols=self._drop_input_cols,
567
+ expected_output_cols_list=expected_output_cols,
568
+ )
561
569
  self._sklearn_object = fitted_estimator
562
570
  self._is_fitted = True
563
571
  return output_result
@@ -636,12 +644,41 @@ class MultiTaskLassoCV(BaseTransformer):
636
644
 
637
645
  return rv
638
646
 
639
- def _align_expected_output_names(
640
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
641
- ) -> List[str]:
647
+ def _align_expected_output(
648
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
649
+ ) -> Tuple[List[str], pd.DataFrame]:
650
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
651
+ and output dataframe with 1 line.
652
+ If the method is fit_predict, run 2 lines of data.
653
+ """
642
654
  # in case the inferred output column names dimension is different
643
655
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
644
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
656
+
657
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
658
+ # so change the minimum of number of rows to 2
659
+ num_examples = 2
660
+ statement_params = telemetry.get_function_usage_statement_params(
661
+ project=_PROJECT,
662
+ subproject=_SUBPROJECT,
663
+ function_name=telemetry.get_statement_params_full_func_name(
664
+ inspect.currentframe(), MultiTaskLassoCV.__class__.__name__
665
+ ),
666
+ api_calls=[Session.call],
667
+ custom_tags={"autogen": True} if self._autogenerated else None,
668
+ )
669
+ if output_cols_prefix == "fit_predict_":
670
+ if hasattr(self._sklearn_object, "n_clusters"):
671
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
672
+ num_examples = self._sklearn_object.n_clusters
673
+ elif hasattr(self._sklearn_object, "min_samples"):
674
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
675
+ num_examples = self._sklearn_object.min_samples
676
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
677
+ # LocalOutlierFactor expects n_neighbors <= n_samples
678
+ num_examples = self._sklearn_object.n_neighbors
679
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
680
+ else:
681
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
645
682
 
646
683
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
647
684
  # seen during the fit.
@@ -653,12 +690,14 @@ class MultiTaskLassoCV(BaseTransformer):
653
690
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
654
691
  if self.sample_weight_col:
655
692
  output_df_columns_set -= set(self.sample_weight_col)
693
+
656
694
  # if the dimension of inferred output column names is correct; use it
657
695
  if len(expected_output_cols_list) == len(output_df_columns_set):
658
- return expected_output_cols_list
696
+ return expected_output_cols_list, output_df_pd
659
697
  # otherwise, use the sklearn estimator's output
660
698
  else:
661
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
699
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
700
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
662
701
 
663
702
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
664
703
  @telemetry.send_api_usage_telemetry(
@@ -704,7 +743,7 @@ class MultiTaskLassoCV(BaseTransformer):
704
743
  drop_input_cols=self._drop_input_cols,
705
744
  expected_output_cols_type="float",
706
745
  )
707
- expected_output_cols = self._align_expected_output_names(
746
+ expected_output_cols, _ = self._align_expected_output(
708
747
  inference_method, dataset, expected_output_cols, output_cols_prefix
709
748
  )
710
749
 
@@ -770,7 +809,7 @@ class MultiTaskLassoCV(BaseTransformer):
770
809
  drop_input_cols=self._drop_input_cols,
771
810
  expected_output_cols_type="float",
772
811
  )
773
- expected_output_cols = self._align_expected_output_names(
812
+ expected_output_cols, _ = self._align_expected_output(
774
813
  inference_method, dataset, expected_output_cols, output_cols_prefix
775
814
  )
776
815
  elif isinstance(dataset, pd.DataFrame):
@@ -833,7 +872,7 @@ class MultiTaskLassoCV(BaseTransformer):
833
872
  drop_input_cols=self._drop_input_cols,
834
873
  expected_output_cols_type="float",
835
874
  )
836
- expected_output_cols = self._align_expected_output_names(
875
+ expected_output_cols, _ = self._align_expected_output(
837
876
  inference_method, dataset, expected_output_cols, output_cols_prefix
838
877
  )
839
878
 
@@ -898,7 +937,7 @@ class MultiTaskLassoCV(BaseTransformer):
898
937
  drop_input_cols = self._drop_input_cols,
899
938
  expected_output_cols_type="float",
900
939
  )
901
- expected_output_cols = self._align_expected_output_names(
940
+ expected_output_cols, _ = self._align_expected_output(
902
941
  inference_method, dataset, expected_output_cols, output_cols_prefix
903
942
  )
904
943
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -500,12 +497,23 @@ class OrthogonalMatchingPursuit(BaseTransformer):
500
497
  autogenerated=self._autogenerated,
501
498
  subproject=_SUBPROJECT,
502
499
  )
503
- output_result, fitted_estimator = model_trainer.train_fit_predict(
504
- drop_input_cols=self._drop_input_cols,
505
- expected_output_cols_list=(
506
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
507
- ),
500
+ expected_output_cols = (
501
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
508
502
  )
503
+ if isinstance(dataset, DataFrame):
504
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
505
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
506
+ )
507
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
508
+ drop_input_cols=self._drop_input_cols,
509
+ expected_output_cols_list=expected_output_cols,
510
+ example_output_pd_df=example_output_pd_df,
511
+ )
512
+ else:
513
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
514
+ drop_input_cols=self._drop_input_cols,
515
+ expected_output_cols_list=expected_output_cols,
516
+ )
509
517
  self._sklearn_object = fitted_estimator
510
518
  self._is_fitted = True
511
519
  return output_result
@@ -584,12 +592,41 @@ class OrthogonalMatchingPursuit(BaseTransformer):
584
592
 
585
593
  return rv
586
594
 
587
- def _align_expected_output_names(
588
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
589
- ) -> List[str]:
595
+ def _align_expected_output(
596
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
597
+ ) -> Tuple[List[str], pd.DataFrame]:
598
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
599
+ and output dataframe with 1 line.
600
+ If the method is fit_predict, run 2 lines of data.
601
+ """
590
602
  # in case the inferred output column names dimension is different
591
603
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
592
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
604
+
605
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
606
+ # so change the minimum of number of rows to 2
607
+ num_examples = 2
608
+ statement_params = telemetry.get_function_usage_statement_params(
609
+ project=_PROJECT,
610
+ subproject=_SUBPROJECT,
611
+ function_name=telemetry.get_statement_params_full_func_name(
612
+ inspect.currentframe(), OrthogonalMatchingPursuit.__class__.__name__
613
+ ),
614
+ api_calls=[Session.call],
615
+ custom_tags={"autogen": True} if self._autogenerated else None,
616
+ )
617
+ if output_cols_prefix == "fit_predict_":
618
+ if hasattr(self._sklearn_object, "n_clusters"):
619
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
620
+ num_examples = self._sklearn_object.n_clusters
621
+ elif hasattr(self._sklearn_object, "min_samples"):
622
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
623
+ num_examples = self._sklearn_object.min_samples
624
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
625
+ # LocalOutlierFactor expects n_neighbors <= n_samples
626
+ num_examples = self._sklearn_object.n_neighbors
627
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
628
+ else:
629
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
593
630
 
594
631
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
595
632
  # seen during the fit.
@@ -601,12 +638,14 @@ class OrthogonalMatchingPursuit(BaseTransformer):
601
638
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
602
639
  if self.sample_weight_col:
603
640
  output_df_columns_set -= set(self.sample_weight_col)
641
+
604
642
  # if the dimension of inferred output column names is correct; use it
605
643
  if len(expected_output_cols_list) == len(output_df_columns_set):
606
- return expected_output_cols_list
644
+ return expected_output_cols_list, output_df_pd
607
645
  # otherwise, use the sklearn estimator's output
608
646
  else:
609
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
647
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
648
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
610
649
 
611
650
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
612
651
  @telemetry.send_api_usage_telemetry(
@@ -652,7 +691,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
652
691
  drop_input_cols=self._drop_input_cols,
653
692
  expected_output_cols_type="float",
654
693
  )
655
- expected_output_cols = self._align_expected_output_names(
694
+ expected_output_cols, _ = self._align_expected_output(
656
695
  inference_method, dataset, expected_output_cols, output_cols_prefix
657
696
  )
658
697
 
@@ -718,7 +757,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
718
757
  drop_input_cols=self._drop_input_cols,
719
758
  expected_output_cols_type="float",
720
759
  )
721
- expected_output_cols = self._align_expected_output_names(
760
+ expected_output_cols, _ = self._align_expected_output(
722
761
  inference_method, dataset, expected_output_cols, output_cols_prefix
723
762
  )
724
763
  elif isinstance(dataset, pd.DataFrame):
@@ -781,7 +820,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
781
820
  drop_input_cols=self._drop_input_cols,
782
821
  expected_output_cols_type="float",
783
822
  )
784
- expected_output_cols = self._align_expected_output_names(
823
+ expected_output_cols, _ = self._align_expected_output(
785
824
  inference_method, dataset, expected_output_cols, output_cols_prefix
786
825
  )
787
826
 
@@ -846,7 +885,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
846
885
  drop_input_cols = self._drop_input_cols,
847
886
  expected_output_cols_type="float",
848
887
  )
849
- expected_output_cols = self._align_expected_output_names(
888
+ expected_output_cols, _ = self._align_expected_output(
850
889
  inference_method, dataset, expected_output_cols, output_cols_prefix
851
890
  )
852
891
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -574,12 +571,23 @@ class PassiveAggressiveClassifier(BaseTransformer):
574
571
  autogenerated=self._autogenerated,
575
572
  subproject=_SUBPROJECT,
576
573
  )
577
- output_result, fitted_estimator = model_trainer.train_fit_predict(
578
- drop_input_cols=self._drop_input_cols,
579
- expected_output_cols_list=(
580
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
581
- ),
574
+ expected_output_cols = (
575
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
582
576
  )
577
+ if isinstance(dataset, DataFrame):
578
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
579
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
580
+ )
581
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
582
+ drop_input_cols=self._drop_input_cols,
583
+ expected_output_cols_list=expected_output_cols,
584
+ example_output_pd_df=example_output_pd_df,
585
+ )
586
+ else:
587
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
588
+ drop_input_cols=self._drop_input_cols,
589
+ expected_output_cols_list=expected_output_cols,
590
+ )
583
591
  self._sklearn_object = fitted_estimator
584
592
  self._is_fitted = True
585
593
  return output_result
@@ -658,12 +666,41 @@ class PassiveAggressiveClassifier(BaseTransformer):
658
666
 
659
667
  return rv
660
668
 
661
- def _align_expected_output_names(
662
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
663
- ) -> List[str]:
669
+ def _align_expected_output(
670
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
671
+ ) -> Tuple[List[str], pd.DataFrame]:
672
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
673
+ and output dataframe with 1 line.
674
+ If the method is fit_predict, run 2 lines of data.
675
+ """
664
676
  # in case the inferred output column names dimension is different
665
677
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
666
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
678
+
679
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
680
+ # so change the minimum of number of rows to 2
681
+ num_examples = 2
682
+ statement_params = telemetry.get_function_usage_statement_params(
683
+ project=_PROJECT,
684
+ subproject=_SUBPROJECT,
685
+ function_name=telemetry.get_statement_params_full_func_name(
686
+ inspect.currentframe(), PassiveAggressiveClassifier.__class__.__name__
687
+ ),
688
+ api_calls=[Session.call],
689
+ custom_tags={"autogen": True} if self._autogenerated else None,
690
+ )
691
+ if output_cols_prefix == "fit_predict_":
692
+ if hasattr(self._sklearn_object, "n_clusters"):
693
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
694
+ num_examples = self._sklearn_object.n_clusters
695
+ elif hasattr(self._sklearn_object, "min_samples"):
696
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
697
+ num_examples = self._sklearn_object.min_samples
698
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
699
+ # LocalOutlierFactor expects n_neighbors <= n_samples
700
+ num_examples = self._sklearn_object.n_neighbors
701
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
702
+ else:
703
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
667
704
 
668
705
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
669
706
  # seen during the fit.
@@ -675,12 +712,14 @@ class PassiveAggressiveClassifier(BaseTransformer):
675
712
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
676
713
  if self.sample_weight_col:
677
714
  output_df_columns_set -= set(self.sample_weight_col)
715
+
678
716
  # if the dimension of inferred output column names is correct; use it
679
717
  if len(expected_output_cols_list) == len(output_df_columns_set):
680
- return expected_output_cols_list
718
+ return expected_output_cols_list, output_df_pd
681
719
  # otherwise, use the sklearn estimator's output
682
720
  else:
683
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
721
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
722
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
684
723
 
685
724
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
686
725
  @telemetry.send_api_usage_telemetry(
@@ -726,7 +765,7 @@ class PassiveAggressiveClassifier(BaseTransformer):
726
765
  drop_input_cols=self._drop_input_cols,
727
766
  expected_output_cols_type="float",
728
767
  )
729
- expected_output_cols = self._align_expected_output_names(
768
+ expected_output_cols, _ = self._align_expected_output(
730
769
  inference_method, dataset, expected_output_cols, output_cols_prefix
731
770
  )
732
771
 
@@ -792,7 +831,7 @@ class PassiveAggressiveClassifier(BaseTransformer):
792
831
  drop_input_cols=self._drop_input_cols,
793
832
  expected_output_cols_type="float",
794
833
  )
795
- expected_output_cols = self._align_expected_output_names(
834
+ expected_output_cols, _ = self._align_expected_output(
796
835
  inference_method, dataset, expected_output_cols, output_cols_prefix
797
836
  )
798
837
  elif isinstance(dataset, pd.DataFrame):
@@ -857,7 +896,7 @@ class PassiveAggressiveClassifier(BaseTransformer):
857
896
  drop_input_cols=self._drop_input_cols,
858
897
  expected_output_cols_type="float",
859
898
  )
860
- expected_output_cols = self._align_expected_output_names(
899
+ expected_output_cols, _ = self._align_expected_output(
861
900
  inference_method, dataset, expected_output_cols, output_cols_prefix
862
901
  )
863
902
 
@@ -922,7 +961,7 @@ class PassiveAggressiveClassifier(BaseTransformer):
922
961
  drop_input_cols = self._drop_input_cols,
923
962
  expected_output_cols_type="float",
924
963
  )
925
- expected_output_cols = self._align_expected_output_names(
964
+ expected_output_cols, _ = self._align_expected_output(
926
965
  inference_method, dataset, expected_output_cols, output_cols_prefix
927
966
  )
928
967
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -560,12 +557,23 @@ class PassiveAggressiveRegressor(BaseTransformer):
560
557
  autogenerated=self._autogenerated,
561
558
  subproject=_SUBPROJECT,
562
559
  )
563
- output_result, fitted_estimator = model_trainer.train_fit_predict(
564
- drop_input_cols=self._drop_input_cols,
565
- expected_output_cols_list=(
566
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
567
- ),
560
+ expected_output_cols = (
561
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
568
562
  )
563
+ if isinstance(dataset, DataFrame):
564
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
565
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
566
+ )
567
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
568
+ drop_input_cols=self._drop_input_cols,
569
+ expected_output_cols_list=expected_output_cols,
570
+ example_output_pd_df=example_output_pd_df,
571
+ )
572
+ else:
573
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
574
+ drop_input_cols=self._drop_input_cols,
575
+ expected_output_cols_list=expected_output_cols,
576
+ )
569
577
  self._sklearn_object = fitted_estimator
570
578
  self._is_fitted = True
571
579
  return output_result
@@ -644,12 +652,41 @@ class PassiveAggressiveRegressor(BaseTransformer):
644
652
 
645
653
  return rv
646
654
 
647
- def _align_expected_output_names(
648
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
649
- ) -> List[str]:
655
+ def _align_expected_output(
656
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
657
+ ) -> Tuple[List[str], pd.DataFrame]:
658
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
659
+ and output dataframe with 1 line.
660
+ If the method is fit_predict, run 2 lines of data.
661
+ """
650
662
  # in case the inferred output column names dimension is different
651
663
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
652
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
664
+
665
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
666
+ # so change the minimum of number of rows to 2
667
+ num_examples = 2
668
+ statement_params = telemetry.get_function_usage_statement_params(
669
+ project=_PROJECT,
670
+ subproject=_SUBPROJECT,
671
+ function_name=telemetry.get_statement_params_full_func_name(
672
+ inspect.currentframe(), PassiveAggressiveRegressor.__class__.__name__
673
+ ),
674
+ api_calls=[Session.call],
675
+ custom_tags={"autogen": True} if self._autogenerated else None,
676
+ )
677
+ if output_cols_prefix == "fit_predict_":
678
+ if hasattr(self._sklearn_object, "n_clusters"):
679
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
680
+ num_examples = self._sklearn_object.n_clusters
681
+ elif hasattr(self._sklearn_object, "min_samples"):
682
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
683
+ num_examples = self._sklearn_object.min_samples
684
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
685
+ # LocalOutlierFactor expects n_neighbors <= n_samples
686
+ num_examples = self._sklearn_object.n_neighbors
687
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
688
+ else:
689
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
653
690
 
654
691
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
655
692
  # seen during the fit.
@@ -661,12 +698,14 @@ class PassiveAggressiveRegressor(BaseTransformer):
661
698
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
662
699
  if self.sample_weight_col:
663
700
  output_df_columns_set -= set(self.sample_weight_col)
701
+
664
702
  # if the dimension of inferred output column names is correct; use it
665
703
  if len(expected_output_cols_list) == len(output_df_columns_set):
666
- return expected_output_cols_list
704
+ return expected_output_cols_list, output_df_pd
667
705
  # otherwise, use the sklearn estimator's output
668
706
  else:
669
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
707
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
708
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
670
709
 
671
710
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
672
711
  @telemetry.send_api_usage_telemetry(
@@ -712,7 +751,7 @@ class PassiveAggressiveRegressor(BaseTransformer):
712
751
  drop_input_cols=self._drop_input_cols,
713
752
  expected_output_cols_type="float",
714
753
  )
715
- expected_output_cols = self._align_expected_output_names(
754
+ expected_output_cols, _ = self._align_expected_output(
716
755
  inference_method, dataset, expected_output_cols, output_cols_prefix
717
756
  )
718
757
 
@@ -778,7 +817,7 @@ class PassiveAggressiveRegressor(BaseTransformer):
778
817
  drop_input_cols=self._drop_input_cols,
779
818
  expected_output_cols_type="float",
780
819
  )
781
- expected_output_cols = self._align_expected_output_names(
820
+ expected_output_cols, _ = self._align_expected_output(
782
821
  inference_method, dataset, expected_output_cols, output_cols_prefix
783
822
  )
784
823
  elif isinstance(dataset, pd.DataFrame):
@@ -841,7 +880,7 @@ class PassiveAggressiveRegressor(BaseTransformer):
841
880
  drop_input_cols=self._drop_input_cols,
842
881
  expected_output_cols_type="float",
843
882
  )
844
- expected_output_cols = self._align_expected_output_names(
883
+ expected_output_cols, _ = self._align_expected_output(
845
884
  inference_method, dataset, expected_output_cols, output_cols_prefix
846
885
  )
847
886
 
@@ -906,7 +945,7 @@ class PassiveAggressiveRegressor(BaseTransformer):
906
945
  drop_input_cols = self._drop_input_cols,
907
946
  expected_output_cols_type="float",
908
947
  )
909
- expected_output_cols = self._align_expected_output_names(
948
+ expected_output_cols, _ = self._align_expected_output(
910
949
  inference_method, dataset, expected_output_cols, output_cols_prefix
911
950
  )
912
951