snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. snowflake/ml/_internal/telemetry.py +142 -20
  2. snowflake/ml/_internal/utils/identifier.py +48 -11
  3. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  4. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  5. snowflake/ml/_internal/utils/table_manager.py +19 -1
  6. snowflake/ml/_internal/utils/uri.py +2 -2
  7. snowflake/ml/data/data_connector.py +33 -7
  8. snowflake/ml/data/torch_utils.py +68 -0
  9. snowflake/ml/dataset/dataset.py +1 -3
  10. snowflake/ml/feature_store/feature_store.py +41 -17
  11. snowflake/ml/feature_store/feature_view.py +2 -2
  12. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  13. snowflake/ml/fileset/fileset.py +1 -1
  14. snowflake/ml/fileset/sfcfs.py +9 -3
  15. snowflake/ml/model/_client/model/model_version_impl.py +22 -7
  16. snowflake/ml/model/_client/ops/model_ops.py +39 -3
  17. snowflake/ml/model/_client/ops/service_ops.py +198 -7
  18. snowflake/ml/model/_client/service/model_deployment_spec.py +4 -5
  19. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -2
  20. snowflake/ml/model/_client/sql/service.py +85 -18
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  23. snowflake/ml/model/_model_composer/model_composer.py +2 -0
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -8
  25. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  26. snowflake/ml/model/_packager/model_handlers/catboost.py +17 -15
  27. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +23 -15
  28. snowflake/ml/model/_packager/model_handlers/lightgbm.py +15 -57
  29. snowflake/ml/model/_packager/model_handlers/llm.py +4 -2
  30. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  31. snowflake/ml/model/_packager/model_handlers/sklearn.py +36 -24
  32. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +119 -6
  33. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  34. snowflake/ml/model/_packager/model_handlers/xgboost.py +48 -48
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +10 -7
  36. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +0 -8
  37. snowflake/ml/model/_packager/model_packager.py +2 -0
  38. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  39. snowflake/ml/model/_signatures/utils.py +9 -0
  40. snowflake/ml/model/models/llm.py +3 -1
  41. snowflake/ml/model/type_hints.py +9 -1
  42. snowflake/ml/modeling/_internal/constants.py +1 -0
  43. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  44. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  45. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  46. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  48. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  49. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  50. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  51. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  52. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  53. snowflake/ml/modeling/cluster/birch.py +60 -21
  54. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  55. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  56. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  57. snowflake/ml/modeling/cluster/k_means.py +60 -21
  58. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  59. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  60. snowflake/ml/modeling/cluster/optics.py +60 -21
  61. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  62. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  63. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  64. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  65. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  66. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  67. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  68. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  69. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  70. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  71. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  72. snowflake/ml/modeling/covariance/oas.py +60 -21
  73. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  74. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  75. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  76. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  77. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  78. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  79. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  80. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  81. snowflake/ml/modeling/decomposition/pca.py +60 -21
  82. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  83. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  84. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  85. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  86. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  87. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  88. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  89. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  90. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  91. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  92. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  93. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  94. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  95. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  96. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  97. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  98. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  99. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  100. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  101. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  102. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  103. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  104. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  105. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  106. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  107. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  108. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  109. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  111. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  112. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  113. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  114. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  123. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  125. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  129. snowflake/ml/modeling/linear_model/lars.py +60 -21
  130. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  131. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  136. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  146. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  149. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  158. snowflake/ml/modeling/manifold/isomap.py +60 -21
  159. snowflake/ml/modeling/manifold/mds.py +60 -21
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  161. snowflake/ml/modeling/manifold/tsne.py +60 -21
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  174. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  184. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  185. snowflake/ml/modeling/pipeline/pipeline.py +1 -12
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  189. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  190. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  191. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  192. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  193. snowflake/ml/modeling/svm/svc.py +60 -21
  194. snowflake/ml/modeling/svm/svr.py +60 -21
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  203. snowflake/ml/registry/_manager/model_manager.py +4 -0
  204. snowflake/ml/registry/model_registry.py +1 -1
  205. snowflake/ml/registry/registry.py +1 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +23 -4
  208. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +211 -209
  209. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  210. snowflake/ml/data/torch_dataset.py +0 -33
  211. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  212. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -553,12 +550,23 @@ class Isomap(BaseTransformer):
553
550
  autogenerated=self._autogenerated,
554
551
  subproject=_SUBPROJECT,
555
552
  )
556
- output_result, fitted_estimator = model_trainer.train_fit_predict(
557
- drop_input_cols=self._drop_input_cols,
558
- expected_output_cols_list=(
559
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
560
- ),
553
+ expected_output_cols = (
554
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
561
555
  )
556
+ if isinstance(dataset, DataFrame):
557
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
558
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
559
+ )
560
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
561
+ drop_input_cols=self._drop_input_cols,
562
+ expected_output_cols_list=expected_output_cols,
563
+ example_output_pd_df=example_output_pd_df,
564
+ )
565
+ else:
566
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
567
+ drop_input_cols=self._drop_input_cols,
568
+ expected_output_cols_list=expected_output_cols,
569
+ )
562
570
  self._sklearn_object = fitted_estimator
563
571
  self._is_fitted = True
564
572
  return output_result
@@ -639,12 +647,41 @@ class Isomap(BaseTransformer):
639
647
 
640
648
  return rv
641
649
 
642
- def _align_expected_output_names(
643
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
644
- ) -> List[str]:
650
+ def _align_expected_output(
651
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
652
+ ) -> Tuple[List[str], pd.DataFrame]:
653
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
654
+ and output dataframe with 1 line.
655
+ If the method is fit_predict, run 2 lines of data.
656
+ """
645
657
  # in case the inferred output column names dimension is different
646
658
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
647
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
659
+
660
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
661
+ # so change the minimum of number of rows to 2
662
+ num_examples = 2
663
+ statement_params = telemetry.get_function_usage_statement_params(
664
+ project=_PROJECT,
665
+ subproject=_SUBPROJECT,
666
+ function_name=telemetry.get_statement_params_full_func_name(
667
+ inspect.currentframe(), Isomap.__class__.__name__
668
+ ),
669
+ api_calls=[Session.call],
670
+ custom_tags={"autogen": True} if self._autogenerated else None,
671
+ )
672
+ if output_cols_prefix == "fit_predict_":
673
+ if hasattr(self._sklearn_object, "n_clusters"):
674
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
675
+ num_examples = self._sklearn_object.n_clusters
676
+ elif hasattr(self._sklearn_object, "min_samples"):
677
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
678
+ num_examples = self._sklearn_object.min_samples
679
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
680
+ # LocalOutlierFactor expects n_neighbors <= n_samples
681
+ num_examples = self._sklearn_object.n_neighbors
682
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
683
+ else:
684
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
648
685
 
649
686
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
650
687
  # seen during the fit.
@@ -656,12 +693,14 @@ class Isomap(BaseTransformer):
656
693
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
657
694
  if self.sample_weight_col:
658
695
  output_df_columns_set -= set(self.sample_weight_col)
696
+
659
697
  # if the dimension of inferred output column names is correct; use it
660
698
  if len(expected_output_cols_list) == len(output_df_columns_set):
661
- return expected_output_cols_list
699
+ return expected_output_cols_list, output_df_pd
662
700
  # otherwise, use the sklearn estimator's output
663
701
  else:
664
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
702
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
703
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
665
704
 
666
705
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
667
706
  @telemetry.send_api_usage_telemetry(
@@ -707,7 +746,7 @@ class Isomap(BaseTransformer):
707
746
  drop_input_cols=self._drop_input_cols,
708
747
  expected_output_cols_type="float",
709
748
  )
710
- expected_output_cols = self._align_expected_output_names(
749
+ expected_output_cols, _ = self._align_expected_output(
711
750
  inference_method, dataset, expected_output_cols, output_cols_prefix
712
751
  )
713
752
 
@@ -773,7 +812,7 @@ class Isomap(BaseTransformer):
773
812
  drop_input_cols=self._drop_input_cols,
774
813
  expected_output_cols_type="float",
775
814
  )
776
- expected_output_cols = self._align_expected_output_names(
815
+ expected_output_cols, _ = self._align_expected_output(
777
816
  inference_method, dataset, expected_output_cols, output_cols_prefix
778
817
  )
779
818
  elif isinstance(dataset, pd.DataFrame):
@@ -836,7 +875,7 @@ class Isomap(BaseTransformer):
836
875
  drop_input_cols=self._drop_input_cols,
837
876
  expected_output_cols_type="float",
838
877
  )
839
- expected_output_cols = self._align_expected_output_names(
878
+ expected_output_cols, _ = self._align_expected_output(
840
879
  inference_method, dataset, expected_output_cols, output_cols_prefix
841
880
  )
842
881
 
@@ -901,7 +940,7 @@ class Isomap(BaseTransformer):
901
940
  drop_input_cols = self._drop_input_cols,
902
941
  expected_output_cols_type="float",
903
942
  )
904
- expected_output_cols = self._align_expected_output_names(
943
+ expected_output_cols, _ = self._align_expected_output(
905
944
  inference_method, dataset, expected_output_cols, output_cols_prefix
906
945
  )
907
946
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -534,12 +531,23 @@ class MDS(BaseTransformer):
534
531
  autogenerated=self._autogenerated,
535
532
  subproject=_SUBPROJECT,
536
533
  )
537
- output_result, fitted_estimator = model_trainer.train_fit_predict(
538
- drop_input_cols=self._drop_input_cols,
539
- expected_output_cols_list=(
540
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
541
- ),
534
+ expected_output_cols = (
535
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
542
536
  )
537
+ if isinstance(dataset, DataFrame):
538
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
539
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
540
+ )
541
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
542
+ drop_input_cols=self._drop_input_cols,
543
+ expected_output_cols_list=expected_output_cols,
544
+ example_output_pd_df=example_output_pd_df,
545
+ )
546
+ else:
547
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
548
+ drop_input_cols=self._drop_input_cols,
549
+ expected_output_cols_list=expected_output_cols,
550
+ )
543
551
  self._sklearn_object = fitted_estimator
544
552
  self._is_fitted = True
545
553
  return output_result
@@ -620,12 +628,41 @@ class MDS(BaseTransformer):
620
628
 
621
629
  return rv
622
630
 
623
- def _align_expected_output_names(
624
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
625
- ) -> List[str]:
631
+ def _align_expected_output(
632
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
633
+ ) -> Tuple[List[str], pd.DataFrame]:
634
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
635
+ and output dataframe with 1 line.
636
+ If the method is fit_predict, run 2 lines of data.
637
+ """
626
638
  # in case the inferred output column names dimension is different
627
639
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
628
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
640
+
641
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
642
+ # so change the minimum of number of rows to 2
643
+ num_examples = 2
644
+ statement_params = telemetry.get_function_usage_statement_params(
645
+ project=_PROJECT,
646
+ subproject=_SUBPROJECT,
647
+ function_name=telemetry.get_statement_params_full_func_name(
648
+ inspect.currentframe(), MDS.__class__.__name__
649
+ ),
650
+ api_calls=[Session.call],
651
+ custom_tags={"autogen": True} if self._autogenerated else None,
652
+ )
653
+ if output_cols_prefix == "fit_predict_":
654
+ if hasattr(self._sklearn_object, "n_clusters"):
655
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
656
+ num_examples = self._sklearn_object.n_clusters
657
+ elif hasattr(self._sklearn_object, "min_samples"):
658
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
659
+ num_examples = self._sklearn_object.min_samples
660
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
661
+ # LocalOutlierFactor expects n_neighbors <= n_samples
662
+ num_examples = self._sklearn_object.n_neighbors
663
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
664
+ else:
665
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
629
666
 
630
667
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
631
668
  # seen during the fit.
@@ -637,12 +674,14 @@ class MDS(BaseTransformer):
637
674
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
638
675
  if self.sample_weight_col:
639
676
  output_df_columns_set -= set(self.sample_weight_col)
677
+
640
678
  # if the dimension of inferred output column names is correct; use it
641
679
  if len(expected_output_cols_list) == len(output_df_columns_set):
642
- return expected_output_cols_list
680
+ return expected_output_cols_list, output_df_pd
643
681
  # otherwise, use the sklearn estimator's output
644
682
  else:
645
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
683
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
684
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
646
685
 
647
686
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
648
687
  @telemetry.send_api_usage_telemetry(
@@ -688,7 +727,7 @@ class MDS(BaseTransformer):
688
727
  drop_input_cols=self._drop_input_cols,
689
728
  expected_output_cols_type="float",
690
729
  )
691
- expected_output_cols = self._align_expected_output_names(
730
+ expected_output_cols, _ = self._align_expected_output(
692
731
  inference_method, dataset, expected_output_cols, output_cols_prefix
693
732
  )
694
733
 
@@ -754,7 +793,7 @@ class MDS(BaseTransformer):
754
793
  drop_input_cols=self._drop_input_cols,
755
794
  expected_output_cols_type="float",
756
795
  )
757
- expected_output_cols = self._align_expected_output_names(
796
+ expected_output_cols, _ = self._align_expected_output(
758
797
  inference_method, dataset, expected_output_cols, output_cols_prefix
759
798
  )
760
799
  elif isinstance(dataset, pd.DataFrame):
@@ -817,7 +856,7 @@ class MDS(BaseTransformer):
817
856
  drop_input_cols=self._drop_input_cols,
818
857
  expected_output_cols_type="float",
819
858
  )
820
- expected_output_cols = self._align_expected_output_names(
859
+ expected_output_cols, _ = self._align_expected_output(
821
860
  inference_method, dataset, expected_output_cols, output_cols_prefix
822
861
  )
823
862
 
@@ -882,7 +921,7 @@ class MDS(BaseTransformer):
882
921
  drop_input_cols = self._drop_input_cols,
883
922
  expected_output_cols_type="float",
884
923
  )
885
- expected_output_cols = self._align_expected_output_names(
924
+ expected_output_cols, _ = self._align_expected_output(
886
925
  inference_method, dataset, expected_output_cols, output_cols_prefix
887
926
  )
888
927
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -536,12 +533,23 @@ class SpectralEmbedding(BaseTransformer):
536
533
  autogenerated=self._autogenerated,
537
534
  subproject=_SUBPROJECT,
538
535
  )
539
- output_result, fitted_estimator = model_trainer.train_fit_predict(
540
- drop_input_cols=self._drop_input_cols,
541
- expected_output_cols_list=(
542
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
543
- ),
536
+ expected_output_cols = (
537
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
544
538
  )
539
+ if isinstance(dataset, DataFrame):
540
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
541
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
542
+ )
543
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
544
+ drop_input_cols=self._drop_input_cols,
545
+ expected_output_cols_list=expected_output_cols,
546
+ example_output_pd_df=example_output_pd_df,
547
+ )
548
+ else:
549
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
550
+ drop_input_cols=self._drop_input_cols,
551
+ expected_output_cols_list=expected_output_cols,
552
+ )
545
553
  self._sklearn_object = fitted_estimator
546
554
  self._is_fitted = True
547
555
  return output_result
@@ -622,12 +630,41 @@ class SpectralEmbedding(BaseTransformer):
622
630
 
623
631
  return rv
624
632
 
625
- def _align_expected_output_names(
626
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
627
- ) -> List[str]:
633
+ def _align_expected_output(
634
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
635
+ ) -> Tuple[List[str], pd.DataFrame]:
636
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
637
+ and output dataframe with 1 line.
638
+ If the method is fit_predict, run 2 lines of data.
639
+ """
628
640
  # in case the inferred output column names dimension is different
629
641
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
630
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
642
+
643
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
644
+ # so change the minimum of number of rows to 2
645
+ num_examples = 2
646
+ statement_params = telemetry.get_function_usage_statement_params(
647
+ project=_PROJECT,
648
+ subproject=_SUBPROJECT,
649
+ function_name=telemetry.get_statement_params_full_func_name(
650
+ inspect.currentframe(), SpectralEmbedding.__class__.__name__
651
+ ),
652
+ api_calls=[Session.call],
653
+ custom_tags={"autogen": True} if self._autogenerated else None,
654
+ )
655
+ if output_cols_prefix == "fit_predict_":
656
+ if hasattr(self._sklearn_object, "n_clusters"):
657
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
658
+ num_examples = self._sklearn_object.n_clusters
659
+ elif hasattr(self._sklearn_object, "min_samples"):
660
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
661
+ num_examples = self._sklearn_object.min_samples
662
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
663
+ # LocalOutlierFactor expects n_neighbors <= n_samples
664
+ num_examples = self._sklearn_object.n_neighbors
665
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
666
+ else:
667
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
631
668
 
632
669
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
633
670
  # seen during the fit.
@@ -639,12 +676,14 @@ class SpectralEmbedding(BaseTransformer):
639
676
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
640
677
  if self.sample_weight_col:
641
678
  output_df_columns_set -= set(self.sample_weight_col)
679
+
642
680
  # if the dimension of inferred output column names is correct; use it
643
681
  if len(expected_output_cols_list) == len(output_df_columns_set):
644
- return expected_output_cols_list
682
+ return expected_output_cols_list, output_df_pd
645
683
  # otherwise, use the sklearn estimator's output
646
684
  else:
647
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
685
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
686
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
648
687
 
649
688
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
650
689
  @telemetry.send_api_usage_telemetry(
@@ -690,7 +729,7 @@ class SpectralEmbedding(BaseTransformer):
690
729
  drop_input_cols=self._drop_input_cols,
691
730
  expected_output_cols_type="float",
692
731
  )
693
- expected_output_cols = self._align_expected_output_names(
732
+ expected_output_cols, _ = self._align_expected_output(
694
733
  inference_method, dataset, expected_output_cols, output_cols_prefix
695
734
  )
696
735
 
@@ -756,7 +795,7 @@ class SpectralEmbedding(BaseTransformer):
756
795
  drop_input_cols=self._drop_input_cols,
757
796
  expected_output_cols_type="float",
758
797
  )
759
- expected_output_cols = self._align_expected_output_names(
798
+ expected_output_cols, _ = self._align_expected_output(
760
799
  inference_method, dataset, expected_output_cols, output_cols_prefix
761
800
  )
762
801
  elif isinstance(dataset, pd.DataFrame):
@@ -819,7 +858,7 @@ class SpectralEmbedding(BaseTransformer):
819
858
  drop_input_cols=self._drop_input_cols,
820
859
  expected_output_cols_type="float",
821
860
  )
822
- expected_output_cols = self._align_expected_output_names(
861
+ expected_output_cols, _ = self._align_expected_output(
823
862
  inference_method, dataset, expected_output_cols, output_cols_prefix
824
863
  )
825
864
 
@@ -884,7 +923,7 @@ class SpectralEmbedding(BaseTransformer):
884
923
  drop_input_cols = self._drop_input_cols,
885
924
  expected_output_cols_type="float",
886
925
  )
887
- expected_output_cols = self._align_expected_output_names(
926
+ expected_output_cols, _ = self._align_expected_output(
888
927
  inference_method, dataset, expected_output_cols, output_cols_prefix
889
928
  )
890
929
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -595,12 +592,23 @@ class TSNE(BaseTransformer):
595
592
  autogenerated=self._autogenerated,
596
593
  subproject=_SUBPROJECT,
597
594
  )
598
- output_result, fitted_estimator = model_trainer.train_fit_predict(
599
- drop_input_cols=self._drop_input_cols,
600
- expected_output_cols_list=(
601
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
602
- ),
595
+ expected_output_cols = (
596
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
603
597
  )
598
+ if isinstance(dataset, DataFrame):
599
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
600
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
601
+ )
602
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
603
+ drop_input_cols=self._drop_input_cols,
604
+ expected_output_cols_list=expected_output_cols,
605
+ example_output_pd_df=example_output_pd_df,
606
+ )
607
+ else:
608
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
609
+ drop_input_cols=self._drop_input_cols,
610
+ expected_output_cols_list=expected_output_cols,
611
+ )
604
612
  self._sklearn_object = fitted_estimator
605
613
  self._is_fitted = True
606
614
  return output_result
@@ -681,12 +689,41 @@ class TSNE(BaseTransformer):
681
689
 
682
690
  return rv
683
691
 
684
- def _align_expected_output_names(
685
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
686
- ) -> List[str]:
692
+ def _align_expected_output(
693
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
694
+ ) -> Tuple[List[str], pd.DataFrame]:
695
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
696
+ and output dataframe with 1 line.
697
+ If the method is fit_predict, run 2 lines of data.
698
+ """
687
699
  # in case the inferred output column names dimension is different
688
700
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
689
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
701
+
702
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
703
+ # so change the minimum of number of rows to 2
704
+ num_examples = 2
705
+ statement_params = telemetry.get_function_usage_statement_params(
706
+ project=_PROJECT,
707
+ subproject=_SUBPROJECT,
708
+ function_name=telemetry.get_statement_params_full_func_name(
709
+ inspect.currentframe(), TSNE.__class__.__name__
710
+ ),
711
+ api_calls=[Session.call],
712
+ custom_tags={"autogen": True} if self._autogenerated else None,
713
+ )
714
+ if output_cols_prefix == "fit_predict_":
715
+ if hasattr(self._sklearn_object, "n_clusters"):
716
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
717
+ num_examples = self._sklearn_object.n_clusters
718
+ elif hasattr(self._sklearn_object, "min_samples"):
719
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
720
+ num_examples = self._sklearn_object.min_samples
721
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
722
+ # LocalOutlierFactor expects n_neighbors <= n_samples
723
+ num_examples = self._sklearn_object.n_neighbors
724
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
725
+ else:
726
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
690
727
 
691
728
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
692
729
  # seen during the fit.
@@ -698,12 +735,14 @@ class TSNE(BaseTransformer):
698
735
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
699
736
  if self.sample_weight_col:
700
737
  output_df_columns_set -= set(self.sample_weight_col)
738
+
701
739
  # if the dimension of inferred output column names is correct; use it
702
740
  if len(expected_output_cols_list) == len(output_df_columns_set):
703
- return expected_output_cols_list
741
+ return expected_output_cols_list, output_df_pd
704
742
  # otherwise, use the sklearn estimator's output
705
743
  else:
706
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
744
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
745
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
707
746
 
708
747
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
709
748
  @telemetry.send_api_usage_telemetry(
@@ -749,7 +788,7 @@ class TSNE(BaseTransformer):
749
788
  drop_input_cols=self._drop_input_cols,
750
789
  expected_output_cols_type="float",
751
790
  )
752
- expected_output_cols = self._align_expected_output_names(
791
+ expected_output_cols, _ = self._align_expected_output(
753
792
  inference_method, dataset, expected_output_cols, output_cols_prefix
754
793
  )
755
794
 
@@ -815,7 +854,7 @@ class TSNE(BaseTransformer):
815
854
  drop_input_cols=self._drop_input_cols,
816
855
  expected_output_cols_type="float",
817
856
  )
818
- expected_output_cols = self._align_expected_output_names(
857
+ expected_output_cols, _ = self._align_expected_output(
819
858
  inference_method, dataset, expected_output_cols, output_cols_prefix
820
859
  )
821
860
  elif isinstance(dataset, pd.DataFrame):
@@ -878,7 +917,7 @@ class TSNE(BaseTransformer):
878
917
  drop_input_cols=self._drop_input_cols,
879
918
  expected_output_cols_type="float",
880
919
  )
881
- expected_output_cols = self._align_expected_output_names(
920
+ expected_output_cols, _ = self._align_expected_output(
882
921
  inference_method, dataset, expected_output_cols, output_cols_prefix
883
922
  )
884
923
 
@@ -943,7 +982,7 @@ class TSNE(BaseTransformer):
943
982
  drop_input_cols = self._drop_input_cols,
944
983
  expected_output_cols_type="float",
945
984
  )
946
- expected_output_cols = self._align_expected_output_names(
985
+ expected_output_cols, _ = self._align_expected_output(
947
986
  inference_method, dataset, expected_output_cols, output_cols_prefix
948
987
  )
949
988