snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +156 -20
  5. snowflake/ml/_internal/utils/identifier.py +48 -11
  6. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  7. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  8. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  9. snowflake/ml/_internal/utils/table_manager.py +19 -1
  10. snowflake/ml/_internal/utils/uri.py +2 -2
  11. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  12. snowflake/ml/data/data_connector.py +88 -9
  13. snowflake/ml/data/data_ingestor.py +18 -1
  14. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  15. snowflake/ml/data/torch_utils.py +68 -0
  16. snowflake/ml/dataset/dataset.py +1 -3
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +9 -3
  19. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  20. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  21. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  22. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  23. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  24. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  26. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  27. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  29. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  31. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  32. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  33. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  34. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  35. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  36. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  37. snowflake/ml/feature_store/feature_store.py +100 -41
  38. snowflake/ml/feature_store/feature_view.py +149 -5
  39. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  40. snowflake/ml/fileset/fileset.py +1 -1
  41. snowflake/ml/fileset/sfcfs.py +9 -3
  42. snowflake/ml/model/_client/model/model_impl.py +11 -2
  43. snowflake/ml/model/_client/model/model_version_impl.py +186 -20
  44. snowflake/ml/model/_client/ops/model_ops.py +144 -30
  45. snowflake/ml/model/_client/ops/service_ops.py +312 -0
  46. snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
  47. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
  48. snowflake/ml/model/_client/sql/model_version.py +13 -4
  49. snowflake/ml/model/_client/sql/service.py +196 -0
  50. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  51. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  52. snowflake/ml/model/_model_composer/model_composer.py +5 -0
  53. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
  54. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  55. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  57. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  58. snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
  59. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  60. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
  61. snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
  62. snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
  63. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  64. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  65. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  66. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  67. snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
  68. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
  69. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  70. snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
  71. snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
  72. snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
  73. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
  74. snowflake/ml/model/_packager/model_packager.py +4 -1
  75. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  76. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  77. snowflake/ml/model/_signatures/utils.py +9 -0
  78. snowflake/ml/model/models/llm.py +3 -1
  79. snowflake/ml/model/type_hints.py +10 -4
  80. snowflake/ml/modeling/_internal/constants.py +1 -0
  81. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  82. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  83. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  84. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  85. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  86. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  87. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  88. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  89. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  90. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  91. snowflake/ml/modeling/cluster/birch.py +60 -21
  92. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  93. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  94. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  95. snowflake/ml/modeling/cluster/k_means.py +60 -21
  96. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  97. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  98. snowflake/ml/modeling/cluster/optics.py +60 -21
  99. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  100. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  101. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  102. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  103. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  104. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  105. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  106. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  107. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  108. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  109. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  110. snowflake/ml/modeling/covariance/oas.py +60 -21
  111. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  112. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  113. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  114. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  115. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  116. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  117. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  118. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  119. snowflake/ml/modeling/decomposition/pca.py +60 -21
  120. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  121. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  122. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  123. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  124. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  125. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  126. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  127. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  128. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  129. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  130. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  131. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  132. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  133. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  134. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  135. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  136. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  137. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  138. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  139. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  140. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  141. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  142. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  143. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  144. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  145. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  146. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  147. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  148. snowflake/ml/modeling/framework/base.py +28 -19
  149. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  150. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  151. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  152. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  153. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  154. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  155. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  156. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  157. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  158. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  159. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  160. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  161. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  162. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  163. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  164. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  165. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  166. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  167. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  168. snowflake/ml/modeling/linear_model/lars.py +60 -21
  169. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  170. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  171. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  172. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  173. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  174. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  175. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  176. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  177. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  178. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  179. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  180. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  181. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  182. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  183. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  184. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  185. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  186. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  187. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  188. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  189. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  190. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  191. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  192. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  193. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  194. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  195. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  196. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  197. snowflake/ml/modeling/manifold/isomap.py +60 -21
  198. snowflake/ml/modeling/manifold/mds.py +60 -21
  199. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  200. snowflake/ml/modeling/manifold/tsne.py +60 -21
  201. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  202. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  203. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  204. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  205. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  206. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  207. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  208. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  209. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  210. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  211. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  212. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  213. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  214. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  215. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  216. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  217. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  218. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  219. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  220. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  221. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  222. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  223. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  224. snowflake/ml/modeling/pipeline/pipeline.py +4 -12
  225. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  226. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  227. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  228. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  229. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  230. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  231. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  232. snowflake/ml/modeling/svm/svc.py +60 -21
  233. snowflake/ml/modeling/svm/svr.py +60 -21
  234. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  235. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  236. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  237. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  238. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  239. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  240. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  241. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  242. snowflake/ml/registry/_manager/model_manager.py +20 -2
  243. snowflake/ml/registry/model_registry.py +1 -1
  244. snowflake/ml/registry/registry.py +1 -2
  245. snowflake/ml/utils/sql_client.py +22 -0
  246. snowflake/ml/version.py +1 -1
  247. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
  248. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
  249. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  250. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  251. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  252. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -548,12 +545,23 @@ class KNeighborsRegressor(BaseTransformer):
548
545
  autogenerated=self._autogenerated,
549
546
  subproject=_SUBPROJECT,
550
547
  )
551
- output_result, fitted_estimator = model_trainer.train_fit_predict(
552
- drop_input_cols=self._drop_input_cols,
553
- expected_output_cols_list=(
554
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
555
- ),
548
+ expected_output_cols = (
549
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
556
550
  )
551
+ if isinstance(dataset, DataFrame):
552
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
553
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
554
+ )
555
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
556
+ drop_input_cols=self._drop_input_cols,
557
+ expected_output_cols_list=expected_output_cols,
558
+ example_output_pd_df=example_output_pd_df,
559
+ )
560
+ else:
561
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
562
+ drop_input_cols=self._drop_input_cols,
563
+ expected_output_cols_list=expected_output_cols,
564
+ )
557
565
  self._sklearn_object = fitted_estimator
558
566
  self._is_fitted = True
559
567
  return output_result
@@ -632,12 +640,41 @@ class KNeighborsRegressor(BaseTransformer):
632
640
 
633
641
  return rv
634
642
 
635
- def _align_expected_output_names(
636
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
637
- ) -> List[str]:
643
+ def _align_expected_output(
644
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
645
+ ) -> Tuple[List[str], pd.DataFrame]:
646
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
647
+ and output dataframe with 1 line.
648
+ If the method is fit_predict, run 2 lines of data.
649
+ """
638
650
  # in case the inferred output column names dimension is different
639
651
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
640
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
652
+
653
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
654
+ # so change the minimum of number of rows to 2
655
+ num_examples = 2
656
+ statement_params = telemetry.get_function_usage_statement_params(
657
+ project=_PROJECT,
658
+ subproject=_SUBPROJECT,
659
+ function_name=telemetry.get_statement_params_full_func_name(
660
+ inspect.currentframe(), KNeighborsRegressor.__class__.__name__
661
+ ),
662
+ api_calls=[Session.call],
663
+ custom_tags={"autogen": True} if self._autogenerated else None,
664
+ )
665
+ if output_cols_prefix == "fit_predict_":
666
+ if hasattr(self._sklearn_object, "n_clusters"):
667
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
668
+ num_examples = self._sklearn_object.n_clusters
669
+ elif hasattr(self._sklearn_object, "min_samples"):
670
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
671
+ num_examples = self._sklearn_object.min_samples
672
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
673
+ # LocalOutlierFactor expects n_neighbors <= n_samples
674
+ num_examples = self._sklearn_object.n_neighbors
675
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
676
+ else:
677
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
641
678
 
642
679
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
643
680
  # seen during the fit.
@@ -649,12 +686,14 @@ class KNeighborsRegressor(BaseTransformer):
649
686
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
650
687
  if self.sample_weight_col:
651
688
  output_df_columns_set -= set(self.sample_weight_col)
689
+
652
690
  # if the dimension of inferred output column names is correct; use it
653
691
  if len(expected_output_cols_list) == len(output_df_columns_set):
654
- return expected_output_cols_list
692
+ return expected_output_cols_list, output_df_pd
655
693
  # otherwise, use the sklearn estimator's output
656
694
  else:
657
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
695
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
696
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
658
697
 
659
698
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
660
699
  @telemetry.send_api_usage_telemetry(
@@ -700,7 +739,7 @@ class KNeighborsRegressor(BaseTransformer):
700
739
  drop_input_cols=self._drop_input_cols,
701
740
  expected_output_cols_type="float",
702
741
  )
703
- expected_output_cols = self._align_expected_output_names(
742
+ expected_output_cols, _ = self._align_expected_output(
704
743
  inference_method, dataset, expected_output_cols, output_cols_prefix
705
744
  )
706
745
 
@@ -766,7 +805,7 @@ class KNeighborsRegressor(BaseTransformer):
766
805
  drop_input_cols=self._drop_input_cols,
767
806
  expected_output_cols_type="float",
768
807
  )
769
- expected_output_cols = self._align_expected_output_names(
808
+ expected_output_cols, _ = self._align_expected_output(
770
809
  inference_method, dataset, expected_output_cols, output_cols_prefix
771
810
  )
772
811
  elif isinstance(dataset, pd.DataFrame):
@@ -829,7 +868,7 @@ class KNeighborsRegressor(BaseTransformer):
829
868
  drop_input_cols=self._drop_input_cols,
830
869
  expected_output_cols_type="float",
831
870
  )
832
- expected_output_cols = self._align_expected_output_names(
871
+ expected_output_cols, _ = self._align_expected_output(
833
872
  inference_method, dataset, expected_output_cols, output_cols_prefix
834
873
  )
835
874
 
@@ -894,7 +933,7 @@ class KNeighborsRegressor(BaseTransformer):
894
933
  drop_input_cols = self._drop_input_cols,
895
934
  expected_output_cols_type="float",
896
935
  )
897
- expected_output_cols = self._align_expected_output_names(
936
+ expected_output_cols, _ = self._align_expected_output(
898
937
  inference_method, dataset, expected_output_cols, output_cols_prefix
899
938
  )
900
939
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -525,12 +522,23 @@ class KernelDensity(BaseTransformer):
525
522
  autogenerated=self._autogenerated,
526
523
  subproject=_SUBPROJECT,
527
524
  )
528
- output_result, fitted_estimator = model_trainer.train_fit_predict(
529
- drop_input_cols=self._drop_input_cols,
530
- expected_output_cols_list=(
531
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
532
- ),
525
+ expected_output_cols = (
526
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
533
527
  )
528
+ if isinstance(dataset, DataFrame):
529
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
530
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
531
+ )
532
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
533
+ drop_input_cols=self._drop_input_cols,
534
+ expected_output_cols_list=expected_output_cols,
535
+ example_output_pd_df=example_output_pd_df,
536
+ )
537
+ else:
538
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
539
+ drop_input_cols=self._drop_input_cols,
540
+ expected_output_cols_list=expected_output_cols,
541
+ )
534
542
  self._sklearn_object = fitted_estimator
535
543
  self._is_fitted = True
536
544
  return output_result
@@ -609,12 +617,41 @@ class KernelDensity(BaseTransformer):
609
617
 
610
618
  return rv
611
619
 
612
- def _align_expected_output_names(
613
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
614
- ) -> List[str]:
620
+ def _align_expected_output(
621
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
622
+ ) -> Tuple[List[str], pd.DataFrame]:
623
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
624
+ and output dataframe with 1 line.
625
+ If the method is fit_predict, run 2 lines of data.
626
+ """
615
627
  # in case the inferred output column names dimension is different
616
628
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
617
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
629
+
630
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
631
+ # so change the minimum of number of rows to 2
632
+ num_examples = 2
633
+ statement_params = telemetry.get_function_usage_statement_params(
634
+ project=_PROJECT,
635
+ subproject=_SUBPROJECT,
636
+ function_name=telemetry.get_statement_params_full_func_name(
637
+ inspect.currentframe(), KernelDensity.__class__.__name__
638
+ ),
639
+ api_calls=[Session.call],
640
+ custom_tags={"autogen": True} if self._autogenerated else None,
641
+ )
642
+ if output_cols_prefix == "fit_predict_":
643
+ if hasattr(self._sklearn_object, "n_clusters"):
644
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
645
+ num_examples = self._sklearn_object.n_clusters
646
+ elif hasattr(self._sklearn_object, "min_samples"):
647
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
648
+ num_examples = self._sklearn_object.min_samples
649
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
650
+ # LocalOutlierFactor expects n_neighbors <= n_samples
651
+ num_examples = self._sklearn_object.n_neighbors
652
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
653
+ else:
654
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
618
655
 
619
656
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
620
657
  # seen during the fit.
@@ -626,12 +663,14 @@ class KernelDensity(BaseTransformer):
626
663
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
627
664
  if self.sample_weight_col:
628
665
  output_df_columns_set -= set(self.sample_weight_col)
666
+
629
667
  # if the dimension of inferred output column names is correct; use it
630
668
  if len(expected_output_cols_list) == len(output_df_columns_set):
631
- return expected_output_cols_list
669
+ return expected_output_cols_list, output_df_pd
632
670
  # otherwise, use the sklearn estimator's output
633
671
  else:
634
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
672
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
673
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
635
674
 
636
675
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
637
676
  @telemetry.send_api_usage_telemetry(
@@ -677,7 +716,7 @@ class KernelDensity(BaseTransformer):
677
716
  drop_input_cols=self._drop_input_cols,
678
717
  expected_output_cols_type="float",
679
718
  )
680
- expected_output_cols = self._align_expected_output_names(
719
+ expected_output_cols, _ = self._align_expected_output(
681
720
  inference_method, dataset, expected_output_cols, output_cols_prefix
682
721
  )
683
722
 
@@ -743,7 +782,7 @@ class KernelDensity(BaseTransformer):
743
782
  drop_input_cols=self._drop_input_cols,
744
783
  expected_output_cols_type="float",
745
784
  )
746
- expected_output_cols = self._align_expected_output_names(
785
+ expected_output_cols, _ = self._align_expected_output(
747
786
  inference_method, dataset, expected_output_cols, output_cols_prefix
748
787
  )
749
788
  elif isinstance(dataset, pd.DataFrame):
@@ -806,7 +845,7 @@ class KernelDensity(BaseTransformer):
806
845
  drop_input_cols=self._drop_input_cols,
807
846
  expected_output_cols_type="float",
808
847
  )
809
- expected_output_cols = self._align_expected_output_names(
848
+ expected_output_cols, _ = self._align_expected_output(
810
849
  inference_method, dataset, expected_output_cols, output_cols_prefix
811
850
  )
812
851
 
@@ -873,7 +912,7 @@ class KernelDensity(BaseTransformer):
873
912
  drop_input_cols = self._drop_input_cols,
874
913
  expected_output_cols_type="float",
875
914
  )
876
- expected_output_cols = self._align_expected_output_names(
915
+ expected_output_cols, _ = self._align_expected_output(
877
916
  inference_method, dataset, expected_output_cols, output_cols_prefix
878
917
  )
879
918
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -557,12 +554,23 @@ class LocalOutlierFactor(BaseTransformer):
557
554
  autogenerated=self._autogenerated,
558
555
  subproject=_SUBPROJECT,
559
556
  )
560
- output_result, fitted_estimator = model_trainer.train_fit_predict(
561
- drop_input_cols=self._drop_input_cols,
562
- expected_output_cols_list=(
563
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
564
- ),
557
+ expected_output_cols = (
558
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
565
559
  )
560
+ if isinstance(dataset, DataFrame):
561
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
562
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
563
+ )
564
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
565
+ drop_input_cols=self._drop_input_cols,
566
+ expected_output_cols_list=expected_output_cols,
567
+ example_output_pd_df=example_output_pd_df,
568
+ )
569
+ else:
570
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
571
+ drop_input_cols=self._drop_input_cols,
572
+ expected_output_cols_list=expected_output_cols,
573
+ )
566
574
  self._sklearn_object = fitted_estimator
567
575
  self._is_fitted = True
568
576
  return output_result
@@ -641,12 +649,41 @@ class LocalOutlierFactor(BaseTransformer):
641
649
 
642
650
  return rv
643
651
 
644
- def _align_expected_output_names(
645
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
646
- ) -> List[str]:
652
+ def _align_expected_output(
653
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
654
+ ) -> Tuple[List[str], pd.DataFrame]:
655
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
656
+ and output dataframe with 1 line.
657
+ If the method is fit_predict, run 2 lines of data.
658
+ """
647
659
  # in case the inferred output column names dimension is different
648
660
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
649
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
661
+
662
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
663
+ # so change the minimum of number of rows to 2
664
+ num_examples = 2
665
+ statement_params = telemetry.get_function_usage_statement_params(
666
+ project=_PROJECT,
667
+ subproject=_SUBPROJECT,
668
+ function_name=telemetry.get_statement_params_full_func_name(
669
+ inspect.currentframe(), LocalOutlierFactor.__class__.__name__
670
+ ),
671
+ api_calls=[Session.call],
672
+ custom_tags={"autogen": True} if self._autogenerated else None,
673
+ )
674
+ if output_cols_prefix == "fit_predict_":
675
+ if hasattr(self._sklearn_object, "n_clusters"):
676
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
677
+ num_examples = self._sklearn_object.n_clusters
678
+ elif hasattr(self._sklearn_object, "min_samples"):
679
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
680
+ num_examples = self._sklearn_object.min_samples
681
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
682
+ # LocalOutlierFactor expects n_neighbors <= n_samples
683
+ num_examples = self._sklearn_object.n_neighbors
684
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
685
+ else:
686
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
650
687
 
651
688
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
652
689
  # seen during the fit.
@@ -658,12 +695,14 @@ class LocalOutlierFactor(BaseTransformer):
658
695
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
659
696
  if self.sample_weight_col:
660
697
  output_df_columns_set -= set(self.sample_weight_col)
698
+
661
699
  # if the dimension of inferred output column names is correct; use it
662
700
  if len(expected_output_cols_list) == len(output_df_columns_set):
663
- return expected_output_cols_list
701
+ return expected_output_cols_list, output_df_pd
664
702
  # otherwise, use the sklearn estimator's output
665
703
  else:
666
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
704
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
705
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
667
706
 
668
707
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
669
708
  @telemetry.send_api_usage_telemetry(
@@ -709,7 +748,7 @@ class LocalOutlierFactor(BaseTransformer):
709
748
  drop_input_cols=self._drop_input_cols,
710
749
  expected_output_cols_type="float",
711
750
  )
712
- expected_output_cols = self._align_expected_output_names(
751
+ expected_output_cols, _ = self._align_expected_output(
713
752
  inference_method, dataset, expected_output_cols, output_cols_prefix
714
753
  )
715
754
 
@@ -775,7 +814,7 @@ class LocalOutlierFactor(BaseTransformer):
775
814
  drop_input_cols=self._drop_input_cols,
776
815
  expected_output_cols_type="float",
777
816
  )
778
- expected_output_cols = self._align_expected_output_names(
817
+ expected_output_cols, _ = self._align_expected_output(
779
818
  inference_method, dataset, expected_output_cols, output_cols_prefix
780
819
  )
781
820
  elif isinstance(dataset, pd.DataFrame):
@@ -840,7 +879,7 @@ class LocalOutlierFactor(BaseTransformer):
840
879
  drop_input_cols=self._drop_input_cols,
841
880
  expected_output_cols_type="float",
842
881
  )
843
- expected_output_cols = self._align_expected_output_names(
882
+ expected_output_cols, _ = self._align_expected_output(
844
883
  inference_method, dataset, expected_output_cols, output_cols_prefix
845
884
  )
846
885
 
@@ -907,7 +946,7 @@ class LocalOutlierFactor(BaseTransformer):
907
946
  drop_input_cols = self._drop_input_cols,
908
947
  expected_output_cols_type="float",
909
948
  )
910
- expected_output_cols = self._align_expected_output_names(
949
+ expected_output_cols, _ = self._align_expected_output(
911
950
  inference_method, dataset, expected_output_cols, output_cols_prefix
912
951
  )
913
952
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -486,12 +483,23 @@ class NearestCentroid(BaseTransformer):
486
483
  autogenerated=self._autogenerated,
487
484
  subproject=_SUBPROJECT,
488
485
  )
489
- output_result, fitted_estimator = model_trainer.train_fit_predict(
490
- drop_input_cols=self._drop_input_cols,
491
- expected_output_cols_list=(
492
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
493
- ),
486
+ expected_output_cols = (
487
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
494
488
  )
489
+ if isinstance(dataset, DataFrame):
490
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
491
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
492
+ )
493
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
494
+ drop_input_cols=self._drop_input_cols,
495
+ expected_output_cols_list=expected_output_cols,
496
+ example_output_pd_df=example_output_pd_df,
497
+ )
498
+ else:
499
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
500
+ drop_input_cols=self._drop_input_cols,
501
+ expected_output_cols_list=expected_output_cols,
502
+ )
495
503
  self._sklearn_object = fitted_estimator
496
504
  self._is_fitted = True
497
505
  return output_result
@@ -570,12 +578,41 @@ class NearestCentroid(BaseTransformer):
570
578
 
571
579
  return rv
572
580
 
573
- def _align_expected_output_names(
574
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
575
- ) -> List[str]:
581
+ def _align_expected_output(
582
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
583
+ ) -> Tuple[List[str], pd.DataFrame]:
584
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
585
+ and output dataframe with 1 line.
586
+ If the method is fit_predict, run 2 lines of data.
587
+ """
576
588
  # in case the inferred output column names dimension is different
577
589
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
578
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
590
+
591
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
592
+ # so change the minimum of number of rows to 2
593
+ num_examples = 2
594
+ statement_params = telemetry.get_function_usage_statement_params(
595
+ project=_PROJECT,
596
+ subproject=_SUBPROJECT,
597
+ function_name=telemetry.get_statement_params_full_func_name(
598
+ inspect.currentframe(), NearestCentroid.__class__.__name__
599
+ ),
600
+ api_calls=[Session.call],
601
+ custom_tags={"autogen": True} if self._autogenerated else None,
602
+ )
603
+ if output_cols_prefix == "fit_predict_":
604
+ if hasattr(self._sklearn_object, "n_clusters"):
605
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
606
+ num_examples = self._sklearn_object.n_clusters
607
+ elif hasattr(self._sklearn_object, "min_samples"):
608
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
609
+ num_examples = self._sklearn_object.min_samples
610
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
611
+ # LocalOutlierFactor expects n_neighbors <= n_samples
612
+ num_examples = self._sklearn_object.n_neighbors
613
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
614
+ else:
615
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
579
616
 
580
617
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
581
618
  # seen during the fit.
@@ -587,12 +624,14 @@ class NearestCentroid(BaseTransformer):
587
624
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
588
625
  if self.sample_weight_col:
589
626
  output_df_columns_set -= set(self.sample_weight_col)
627
+
590
628
  # if the dimension of inferred output column names is correct; use it
591
629
  if len(expected_output_cols_list) == len(output_df_columns_set):
592
- return expected_output_cols_list
630
+ return expected_output_cols_list, output_df_pd
593
631
  # otherwise, use the sklearn estimator's output
594
632
  else:
595
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
633
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
634
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
596
635
 
597
636
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
598
637
  @telemetry.send_api_usage_telemetry(
@@ -638,7 +677,7 @@ class NearestCentroid(BaseTransformer):
638
677
  drop_input_cols=self._drop_input_cols,
639
678
  expected_output_cols_type="float",
640
679
  )
641
- expected_output_cols = self._align_expected_output_names(
680
+ expected_output_cols, _ = self._align_expected_output(
642
681
  inference_method, dataset, expected_output_cols, output_cols_prefix
643
682
  )
644
683
 
@@ -704,7 +743,7 @@ class NearestCentroid(BaseTransformer):
704
743
  drop_input_cols=self._drop_input_cols,
705
744
  expected_output_cols_type="float",
706
745
  )
707
- expected_output_cols = self._align_expected_output_names(
746
+ expected_output_cols, _ = self._align_expected_output(
708
747
  inference_method, dataset, expected_output_cols, output_cols_prefix
709
748
  )
710
749
  elif isinstance(dataset, pd.DataFrame):
@@ -767,7 +806,7 @@ class NearestCentroid(BaseTransformer):
767
806
  drop_input_cols=self._drop_input_cols,
768
807
  expected_output_cols_type="float",
769
808
  )
770
- expected_output_cols = self._align_expected_output_names(
809
+ expected_output_cols, _ = self._align_expected_output(
771
810
  inference_method, dataset, expected_output_cols, output_cols_prefix
772
811
  )
773
812
 
@@ -832,7 +871,7 @@ class NearestCentroid(BaseTransformer):
832
871
  drop_input_cols = self._drop_input_cols,
833
872
  expected_output_cols_type="float",
834
873
  )
835
- expected_output_cols = self._align_expected_output_names(
874
+ expected_output_cols, _ = self._align_expected_output(
836
875
  inference_method, dataset, expected_output_cols, output_cols_prefix
837
876
  )
838
877