snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +156 -20
  5. snowflake/ml/_internal/utils/identifier.py +48 -11
  6. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  7. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  8. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  9. snowflake/ml/_internal/utils/table_manager.py +19 -1
  10. snowflake/ml/_internal/utils/uri.py +2 -2
  11. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  12. snowflake/ml/data/data_connector.py +88 -9
  13. snowflake/ml/data/data_ingestor.py +18 -1
  14. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  15. snowflake/ml/data/torch_utils.py +68 -0
  16. snowflake/ml/dataset/dataset.py +1 -3
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +9 -3
  19. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  20. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  21. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  22. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  23. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  24. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  26. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  27. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  29. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  31. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  32. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  33. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  34. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  35. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  36. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  37. snowflake/ml/feature_store/feature_store.py +100 -41
  38. snowflake/ml/feature_store/feature_view.py +149 -5
  39. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  40. snowflake/ml/fileset/fileset.py +1 -1
  41. snowflake/ml/fileset/sfcfs.py +9 -3
  42. snowflake/ml/model/_client/model/model_impl.py +11 -2
  43. snowflake/ml/model/_client/model/model_version_impl.py +186 -20
  44. snowflake/ml/model/_client/ops/model_ops.py +144 -30
  45. snowflake/ml/model/_client/ops/service_ops.py +312 -0
  46. snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
  47. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
  48. snowflake/ml/model/_client/sql/model_version.py +13 -4
  49. snowflake/ml/model/_client/sql/service.py +196 -0
  50. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  51. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  52. snowflake/ml/model/_model_composer/model_composer.py +5 -0
  53. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
  54. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  55. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  57. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  58. snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
  59. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  60. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
  61. snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
  62. snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
  63. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  64. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  65. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  66. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  67. snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
  68. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
  69. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  70. snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
  71. snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
  72. snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
  73. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
  74. snowflake/ml/model/_packager/model_packager.py +4 -1
  75. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  76. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  77. snowflake/ml/model/_signatures/utils.py +9 -0
  78. snowflake/ml/model/models/llm.py +3 -1
  79. snowflake/ml/model/type_hints.py +10 -4
  80. snowflake/ml/modeling/_internal/constants.py +1 -0
  81. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  82. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  83. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  84. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  85. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  86. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  87. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  88. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  89. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  90. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  91. snowflake/ml/modeling/cluster/birch.py +60 -21
  92. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  93. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  94. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  95. snowflake/ml/modeling/cluster/k_means.py +60 -21
  96. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  97. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  98. snowflake/ml/modeling/cluster/optics.py +60 -21
  99. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  100. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  101. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  102. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  103. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  104. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  105. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  106. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  107. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  108. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  109. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  110. snowflake/ml/modeling/covariance/oas.py +60 -21
  111. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  112. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  113. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  114. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  115. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  116. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  117. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  118. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  119. snowflake/ml/modeling/decomposition/pca.py +60 -21
  120. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  121. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  122. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  123. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  124. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  125. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  126. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  127. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  128. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  129. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  130. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  131. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  132. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  133. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  134. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  135. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  136. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  137. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  138. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  139. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  140. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  141. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  142. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  143. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  144. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  145. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  146. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  147. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  148. snowflake/ml/modeling/framework/base.py +28 -19
  149. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  150. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  151. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  152. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  153. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  154. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  155. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  156. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  157. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  158. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  159. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  160. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  161. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  162. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  163. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  164. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  165. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  166. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  167. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  168. snowflake/ml/modeling/linear_model/lars.py +60 -21
  169. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  170. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  171. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  172. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  173. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  174. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  175. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  176. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  177. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  178. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  179. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  180. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  181. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  182. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  183. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  184. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  185. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  186. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  187. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  188. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  189. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  190. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  191. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  192. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  193. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  194. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  195. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  196. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  197. snowflake/ml/modeling/manifold/isomap.py +60 -21
  198. snowflake/ml/modeling/manifold/mds.py +60 -21
  199. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  200. snowflake/ml/modeling/manifold/tsne.py +60 -21
  201. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  202. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  203. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  204. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  205. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  206. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  207. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  208. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  209. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  210. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  211. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  212. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  213. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  214. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  215. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  216. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  217. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  218. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  219. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  220. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  221. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  222. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  223. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  224. snowflake/ml/modeling/pipeline/pipeline.py +4 -12
  225. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  226. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  227. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  228. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  229. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  230. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  231. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  232. snowflake/ml/modeling/svm/svc.py +60 -21
  233. snowflake/ml/modeling/svm/svr.py +60 -21
  234. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  235. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  236. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  237. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  238. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  239. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  240. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  241. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  242. snowflake/ml/registry/_manager/model_manager.py +20 -2
  243. snowflake/ml/registry/model_registry.py +1 -1
  244. snowflake/ml/registry/registry.py +1 -2
  245. snowflake/ml/utils/sql_client.py +22 -0
  246. snowflake/ml/version.py +1 -1
  247. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
  248. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
  249. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  250. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  251. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  252. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -517,12 +514,23 @@ class AdaBoostClassifier(BaseTransformer):
517
514
  autogenerated=self._autogenerated,
518
515
  subproject=_SUBPROJECT,
519
516
  )
520
- output_result, fitted_estimator = model_trainer.train_fit_predict(
521
- drop_input_cols=self._drop_input_cols,
522
- expected_output_cols_list=(
523
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
524
- ),
517
+ expected_output_cols = (
518
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
525
519
  )
520
+ if isinstance(dataset, DataFrame):
521
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
522
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
523
+ )
524
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
525
+ drop_input_cols=self._drop_input_cols,
526
+ expected_output_cols_list=expected_output_cols,
527
+ example_output_pd_df=example_output_pd_df,
528
+ )
529
+ else:
530
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
531
+ drop_input_cols=self._drop_input_cols,
532
+ expected_output_cols_list=expected_output_cols,
533
+ )
526
534
  self._sklearn_object = fitted_estimator
527
535
  self._is_fitted = True
528
536
  return output_result
@@ -601,12 +609,41 @@ class AdaBoostClassifier(BaseTransformer):
601
609
 
602
610
  return rv
603
611
 
604
- def _align_expected_output_names(
605
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
606
- ) -> List[str]:
612
+ def _align_expected_output(
613
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
614
+ ) -> Tuple[List[str], pd.DataFrame]:
615
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
616
+ and output dataframe with 1 line.
617
+ If the method is fit_predict, run 2 lines of data.
618
+ """
607
619
  # in case the inferred output column names dimension is different
608
620
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
609
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
621
+
622
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
623
+ # so change the minimum of number of rows to 2
624
+ num_examples = 2
625
+ statement_params = telemetry.get_function_usage_statement_params(
626
+ project=_PROJECT,
627
+ subproject=_SUBPROJECT,
628
+ function_name=telemetry.get_statement_params_full_func_name(
629
+ inspect.currentframe(), AdaBoostClassifier.__class__.__name__
630
+ ),
631
+ api_calls=[Session.call],
632
+ custom_tags={"autogen": True} if self._autogenerated else None,
633
+ )
634
+ if output_cols_prefix == "fit_predict_":
635
+ if hasattr(self._sklearn_object, "n_clusters"):
636
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
637
+ num_examples = self._sklearn_object.n_clusters
638
+ elif hasattr(self._sklearn_object, "min_samples"):
639
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
640
+ num_examples = self._sklearn_object.min_samples
641
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
642
+ # LocalOutlierFactor expects n_neighbors <= n_samples
643
+ num_examples = self._sklearn_object.n_neighbors
644
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
645
+ else:
646
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
610
647
 
611
648
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
612
649
  # seen during the fit.
@@ -618,12 +655,14 @@ class AdaBoostClassifier(BaseTransformer):
618
655
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
619
656
  if self.sample_weight_col:
620
657
  output_df_columns_set -= set(self.sample_weight_col)
658
+
621
659
  # if the dimension of inferred output column names is correct; use it
622
660
  if len(expected_output_cols_list) == len(output_df_columns_set):
623
- return expected_output_cols_list
661
+ return expected_output_cols_list, output_df_pd
624
662
  # otherwise, use the sklearn estimator's output
625
663
  else:
626
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
664
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
665
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
627
666
 
628
667
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
629
668
  @telemetry.send_api_usage_telemetry(
@@ -671,7 +710,7 @@ class AdaBoostClassifier(BaseTransformer):
671
710
  drop_input_cols=self._drop_input_cols,
672
711
  expected_output_cols_type="float",
673
712
  )
674
- expected_output_cols = self._align_expected_output_names(
713
+ expected_output_cols, _ = self._align_expected_output(
675
714
  inference_method, dataset, expected_output_cols, output_cols_prefix
676
715
  )
677
716
 
@@ -739,7 +778,7 @@ class AdaBoostClassifier(BaseTransformer):
739
778
  drop_input_cols=self._drop_input_cols,
740
779
  expected_output_cols_type="float",
741
780
  )
742
- expected_output_cols = self._align_expected_output_names(
781
+ expected_output_cols, _ = self._align_expected_output(
743
782
  inference_method, dataset, expected_output_cols, output_cols_prefix
744
783
  )
745
784
  elif isinstance(dataset, pd.DataFrame):
@@ -804,7 +843,7 @@ class AdaBoostClassifier(BaseTransformer):
804
843
  drop_input_cols=self._drop_input_cols,
805
844
  expected_output_cols_type="float",
806
845
  )
807
- expected_output_cols = self._align_expected_output_names(
846
+ expected_output_cols, _ = self._align_expected_output(
808
847
  inference_method, dataset, expected_output_cols, output_cols_prefix
809
848
  )
810
849
 
@@ -869,7 +908,7 @@ class AdaBoostClassifier(BaseTransformer):
869
908
  drop_input_cols = self._drop_input_cols,
870
909
  expected_output_cols_type="float",
871
910
  )
872
- expected_output_cols = self._align_expected_output_names(
911
+ expected_output_cols, _ = self._align_expected_output(
873
912
  inference_method, dataset, expected_output_cols, output_cols_prefix
874
913
  )
875
914
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -514,12 +511,23 @@ class AdaBoostRegressor(BaseTransformer):
514
511
  autogenerated=self._autogenerated,
515
512
  subproject=_SUBPROJECT,
516
513
  )
517
- output_result, fitted_estimator = model_trainer.train_fit_predict(
518
- drop_input_cols=self._drop_input_cols,
519
- expected_output_cols_list=(
520
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
521
- ),
514
+ expected_output_cols = (
515
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
522
516
  )
517
+ if isinstance(dataset, DataFrame):
518
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
519
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
520
+ )
521
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
522
+ drop_input_cols=self._drop_input_cols,
523
+ expected_output_cols_list=expected_output_cols,
524
+ example_output_pd_df=example_output_pd_df,
525
+ )
526
+ else:
527
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
528
+ drop_input_cols=self._drop_input_cols,
529
+ expected_output_cols_list=expected_output_cols,
530
+ )
523
531
  self._sklearn_object = fitted_estimator
524
532
  self._is_fitted = True
525
533
  return output_result
@@ -598,12 +606,41 @@ class AdaBoostRegressor(BaseTransformer):
598
606
 
599
607
  return rv
600
608
 
601
- def _align_expected_output_names(
602
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
603
- ) -> List[str]:
609
+ def _align_expected_output(
610
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
611
+ ) -> Tuple[List[str], pd.DataFrame]:
612
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
613
+ and output dataframe with 1 line.
614
+ If the method is fit_predict, run 2 lines of data.
615
+ """
604
616
  # in case the inferred output column names dimension is different
605
617
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
606
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
618
+
619
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
620
+ # so change the minimum of number of rows to 2
621
+ num_examples = 2
622
+ statement_params = telemetry.get_function_usage_statement_params(
623
+ project=_PROJECT,
624
+ subproject=_SUBPROJECT,
625
+ function_name=telemetry.get_statement_params_full_func_name(
626
+ inspect.currentframe(), AdaBoostRegressor.__class__.__name__
627
+ ),
628
+ api_calls=[Session.call],
629
+ custom_tags={"autogen": True} if self._autogenerated else None,
630
+ )
631
+ if output_cols_prefix == "fit_predict_":
632
+ if hasattr(self._sklearn_object, "n_clusters"):
633
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
634
+ num_examples = self._sklearn_object.n_clusters
635
+ elif hasattr(self._sklearn_object, "min_samples"):
636
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
637
+ num_examples = self._sklearn_object.min_samples
638
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
639
+ # LocalOutlierFactor expects n_neighbors <= n_samples
640
+ num_examples = self._sklearn_object.n_neighbors
641
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
642
+ else:
643
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
607
644
 
608
645
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
609
646
  # seen during the fit.
@@ -615,12 +652,14 @@ class AdaBoostRegressor(BaseTransformer):
615
652
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
616
653
  if self.sample_weight_col:
617
654
  output_df_columns_set -= set(self.sample_weight_col)
655
+
618
656
  # if the dimension of inferred output column names is correct; use it
619
657
  if len(expected_output_cols_list) == len(output_df_columns_set):
620
- return expected_output_cols_list
658
+ return expected_output_cols_list, output_df_pd
621
659
  # otherwise, use the sklearn estimator's output
622
660
  else:
623
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
661
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
662
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
624
663
 
625
664
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
626
665
  @telemetry.send_api_usage_telemetry(
@@ -666,7 +705,7 @@ class AdaBoostRegressor(BaseTransformer):
666
705
  drop_input_cols=self._drop_input_cols,
667
706
  expected_output_cols_type="float",
668
707
  )
669
- expected_output_cols = self._align_expected_output_names(
708
+ expected_output_cols, _ = self._align_expected_output(
670
709
  inference_method, dataset, expected_output_cols, output_cols_prefix
671
710
  )
672
711
 
@@ -732,7 +771,7 @@ class AdaBoostRegressor(BaseTransformer):
732
771
  drop_input_cols=self._drop_input_cols,
733
772
  expected_output_cols_type="float",
734
773
  )
735
- expected_output_cols = self._align_expected_output_names(
774
+ expected_output_cols, _ = self._align_expected_output(
736
775
  inference_method, dataset, expected_output_cols, output_cols_prefix
737
776
  )
738
777
  elif isinstance(dataset, pd.DataFrame):
@@ -795,7 +834,7 @@ class AdaBoostRegressor(BaseTransformer):
795
834
  drop_input_cols=self._drop_input_cols,
796
835
  expected_output_cols_type="float",
797
836
  )
798
- expected_output_cols = self._align_expected_output_names(
837
+ expected_output_cols, _ = self._align_expected_output(
799
838
  inference_method, dataset, expected_output_cols, output_cols_prefix
800
839
  )
801
840
 
@@ -860,7 +899,7 @@ class AdaBoostRegressor(BaseTransformer):
860
899
  drop_input_cols = self._drop_input_cols,
861
900
  expected_output_cols_type="float",
862
901
  )
863
- expected_output_cols = self._align_expected_output_names(
902
+ expected_output_cols, _ = self._align_expected_output(
864
903
  inference_method, dataset, expected_output_cols, output_cols_prefix
865
904
  )
866
905
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -549,12 +546,23 @@ class BaggingClassifier(BaseTransformer):
549
546
  autogenerated=self._autogenerated,
550
547
  subproject=_SUBPROJECT,
551
548
  )
552
- output_result, fitted_estimator = model_trainer.train_fit_predict(
553
- drop_input_cols=self._drop_input_cols,
554
- expected_output_cols_list=(
555
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
556
- ),
549
+ expected_output_cols = (
550
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
557
551
  )
552
+ if isinstance(dataset, DataFrame):
553
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
554
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
555
+ )
556
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
557
+ drop_input_cols=self._drop_input_cols,
558
+ expected_output_cols_list=expected_output_cols,
559
+ example_output_pd_df=example_output_pd_df,
560
+ )
561
+ else:
562
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
563
+ drop_input_cols=self._drop_input_cols,
564
+ expected_output_cols_list=expected_output_cols,
565
+ )
558
566
  self._sklearn_object = fitted_estimator
559
567
  self._is_fitted = True
560
568
  return output_result
@@ -633,12 +641,41 @@ class BaggingClassifier(BaseTransformer):
633
641
 
634
642
  return rv
635
643
 
636
- def _align_expected_output_names(
637
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
638
- ) -> List[str]:
644
+ def _align_expected_output(
645
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
646
+ ) -> Tuple[List[str], pd.DataFrame]:
647
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
648
+ and output dataframe with 1 line.
649
+ If the method is fit_predict, run 2 lines of data.
650
+ """
639
651
  # in case the inferred output column names dimension is different
640
652
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
641
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
653
+
654
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
655
+ # so change the minimum of number of rows to 2
656
+ num_examples = 2
657
+ statement_params = telemetry.get_function_usage_statement_params(
658
+ project=_PROJECT,
659
+ subproject=_SUBPROJECT,
660
+ function_name=telemetry.get_statement_params_full_func_name(
661
+ inspect.currentframe(), BaggingClassifier.__class__.__name__
662
+ ),
663
+ api_calls=[Session.call],
664
+ custom_tags={"autogen": True} if self._autogenerated else None,
665
+ )
666
+ if output_cols_prefix == "fit_predict_":
667
+ if hasattr(self._sklearn_object, "n_clusters"):
668
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
669
+ num_examples = self._sklearn_object.n_clusters
670
+ elif hasattr(self._sklearn_object, "min_samples"):
671
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
672
+ num_examples = self._sklearn_object.min_samples
673
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
674
+ # LocalOutlierFactor expects n_neighbors <= n_samples
675
+ num_examples = self._sklearn_object.n_neighbors
676
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
677
+ else:
678
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
642
679
 
643
680
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
644
681
  # seen during the fit.
@@ -650,12 +687,14 @@ class BaggingClassifier(BaseTransformer):
650
687
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
651
688
  if self.sample_weight_col:
652
689
  output_df_columns_set -= set(self.sample_weight_col)
690
+
653
691
  # if the dimension of inferred output column names is correct; use it
654
692
  if len(expected_output_cols_list) == len(output_df_columns_set):
655
- return expected_output_cols_list
693
+ return expected_output_cols_list, output_df_pd
656
694
  # otherwise, use the sklearn estimator's output
657
695
  else:
658
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
696
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
697
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
659
698
 
660
699
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
661
700
  @telemetry.send_api_usage_telemetry(
@@ -703,7 +742,7 @@ class BaggingClassifier(BaseTransformer):
703
742
  drop_input_cols=self._drop_input_cols,
704
743
  expected_output_cols_type="float",
705
744
  )
706
- expected_output_cols = self._align_expected_output_names(
745
+ expected_output_cols, _ = self._align_expected_output(
707
746
  inference_method, dataset, expected_output_cols, output_cols_prefix
708
747
  )
709
748
 
@@ -771,7 +810,7 @@ class BaggingClassifier(BaseTransformer):
771
810
  drop_input_cols=self._drop_input_cols,
772
811
  expected_output_cols_type="float",
773
812
  )
774
- expected_output_cols = self._align_expected_output_names(
813
+ expected_output_cols, _ = self._align_expected_output(
775
814
  inference_method, dataset, expected_output_cols, output_cols_prefix
776
815
  )
777
816
  elif isinstance(dataset, pd.DataFrame):
@@ -836,7 +875,7 @@ class BaggingClassifier(BaseTransformer):
836
875
  drop_input_cols=self._drop_input_cols,
837
876
  expected_output_cols_type="float",
838
877
  )
839
- expected_output_cols = self._align_expected_output_names(
878
+ expected_output_cols, _ = self._align_expected_output(
840
879
  inference_method, dataset, expected_output_cols, output_cols_prefix
841
880
  )
842
881
 
@@ -901,7 +940,7 @@ class BaggingClassifier(BaseTransformer):
901
940
  drop_input_cols = self._drop_input_cols,
902
941
  expected_output_cols_type="float",
903
942
  )
904
- expected_output_cols = self._align_expected_output_names(
943
+ expected_output_cols, _ = self._align_expected_output(
905
944
  inference_method, dataset, expected_output_cols, output_cols_prefix
906
945
  )
907
946
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -549,12 +546,23 @@ class BaggingRegressor(BaseTransformer):
549
546
  autogenerated=self._autogenerated,
550
547
  subproject=_SUBPROJECT,
551
548
  )
552
- output_result, fitted_estimator = model_trainer.train_fit_predict(
553
- drop_input_cols=self._drop_input_cols,
554
- expected_output_cols_list=(
555
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
556
- ),
549
+ expected_output_cols = (
550
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
557
551
  )
552
+ if isinstance(dataset, DataFrame):
553
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
554
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
555
+ )
556
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
557
+ drop_input_cols=self._drop_input_cols,
558
+ expected_output_cols_list=expected_output_cols,
559
+ example_output_pd_df=example_output_pd_df,
560
+ )
561
+ else:
562
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
563
+ drop_input_cols=self._drop_input_cols,
564
+ expected_output_cols_list=expected_output_cols,
565
+ )
558
566
  self._sklearn_object = fitted_estimator
559
567
  self._is_fitted = True
560
568
  return output_result
@@ -633,12 +641,41 @@ class BaggingRegressor(BaseTransformer):
633
641
 
634
642
  return rv
635
643
 
636
- def _align_expected_output_names(
637
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
638
- ) -> List[str]:
644
+ def _align_expected_output(
645
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
646
+ ) -> Tuple[List[str], pd.DataFrame]:
647
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
648
+ and output dataframe with 1 line.
649
+ If the method is fit_predict, run 2 lines of data.
650
+ """
639
651
  # in case the inferred output column names dimension is different
640
652
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
641
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
653
+
654
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
655
+ # so change the minimum of number of rows to 2
656
+ num_examples = 2
657
+ statement_params = telemetry.get_function_usage_statement_params(
658
+ project=_PROJECT,
659
+ subproject=_SUBPROJECT,
660
+ function_name=telemetry.get_statement_params_full_func_name(
661
+ inspect.currentframe(), BaggingRegressor.__class__.__name__
662
+ ),
663
+ api_calls=[Session.call],
664
+ custom_tags={"autogen": True} if self._autogenerated else None,
665
+ )
666
+ if output_cols_prefix == "fit_predict_":
667
+ if hasattr(self._sklearn_object, "n_clusters"):
668
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
669
+ num_examples = self._sklearn_object.n_clusters
670
+ elif hasattr(self._sklearn_object, "min_samples"):
671
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
672
+ num_examples = self._sklearn_object.min_samples
673
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
674
+ # LocalOutlierFactor expects n_neighbors <= n_samples
675
+ num_examples = self._sklearn_object.n_neighbors
676
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
677
+ else:
678
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
642
679
 
643
680
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
644
681
  # seen during the fit.
@@ -650,12 +687,14 @@ class BaggingRegressor(BaseTransformer):
650
687
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
651
688
  if self.sample_weight_col:
652
689
  output_df_columns_set -= set(self.sample_weight_col)
690
+
653
691
  # if the dimension of inferred output column names is correct; use it
654
692
  if len(expected_output_cols_list) == len(output_df_columns_set):
655
- return expected_output_cols_list
693
+ return expected_output_cols_list, output_df_pd
656
694
  # otherwise, use the sklearn estimator's output
657
695
  else:
658
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
696
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
697
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
659
698
 
660
699
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
661
700
  @telemetry.send_api_usage_telemetry(
@@ -701,7 +740,7 @@ class BaggingRegressor(BaseTransformer):
701
740
  drop_input_cols=self._drop_input_cols,
702
741
  expected_output_cols_type="float",
703
742
  )
704
- expected_output_cols = self._align_expected_output_names(
743
+ expected_output_cols, _ = self._align_expected_output(
705
744
  inference_method, dataset, expected_output_cols, output_cols_prefix
706
745
  )
707
746
 
@@ -767,7 +806,7 @@ class BaggingRegressor(BaseTransformer):
767
806
  drop_input_cols=self._drop_input_cols,
768
807
  expected_output_cols_type="float",
769
808
  )
770
- expected_output_cols = self._align_expected_output_names(
809
+ expected_output_cols, _ = self._align_expected_output(
771
810
  inference_method, dataset, expected_output_cols, output_cols_prefix
772
811
  )
773
812
  elif isinstance(dataset, pd.DataFrame):
@@ -830,7 +869,7 @@ class BaggingRegressor(BaseTransformer):
830
869
  drop_input_cols=self._drop_input_cols,
831
870
  expected_output_cols_type="float",
832
871
  )
833
- expected_output_cols = self._align_expected_output_names(
872
+ expected_output_cols, _ = self._align_expected_output(
834
873
  inference_method, dataset, expected_output_cols, output_cols_prefix
835
874
  )
836
875
 
@@ -895,7 +934,7 @@ class BaggingRegressor(BaseTransformer):
895
934
  drop_input_cols = self._drop_input_cols,
896
935
  expected_output_cols_type="float",
897
936
  )
898
- expected_output_cols = self._align_expected_output_names(
937
+ expected_output_cols, _ = self._align_expected_output(
899
938
  inference_method, dataset, expected_output_cols, output_cols_prefix
900
939
  )
901
940