snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +156 -20
  5. snowflake/ml/_internal/utils/identifier.py +48 -11
  6. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  7. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  8. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  9. snowflake/ml/_internal/utils/table_manager.py +19 -1
  10. snowflake/ml/_internal/utils/uri.py +2 -2
  11. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  12. snowflake/ml/data/data_connector.py +88 -9
  13. snowflake/ml/data/data_ingestor.py +18 -1
  14. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  15. snowflake/ml/data/torch_utils.py +68 -0
  16. snowflake/ml/dataset/dataset.py +1 -3
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +9 -3
  19. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  20. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  21. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  22. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  23. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  24. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  26. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  27. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  29. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  31. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  32. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  33. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  34. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  35. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  36. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  37. snowflake/ml/feature_store/feature_store.py +100 -41
  38. snowflake/ml/feature_store/feature_view.py +149 -5
  39. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  40. snowflake/ml/fileset/fileset.py +1 -1
  41. snowflake/ml/fileset/sfcfs.py +9 -3
  42. snowflake/ml/model/_client/model/model_impl.py +11 -2
  43. snowflake/ml/model/_client/model/model_version_impl.py +186 -20
  44. snowflake/ml/model/_client/ops/model_ops.py +144 -30
  45. snowflake/ml/model/_client/ops/service_ops.py +312 -0
  46. snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
  47. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
  48. snowflake/ml/model/_client/sql/model_version.py +13 -4
  49. snowflake/ml/model/_client/sql/service.py +196 -0
  50. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  51. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  52. snowflake/ml/model/_model_composer/model_composer.py +5 -0
  53. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
  54. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  55. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  57. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  58. snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
  59. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  60. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
  61. snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
  62. snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
  63. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  64. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  65. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  66. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  67. snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
  68. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
  69. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  70. snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
  71. snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
  72. snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
  73. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
  74. snowflake/ml/model/_packager/model_packager.py +4 -1
  75. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  76. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  77. snowflake/ml/model/_signatures/utils.py +9 -0
  78. snowflake/ml/model/models/llm.py +3 -1
  79. snowflake/ml/model/type_hints.py +10 -4
  80. snowflake/ml/modeling/_internal/constants.py +1 -0
  81. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  82. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  83. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  84. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  85. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  86. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  87. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  88. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  89. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  90. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  91. snowflake/ml/modeling/cluster/birch.py +60 -21
  92. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  93. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  94. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  95. snowflake/ml/modeling/cluster/k_means.py +60 -21
  96. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  97. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  98. snowflake/ml/modeling/cluster/optics.py +60 -21
  99. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  100. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  101. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  102. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  103. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  104. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  105. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  106. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  107. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  108. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  109. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  110. snowflake/ml/modeling/covariance/oas.py +60 -21
  111. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  112. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  113. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  114. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  115. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  116. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  117. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  118. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  119. snowflake/ml/modeling/decomposition/pca.py +60 -21
  120. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  121. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  122. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  123. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  124. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  125. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  126. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  127. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  128. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  129. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  130. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  131. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  132. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  133. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  134. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  135. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  136. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  137. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  138. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  139. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  140. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  141. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  142. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  143. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  144. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  145. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  146. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  147. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  148. snowflake/ml/modeling/framework/base.py +28 -19
  149. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  150. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  151. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  152. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  153. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  154. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  155. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  156. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  157. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  158. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  159. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  160. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  161. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  162. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  163. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  164. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  165. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  166. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  167. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  168. snowflake/ml/modeling/linear_model/lars.py +60 -21
  169. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  170. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  171. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  172. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  173. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  174. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  175. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  176. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  177. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  178. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  179. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  180. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  181. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  182. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  183. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  184. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  185. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  186. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  187. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  188. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  189. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  190. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  191. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  192. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  193. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  194. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  195. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  196. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  197. snowflake/ml/modeling/manifold/isomap.py +60 -21
  198. snowflake/ml/modeling/manifold/mds.py +60 -21
  199. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  200. snowflake/ml/modeling/manifold/tsne.py +60 -21
  201. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  202. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  203. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  204. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  205. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  206. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  207. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  208. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  209. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  210. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  211. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  212. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  213. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  214. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  215. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  216. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  217. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  218. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  219. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  220. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  221. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  222. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  223. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  224. snowflake/ml/modeling/pipeline/pipeline.py +4 -12
  225. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  226. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  227. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  228. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  229. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  230. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  231. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  232. snowflake/ml/modeling/svm/svc.py +60 -21
  233. snowflake/ml/modeling/svm/svr.py +60 -21
  234. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  235. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  236. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  237. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  238. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  239. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  240. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  241. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  242. snowflake/ml/registry/_manager/model_manager.py +20 -2
  243. snowflake/ml/registry/model_registry.py +1 -1
  244. snowflake/ml/registry/registry.py +1 -2
  245. snowflake/ml/utils/sql_client.py +22 -0
  246. snowflake/ml/version.py +1 -1
  247. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
  248. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
  249. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  250. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  251. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  252. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -536,12 +533,23 @@ class NearestNeighbors(BaseTransformer):
536
533
  autogenerated=self._autogenerated,
537
534
  subproject=_SUBPROJECT,
538
535
  )
539
- output_result, fitted_estimator = model_trainer.train_fit_predict(
540
- drop_input_cols=self._drop_input_cols,
541
- expected_output_cols_list=(
542
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
543
- ),
536
+ expected_output_cols = (
537
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
544
538
  )
539
+ if isinstance(dataset, DataFrame):
540
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
541
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
542
+ )
543
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
544
+ drop_input_cols=self._drop_input_cols,
545
+ expected_output_cols_list=expected_output_cols,
546
+ example_output_pd_df=example_output_pd_df,
547
+ )
548
+ else:
549
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
550
+ drop_input_cols=self._drop_input_cols,
551
+ expected_output_cols_list=expected_output_cols,
552
+ )
545
553
  self._sklearn_object = fitted_estimator
546
554
  self._is_fitted = True
547
555
  return output_result
@@ -620,12 +628,41 @@ class NearestNeighbors(BaseTransformer):
620
628
 
621
629
  return rv
622
630
 
623
- def _align_expected_output_names(
624
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
625
- ) -> List[str]:
631
+ def _align_expected_output(
632
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
633
+ ) -> Tuple[List[str], pd.DataFrame]:
634
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
635
+ and output dataframe with 1 line.
636
+ If the method is fit_predict, run 2 lines of data.
637
+ """
626
638
  # in case the inferred output column names dimension is different
627
639
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
628
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
640
+
641
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
642
+ # so change the minimum of number of rows to 2
643
+ num_examples = 2
644
+ statement_params = telemetry.get_function_usage_statement_params(
645
+ project=_PROJECT,
646
+ subproject=_SUBPROJECT,
647
+ function_name=telemetry.get_statement_params_full_func_name(
648
+ inspect.currentframe(), NearestNeighbors.__class__.__name__
649
+ ),
650
+ api_calls=[Session.call],
651
+ custom_tags={"autogen": True} if self._autogenerated else None,
652
+ )
653
+ if output_cols_prefix == "fit_predict_":
654
+ if hasattr(self._sklearn_object, "n_clusters"):
655
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
656
+ num_examples = self._sklearn_object.n_clusters
657
+ elif hasattr(self._sklearn_object, "min_samples"):
658
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
659
+ num_examples = self._sklearn_object.min_samples
660
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
661
+ # LocalOutlierFactor expects n_neighbors <= n_samples
662
+ num_examples = self._sklearn_object.n_neighbors
663
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
664
+ else:
665
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
629
666
 
630
667
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
631
668
  # seen during the fit.
@@ -637,12 +674,14 @@ class NearestNeighbors(BaseTransformer):
637
674
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
638
675
  if self.sample_weight_col:
639
676
  output_df_columns_set -= set(self.sample_weight_col)
677
+
640
678
  # if the dimension of inferred output column names is correct; use it
641
679
  if len(expected_output_cols_list) == len(output_df_columns_set):
642
- return expected_output_cols_list
680
+ return expected_output_cols_list, output_df_pd
643
681
  # otherwise, use the sklearn estimator's output
644
682
  else:
645
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
683
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
684
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
646
685
 
647
686
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
648
687
  @telemetry.send_api_usage_telemetry(
@@ -688,7 +727,7 @@ class NearestNeighbors(BaseTransformer):
688
727
  drop_input_cols=self._drop_input_cols,
689
728
  expected_output_cols_type="float",
690
729
  )
691
- expected_output_cols = self._align_expected_output_names(
730
+ expected_output_cols, _ = self._align_expected_output(
692
731
  inference_method, dataset, expected_output_cols, output_cols_prefix
693
732
  )
694
733
 
@@ -754,7 +793,7 @@ class NearestNeighbors(BaseTransformer):
754
793
  drop_input_cols=self._drop_input_cols,
755
794
  expected_output_cols_type="float",
756
795
  )
757
- expected_output_cols = self._align_expected_output_names(
796
+ expected_output_cols, _ = self._align_expected_output(
758
797
  inference_method, dataset, expected_output_cols, output_cols_prefix
759
798
  )
760
799
  elif isinstance(dataset, pd.DataFrame):
@@ -817,7 +856,7 @@ class NearestNeighbors(BaseTransformer):
817
856
  drop_input_cols=self._drop_input_cols,
818
857
  expected_output_cols_type="float",
819
858
  )
820
- expected_output_cols = self._align_expected_output_names(
859
+ expected_output_cols, _ = self._align_expected_output(
821
860
  inference_method, dataset, expected_output_cols, output_cols_prefix
822
861
  )
823
862
 
@@ -882,7 +921,7 @@ class NearestNeighbors(BaseTransformer):
882
921
  drop_input_cols = self._drop_input_cols,
883
922
  expected_output_cols_type="float",
884
923
  )
885
- expected_output_cols = self._align_expected_output_names(
924
+ expected_output_cols, _ = self._align_expected_output(
886
925
  inference_method, dataset, expected_output_cols, output_cols_prefix
887
926
  )
888
927
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -557,12 +554,23 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
557
554
  autogenerated=self._autogenerated,
558
555
  subproject=_SUBPROJECT,
559
556
  )
560
- output_result, fitted_estimator = model_trainer.train_fit_predict(
561
- drop_input_cols=self._drop_input_cols,
562
- expected_output_cols_list=(
563
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
564
- ),
557
+ expected_output_cols = (
558
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
565
559
  )
560
+ if isinstance(dataset, DataFrame):
561
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
562
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
563
+ )
564
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
565
+ drop_input_cols=self._drop_input_cols,
566
+ expected_output_cols_list=expected_output_cols,
567
+ example_output_pd_df=example_output_pd_df,
568
+ )
569
+ else:
570
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
571
+ drop_input_cols=self._drop_input_cols,
572
+ expected_output_cols_list=expected_output_cols,
573
+ )
566
574
  self._sklearn_object = fitted_estimator
567
575
  self._is_fitted = True
568
576
  return output_result
@@ -643,12 +651,41 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
643
651
 
644
652
  return rv
645
653
 
646
- def _align_expected_output_names(
647
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
648
- ) -> List[str]:
654
+ def _align_expected_output(
655
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
656
+ ) -> Tuple[List[str], pd.DataFrame]:
657
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
658
+ and output dataframe with 1 line.
659
+ If the method is fit_predict, run 2 lines of data.
660
+ """
649
661
  # in case the inferred output column names dimension is different
650
662
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
651
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
663
+
664
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
665
+ # so change the minimum of number of rows to 2
666
+ num_examples = 2
667
+ statement_params = telemetry.get_function_usage_statement_params(
668
+ project=_PROJECT,
669
+ subproject=_SUBPROJECT,
670
+ function_name=telemetry.get_statement_params_full_func_name(
671
+ inspect.currentframe(), NeighborhoodComponentsAnalysis.__class__.__name__
672
+ ),
673
+ api_calls=[Session.call],
674
+ custom_tags={"autogen": True} if self._autogenerated else None,
675
+ )
676
+ if output_cols_prefix == "fit_predict_":
677
+ if hasattr(self._sklearn_object, "n_clusters"):
678
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
679
+ num_examples = self._sklearn_object.n_clusters
680
+ elif hasattr(self._sklearn_object, "min_samples"):
681
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
682
+ num_examples = self._sklearn_object.min_samples
683
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
684
+ # LocalOutlierFactor expects n_neighbors <= n_samples
685
+ num_examples = self._sklearn_object.n_neighbors
686
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
687
+ else:
688
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
652
689
 
653
690
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
654
691
  # seen during the fit.
@@ -660,12 +697,14 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
660
697
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
661
698
  if self.sample_weight_col:
662
699
  output_df_columns_set -= set(self.sample_weight_col)
700
+
663
701
  # if the dimension of inferred output column names is correct; use it
664
702
  if len(expected_output_cols_list) == len(output_df_columns_set):
665
- return expected_output_cols_list
703
+ return expected_output_cols_list, output_df_pd
666
704
  # otherwise, use the sklearn estimator's output
667
705
  else:
668
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
706
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
707
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
669
708
 
670
709
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
671
710
  @telemetry.send_api_usage_telemetry(
@@ -711,7 +750,7 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
711
750
  drop_input_cols=self._drop_input_cols,
712
751
  expected_output_cols_type="float",
713
752
  )
714
- expected_output_cols = self._align_expected_output_names(
753
+ expected_output_cols, _ = self._align_expected_output(
715
754
  inference_method, dataset, expected_output_cols, output_cols_prefix
716
755
  )
717
756
 
@@ -777,7 +816,7 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
777
816
  drop_input_cols=self._drop_input_cols,
778
817
  expected_output_cols_type="float",
779
818
  )
780
- expected_output_cols = self._align_expected_output_names(
819
+ expected_output_cols, _ = self._align_expected_output(
781
820
  inference_method, dataset, expected_output_cols, output_cols_prefix
782
821
  )
783
822
  elif isinstance(dataset, pd.DataFrame):
@@ -840,7 +879,7 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
840
879
  drop_input_cols=self._drop_input_cols,
841
880
  expected_output_cols_type="float",
842
881
  )
843
- expected_output_cols = self._align_expected_output_names(
882
+ expected_output_cols, _ = self._align_expected_output(
844
883
  inference_method, dataset, expected_output_cols, output_cols_prefix
845
884
  )
846
885
 
@@ -905,7 +944,7 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
905
944
  drop_input_cols = self._drop_input_cols,
906
945
  expected_output_cols_type="float",
907
946
  )
908
- expected_output_cols = self._align_expected_output_names(
947
+ expected_output_cols, _ = self._align_expected_output(
909
948
  inference_method, dataset, expected_output_cols, output_cols_prefix
910
949
  )
911
950
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -558,12 +555,23 @@ class RadiusNeighborsClassifier(BaseTransformer):
558
555
  autogenerated=self._autogenerated,
559
556
  subproject=_SUBPROJECT,
560
557
  )
561
- output_result, fitted_estimator = model_trainer.train_fit_predict(
562
- drop_input_cols=self._drop_input_cols,
563
- expected_output_cols_list=(
564
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
565
- ),
558
+ expected_output_cols = (
559
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
566
560
  )
561
+ if isinstance(dataset, DataFrame):
562
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
563
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
564
+ )
565
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
566
+ drop_input_cols=self._drop_input_cols,
567
+ expected_output_cols_list=expected_output_cols,
568
+ example_output_pd_df=example_output_pd_df,
569
+ )
570
+ else:
571
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
572
+ drop_input_cols=self._drop_input_cols,
573
+ expected_output_cols_list=expected_output_cols,
574
+ )
567
575
  self._sklearn_object = fitted_estimator
568
576
  self._is_fitted = True
569
577
  return output_result
@@ -642,12 +650,41 @@ class RadiusNeighborsClassifier(BaseTransformer):
642
650
 
643
651
  return rv
644
652
 
645
- def _align_expected_output_names(
646
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
647
- ) -> List[str]:
653
+ def _align_expected_output(
654
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
655
+ ) -> Tuple[List[str], pd.DataFrame]:
656
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
657
+ and output dataframe with 1 line.
658
+ If the method is fit_predict, run 2 lines of data.
659
+ """
648
660
  # in case the inferred output column names dimension is different
649
661
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
650
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
662
+
663
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
664
+ # so change the minimum of number of rows to 2
665
+ num_examples = 2
666
+ statement_params = telemetry.get_function_usage_statement_params(
667
+ project=_PROJECT,
668
+ subproject=_SUBPROJECT,
669
+ function_name=telemetry.get_statement_params_full_func_name(
670
+ inspect.currentframe(), RadiusNeighborsClassifier.__class__.__name__
671
+ ),
672
+ api_calls=[Session.call],
673
+ custom_tags={"autogen": True} if self._autogenerated else None,
674
+ )
675
+ if output_cols_prefix == "fit_predict_":
676
+ if hasattr(self._sklearn_object, "n_clusters"):
677
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
678
+ num_examples = self._sklearn_object.n_clusters
679
+ elif hasattr(self._sklearn_object, "min_samples"):
680
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
681
+ num_examples = self._sklearn_object.min_samples
682
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
683
+ # LocalOutlierFactor expects n_neighbors <= n_samples
684
+ num_examples = self._sklearn_object.n_neighbors
685
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
686
+ else:
687
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
651
688
 
652
689
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
653
690
  # seen during the fit.
@@ -659,12 +696,14 @@ class RadiusNeighborsClassifier(BaseTransformer):
659
696
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
660
697
  if self.sample_weight_col:
661
698
  output_df_columns_set -= set(self.sample_weight_col)
699
+
662
700
  # if the dimension of inferred output column names is correct; use it
663
701
  if len(expected_output_cols_list) == len(output_df_columns_set):
664
- return expected_output_cols_list
702
+ return expected_output_cols_list, output_df_pd
665
703
  # otherwise, use the sklearn estimator's output
666
704
  else:
667
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
705
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
706
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
668
707
 
669
708
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
670
709
  @telemetry.send_api_usage_telemetry(
@@ -712,7 +751,7 @@ class RadiusNeighborsClassifier(BaseTransformer):
712
751
  drop_input_cols=self._drop_input_cols,
713
752
  expected_output_cols_type="float",
714
753
  )
715
- expected_output_cols = self._align_expected_output_names(
754
+ expected_output_cols, _ = self._align_expected_output(
716
755
  inference_method, dataset, expected_output_cols, output_cols_prefix
717
756
  )
718
757
 
@@ -780,7 +819,7 @@ class RadiusNeighborsClassifier(BaseTransformer):
780
819
  drop_input_cols=self._drop_input_cols,
781
820
  expected_output_cols_type="float",
782
821
  )
783
- expected_output_cols = self._align_expected_output_names(
822
+ expected_output_cols, _ = self._align_expected_output(
784
823
  inference_method, dataset, expected_output_cols, output_cols_prefix
785
824
  )
786
825
  elif isinstance(dataset, pd.DataFrame):
@@ -843,7 +882,7 @@ class RadiusNeighborsClassifier(BaseTransformer):
843
882
  drop_input_cols=self._drop_input_cols,
844
883
  expected_output_cols_type="float",
845
884
  )
846
- expected_output_cols = self._align_expected_output_names(
885
+ expected_output_cols, _ = self._align_expected_output(
847
886
  inference_method, dataset, expected_output_cols, output_cols_prefix
848
887
  )
849
888
 
@@ -908,7 +947,7 @@ class RadiusNeighborsClassifier(BaseTransformer):
908
947
  drop_input_cols = self._drop_input_cols,
909
948
  expected_output_cols_type="float",
910
949
  )
911
- expected_output_cols = self._align_expected_output_names(
950
+ expected_output_cols, _ = self._align_expected_output(
912
951
  inference_method, dataset, expected_output_cols, output_cols_prefix
913
952
  )
914
953
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -548,12 +545,23 @@ class RadiusNeighborsRegressor(BaseTransformer):
548
545
  autogenerated=self._autogenerated,
549
546
  subproject=_SUBPROJECT,
550
547
  )
551
- output_result, fitted_estimator = model_trainer.train_fit_predict(
552
- drop_input_cols=self._drop_input_cols,
553
- expected_output_cols_list=(
554
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
555
- ),
548
+ expected_output_cols = (
549
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
556
550
  )
551
+ if isinstance(dataset, DataFrame):
552
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
553
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
554
+ )
555
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
556
+ drop_input_cols=self._drop_input_cols,
557
+ expected_output_cols_list=expected_output_cols,
558
+ example_output_pd_df=example_output_pd_df,
559
+ )
560
+ else:
561
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
562
+ drop_input_cols=self._drop_input_cols,
563
+ expected_output_cols_list=expected_output_cols,
564
+ )
557
565
  self._sklearn_object = fitted_estimator
558
566
  self._is_fitted = True
559
567
  return output_result
@@ -632,12 +640,41 @@ class RadiusNeighborsRegressor(BaseTransformer):
632
640
 
633
641
  return rv
634
642
 
635
- def _align_expected_output_names(
636
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
637
- ) -> List[str]:
643
+ def _align_expected_output(
644
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
645
+ ) -> Tuple[List[str], pd.DataFrame]:
646
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
647
+ and output dataframe with 1 line.
648
+ If the method is fit_predict, run 2 lines of data.
649
+ """
638
650
  # in case the inferred output column names dimension is different
639
651
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
640
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
652
+
653
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
654
+ # so change the minimum of number of rows to 2
655
+ num_examples = 2
656
+ statement_params = telemetry.get_function_usage_statement_params(
657
+ project=_PROJECT,
658
+ subproject=_SUBPROJECT,
659
+ function_name=telemetry.get_statement_params_full_func_name(
660
+ inspect.currentframe(), RadiusNeighborsRegressor.__class__.__name__
661
+ ),
662
+ api_calls=[Session.call],
663
+ custom_tags={"autogen": True} if self._autogenerated else None,
664
+ )
665
+ if output_cols_prefix == "fit_predict_":
666
+ if hasattr(self._sklearn_object, "n_clusters"):
667
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
668
+ num_examples = self._sklearn_object.n_clusters
669
+ elif hasattr(self._sklearn_object, "min_samples"):
670
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
671
+ num_examples = self._sklearn_object.min_samples
672
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
673
+ # LocalOutlierFactor expects n_neighbors <= n_samples
674
+ num_examples = self._sklearn_object.n_neighbors
675
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
676
+ else:
677
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
641
678
 
642
679
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
643
680
  # seen during the fit.
@@ -649,12 +686,14 @@ class RadiusNeighborsRegressor(BaseTransformer):
649
686
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
650
687
  if self.sample_weight_col:
651
688
  output_df_columns_set -= set(self.sample_weight_col)
689
+
652
690
  # if the dimension of inferred output column names is correct; use it
653
691
  if len(expected_output_cols_list) == len(output_df_columns_set):
654
- return expected_output_cols_list
692
+ return expected_output_cols_list, output_df_pd
655
693
  # otherwise, use the sklearn estimator's output
656
694
  else:
657
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
695
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
696
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
658
697
 
659
698
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
660
699
  @telemetry.send_api_usage_telemetry(
@@ -700,7 +739,7 @@ class RadiusNeighborsRegressor(BaseTransformer):
700
739
  drop_input_cols=self._drop_input_cols,
701
740
  expected_output_cols_type="float",
702
741
  )
703
- expected_output_cols = self._align_expected_output_names(
742
+ expected_output_cols, _ = self._align_expected_output(
704
743
  inference_method, dataset, expected_output_cols, output_cols_prefix
705
744
  )
706
745
 
@@ -766,7 +805,7 @@ class RadiusNeighborsRegressor(BaseTransformer):
766
805
  drop_input_cols=self._drop_input_cols,
767
806
  expected_output_cols_type="float",
768
807
  )
769
- expected_output_cols = self._align_expected_output_names(
808
+ expected_output_cols, _ = self._align_expected_output(
770
809
  inference_method, dataset, expected_output_cols, output_cols_prefix
771
810
  )
772
811
  elif isinstance(dataset, pd.DataFrame):
@@ -829,7 +868,7 @@ class RadiusNeighborsRegressor(BaseTransformer):
829
868
  drop_input_cols=self._drop_input_cols,
830
869
  expected_output_cols_type="float",
831
870
  )
832
- expected_output_cols = self._align_expected_output_names(
871
+ expected_output_cols, _ = self._align_expected_output(
833
872
  inference_method, dataset, expected_output_cols, output_cols_prefix
834
873
  )
835
874
 
@@ -894,7 +933,7 @@ class RadiusNeighborsRegressor(BaseTransformer):
894
933
  drop_input_cols = self._drop_input_cols,
895
934
  expected_output_cols_type="float",
896
935
  )
897
- expected_output_cols = self._align_expected_output_names(
936
+ expected_output_cols, _ = self._align_expected_output(
898
937
  inference_method, dataset, expected_output_cols, output_cols_prefix
899
938
  )
900
939