snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +156 -20
  5. snowflake/ml/_internal/utils/identifier.py +48 -11
  6. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  7. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  8. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  9. snowflake/ml/_internal/utils/table_manager.py +19 -1
  10. snowflake/ml/_internal/utils/uri.py +2 -2
  11. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  12. snowflake/ml/data/data_connector.py +88 -9
  13. snowflake/ml/data/data_ingestor.py +18 -1
  14. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  15. snowflake/ml/data/torch_utils.py +68 -0
  16. snowflake/ml/dataset/dataset.py +1 -3
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +9 -3
  19. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  20. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  21. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  22. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  23. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  24. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  26. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  27. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  29. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  31. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  32. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  33. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  34. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  35. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  36. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  37. snowflake/ml/feature_store/feature_store.py +100 -41
  38. snowflake/ml/feature_store/feature_view.py +149 -5
  39. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  40. snowflake/ml/fileset/fileset.py +1 -1
  41. snowflake/ml/fileset/sfcfs.py +9 -3
  42. snowflake/ml/model/_client/model/model_impl.py +11 -2
  43. snowflake/ml/model/_client/model/model_version_impl.py +186 -20
  44. snowflake/ml/model/_client/ops/model_ops.py +144 -30
  45. snowflake/ml/model/_client/ops/service_ops.py +312 -0
  46. snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
  47. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
  48. snowflake/ml/model/_client/sql/model_version.py +13 -4
  49. snowflake/ml/model/_client/sql/service.py +196 -0
  50. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  51. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  52. snowflake/ml/model/_model_composer/model_composer.py +5 -0
  53. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
  54. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  55. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  57. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  58. snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
  59. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  60. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
  61. snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
  62. snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
  63. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  64. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  65. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  66. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  67. snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
  68. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
  69. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  70. snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
  71. snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
  72. snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
  73. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
  74. snowflake/ml/model/_packager/model_packager.py +4 -1
  75. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  76. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  77. snowflake/ml/model/_signatures/utils.py +9 -0
  78. snowflake/ml/model/models/llm.py +3 -1
  79. snowflake/ml/model/type_hints.py +10 -4
  80. snowflake/ml/modeling/_internal/constants.py +1 -0
  81. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  82. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  83. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  84. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  85. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  86. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  87. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  88. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  89. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  90. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  91. snowflake/ml/modeling/cluster/birch.py +60 -21
  92. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  93. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  94. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  95. snowflake/ml/modeling/cluster/k_means.py +60 -21
  96. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  97. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  98. snowflake/ml/modeling/cluster/optics.py +60 -21
  99. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  100. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  101. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  102. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  103. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  104. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  105. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  106. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  107. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  108. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  109. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  110. snowflake/ml/modeling/covariance/oas.py +60 -21
  111. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  112. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  113. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  114. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  115. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  116. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  117. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  118. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  119. snowflake/ml/modeling/decomposition/pca.py +60 -21
  120. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  121. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  122. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  123. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  124. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  125. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  126. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  127. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  128. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  129. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  130. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  131. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  132. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  133. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  134. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  135. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  136. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  137. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  138. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  139. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  140. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  141. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  142. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  143. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  144. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  145. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  146. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  147. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  148. snowflake/ml/modeling/framework/base.py +28 -19
  149. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  150. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  151. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  152. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  153. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  154. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  155. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  156. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  157. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  158. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  159. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  160. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  161. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  162. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  163. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  164. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  165. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  166. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  167. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  168. snowflake/ml/modeling/linear_model/lars.py +60 -21
  169. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  170. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  171. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  172. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  173. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  174. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  175. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  176. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  177. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  178. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  179. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  180. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  181. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  182. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  183. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  184. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  185. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  186. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  187. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  188. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  189. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  190. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  191. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  192. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  193. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  194. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  195. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  196. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  197. snowflake/ml/modeling/manifold/isomap.py +60 -21
  198. snowflake/ml/modeling/manifold/mds.py +60 -21
  199. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  200. snowflake/ml/modeling/manifold/tsne.py +60 -21
  201. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  202. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  203. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  204. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  205. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  206. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  207. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  208. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  209. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  210. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  211. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  212. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  213. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  214. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  215. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  216. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  217. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  218. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  219. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  220. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  221. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  222. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  223. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  224. snowflake/ml/modeling/pipeline/pipeline.py +4 -12
  225. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  226. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  227. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  228. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  229. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  230. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  231. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  232. snowflake/ml/modeling/svm/svc.py +60 -21
  233. snowflake/ml/modeling/svm/svr.py +60 -21
  234. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  235. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  236. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  237. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  238. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  239. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  240. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  241. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  242. snowflake/ml/registry/_manager/model_manager.py +20 -2
  243. snowflake/ml/registry/model_registry.py +1 -1
  244. snowflake/ml/registry/registry.py +1 -2
  245. snowflake/ml/utils/sql_client.py +22 -0
  246. snowflake/ml/version.py +1 -1
  247. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
  248. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
  249. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  250. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  251. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  252. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -25,12 +23,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
25
23
  from snowflake.ml._internal import telemetry
26
24
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
27
25
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
28
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
26
+ from snowflake.ml._internal.utils import identifier
29
27
  from snowflake.snowpark import DataFrame, Session
30
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
31
29
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
32
30
  from snowflake.ml.modeling._internal.transformer_protocols import (
33
- ModelTransformHandlers,
34
31
  BatchInferenceKwargsTypedDict,
35
32
  ScoreKwargsTypedDict
36
33
  )
@@ -482,12 +479,23 @@ class GenericUnivariateSelect(BaseTransformer):
482
479
  autogenerated=self._autogenerated,
483
480
  subproject=_SUBPROJECT,
484
481
  )
485
- output_result, fitted_estimator = model_trainer.train_fit_predict(
486
- drop_input_cols=self._drop_input_cols,
487
- expected_output_cols_list=(
488
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
489
- ),
482
+ expected_output_cols = (
483
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
490
484
  )
485
+ if isinstance(dataset, DataFrame):
486
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
487
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
488
+ )
489
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
490
+ drop_input_cols=self._drop_input_cols,
491
+ expected_output_cols_list=expected_output_cols,
492
+ example_output_pd_df=example_output_pd_df,
493
+ )
494
+ else:
495
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
496
+ drop_input_cols=self._drop_input_cols,
497
+ expected_output_cols_list=expected_output_cols,
498
+ )
491
499
  self._sklearn_object = fitted_estimator
492
500
  self._is_fitted = True
493
501
  return output_result
@@ -568,12 +576,41 @@ class GenericUnivariateSelect(BaseTransformer):
568
576
 
569
577
  return rv
570
578
 
571
- def _align_expected_output_names(
572
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
573
- ) -> List[str]:
579
+ def _align_expected_output(
580
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
581
+ ) -> Tuple[List[str], pd.DataFrame]:
582
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
583
+ and output dataframe with 1 line.
584
+ If the method is fit_predict, run 2 lines of data.
585
+ """
574
586
  # in case the inferred output column names dimension is different
575
587
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
576
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
588
+
589
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
590
+ # so change the minimum of number of rows to 2
591
+ num_examples = 2
592
+ statement_params = telemetry.get_function_usage_statement_params(
593
+ project=_PROJECT,
594
+ subproject=_SUBPROJECT,
595
+ function_name=telemetry.get_statement_params_full_func_name(
596
+ inspect.currentframe(), GenericUnivariateSelect.__class__.__name__
597
+ ),
598
+ api_calls=[Session.call],
599
+ custom_tags={"autogen": True} if self._autogenerated else None,
600
+ )
601
+ if output_cols_prefix == "fit_predict_":
602
+ if hasattr(self._sklearn_object, "n_clusters"):
603
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
604
+ num_examples = self._sklearn_object.n_clusters
605
+ elif hasattr(self._sklearn_object, "min_samples"):
606
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
607
+ num_examples = self._sklearn_object.min_samples
608
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
609
+ # LocalOutlierFactor expects n_neighbors <= n_samples
610
+ num_examples = self._sklearn_object.n_neighbors
611
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
612
+ else:
613
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
577
614
 
578
615
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
579
616
  # seen during the fit.
@@ -585,12 +622,14 @@ class GenericUnivariateSelect(BaseTransformer):
585
622
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
586
623
  if self.sample_weight_col:
587
624
  output_df_columns_set -= set(self.sample_weight_col)
625
+
588
626
  # if the dimension of inferred output column names is correct; use it
589
627
  if len(expected_output_cols_list) == len(output_df_columns_set):
590
- return expected_output_cols_list
628
+ return expected_output_cols_list, output_df_pd
591
629
  # otherwise, use the sklearn estimator's output
592
630
  else:
593
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
631
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
632
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
594
633
 
595
634
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
596
635
  @telemetry.send_api_usage_telemetry(
@@ -636,7 +675,7 @@ class GenericUnivariateSelect(BaseTransformer):
636
675
  drop_input_cols=self._drop_input_cols,
637
676
  expected_output_cols_type="float",
638
677
  )
639
- expected_output_cols = self._align_expected_output_names(
678
+ expected_output_cols, _ = self._align_expected_output(
640
679
  inference_method, dataset, expected_output_cols, output_cols_prefix
641
680
  )
642
681
 
@@ -702,7 +741,7 @@ class GenericUnivariateSelect(BaseTransformer):
702
741
  drop_input_cols=self._drop_input_cols,
703
742
  expected_output_cols_type="float",
704
743
  )
705
- expected_output_cols = self._align_expected_output_names(
744
+ expected_output_cols, _ = self._align_expected_output(
706
745
  inference_method, dataset, expected_output_cols, output_cols_prefix
707
746
  )
708
747
  elif isinstance(dataset, pd.DataFrame):
@@ -765,7 +804,7 @@ class GenericUnivariateSelect(BaseTransformer):
765
804
  drop_input_cols=self._drop_input_cols,
766
805
  expected_output_cols_type="float",
767
806
  )
768
- expected_output_cols = self._align_expected_output_names(
807
+ expected_output_cols, _ = self._align_expected_output(
769
808
  inference_method, dataset, expected_output_cols, output_cols_prefix
770
809
  )
771
810
 
@@ -830,7 +869,7 @@ class GenericUnivariateSelect(BaseTransformer):
830
869
  drop_input_cols = self._drop_input_cols,
831
870
  expected_output_cols_type="float",
832
871
  )
833
- expected_output_cols = self._align_expected_output_names(
872
+ expected_output_cols, _ = self._align_expected_output(
834
873
  inference_method, dataset, expected_output_cols, output_cols_prefix
835
874
  )
836
875
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -25,12 +23,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
25
23
  from snowflake.ml._internal import telemetry
26
24
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
27
25
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
28
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
26
+ from snowflake.ml._internal.utils import identifier
29
27
  from snowflake.snowpark import DataFrame, Session
30
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
31
29
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
32
30
  from snowflake.ml.modeling._internal.transformer_protocols import (
33
- ModelTransformHandlers,
34
31
  BatchInferenceKwargsTypedDict,
35
32
  ScoreKwargsTypedDict
36
33
  )
@@ -478,12 +475,23 @@ class SelectFdr(BaseTransformer):
478
475
  autogenerated=self._autogenerated,
479
476
  subproject=_SUBPROJECT,
480
477
  )
481
- output_result, fitted_estimator = model_trainer.train_fit_predict(
482
- drop_input_cols=self._drop_input_cols,
483
- expected_output_cols_list=(
484
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
485
- ),
478
+ expected_output_cols = (
479
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
486
480
  )
481
+ if isinstance(dataset, DataFrame):
482
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
483
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
484
+ )
485
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
486
+ drop_input_cols=self._drop_input_cols,
487
+ expected_output_cols_list=expected_output_cols,
488
+ example_output_pd_df=example_output_pd_df,
489
+ )
490
+ else:
491
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
492
+ drop_input_cols=self._drop_input_cols,
493
+ expected_output_cols_list=expected_output_cols,
494
+ )
487
495
  self._sklearn_object = fitted_estimator
488
496
  self._is_fitted = True
489
497
  return output_result
@@ -564,12 +572,41 @@ class SelectFdr(BaseTransformer):
564
572
 
565
573
  return rv
566
574
 
567
- def _align_expected_output_names(
568
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
569
- ) -> List[str]:
575
+ def _align_expected_output(
576
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
577
+ ) -> Tuple[List[str], pd.DataFrame]:
578
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
579
+ and output dataframe with 1 line.
580
+ If the method is fit_predict, run 2 lines of data.
581
+ """
570
582
  # in case the inferred output column names dimension is different
571
583
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
572
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
584
+
585
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
586
+ # so change the minimum of number of rows to 2
587
+ num_examples = 2
588
+ statement_params = telemetry.get_function_usage_statement_params(
589
+ project=_PROJECT,
590
+ subproject=_SUBPROJECT,
591
+ function_name=telemetry.get_statement_params_full_func_name(
592
+ inspect.currentframe(), SelectFdr.__class__.__name__
593
+ ),
594
+ api_calls=[Session.call],
595
+ custom_tags={"autogen": True} if self._autogenerated else None,
596
+ )
597
+ if output_cols_prefix == "fit_predict_":
598
+ if hasattr(self._sklearn_object, "n_clusters"):
599
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
600
+ num_examples = self._sklearn_object.n_clusters
601
+ elif hasattr(self._sklearn_object, "min_samples"):
602
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
603
+ num_examples = self._sklearn_object.min_samples
604
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
605
+ # LocalOutlierFactor expects n_neighbors <= n_samples
606
+ num_examples = self._sklearn_object.n_neighbors
607
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
608
+ else:
609
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
573
610
 
574
611
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
575
612
  # seen during the fit.
@@ -581,12 +618,14 @@ class SelectFdr(BaseTransformer):
581
618
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
582
619
  if self.sample_weight_col:
583
620
  output_df_columns_set -= set(self.sample_weight_col)
621
+
584
622
  # if the dimension of inferred output column names is correct; use it
585
623
  if len(expected_output_cols_list) == len(output_df_columns_set):
586
- return expected_output_cols_list
624
+ return expected_output_cols_list, output_df_pd
587
625
  # otherwise, use the sklearn estimator's output
588
626
  else:
589
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
627
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
628
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
590
629
 
591
630
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
592
631
  @telemetry.send_api_usage_telemetry(
@@ -632,7 +671,7 @@ class SelectFdr(BaseTransformer):
632
671
  drop_input_cols=self._drop_input_cols,
633
672
  expected_output_cols_type="float",
634
673
  )
635
- expected_output_cols = self._align_expected_output_names(
674
+ expected_output_cols, _ = self._align_expected_output(
636
675
  inference_method, dataset, expected_output_cols, output_cols_prefix
637
676
  )
638
677
 
@@ -698,7 +737,7 @@ class SelectFdr(BaseTransformer):
698
737
  drop_input_cols=self._drop_input_cols,
699
738
  expected_output_cols_type="float",
700
739
  )
701
- expected_output_cols = self._align_expected_output_names(
740
+ expected_output_cols, _ = self._align_expected_output(
702
741
  inference_method, dataset, expected_output_cols, output_cols_prefix
703
742
  )
704
743
  elif isinstance(dataset, pd.DataFrame):
@@ -761,7 +800,7 @@ class SelectFdr(BaseTransformer):
761
800
  drop_input_cols=self._drop_input_cols,
762
801
  expected_output_cols_type="float",
763
802
  )
764
- expected_output_cols = self._align_expected_output_names(
803
+ expected_output_cols, _ = self._align_expected_output(
765
804
  inference_method, dataset, expected_output_cols, output_cols_prefix
766
805
  )
767
806
 
@@ -826,7 +865,7 @@ class SelectFdr(BaseTransformer):
826
865
  drop_input_cols = self._drop_input_cols,
827
866
  expected_output_cols_type="float",
828
867
  )
829
- expected_output_cols = self._align_expected_output_names(
868
+ expected_output_cols, _ = self._align_expected_output(
830
869
  inference_method, dataset, expected_output_cols, output_cols_prefix
831
870
  )
832
871
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -25,12 +23,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
25
23
  from snowflake.ml._internal import telemetry
26
24
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
27
25
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
28
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
26
+ from snowflake.ml._internal.utils import identifier
29
27
  from snowflake.snowpark import DataFrame, Session
30
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
31
29
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
32
30
  from snowflake.ml.modeling._internal.transformer_protocols import (
33
- ModelTransformHandlers,
34
31
  BatchInferenceKwargsTypedDict,
35
32
  ScoreKwargsTypedDict
36
33
  )
@@ -478,12 +475,23 @@ class SelectFpr(BaseTransformer):
478
475
  autogenerated=self._autogenerated,
479
476
  subproject=_SUBPROJECT,
480
477
  )
481
- output_result, fitted_estimator = model_trainer.train_fit_predict(
482
- drop_input_cols=self._drop_input_cols,
483
- expected_output_cols_list=(
484
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
485
- ),
478
+ expected_output_cols = (
479
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
486
480
  )
481
+ if isinstance(dataset, DataFrame):
482
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
483
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
484
+ )
485
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
486
+ drop_input_cols=self._drop_input_cols,
487
+ expected_output_cols_list=expected_output_cols,
488
+ example_output_pd_df=example_output_pd_df,
489
+ )
490
+ else:
491
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
492
+ drop_input_cols=self._drop_input_cols,
493
+ expected_output_cols_list=expected_output_cols,
494
+ )
487
495
  self._sklearn_object = fitted_estimator
488
496
  self._is_fitted = True
489
497
  return output_result
@@ -564,12 +572,41 @@ class SelectFpr(BaseTransformer):
564
572
 
565
573
  return rv
566
574
 
567
- def _align_expected_output_names(
568
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
569
- ) -> List[str]:
575
+ def _align_expected_output(
576
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
577
+ ) -> Tuple[List[str], pd.DataFrame]:
578
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
579
+ and output dataframe with 1 line.
580
+ If the method is fit_predict, run 2 lines of data.
581
+ """
570
582
  # in case the inferred output column names dimension is different
571
583
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
572
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
584
+
585
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
586
+ # so change the minimum of number of rows to 2
587
+ num_examples = 2
588
+ statement_params = telemetry.get_function_usage_statement_params(
589
+ project=_PROJECT,
590
+ subproject=_SUBPROJECT,
591
+ function_name=telemetry.get_statement_params_full_func_name(
592
+ inspect.currentframe(), SelectFpr.__class__.__name__
593
+ ),
594
+ api_calls=[Session.call],
595
+ custom_tags={"autogen": True} if self._autogenerated else None,
596
+ )
597
+ if output_cols_prefix == "fit_predict_":
598
+ if hasattr(self._sklearn_object, "n_clusters"):
599
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
600
+ num_examples = self._sklearn_object.n_clusters
601
+ elif hasattr(self._sklearn_object, "min_samples"):
602
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
603
+ num_examples = self._sklearn_object.min_samples
604
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
605
+ # LocalOutlierFactor expects n_neighbors <= n_samples
606
+ num_examples = self._sklearn_object.n_neighbors
607
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
608
+ else:
609
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
573
610
 
574
611
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
575
612
  # seen during the fit.
@@ -581,12 +618,14 @@ class SelectFpr(BaseTransformer):
581
618
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
582
619
  if self.sample_weight_col:
583
620
  output_df_columns_set -= set(self.sample_weight_col)
621
+
584
622
  # if the dimension of inferred output column names is correct; use it
585
623
  if len(expected_output_cols_list) == len(output_df_columns_set):
586
- return expected_output_cols_list
624
+ return expected_output_cols_list, output_df_pd
587
625
  # otherwise, use the sklearn estimator's output
588
626
  else:
589
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
627
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
628
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
590
629
 
591
630
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
592
631
  @telemetry.send_api_usage_telemetry(
@@ -632,7 +671,7 @@ class SelectFpr(BaseTransformer):
632
671
  drop_input_cols=self._drop_input_cols,
633
672
  expected_output_cols_type="float",
634
673
  )
635
- expected_output_cols = self._align_expected_output_names(
674
+ expected_output_cols, _ = self._align_expected_output(
636
675
  inference_method, dataset, expected_output_cols, output_cols_prefix
637
676
  )
638
677
 
@@ -698,7 +737,7 @@ class SelectFpr(BaseTransformer):
698
737
  drop_input_cols=self._drop_input_cols,
699
738
  expected_output_cols_type="float",
700
739
  )
701
- expected_output_cols = self._align_expected_output_names(
740
+ expected_output_cols, _ = self._align_expected_output(
702
741
  inference_method, dataset, expected_output_cols, output_cols_prefix
703
742
  )
704
743
  elif isinstance(dataset, pd.DataFrame):
@@ -761,7 +800,7 @@ class SelectFpr(BaseTransformer):
761
800
  drop_input_cols=self._drop_input_cols,
762
801
  expected_output_cols_type="float",
763
802
  )
764
- expected_output_cols = self._align_expected_output_names(
803
+ expected_output_cols, _ = self._align_expected_output(
765
804
  inference_method, dataset, expected_output_cols, output_cols_prefix
766
805
  )
767
806
 
@@ -826,7 +865,7 @@ class SelectFpr(BaseTransformer):
826
865
  drop_input_cols = self._drop_input_cols,
827
866
  expected_output_cols_type="float",
828
867
  )
829
- expected_output_cols = self._align_expected_output_names(
868
+ expected_output_cols, _ = self._align_expected_output(
830
869
  inference_method, dataset, expected_output_cols, output_cols_prefix
831
870
  )
832
871
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -25,12 +23,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
25
23
  from snowflake.ml._internal import telemetry
26
24
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
27
25
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
28
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
26
+ from snowflake.ml._internal.utils import identifier
29
27
  from snowflake.snowpark import DataFrame, Session
30
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
31
29
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
32
30
  from snowflake.ml.modeling._internal.transformer_protocols import (
33
- ModelTransformHandlers,
34
31
  BatchInferenceKwargsTypedDict,
35
32
  ScoreKwargsTypedDict
36
33
  )
@@ -478,12 +475,23 @@ class SelectFwe(BaseTransformer):
478
475
  autogenerated=self._autogenerated,
479
476
  subproject=_SUBPROJECT,
480
477
  )
481
- output_result, fitted_estimator = model_trainer.train_fit_predict(
482
- drop_input_cols=self._drop_input_cols,
483
- expected_output_cols_list=(
484
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
485
- ),
478
+ expected_output_cols = (
479
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
486
480
  )
481
+ if isinstance(dataset, DataFrame):
482
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
483
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
484
+ )
485
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
486
+ drop_input_cols=self._drop_input_cols,
487
+ expected_output_cols_list=expected_output_cols,
488
+ example_output_pd_df=example_output_pd_df,
489
+ )
490
+ else:
491
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
492
+ drop_input_cols=self._drop_input_cols,
493
+ expected_output_cols_list=expected_output_cols,
494
+ )
487
495
  self._sklearn_object = fitted_estimator
488
496
  self._is_fitted = True
489
497
  return output_result
@@ -564,12 +572,41 @@ class SelectFwe(BaseTransformer):
564
572
 
565
573
  return rv
566
574
 
567
- def _align_expected_output_names(
568
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
569
- ) -> List[str]:
575
+ def _align_expected_output(
576
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
577
+ ) -> Tuple[List[str], pd.DataFrame]:
578
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
579
+ and output dataframe with 1 line.
580
+ If the method is fit_predict, run 2 lines of data.
581
+ """
570
582
  # in case the inferred output column names dimension is different
571
583
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
572
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
584
+
585
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
586
+ # so change the minimum of number of rows to 2
587
+ num_examples = 2
588
+ statement_params = telemetry.get_function_usage_statement_params(
589
+ project=_PROJECT,
590
+ subproject=_SUBPROJECT,
591
+ function_name=telemetry.get_statement_params_full_func_name(
592
+ inspect.currentframe(), SelectFwe.__class__.__name__
593
+ ),
594
+ api_calls=[Session.call],
595
+ custom_tags={"autogen": True} if self._autogenerated else None,
596
+ )
597
+ if output_cols_prefix == "fit_predict_":
598
+ if hasattr(self._sklearn_object, "n_clusters"):
599
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
600
+ num_examples = self._sklearn_object.n_clusters
601
+ elif hasattr(self._sklearn_object, "min_samples"):
602
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
603
+ num_examples = self._sklearn_object.min_samples
604
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
605
+ # LocalOutlierFactor expects n_neighbors <= n_samples
606
+ num_examples = self._sklearn_object.n_neighbors
607
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
608
+ else:
609
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
573
610
 
574
611
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
575
612
  # seen during the fit.
@@ -581,12 +618,14 @@ class SelectFwe(BaseTransformer):
581
618
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
582
619
  if self.sample_weight_col:
583
620
  output_df_columns_set -= set(self.sample_weight_col)
621
+
584
622
  # if the dimension of inferred output column names is correct; use it
585
623
  if len(expected_output_cols_list) == len(output_df_columns_set):
586
- return expected_output_cols_list
624
+ return expected_output_cols_list, output_df_pd
587
625
  # otherwise, use the sklearn estimator's output
588
626
  else:
589
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
627
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
628
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
590
629
 
591
630
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
592
631
  @telemetry.send_api_usage_telemetry(
@@ -632,7 +671,7 @@ class SelectFwe(BaseTransformer):
632
671
  drop_input_cols=self._drop_input_cols,
633
672
  expected_output_cols_type="float",
634
673
  )
635
- expected_output_cols = self._align_expected_output_names(
674
+ expected_output_cols, _ = self._align_expected_output(
636
675
  inference_method, dataset, expected_output_cols, output_cols_prefix
637
676
  )
638
677
 
@@ -698,7 +737,7 @@ class SelectFwe(BaseTransformer):
698
737
  drop_input_cols=self._drop_input_cols,
699
738
  expected_output_cols_type="float",
700
739
  )
701
- expected_output_cols = self._align_expected_output_names(
740
+ expected_output_cols, _ = self._align_expected_output(
702
741
  inference_method, dataset, expected_output_cols, output_cols_prefix
703
742
  )
704
743
  elif isinstance(dataset, pd.DataFrame):
@@ -761,7 +800,7 @@ class SelectFwe(BaseTransformer):
761
800
  drop_input_cols=self._drop_input_cols,
762
801
  expected_output_cols_type="float",
763
802
  )
764
- expected_output_cols = self._align_expected_output_names(
803
+ expected_output_cols, _ = self._align_expected_output(
765
804
  inference_method, dataset, expected_output_cols, output_cols_prefix
766
805
  )
767
806
 
@@ -826,7 +865,7 @@ class SelectFwe(BaseTransformer):
826
865
  drop_input_cols = self._drop_input_cols,
827
866
  expected_output_cols_type="float",
828
867
  )
829
- expected_output_cols = self._align_expected_output_names(
868
+ expected_output_cols, _ = self._align_expected_output(
830
869
  inference_method, dataset, expected_output_cols, output_cols_prefix
831
870
  )
832
871