snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +156 -20
  5. snowflake/ml/_internal/utils/identifier.py +48 -11
  6. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  7. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  8. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  9. snowflake/ml/_internal/utils/table_manager.py +19 -1
  10. snowflake/ml/_internal/utils/uri.py +2 -2
  11. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  12. snowflake/ml/data/data_connector.py +88 -9
  13. snowflake/ml/data/data_ingestor.py +18 -1
  14. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  15. snowflake/ml/data/torch_utils.py +68 -0
  16. snowflake/ml/dataset/dataset.py +1 -3
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +9 -3
  19. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  20. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  21. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  22. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  23. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  24. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  26. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  27. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  29. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  31. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  32. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  33. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  34. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  35. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  36. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  37. snowflake/ml/feature_store/feature_store.py +100 -41
  38. snowflake/ml/feature_store/feature_view.py +149 -5
  39. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  40. snowflake/ml/fileset/fileset.py +1 -1
  41. snowflake/ml/fileset/sfcfs.py +9 -3
  42. snowflake/ml/model/_client/model/model_impl.py +11 -2
  43. snowflake/ml/model/_client/model/model_version_impl.py +186 -20
  44. snowflake/ml/model/_client/ops/model_ops.py +144 -30
  45. snowflake/ml/model/_client/ops/service_ops.py +312 -0
  46. snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
  47. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
  48. snowflake/ml/model/_client/sql/model_version.py +13 -4
  49. snowflake/ml/model/_client/sql/service.py +196 -0
  50. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  51. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  52. snowflake/ml/model/_model_composer/model_composer.py +5 -0
  53. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
  54. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  55. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  57. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  58. snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
  59. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  60. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
  61. snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
  62. snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
  63. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  64. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  65. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  66. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  67. snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
  68. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
  69. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  70. snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
  71. snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
  72. snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
  73. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
  74. snowflake/ml/model/_packager/model_packager.py +4 -1
  75. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  76. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  77. snowflake/ml/model/_signatures/utils.py +9 -0
  78. snowflake/ml/model/models/llm.py +3 -1
  79. snowflake/ml/model/type_hints.py +10 -4
  80. snowflake/ml/modeling/_internal/constants.py +1 -0
  81. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  82. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  83. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  84. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  85. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  86. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  87. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  88. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  89. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  90. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  91. snowflake/ml/modeling/cluster/birch.py +60 -21
  92. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  93. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  94. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  95. snowflake/ml/modeling/cluster/k_means.py +60 -21
  96. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  97. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  98. snowflake/ml/modeling/cluster/optics.py +60 -21
  99. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  100. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  101. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  102. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  103. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  104. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  105. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  106. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  107. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  108. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  109. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  110. snowflake/ml/modeling/covariance/oas.py +60 -21
  111. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  112. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  113. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  114. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  115. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  116. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  117. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  118. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  119. snowflake/ml/modeling/decomposition/pca.py +60 -21
  120. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  121. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  122. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  123. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  124. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  125. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  126. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  127. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  128. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  129. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  130. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  131. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  132. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  133. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  134. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  135. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  136. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  137. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  138. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  139. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  140. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  141. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  142. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  143. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  144. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  145. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  146. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  147. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  148. snowflake/ml/modeling/framework/base.py +28 -19
  149. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  150. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  151. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  152. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  153. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  154. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  155. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  156. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  157. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  158. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  159. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  160. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  161. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  162. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  163. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  164. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  165. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  166. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  167. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  168. snowflake/ml/modeling/linear_model/lars.py +60 -21
  169. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  170. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  171. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  172. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  173. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  174. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  175. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  176. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  177. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  178. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  179. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  180. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  181. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  182. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  183. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  184. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  185. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  186. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  187. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  188. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  189. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  190. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  191. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  192. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  193. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  194. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  195. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  196. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  197. snowflake/ml/modeling/manifold/isomap.py +60 -21
  198. snowflake/ml/modeling/manifold/mds.py +60 -21
  199. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  200. snowflake/ml/modeling/manifold/tsne.py +60 -21
  201. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  202. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  203. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  204. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  205. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  206. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  207. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  208. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  209. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  210. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  211. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  212. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  213. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  214. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  215. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  216. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  217. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  218. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  219. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  220. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  221. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  222. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  223. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  224. snowflake/ml/modeling/pipeline/pipeline.py +4 -12
  225. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  226. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  227. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  228. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  229. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  230. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  231. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  232. snowflake/ml/modeling/svm/svc.py +60 -21
  233. snowflake/ml/modeling/svm/svr.py +60 -21
  234. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  235. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  236. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  237. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  238. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  239. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  240. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  241. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  242. snowflake/ml/registry/_manager/model_manager.py +20 -2
  243. snowflake/ml/registry/model_registry.py +1 -1
  244. snowflake/ml/registry/registry.py +1 -2
  245. snowflake/ml/utils/sql_client.py +22 -0
  246. snowflake/ml/version.py +1 -1
  247. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
  248. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
  249. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  250. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  251. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  252. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -506,12 +503,23 @@ class LGBMClassifier(BaseTransformer):
506
503
  autogenerated=self._autogenerated,
507
504
  subproject=_SUBPROJECT,
508
505
  )
509
- output_result, fitted_estimator = model_trainer.train_fit_predict(
510
- drop_input_cols=self._drop_input_cols,
511
- expected_output_cols_list=(
512
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
513
- ),
506
+ expected_output_cols = (
507
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
514
508
  )
509
+ if isinstance(dataset, DataFrame):
510
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
511
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
512
+ )
513
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
514
+ drop_input_cols=self._drop_input_cols,
515
+ expected_output_cols_list=expected_output_cols,
516
+ example_output_pd_df=example_output_pd_df,
517
+ )
518
+ else:
519
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
520
+ drop_input_cols=self._drop_input_cols,
521
+ expected_output_cols_list=expected_output_cols,
522
+ )
515
523
  self._sklearn_object = fitted_estimator
516
524
  self._is_fitted = True
517
525
  return output_result
@@ -590,12 +598,41 @@ class LGBMClassifier(BaseTransformer):
590
598
 
591
599
  return rv
592
600
 
593
- def _align_expected_output_names(
594
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
595
- ) -> List[str]:
601
+ def _align_expected_output(
602
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
603
+ ) -> Tuple[List[str], pd.DataFrame]:
604
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
605
+ and output dataframe with 1 line.
606
+ If the method is fit_predict, run 2 lines of data.
607
+ """
596
608
  # in case the inferred output column names dimension is different
597
609
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
598
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
610
+
611
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
612
+ # so change the minimum of number of rows to 2
613
+ num_examples = 2
614
+ statement_params = telemetry.get_function_usage_statement_params(
615
+ project=_PROJECT,
616
+ subproject=_SUBPROJECT,
617
+ function_name=telemetry.get_statement_params_full_func_name(
618
+ inspect.currentframe(), LGBMClassifier.__class__.__name__
619
+ ),
620
+ api_calls=[Session.call],
621
+ custom_tags={"autogen": True} if self._autogenerated else None,
622
+ )
623
+ if output_cols_prefix == "fit_predict_":
624
+ if hasattr(self._sklearn_object, "n_clusters"):
625
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
626
+ num_examples = self._sklearn_object.n_clusters
627
+ elif hasattr(self._sklearn_object, "min_samples"):
628
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
629
+ num_examples = self._sklearn_object.min_samples
630
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
631
+ # LocalOutlierFactor expects n_neighbors <= n_samples
632
+ num_examples = self._sklearn_object.n_neighbors
633
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
634
+ else:
635
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
599
636
 
600
637
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
601
638
  # seen during the fit.
@@ -607,12 +644,14 @@ class LGBMClassifier(BaseTransformer):
607
644
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
608
645
  if self.sample_weight_col:
609
646
  output_df_columns_set -= set(self.sample_weight_col)
647
+
610
648
  # if the dimension of inferred output column names is correct; use it
611
649
  if len(expected_output_cols_list) == len(output_df_columns_set):
612
- return expected_output_cols_list
650
+ return expected_output_cols_list, output_df_pd
613
651
  # otherwise, use the sklearn estimator's output
614
652
  else:
615
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
653
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
654
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
616
655
 
617
656
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
618
657
  @telemetry.send_api_usage_telemetry(
@@ -660,7 +699,7 @@ class LGBMClassifier(BaseTransformer):
660
699
  drop_input_cols=self._drop_input_cols,
661
700
  expected_output_cols_type="float",
662
701
  )
663
- expected_output_cols = self._align_expected_output_names(
702
+ expected_output_cols, _ = self._align_expected_output(
664
703
  inference_method, dataset, expected_output_cols, output_cols_prefix
665
704
  )
666
705
 
@@ -728,7 +767,7 @@ class LGBMClassifier(BaseTransformer):
728
767
  drop_input_cols=self._drop_input_cols,
729
768
  expected_output_cols_type="float",
730
769
  )
731
- expected_output_cols = self._align_expected_output_names(
770
+ expected_output_cols, _ = self._align_expected_output(
732
771
  inference_method, dataset, expected_output_cols, output_cols_prefix
733
772
  )
734
773
  elif isinstance(dataset, pd.DataFrame):
@@ -791,7 +830,7 @@ class LGBMClassifier(BaseTransformer):
791
830
  drop_input_cols=self._drop_input_cols,
792
831
  expected_output_cols_type="float",
793
832
  )
794
- expected_output_cols = self._align_expected_output_names(
833
+ expected_output_cols, _ = self._align_expected_output(
795
834
  inference_method, dataset, expected_output_cols, output_cols_prefix
796
835
  )
797
836
 
@@ -856,7 +895,7 @@ class LGBMClassifier(BaseTransformer):
856
895
  drop_input_cols = self._drop_input_cols,
857
896
  expected_output_cols_type="float",
858
897
  )
859
- expected_output_cols = self._align_expected_output_names(
898
+ expected_output_cols, _ = self._align_expected_output(
860
899
  inference_method, dataset, expected_output_cols, output_cols_prefix
861
900
  )
862
901
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -506,12 +503,23 @@ class LGBMRegressor(BaseTransformer):
506
503
  autogenerated=self._autogenerated,
507
504
  subproject=_SUBPROJECT,
508
505
  )
509
- output_result, fitted_estimator = model_trainer.train_fit_predict(
510
- drop_input_cols=self._drop_input_cols,
511
- expected_output_cols_list=(
512
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
513
- ),
506
+ expected_output_cols = (
507
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
514
508
  )
509
+ if isinstance(dataset, DataFrame):
510
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
511
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
512
+ )
513
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
514
+ drop_input_cols=self._drop_input_cols,
515
+ expected_output_cols_list=expected_output_cols,
516
+ example_output_pd_df=example_output_pd_df,
517
+ )
518
+ else:
519
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
520
+ drop_input_cols=self._drop_input_cols,
521
+ expected_output_cols_list=expected_output_cols,
522
+ )
515
523
  self._sklearn_object = fitted_estimator
516
524
  self._is_fitted = True
517
525
  return output_result
@@ -590,12 +598,41 @@ class LGBMRegressor(BaseTransformer):
590
598
 
591
599
  return rv
592
600
 
593
- def _align_expected_output_names(
594
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
595
- ) -> List[str]:
601
+ def _align_expected_output(
602
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
603
+ ) -> Tuple[List[str], pd.DataFrame]:
604
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
605
+ and output dataframe with 1 line.
606
+ If the method is fit_predict, run 2 lines of data.
607
+ """
596
608
  # in case the inferred output column names dimension is different
597
609
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
598
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
610
+
611
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
612
+ # so change the minimum of number of rows to 2
613
+ num_examples = 2
614
+ statement_params = telemetry.get_function_usage_statement_params(
615
+ project=_PROJECT,
616
+ subproject=_SUBPROJECT,
617
+ function_name=telemetry.get_statement_params_full_func_name(
618
+ inspect.currentframe(), LGBMRegressor.__class__.__name__
619
+ ),
620
+ api_calls=[Session.call],
621
+ custom_tags={"autogen": True} if self._autogenerated else None,
622
+ )
623
+ if output_cols_prefix == "fit_predict_":
624
+ if hasattr(self._sklearn_object, "n_clusters"):
625
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
626
+ num_examples = self._sklearn_object.n_clusters
627
+ elif hasattr(self._sklearn_object, "min_samples"):
628
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
629
+ num_examples = self._sklearn_object.min_samples
630
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
631
+ # LocalOutlierFactor expects n_neighbors <= n_samples
632
+ num_examples = self._sklearn_object.n_neighbors
633
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
634
+ else:
635
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
599
636
 
600
637
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
601
638
  # seen during the fit.
@@ -607,12 +644,14 @@ class LGBMRegressor(BaseTransformer):
607
644
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
608
645
  if self.sample_weight_col:
609
646
  output_df_columns_set -= set(self.sample_weight_col)
647
+
610
648
  # if the dimension of inferred output column names is correct; use it
611
649
  if len(expected_output_cols_list) == len(output_df_columns_set):
612
- return expected_output_cols_list
650
+ return expected_output_cols_list, output_df_pd
613
651
  # otherwise, use the sklearn estimator's output
614
652
  else:
615
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
653
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
654
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
616
655
 
617
656
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
618
657
  @telemetry.send_api_usage_telemetry(
@@ -658,7 +697,7 @@ class LGBMRegressor(BaseTransformer):
658
697
  drop_input_cols=self._drop_input_cols,
659
698
  expected_output_cols_type="float",
660
699
  )
661
- expected_output_cols = self._align_expected_output_names(
700
+ expected_output_cols, _ = self._align_expected_output(
662
701
  inference_method, dataset, expected_output_cols, output_cols_prefix
663
702
  )
664
703
 
@@ -724,7 +763,7 @@ class LGBMRegressor(BaseTransformer):
724
763
  drop_input_cols=self._drop_input_cols,
725
764
  expected_output_cols_type="float",
726
765
  )
727
- expected_output_cols = self._align_expected_output_names(
766
+ expected_output_cols, _ = self._align_expected_output(
728
767
  inference_method, dataset, expected_output_cols, output_cols_prefix
729
768
  )
730
769
  elif isinstance(dataset, pd.DataFrame):
@@ -787,7 +826,7 @@ class LGBMRegressor(BaseTransformer):
787
826
  drop_input_cols=self._drop_input_cols,
788
827
  expected_output_cols_type="float",
789
828
  )
790
- expected_output_cols = self._align_expected_output_names(
829
+ expected_output_cols, _ = self._align_expected_output(
791
830
  inference_method, dataset, expected_output_cols, output_cols_prefix
792
831
  )
793
832
 
@@ -852,7 +891,7 @@ class LGBMRegressor(BaseTransformer):
852
891
  drop_input_cols = self._drop_input_cols,
853
892
  expected_output_cols_type="float",
854
893
  )
855
- expected_output_cols = self._align_expected_output_names(
894
+ expected_output_cols, _ = self._align_expected_output(
856
895
  inference_method, dataset, expected_output_cols, output_cols_prefix
857
896
  )
858
897
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -531,12 +528,23 @@ class ARDRegression(BaseTransformer):
531
528
  autogenerated=self._autogenerated,
532
529
  subproject=_SUBPROJECT,
533
530
  )
534
- output_result, fitted_estimator = model_trainer.train_fit_predict(
535
- drop_input_cols=self._drop_input_cols,
536
- expected_output_cols_list=(
537
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
538
- ),
531
+ expected_output_cols = (
532
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
539
533
  )
534
+ if isinstance(dataset, DataFrame):
535
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
536
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
537
+ )
538
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
539
+ drop_input_cols=self._drop_input_cols,
540
+ expected_output_cols_list=expected_output_cols,
541
+ example_output_pd_df=example_output_pd_df,
542
+ )
543
+ else:
544
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
545
+ drop_input_cols=self._drop_input_cols,
546
+ expected_output_cols_list=expected_output_cols,
547
+ )
540
548
  self._sklearn_object = fitted_estimator
541
549
  self._is_fitted = True
542
550
  return output_result
@@ -615,12 +623,41 @@ class ARDRegression(BaseTransformer):
615
623
 
616
624
  return rv
617
625
 
618
- def _align_expected_output_names(
619
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
620
- ) -> List[str]:
626
+ def _align_expected_output(
627
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
628
+ ) -> Tuple[List[str], pd.DataFrame]:
629
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
630
+ and output dataframe with 1 line.
631
+ If the method is fit_predict, run 2 lines of data.
632
+ """
621
633
  # in case the inferred output column names dimension is different
622
634
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
623
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
635
+
636
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
637
+ # so change the minimum of number of rows to 2
638
+ num_examples = 2
639
+ statement_params = telemetry.get_function_usage_statement_params(
640
+ project=_PROJECT,
641
+ subproject=_SUBPROJECT,
642
+ function_name=telemetry.get_statement_params_full_func_name(
643
+ inspect.currentframe(), ARDRegression.__class__.__name__
644
+ ),
645
+ api_calls=[Session.call],
646
+ custom_tags={"autogen": True} if self._autogenerated else None,
647
+ )
648
+ if output_cols_prefix == "fit_predict_":
649
+ if hasattr(self._sklearn_object, "n_clusters"):
650
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
651
+ num_examples = self._sklearn_object.n_clusters
652
+ elif hasattr(self._sklearn_object, "min_samples"):
653
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
654
+ num_examples = self._sklearn_object.min_samples
655
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
656
+ # LocalOutlierFactor expects n_neighbors <= n_samples
657
+ num_examples = self._sklearn_object.n_neighbors
658
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
659
+ else:
660
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
624
661
 
625
662
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
626
663
  # seen during the fit.
@@ -632,12 +669,14 @@ class ARDRegression(BaseTransformer):
632
669
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
633
670
  if self.sample_weight_col:
634
671
  output_df_columns_set -= set(self.sample_weight_col)
672
+
635
673
  # if the dimension of inferred output column names is correct; use it
636
674
  if len(expected_output_cols_list) == len(output_df_columns_set):
637
- return expected_output_cols_list
675
+ return expected_output_cols_list, output_df_pd
638
676
  # otherwise, use the sklearn estimator's output
639
677
  else:
640
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
678
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
679
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
641
680
 
642
681
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
643
682
  @telemetry.send_api_usage_telemetry(
@@ -683,7 +722,7 @@ class ARDRegression(BaseTransformer):
683
722
  drop_input_cols=self._drop_input_cols,
684
723
  expected_output_cols_type="float",
685
724
  )
686
- expected_output_cols = self._align_expected_output_names(
725
+ expected_output_cols, _ = self._align_expected_output(
687
726
  inference_method, dataset, expected_output_cols, output_cols_prefix
688
727
  )
689
728
 
@@ -749,7 +788,7 @@ class ARDRegression(BaseTransformer):
749
788
  drop_input_cols=self._drop_input_cols,
750
789
  expected_output_cols_type="float",
751
790
  )
752
- expected_output_cols = self._align_expected_output_names(
791
+ expected_output_cols, _ = self._align_expected_output(
753
792
  inference_method, dataset, expected_output_cols, output_cols_prefix
754
793
  )
755
794
  elif isinstance(dataset, pd.DataFrame):
@@ -812,7 +851,7 @@ class ARDRegression(BaseTransformer):
812
851
  drop_input_cols=self._drop_input_cols,
813
852
  expected_output_cols_type="float",
814
853
  )
815
- expected_output_cols = self._align_expected_output_names(
854
+ expected_output_cols, _ = self._align_expected_output(
816
855
  inference_method, dataset, expected_output_cols, output_cols_prefix
817
856
  )
818
857
 
@@ -877,7 +916,7 @@ class ARDRegression(BaseTransformer):
877
916
  drop_input_cols = self._drop_input_cols,
878
917
  expected_output_cols_type="float",
879
918
  )
880
- expected_output_cols = self._align_expected_output_names(
919
+ expected_output_cols, _ = self._align_expected_output(
881
920
  inference_method, dataset, expected_output_cols, output_cols_prefix
882
921
  )
883
922
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -542,12 +539,23 @@ class BayesianRidge(BaseTransformer):
542
539
  autogenerated=self._autogenerated,
543
540
  subproject=_SUBPROJECT,
544
541
  )
545
- output_result, fitted_estimator = model_trainer.train_fit_predict(
546
- drop_input_cols=self._drop_input_cols,
547
- expected_output_cols_list=(
548
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
549
- ),
542
+ expected_output_cols = (
543
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
550
544
  )
545
+ if isinstance(dataset, DataFrame):
546
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
547
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
548
+ )
549
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
550
+ drop_input_cols=self._drop_input_cols,
551
+ expected_output_cols_list=expected_output_cols,
552
+ example_output_pd_df=example_output_pd_df,
553
+ )
554
+ else:
555
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
556
+ drop_input_cols=self._drop_input_cols,
557
+ expected_output_cols_list=expected_output_cols,
558
+ )
551
559
  self._sklearn_object = fitted_estimator
552
560
  self._is_fitted = True
553
561
  return output_result
@@ -626,12 +634,41 @@ class BayesianRidge(BaseTransformer):
626
634
 
627
635
  return rv
628
636
 
629
- def _align_expected_output_names(
630
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
631
- ) -> List[str]:
637
+ def _align_expected_output(
638
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
639
+ ) -> Tuple[List[str], pd.DataFrame]:
640
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
641
+ and output dataframe with 1 line.
642
+ If the method is fit_predict, run 2 lines of data.
643
+ """
632
644
  # in case the inferred output column names dimension is different
633
645
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
634
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
646
+
647
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
648
+ # so change the minimum of number of rows to 2
649
+ num_examples = 2
650
+ statement_params = telemetry.get_function_usage_statement_params(
651
+ project=_PROJECT,
652
+ subproject=_SUBPROJECT,
653
+ function_name=telemetry.get_statement_params_full_func_name(
654
+ inspect.currentframe(), BayesianRidge.__class__.__name__
655
+ ),
656
+ api_calls=[Session.call],
657
+ custom_tags={"autogen": True} if self._autogenerated else None,
658
+ )
659
+ if output_cols_prefix == "fit_predict_":
660
+ if hasattr(self._sklearn_object, "n_clusters"):
661
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
662
+ num_examples = self._sklearn_object.n_clusters
663
+ elif hasattr(self._sklearn_object, "min_samples"):
664
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
665
+ num_examples = self._sklearn_object.min_samples
666
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
667
+ # LocalOutlierFactor expects n_neighbors <= n_samples
668
+ num_examples = self._sklearn_object.n_neighbors
669
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
670
+ else:
671
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
635
672
 
636
673
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
637
674
  # seen during the fit.
@@ -643,12 +680,14 @@ class BayesianRidge(BaseTransformer):
643
680
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
644
681
  if self.sample_weight_col:
645
682
  output_df_columns_set -= set(self.sample_weight_col)
683
+
646
684
  # if the dimension of inferred output column names is correct; use it
647
685
  if len(expected_output_cols_list) == len(output_df_columns_set):
648
- return expected_output_cols_list
686
+ return expected_output_cols_list, output_df_pd
649
687
  # otherwise, use the sklearn estimator's output
650
688
  else:
651
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
689
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
690
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
652
691
 
653
692
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
654
693
  @telemetry.send_api_usage_telemetry(
@@ -694,7 +733,7 @@ class BayesianRidge(BaseTransformer):
694
733
  drop_input_cols=self._drop_input_cols,
695
734
  expected_output_cols_type="float",
696
735
  )
697
- expected_output_cols = self._align_expected_output_names(
736
+ expected_output_cols, _ = self._align_expected_output(
698
737
  inference_method, dataset, expected_output_cols, output_cols_prefix
699
738
  )
700
739
 
@@ -760,7 +799,7 @@ class BayesianRidge(BaseTransformer):
760
799
  drop_input_cols=self._drop_input_cols,
761
800
  expected_output_cols_type="float",
762
801
  )
763
- expected_output_cols = self._align_expected_output_names(
802
+ expected_output_cols, _ = self._align_expected_output(
764
803
  inference_method, dataset, expected_output_cols, output_cols_prefix
765
804
  )
766
805
  elif isinstance(dataset, pd.DataFrame):
@@ -823,7 +862,7 @@ class BayesianRidge(BaseTransformer):
823
862
  drop_input_cols=self._drop_input_cols,
824
863
  expected_output_cols_type="float",
825
864
  )
826
- expected_output_cols = self._align_expected_output_names(
865
+ expected_output_cols, _ = self._align_expected_output(
827
866
  inference_method, dataset, expected_output_cols, output_cols_prefix
828
867
  )
829
868
 
@@ -888,7 +927,7 @@ class BayesianRidge(BaseTransformer):
888
927
  drop_input_cols = self._drop_input_cols,
889
928
  expected_output_cols_type="float",
890
929
  )
891
- expected_output_cols = self._align_expected_output_names(
930
+ expected_output_cols, _ = self._align_expected_output(
892
931
  inference_method, dataset, expected_output_cols, output_cols_prefix
893
932
  )
894
933