snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. snowflake/ml/_internal/telemetry.py +142 -20
  2. snowflake/ml/_internal/utils/identifier.py +48 -11
  3. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  4. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  5. snowflake/ml/_internal/utils/table_manager.py +19 -1
  6. snowflake/ml/_internal/utils/uri.py +2 -2
  7. snowflake/ml/data/data_connector.py +33 -7
  8. snowflake/ml/data/torch_utils.py +68 -0
  9. snowflake/ml/dataset/dataset.py +1 -3
  10. snowflake/ml/feature_store/feature_store.py +41 -17
  11. snowflake/ml/feature_store/feature_view.py +2 -2
  12. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  13. snowflake/ml/fileset/fileset.py +1 -1
  14. snowflake/ml/fileset/sfcfs.py +9 -3
  15. snowflake/ml/model/_client/model/model_version_impl.py +22 -7
  16. snowflake/ml/model/_client/ops/model_ops.py +39 -3
  17. snowflake/ml/model/_client/ops/service_ops.py +198 -7
  18. snowflake/ml/model/_client/service/model_deployment_spec.py +4 -5
  19. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -2
  20. snowflake/ml/model/_client/sql/service.py +85 -18
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  23. snowflake/ml/model/_model_composer/model_composer.py +2 -0
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -8
  25. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  26. snowflake/ml/model/_packager/model_handlers/catboost.py +17 -15
  27. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +23 -15
  28. snowflake/ml/model/_packager/model_handlers/lightgbm.py +15 -57
  29. snowflake/ml/model/_packager/model_handlers/llm.py +4 -2
  30. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  31. snowflake/ml/model/_packager/model_handlers/sklearn.py +36 -24
  32. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +119 -6
  33. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  34. snowflake/ml/model/_packager/model_handlers/xgboost.py +48 -48
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +10 -7
  36. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +0 -8
  37. snowflake/ml/model/_packager/model_packager.py +2 -0
  38. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  39. snowflake/ml/model/_signatures/utils.py +9 -0
  40. snowflake/ml/model/models/llm.py +3 -1
  41. snowflake/ml/model/type_hints.py +9 -1
  42. snowflake/ml/modeling/_internal/constants.py +1 -0
  43. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  44. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  45. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  46. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  48. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  49. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  50. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  51. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  52. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  53. snowflake/ml/modeling/cluster/birch.py +60 -21
  54. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  55. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  56. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  57. snowflake/ml/modeling/cluster/k_means.py +60 -21
  58. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  59. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  60. snowflake/ml/modeling/cluster/optics.py +60 -21
  61. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  62. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  63. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  64. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  65. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  66. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  67. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  68. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  69. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  70. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  71. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  72. snowflake/ml/modeling/covariance/oas.py +60 -21
  73. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  74. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  75. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  76. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  77. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  78. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  79. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  80. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  81. snowflake/ml/modeling/decomposition/pca.py +60 -21
  82. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  83. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  84. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  85. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  86. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  87. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  88. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  89. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  90. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  91. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  92. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  93. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  94. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  95. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  96. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  97. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  98. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  99. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  100. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  101. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  102. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  103. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  104. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  105. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  106. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  107. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  108. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  109. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  111. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  112. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  113. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  114. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  123. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  125. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  129. snowflake/ml/modeling/linear_model/lars.py +60 -21
  130. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  131. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  136. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  146. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  149. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  158. snowflake/ml/modeling/manifold/isomap.py +60 -21
  159. snowflake/ml/modeling/manifold/mds.py +60 -21
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  161. snowflake/ml/modeling/manifold/tsne.py +60 -21
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  174. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  184. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  185. snowflake/ml/modeling/pipeline/pipeline.py +1 -12
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  189. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  190. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  191. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  192. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  193. snowflake/ml/modeling/svm/svc.py +60 -21
  194. snowflake/ml/modeling/svm/svr.py +60 -21
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  203. snowflake/ml/registry/_manager/model_manager.py +4 -0
  204. snowflake/ml/registry/model_registry.py +1 -1
  205. snowflake/ml/registry/registry.py +1 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +23 -4
  208. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +211 -209
  209. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  210. snowflake/ml/data/torch_dataset.py +0 -33
  211. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  212. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -627,12 +624,23 @@ class LogisticRegressionCV(BaseTransformer):
627
624
  autogenerated=self._autogenerated,
628
625
  subproject=_SUBPROJECT,
629
626
  )
630
- output_result, fitted_estimator = model_trainer.train_fit_predict(
631
- drop_input_cols=self._drop_input_cols,
632
- expected_output_cols_list=(
633
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
634
- ),
627
+ expected_output_cols = (
628
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
635
629
  )
630
+ if isinstance(dataset, DataFrame):
631
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
632
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
633
+ )
634
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
635
+ drop_input_cols=self._drop_input_cols,
636
+ expected_output_cols_list=expected_output_cols,
637
+ example_output_pd_df=example_output_pd_df,
638
+ )
639
+ else:
640
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
641
+ drop_input_cols=self._drop_input_cols,
642
+ expected_output_cols_list=expected_output_cols,
643
+ )
636
644
  self._sklearn_object = fitted_estimator
637
645
  self._is_fitted = True
638
646
  return output_result
@@ -711,12 +719,41 @@ class LogisticRegressionCV(BaseTransformer):
711
719
 
712
720
  return rv
713
721
 
714
- def _align_expected_output_names(
715
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
716
- ) -> List[str]:
722
+ def _align_expected_output(
723
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
724
+ ) -> Tuple[List[str], pd.DataFrame]:
725
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
726
+ and output dataframe with 1 line.
727
+ If the method is fit_predict, run 2 lines of data.
728
+ """
717
729
  # in case the inferred output column names dimension is different
718
730
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
719
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
731
+
732
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
733
+ # so change the minimum of number of rows to 2
734
+ num_examples = 2
735
+ statement_params = telemetry.get_function_usage_statement_params(
736
+ project=_PROJECT,
737
+ subproject=_SUBPROJECT,
738
+ function_name=telemetry.get_statement_params_full_func_name(
739
+ inspect.currentframe(), LogisticRegressionCV.__class__.__name__
740
+ ),
741
+ api_calls=[Session.call],
742
+ custom_tags={"autogen": True} if self._autogenerated else None,
743
+ )
744
+ if output_cols_prefix == "fit_predict_":
745
+ if hasattr(self._sklearn_object, "n_clusters"):
746
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
747
+ num_examples = self._sklearn_object.n_clusters
748
+ elif hasattr(self._sklearn_object, "min_samples"):
749
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
750
+ num_examples = self._sklearn_object.min_samples
751
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
752
+ # LocalOutlierFactor expects n_neighbors <= n_samples
753
+ num_examples = self._sklearn_object.n_neighbors
754
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
755
+ else:
756
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
720
757
 
721
758
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
722
759
  # seen during the fit.
@@ -728,12 +765,14 @@ class LogisticRegressionCV(BaseTransformer):
728
765
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
729
766
  if self.sample_weight_col:
730
767
  output_df_columns_set -= set(self.sample_weight_col)
768
+
731
769
  # if the dimension of inferred output column names is correct; use it
732
770
  if len(expected_output_cols_list) == len(output_df_columns_set):
733
- return expected_output_cols_list
771
+ return expected_output_cols_list, output_df_pd
734
772
  # otherwise, use the sklearn estimator's output
735
773
  else:
736
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
774
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
775
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
737
776
 
738
777
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
739
778
  @telemetry.send_api_usage_telemetry(
@@ -781,7 +820,7 @@ class LogisticRegressionCV(BaseTransformer):
781
820
  drop_input_cols=self._drop_input_cols,
782
821
  expected_output_cols_type="float",
783
822
  )
784
- expected_output_cols = self._align_expected_output_names(
823
+ expected_output_cols, _ = self._align_expected_output(
785
824
  inference_method, dataset, expected_output_cols, output_cols_prefix
786
825
  )
787
826
 
@@ -849,7 +888,7 @@ class LogisticRegressionCV(BaseTransformer):
849
888
  drop_input_cols=self._drop_input_cols,
850
889
  expected_output_cols_type="float",
851
890
  )
852
- expected_output_cols = self._align_expected_output_names(
891
+ expected_output_cols, _ = self._align_expected_output(
853
892
  inference_method, dataset, expected_output_cols, output_cols_prefix
854
893
  )
855
894
  elif isinstance(dataset, pd.DataFrame):
@@ -914,7 +953,7 @@ class LogisticRegressionCV(BaseTransformer):
914
953
  drop_input_cols=self._drop_input_cols,
915
954
  expected_output_cols_type="float",
916
955
  )
917
- expected_output_cols = self._align_expected_output_names(
956
+ expected_output_cols, _ = self._align_expected_output(
918
957
  inference_method, dataset, expected_output_cols, output_cols_prefix
919
958
  )
920
959
 
@@ -979,7 +1018,7 @@ class LogisticRegressionCV(BaseTransformer):
979
1018
  drop_input_cols = self._drop_input_cols,
980
1019
  expected_output_cols_type="float",
981
1020
  )
982
- expected_output_cols = self._align_expected_output_names(
1021
+ expected_output_cols, _ = self._align_expected_output(
983
1022
  inference_method, dataset, expected_output_cols, output_cols_prefix
984
1023
  )
985
1024
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -525,12 +522,23 @@ class MultiTaskElasticNet(BaseTransformer):
525
522
  autogenerated=self._autogenerated,
526
523
  subproject=_SUBPROJECT,
527
524
  )
528
- output_result, fitted_estimator = model_trainer.train_fit_predict(
529
- drop_input_cols=self._drop_input_cols,
530
- expected_output_cols_list=(
531
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
532
- ),
525
+ expected_output_cols = (
526
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
533
527
  )
528
+ if isinstance(dataset, DataFrame):
529
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
530
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
531
+ )
532
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
533
+ drop_input_cols=self._drop_input_cols,
534
+ expected_output_cols_list=expected_output_cols,
535
+ example_output_pd_df=example_output_pd_df,
536
+ )
537
+ else:
538
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
539
+ drop_input_cols=self._drop_input_cols,
540
+ expected_output_cols_list=expected_output_cols,
541
+ )
534
542
  self._sklearn_object = fitted_estimator
535
543
  self._is_fitted = True
536
544
  return output_result
@@ -609,12 +617,41 @@ class MultiTaskElasticNet(BaseTransformer):
609
617
 
610
618
  return rv
611
619
 
612
- def _align_expected_output_names(
613
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
614
- ) -> List[str]:
620
+ def _align_expected_output(
621
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
622
+ ) -> Tuple[List[str], pd.DataFrame]:
623
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
624
+ and output dataframe with 1 line.
625
+ If the method is fit_predict, run 2 lines of data.
626
+ """
615
627
  # in case the inferred output column names dimension is different
616
628
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
617
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
629
+
630
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
631
+ # so change the minimum of number of rows to 2
632
+ num_examples = 2
633
+ statement_params = telemetry.get_function_usage_statement_params(
634
+ project=_PROJECT,
635
+ subproject=_SUBPROJECT,
636
+ function_name=telemetry.get_statement_params_full_func_name(
637
+ inspect.currentframe(), MultiTaskElasticNet.__class__.__name__
638
+ ),
639
+ api_calls=[Session.call],
640
+ custom_tags={"autogen": True} if self._autogenerated else None,
641
+ )
642
+ if output_cols_prefix == "fit_predict_":
643
+ if hasattr(self._sklearn_object, "n_clusters"):
644
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
645
+ num_examples = self._sklearn_object.n_clusters
646
+ elif hasattr(self._sklearn_object, "min_samples"):
647
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
648
+ num_examples = self._sklearn_object.min_samples
649
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
650
+ # LocalOutlierFactor expects n_neighbors <= n_samples
651
+ num_examples = self._sklearn_object.n_neighbors
652
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
653
+ else:
654
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
618
655
 
619
656
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
620
657
  # seen during the fit.
@@ -626,12 +663,14 @@ class MultiTaskElasticNet(BaseTransformer):
626
663
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
627
664
  if self.sample_weight_col:
628
665
  output_df_columns_set -= set(self.sample_weight_col)
666
+
629
667
  # if the dimension of inferred output column names is correct; use it
630
668
  if len(expected_output_cols_list) == len(output_df_columns_set):
631
- return expected_output_cols_list
669
+ return expected_output_cols_list, output_df_pd
632
670
  # otherwise, use the sklearn estimator's output
633
671
  else:
634
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
672
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
673
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
635
674
 
636
675
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
637
676
  @telemetry.send_api_usage_telemetry(
@@ -677,7 +716,7 @@ class MultiTaskElasticNet(BaseTransformer):
677
716
  drop_input_cols=self._drop_input_cols,
678
717
  expected_output_cols_type="float",
679
718
  )
680
- expected_output_cols = self._align_expected_output_names(
719
+ expected_output_cols, _ = self._align_expected_output(
681
720
  inference_method, dataset, expected_output_cols, output_cols_prefix
682
721
  )
683
722
 
@@ -743,7 +782,7 @@ class MultiTaskElasticNet(BaseTransformer):
743
782
  drop_input_cols=self._drop_input_cols,
744
783
  expected_output_cols_type="float",
745
784
  )
746
- expected_output_cols = self._align_expected_output_names(
785
+ expected_output_cols, _ = self._align_expected_output(
747
786
  inference_method, dataset, expected_output_cols, output_cols_prefix
748
787
  )
749
788
  elif isinstance(dataset, pd.DataFrame):
@@ -806,7 +845,7 @@ class MultiTaskElasticNet(BaseTransformer):
806
845
  drop_input_cols=self._drop_input_cols,
807
846
  expected_output_cols_type="float",
808
847
  )
809
- expected_output_cols = self._align_expected_output_names(
848
+ expected_output_cols, _ = self._align_expected_output(
810
849
  inference_method, dataset, expected_output_cols, output_cols_prefix
811
850
  )
812
851
 
@@ -871,7 +910,7 @@ class MultiTaskElasticNet(BaseTransformer):
871
910
  drop_input_cols = self._drop_input_cols,
872
911
  expected_output_cols_type="float",
873
912
  )
874
- expected_output_cols = self._align_expected_output_names(
913
+ expected_output_cols, _ = self._align_expected_output(
875
914
  inference_method, dataset, expected_output_cols, output_cols_prefix
876
915
  )
877
916
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -566,12 +563,23 @@ class MultiTaskElasticNetCV(BaseTransformer):
566
563
  autogenerated=self._autogenerated,
567
564
  subproject=_SUBPROJECT,
568
565
  )
569
- output_result, fitted_estimator = model_trainer.train_fit_predict(
570
- drop_input_cols=self._drop_input_cols,
571
- expected_output_cols_list=(
572
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
573
- ),
566
+ expected_output_cols = (
567
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
574
568
  )
569
+ if isinstance(dataset, DataFrame):
570
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
571
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
572
+ )
573
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
574
+ drop_input_cols=self._drop_input_cols,
575
+ expected_output_cols_list=expected_output_cols,
576
+ example_output_pd_df=example_output_pd_df,
577
+ )
578
+ else:
579
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
580
+ drop_input_cols=self._drop_input_cols,
581
+ expected_output_cols_list=expected_output_cols,
582
+ )
575
583
  self._sklearn_object = fitted_estimator
576
584
  self._is_fitted = True
577
585
  return output_result
@@ -650,12 +658,41 @@ class MultiTaskElasticNetCV(BaseTransformer):
650
658
 
651
659
  return rv
652
660
 
653
- def _align_expected_output_names(
654
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
655
- ) -> List[str]:
661
+ def _align_expected_output(
662
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
663
+ ) -> Tuple[List[str], pd.DataFrame]:
664
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
665
+ and output dataframe with 1 line.
666
+ If the method is fit_predict, run 2 lines of data.
667
+ """
656
668
  # in case the inferred output column names dimension is different
657
669
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
658
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
670
+
671
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
672
+ # so change the minimum of number of rows to 2
673
+ num_examples = 2
674
+ statement_params = telemetry.get_function_usage_statement_params(
675
+ project=_PROJECT,
676
+ subproject=_SUBPROJECT,
677
+ function_name=telemetry.get_statement_params_full_func_name(
678
+ inspect.currentframe(), MultiTaskElasticNetCV.__class__.__name__
679
+ ),
680
+ api_calls=[Session.call],
681
+ custom_tags={"autogen": True} if self._autogenerated else None,
682
+ )
683
+ if output_cols_prefix == "fit_predict_":
684
+ if hasattr(self._sklearn_object, "n_clusters"):
685
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
686
+ num_examples = self._sklearn_object.n_clusters
687
+ elif hasattr(self._sklearn_object, "min_samples"):
688
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
689
+ num_examples = self._sklearn_object.min_samples
690
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
691
+ # LocalOutlierFactor expects n_neighbors <= n_samples
692
+ num_examples = self._sklearn_object.n_neighbors
693
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
694
+ else:
695
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
659
696
 
660
697
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
661
698
  # seen during the fit.
@@ -667,12 +704,14 @@ class MultiTaskElasticNetCV(BaseTransformer):
667
704
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
668
705
  if self.sample_weight_col:
669
706
  output_df_columns_set -= set(self.sample_weight_col)
707
+
670
708
  # if the dimension of inferred output column names is correct; use it
671
709
  if len(expected_output_cols_list) == len(output_df_columns_set):
672
- return expected_output_cols_list
710
+ return expected_output_cols_list, output_df_pd
673
711
  # otherwise, use the sklearn estimator's output
674
712
  else:
675
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
713
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
714
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
676
715
 
677
716
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
678
717
  @telemetry.send_api_usage_telemetry(
@@ -718,7 +757,7 @@ class MultiTaskElasticNetCV(BaseTransformer):
718
757
  drop_input_cols=self._drop_input_cols,
719
758
  expected_output_cols_type="float",
720
759
  )
721
- expected_output_cols = self._align_expected_output_names(
760
+ expected_output_cols, _ = self._align_expected_output(
722
761
  inference_method, dataset, expected_output_cols, output_cols_prefix
723
762
  )
724
763
 
@@ -784,7 +823,7 @@ class MultiTaskElasticNetCV(BaseTransformer):
784
823
  drop_input_cols=self._drop_input_cols,
785
824
  expected_output_cols_type="float",
786
825
  )
787
- expected_output_cols = self._align_expected_output_names(
826
+ expected_output_cols, _ = self._align_expected_output(
788
827
  inference_method, dataset, expected_output_cols, output_cols_prefix
789
828
  )
790
829
  elif isinstance(dataset, pd.DataFrame):
@@ -847,7 +886,7 @@ class MultiTaskElasticNetCV(BaseTransformer):
847
886
  drop_input_cols=self._drop_input_cols,
848
887
  expected_output_cols_type="float",
849
888
  )
850
- expected_output_cols = self._align_expected_output_names(
889
+ expected_output_cols, _ = self._align_expected_output(
851
890
  inference_method, dataset, expected_output_cols, output_cols_prefix
852
891
  )
853
892
 
@@ -912,7 +951,7 @@ class MultiTaskElasticNetCV(BaseTransformer):
912
951
  drop_input_cols = self._drop_input_cols,
913
952
  expected_output_cols_type="float",
914
953
  )
915
- expected_output_cols = self._align_expected_output_names(
954
+ expected_output_cols, _ = self._align_expected_output(
916
955
  inference_method, dataset, expected_output_cols, output_cols_prefix
917
956
  )
918
957
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -517,12 +514,23 @@ class MultiTaskLasso(BaseTransformer):
517
514
  autogenerated=self._autogenerated,
518
515
  subproject=_SUBPROJECT,
519
516
  )
520
- output_result, fitted_estimator = model_trainer.train_fit_predict(
521
- drop_input_cols=self._drop_input_cols,
522
- expected_output_cols_list=(
523
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
524
- ),
517
+ expected_output_cols = (
518
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
525
519
  )
520
+ if isinstance(dataset, DataFrame):
521
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
522
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
523
+ )
524
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
525
+ drop_input_cols=self._drop_input_cols,
526
+ expected_output_cols_list=expected_output_cols,
527
+ example_output_pd_df=example_output_pd_df,
528
+ )
529
+ else:
530
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
531
+ drop_input_cols=self._drop_input_cols,
532
+ expected_output_cols_list=expected_output_cols,
533
+ )
526
534
  self._sklearn_object = fitted_estimator
527
535
  self._is_fitted = True
528
536
  return output_result
@@ -601,12 +609,41 @@ class MultiTaskLasso(BaseTransformer):
601
609
 
602
610
  return rv
603
611
 
604
- def _align_expected_output_names(
605
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
606
- ) -> List[str]:
612
+ def _align_expected_output(
613
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
614
+ ) -> Tuple[List[str], pd.DataFrame]:
615
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
616
+ and output dataframe with 1 line.
617
+ If the method is fit_predict, run 2 lines of data.
618
+ """
607
619
  # in case the inferred output column names dimension is different
608
620
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
609
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
621
+
622
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
623
+ # so change the minimum of number of rows to 2
624
+ num_examples = 2
625
+ statement_params = telemetry.get_function_usage_statement_params(
626
+ project=_PROJECT,
627
+ subproject=_SUBPROJECT,
628
+ function_name=telemetry.get_statement_params_full_func_name(
629
+ inspect.currentframe(), MultiTaskLasso.__class__.__name__
630
+ ),
631
+ api_calls=[Session.call],
632
+ custom_tags={"autogen": True} if self._autogenerated else None,
633
+ )
634
+ if output_cols_prefix == "fit_predict_":
635
+ if hasattr(self._sklearn_object, "n_clusters"):
636
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
637
+ num_examples = self._sklearn_object.n_clusters
638
+ elif hasattr(self._sklearn_object, "min_samples"):
639
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
640
+ num_examples = self._sklearn_object.min_samples
641
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
642
+ # LocalOutlierFactor expects n_neighbors <= n_samples
643
+ num_examples = self._sklearn_object.n_neighbors
644
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
645
+ else:
646
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
610
647
 
611
648
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
612
649
  # seen during the fit.
@@ -618,12 +655,14 @@ class MultiTaskLasso(BaseTransformer):
618
655
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
619
656
  if self.sample_weight_col:
620
657
  output_df_columns_set -= set(self.sample_weight_col)
658
+
621
659
  # if the dimension of inferred output column names is correct; use it
622
660
  if len(expected_output_cols_list) == len(output_df_columns_set):
623
- return expected_output_cols_list
661
+ return expected_output_cols_list, output_df_pd
624
662
  # otherwise, use the sklearn estimator's output
625
663
  else:
626
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
664
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
665
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
627
666
 
628
667
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
629
668
  @telemetry.send_api_usage_telemetry(
@@ -669,7 +708,7 @@ class MultiTaskLasso(BaseTransformer):
669
708
  drop_input_cols=self._drop_input_cols,
670
709
  expected_output_cols_type="float",
671
710
  )
672
- expected_output_cols = self._align_expected_output_names(
711
+ expected_output_cols, _ = self._align_expected_output(
673
712
  inference_method, dataset, expected_output_cols, output_cols_prefix
674
713
  )
675
714
 
@@ -735,7 +774,7 @@ class MultiTaskLasso(BaseTransformer):
735
774
  drop_input_cols=self._drop_input_cols,
736
775
  expected_output_cols_type="float",
737
776
  )
738
- expected_output_cols = self._align_expected_output_names(
777
+ expected_output_cols, _ = self._align_expected_output(
739
778
  inference_method, dataset, expected_output_cols, output_cols_prefix
740
779
  )
741
780
  elif isinstance(dataset, pd.DataFrame):
@@ -798,7 +837,7 @@ class MultiTaskLasso(BaseTransformer):
798
837
  drop_input_cols=self._drop_input_cols,
799
838
  expected_output_cols_type="float",
800
839
  )
801
- expected_output_cols = self._align_expected_output_names(
840
+ expected_output_cols, _ = self._align_expected_output(
802
841
  inference_method, dataset, expected_output_cols, output_cols_prefix
803
842
  )
804
843
 
@@ -863,7 +902,7 @@ class MultiTaskLasso(BaseTransformer):
863
902
  drop_input_cols = self._drop_input_cols,
864
903
  expected_output_cols_type="float",
865
904
  )
866
- expected_output_cols = self._align_expected_output_names(
905
+ expected_output_cols, _ = self._align_expected_output(
867
906
  inference_method, dataset, expected_output_cols, output_cols_prefix
868
907
  )
869
908