snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +156 -20
  5. snowflake/ml/_internal/utils/identifier.py +48 -11
  6. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  7. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  8. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  9. snowflake/ml/_internal/utils/table_manager.py +19 -1
  10. snowflake/ml/_internal/utils/uri.py +2 -2
  11. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  12. snowflake/ml/data/data_connector.py +88 -9
  13. snowflake/ml/data/data_ingestor.py +18 -1
  14. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  15. snowflake/ml/data/torch_utils.py +68 -0
  16. snowflake/ml/dataset/dataset.py +1 -3
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +9 -3
  19. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  20. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  21. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  22. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  23. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  24. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  26. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  27. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  29. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  31. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  32. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  33. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  34. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  35. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  36. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  37. snowflake/ml/feature_store/feature_store.py +100 -41
  38. snowflake/ml/feature_store/feature_view.py +149 -5
  39. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  40. snowflake/ml/fileset/fileset.py +1 -1
  41. snowflake/ml/fileset/sfcfs.py +9 -3
  42. snowflake/ml/model/_client/model/model_impl.py +11 -2
  43. snowflake/ml/model/_client/model/model_version_impl.py +186 -20
  44. snowflake/ml/model/_client/ops/model_ops.py +144 -30
  45. snowflake/ml/model/_client/ops/service_ops.py +312 -0
  46. snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
  47. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
  48. snowflake/ml/model/_client/sql/model_version.py +13 -4
  49. snowflake/ml/model/_client/sql/service.py +196 -0
  50. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  51. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  52. snowflake/ml/model/_model_composer/model_composer.py +5 -0
  53. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
  54. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  55. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  57. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  58. snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
  59. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  60. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
  61. snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
  62. snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
  63. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  64. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  65. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  66. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  67. snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
  68. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
  69. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  70. snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
  71. snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
  72. snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
  73. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
  74. snowflake/ml/model/_packager/model_packager.py +4 -1
  75. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  76. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  77. snowflake/ml/model/_signatures/utils.py +9 -0
  78. snowflake/ml/model/models/llm.py +3 -1
  79. snowflake/ml/model/type_hints.py +10 -4
  80. snowflake/ml/modeling/_internal/constants.py +1 -0
  81. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  82. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  83. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  84. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  85. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  86. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  87. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  88. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  89. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  90. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  91. snowflake/ml/modeling/cluster/birch.py +60 -21
  92. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  93. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  94. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  95. snowflake/ml/modeling/cluster/k_means.py +60 -21
  96. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  97. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  98. snowflake/ml/modeling/cluster/optics.py +60 -21
  99. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  100. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  101. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  102. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  103. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  104. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  105. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  106. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  107. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  108. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  109. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  110. snowflake/ml/modeling/covariance/oas.py +60 -21
  111. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  112. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  113. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  114. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  115. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  116. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  117. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  118. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  119. snowflake/ml/modeling/decomposition/pca.py +60 -21
  120. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  121. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  122. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  123. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  124. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  125. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  126. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  127. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  128. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  129. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  130. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  131. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  132. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  133. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  134. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  135. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  136. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  137. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  138. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  139. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  140. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  141. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  142. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  143. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  144. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  145. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  146. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  147. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  148. snowflake/ml/modeling/framework/base.py +28 -19
  149. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  150. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  151. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  152. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  153. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  154. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  155. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  156. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  157. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  158. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  159. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  160. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  161. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  162. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  163. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  164. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  165. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  166. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  167. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  168. snowflake/ml/modeling/linear_model/lars.py +60 -21
  169. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  170. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  171. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  172. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  173. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  174. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  175. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  176. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  177. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  178. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  179. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  180. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  181. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  182. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  183. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  184. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  185. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  186. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  187. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  188. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  189. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  190. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  191. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  192. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  193. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  194. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  195. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  196. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  197. snowflake/ml/modeling/manifold/isomap.py +60 -21
  198. snowflake/ml/modeling/manifold/mds.py +60 -21
  199. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  200. snowflake/ml/modeling/manifold/tsne.py +60 -21
  201. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  202. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  203. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  204. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  205. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  206. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  207. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  208. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  209. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  210. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  211. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  212. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  213. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  214. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  215. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  216. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  217. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  218. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  219. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  220. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  221. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  222. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  223. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  224. snowflake/ml/modeling/pipeline/pipeline.py +4 -12
  225. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  226. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  227. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  228. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  229. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  230. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  231. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  232. snowflake/ml/modeling/svm/svc.py +60 -21
  233. snowflake/ml/modeling/svm/svr.py +60 -21
  234. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  235. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  236. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  237. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  238. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  239. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  240. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  241. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  242. snowflake/ml/registry/_manager/model_manager.py +20 -2
  243. snowflake/ml/registry/model_registry.py +1 -1
  244. snowflake/ml/registry/registry.py +1 -2
  245. snowflake/ml/utils/sql_client.py +22 -0
  246. snowflake/ml/version.py +1 -1
  247. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
  248. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
  249. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  250. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  251. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  252. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -560,12 +557,23 @@ class PassiveAggressiveRegressor(BaseTransformer):
560
557
  autogenerated=self._autogenerated,
561
558
  subproject=_SUBPROJECT,
562
559
  )
563
- output_result, fitted_estimator = model_trainer.train_fit_predict(
564
- drop_input_cols=self._drop_input_cols,
565
- expected_output_cols_list=(
566
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
567
- ),
560
+ expected_output_cols = (
561
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
568
562
  )
563
+ if isinstance(dataset, DataFrame):
564
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
565
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
566
+ )
567
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
568
+ drop_input_cols=self._drop_input_cols,
569
+ expected_output_cols_list=expected_output_cols,
570
+ example_output_pd_df=example_output_pd_df,
571
+ )
572
+ else:
573
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
574
+ drop_input_cols=self._drop_input_cols,
575
+ expected_output_cols_list=expected_output_cols,
576
+ )
569
577
  self._sklearn_object = fitted_estimator
570
578
  self._is_fitted = True
571
579
  return output_result
@@ -644,12 +652,41 @@ class PassiveAggressiveRegressor(BaseTransformer):
644
652
 
645
653
  return rv
646
654
 
647
- def _align_expected_output_names(
648
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
649
- ) -> List[str]:
655
+ def _align_expected_output(
656
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
657
+ ) -> Tuple[List[str], pd.DataFrame]:
658
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
659
+ and output dataframe with 1 line.
660
+ If the method is fit_predict, run 2 lines of data.
661
+ """
650
662
  # in case the inferred output column names dimension is different
651
663
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
652
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
664
+
665
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
666
+ # so change the minimum of number of rows to 2
667
+ num_examples = 2
668
+ statement_params = telemetry.get_function_usage_statement_params(
669
+ project=_PROJECT,
670
+ subproject=_SUBPROJECT,
671
+ function_name=telemetry.get_statement_params_full_func_name(
672
+ inspect.currentframe(), PassiveAggressiveRegressor.__class__.__name__
673
+ ),
674
+ api_calls=[Session.call],
675
+ custom_tags={"autogen": True} if self._autogenerated else None,
676
+ )
677
+ if output_cols_prefix == "fit_predict_":
678
+ if hasattr(self._sklearn_object, "n_clusters"):
679
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
680
+ num_examples = self._sklearn_object.n_clusters
681
+ elif hasattr(self._sklearn_object, "min_samples"):
682
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
683
+ num_examples = self._sklearn_object.min_samples
684
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
685
+ # LocalOutlierFactor expects n_neighbors <= n_samples
686
+ num_examples = self._sklearn_object.n_neighbors
687
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
688
+ else:
689
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
653
690
 
654
691
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
655
692
  # seen during the fit.
@@ -661,12 +698,14 @@ class PassiveAggressiveRegressor(BaseTransformer):
661
698
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
662
699
  if self.sample_weight_col:
663
700
  output_df_columns_set -= set(self.sample_weight_col)
701
+
664
702
  # if the dimension of inferred output column names is correct; use it
665
703
  if len(expected_output_cols_list) == len(output_df_columns_set):
666
- return expected_output_cols_list
704
+ return expected_output_cols_list, output_df_pd
667
705
  # otherwise, use the sklearn estimator's output
668
706
  else:
669
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
707
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
708
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
670
709
 
671
710
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
672
711
  @telemetry.send_api_usage_telemetry(
@@ -712,7 +751,7 @@ class PassiveAggressiveRegressor(BaseTransformer):
712
751
  drop_input_cols=self._drop_input_cols,
713
752
  expected_output_cols_type="float",
714
753
  )
715
- expected_output_cols = self._align_expected_output_names(
754
+ expected_output_cols, _ = self._align_expected_output(
716
755
  inference_method, dataset, expected_output_cols, output_cols_prefix
717
756
  )
718
757
 
@@ -778,7 +817,7 @@ class PassiveAggressiveRegressor(BaseTransformer):
778
817
  drop_input_cols=self._drop_input_cols,
779
818
  expected_output_cols_type="float",
780
819
  )
781
- expected_output_cols = self._align_expected_output_names(
820
+ expected_output_cols, _ = self._align_expected_output(
782
821
  inference_method, dataset, expected_output_cols, output_cols_prefix
783
822
  )
784
823
  elif isinstance(dataset, pd.DataFrame):
@@ -841,7 +880,7 @@ class PassiveAggressiveRegressor(BaseTransformer):
841
880
  drop_input_cols=self._drop_input_cols,
842
881
  expected_output_cols_type="float",
843
882
  )
844
- expected_output_cols = self._align_expected_output_names(
883
+ expected_output_cols, _ = self._align_expected_output(
845
884
  inference_method, dataset, expected_output_cols, output_cols_prefix
846
885
  )
847
886
 
@@ -906,7 +945,7 @@ class PassiveAggressiveRegressor(BaseTransformer):
906
945
  drop_input_cols = self._drop_input_cols,
907
946
  expected_output_cols_type="float",
908
947
  )
909
- expected_output_cols = self._align_expected_output_names(
948
+ expected_output_cols, _ = self._align_expected_output(
910
949
  inference_method, dataset, expected_output_cols, output_cols_prefix
911
950
  )
912
951
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -573,12 +570,23 @@ class Perceptron(BaseTransformer):
573
570
  autogenerated=self._autogenerated,
574
571
  subproject=_SUBPROJECT,
575
572
  )
576
- output_result, fitted_estimator = model_trainer.train_fit_predict(
577
- drop_input_cols=self._drop_input_cols,
578
- expected_output_cols_list=(
579
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
580
- ),
573
+ expected_output_cols = (
574
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
581
575
  )
576
+ if isinstance(dataset, DataFrame):
577
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
578
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
579
+ )
580
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
581
+ drop_input_cols=self._drop_input_cols,
582
+ expected_output_cols_list=expected_output_cols,
583
+ example_output_pd_df=example_output_pd_df,
584
+ )
585
+ else:
586
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
587
+ drop_input_cols=self._drop_input_cols,
588
+ expected_output_cols_list=expected_output_cols,
589
+ )
582
590
  self._sklearn_object = fitted_estimator
583
591
  self._is_fitted = True
584
592
  return output_result
@@ -657,12 +665,41 @@ class Perceptron(BaseTransformer):
657
665
 
658
666
  return rv
659
667
 
660
- def _align_expected_output_names(
661
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
662
- ) -> List[str]:
668
+ def _align_expected_output(
669
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
670
+ ) -> Tuple[List[str], pd.DataFrame]:
671
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
672
+ and output dataframe with 1 line.
673
+ If the method is fit_predict, run 2 lines of data.
674
+ """
663
675
  # in case the inferred output column names dimension is different
664
676
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
665
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
677
+
678
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
679
+ # so change the minimum of number of rows to 2
680
+ num_examples = 2
681
+ statement_params = telemetry.get_function_usage_statement_params(
682
+ project=_PROJECT,
683
+ subproject=_SUBPROJECT,
684
+ function_name=telemetry.get_statement_params_full_func_name(
685
+ inspect.currentframe(), Perceptron.__class__.__name__
686
+ ),
687
+ api_calls=[Session.call],
688
+ custom_tags={"autogen": True} if self._autogenerated else None,
689
+ )
690
+ if output_cols_prefix == "fit_predict_":
691
+ if hasattr(self._sklearn_object, "n_clusters"):
692
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
693
+ num_examples = self._sklearn_object.n_clusters
694
+ elif hasattr(self._sklearn_object, "min_samples"):
695
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
696
+ num_examples = self._sklearn_object.min_samples
697
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
698
+ # LocalOutlierFactor expects n_neighbors <= n_samples
699
+ num_examples = self._sklearn_object.n_neighbors
700
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
701
+ else:
702
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
666
703
 
667
704
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
668
705
  # seen during the fit.
@@ -674,12 +711,14 @@ class Perceptron(BaseTransformer):
674
711
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
675
712
  if self.sample_weight_col:
676
713
  output_df_columns_set -= set(self.sample_weight_col)
714
+
677
715
  # if the dimension of inferred output column names is correct; use it
678
716
  if len(expected_output_cols_list) == len(output_df_columns_set):
679
- return expected_output_cols_list
717
+ return expected_output_cols_list, output_df_pd
680
718
  # otherwise, use the sklearn estimator's output
681
719
  else:
682
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
720
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
721
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
683
722
 
684
723
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
685
724
  @telemetry.send_api_usage_telemetry(
@@ -725,7 +764,7 @@ class Perceptron(BaseTransformer):
725
764
  drop_input_cols=self._drop_input_cols,
726
765
  expected_output_cols_type="float",
727
766
  )
728
- expected_output_cols = self._align_expected_output_names(
767
+ expected_output_cols, _ = self._align_expected_output(
729
768
  inference_method, dataset, expected_output_cols, output_cols_prefix
730
769
  )
731
770
 
@@ -791,7 +830,7 @@ class Perceptron(BaseTransformer):
791
830
  drop_input_cols=self._drop_input_cols,
792
831
  expected_output_cols_type="float",
793
832
  )
794
- expected_output_cols = self._align_expected_output_names(
833
+ expected_output_cols, _ = self._align_expected_output(
795
834
  inference_method, dataset, expected_output_cols, output_cols_prefix
796
835
  )
797
836
  elif isinstance(dataset, pd.DataFrame):
@@ -856,7 +895,7 @@ class Perceptron(BaseTransformer):
856
895
  drop_input_cols=self._drop_input_cols,
857
896
  expected_output_cols_type="float",
858
897
  )
859
- expected_output_cols = self._align_expected_output_names(
898
+ expected_output_cols, _ = self._align_expected_output(
860
899
  inference_method, dataset, expected_output_cols, output_cols_prefix
861
900
  )
862
901
 
@@ -921,7 +960,7 @@ class Perceptron(BaseTransformer):
921
960
  drop_input_cols = self._drop_input_cols,
922
961
  expected_output_cols_type="float",
923
962
  )
924
- expected_output_cols = self._align_expected_output_names(
963
+ expected_output_cols, _ = self._align_expected_output(
925
964
  inference_method, dataset, expected_output_cols, output_cols_prefix
926
965
  )
927
966
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -522,12 +519,23 @@ class PoissonRegressor(BaseTransformer):
522
519
  autogenerated=self._autogenerated,
523
520
  subproject=_SUBPROJECT,
524
521
  )
525
- output_result, fitted_estimator = model_trainer.train_fit_predict(
526
- drop_input_cols=self._drop_input_cols,
527
- expected_output_cols_list=(
528
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
529
- ),
522
+ expected_output_cols = (
523
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
530
524
  )
525
+ if isinstance(dataset, DataFrame):
526
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
527
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
528
+ )
529
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
530
+ drop_input_cols=self._drop_input_cols,
531
+ expected_output_cols_list=expected_output_cols,
532
+ example_output_pd_df=example_output_pd_df,
533
+ )
534
+ else:
535
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
536
+ drop_input_cols=self._drop_input_cols,
537
+ expected_output_cols_list=expected_output_cols,
538
+ )
531
539
  self._sklearn_object = fitted_estimator
532
540
  self._is_fitted = True
533
541
  return output_result
@@ -606,12 +614,41 @@ class PoissonRegressor(BaseTransformer):
606
614
 
607
615
  return rv
608
616
 
609
- def _align_expected_output_names(
610
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
611
- ) -> List[str]:
617
+ def _align_expected_output(
618
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
619
+ ) -> Tuple[List[str], pd.DataFrame]:
620
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
621
+ and output dataframe with 1 line.
622
+ If the method is fit_predict, run 2 lines of data.
623
+ """
612
624
  # in case the inferred output column names dimension is different
613
625
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
614
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
626
+
627
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
628
+ # so change the minimum of number of rows to 2
629
+ num_examples = 2
630
+ statement_params = telemetry.get_function_usage_statement_params(
631
+ project=_PROJECT,
632
+ subproject=_SUBPROJECT,
633
+ function_name=telemetry.get_statement_params_full_func_name(
634
+ inspect.currentframe(), PoissonRegressor.__class__.__name__
635
+ ),
636
+ api_calls=[Session.call],
637
+ custom_tags={"autogen": True} if self._autogenerated else None,
638
+ )
639
+ if output_cols_prefix == "fit_predict_":
640
+ if hasattr(self._sklearn_object, "n_clusters"):
641
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
642
+ num_examples = self._sklearn_object.n_clusters
643
+ elif hasattr(self._sklearn_object, "min_samples"):
644
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
645
+ num_examples = self._sklearn_object.min_samples
646
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
647
+ # LocalOutlierFactor expects n_neighbors <= n_samples
648
+ num_examples = self._sklearn_object.n_neighbors
649
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
650
+ else:
651
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
615
652
 
616
653
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
617
654
  # seen during the fit.
@@ -623,12 +660,14 @@ class PoissonRegressor(BaseTransformer):
623
660
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
624
661
  if self.sample_weight_col:
625
662
  output_df_columns_set -= set(self.sample_weight_col)
663
+
626
664
  # if the dimension of inferred output column names is correct; use it
627
665
  if len(expected_output_cols_list) == len(output_df_columns_set):
628
- return expected_output_cols_list
666
+ return expected_output_cols_list, output_df_pd
629
667
  # otherwise, use the sklearn estimator's output
630
668
  else:
631
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
669
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
670
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
632
671
 
633
672
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
634
673
  @telemetry.send_api_usage_telemetry(
@@ -674,7 +713,7 @@ class PoissonRegressor(BaseTransformer):
674
713
  drop_input_cols=self._drop_input_cols,
675
714
  expected_output_cols_type="float",
676
715
  )
677
- expected_output_cols = self._align_expected_output_names(
716
+ expected_output_cols, _ = self._align_expected_output(
678
717
  inference_method, dataset, expected_output_cols, output_cols_prefix
679
718
  )
680
719
 
@@ -740,7 +779,7 @@ class PoissonRegressor(BaseTransformer):
740
779
  drop_input_cols=self._drop_input_cols,
741
780
  expected_output_cols_type="float",
742
781
  )
743
- expected_output_cols = self._align_expected_output_names(
782
+ expected_output_cols, _ = self._align_expected_output(
744
783
  inference_method, dataset, expected_output_cols, output_cols_prefix
745
784
  )
746
785
  elif isinstance(dataset, pd.DataFrame):
@@ -803,7 +842,7 @@ class PoissonRegressor(BaseTransformer):
803
842
  drop_input_cols=self._drop_input_cols,
804
843
  expected_output_cols_type="float",
805
844
  )
806
- expected_output_cols = self._align_expected_output_names(
845
+ expected_output_cols, _ = self._align_expected_output(
807
846
  inference_method, dataset, expected_output_cols, output_cols_prefix
808
847
  )
809
848
 
@@ -868,7 +907,7 @@ class PoissonRegressor(BaseTransformer):
868
907
  drop_input_cols = self._drop_input_cols,
869
908
  expected_output_cols_type="float",
870
909
  )
871
- expected_output_cols = self._align_expected_output_names(
910
+ expected_output_cols, _ = self._align_expected_output(
872
911
  inference_method, dataset, expected_output_cols, output_cols_prefix
873
912
  )
874
913
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -578,12 +575,23 @@ class RANSACRegressor(BaseTransformer):
578
575
  autogenerated=self._autogenerated,
579
576
  subproject=_SUBPROJECT,
580
577
  )
581
- output_result, fitted_estimator = model_trainer.train_fit_predict(
582
- drop_input_cols=self._drop_input_cols,
583
- expected_output_cols_list=(
584
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
585
- ),
578
+ expected_output_cols = (
579
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
586
580
  )
581
+ if isinstance(dataset, DataFrame):
582
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
583
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
584
+ )
585
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
586
+ drop_input_cols=self._drop_input_cols,
587
+ expected_output_cols_list=expected_output_cols,
588
+ example_output_pd_df=example_output_pd_df,
589
+ )
590
+ else:
591
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
592
+ drop_input_cols=self._drop_input_cols,
593
+ expected_output_cols_list=expected_output_cols,
594
+ )
587
595
  self._sklearn_object = fitted_estimator
588
596
  self._is_fitted = True
589
597
  return output_result
@@ -662,12 +670,41 @@ class RANSACRegressor(BaseTransformer):
662
670
 
663
671
  return rv
664
672
 
665
- def _align_expected_output_names(
666
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
667
- ) -> List[str]:
673
+ def _align_expected_output(
674
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
675
+ ) -> Tuple[List[str], pd.DataFrame]:
676
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
677
+ and output dataframe with 1 line.
678
+ If the method is fit_predict, run 2 lines of data.
679
+ """
668
680
  # in case the inferred output column names dimension is different
669
681
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
670
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
682
+
683
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
684
+ # so change the minimum of number of rows to 2
685
+ num_examples = 2
686
+ statement_params = telemetry.get_function_usage_statement_params(
687
+ project=_PROJECT,
688
+ subproject=_SUBPROJECT,
689
+ function_name=telemetry.get_statement_params_full_func_name(
690
+ inspect.currentframe(), RANSACRegressor.__class__.__name__
691
+ ),
692
+ api_calls=[Session.call],
693
+ custom_tags={"autogen": True} if self._autogenerated else None,
694
+ )
695
+ if output_cols_prefix == "fit_predict_":
696
+ if hasattr(self._sklearn_object, "n_clusters"):
697
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
698
+ num_examples = self._sklearn_object.n_clusters
699
+ elif hasattr(self._sklearn_object, "min_samples"):
700
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
701
+ num_examples = self._sklearn_object.min_samples
702
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
703
+ # LocalOutlierFactor expects n_neighbors <= n_samples
704
+ num_examples = self._sklearn_object.n_neighbors
705
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
706
+ else:
707
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
671
708
 
672
709
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
673
710
  # seen during the fit.
@@ -679,12 +716,14 @@ class RANSACRegressor(BaseTransformer):
679
716
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
680
717
  if self.sample_weight_col:
681
718
  output_df_columns_set -= set(self.sample_weight_col)
719
+
682
720
  # if the dimension of inferred output column names is correct; use it
683
721
  if len(expected_output_cols_list) == len(output_df_columns_set):
684
- return expected_output_cols_list
722
+ return expected_output_cols_list, output_df_pd
685
723
  # otherwise, use the sklearn estimator's output
686
724
  else:
687
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
725
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
726
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
688
727
 
689
728
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
690
729
  @telemetry.send_api_usage_telemetry(
@@ -730,7 +769,7 @@ class RANSACRegressor(BaseTransformer):
730
769
  drop_input_cols=self._drop_input_cols,
731
770
  expected_output_cols_type="float",
732
771
  )
733
- expected_output_cols = self._align_expected_output_names(
772
+ expected_output_cols, _ = self._align_expected_output(
734
773
  inference_method, dataset, expected_output_cols, output_cols_prefix
735
774
  )
736
775
 
@@ -796,7 +835,7 @@ class RANSACRegressor(BaseTransformer):
796
835
  drop_input_cols=self._drop_input_cols,
797
836
  expected_output_cols_type="float",
798
837
  )
799
- expected_output_cols = self._align_expected_output_names(
838
+ expected_output_cols, _ = self._align_expected_output(
800
839
  inference_method, dataset, expected_output_cols, output_cols_prefix
801
840
  )
802
841
  elif isinstance(dataset, pd.DataFrame):
@@ -859,7 +898,7 @@ class RANSACRegressor(BaseTransformer):
859
898
  drop_input_cols=self._drop_input_cols,
860
899
  expected_output_cols_type="float",
861
900
  )
862
- expected_output_cols = self._align_expected_output_names(
901
+ expected_output_cols, _ = self._align_expected_output(
863
902
  inference_method, dataset, expected_output_cols, output_cols_prefix
864
903
  )
865
904
 
@@ -924,7 +963,7 @@ class RANSACRegressor(BaseTransformer):
924
963
  drop_input_cols = self._drop_input_cols,
925
964
  expected_output_cols_type="float",
926
965
  )
927
- expected_output_cols = self._align_expected_output_names(
966
+ expected_output_cols, _ = self._align_expected_output(
928
967
  inference_method, dataset, expected_output_cols, output_cols_prefix
929
968
  )
930
969