snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +156 -20
  5. snowflake/ml/_internal/utils/identifier.py +48 -11
  6. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  7. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  8. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  9. snowflake/ml/_internal/utils/table_manager.py +19 -1
  10. snowflake/ml/_internal/utils/uri.py +2 -2
  11. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  12. snowflake/ml/data/data_connector.py +88 -9
  13. snowflake/ml/data/data_ingestor.py +18 -1
  14. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  15. snowflake/ml/data/torch_utils.py +68 -0
  16. snowflake/ml/dataset/dataset.py +1 -3
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +9 -3
  19. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  20. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  21. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  22. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  23. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  24. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  26. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  27. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  29. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  31. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  32. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  33. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  34. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  35. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  36. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  37. snowflake/ml/feature_store/feature_store.py +100 -41
  38. snowflake/ml/feature_store/feature_view.py +149 -5
  39. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  40. snowflake/ml/fileset/fileset.py +1 -1
  41. snowflake/ml/fileset/sfcfs.py +9 -3
  42. snowflake/ml/model/_client/model/model_impl.py +11 -2
  43. snowflake/ml/model/_client/model/model_version_impl.py +186 -20
  44. snowflake/ml/model/_client/ops/model_ops.py +144 -30
  45. snowflake/ml/model/_client/ops/service_ops.py +312 -0
  46. snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
  47. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
  48. snowflake/ml/model/_client/sql/model_version.py +13 -4
  49. snowflake/ml/model/_client/sql/service.py +196 -0
  50. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  51. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  52. snowflake/ml/model/_model_composer/model_composer.py +5 -0
  53. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
  54. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  55. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  57. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  58. snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
  59. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  60. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
  61. snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
  62. snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
  63. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  64. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  65. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  66. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  67. snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
  68. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
  69. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  70. snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
  71. snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
  72. snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
  73. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
  74. snowflake/ml/model/_packager/model_packager.py +4 -1
  75. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  76. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  77. snowflake/ml/model/_signatures/utils.py +9 -0
  78. snowflake/ml/model/models/llm.py +3 -1
  79. snowflake/ml/model/type_hints.py +10 -4
  80. snowflake/ml/modeling/_internal/constants.py +1 -0
  81. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  82. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  83. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  84. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  85. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  86. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  87. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  88. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  89. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  90. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  91. snowflake/ml/modeling/cluster/birch.py +60 -21
  92. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  93. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  94. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  95. snowflake/ml/modeling/cluster/k_means.py +60 -21
  96. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  97. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  98. snowflake/ml/modeling/cluster/optics.py +60 -21
  99. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  100. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  101. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  102. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  103. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  104. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  105. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  106. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  107. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  108. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  109. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  110. snowflake/ml/modeling/covariance/oas.py +60 -21
  111. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  112. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  113. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  114. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  115. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  116. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  117. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  118. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  119. snowflake/ml/modeling/decomposition/pca.py +60 -21
  120. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  121. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  122. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  123. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  124. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  125. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  126. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  127. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  128. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  129. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  130. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  131. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  132. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  133. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  134. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  135. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  136. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  137. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  138. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  139. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  140. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  141. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  142. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  143. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  144. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  145. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  146. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  147. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  148. snowflake/ml/modeling/framework/base.py +28 -19
  149. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  150. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  151. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  152. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  153. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  154. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  155. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  156. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  157. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  158. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  159. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  160. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  161. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  162. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  163. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  164. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  165. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  166. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  167. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  168. snowflake/ml/modeling/linear_model/lars.py +60 -21
  169. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  170. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  171. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  172. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  173. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  174. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  175. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  176. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  177. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  178. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  179. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  180. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  181. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  182. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  183. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  184. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  185. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  186. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  187. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  188. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  189. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  190. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  191. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  192. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  193. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  194. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  195. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  196. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  197. snowflake/ml/modeling/manifold/isomap.py +60 -21
  198. snowflake/ml/modeling/manifold/mds.py +60 -21
  199. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  200. snowflake/ml/modeling/manifold/tsne.py +60 -21
  201. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  202. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  203. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  204. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  205. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  206. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  207. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  208. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  209. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  210. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  211. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  212. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  213. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  214. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  215. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  216. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  217. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  218. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  219. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  220. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  221. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  222. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  223. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  224. snowflake/ml/modeling/pipeline/pipeline.py +4 -12
  225. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  226. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  227. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  228. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  229. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  230. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  231. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  232. snowflake/ml/modeling/svm/svc.py +60 -21
  233. snowflake/ml/modeling/svm/svr.py +60 -21
  234. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  235. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  236. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  237. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  238. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  239. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  240. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  241. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  242. snowflake/ml/registry/_manager/model_manager.py +20 -2
  243. snowflake/ml/registry/model_registry.py +1 -1
  244. snowflake/ml/registry/registry.py +1 -2
  245. snowflake/ml/utils/sql_client.py +22 -0
  246. snowflake/ml/version.py +1 -1
  247. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
  248. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
  249. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  250. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  251. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  252. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -595,12 +592,23 @@ class TSNE(BaseTransformer):
595
592
  autogenerated=self._autogenerated,
596
593
  subproject=_SUBPROJECT,
597
594
  )
598
- output_result, fitted_estimator = model_trainer.train_fit_predict(
599
- drop_input_cols=self._drop_input_cols,
600
- expected_output_cols_list=(
601
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
602
- ),
595
+ expected_output_cols = (
596
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
603
597
  )
598
+ if isinstance(dataset, DataFrame):
599
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
600
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
601
+ )
602
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
603
+ drop_input_cols=self._drop_input_cols,
604
+ expected_output_cols_list=expected_output_cols,
605
+ example_output_pd_df=example_output_pd_df,
606
+ )
607
+ else:
608
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
609
+ drop_input_cols=self._drop_input_cols,
610
+ expected_output_cols_list=expected_output_cols,
611
+ )
604
612
  self._sklearn_object = fitted_estimator
605
613
  self._is_fitted = True
606
614
  return output_result
@@ -681,12 +689,41 @@ class TSNE(BaseTransformer):
681
689
 
682
690
  return rv
683
691
 
684
- def _align_expected_output_names(
685
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
686
- ) -> List[str]:
692
+ def _align_expected_output(
693
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
694
+ ) -> Tuple[List[str], pd.DataFrame]:
695
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
696
+ and output dataframe with 1 line.
697
+ If the method is fit_predict, run 2 lines of data.
698
+ """
687
699
  # in case the inferred output column names dimension is different
688
700
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
689
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
701
+
702
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
703
+ # so change the minimum of number of rows to 2
704
+ num_examples = 2
705
+ statement_params = telemetry.get_function_usage_statement_params(
706
+ project=_PROJECT,
707
+ subproject=_SUBPROJECT,
708
+ function_name=telemetry.get_statement_params_full_func_name(
709
+ inspect.currentframe(), TSNE.__class__.__name__
710
+ ),
711
+ api_calls=[Session.call],
712
+ custom_tags={"autogen": True} if self._autogenerated else None,
713
+ )
714
+ if output_cols_prefix == "fit_predict_":
715
+ if hasattr(self._sklearn_object, "n_clusters"):
716
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
717
+ num_examples = self._sklearn_object.n_clusters
718
+ elif hasattr(self._sklearn_object, "min_samples"):
719
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
720
+ num_examples = self._sklearn_object.min_samples
721
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
722
+ # LocalOutlierFactor expects n_neighbors <= n_samples
723
+ num_examples = self._sklearn_object.n_neighbors
724
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
725
+ else:
726
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
690
727
 
691
728
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
692
729
  # seen during the fit.
@@ -698,12 +735,14 @@ class TSNE(BaseTransformer):
698
735
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
699
736
  if self.sample_weight_col:
700
737
  output_df_columns_set -= set(self.sample_weight_col)
738
+
701
739
  # if the dimension of inferred output column names is correct; use it
702
740
  if len(expected_output_cols_list) == len(output_df_columns_set):
703
- return expected_output_cols_list
741
+ return expected_output_cols_list, output_df_pd
704
742
  # otherwise, use the sklearn estimator's output
705
743
  else:
706
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
744
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
745
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
707
746
 
708
747
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
709
748
  @telemetry.send_api_usage_telemetry(
@@ -749,7 +788,7 @@ class TSNE(BaseTransformer):
749
788
  drop_input_cols=self._drop_input_cols,
750
789
  expected_output_cols_type="float",
751
790
  )
752
- expected_output_cols = self._align_expected_output_names(
791
+ expected_output_cols, _ = self._align_expected_output(
753
792
  inference_method, dataset, expected_output_cols, output_cols_prefix
754
793
  )
755
794
 
@@ -815,7 +854,7 @@ class TSNE(BaseTransformer):
815
854
  drop_input_cols=self._drop_input_cols,
816
855
  expected_output_cols_type="float",
817
856
  )
818
- expected_output_cols = self._align_expected_output_names(
857
+ expected_output_cols, _ = self._align_expected_output(
819
858
  inference_method, dataset, expected_output_cols, output_cols_prefix
820
859
  )
821
860
  elif isinstance(dataset, pd.DataFrame):
@@ -878,7 +917,7 @@ class TSNE(BaseTransformer):
878
917
  drop_input_cols=self._drop_input_cols,
879
918
  expected_output_cols_type="float",
880
919
  )
881
- expected_output_cols = self._align_expected_output_names(
920
+ expected_output_cols, _ = self._align_expected_output(
882
921
  inference_method, dataset, expected_output_cols, output_cols_prefix
883
922
  )
884
923
 
@@ -943,7 +982,7 @@ class TSNE(BaseTransformer):
943
982
  drop_input_cols = self._drop_input_cols,
944
983
  expected_output_cols_type="float",
945
984
  )
946
- expected_output_cols = self._align_expected_output_names(
985
+ expected_output_cols, _ = self._align_expected_output(
947
986
  inference_method, dataset, expected_output_cols, output_cols_prefix
948
987
  )
949
988
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -602,12 +599,23 @@ class BayesianGaussianMixture(BaseTransformer):
602
599
  autogenerated=self._autogenerated,
603
600
  subproject=_SUBPROJECT,
604
601
  )
605
- output_result, fitted_estimator = model_trainer.train_fit_predict(
606
- drop_input_cols=self._drop_input_cols,
607
- expected_output_cols_list=(
608
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
609
- ),
602
+ expected_output_cols = (
603
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
610
604
  )
605
+ if isinstance(dataset, DataFrame):
606
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
607
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
608
+ )
609
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
610
+ drop_input_cols=self._drop_input_cols,
611
+ expected_output_cols_list=expected_output_cols,
612
+ example_output_pd_df=example_output_pd_df,
613
+ )
614
+ else:
615
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
616
+ drop_input_cols=self._drop_input_cols,
617
+ expected_output_cols_list=expected_output_cols,
618
+ )
611
619
  self._sklearn_object = fitted_estimator
612
620
  self._is_fitted = True
613
621
  return output_result
@@ -686,12 +694,41 @@ class BayesianGaussianMixture(BaseTransformer):
686
694
 
687
695
  return rv
688
696
 
689
- def _align_expected_output_names(
690
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
691
- ) -> List[str]:
697
+ def _align_expected_output(
698
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
699
+ ) -> Tuple[List[str], pd.DataFrame]:
700
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
701
+ and output dataframe with 1 line.
702
+ If the method is fit_predict, run 2 lines of data.
703
+ """
692
704
  # in case the inferred output column names dimension is different
693
705
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
694
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
706
+
707
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
708
+ # so change the minimum of number of rows to 2
709
+ num_examples = 2
710
+ statement_params = telemetry.get_function_usage_statement_params(
711
+ project=_PROJECT,
712
+ subproject=_SUBPROJECT,
713
+ function_name=telemetry.get_statement_params_full_func_name(
714
+ inspect.currentframe(), BayesianGaussianMixture.__class__.__name__
715
+ ),
716
+ api_calls=[Session.call],
717
+ custom_tags={"autogen": True} if self._autogenerated else None,
718
+ )
719
+ if output_cols_prefix == "fit_predict_":
720
+ if hasattr(self._sklearn_object, "n_clusters"):
721
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
722
+ num_examples = self._sklearn_object.n_clusters
723
+ elif hasattr(self._sklearn_object, "min_samples"):
724
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
725
+ num_examples = self._sklearn_object.min_samples
726
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
727
+ # LocalOutlierFactor expects n_neighbors <= n_samples
728
+ num_examples = self._sklearn_object.n_neighbors
729
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
730
+ else:
731
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
695
732
 
696
733
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
697
734
  # seen during the fit.
@@ -703,12 +740,14 @@ class BayesianGaussianMixture(BaseTransformer):
703
740
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
704
741
  if self.sample_weight_col:
705
742
  output_df_columns_set -= set(self.sample_weight_col)
743
+
706
744
  # if the dimension of inferred output column names is correct; use it
707
745
  if len(expected_output_cols_list) == len(output_df_columns_set):
708
- return expected_output_cols_list
746
+ return expected_output_cols_list, output_df_pd
709
747
  # otherwise, use the sklearn estimator's output
710
748
  else:
711
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
749
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
750
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
712
751
 
713
752
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
714
753
  @telemetry.send_api_usage_telemetry(
@@ -756,7 +795,7 @@ class BayesianGaussianMixture(BaseTransformer):
756
795
  drop_input_cols=self._drop_input_cols,
757
796
  expected_output_cols_type="float",
758
797
  )
759
- expected_output_cols = self._align_expected_output_names(
798
+ expected_output_cols, _ = self._align_expected_output(
760
799
  inference_method, dataset, expected_output_cols, output_cols_prefix
761
800
  )
762
801
 
@@ -824,7 +863,7 @@ class BayesianGaussianMixture(BaseTransformer):
824
863
  drop_input_cols=self._drop_input_cols,
825
864
  expected_output_cols_type="float",
826
865
  )
827
- expected_output_cols = self._align_expected_output_names(
866
+ expected_output_cols, _ = self._align_expected_output(
828
867
  inference_method, dataset, expected_output_cols, output_cols_prefix
829
868
  )
830
869
  elif isinstance(dataset, pd.DataFrame):
@@ -887,7 +926,7 @@ class BayesianGaussianMixture(BaseTransformer):
887
926
  drop_input_cols=self._drop_input_cols,
888
927
  expected_output_cols_type="float",
889
928
  )
890
- expected_output_cols = self._align_expected_output_names(
929
+ expected_output_cols, _ = self._align_expected_output(
891
930
  inference_method, dataset, expected_output_cols, output_cols_prefix
892
931
  )
893
932
 
@@ -954,7 +993,7 @@ class BayesianGaussianMixture(BaseTransformer):
954
993
  drop_input_cols = self._drop_input_cols,
955
994
  expected_output_cols_type="float",
956
995
  )
957
- expected_output_cols = self._align_expected_output_names(
996
+ expected_output_cols, _ = self._align_expected_output(
958
997
  inference_method, dataset, expected_output_cols, output_cols_prefix
959
998
  )
960
999
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -575,12 +572,23 @@ class GaussianMixture(BaseTransformer):
575
572
  autogenerated=self._autogenerated,
576
573
  subproject=_SUBPROJECT,
577
574
  )
578
- output_result, fitted_estimator = model_trainer.train_fit_predict(
579
- drop_input_cols=self._drop_input_cols,
580
- expected_output_cols_list=(
581
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
582
- ),
575
+ expected_output_cols = (
576
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
583
577
  )
578
+ if isinstance(dataset, DataFrame):
579
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
580
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
581
+ )
582
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
583
+ drop_input_cols=self._drop_input_cols,
584
+ expected_output_cols_list=expected_output_cols,
585
+ example_output_pd_df=example_output_pd_df,
586
+ )
587
+ else:
588
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
589
+ drop_input_cols=self._drop_input_cols,
590
+ expected_output_cols_list=expected_output_cols,
591
+ )
584
592
  self._sklearn_object = fitted_estimator
585
593
  self._is_fitted = True
586
594
  return output_result
@@ -659,12 +667,41 @@ class GaussianMixture(BaseTransformer):
659
667
 
660
668
  return rv
661
669
 
662
- def _align_expected_output_names(
663
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
664
- ) -> List[str]:
670
+ def _align_expected_output(
671
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
672
+ ) -> Tuple[List[str], pd.DataFrame]:
673
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
674
+ and output dataframe with 1 line.
675
+ If the method is fit_predict, run 2 lines of data.
676
+ """
665
677
  # in case the inferred output column names dimension is different
666
678
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
667
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
679
+
680
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
681
+ # so change the minimum of number of rows to 2
682
+ num_examples = 2
683
+ statement_params = telemetry.get_function_usage_statement_params(
684
+ project=_PROJECT,
685
+ subproject=_SUBPROJECT,
686
+ function_name=telemetry.get_statement_params_full_func_name(
687
+ inspect.currentframe(), GaussianMixture.__class__.__name__
688
+ ),
689
+ api_calls=[Session.call],
690
+ custom_tags={"autogen": True} if self._autogenerated else None,
691
+ )
692
+ if output_cols_prefix == "fit_predict_":
693
+ if hasattr(self._sklearn_object, "n_clusters"):
694
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
695
+ num_examples = self._sklearn_object.n_clusters
696
+ elif hasattr(self._sklearn_object, "min_samples"):
697
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
698
+ num_examples = self._sklearn_object.min_samples
699
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
700
+ # LocalOutlierFactor expects n_neighbors <= n_samples
701
+ num_examples = self._sklearn_object.n_neighbors
702
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
703
+ else:
704
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
668
705
 
669
706
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
670
707
  # seen during the fit.
@@ -676,12 +713,14 @@ class GaussianMixture(BaseTransformer):
676
713
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
677
714
  if self.sample_weight_col:
678
715
  output_df_columns_set -= set(self.sample_weight_col)
716
+
679
717
  # if the dimension of inferred output column names is correct; use it
680
718
  if len(expected_output_cols_list) == len(output_df_columns_set):
681
- return expected_output_cols_list
719
+ return expected_output_cols_list, output_df_pd
682
720
  # otherwise, use the sklearn estimator's output
683
721
  else:
684
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
722
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
723
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
685
724
 
686
725
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
687
726
  @telemetry.send_api_usage_telemetry(
@@ -729,7 +768,7 @@ class GaussianMixture(BaseTransformer):
729
768
  drop_input_cols=self._drop_input_cols,
730
769
  expected_output_cols_type="float",
731
770
  )
732
- expected_output_cols = self._align_expected_output_names(
771
+ expected_output_cols, _ = self._align_expected_output(
733
772
  inference_method, dataset, expected_output_cols, output_cols_prefix
734
773
  )
735
774
 
@@ -797,7 +836,7 @@ class GaussianMixture(BaseTransformer):
797
836
  drop_input_cols=self._drop_input_cols,
798
837
  expected_output_cols_type="float",
799
838
  )
800
- expected_output_cols = self._align_expected_output_names(
839
+ expected_output_cols, _ = self._align_expected_output(
801
840
  inference_method, dataset, expected_output_cols, output_cols_prefix
802
841
  )
803
842
  elif isinstance(dataset, pd.DataFrame):
@@ -860,7 +899,7 @@ class GaussianMixture(BaseTransformer):
860
899
  drop_input_cols=self._drop_input_cols,
861
900
  expected_output_cols_type="float",
862
901
  )
863
- expected_output_cols = self._align_expected_output_names(
902
+ expected_output_cols, _ = self._align_expected_output(
864
903
  inference_method, dataset, expected_output_cols, output_cols_prefix
865
904
  )
866
905
 
@@ -927,7 +966,7 @@ class GaussianMixture(BaseTransformer):
927
966
  drop_input_cols = self._drop_input_cols,
928
967
  expected_output_cols_type="float",
929
968
  )
930
- expected_output_cols = self._align_expected_output_names(
969
+ expected_output_cols, _ = self._align_expected_output(
931
970
  inference_method, dataset, expected_output_cols, output_cols_prefix
932
971
  )
933
972
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -483,12 +480,23 @@ class OneVsOneClassifier(BaseTransformer):
483
480
  autogenerated=self._autogenerated,
484
481
  subproject=_SUBPROJECT,
485
482
  )
486
- output_result, fitted_estimator = model_trainer.train_fit_predict(
487
- drop_input_cols=self._drop_input_cols,
488
- expected_output_cols_list=(
489
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
490
- ),
483
+ expected_output_cols = (
484
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
491
485
  )
486
+ if isinstance(dataset, DataFrame):
487
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
488
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
489
+ )
490
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
491
+ drop_input_cols=self._drop_input_cols,
492
+ expected_output_cols_list=expected_output_cols,
493
+ example_output_pd_df=example_output_pd_df,
494
+ )
495
+ else:
496
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
497
+ drop_input_cols=self._drop_input_cols,
498
+ expected_output_cols_list=expected_output_cols,
499
+ )
492
500
  self._sklearn_object = fitted_estimator
493
501
  self._is_fitted = True
494
502
  return output_result
@@ -567,12 +575,41 @@ class OneVsOneClassifier(BaseTransformer):
567
575
 
568
576
  return rv
569
577
 
570
- def _align_expected_output_names(
571
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
572
- ) -> List[str]:
578
+ def _align_expected_output(
579
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
580
+ ) -> Tuple[List[str], pd.DataFrame]:
581
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
582
+ and output dataframe with 1 line.
583
+ If the method is fit_predict, run 2 lines of data.
584
+ """
573
585
  # in case the inferred output column names dimension is different
574
586
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
575
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
587
+
588
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
589
+ # so change the minimum of number of rows to 2
590
+ num_examples = 2
591
+ statement_params = telemetry.get_function_usage_statement_params(
592
+ project=_PROJECT,
593
+ subproject=_SUBPROJECT,
594
+ function_name=telemetry.get_statement_params_full_func_name(
595
+ inspect.currentframe(), OneVsOneClassifier.__class__.__name__
596
+ ),
597
+ api_calls=[Session.call],
598
+ custom_tags={"autogen": True} if self._autogenerated else None,
599
+ )
600
+ if output_cols_prefix == "fit_predict_":
601
+ if hasattr(self._sklearn_object, "n_clusters"):
602
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
603
+ num_examples = self._sklearn_object.n_clusters
604
+ elif hasattr(self._sklearn_object, "min_samples"):
605
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
606
+ num_examples = self._sklearn_object.min_samples
607
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
608
+ # LocalOutlierFactor expects n_neighbors <= n_samples
609
+ num_examples = self._sklearn_object.n_neighbors
610
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
611
+ else:
612
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
576
613
 
577
614
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
578
615
  # seen during the fit.
@@ -584,12 +621,14 @@ class OneVsOneClassifier(BaseTransformer):
584
621
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
585
622
  if self.sample_weight_col:
586
623
  output_df_columns_set -= set(self.sample_weight_col)
624
+
587
625
  # if the dimension of inferred output column names is correct; use it
588
626
  if len(expected_output_cols_list) == len(output_df_columns_set):
589
- return expected_output_cols_list
627
+ return expected_output_cols_list, output_df_pd
590
628
  # otherwise, use the sklearn estimator's output
591
629
  else:
592
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
630
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
631
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
593
632
 
594
633
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
595
634
  @telemetry.send_api_usage_telemetry(
@@ -635,7 +674,7 @@ class OneVsOneClassifier(BaseTransformer):
635
674
  drop_input_cols=self._drop_input_cols,
636
675
  expected_output_cols_type="float",
637
676
  )
638
- expected_output_cols = self._align_expected_output_names(
677
+ expected_output_cols, _ = self._align_expected_output(
639
678
  inference_method, dataset, expected_output_cols, output_cols_prefix
640
679
  )
641
680
 
@@ -701,7 +740,7 @@ class OneVsOneClassifier(BaseTransformer):
701
740
  drop_input_cols=self._drop_input_cols,
702
741
  expected_output_cols_type="float",
703
742
  )
704
- expected_output_cols = self._align_expected_output_names(
743
+ expected_output_cols, _ = self._align_expected_output(
705
744
  inference_method, dataset, expected_output_cols, output_cols_prefix
706
745
  )
707
746
  elif isinstance(dataset, pd.DataFrame):
@@ -766,7 +805,7 @@ class OneVsOneClassifier(BaseTransformer):
766
805
  drop_input_cols=self._drop_input_cols,
767
806
  expected_output_cols_type="float",
768
807
  )
769
- expected_output_cols = self._align_expected_output_names(
808
+ expected_output_cols, _ = self._align_expected_output(
770
809
  inference_method, dataset, expected_output_cols, output_cols_prefix
771
810
  )
772
811
 
@@ -831,7 +870,7 @@ class OneVsOneClassifier(BaseTransformer):
831
870
  drop_input_cols = self._drop_input_cols,
832
871
  expected_output_cols_type="float",
833
872
  )
834
- expected_output_cols = self._align_expected_output_names(
873
+ expected_output_cols, _ = self._align_expected_output(
835
874
  inference_method, dataset, expected_output_cols, output_cols_prefix
836
875
  )
837
876