snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +156 -20
  5. snowflake/ml/_internal/utils/identifier.py +48 -11
  6. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  7. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  8. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  9. snowflake/ml/_internal/utils/table_manager.py +19 -1
  10. snowflake/ml/_internal/utils/uri.py +2 -2
  11. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  12. snowflake/ml/data/data_connector.py +88 -9
  13. snowflake/ml/data/data_ingestor.py +18 -1
  14. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  15. snowflake/ml/data/torch_utils.py +68 -0
  16. snowflake/ml/dataset/dataset.py +1 -3
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +9 -3
  19. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  20. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  21. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  22. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  23. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  24. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  26. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  27. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  29. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  31. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  32. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  33. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  34. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  35. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  36. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  37. snowflake/ml/feature_store/feature_store.py +100 -41
  38. snowflake/ml/feature_store/feature_view.py +149 -5
  39. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  40. snowflake/ml/fileset/fileset.py +1 -1
  41. snowflake/ml/fileset/sfcfs.py +9 -3
  42. snowflake/ml/model/_client/model/model_impl.py +11 -2
  43. snowflake/ml/model/_client/model/model_version_impl.py +186 -20
  44. snowflake/ml/model/_client/ops/model_ops.py +144 -30
  45. snowflake/ml/model/_client/ops/service_ops.py +312 -0
  46. snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
  47. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
  48. snowflake/ml/model/_client/sql/model_version.py +13 -4
  49. snowflake/ml/model/_client/sql/service.py +196 -0
  50. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  51. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  52. snowflake/ml/model/_model_composer/model_composer.py +5 -0
  53. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
  54. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  55. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  57. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  58. snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
  59. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  60. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
  61. snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
  62. snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
  63. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  64. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  65. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  66. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  67. snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
  68. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
  69. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  70. snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
  71. snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
  72. snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
  73. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
  74. snowflake/ml/model/_packager/model_packager.py +4 -1
  75. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  76. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  77. snowflake/ml/model/_signatures/utils.py +9 -0
  78. snowflake/ml/model/models/llm.py +3 -1
  79. snowflake/ml/model/type_hints.py +10 -4
  80. snowflake/ml/modeling/_internal/constants.py +1 -0
  81. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  82. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  83. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  84. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  85. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  86. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  87. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  88. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  89. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  90. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  91. snowflake/ml/modeling/cluster/birch.py +60 -21
  92. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  93. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  94. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  95. snowflake/ml/modeling/cluster/k_means.py +60 -21
  96. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  97. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  98. snowflake/ml/modeling/cluster/optics.py +60 -21
  99. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  100. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  101. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  102. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  103. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  104. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  105. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  106. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  107. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  108. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  109. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  110. snowflake/ml/modeling/covariance/oas.py +60 -21
  111. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  112. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  113. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  114. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  115. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  116. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  117. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  118. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  119. snowflake/ml/modeling/decomposition/pca.py +60 -21
  120. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  121. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  122. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  123. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  124. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  125. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  126. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  127. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  128. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  129. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  130. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  131. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  132. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  133. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  134. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  135. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  136. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  137. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  138. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  139. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  140. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  141. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  142. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  143. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  144. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  145. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  146. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  147. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  148. snowflake/ml/modeling/framework/base.py +28 -19
  149. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  150. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  151. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  152. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  153. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  154. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  155. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  156. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  157. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  158. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  159. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  160. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  161. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  162. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  163. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  164. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  165. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  166. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  167. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  168. snowflake/ml/modeling/linear_model/lars.py +60 -21
  169. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  170. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  171. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  172. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  173. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  174. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  175. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  176. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  177. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  178. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  179. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  180. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  181. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  182. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  183. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  184. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  185. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  186. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  187. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  188. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  189. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  190. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  191. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  192. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  193. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  194. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  195. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  196. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  197. snowflake/ml/modeling/manifold/isomap.py +60 -21
  198. snowflake/ml/modeling/manifold/mds.py +60 -21
  199. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  200. snowflake/ml/modeling/manifold/tsne.py +60 -21
  201. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  202. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  203. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  204. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  205. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  206. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  207. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  208. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  209. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  210. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  211. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  212. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  213. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  214. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  215. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  216. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  217. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  218. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  219. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  220. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  221. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  222. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  223. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  224. snowflake/ml/modeling/pipeline/pipeline.py +4 -12
  225. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  226. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  227. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  228. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  229. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  230. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  231. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  232. snowflake/ml/modeling/svm/svc.py +60 -21
  233. snowflake/ml/modeling/svm/svr.py +60 -21
  234. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  235. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  236. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  237. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  238. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  239. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  240. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  241. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  242. snowflake/ml/registry/_manager/model_manager.py +20 -2
  243. snowflake/ml/registry/model_registry.py +1 -1
  244. snowflake/ml/registry/registry.py +1 -2
  245. snowflake/ml/utils/sql_client.py +22 -0
  246. snowflake/ml/version.py +1 -1
  247. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
  248. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
  249. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  250. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  251. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  252. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -570,12 +567,23 @@ class Ridge(BaseTransformer):
570
567
  autogenerated=self._autogenerated,
571
568
  subproject=_SUBPROJECT,
572
569
  )
573
- output_result, fitted_estimator = model_trainer.train_fit_predict(
574
- drop_input_cols=self._drop_input_cols,
575
- expected_output_cols_list=(
576
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
577
- ),
570
+ expected_output_cols = (
571
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
578
572
  )
573
+ if isinstance(dataset, DataFrame):
574
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
575
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
576
+ )
577
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
578
+ drop_input_cols=self._drop_input_cols,
579
+ expected_output_cols_list=expected_output_cols,
580
+ example_output_pd_df=example_output_pd_df,
581
+ )
582
+ else:
583
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
584
+ drop_input_cols=self._drop_input_cols,
585
+ expected_output_cols_list=expected_output_cols,
586
+ )
579
587
  self._sklearn_object = fitted_estimator
580
588
  self._is_fitted = True
581
589
  return output_result
@@ -654,12 +662,41 @@ class Ridge(BaseTransformer):
654
662
 
655
663
  return rv
656
664
 
657
- def _align_expected_output_names(
658
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
659
- ) -> List[str]:
665
+ def _align_expected_output(
666
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
667
+ ) -> Tuple[List[str], pd.DataFrame]:
668
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
669
+ and output dataframe with 1 line.
670
+ If the method is fit_predict, run 2 lines of data.
671
+ """
660
672
  # in case the inferred output column names dimension is different
661
673
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
662
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
674
+
675
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
676
+ # so change the minimum of number of rows to 2
677
+ num_examples = 2
678
+ statement_params = telemetry.get_function_usage_statement_params(
679
+ project=_PROJECT,
680
+ subproject=_SUBPROJECT,
681
+ function_name=telemetry.get_statement_params_full_func_name(
682
+ inspect.currentframe(), Ridge.__class__.__name__
683
+ ),
684
+ api_calls=[Session.call],
685
+ custom_tags={"autogen": True} if self._autogenerated else None,
686
+ )
687
+ if output_cols_prefix == "fit_predict_":
688
+ if hasattr(self._sklearn_object, "n_clusters"):
689
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
690
+ num_examples = self._sklearn_object.n_clusters
691
+ elif hasattr(self._sklearn_object, "min_samples"):
692
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
693
+ num_examples = self._sklearn_object.min_samples
694
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
695
+ # LocalOutlierFactor expects n_neighbors <= n_samples
696
+ num_examples = self._sklearn_object.n_neighbors
697
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
698
+ else:
699
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
663
700
 
664
701
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
665
702
  # seen during the fit.
@@ -671,12 +708,14 @@ class Ridge(BaseTransformer):
671
708
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
672
709
  if self.sample_weight_col:
673
710
  output_df_columns_set -= set(self.sample_weight_col)
711
+
674
712
  # if the dimension of inferred output column names is correct; use it
675
713
  if len(expected_output_cols_list) == len(output_df_columns_set):
676
- return expected_output_cols_list
714
+ return expected_output_cols_list, output_df_pd
677
715
  # otherwise, use the sklearn estimator's output
678
716
  else:
679
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
717
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
718
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
680
719
 
681
720
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
682
721
  @telemetry.send_api_usage_telemetry(
@@ -722,7 +761,7 @@ class Ridge(BaseTransformer):
722
761
  drop_input_cols=self._drop_input_cols,
723
762
  expected_output_cols_type="float",
724
763
  )
725
- expected_output_cols = self._align_expected_output_names(
764
+ expected_output_cols, _ = self._align_expected_output(
726
765
  inference_method, dataset, expected_output_cols, output_cols_prefix
727
766
  )
728
767
 
@@ -788,7 +827,7 @@ class Ridge(BaseTransformer):
788
827
  drop_input_cols=self._drop_input_cols,
789
828
  expected_output_cols_type="float",
790
829
  )
791
- expected_output_cols = self._align_expected_output_names(
830
+ expected_output_cols, _ = self._align_expected_output(
792
831
  inference_method, dataset, expected_output_cols, output_cols_prefix
793
832
  )
794
833
  elif isinstance(dataset, pd.DataFrame):
@@ -851,7 +890,7 @@ class Ridge(BaseTransformer):
851
890
  drop_input_cols=self._drop_input_cols,
852
891
  expected_output_cols_type="float",
853
892
  )
854
- expected_output_cols = self._align_expected_output_names(
893
+ expected_output_cols, _ = self._align_expected_output(
855
894
  inference_method, dataset, expected_output_cols, output_cols_prefix
856
895
  )
857
896
 
@@ -916,7 +955,7 @@ class Ridge(BaseTransformer):
916
955
  drop_input_cols = self._drop_input_cols,
917
956
  expected_output_cols_type="float",
918
957
  )
919
- expected_output_cols = self._align_expected_output_names(
958
+ expected_output_cols, _ = self._align_expected_output(
920
959
  inference_method, dataset, expected_output_cols, output_cols_prefix
921
960
  )
922
961
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -570,12 +567,23 @@ class RidgeClassifier(BaseTransformer):
570
567
  autogenerated=self._autogenerated,
571
568
  subproject=_SUBPROJECT,
572
569
  )
573
- output_result, fitted_estimator = model_trainer.train_fit_predict(
574
- drop_input_cols=self._drop_input_cols,
575
- expected_output_cols_list=(
576
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
577
- ),
570
+ expected_output_cols = (
571
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
578
572
  )
573
+ if isinstance(dataset, DataFrame):
574
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
575
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
576
+ )
577
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
578
+ drop_input_cols=self._drop_input_cols,
579
+ expected_output_cols_list=expected_output_cols,
580
+ example_output_pd_df=example_output_pd_df,
581
+ )
582
+ else:
583
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
584
+ drop_input_cols=self._drop_input_cols,
585
+ expected_output_cols_list=expected_output_cols,
586
+ )
579
587
  self._sklearn_object = fitted_estimator
580
588
  self._is_fitted = True
581
589
  return output_result
@@ -654,12 +662,41 @@ class RidgeClassifier(BaseTransformer):
654
662
 
655
663
  return rv
656
664
 
657
- def _align_expected_output_names(
658
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
659
- ) -> List[str]:
665
+ def _align_expected_output(
666
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
667
+ ) -> Tuple[List[str], pd.DataFrame]:
668
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
669
+ and output dataframe with 1 line.
670
+ If the method is fit_predict, run 2 lines of data.
671
+ """
660
672
  # in case the inferred output column names dimension is different
661
673
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
662
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
674
+
675
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
676
+ # so change the minimum of number of rows to 2
677
+ num_examples = 2
678
+ statement_params = telemetry.get_function_usage_statement_params(
679
+ project=_PROJECT,
680
+ subproject=_SUBPROJECT,
681
+ function_name=telemetry.get_statement_params_full_func_name(
682
+ inspect.currentframe(), RidgeClassifier.__class__.__name__
683
+ ),
684
+ api_calls=[Session.call],
685
+ custom_tags={"autogen": True} if self._autogenerated else None,
686
+ )
687
+ if output_cols_prefix == "fit_predict_":
688
+ if hasattr(self._sklearn_object, "n_clusters"):
689
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
690
+ num_examples = self._sklearn_object.n_clusters
691
+ elif hasattr(self._sklearn_object, "min_samples"):
692
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
693
+ num_examples = self._sklearn_object.min_samples
694
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
695
+ # LocalOutlierFactor expects n_neighbors <= n_samples
696
+ num_examples = self._sklearn_object.n_neighbors
697
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
698
+ else:
699
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
663
700
 
664
701
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
665
702
  # seen during the fit.
@@ -671,12 +708,14 @@ class RidgeClassifier(BaseTransformer):
671
708
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
672
709
  if self.sample_weight_col:
673
710
  output_df_columns_set -= set(self.sample_weight_col)
711
+
674
712
  # if the dimension of inferred output column names is correct; use it
675
713
  if len(expected_output_cols_list) == len(output_df_columns_set):
676
- return expected_output_cols_list
714
+ return expected_output_cols_list, output_df_pd
677
715
  # otherwise, use the sklearn estimator's output
678
716
  else:
679
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
717
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
718
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
680
719
 
681
720
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
682
721
  @telemetry.send_api_usage_telemetry(
@@ -722,7 +761,7 @@ class RidgeClassifier(BaseTransformer):
722
761
  drop_input_cols=self._drop_input_cols,
723
762
  expected_output_cols_type="float",
724
763
  )
725
- expected_output_cols = self._align_expected_output_names(
764
+ expected_output_cols, _ = self._align_expected_output(
726
765
  inference_method, dataset, expected_output_cols, output_cols_prefix
727
766
  )
728
767
 
@@ -788,7 +827,7 @@ class RidgeClassifier(BaseTransformer):
788
827
  drop_input_cols=self._drop_input_cols,
789
828
  expected_output_cols_type="float",
790
829
  )
791
- expected_output_cols = self._align_expected_output_names(
830
+ expected_output_cols, _ = self._align_expected_output(
792
831
  inference_method, dataset, expected_output_cols, output_cols_prefix
793
832
  )
794
833
  elif isinstance(dataset, pd.DataFrame):
@@ -853,7 +892,7 @@ class RidgeClassifier(BaseTransformer):
853
892
  drop_input_cols=self._drop_input_cols,
854
893
  expected_output_cols_type="float",
855
894
  )
856
- expected_output_cols = self._align_expected_output_names(
895
+ expected_output_cols, _ = self._align_expected_output(
857
896
  inference_method, dataset, expected_output_cols, output_cols_prefix
858
897
  )
859
898
 
@@ -918,7 +957,7 @@ class RidgeClassifier(BaseTransformer):
918
957
  drop_input_cols = self._drop_input_cols,
919
958
  expected_output_cols_type="float",
920
959
  )
921
- expected_output_cols = self._align_expected_output_names(
960
+ expected_output_cols, _ = self._align_expected_output(
922
961
  inference_method, dataset, expected_output_cols, output_cols_prefix
923
962
  )
924
963
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -521,12 +518,23 @@ class RidgeClassifierCV(BaseTransformer):
521
518
  autogenerated=self._autogenerated,
522
519
  subproject=_SUBPROJECT,
523
520
  )
524
- output_result, fitted_estimator = model_trainer.train_fit_predict(
525
- drop_input_cols=self._drop_input_cols,
526
- expected_output_cols_list=(
527
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
528
- ),
521
+ expected_output_cols = (
522
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
529
523
  )
524
+ if isinstance(dataset, DataFrame):
525
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
526
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
527
+ )
528
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
529
+ drop_input_cols=self._drop_input_cols,
530
+ expected_output_cols_list=expected_output_cols,
531
+ example_output_pd_df=example_output_pd_df,
532
+ )
533
+ else:
534
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
535
+ drop_input_cols=self._drop_input_cols,
536
+ expected_output_cols_list=expected_output_cols,
537
+ )
530
538
  self._sklearn_object = fitted_estimator
531
539
  self._is_fitted = True
532
540
  return output_result
@@ -605,12 +613,41 @@ class RidgeClassifierCV(BaseTransformer):
605
613
 
606
614
  return rv
607
615
 
608
- def _align_expected_output_names(
609
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
610
- ) -> List[str]:
616
+ def _align_expected_output(
617
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
618
+ ) -> Tuple[List[str], pd.DataFrame]:
619
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
620
+ and output dataframe with 1 line.
621
+ If the method is fit_predict, run 2 lines of data.
622
+ """
611
623
  # in case the inferred output column names dimension is different
612
624
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
613
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
625
+
626
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
627
+ # so change the minimum of number of rows to 2
628
+ num_examples = 2
629
+ statement_params = telemetry.get_function_usage_statement_params(
630
+ project=_PROJECT,
631
+ subproject=_SUBPROJECT,
632
+ function_name=telemetry.get_statement_params_full_func_name(
633
+ inspect.currentframe(), RidgeClassifierCV.__class__.__name__
634
+ ),
635
+ api_calls=[Session.call],
636
+ custom_tags={"autogen": True} if self._autogenerated else None,
637
+ )
638
+ if output_cols_prefix == "fit_predict_":
639
+ if hasattr(self._sklearn_object, "n_clusters"):
640
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
641
+ num_examples = self._sklearn_object.n_clusters
642
+ elif hasattr(self._sklearn_object, "min_samples"):
643
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
644
+ num_examples = self._sklearn_object.min_samples
645
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
646
+ # LocalOutlierFactor expects n_neighbors <= n_samples
647
+ num_examples = self._sklearn_object.n_neighbors
648
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
649
+ else:
650
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
614
651
 
615
652
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
616
653
  # seen during the fit.
@@ -622,12 +659,14 @@ class RidgeClassifierCV(BaseTransformer):
622
659
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
623
660
  if self.sample_weight_col:
624
661
  output_df_columns_set -= set(self.sample_weight_col)
662
+
625
663
  # if the dimension of inferred output column names is correct; use it
626
664
  if len(expected_output_cols_list) == len(output_df_columns_set):
627
- return expected_output_cols_list
665
+ return expected_output_cols_list, output_df_pd
628
666
  # otherwise, use the sklearn estimator's output
629
667
  else:
630
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
668
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
669
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
631
670
 
632
671
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
633
672
  @telemetry.send_api_usage_telemetry(
@@ -673,7 +712,7 @@ class RidgeClassifierCV(BaseTransformer):
673
712
  drop_input_cols=self._drop_input_cols,
674
713
  expected_output_cols_type="float",
675
714
  )
676
- expected_output_cols = self._align_expected_output_names(
715
+ expected_output_cols, _ = self._align_expected_output(
677
716
  inference_method, dataset, expected_output_cols, output_cols_prefix
678
717
  )
679
718
 
@@ -739,7 +778,7 @@ class RidgeClassifierCV(BaseTransformer):
739
778
  drop_input_cols=self._drop_input_cols,
740
779
  expected_output_cols_type="float",
741
780
  )
742
- expected_output_cols = self._align_expected_output_names(
781
+ expected_output_cols, _ = self._align_expected_output(
743
782
  inference_method, dataset, expected_output_cols, output_cols_prefix
744
783
  )
745
784
  elif isinstance(dataset, pd.DataFrame):
@@ -804,7 +843,7 @@ class RidgeClassifierCV(BaseTransformer):
804
843
  drop_input_cols=self._drop_input_cols,
805
844
  expected_output_cols_type="float",
806
845
  )
807
- expected_output_cols = self._align_expected_output_names(
846
+ expected_output_cols, _ = self._align_expected_output(
808
847
  inference_method, dataset, expected_output_cols, output_cols_prefix
809
848
  )
810
849
 
@@ -869,7 +908,7 @@ class RidgeClassifierCV(BaseTransformer):
869
908
  drop_input_cols = self._drop_input_cols,
870
909
  expected_output_cols_type="float",
871
910
  )
872
- expected_output_cols = self._align_expected_output_names(
911
+ expected_output_cols, _ = self._align_expected_output(
873
912
  inference_method, dataset, expected_output_cols, output_cols_prefix
874
913
  )
875
914
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -542,12 +539,23 @@ class RidgeCV(BaseTransformer):
542
539
  autogenerated=self._autogenerated,
543
540
  subproject=_SUBPROJECT,
544
541
  )
545
- output_result, fitted_estimator = model_trainer.train_fit_predict(
546
- drop_input_cols=self._drop_input_cols,
547
- expected_output_cols_list=(
548
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
549
- ),
542
+ expected_output_cols = (
543
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
550
544
  )
545
+ if isinstance(dataset, DataFrame):
546
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
547
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
548
+ )
549
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
550
+ drop_input_cols=self._drop_input_cols,
551
+ expected_output_cols_list=expected_output_cols,
552
+ example_output_pd_df=example_output_pd_df,
553
+ )
554
+ else:
555
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
556
+ drop_input_cols=self._drop_input_cols,
557
+ expected_output_cols_list=expected_output_cols,
558
+ )
551
559
  self._sklearn_object = fitted_estimator
552
560
  self._is_fitted = True
553
561
  return output_result
@@ -626,12 +634,41 @@ class RidgeCV(BaseTransformer):
626
634
 
627
635
  return rv
628
636
 
629
- def _align_expected_output_names(
630
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
631
- ) -> List[str]:
637
+ def _align_expected_output(
638
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
639
+ ) -> Tuple[List[str], pd.DataFrame]:
640
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
641
+ and output dataframe with 1 line.
642
+ If the method is fit_predict, run 2 lines of data.
643
+ """
632
644
  # in case the inferred output column names dimension is different
633
645
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
634
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
646
+
647
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
648
+ # so change the minimum of number of rows to 2
649
+ num_examples = 2
650
+ statement_params = telemetry.get_function_usage_statement_params(
651
+ project=_PROJECT,
652
+ subproject=_SUBPROJECT,
653
+ function_name=telemetry.get_statement_params_full_func_name(
654
+ inspect.currentframe(), RidgeCV.__class__.__name__
655
+ ),
656
+ api_calls=[Session.call],
657
+ custom_tags={"autogen": True} if self._autogenerated else None,
658
+ )
659
+ if output_cols_prefix == "fit_predict_":
660
+ if hasattr(self._sklearn_object, "n_clusters"):
661
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
662
+ num_examples = self._sklearn_object.n_clusters
663
+ elif hasattr(self._sklearn_object, "min_samples"):
664
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
665
+ num_examples = self._sklearn_object.min_samples
666
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
667
+ # LocalOutlierFactor expects n_neighbors <= n_samples
668
+ num_examples = self._sklearn_object.n_neighbors
669
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
670
+ else:
671
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
635
672
 
636
673
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
637
674
  # seen during the fit.
@@ -643,12 +680,14 @@ class RidgeCV(BaseTransformer):
643
680
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
644
681
  if self.sample_weight_col:
645
682
  output_df_columns_set -= set(self.sample_weight_col)
683
+
646
684
  # if the dimension of inferred output column names is correct; use it
647
685
  if len(expected_output_cols_list) == len(output_df_columns_set):
648
- return expected_output_cols_list
686
+ return expected_output_cols_list, output_df_pd
649
687
  # otherwise, use the sklearn estimator's output
650
688
  else:
651
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
689
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
690
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
652
691
 
653
692
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
654
693
  @telemetry.send_api_usage_telemetry(
@@ -694,7 +733,7 @@ class RidgeCV(BaseTransformer):
694
733
  drop_input_cols=self._drop_input_cols,
695
734
  expected_output_cols_type="float",
696
735
  )
697
- expected_output_cols = self._align_expected_output_names(
736
+ expected_output_cols, _ = self._align_expected_output(
698
737
  inference_method, dataset, expected_output_cols, output_cols_prefix
699
738
  )
700
739
 
@@ -760,7 +799,7 @@ class RidgeCV(BaseTransformer):
760
799
  drop_input_cols=self._drop_input_cols,
761
800
  expected_output_cols_type="float",
762
801
  )
763
- expected_output_cols = self._align_expected_output_names(
802
+ expected_output_cols, _ = self._align_expected_output(
764
803
  inference_method, dataset, expected_output_cols, output_cols_prefix
765
804
  )
766
805
  elif isinstance(dataset, pd.DataFrame):
@@ -823,7 +862,7 @@ class RidgeCV(BaseTransformer):
823
862
  drop_input_cols=self._drop_input_cols,
824
863
  expected_output_cols_type="float",
825
864
  )
826
- expected_output_cols = self._align_expected_output_names(
865
+ expected_output_cols, _ = self._align_expected_output(
827
866
  inference_method, dataset, expected_output_cols, output_cols_prefix
828
867
  )
829
868
 
@@ -888,7 +927,7 @@ class RidgeCV(BaseTransformer):
888
927
  drop_input_cols = self._drop_input_cols,
889
928
  expected_output_cols_type="float",
890
929
  )
891
- expected_output_cols = self._align_expected_output_names(
930
+ expected_output_cols, _ = self._align_expected_output(
892
931
  inference_method, dataset, expected_output_cols, output_cols_prefix
893
932
  )
894
933