snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +156 -20
  5. snowflake/ml/_internal/utils/identifier.py +48 -11
  6. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  7. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  8. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  9. snowflake/ml/_internal/utils/table_manager.py +19 -1
  10. snowflake/ml/_internal/utils/uri.py +2 -2
  11. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  12. snowflake/ml/data/data_connector.py +88 -9
  13. snowflake/ml/data/data_ingestor.py +18 -1
  14. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  15. snowflake/ml/data/torch_utils.py +68 -0
  16. snowflake/ml/dataset/dataset.py +1 -3
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +9 -3
  19. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  20. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  21. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  22. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  23. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  24. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  26. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  27. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  29. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  31. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  32. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  33. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  34. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  35. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  36. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  37. snowflake/ml/feature_store/feature_store.py +100 -41
  38. snowflake/ml/feature_store/feature_view.py +149 -5
  39. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  40. snowflake/ml/fileset/fileset.py +1 -1
  41. snowflake/ml/fileset/sfcfs.py +9 -3
  42. snowflake/ml/model/_client/model/model_impl.py +11 -2
  43. snowflake/ml/model/_client/model/model_version_impl.py +186 -20
  44. snowflake/ml/model/_client/ops/model_ops.py +144 -30
  45. snowflake/ml/model/_client/ops/service_ops.py +312 -0
  46. snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
  47. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
  48. snowflake/ml/model/_client/sql/model_version.py +13 -4
  49. snowflake/ml/model/_client/sql/service.py +196 -0
  50. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  51. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  52. snowflake/ml/model/_model_composer/model_composer.py +5 -0
  53. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
  54. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  55. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  57. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  58. snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
  59. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  60. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
  61. snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
  62. snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
  63. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  64. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  65. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  66. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  67. snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
  68. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
  69. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  70. snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
  71. snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
  72. snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
  73. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
  74. snowflake/ml/model/_packager/model_packager.py +4 -1
  75. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  76. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  77. snowflake/ml/model/_signatures/utils.py +9 -0
  78. snowflake/ml/model/models/llm.py +3 -1
  79. snowflake/ml/model/type_hints.py +10 -4
  80. snowflake/ml/modeling/_internal/constants.py +1 -0
  81. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  82. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  83. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  84. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  85. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  86. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  87. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  88. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  89. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  90. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  91. snowflake/ml/modeling/cluster/birch.py +60 -21
  92. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  93. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  94. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  95. snowflake/ml/modeling/cluster/k_means.py +60 -21
  96. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  97. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  98. snowflake/ml/modeling/cluster/optics.py +60 -21
  99. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  100. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  101. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  102. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  103. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  104. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  105. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  106. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  107. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  108. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  109. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  110. snowflake/ml/modeling/covariance/oas.py +60 -21
  111. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  112. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  113. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  114. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  115. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  116. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  117. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  118. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  119. snowflake/ml/modeling/decomposition/pca.py +60 -21
  120. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  121. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  122. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  123. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  124. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  125. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  126. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  127. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  128. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  129. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  130. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  131. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  132. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  133. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  134. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  135. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  136. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  137. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  138. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  139. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  140. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  141. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  142. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  143. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  144. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  145. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  146. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  147. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  148. snowflake/ml/modeling/framework/base.py +28 -19
  149. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  150. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  151. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  152. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  153. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  154. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  155. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  156. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  157. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  158. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  159. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  160. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  161. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  162. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  163. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  164. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  165. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  166. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  167. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  168. snowflake/ml/modeling/linear_model/lars.py +60 -21
  169. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  170. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  171. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  172. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  173. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  174. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  175. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  176. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  177. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  178. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  179. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  180. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  181. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  182. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  183. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  184. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  185. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  186. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  187. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  188. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  189. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  190. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  191. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  192. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  193. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  194. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  195. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  196. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  197. snowflake/ml/modeling/manifold/isomap.py +60 -21
  198. snowflake/ml/modeling/manifold/mds.py +60 -21
  199. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  200. snowflake/ml/modeling/manifold/tsne.py +60 -21
  201. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  202. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  203. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  204. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  205. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  206. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  207. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  208. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  209. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  210. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  211. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  212. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  213. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  214. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  215. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  216. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  217. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  218. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  219. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  220. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  221. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  222. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  223. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  224. snowflake/ml/modeling/pipeline/pipeline.py +4 -12
  225. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  226. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  227. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  228. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  229. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  230. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  231. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  232. snowflake/ml/modeling/svm/svc.py +60 -21
  233. snowflake/ml/modeling/svm/svr.py +60 -21
  234. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  235. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  236. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  237. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  238. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  239. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  240. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  241. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  242. snowflake/ml/registry/_manager/model_manager.py +20 -2
  243. snowflake/ml/registry/model_registry.py +1 -1
  244. snowflake/ml/registry/registry.py +1 -2
  245. snowflake/ml/utils/sql_client.py +22 -0
  246. snowflake/ml/version.py +1 -1
  247. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
  248. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
  249. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  250. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  251. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  252. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -498,12 +495,23 @@ class PolynomialCountSketch(BaseTransformer):
498
495
  autogenerated=self._autogenerated,
499
496
  subproject=_SUBPROJECT,
500
497
  )
501
- output_result, fitted_estimator = model_trainer.train_fit_predict(
502
- drop_input_cols=self._drop_input_cols,
503
- expected_output_cols_list=(
504
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
505
- ),
498
+ expected_output_cols = (
499
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
506
500
  )
501
+ if isinstance(dataset, DataFrame):
502
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
503
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
504
+ )
505
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
506
+ drop_input_cols=self._drop_input_cols,
507
+ expected_output_cols_list=expected_output_cols,
508
+ example_output_pd_df=example_output_pd_df,
509
+ )
510
+ else:
511
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
512
+ drop_input_cols=self._drop_input_cols,
513
+ expected_output_cols_list=expected_output_cols,
514
+ )
507
515
  self._sklearn_object = fitted_estimator
508
516
  self._is_fitted = True
509
517
  return output_result
@@ -584,12 +592,41 @@ class PolynomialCountSketch(BaseTransformer):
584
592
 
585
593
  return rv
586
594
 
587
- def _align_expected_output_names(
588
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
589
- ) -> List[str]:
595
+ def _align_expected_output(
596
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
597
+ ) -> Tuple[List[str], pd.DataFrame]:
598
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
599
+ and output dataframe with 1 line.
600
+ If the method is fit_predict, run 2 lines of data.
601
+ """
590
602
  # in case the inferred output column names dimension is different
591
603
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
592
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
604
+
605
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
606
+ # so change the minimum of number of rows to 2
607
+ num_examples = 2
608
+ statement_params = telemetry.get_function_usage_statement_params(
609
+ project=_PROJECT,
610
+ subproject=_SUBPROJECT,
611
+ function_name=telemetry.get_statement_params_full_func_name(
612
+ inspect.currentframe(), PolynomialCountSketch.__class__.__name__
613
+ ),
614
+ api_calls=[Session.call],
615
+ custom_tags={"autogen": True} if self._autogenerated else None,
616
+ )
617
+ if output_cols_prefix == "fit_predict_":
618
+ if hasattr(self._sklearn_object, "n_clusters"):
619
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
620
+ num_examples = self._sklearn_object.n_clusters
621
+ elif hasattr(self._sklearn_object, "min_samples"):
622
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
623
+ num_examples = self._sklearn_object.min_samples
624
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
625
+ # LocalOutlierFactor expects n_neighbors <= n_samples
626
+ num_examples = self._sklearn_object.n_neighbors
627
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
628
+ else:
629
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
593
630
 
594
631
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
595
632
  # seen during the fit.
@@ -601,12 +638,14 @@ class PolynomialCountSketch(BaseTransformer):
601
638
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
602
639
  if self.sample_weight_col:
603
640
  output_df_columns_set -= set(self.sample_weight_col)
641
+
604
642
  # if the dimension of inferred output column names is correct; use it
605
643
  if len(expected_output_cols_list) == len(output_df_columns_set):
606
- return expected_output_cols_list
644
+ return expected_output_cols_list, output_df_pd
607
645
  # otherwise, use the sklearn estimator's output
608
646
  else:
609
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
647
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
648
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
610
649
 
611
650
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
612
651
  @telemetry.send_api_usage_telemetry(
@@ -652,7 +691,7 @@ class PolynomialCountSketch(BaseTransformer):
652
691
  drop_input_cols=self._drop_input_cols,
653
692
  expected_output_cols_type="float",
654
693
  )
655
- expected_output_cols = self._align_expected_output_names(
694
+ expected_output_cols, _ = self._align_expected_output(
656
695
  inference_method, dataset, expected_output_cols, output_cols_prefix
657
696
  )
658
697
 
@@ -718,7 +757,7 @@ class PolynomialCountSketch(BaseTransformer):
718
757
  drop_input_cols=self._drop_input_cols,
719
758
  expected_output_cols_type="float",
720
759
  )
721
- expected_output_cols = self._align_expected_output_names(
760
+ expected_output_cols, _ = self._align_expected_output(
722
761
  inference_method, dataset, expected_output_cols, output_cols_prefix
723
762
  )
724
763
  elif isinstance(dataset, pd.DataFrame):
@@ -781,7 +820,7 @@ class PolynomialCountSketch(BaseTransformer):
781
820
  drop_input_cols=self._drop_input_cols,
782
821
  expected_output_cols_type="float",
783
822
  )
784
- expected_output_cols = self._align_expected_output_names(
823
+ expected_output_cols, _ = self._align_expected_output(
785
824
  inference_method, dataset, expected_output_cols, output_cols_prefix
786
825
  )
787
826
 
@@ -846,7 +885,7 @@ class PolynomialCountSketch(BaseTransformer):
846
885
  drop_input_cols = self._drop_input_cols,
847
886
  expected_output_cols_type="float",
848
887
  )
849
- expected_output_cols = self._align_expected_output_names(
888
+ expected_output_cols, _ = self._align_expected_output(
850
889
  inference_method, dataset, expected_output_cols, output_cols_prefix
851
890
  )
852
891
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -485,12 +482,23 @@ class RBFSampler(BaseTransformer):
485
482
  autogenerated=self._autogenerated,
486
483
  subproject=_SUBPROJECT,
487
484
  )
488
- output_result, fitted_estimator = model_trainer.train_fit_predict(
489
- drop_input_cols=self._drop_input_cols,
490
- expected_output_cols_list=(
491
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
492
- ),
485
+ expected_output_cols = (
486
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
493
487
  )
488
+ if isinstance(dataset, DataFrame):
489
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
490
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
491
+ )
492
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
493
+ drop_input_cols=self._drop_input_cols,
494
+ expected_output_cols_list=expected_output_cols,
495
+ example_output_pd_df=example_output_pd_df,
496
+ )
497
+ else:
498
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
499
+ drop_input_cols=self._drop_input_cols,
500
+ expected_output_cols_list=expected_output_cols,
501
+ )
494
502
  self._sklearn_object = fitted_estimator
495
503
  self._is_fitted = True
496
504
  return output_result
@@ -571,12 +579,41 @@ class RBFSampler(BaseTransformer):
571
579
 
572
580
  return rv
573
581
 
574
- def _align_expected_output_names(
575
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
576
- ) -> List[str]:
582
+ def _align_expected_output(
583
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
584
+ ) -> Tuple[List[str], pd.DataFrame]:
585
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
586
+ and output dataframe with 1 line.
587
+ If the method is fit_predict, run 2 lines of data.
588
+ """
577
589
  # in case the inferred output column names dimension is different
578
590
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
579
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
591
+
592
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
593
+ # so change the minimum of number of rows to 2
594
+ num_examples = 2
595
+ statement_params = telemetry.get_function_usage_statement_params(
596
+ project=_PROJECT,
597
+ subproject=_SUBPROJECT,
598
+ function_name=telemetry.get_statement_params_full_func_name(
599
+ inspect.currentframe(), RBFSampler.__class__.__name__
600
+ ),
601
+ api_calls=[Session.call],
602
+ custom_tags={"autogen": True} if self._autogenerated else None,
603
+ )
604
+ if output_cols_prefix == "fit_predict_":
605
+ if hasattr(self._sklearn_object, "n_clusters"):
606
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
607
+ num_examples = self._sklearn_object.n_clusters
608
+ elif hasattr(self._sklearn_object, "min_samples"):
609
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
610
+ num_examples = self._sklearn_object.min_samples
611
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
612
+ # LocalOutlierFactor expects n_neighbors <= n_samples
613
+ num_examples = self._sklearn_object.n_neighbors
614
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
615
+ else:
616
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
580
617
 
581
618
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
582
619
  # seen during the fit.
@@ -588,12 +625,14 @@ class RBFSampler(BaseTransformer):
588
625
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
589
626
  if self.sample_weight_col:
590
627
  output_df_columns_set -= set(self.sample_weight_col)
628
+
591
629
  # if the dimension of inferred output column names is correct; use it
592
630
  if len(expected_output_cols_list) == len(output_df_columns_set):
593
- return expected_output_cols_list
631
+ return expected_output_cols_list, output_df_pd
594
632
  # otherwise, use the sklearn estimator's output
595
633
  else:
596
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
634
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
635
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
597
636
 
598
637
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
599
638
  @telemetry.send_api_usage_telemetry(
@@ -639,7 +678,7 @@ class RBFSampler(BaseTransformer):
639
678
  drop_input_cols=self._drop_input_cols,
640
679
  expected_output_cols_type="float",
641
680
  )
642
- expected_output_cols = self._align_expected_output_names(
681
+ expected_output_cols, _ = self._align_expected_output(
643
682
  inference_method, dataset, expected_output_cols, output_cols_prefix
644
683
  )
645
684
 
@@ -705,7 +744,7 @@ class RBFSampler(BaseTransformer):
705
744
  drop_input_cols=self._drop_input_cols,
706
745
  expected_output_cols_type="float",
707
746
  )
708
- expected_output_cols = self._align_expected_output_names(
747
+ expected_output_cols, _ = self._align_expected_output(
709
748
  inference_method, dataset, expected_output_cols, output_cols_prefix
710
749
  )
711
750
  elif isinstance(dataset, pd.DataFrame):
@@ -768,7 +807,7 @@ class RBFSampler(BaseTransformer):
768
807
  drop_input_cols=self._drop_input_cols,
769
808
  expected_output_cols_type="float",
770
809
  )
771
- expected_output_cols = self._align_expected_output_names(
810
+ expected_output_cols, _ = self._align_expected_output(
772
811
  inference_method, dataset, expected_output_cols, output_cols_prefix
773
812
  )
774
813
 
@@ -833,7 +872,7 @@ class RBFSampler(BaseTransformer):
833
872
  drop_input_cols = self._drop_input_cols,
834
873
  expected_output_cols_type="float",
835
874
  )
836
- expected_output_cols = self._align_expected_output_names(
875
+ expected_output_cols, _ = self._align_expected_output(
837
876
  inference_method, dataset, expected_output_cols, output_cols_prefix
838
877
  )
839
878
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -483,12 +480,23 @@ class SkewedChi2Sampler(BaseTransformer):
483
480
  autogenerated=self._autogenerated,
484
481
  subproject=_SUBPROJECT,
485
482
  )
486
- output_result, fitted_estimator = model_trainer.train_fit_predict(
487
- drop_input_cols=self._drop_input_cols,
488
- expected_output_cols_list=(
489
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
490
- ),
483
+ expected_output_cols = (
484
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
491
485
  )
486
+ if isinstance(dataset, DataFrame):
487
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
488
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
489
+ )
490
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
491
+ drop_input_cols=self._drop_input_cols,
492
+ expected_output_cols_list=expected_output_cols,
493
+ example_output_pd_df=example_output_pd_df,
494
+ )
495
+ else:
496
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
497
+ drop_input_cols=self._drop_input_cols,
498
+ expected_output_cols_list=expected_output_cols,
499
+ )
492
500
  self._sklearn_object = fitted_estimator
493
501
  self._is_fitted = True
494
502
  return output_result
@@ -569,12 +577,41 @@ class SkewedChi2Sampler(BaseTransformer):
569
577
 
570
578
  return rv
571
579
 
572
- def _align_expected_output_names(
573
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
574
- ) -> List[str]:
580
+ def _align_expected_output(
581
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
582
+ ) -> Tuple[List[str], pd.DataFrame]:
583
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
584
+ and output dataframe with 1 line.
585
+ If the method is fit_predict, run 2 lines of data.
586
+ """
575
587
  # in case the inferred output column names dimension is different
576
588
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
577
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
589
+
590
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
591
+ # so change the minimum of number of rows to 2
592
+ num_examples = 2
593
+ statement_params = telemetry.get_function_usage_statement_params(
594
+ project=_PROJECT,
595
+ subproject=_SUBPROJECT,
596
+ function_name=telemetry.get_statement_params_full_func_name(
597
+ inspect.currentframe(), SkewedChi2Sampler.__class__.__name__
598
+ ),
599
+ api_calls=[Session.call],
600
+ custom_tags={"autogen": True} if self._autogenerated else None,
601
+ )
602
+ if output_cols_prefix == "fit_predict_":
603
+ if hasattr(self._sklearn_object, "n_clusters"):
604
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
605
+ num_examples = self._sklearn_object.n_clusters
606
+ elif hasattr(self._sklearn_object, "min_samples"):
607
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
608
+ num_examples = self._sklearn_object.min_samples
609
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
610
+ # LocalOutlierFactor expects n_neighbors <= n_samples
611
+ num_examples = self._sklearn_object.n_neighbors
612
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
613
+ else:
614
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
578
615
 
579
616
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
580
617
  # seen during the fit.
@@ -586,12 +623,14 @@ class SkewedChi2Sampler(BaseTransformer):
586
623
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
587
624
  if self.sample_weight_col:
588
625
  output_df_columns_set -= set(self.sample_weight_col)
626
+
589
627
  # if the dimension of inferred output column names is correct; use it
590
628
  if len(expected_output_cols_list) == len(output_df_columns_set):
591
- return expected_output_cols_list
629
+ return expected_output_cols_list, output_df_pd
592
630
  # otherwise, use the sklearn estimator's output
593
631
  else:
594
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
632
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
633
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
595
634
 
596
635
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
597
636
  @telemetry.send_api_usage_telemetry(
@@ -637,7 +676,7 @@ class SkewedChi2Sampler(BaseTransformer):
637
676
  drop_input_cols=self._drop_input_cols,
638
677
  expected_output_cols_type="float",
639
678
  )
640
- expected_output_cols = self._align_expected_output_names(
679
+ expected_output_cols, _ = self._align_expected_output(
641
680
  inference_method, dataset, expected_output_cols, output_cols_prefix
642
681
  )
643
682
 
@@ -703,7 +742,7 @@ class SkewedChi2Sampler(BaseTransformer):
703
742
  drop_input_cols=self._drop_input_cols,
704
743
  expected_output_cols_type="float",
705
744
  )
706
- expected_output_cols = self._align_expected_output_names(
745
+ expected_output_cols, _ = self._align_expected_output(
707
746
  inference_method, dataset, expected_output_cols, output_cols_prefix
708
747
  )
709
748
  elif isinstance(dataset, pd.DataFrame):
@@ -766,7 +805,7 @@ class SkewedChi2Sampler(BaseTransformer):
766
805
  drop_input_cols=self._drop_input_cols,
767
806
  expected_output_cols_type="float",
768
807
  )
769
- expected_output_cols = self._align_expected_output_names(
808
+ expected_output_cols, _ = self._align_expected_output(
770
809
  inference_method, dataset, expected_output_cols, output_cols_prefix
771
810
  )
772
811
 
@@ -831,7 +870,7 @@ class SkewedChi2Sampler(BaseTransformer):
831
870
  drop_input_cols = self._drop_input_cols,
832
871
  expected_output_cols_type="float",
833
872
  )
834
- expected_output_cols = self._align_expected_output_names(
873
+ expected_output_cols, _ = self._align_expected_output(
835
874
  inference_method, dataset, expected_output_cols, output_cols_prefix
836
875
  )
837
876
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -517,12 +514,23 @@ class KernelRidge(BaseTransformer):
517
514
  autogenerated=self._autogenerated,
518
515
  subproject=_SUBPROJECT,
519
516
  )
520
- output_result, fitted_estimator = model_trainer.train_fit_predict(
521
- drop_input_cols=self._drop_input_cols,
522
- expected_output_cols_list=(
523
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
524
- ),
517
+ expected_output_cols = (
518
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
525
519
  )
520
+ if isinstance(dataset, DataFrame):
521
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
522
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
523
+ )
524
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
525
+ drop_input_cols=self._drop_input_cols,
526
+ expected_output_cols_list=expected_output_cols,
527
+ example_output_pd_df=example_output_pd_df,
528
+ )
529
+ else:
530
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
531
+ drop_input_cols=self._drop_input_cols,
532
+ expected_output_cols_list=expected_output_cols,
533
+ )
526
534
  self._sklearn_object = fitted_estimator
527
535
  self._is_fitted = True
528
536
  return output_result
@@ -601,12 +609,41 @@ class KernelRidge(BaseTransformer):
601
609
 
602
610
  return rv
603
611
 
604
- def _align_expected_output_names(
605
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
606
- ) -> List[str]:
612
+ def _align_expected_output(
613
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
614
+ ) -> Tuple[List[str], pd.DataFrame]:
615
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
616
+ and output dataframe with 1 line.
617
+ If the method is fit_predict, run 2 lines of data.
618
+ """
607
619
  # in case the inferred output column names dimension is different
608
620
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
609
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
621
+
622
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
623
+ # so change the minimum of number of rows to 2
624
+ num_examples = 2
625
+ statement_params = telemetry.get_function_usage_statement_params(
626
+ project=_PROJECT,
627
+ subproject=_SUBPROJECT,
628
+ function_name=telemetry.get_statement_params_full_func_name(
629
+ inspect.currentframe(), KernelRidge.__class__.__name__
630
+ ),
631
+ api_calls=[Session.call],
632
+ custom_tags={"autogen": True} if self._autogenerated else None,
633
+ )
634
+ if output_cols_prefix == "fit_predict_":
635
+ if hasattr(self._sklearn_object, "n_clusters"):
636
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
637
+ num_examples = self._sklearn_object.n_clusters
638
+ elif hasattr(self._sklearn_object, "min_samples"):
639
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
640
+ num_examples = self._sklearn_object.min_samples
641
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
642
+ # LocalOutlierFactor expects n_neighbors <= n_samples
643
+ num_examples = self._sklearn_object.n_neighbors
644
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
645
+ else:
646
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
610
647
 
611
648
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
612
649
  # seen during the fit.
@@ -618,12 +655,14 @@ class KernelRidge(BaseTransformer):
618
655
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
619
656
  if self.sample_weight_col:
620
657
  output_df_columns_set -= set(self.sample_weight_col)
658
+
621
659
  # if the dimension of inferred output column names is correct; use it
622
660
  if len(expected_output_cols_list) == len(output_df_columns_set):
623
- return expected_output_cols_list
661
+ return expected_output_cols_list, output_df_pd
624
662
  # otherwise, use the sklearn estimator's output
625
663
  else:
626
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
664
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
665
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
627
666
 
628
667
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
629
668
  @telemetry.send_api_usage_telemetry(
@@ -669,7 +708,7 @@ class KernelRidge(BaseTransformer):
669
708
  drop_input_cols=self._drop_input_cols,
670
709
  expected_output_cols_type="float",
671
710
  )
672
- expected_output_cols = self._align_expected_output_names(
711
+ expected_output_cols, _ = self._align_expected_output(
673
712
  inference_method, dataset, expected_output_cols, output_cols_prefix
674
713
  )
675
714
 
@@ -735,7 +774,7 @@ class KernelRidge(BaseTransformer):
735
774
  drop_input_cols=self._drop_input_cols,
736
775
  expected_output_cols_type="float",
737
776
  )
738
- expected_output_cols = self._align_expected_output_names(
777
+ expected_output_cols, _ = self._align_expected_output(
739
778
  inference_method, dataset, expected_output_cols, output_cols_prefix
740
779
  )
741
780
  elif isinstance(dataset, pd.DataFrame):
@@ -798,7 +837,7 @@ class KernelRidge(BaseTransformer):
798
837
  drop_input_cols=self._drop_input_cols,
799
838
  expected_output_cols_type="float",
800
839
  )
801
- expected_output_cols = self._align_expected_output_names(
840
+ expected_output_cols, _ = self._align_expected_output(
802
841
  inference_method, dataset, expected_output_cols, output_cols_prefix
803
842
  )
804
843
 
@@ -863,7 +902,7 @@ class KernelRidge(BaseTransformer):
863
902
  drop_input_cols = self._drop_input_cols,
864
903
  expected_output_cols_type="float",
865
904
  )
866
- expected_output_cols = self._align_expected_output_names(
905
+ expected_output_cols, _ = self._align_expected_output(
867
906
  inference_method, dataset, expected_output_cols, output_cols_prefix
868
907
  )
869
908