snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. snowflake/ml/_internal/telemetry.py +142 -20
  2. snowflake/ml/_internal/utils/identifier.py +48 -11
  3. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  4. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  5. snowflake/ml/_internal/utils/table_manager.py +19 -1
  6. snowflake/ml/_internal/utils/uri.py +2 -2
  7. snowflake/ml/data/data_connector.py +33 -7
  8. snowflake/ml/data/torch_utils.py +68 -0
  9. snowflake/ml/dataset/dataset.py +1 -3
  10. snowflake/ml/feature_store/feature_store.py +41 -17
  11. snowflake/ml/feature_store/feature_view.py +2 -2
  12. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  13. snowflake/ml/fileset/fileset.py +1 -1
  14. snowflake/ml/fileset/sfcfs.py +9 -3
  15. snowflake/ml/model/_client/model/model_version_impl.py +22 -7
  16. snowflake/ml/model/_client/ops/model_ops.py +39 -3
  17. snowflake/ml/model/_client/ops/service_ops.py +198 -7
  18. snowflake/ml/model/_client/service/model_deployment_spec.py +4 -5
  19. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -2
  20. snowflake/ml/model/_client/sql/service.py +85 -18
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  23. snowflake/ml/model/_model_composer/model_composer.py +2 -0
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -8
  25. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  26. snowflake/ml/model/_packager/model_handlers/catboost.py +17 -15
  27. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +23 -15
  28. snowflake/ml/model/_packager/model_handlers/lightgbm.py +15 -57
  29. snowflake/ml/model/_packager/model_handlers/llm.py +4 -2
  30. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  31. snowflake/ml/model/_packager/model_handlers/sklearn.py +36 -24
  32. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +119 -6
  33. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  34. snowflake/ml/model/_packager/model_handlers/xgboost.py +48 -48
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +10 -7
  36. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +0 -8
  37. snowflake/ml/model/_packager/model_packager.py +2 -0
  38. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  39. snowflake/ml/model/_signatures/utils.py +9 -0
  40. snowflake/ml/model/models/llm.py +3 -1
  41. snowflake/ml/model/type_hints.py +9 -1
  42. snowflake/ml/modeling/_internal/constants.py +1 -0
  43. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  44. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  45. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  46. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  48. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  49. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  50. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  51. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  52. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  53. snowflake/ml/modeling/cluster/birch.py +60 -21
  54. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  55. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  56. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  57. snowflake/ml/modeling/cluster/k_means.py +60 -21
  58. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  59. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  60. snowflake/ml/modeling/cluster/optics.py +60 -21
  61. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  62. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  63. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  64. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  65. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  66. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  67. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  68. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  69. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  70. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  71. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  72. snowflake/ml/modeling/covariance/oas.py +60 -21
  73. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  74. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  75. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  76. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  77. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  78. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  79. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  80. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  81. snowflake/ml/modeling/decomposition/pca.py +60 -21
  82. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  83. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  84. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  85. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  86. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  87. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  88. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  89. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  90. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  91. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  92. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  93. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  94. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  95. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  96. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  97. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  98. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  99. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  100. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  101. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  102. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  103. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  104. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  105. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  106. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  107. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  108. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  109. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  111. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  112. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  113. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  114. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  123. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  125. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  129. snowflake/ml/modeling/linear_model/lars.py +60 -21
  130. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  131. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  136. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  146. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  149. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  158. snowflake/ml/modeling/manifold/isomap.py +60 -21
  159. snowflake/ml/modeling/manifold/mds.py +60 -21
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  161. snowflake/ml/modeling/manifold/tsne.py +60 -21
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  174. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  184. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  185. snowflake/ml/modeling/pipeline/pipeline.py +1 -12
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  189. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  190. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  191. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  192. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  193. snowflake/ml/modeling/svm/svc.py +60 -21
  194. snowflake/ml/modeling/svm/svr.py +60 -21
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  203. snowflake/ml/registry/_manager/model_manager.py +4 -0
  204. snowflake/ml/registry/model_registry.py +1 -1
  205. snowflake/ml/registry/registry.py +1 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +23 -4
  208. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +211 -209
  209. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  210. snowflake/ml/data/torch_dataset.py +0 -33
  211. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  212. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -563,12 +560,23 @@ class SGDOneClassSVM(BaseTransformer):
563
560
  autogenerated=self._autogenerated,
564
561
  subproject=_SUBPROJECT,
565
562
  )
566
- output_result, fitted_estimator = model_trainer.train_fit_predict(
567
- drop_input_cols=self._drop_input_cols,
568
- expected_output_cols_list=(
569
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
570
- ),
563
+ expected_output_cols = (
564
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
571
565
  )
566
+ if isinstance(dataset, DataFrame):
567
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
568
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
569
+ )
570
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
571
+ drop_input_cols=self._drop_input_cols,
572
+ expected_output_cols_list=expected_output_cols,
573
+ example_output_pd_df=example_output_pd_df,
574
+ )
575
+ else:
576
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
577
+ drop_input_cols=self._drop_input_cols,
578
+ expected_output_cols_list=expected_output_cols,
579
+ )
572
580
  self._sklearn_object = fitted_estimator
573
581
  self._is_fitted = True
574
582
  return output_result
@@ -647,12 +655,41 @@ class SGDOneClassSVM(BaseTransformer):
647
655
 
648
656
  return rv
649
657
 
650
- def _align_expected_output_names(
651
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
652
- ) -> List[str]:
658
+ def _align_expected_output(
659
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
660
+ ) -> Tuple[List[str], pd.DataFrame]:
661
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
662
+ and output dataframe with 1 line.
663
+ If the method is fit_predict, run 2 lines of data.
664
+ """
653
665
  # in case the inferred output column names dimension is different
654
666
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
655
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
667
+
668
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
669
+ # so change the minimum of number of rows to 2
670
+ num_examples = 2
671
+ statement_params = telemetry.get_function_usage_statement_params(
672
+ project=_PROJECT,
673
+ subproject=_SUBPROJECT,
674
+ function_name=telemetry.get_statement_params_full_func_name(
675
+ inspect.currentframe(), SGDOneClassSVM.__class__.__name__
676
+ ),
677
+ api_calls=[Session.call],
678
+ custom_tags={"autogen": True} if self._autogenerated else None,
679
+ )
680
+ if output_cols_prefix == "fit_predict_":
681
+ if hasattr(self._sklearn_object, "n_clusters"):
682
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
683
+ num_examples = self._sklearn_object.n_clusters
684
+ elif hasattr(self._sklearn_object, "min_samples"):
685
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
686
+ num_examples = self._sklearn_object.min_samples
687
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
688
+ # LocalOutlierFactor expects n_neighbors <= n_samples
689
+ num_examples = self._sklearn_object.n_neighbors
690
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
691
+ else:
692
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
656
693
 
657
694
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
658
695
  # seen during the fit.
@@ -664,12 +701,14 @@ class SGDOneClassSVM(BaseTransformer):
664
701
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
665
702
  if self.sample_weight_col:
666
703
  output_df_columns_set -= set(self.sample_weight_col)
704
+
667
705
  # if the dimension of inferred output column names is correct; use it
668
706
  if len(expected_output_cols_list) == len(output_df_columns_set):
669
- return expected_output_cols_list
707
+ return expected_output_cols_list, output_df_pd
670
708
  # otherwise, use the sklearn estimator's output
671
709
  else:
672
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
710
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
711
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
673
712
 
674
713
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
675
714
  @telemetry.send_api_usage_telemetry(
@@ -715,7 +754,7 @@ class SGDOneClassSVM(BaseTransformer):
715
754
  drop_input_cols=self._drop_input_cols,
716
755
  expected_output_cols_type="float",
717
756
  )
718
- expected_output_cols = self._align_expected_output_names(
757
+ expected_output_cols, _ = self._align_expected_output(
719
758
  inference_method, dataset, expected_output_cols, output_cols_prefix
720
759
  )
721
760
 
@@ -781,7 +820,7 @@ class SGDOneClassSVM(BaseTransformer):
781
820
  drop_input_cols=self._drop_input_cols,
782
821
  expected_output_cols_type="float",
783
822
  )
784
- expected_output_cols = self._align_expected_output_names(
823
+ expected_output_cols, _ = self._align_expected_output(
785
824
  inference_method, dataset, expected_output_cols, output_cols_prefix
786
825
  )
787
826
  elif isinstance(dataset, pd.DataFrame):
@@ -846,7 +885,7 @@ class SGDOneClassSVM(BaseTransformer):
846
885
  drop_input_cols=self._drop_input_cols,
847
886
  expected_output_cols_type="float",
848
887
  )
849
- expected_output_cols = self._align_expected_output_names(
888
+ expected_output_cols, _ = self._align_expected_output(
850
889
  inference_method, dataset, expected_output_cols, output_cols_prefix
851
890
  )
852
891
 
@@ -913,7 +952,7 @@ class SGDOneClassSVM(BaseTransformer):
913
952
  drop_input_cols = self._drop_input_cols,
914
953
  expected_output_cols_type="float",
915
954
  )
916
- expected_output_cols = self._align_expected_output_names(
955
+ expected_output_cols, _ = self._align_expected_output(
917
956
  inference_method, dataset, expected_output_cols, output_cols_prefix
918
957
  )
919
958
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -627,12 +624,23 @@ class SGDRegressor(BaseTransformer):
627
624
  autogenerated=self._autogenerated,
628
625
  subproject=_SUBPROJECT,
629
626
  )
630
- output_result, fitted_estimator = model_trainer.train_fit_predict(
631
- drop_input_cols=self._drop_input_cols,
632
- expected_output_cols_list=(
633
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
634
- ),
627
+ expected_output_cols = (
628
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
635
629
  )
630
+ if isinstance(dataset, DataFrame):
631
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
632
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
633
+ )
634
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
635
+ drop_input_cols=self._drop_input_cols,
636
+ expected_output_cols_list=expected_output_cols,
637
+ example_output_pd_df=example_output_pd_df,
638
+ )
639
+ else:
640
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
641
+ drop_input_cols=self._drop_input_cols,
642
+ expected_output_cols_list=expected_output_cols,
643
+ )
636
644
  self._sklearn_object = fitted_estimator
637
645
  self._is_fitted = True
638
646
  return output_result
@@ -711,12 +719,41 @@ class SGDRegressor(BaseTransformer):
711
719
 
712
720
  return rv
713
721
 
714
- def _align_expected_output_names(
715
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
716
- ) -> List[str]:
722
+ def _align_expected_output(
723
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
724
+ ) -> Tuple[List[str], pd.DataFrame]:
725
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
726
+ and output dataframe with 1 line.
727
+ If the method is fit_predict, run 2 lines of data.
728
+ """
717
729
  # in case the inferred output column names dimension is different
718
730
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
719
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
731
+
732
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
733
+ # so change the minimum of number of rows to 2
734
+ num_examples = 2
735
+ statement_params = telemetry.get_function_usage_statement_params(
736
+ project=_PROJECT,
737
+ subproject=_SUBPROJECT,
738
+ function_name=telemetry.get_statement_params_full_func_name(
739
+ inspect.currentframe(), SGDRegressor.__class__.__name__
740
+ ),
741
+ api_calls=[Session.call],
742
+ custom_tags={"autogen": True} if self._autogenerated else None,
743
+ )
744
+ if output_cols_prefix == "fit_predict_":
745
+ if hasattr(self._sklearn_object, "n_clusters"):
746
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
747
+ num_examples = self._sklearn_object.n_clusters
748
+ elif hasattr(self._sklearn_object, "min_samples"):
749
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
750
+ num_examples = self._sklearn_object.min_samples
751
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
752
+ # LocalOutlierFactor expects n_neighbors <= n_samples
753
+ num_examples = self._sklearn_object.n_neighbors
754
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
755
+ else:
756
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
720
757
 
721
758
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
722
759
  # seen during the fit.
@@ -728,12 +765,14 @@ class SGDRegressor(BaseTransformer):
728
765
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
729
766
  if self.sample_weight_col:
730
767
  output_df_columns_set -= set(self.sample_weight_col)
768
+
731
769
  # if the dimension of inferred output column names is correct; use it
732
770
  if len(expected_output_cols_list) == len(output_df_columns_set):
733
- return expected_output_cols_list
771
+ return expected_output_cols_list, output_df_pd
734
772
  # otherwise, use the sklearn estimator's output
735
773
  else:
736
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
774
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
775
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
737
776
 
738
777
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
739
778
  @telemetry.send_api_usage_telemetry(
@@ -779,7 +818,7 @@ class SGDRegressor(BaseTransformer):
779
818
  drop_input_cols=self._drop_input_cols,
780
819
  expected_output_cols_type="float",
781
820
  )
782
- expected_output_cols = self._align_expected_output_names(
821
+ expected_output_cols, _ = self._align_expected_output(
783
822
  inference_method, dataset, expected_output_cols, output_cols_prefix
784
823
  )
785
824
 
@@ -845,7 +884,7 @@ class SGDRegressor(BaseTransformer):
845
884
  drop_input_cols=self._drop_input_cols,
846
885
  expected_output_cols_type="float",
847
886
  )
848
- expected_output_cols = self._align_expected_output_names(
887
+ expected_output_cols, _ = self._align_expected_output(
849
888
  inference_method, dataset, expected_output_cols, output_cols_prefix
850
889
  )
851
890
  elif isinstance(dataset, pd.DataFrame):
@@ -908,7 +947,7 @@ class SGDRegressor(BaseTransformer):
908
947
  drop_input_cols=self._drop_input_cols,
909
948
  expected_output_cols_type="float",
910
949
  )
911
- expected_output_cols = self._align_expected_output_names(
950
+ expected_output_cols, _ = self._align_expected_output(
912
951
  inference_method, dataset, expected_output_cols, output_cols_prefix
913
952
  )
914
953
 
@@ -973,7 +1012,7 @@ class SGDRegressor(BaseTransformer):
973
1012
  drop_input_cols = self._drop_input_cols,
974
1013
  expected_output_cols_type="float",
975
1014
  )
976
- expected_output_cols = self._align_expected_output_names(
1015
+ expected_output_cols, _ = self._align_expected_output(
977
1016
  inference_method, dataset, expected_output_cols, output_cols_prefix
978
1017
  )
979
1018
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -529,12 +526,23 @@ class TheilSenRegressor(BaseTransformer):
529
526
  autogenerated=self._autogenerated,
530
527
  subproject=_SUBPROJECT,
531
528
  )
532
- output_result, fitted_estimator = model_trainer.train_fit_predict(
533
- drop_input_cols=self._drop_input_cols,
534
- expected_output_cols_list=(
535
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
536
- ),
529
+ expected_output_cols = (
530
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
537
531
  )
532
+ if isinstance(dataset, DataFrame):
533
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
534
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
535
+ )
536
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
537
+ drop_input_cols=self._drop_input_cols,
538
+ expected_output_cols_list=expected_output_cols,
539
+ example_output_pd_df=example_output_pd_df,
540
+ )
541
+ else:
542
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
543
+ drop_input_cols=self._drop_input_cols,
544
+ expected_output_cols_list=expected_output_cols,
545
+ )
538
546
  self._sklearn_object = fitted_estimator
539
547
  self._is_fitted = True
540
548
  return output_result
@@ -613,12 +621,41 @@ class TheilSenRegressor(BaseTransformer):
613
621
 
614
622
  return rv
615
623
 
616
- def _align_expected_output_names(
617
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
618
- ) -> List[str]:
624
+ def _align_expected_output(
625
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
626
+ ) -> Tuple[List[str], pd.DataFrame]:
627
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
628
+ and output dataframe with 1 line.
629
+ If the method is fit_predict, run 2 lines of data.
630
+ """
619
631
  # in case the inferred output column names dimension is different
620
632
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
621
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
633
+
634
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
635
+ # so change the minimum of number of rows to 2
636
+ num_examples = 2
637
+ statement_params = telemetry.get_function_usage_statement_params(
638
+ project=_PROJECT,
639
+ subproject=_SUBPROJECT,
640
+ function_name=telemetry.get_statement_params_full_func_name(
641
+ inspect.currentframe(), TheilSenRegressor.__class__.__name__
642
+ ),
643
+ api_calls=[Session.call],
644
+ custom_tags={"autogen": True} if self._autogenerated else None,
645
+ )
646
+ if output_cols_prefix == "fit_predict_":
647
+ if hasattr(self._sklearn_object, "n_clusters"):
648
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
649
+ num_examples = self._sklearn_object.n_clusters
650
+ elif hasattr(self._sklearn_object, "min_samples"):
651
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
652
+ num_examples = self._sklearn_object.min_samples
653
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
654
+ # LocalOutlierFactor expects n_neighbors <= n_samples
655
+ num_examples = self._sklearn_object.n_neighbors
656
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
657
+ else:
658
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
622
659
 
623
660
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
624
661
  # seen during the fit.
@@ -630,12 +667,14 @@ class TheilSenRegressor(BaseTransformer):
630
667
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
631
668
  if self.sample_weight_col:
632
669
  output_df_columns_set -= set(self.sample_weight_col)
670
+
633
671
  # if the dimension of inferred output column names is correct; use it
634
672
  if len(expected_output_cols_list) == len(output_df_columns_set):
635
- return expected_output_cols_list
673
+ return expected_output_cols_list, output_df_pd
636
674
  # otherwise, use the sklearn estimator's output
637
675
  else:
638
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
676
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
677
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
639
678
 
640
679
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
641
680
  @telemetry.send_api_usage_telemetry(
@@ -681,7 +720,7 @@ class TheilSenRegressor(BaseTransformer):
681
720
  drop_input_cols=self._drop_input_cols,
682
721
  expected_output_cols_type="float",
683
722
  )
684
- expected_output_cols = self._align_expected_output_names(
723
+ expected_output_cols, _ = self._align_expected_output(
685
724
  inference_method, dataset, expected_output_cols, output_cols_prefix
686
725
  )
687
726
 
@@ -747,7 +786,7 @@ class TheilSenRegressor(BaseTransformer):
747
786
  drop_input_cols=self._drop_input_cols,
748
787
  expected_output_cols_type="float",
749
788
  )
750
- expected_output_cols = self._align_expected_output_names(
789
+ expected_output_cols, _ = self._align_expected_output(
751
790
  inference_method, dataset, expected_output_cols, output_cols_prefix
752
791
  )
753
792
  elif isinstance(dataset, pd.DataFrame):
@@ -810,7 +849,7 @@ class TheilSenRegressor(BaseTransformer):
810
849
  drop_input_cols=self._drop_input_cols,
811
850
  expected_output_cols_type="float",
812
851
  )
813
- expected_output_cols = self._align_expected_output_names(
852
+ expected_output_cols, _ = self._align_expected_output(
814
853
  inference_method, dataset, expected_output_cols, output_cols_prefix
815
854
  )
816
855
 
@@ -875,7 +914,7 @@ class TheilSenRegressor(BaseTransformer):
875
914
  drop_input_cols = self._drop_input_cols,
876
915
  expected_output_cols_type="float",
877
916
  )
878
- expected_output_cols = self._align_expected_output_names(
917
+ expected_output_cols, _ = self._align_expected_output(
879
918
  inference_method, dataset, expected_output_cols, output_cols_prefix
880
919
  )
881
920
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -555,12 +552,23 @@ class TweedieRegressor(BaseTransformer):
555
552
  autogenerated=self._autogenerated,
556
553
  subproject=_SUBPROJECT,
557
554
  )
558
- output_result, fitted_estimator = model_trainer.train_fit_predict(
559
- drop_input_cols=self._drop_input_cols,
560
- expected_output_cols_list=(
561
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
562
- ),
555
+ expected_output_cols = (
556
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
563
557
  )
558
+ if isinstance(dataset, DataFrame):
559
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
560
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
561
+ )
562
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
563
+ drop_input_cols=self._drop_input_cols,
564
+ expected_output_cols_list=expected_output_cols,
565
+ example_output_pd_df=example_output_pd_df,
566
+ )
567
+ else:
568
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
569
+ drop_input_cols=self._drop_input_cols,
570
+ expected_output_cols_list=expected_output_cols,
571
+ )
564
572
  self._sklearn_object = fitted_estimator
565
573
  self._is_fitted = True
566
574
  return output_result
@@ -639,12 +647,41 @@ class TweedieRegressor(BaseTransformer):
639
647
 
640
648
  return rv
641
649
 
642
- def _align_expected_output_names(
643
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
644
- ) -> List[str]:
650
+ def _align_expected_output(
651
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
652
+ ) -> Tuple[List[str], pd.DataFrame]:
653
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
654
+ and output dataframe with 1 line.
655
+ If the method is fit_predict, run 2 lines of data.
656
+ """
645
657
  # in case the inferred output column names dimension is different
646
658
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
647
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
659
+
660
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
661
+ # so change the minimum of number of rows to 2
662
+ num_examples = 2
663
+ statement_params = telemetry.get_function_usage_statement_params(
664
+ project=_PROJECT,
665
+ subproject=_SUBPROJECT,
666
+ function_name=telemetry.get_statement_params_full_func_name(
667
+ inspect.currentframe(), TweedieRegressor.__class__.__name__
668
+ ),
669
+ api_calls=[Session.call],
670
+ custom_tags={"autogen": True} if self._autogenerated else None,
671
+ )
672
+ if output_cols_prefix == "fit_predict_":
673
+ if hasattr(self._sklearn_object, "n_clusters"):
674
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
675
+ num_examples = self._sklearn_object.n_clusters
676
+ elif hasattr(self._sklearn_object, "min_samples"):
677
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
678
+ num_examples = self._sklearn_object.min_samples
679
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
680
+ # LocalOutlierFactor expects n_neighbors <= n_samples
681
+ num_examples = self._sklearn_object.n_neighbors
682
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
683
+ else:
684
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
648
685
 
649
686
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
650
687
  # seen during the fit.
@@ -656,12 +693,14 @@ class TweedieRegressor(BaseTransformer):
656
693
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
657
694
  if self.sample_weight_col:
658
695
  output_df_columns_set -= set(self.sample_weight_col)
696
+
659
697
  # if the dimension of inferred output column names is correct; use it
660
698
  if len(expected_output_cols_list) == len(output_df_columns_set):
661
- return expected_output_cols_list
699
+ return expected_output_cols_list, output_df_pd
662
700
  # otherwise, use the sklearn estimator's output
663
701
  else:
664
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
702
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
703
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
665
704
 
666
705
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
667
706
  @telemetry.send_api_usage_telemetry(
@@ -707,7 +746,7 @@ class TweedieRegressor(BaseTransformer):
707
746
  drop_input_cols=self._drop_input_cols,
708
747
  expected_output_cols_type="float",
709
748
  )
710
- expected_output_cols = self._align_expected_output_names(
749
+ expected_output_cols, _ = self._align_expected_output(
711
750
  inference_method, dataset, expected_output_cols, output_cols_prefix
712
751
  )
713
752
 
@@ -773,7 +812,7 @@ class TweedieRegressor(BaseTransformer):
773
812
  drop_input_cols=self._drop_input_cols,
774
813
  expected_output_cols_type="float",
775
814
  )
776
- expected_output_cols = self._align_expected_output_names(
815
+ expected_output_cols, _ = self._align_expected_output(
777
816
  inference_method, dataset, expected_output_cols, output_cols_prefix
778
817
  )
779
818
  elif isinstance(dataset, pd.DataFrame):
@@ -836,7 +875,7 @@ class TweedieRegressor(BaseTransformer):
836
875
  drop_input_cols=self._drop_input_cols,
837
876
  expected_output_cols_type="float",
838
877
  )
839
- expected_output_cols = self._align_expected_output_names(
878
+ expected_output_cols, _ = self._align_expected_output(
840
879
  inference_method, dataset, expected_output_cols, output_cols_prefix
841
880
  )
842
881
 
@@ -901,7 +940,7 @@ class TweedieRegressor(BaseTransformer):
901
940
  drop_input_cols = self._drop_input_cols,
902
941
  expected_output_cols_type="float",
903
942
  )
904
- expected_output_cols = self._align_expected_output_names(
943
+ expected_output_cols, _ = self._align_expected_output(
905
944
  inference_method, dataset, expected_output_cols, output_cols_prefix
906
945
  )
907
946