snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +156 -20
  5. snowflake/ml/_internal/utils/identifier.py +48 -11
  6. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  7. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  8. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  9. snowflake/ml/_internal/utils/table_manager.py +19 -1
  10. snowflake/ml/_internal/utils/uri.py +2 -2
  11. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  12. snowflake/ml/data/data_connector.py +88 -9
  13. snowflake/ml/data/data_ingestor.py +18 -1
  14. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  15. snowflake/ml/data/torch_utils.py +68 -0
  16. snowflake/ml/dataset/dataset.py +1 -3
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +9 -3
  19. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  20. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  21. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  22. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  23. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  24. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  26. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  27. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  29. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  31. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  32. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  33. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  34. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  35. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  36. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  37. snowflake/ml/feature_store/feature_store.py +100 -41
  38. snowflake/ml/feature_store/feature_view.py +149 -5
  39. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  40. snowflake/ml/fileset/fileset.py +1 -1
  41. snowflake/ml/fileset/sfcfs.py +9 -3
  42. snowflake/ml/model/_client/model/model_impl.py +11 -2
  43. snowflake/ml/model/_client/model/model_version_impl.py +186 -20
  44. snowflake/ml/model/_client/ops/model_ops.py +144 -30
  45. snowflake/ml/model/_client/ops/service_ops.py +312 -0
  46. snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
  47. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
  48. snowflake/ml/model/_client/sql/model_version.py +13 -4
  49. snowflake/ml/model/_client/sql/service.py +196 -0
  50. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  51. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  52. snowflake/ml/model/_model_composer/model_composer.py +5 -0
  53. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
  54. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  55. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  57. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  58. snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
  59. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  60. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
  61. snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
  62. snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
  63. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  64. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  65. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  66. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  67. snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
  68. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
  69. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  70. snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
  71. snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
  72. snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
  73. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
  74. snowflake/ml/model/_packager/model_packager.py +4 -1
  75. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  76. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  77. snowflake/ml/model/_signatures/utils.py +9 -0
  78. snowflake/ml/model/models/llm.py +3 -1
  79. snowflake/ml/model/type_hints.py +10 -4
  80. snowflake/ml/modeling/_internal/constants.py +1 -0
  81. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  82. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  83. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  84. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  85. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  86. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  87. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  88. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  89. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  90. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  91. snowflake/ml/modeling/cluster/birch.py +60 -21
  92. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  93. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  94. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  95. snowflake/ml/modeling/cluster/k_means.py +60 -21
  96. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  97. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  98. snowflake/ml/modeling/cluster/optics.py +60 -21
  99. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  100. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  101. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  102. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  103. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  104. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  105. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  106. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  107. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  108. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  109. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  110. snowflake/ml/modeling/covariance/oas.py +60 -21
  111. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  112. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  113. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  114. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  115. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  116. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  117. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  118. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  119. snowflake/ml/modeling/decomposition/pca.py +60 -21
  120. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  121. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  122. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  123. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  124. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  125. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  126. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  127. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  128. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  129. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  130. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  131. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  132. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  133. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  134. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  135. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  136. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  137. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  138. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  139. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  140. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  141. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  142. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  143. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  144. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  145. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  146. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  147. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  148. snowflake/ml/modeling/framework/base.py +28 -19
  149. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  150. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  151. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  152. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  153. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  154. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  155. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  156. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  157. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  158. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  159. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  160. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  161. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  162. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  163. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  164. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  165. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  166. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  167. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  168. snowflake/ml/modeling/linear_model/lars.py +60 -21
  169. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  170. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  171. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  172. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  173. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  174. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  175. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  176. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  177. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  178. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  179. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  180. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  181. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  182. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  183. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  184. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  185. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  186. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  187. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  188. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  189. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  190. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  191. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  192. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  193. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  194. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  195. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  196. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  197. snowflake/ml/modeling/manifold/isomap.py +60 -21
  198. snowflake/ml/modeling/manifold/mds.py +60 -21
  199. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  200. snowflake/ml/modeling/manifold/tsne.py +60 -21
  201. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  202. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  203. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  204. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  205. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  206. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  207. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  208. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  209. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  210. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  211. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  212. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  213. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  214. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  215. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  216. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  217. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  218. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  219. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  220. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  221. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  222. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  223. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  224. snowflake/ml/modeling/pipeline/pipeline.py +4 -12
  225. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  226. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  227. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  228. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  229. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  230. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  231. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  232. snowflake/ml/modeling/svm/svc.py +60 -21
  233. snowflake/ml/modeling/svm/svr.py +60 -21
  234. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  235. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  236. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  237. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  238. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  239. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  240. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  241. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  242. snowflake/ml/registry/_manager/model_manager.py +20 -2
  243. snowflake/ml/registry/model_registry.py +1 -1
  244. snowflake/ml/registry/registry.py +1 -2
  245. snowflake/ml/utils/sql_client.py +22 -0
  246. snowflake/ml/version.py +1 -1
  247. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
  248. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
  249. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  250. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  251. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  252. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -495,12 +492,23 @@ class ComplementNB(BaseTransformer):
495
492
  autogenerated=self._autogenerated,
496
493
  subproject=_SUBPROJECT,
497
494
  )
498
- output_result, fitted_estimator = model_trainer.train_fit_predict(
499
- drop_input_cols=self._drop_input_cols,
500
- expected_output_cols_list=(
501
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
502
- ),
495
+ expected_output_cols = (
496
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
503
497
  )
498
+ if isinstance(dataset, DataFrame):
499
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
500
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
501
+ )
502
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
503
+ drop_input_cols=self._drop_input_cols,
504
+ expected_output_cols_list=expected_output_cols,
505
+ example_output_pd_df=example_output_pd_df,
506
+ )
507
+ else:
508
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
509
+ drop_input_cols=self._drop_input_cols,
510
+ expected_output_cols_list=expected_output_cols,
511
+ )
504
512
  self._sklearn_object = fitted_estimator
505
513
  self._is_fitted = True
506
514
  return output_result
@@ -579,12 +587,41 @@ class ComplementNB(BaseTransformer):
579
587
 
580
588
  return rv
581
589
 
582
- def _align_expected_output_names(
583
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
584
- ) -> List[str]:
590
+ def _align_expected_output(
591
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
592
+ ) -> Tuple[List[str], pd.DataFrame]:
593
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
594
+ and output dataframe with 1 line.
595
+ If the method is fit_predict, run 2 lines of data.
596
+ """
585
597
  # in case the inferred output column names dimension is different
586
598
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
587
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
599
+
600
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
601
+ # so change the minimum of number of rows to 2
602
+ num_examples = 2
603
+ statement_params = telemetry.get_function_usage_statement_params(
604
+ project=_PROJECT,
605
+ subproject=_SUBPROJECT,
606
+ function_name=telemetry.get_statement_params_full_func_name(
607
+ inspect.currentframe(), ComplementNB.__class__.__name__
608
+ ),
609
+ api_calls=[Session.call],
610
+ custom_tags={"autogen": True} if self._autogenerated else None,
611
+ )
612
+ if output_cols_prefix == "fit_predict_":
613
+ if hasattr(self._sklearn_object, "n_clusters"):
614
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
615
+ num_examples = self._sklearn_object.n_clusters
616
+ elif hasattr(self._sklearn_object, "min_samples"):
617
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
618
+ num_examples = self._sklearn_object.min_samples
619
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
620
+ # LocalOutlierFactor expects n_neighbors <= n_samples
621
+ num_examples = self._sklearn_object.n_neighbors
622
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
623
+ else:
624
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
588
625
 
589
626
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
590
627
  # seen during the fit.
@@ -596,12 +633,14 @@ class ComplementNB(BaseTransformer):
596
633
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
597
634
  if self.sample_weight_col:
598
635
  output_df_columns_set -= set(self.sample_weight_col)
636
+
599
637
  # if the dimension of inferred output column names is correct; use it
600
638
  if len(expected_output_cols_list) == len(output_df_columns_set):
601
- return expected_output_cols_list
639
+ return expected_output_cols_list, output_df_pd
602
640
  # otherwise, use the sklearn estimator's output
603
641
  else:
604
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
642
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
643
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
605
644
 
606
645
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
607
646
  @telemetry.send_api_usage_telemetry(
@@ -649,7 +688,7 @@ class ComplementNB(BaseTransformer):
649
688
  drop_input_cols=self._drop_input_cols,
650
689
  expected_output_cols_type="float",
651
690
  )
652
- expected_output_cols = self._align_expected_output_names(
691
+ expected_output_cols, _ = self._align_expected_output(
653
692
  inference_method, dataset, expected_output_cols, output_cols_prefix
654
693
  )
655
694
 
@@ -717,7 +756,7 @@ class ComplementNB(BaseTransformer):
717
756
  drop_input_cols=self._drop_input_cols,
718
757
  expected_output_cols_type="float",
719
758
  )
720
- expected_output_cols = self._align_expected_output_names(
759
+ expected_output_cols, _ = self._align_expected_output(
721
760
  inference_method, dataset, expected_output_cols, output_cols_prefix
722
761
  )
723
762
  elif isinstance(dataset, pd.DataFrame):
@@ -780,7 +819,7 @@ class ComplementNB(BaseTransformer):
780
819
  drop_input_cols=self._drop_input_cols,
781
820
  expected_output_cols_type="float",
782
821
  )
783
- expected_output_cols = self._align_expected_output_names(
822
+ expected_output_cols, _ = self._align_expected_output(
784
823
  inference_method, dataset, expected_output_cols, output_cols_prefix
785
824
  )
786
825
 
@@ -845,7 +884,7 @@ class ComplementNB(BaseTransformer):
845
884
  drop_input_cols = self._drop_input_cols,
846
885
  expected_output_cols_type="float",
847
886
  )
848
- expected_output_cols = self._align_expected_output_names(
887
+ expected_output_cols, _ = self._align_expected_output(
849
888
  inference_method, dataset, expected_output_cols, output_cols_prefix
850
889
  )
851
890
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -476,12 +473,23 @@ class GaussianNB(BaseTransformer):
476
473
  autogenerated=self._autogenerated,
477
474
  subproject=_SUBPROJECT,
478
475
  )
479
- output_result, fitted_estimator = model_trainer.train_fit_predict(
480
- drop_input_cols=self._drop_input_cols,
481
- expected_output_cols_list=(
482
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
483
- ),
476
+ expected_output_cols = (
477
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
484
478
  )
479
+ if isinstance(dataset, DataFrame):
480
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
481
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
482
+ )
483
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
484
+ drop_input_cols=self._drop_input_cols,
485
+ expected_output_cols_list=expected_output_cols,
486
+ example_output_pd_df=example_output_pd_df,
487
+ )
488
+ else:
489
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
490
+ drop_input_cols=self._drop_input_cols,
491
+ expected_output_cols_list=expected_output_cols,
492
+ )
485
493
  self._sklearn_object = fitted_estimator
486
494
  self._is_fitted = True
487
495
  return output_result
@@ -560,12 +568,41 @@ class GaussianNB(BaseTransformer):
560
568
 
561
569
  return rv
562
570
 
563
- def _align_expected_output_names(
564
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
565
- ) -> List[str]:
571
+ def _align_expected_output(
572
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
573
+ ) -> Tuple[List[str], pd.DataFrame]:
574
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
575
+ and output dataframe with 1 line.
576
+ If the method is fit_predict, run 2 lines of data.
577
+ """
566
578
  # in case the inferred output column names dimension is different
567
579
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
568
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
580
+
581
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
582
+ # so change the minimum of number of rows to 2
583
+ num_examples = 2
584
+ statement_params = telemetry.get_function_usage_statement_params(
585
+ project=_PROJECT,
586
+ subproject=_SUBPROJECT,
587
+ function_name=telemetry.get_statement_params_full_func_name(
588
+ inspect.currentframe(), GaussianNB.__class__.__name__
589
+ ),
590
+ api_calls=[Session.call],
591
+ custom_tags={"autogen": True} if self._autogenerated else None,
592
+ )
593
+ if output_cols_prefix == "fit_predict_":
594
+ if hasattr(self._sklearn_object, "n_clusters"):
595
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
596
+ num_examples = self._sklearn_object.n_clusters
597
+ elif hasattr(self._sklearn_object, "min_samples"):
598
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
599
+ num_examples = self._sklearn_object.min_samples
600
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
601
+ # LocalOutlierFactor expects n_neighbors <= n_samples
602
+ num_examples = self._sklearn_object.n_neighbors
603
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
604
+ else:
605
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
569
606
 
570
607
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
571
608
  # seen during the fit.
@@ -577,12 +614,14 @@ class GaussianNB(BaseTransformer):
577
614
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
578
615
  if self.sample_weight_col:
579
616
  output_df_columns_set -= set(self.sample_weight_col)
617
+
580
618
  # if the dimension of inferred output column names is correct; use it
581
619
  if len(expected_output_cols_list) == len(output_df_columns_set):
582
- return expected_output_cols_list
620
+ return expected_output_cols_list, output_df_pd
583
621
  # otherwise, use the sklearn estimator's output
584
622
  else:
585
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
623
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
624
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
586
625
 
587
626
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
588
627
  @telemetry.send_api_usage_telemetry(
@@ -630,7 +669,7 @@ class GaussianNB(BaseTransformer):
630
669
  drop_input_cols=self._drop_input_cols,
631
670
  expected_output_cols_type="float",
632
671
  )
633
- expected_output_cols = self._align_expected_output_names(
672
+ expected_output_cols, _ = self._align_expected_output(
634
673
  inference_method, dataset, expected_output_cols, output_cols_prefix
635
674
  )
636
675
 
@@ -698,7 +737,7 @@ class GaussianNB(BaseTransformer):
698
737
  drop_input_cols=self._drop_input_cols,
699
738
  expected_output_cols_type="float",
700
739
  )
701
- expected_output_cols = self._align_expected_output_names(
740
+ expected_output_cols, _ = self._align_expected_output(
702
741
  inference_method, dataset, expected_output_cols, output_cols_prefix
703
742
  )
704
743
  elif isinstance(dataset, pd.DataFrame):
@@ -761,7 +800,7 @@ class GaussianNB(BaseTransformer):
761
800
  drop_input_cols=self._drop_input_cols,
762
801
  expected_output_cols_type="float",
763
802
  )
764
- expected_output_cols = self._align_expected_output_names(
803
+ expected_output_cols, _ = self._align_expected_output(
765
804
  inference_method, dataset, expected_output_cols, output_cols_prefix
766
805
  )
767
806
 
@@ -826,7 +865,7 @@ class GaussianNB(BaseTransformer):
826
865
  drop_input_cols = self._drop_input_cols,
827
866
  expected_output_cols_type="float",
828
867
  )
829
- expected_output_cols = self._align_expected_output_names(
868
+ expected_output_cols, _ = self._align_expected_output(
830
869
  inference_method, dataset, expected_output_cols, output_cols_prefix
831
870
  )
832
871
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -489,12 +486,23 @@ class MultinomialNB(BaseTransformer):
489
486
  autogenerated=self._autogenerated,
490
487
  subproject=_SUBPROJECT,
491
488
  )
492
- output_result, fitted_estimator = model_trainer.train_fit_predict(
493
- drop_input_cols=self._drop_input_cols,
494
- expected_output_cols_list=(
495
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
496
- ),
489
+ expected_output_cols = (
490
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
497
491
  )
492
+ if isinstance(dataset, DataFrame):
493
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
494
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
495
+ )
496
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
497
+ drop_input_cols=self._drop_input_cols,
498
+ expected_output_cols_list=expected_output_cols,
499
+ example_output_pd_df=example_output_pd_df,
500
+ )
501
+ else:
502
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
503
+ drop_input_cols=self._drop_input_cols,
504
+ expected_output_cols_list=expected_output_cols,
505
+ )
498
506
  self._sklearn_object = fitted_estimator
499
507
  self._is_fitted = True
500
508
  return output_result
@@ -573,12 +581,41 @@ class MultinomialNB(BaseTransformer):
573
581
 
574
582
  return rv
575
583
 
576
- def _align_expected_output_names(
577
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
578
- ) -> List[str]:
584
+ def _align_expected_output(
585
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
586
+ ) -> Tuple[List[str], pd.DataFrame]:
587
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
588
+ and output dataframe with 1 line.
589
+ If the method is fit_predict, run 2 lines of data.
590
+ """
579
591
  # in case the inferred output column names dimension is different
580
592
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
581
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
593
+
594
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
595
+ # so change the minimum of number of rows to 2
596
+ num_examples = 2
597
+ statement_params = telemetry.get_function_usage_statement_params(
598
+ project=_PROJECT,
599
+ subproject=_SUBPROJECT,
600
+ function_name=telemetry.get_statement_params_full_func_name(
601
+ inspect.currentframe(), MultinomialNB.__class__.__name__
602
+ ),
603
+ api_calls=[Session.call],
604
+ custom_tags={"autogen": True} if self._autogenerated else None,
605
+ )
606
+ if output_cols_prefix == "fit_predict_":
607
+ if hasattr(self._sklearn_object, "n_clusters"):
608
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
609
+ num_examples = self._sklearn_object.n_clusters
610
+ elif hasattr(self._sklearn_object, "min_samples"):
611
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
612
+ num_examples = self._sklearn_object.min_samples
613
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
614
+ # LocalOutlierFactor expects n_neighbors <= n_samples
615
+ num_examples = self._sklearn_object.n_neighbors
616
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
617
+ else:
618
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
582
619
 
583
620
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
584
621
  # seen during the fit.
@@ -590,12 +627,14 @@ class MultinomialNB(BaseTransformer):
590
627
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
591
628
  if self.sample_weight_col:
592
629
  output_df_columns_set -= set(self.sample_weight_col)
630
+
593
631
  # if the dimension of inferred output column names is correct; use it
594
632
  if len(expected_output_cols_list) == len(output_df_columns_set):
595
- return expected_output_cols_list
633
+ return expected_output_cols_list, output_df_pd
596
634
  # otherwise, use the sklearn estimator's output
597
635
  else:
598
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
636
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
637
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
599
638
 
600
639
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
601
640
  @telemetry.send_api_usage_telemetry(
@@ -643,7 +682,7 @@ class MultinomialNB(BaseTransformer):
643
682
  drop_input_cols=self._drop_input_cols,
644
683
  expected_output_cols_type="float",
645
684
  )
646
- expected_output_cols = self._align_expected_output_names(
685
+ expected_output_cols, _ = self._align_expected_output(
647
686
  inference_method, dataset, expected_output_cols, output_cols_prefix
648
687
  )
649
688
 
@@ -711,7 +750,7 @@ class MultinomialNB(BaseTransformer):
711
750
  drop_input_cols=self._drop_input_cols,
712
751
  expected_output_cols_type="float",
713
752
  )
714
- expected_output_cols = self._align_expected_output_names(
753
+ expected_output_cols, _ = self._align_expected_output(
715
754
  inference_method, dataset, expected_output_cols, output_cols_prefix
716
755
  )
717
756
  elif isinstance(dataset, pd.DataFrame):
@@ -774,7 +813,7 @@ class MultinomialNB(BaseTransformer):
774
813
  drop_input_cols=self._drop_input_cols,
775
814
  expected_output_cols_type="float",
776
815
  )
777
- expected_output_cols = self._align_expected_output_names(
816
+ expected_output_cols, _ = self._align_expected_output(
778
817
  inference_method, dataset, expected_output_cols, output_cols_prefix
779
818
  )
780
819
 
@@ -839,7 +878,7 @@ class MultinomialNB(BaseTransformer):
839
878
  drop_input_cols = self._drop_input_cols,
840
879
  expected_output_cols_type="float",
841
880
  )
842
- expected_output_cols = self._align_expected_output_names(
881
+ expected_output_cols, _ = self._align_expected_output(
843
882
  inference_method, dataset, expected_output_cols, output_cols_prefix
844
883
  )
845
884
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -546,12 +543,23 @@ class KNeighborsClassifier(BaseTransformer):
546
543
  autogenerated=self._autogenerated,
547
544
  subproject=_SUBPROJECT,
548
545
  )
549
- output_result, fitted_estimator = model_trainer.train_fit_predict(
550
- drop_input_cols=self._drop_input_cols,
551
- expected_output_cols_list=(
552
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
553
- ),
546
+ expected_output_cols = (
547
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
554
548
  )
549
+ if isinstance(dataset, DataFrame):
550
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
551
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
552
+ )
553
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
554
+ drop_input_cols=self._drop_input_cols,
555
+ expected_output_cols_list=expected_output_cols,
556
+ example_output_pd_df=example_output_pd_df,
557
+ )
558
+ else:
559
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
560
+ drop_input_cols=self._drop_input_cols,
561
+ expected_output_cols_list=expected_output_cols,
562
+ )
555
563
  self._sklearn_object = fitted_estimator
556
564
  self._is_fitted = True
557
565
  return output_result
@@ -630,12 +638,41 @@ class KNeighborsClassifier(BaseTransformer):
630
638
 
631
639
  return rv
632
640
 
633
- def _align_expected_output_names(
634
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
635
- ) -> List[str]:
641
+ def _align_expected_output(
642
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
643
+ ) -> Tuple[List[str], pd.DataFrame]:
644
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
645
+ and output dataframe with 1 line.
646
+ If the method is fit_predict, run 2 lines of data.
647
+ """
636
648
  # in case the inferred output column names dimension is different
637
649
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
638
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
650
+
651
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
652
+ # so change the minimum of number of rows to 2
653
+ num_examples = 2
654
+ statement_params = telemetry.get_function_usage_statement_params(
655
+ project=_PROJECT,
656
+ subproject=_SUBPROJECT,
657
+ function_name=telemetry.get_statement_params_full_func_name(
658
+ inspect.currentframe(), KNeighborsClassifier.__class__.__name__
659
+ ),
660
+ api_calls=[Session.call],
661
+ custom_tags={"autogen": True} if self._autogenerated else None,
662
+ )
663
+ if output_cols_prefix == "fit_predict_":
664
+ if hasattr(self._sklearn_object, "n_clusters"):
665
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
666
+ num_examples = self._sklearn_object.n_clusters
667
+ elif hasattr(self._sklearn_object, "min_samples"):
668
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
669
+ num_examples = self._sklearn_object.min_samples
670
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
671
+ # LocalOutlierFactor expects n_neighbors <= n_samples
672
+ num_examples = self._sklearn_object.n_neighbors
673
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
674
+ else:
675
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
639
676
 
640
677
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
641
678
  # seen during the fit.
@@ -647,12 +684,14 @@ class KNeighborsClassifier(BaseTransformer):
647
684
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
648
685
  if self.sample_weight_col:
649
686
  output_df_columns_set -= set(self.sample_weight_col)
687
+
650
688
  # if the dimension of inferred output column names is correct; use it
651
689
  if len(expected_output_cols_list) == len(output_df_columns_set):
652
- return expected_output_cols_list
690
+ return expected_output_cols_list, output_df_pd
653
691
  # otherwise, use the sklearn estimator's output
654
692
  else:
655
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
693
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
694
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
656
695
 
657
696
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
658
697
  @telemetry.send_api_usage_telemetry(
@@ -700,7 +739,7 @@ class KNeighborsClassifier(BaseTransformer):
700
739
  drop_input_cols=self._drop_input_cols,
701
740
  expected_output_cols_type="float",
702
741
  )
703
- expected_output_cols = self._align_expected_output_names(
742
+ expected_output_cols, _ = self._align_expected_output(
704
743
  inference_method, dataset, expected_output_cols, output_cols_prefix
705
744
  )
706
745
 
@@ -768,7 +807,7 @@ class KNeighborsClassifier(BaseTransformer):
768
807
  drop_input_cols=self._drop_input_cols,
769
808
  expected_output_cols_type="float",
770
809
  )
771
- expected_output_cols = self._align_expected_output_names(
810
+ expected_output_cols, _ = self._align_expected_output(
772
811
  inference_method, dataset, expected_output_cols, output_cols_prefix
773
812
  )
774
813
  elif isinstance(dataset, pd.DataFrame):
@@ -831,7 +870,7 @@ class KNeighborsClassifier(BaseTransformer):
831
870
  drop_input_cols=self._drop_input_cols,
832
871
  expected_output_cols_type="float",
833
872
  )
834
- expected_output_cols = self._align_expected_output_names(
873
+ expected_output_cols, _ = self._align_expected_output(
835
874
  inference_method, dataset, expected_output_cols, output_cols_prefix
836
875
  )
837
876
 
@@ -896,7 +935,7 @@ class KNeighborsClassifier(BaseTransformer):
896
935
  drop_input_cols = self._drop_input_cols,
897
936
  expected_output_cols_type="float",
898
937
  )
899
- expected_output_cols = self._align_expected_output_names(
938
+ expected_output_cols, _ = self._align_expected_output(
900
939
  inference_method, dataset, expected_output_cols, output_cols_prefix
901
940
  )
902
941