snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. snowflake/ml/_internal/telemetry.py +142 -20
  2. snowflake/ml/_internal/utils/identifier.py +48 -11
  3. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  4. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  5. snowflake/ml/_internal/utils/table_manager.py +19 -1
  6. snowflake/ml/_internal/utils/uri.py +2 -2
  7. snowflake/ml/data/data_connector.py +33 -7
  8. snowflake/ml/data/torch_utils.py +68 -0
  9. snowflake/ml/dataset/dataset.py +1 -3
  10. snowflake/ml/feature_store/feature_store.py +41 -17
  11. snowflake/ml/feature_store/feature_view.py +2 -2
  12. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  13. snowflake/ml/fileset/fileset.py +1 -1
  14. snowflake/ml/fileset/sfcfs.py +9 -3
  15. snowflake/ml/model/_client/model/model_version_impl.py +22 -7
  16. snowflake/ml/model/_client/ops/model_ops.py +39 -3
  17. snowflake/ml/model/_client/ops/service_ops.py +198 -7
  18. snowflake/ml/model/_client/service/model_deployment_spec.py +4 -5
  19. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -2
  20. snowflake/ml/model/_client/sql/service.py +85 -18
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  23. snowflake/ml/model/_model_composer/model_composer.py +2 -0
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -8
  25. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  26. snowflake/ml/model/_packager/model_handlers/catboost.py +17 -15
  27. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +23 -15
  28. snowflake/ml/model/_packager/model_handlers/lightgbm.py +15 -57
  29. snowflake/ml/model/_packager/model_handlers/llm.py +4 -2
  30. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  31. snowflake/ml/model/_packager/model_handlers/sklearn.py +36 -24
  32. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +119 -6
  33. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  34. snowflake/ml/model/_packager/model_handlers/xgboost.py +48 -48
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +10 -7
  36. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +0 -8
  37. snowflake/ml/model/_packager/model_packager.py +2 -0
  38. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  39. snowflake/ml/model/_signatures/utils.py +9 -0
  40. snowflake/ml/model/models/llm.py +3 -1
  41. snowflake/ml/model/type_hints.py +9 -1
  42. snowflake/ml/modeling/_internal/constants.py +1 -0
  43. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  44. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  45. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  46. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  48. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  49. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  50. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  51. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  52. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  53. snowflake/ml/modeling/cluster/birch.py +60 -21
  54. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  55. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  56. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  57. snowflake/ml/modeling/cluster/k_means.py +60 -21
  58. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  59. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  60. snowflake/ml/modeling/cluster/optics.py +60 -21
  61. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  62. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  63. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  64. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  65. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  66. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  67. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  68. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  69. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  70. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  71. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  72. snowflake/ml/modeling/covariance/oas.py +60 -21
  73. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  74. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  75. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  76. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  77. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  78. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  79. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  80. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  81. snowflake/ml/modeling/decomposition/pca.py +60 -21
  82. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  83. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  84. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  85. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  86. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  87. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  88. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  89. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  90. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  91. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  92. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  93. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  94. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  95. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  96. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  97. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  98. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  99. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  100. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  101. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  102. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  103. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  104. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  105. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  106. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  107. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  108. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  109. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  111. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  112. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  113. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  114. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  123. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  125. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  129. snowflake/ml/modeling/linear_model/lars.py +60 -21
  130. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  131. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  136. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  146. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  149. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  158. snowflake/ml/modeling/manifold/isomap.py +60 -21
  159. snowflake/ml/modeling/manifold/mds.py +60 -21
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  161. snowflake/ml/modeling/manifold/tsne.py +60 -21
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  174. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  184. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  185. snowflake/ml/modeling/pipeline/pipeline.py +1 -12
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  189. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  190. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  191. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  192. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  193. snowflake/ml/modeling/svm/svc.py +60 -21
  194. snowflake/ml/modeling/svm/svr.py +60 -21
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  203. snowflake/ml/registry/_manager/model_manager.py +4 -0
  204. snowflake/ml/registry/model_registry.py +1 -1
  205. snowflake/ml/registry/registry.py +1 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +23 -4
  208. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +211 -209
  209. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  210. snowflake/ml/data/torch_dataset.py +0 -33
  211. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  212. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -499,12 +496,23 @@ class MissingIndicator(BaseTransformer):
499
496
  autogenerated=self._autogenerated,
500
497
  subproject=_SUBPROJECT,
501
498
  )
502
- output_result, fitted_estimator = model_trainer.train_fit_predict(
503
- drop_input_cols=self._drop_input_cols,
504
- expected_output_cols_list=(
505
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
506
- ),
499
+ expected_output_cols = (
500
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
507
501
  )
502
+ if isinstance(dataset, DataFrame):
503
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
504
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
505
+ )
506
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
507
+ drop_input_cols=self._drop_input_cols,
508
+ expected_output_cols_list=expected_output_cols,
509
+ example_output_pd_df=example_output_pd_df,
510
+ )
511
+ else:
512
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
513
+ drop_input_cols=self._drop_input_cols,
514
+ expected_output_cols_list=expected_output_cols,
515
+ )
508
516
  self._sklearn_object = fitted_estimator
509
517
  self._is_fitted = True
510
518
  return output_result
@@ -585,12 +593,41 @@ class MissingIndicator(BaseTransformer):
585
593
 
586
594
  return rv
587
595
 
588
- def _align_expected_output_names(
589
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
590
- ) -> List[str]:
596
+ def _align_expected_output(
597
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
598
+ ) -> Tuple[List[str], pd.DataFrame]:
599
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
600
+ and output dataframe with 1 line.
601
+ If the method is fit_predict, run 2 lines of data.
602
+ """
591
603
  # in case the inferred output column names dimension is different
592
604
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
593
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
605
+
606
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
607
+ # so change the minimum of number of rows to 2
608
+ num_examples = 2
609
+ statement_params = telemetry.get_function_usage_statement_params(
610
+ project=_PROJECT,
611
+ subproject=_SUBPROJECT,
612
+ function_name=telemetry.get_statement_params_full_func_name(
613
+ inspect.currentframe(), MissingIndicator.__class__.__name__
614
+ ),
615
+ api_calls=[Session.call],
616
+ custom_tags={"autogen": True} if self._autogenerated else None,
617
+ )
618
+ if output_cols_prefix == "fit_predict_":
619
+ if hasattr(self._sklearn_object, "n_clusters"):
620
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
621
+ num_examples = self._sklearn_object.n_clusters
622
+ elif hasattr(self._sklearn_object, "min_samples"):
623
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
624
+ num_examples = self._sklearn_object.min_samples
625
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
626
+ # LocalOutlierFactor expects n_neighbors <= n_samples
627
+ num_examples = self._sklearn_object.n_neighbors
628
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
629
+ else:
630
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
594
631
 
595
632
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
596
633
  # seen during the fit.
@@ -602,12 +639,14 @@ class MissingIndicator(BaseTransformer):
602
639
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
603
640
  if self.sample_weight_col:
604
641
  output_df_columns_set -= set(self.sample_weight_col)
642
+
605
643
  # if the dimension of inferred output column names is correct; use it
606
644
  if len(expected_output_cols_list) == len(output_df_columns_set):
607
- return expected_output_cols_list
645
+ return expected_output_cols_list, output_df_pd
608
646
  # otherwise, use the sklearn estimator's output
609
647
  else:
610
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
648
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
649
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
611
650
 
612
651
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
613
652
  @telemetry.send_api_usage_telemetry(
@@ -653,7 +692,7 @@ class MissingIndicator(BaseTransformer):
653
692
  drop_input_cols=self._drop_input_cols,
654
693
  expected_output_cols_type="float",
655
694
  )
656
- expected_output_cols = self._align_expected_output_names(
695
+ expected_output_cols, _ = self._align_expected_output(
657
696
  inference_method, dataset, expected_output_cols, output_cols_prefix
658
697
  )
659
698
 
@@ -719,7 +758,7 @@ class MissingIndicator(BaseTransformer):
719
758
  drop_input_cols=self._drop_input_cols,
720
759
  expected_output_cols_type="float",
721
760
  )
722
- expected_output_cols = self._align_expected_output_names(
761
+ expected_output_cols, _ = self._align_expected_output(
723
762
  inference_method, dataset, expected_output_cols, output_cols_prefix
724
763
  )
725
764
  elif isinstance(dataset, pd.DataFrame):
@@ -782,7 +821,7 @@ class MissingIndicator(BaseTransformer):
782
821
  drop_input_cols=self._drop_input_cols,
783
822
  expected_output_cols_type="float",
784
823
  )
785
- expected_output_cols = self._align_expected_output_names(
824
+ expected_output_cols, _ = self._align_expected_output(
786
825
  inference_method, dataset, expected_output_cols, output_cols_prefix
787
826
  )
788
827
 
@@ -847,7 +886,7 @@ class MissingIndicator(BaseTransformer):
847
886
  drop_input_cols = self._drop_input_cols,
848
887
  expected_output_cols_type="float",
849
888
  )
850
- expected_output_cols = self._align_expected_output_names(
889
+ expected_output_cols, _ = self._align_expected_output(
851
890
  inference_method, dataset, expected_output_cols, output_cols_prefix
852
891
  )
853
892
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -474,12 +471,23 @@ class AdditiveChi2Sampler(BaseTransformer):
474
471
  autogenerated=self._autogenerated,
475
472
  subproject=_SUBPROJECT,
476
473
  )
477
- output_result, fitted_estimator = model_trainer.train_fit_predict(
478
- drop_input_cols=self._drop_input_cols,
479
- expected_output_cols_list=(
480
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
481
- ),
474
+ expected_output_cols = (
475
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
482
476
  )
477
+ if isinstance(dataset, DataFrame):
478
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
479
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
480
+ )
481
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
482
+ drop_input_cols=self._drop_input_cols,
483
+ expected_output_cols_list=expected_output_cols,
484
+ example_output_pd_df=example_output_pd_df,
485
+ )
486
+ else:
487
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
488
+ drop_input_cols=self._drop_input_cols,
489
+ expected_output_cols_list=expected_output_cols,
490
+ )
483
491
  self._sklearn_object = fitted_estimator
484
492
  self._is_fitted = True
485
493
  return output_result
@@ -560,12 +568,41 @@ class AdditiveChi2Sampler(BaseTransformer):
560
568
 
561
569
  return rv
562
570
 
563
- def _align_expected_output_names(
564
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
565
- ) -> List[str]:
571
+ def _align_expected_output(
572
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
573
+ ) -> Tuple[List[str], pd.DataFrame]:
574
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
575
+ and output dataframe with 1 line.
576
+ If the method is fit_predict, run 2 lines of data.
577
+ """
566
578
  # in case the inferred output column names dimension is different
567
579
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
568
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
580
+
581
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
582
+ # so change the minimum of number of rows to 2
583
+ num_examples = 2
584
+ statement_params = telemetry.get_function_usage_statement_params(
585
+ project=_PROJECT,
586
+ subproject=_SUBPROJECT,
587
+ function_name=telemetry.get_statement_params_full_func_name(
588
+ inspect.currentframe(), AdditiveChi2Sampler.__class__.__name__
589
+ ),
590
+ api_calls=[Session.call],
591
+ custom_tags={"autogen": True} if self._autogenerated else None,
592
+ )
593
+ if output_cols_prefix == "fit_predict_":
594
+ if hasattr(self._sklearn_object, "n_clusters"):
595
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
596
+ num_examples = self._sklearn_object.n_clusters
597
+ elif hasattr(self._sklearn_object, "min_samples"):
598
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
599
+ num_examples = self._sklearn_object.min_samples
600
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
601
+ # LocalOutlierFactor expects n_neighbors <= n_samples
602
+ num_examples = self._sklearn_object.n_neighbors
603
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
604
+ else:
605
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
569
606
 
570
607
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
571
608
  # seen during the fit.
@@ -577,12 +614,14 @@ class AdditiveChi2Sampler(BaseTransformer):
577
614
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
578
615
  if self.sample_weight_col:
579
616
  output_df_columns_set -= set(self.sample_weight_col)
617
+
580
618
  # if the dimension of inferred output column names is correct; use it
581
619
  if len(expected_output_cols_list) == len(output_df_columns_set):
582
- return expected_output_cols_list
620
+ return expected_output_cols_list, output_df_pd
583
621
  # otherwise, use the sklearn estimator's output
584
622
  else:
585
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
623
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
624
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
586
625
 
587
626
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
588
627
  @telemetry.send_api_usage_telemetry(
@@ -628,7 +667,7 @@ class AdditiveChi2Sampler(BaseTransformer):
628
667
  drop_input_cols=self._drop_input_cols,
629
668
  expected_output_cols_type="float",
630
669
  )
631
- expected_output_cols = self._align_expected_output_names(
670
+ expected_output_cols, _ = self._align_expected_output(
632
671
  inference_method, dataset, expected_output_cols, output_cols_prefix
633
672
  )
634
673
 
@@ -694,7 +733,7 @@ class AdditiveChi2Sampler(BaseTransformer):
694
733
  drop_input_cols=self._drop_input_cols,
695
734
  expected_output_cols_type="float",
696
735
  )
697
- expected_output_cols = self._align_expected_output_names(
736
+ expected_output_cols, _ = self._align_expected_output(
698
737
  inference_method, dataset, expected_output_cols, output_cols_prefix
699
738
  )
700
739
  elif isinstance(dataset, pd.DataFrame):
@@ -757,7 +796,7 @@ class AdditiveChi2Sampler(BaseTransformer):
757
796
  drop_input_cols=self._drop_input_cols,
758
797
  expected_output_cols_type="float",
759
798
  )
760
- expected_output_cols = self._align_expected_output_names(
799
+ expected_output_cols, _ = self._align_expected_output(
761
800
  inference_method, dataset, expected_output_cols, output_cols_prefix
762
801
  )
763
802
 
@@ -822,7 +861,7 @@ class AdditiveChi2Sampler(BaseTransformer):
822
861
  drop_input_cols = self._drop_input_cols,
823
862
  expected_output_cols_type="float",
824
863
  )
825
- expected_output_cols = self._align_expected_output_names(
864
+ expected_output_cols, _ = self._align_expected_output(
826
865
  inference_method, dataset, expected_output_cols, output_cols_prefix
827
866
  )
828
867
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -522,12 +519,23 @@ class Nystroem(BaseTransformer):
522
519
  autogenerated=self._autogenerated,
523
520
  subproject=_SUBPROJECT,
524
521
  )
525
- output_result, fitted_estimator = model_trainer.train_fit_predict(
526
- drop_input_cols=self._drop_input_cols,
527
- expected_output_cols_list=(
528
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
529
- ),
522
+ expected_output_cols = (
523
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
530
524
  )
525
+ if isinstance(dataset, DataFrame):
526
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
527
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
528
+ )
529
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
530
+ drop_input_cols=self._drop_input_cols,
531
+ expected_output_cols_list=expected_output_cols,
532
+ example_output_pd_df=example_output_pd_df,
533
+ )
534
+ else:
535
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
536
+ drop_input_cols=self._drop_input_cols,
537
+ expected_output_cols_list=expected_output_cols,
538
+ )
531
539
  self._sklearn_object = fitted_estimator
532
540
  self._is_fitted = True
533
541
  return output_result
@@ -608,12 +616,41 @@ class Nystroem(BaseTransformer):
608
616
 
609
617
  return rv
610
618
 
611
- def _align_expected_output_names(
612
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
613
- ) -> List[str]:
619
+ def _align_expected_output(
620
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
621
+ ) -> Tuple[List[str], pd.DataFrame]:
622
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
623
+ and output dataframe with 1 line.
624
+ If the method is fit_predict, run 2 lines of data.
625
+ """
614
626
  # in case the inferred output column names dimension is different
615
627
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
616
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
628
+
629
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
630
+ # so change the minimum of number of rows to 2
631
+ num_examples = 2
632
+ statement_params = telemetry.get_function_usage_statement_params(
633
+ project=_PROJECT,
634
+ subproject=_SUBPROJECT,
635
+ function_name=telemetry.get_statement_params_full_func_name(
636
+ inspect.currentframe(), Nystroem.__class__.__name__
637
+ ),
638
+ api_calls=[Session.call],
639
+ custom_tags={"autogen": True} if self._autogenerated else None,
640
+ )
641
+ if output_cols_prefix == "fit_predict_":
642
+ if hasattr(self._sklearn_object, "n_clusters"):
643
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
644
+ num_examples = self._sklearn_object.n_clusters
645
+ elif hasattr(self._sklearn_object, "min_samples"):
646
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
647
+ num_examples = self._sklearn_object.min_samples
648
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
649
+ # LocalOutlierFactor expects n_neighbors <= n_samples
650
+ num_examples = self._sklearn_object.n_neighbors
651
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
652
+ else:
653
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
617
654
 
618
655
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
619
656
  # seen during the fit.
@@ -625,12 +662,14 @@ class Nystroem(BaseTransformer):
625
662
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
626
663
  if self.sample_weight_col:
627
664
  output_df_columns_set -= set(self.sample_weight_col)
665
+
628
666
  # if the dimension of inferred output column names is correct; use it
629
667
  if len(expected_output_cols_list) == len(output_df_columns_set):
630
- return expected_output_cols_list
668
+ return expected_output_cols_list, output_df_pd
631
669
  # otherwise, use the sklearn estimator's output
632
670
  else:
633
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
671
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
672
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
634
673
 
635
674
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
636
675
  @telemetry.send_api_usage_telemetry(
@@ -676,7 +715,7 @@ class Nystroem(BaseTransformer):
676
715
  drop_input_cols=self._drop_input_cols,
677
716
  expected_output_cols_type="float",
678
717
  )
679
- expected_output_cols = self._align_expected_output_names(
718
+ expected_output_cols, _ = self._align_expected_output(
680
719
  inference_method, dataset, expected_output_cols, output_cols_prefix
681
720
  )
682
721
 
@@ -742,7 +781,7 @@ class Nystroem(BaseTransformer):
742
781
  drop_input_cols=self._drop_input_cols,
743
782
  expected_output_cols_type="float",
744
783
  )
745
- expected_output_cols = self._align_expected_output_names(
784
+ expected_output_cols, _ = self._align_expected_output(
746
785
  inference_method, dataset, expected_output_cols, output_cols_prefix
747
786
  )
748
787
  elif isinstance(dataset, pd.DataFrame):
@@ -805,7 +844,7 @@ class Nystroem(BaseTransformer):
805
844
  drop_input_cols=self._drop_input_cols,
806
845
  expected_output_cols_type="float",
807
846
  )
808
- expected_output_cols = self._align_expected_output_names(
847
+ expected_output_cols, _ = self._align_expected_output(
809
848
  inference_method, dataset, expected_output_cols, output_cols_prefix
810
849
  )
811
850
 
@@ -870,7 +909,7 @@ class Nystroem(BaseTransformer):
870
909
  drop_input_cols = self._drop_input_cols,
871
910
  expected_output_cols_type="float",
872
911
  )
873
- expected_output_cols = self._align_expected_output_names(
912
+ expected_output_cols, _ = self._align_expected_output(
874
913
  inference_method, dataset, expected_output_cols, output_cols_prefix
875
914
  )
876
915
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -498,12 +495,23 @@ class PolynomialCountSketch(BaseTransformer):
498
495
  autogenerated=self._autogenerated,
499
496
  subproject=_SUBPROJECT,
500
497
  )
501
- output_result, fitted_estimator = model_trainer.train_fit_predict(
502
- drop_input_cols=self._drop_input_cols,
503
- expected_output_cols_list=(
504
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
505
- ),
498
+ expected_output_cols = (
499
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
506
500
  )
501
+ if isinstance(dataset, DataFrame):
502
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
503
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
504
+ )
505
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
506
+ drop_input_cols=self._drop_input_cols,
507
+ expected_output_cols_list=expected_output_cols,
508
+ example_output_pd_df=example_output_pd_df,
509
+ )
510
+ else:
511
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
512
+ drop_input_cols=self._drop_input_cols,
513
+ expected_output_cols_list=expected_output_cols,
514
+ )
507
515
  self._sklearn_object = fitted_estimator
508
516
  self._is_fitted = True
509
517
  return output_result
@@ -584,12 +592,41 @@ class PolynomialCountSketch(BaseTransformer):
584
592
 
585
593
  return rv
586
594
 
587
- def _align_expected_output_names(
588
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
589
- ) -> List[str]:
595
+ def _align_expected_output(
596
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
597
+ ) -> Tuple[List[str], pd.DataFrame]:
598
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
599
+ and output dataframe with 1 line.
600
+ If the method is fit_predict, run 2 lines of data.
601
+ """
590
602
  # in case the inferred output column names dimension is different
591
603
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
592
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
604
+
605
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
606
+ # so change the minimum of number of rows to 2
607
+ num_examples = 2
608
+ statement_params = telemetry.get_function_usage_statement_params(
609
+ project=_PROJECT,
610
+ subproject=_SUBPROJECT,
611
+ function_name=telemetry.get_statement_params_full_func_name(
612
+ inspect.currentframe(), PolynomialCountSketch.__class__.__name__
613
+ ),
614
+ api_calls=[Session.call],
615
+ custom_tags={"autogen": True} if self._autogenerated else None,
616
+ )
617
+ if output_cols_prefix == "fit_predict_":
618
+ if hasattr(self._sklearn_object, "n_clusters"):
619
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
620
+ num_examples = self._sklearn_object.n_clusters
621
+ elif hasattr(self._sklearn_object, "min_samples"):
622
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
623
+ num_examples = self._sklearn_object.min_samples
624
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
625
+ # LocalOutlierFactor expects n_neighbors <= n_samples
626
+ num_examples = self._sklearn_object.n_neighbors
627
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
628
+ else:
629
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
593
630
 
594
631
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
595
632
  # seen during the fit.
@@ -601,12 +638,14 @@ class PolynomialCountSketch(BaseTransformer):
601
638
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
602
639
  if self.sample_weight_col:
603
640
  output_df_columns_set -= set(self.sample_weight_col)
641
+
604
642
  # if the dimension of inferred output column names is correct; use it
605
643
  if len(expected_output_cols_list) == len(output_df_columns_set):
606
- return expected_output_cols_list
644
+ return expected_output_cols_list, output_df_pd
607
645
  # otherwise, use the sklearn estimator's output
608
646
  else:
609
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
647
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
648
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
610
649
 
611
650
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
612
651
  @telemetry.send_api_usage_telemetry(
@@ -652,7 +691,7 @@ class PolynomialCountSketch(BaseTransformer):
652
691
  drop_input_cols=self._drop_input_cols,
653
692
  expected_output_cols_type="float",
654
693
  )
655
- expected_output_cols = self._align_expected_output_names(
694
+ expected_output_cols, _ = self._align_expected_output(
656
695
  inference_method, dataset, expected_output_cols, output_cols_prefix
657
696
  )
658
697
 
@@ -718,7 +757,7 @@ class PolynomialCountSketch(BaseTransformer):
718
757
  drop_input_cols=self._drop_input_cols,
719
758
  expected_output_cols_type="float",
720
759
  )
721
- expected_output_cols = self._align_expected_output_names(
760
+ expected_output_cols, _ = self._align_expected_output(
722
761
  inference_method, dataset, expected_output_cols, output_cols_prefix
723
762
  )
724
763
  elif isinstance(dataset, pd.DataFrame):
@@ -781,7 +820,7 @@ class PolynomialCountSketch(BaseTransformer):
781
820
  drop_input_cols=self._drop_input_cols,
782
821
  expected_output_cols_type="float",
783
822
  )
784
- expected_output_cols = self._align_expected_output_names(
823
+ expected_output_cols, _ = self._align_expected_output(
785
824
  inference_method, dataset, expected_output_cols, output_cols_prefix
786
825
  )
787
826
 
@@ -846,7 +885,7 @@ class PolynomialCountSketch(BaseTransformer):
846
885
  drop_input_cols = self._drop_input_cols,
847
886
  expected_output_cols_type="float",
848
887
  )
849
- expected_output_cols = self._align_expected_output_names(
888
+ expected_output_cols, _ = self._align_expected_output(
850
889
  inference_method, dataset, expected_output_cols, output_cols_prefix
851
890
  )
852
891