snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +156 -20
  5. snowflake/ml/_internal/utils/identifier.py +48 -11
  6. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  7. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  8. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  9. snowflake/ml/_internal/utils/table_manager.py +19 -1
  10. snowflake/ml/_internal/utils/uri.py +2 -2
  11. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  12. snowflake/ml/data/data_connector.py +88 -9
  13. snowflake/ml/data/data_ingestor.py +18 -1
  14. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  15. snowflake/ml/data/torch_utils.py +68 -0
  16. snowflake/ml/dataset/dataset.py +1 -3
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +9 -3
  19. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  20. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  21. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  22. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  23. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  24. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  26. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  27. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  29. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  31. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  32. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  33. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  34. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  35. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  36. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  37. snowflake/ml/feature_store/feature_store.py +100 -41
  38. snowflake/ml/feature_store/feature_view.py +149 -5
  39. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  40. snowflake/ml/fileset/fileset.py +1 -1
  41. snowflake/ml/fileset/sfcfs.py +9 -3
  42. snowflake/ml/model/_client/model/model_impl.py +11 -2
  43. snowflake/ml/model/_client/model/model_version_impl.py +186 -20
  44. snowflake/ml/model/_client/ops/model_ops.py +144 -30
  45. snowflake/ml/model/_client/ops/service_ops.py +312 -0
  46. snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
  47. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
  48. snowflake/ml/model/_client/sql/model_version.py +13 -4
  49. snowflake/ml/model/_client/sql/service.py +196 -0
  50. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  51. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  52. snowflake/ml/model/_model_composer/model_composer.py +5 -0
  53. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
  54. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  55. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  57. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  58. snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
  59. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  60. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
  61. snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
  62. snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
  63. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  64. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  65. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  66. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  67. snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
  68. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
  69. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  70. snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
  71. snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
  72. snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
  73. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
  74. snowflake/ml/model/_packager/model_packager.py +4 -1
  75. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  76. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  77. snowflake/ml/model/_signatures/utils.py +9 -0
  78. snowflake/ml/model/models/llm.py +3 -1
  79. snowflake/ml/model/type_hints.py +10 -4
  80. snowflake/ml/modeling/_internal/constants.py +1 -0
  81. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  82. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  83. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  84. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  85. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  86. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  87. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  88. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  89. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  90. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  91. snowflake/ml/modeling/cluster/birch.py +60 -21
  92. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  93. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  94. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  95. snowflake/ml/modeling/cluster/k_means.py +60 -21
  96. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  97. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  98. snowflake/ml/modeling/cluster/optics.py +60 -21
  99. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  100. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  101. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  102. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  103. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  104. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  105. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  106. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  107. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  108. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  109. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  110. snowflake/ml/modeling/covariance/oas.py +60 -21
  111. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  112. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  113. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  114. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  115. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  116. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  117. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  118. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  119. snowflake/ml/modeling/decomposition/pca.py +60 -21
  120. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  121. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  122. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  123. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  124. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  125. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  126. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  127. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  128. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  129. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  130. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  131. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  132. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  133. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  134. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  135. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  136. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  137. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  138. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  139. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  140. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  141. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  142. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  143. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  144. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  145. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  146. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  147. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  148. snowflake/ml/modeling/framework/base.py +28 -19
  149. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  150. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  151. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  152. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  153. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  154. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  155. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  156. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  157. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  158. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  159. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  160. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  161. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  162. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  163. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  164. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  165. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  166. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  167. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  168. snowflake/ml/modeling/linear_model/lars.py +60 -21
  169. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  170. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  171. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  172. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  173. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  174. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  175. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  176. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  177. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  178. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  179. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  180. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  181. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  182. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  183. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  184. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  185. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  186. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  187. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  188. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  189. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  190. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  191. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  192. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  193. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  194. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  195. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  196. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  197. snowflake/ml/modeling/manifold/isomap.py +60 -21
  198. snowflake/ml/modeling/manifold/mds.py +60 -21
  199. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  200. snowflake/ml/modeling/manifold/tsne.py +60 -21
  201. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  202. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  203. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  204. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  205. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  206. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  207. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  208. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  209. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  210. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  211. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  212. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  213. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  214. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  215. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  216. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  217. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  218. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  219. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  220. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  221. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  222. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  223. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  224. snowflake/ml/modeling/pipeline/pipeline.py +4 -12
  225. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  226. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  227. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  228. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  229. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  230. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  231. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  232. snowflake/ml/modeling/svm/svc.py +60 -21
  233. snowflake/ml/modeling/svm/svr.py +60 -21
  234. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  235. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  236. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  237. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  238. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  239. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  240. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  241. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  242. snowflake/ml/registry/_manager/model_manager.py +20 -2
  243. snowflake/ml/registry/model_registry.py +1 -1
  244. snowflake/ml/registry/registry.py +1 -2
  245. snowflake/ml/utils/sql_client.py +22 -0
  246. snowflake/ml/version.py +1 -1
  247. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
  248. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
  249. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  250. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  251. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  252. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -536,12 +533,23 @@ class SVR(BaseTransformer):
536
533
  autogenerated=self._autogenerated,
537
534
  subproject=_SUBPROJECT,
538
535
  )
539
- output_result, fitted_estimator = model_trainer.train_fit_predict(
540
- drop_input_cols=self._drop_input_cols,
541
- expected_output_cols_list=(
542
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
543
- ),
536
+ expected_output_cols = (
537
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
544
538
  )
539
+ if isinstance(dataset, DataFrame):
540
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
541
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
542
+ )
543
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
544
+ drop_input_cols=self._drop_input_cols,
545
+ expected_output_cols_list=expected_output_cols,
546
+ example_output_pd_df=example_output_pd_df,
547
+ )
548
+ else:
549
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
550
+ drop_input_cols=self._drop_input_cols,
551
+ expected_output_cols_list=expected_output_cols,
552
+ )
545
553
  self._sklearn_object = fitted_estimator
546
554
  self._is_fitted = True
547
555
  return output_result
@@ -620,12 +628,41 @@ class SVR(BaseTransformer):
620
628
 
621
629
  return rv
622
630
 
623
- def _align_expected_output_names(
624
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
625
- ) -> List[str]:
631
+ def _align_expected_output(
632
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
633
+ ) -> Tuple[List[str], pd.DataFrame]:
634
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
635
+ and output dataframe with 1 line.
636
+ If the method is fit_predict, run 2 lines of data.
637
+ """
626
638
  # in case the inferred output column names dimension is different
627
639
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
628
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
640
+
641
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
642
+ # so change the minimum of number of rows to 2
643
+ num_examples = 2
644
+ statement_params = telemetry.get_function_usage_statement_params(
645
+ project=_PROJECT,
646
+ subproject=_SUBPROJECT,
647
+ function_name=telemetry.get_statement_params_full_func_name(
648
+ inspect.currentframe(), SVR.__class__.__name__
649
+ ),
650
+ api_calls=[Session.call],
651
+ custom_tags={"autogen": True} if self._autogenerated else None,
652
+ )
653
+ if output_cols_prefix == "fit_predict_":
654
+ if hasattr(self._sklearn_object, "n_clusters"):
655
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
656
+ num_examples = self._sklearn_object.n_clusters
657
+ elif hasattr(self._sklearn_object, "min_samples"):
658
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
659
+ num_examples = self._sklearn_object.min_samples
660
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
661
+ # LocalOutlierFactor expects n_neighbors <= n_samples
662
+ num_examples = self._sklearn_object.n_neighbors
663
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
664
+ else:
665
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
629
666
 
630
667
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
631
668
  # seen during the fit.
@@ -637,12 +674,14 @@ class SVR(BaseTransformer):
637
674
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
638
675
  if self.sample_weight_col:
639
676
  output_df_columns_set -= set(self.sample_weight_col)
677
+
640
678
  # if the dimension of inferred output column names is correct; use it
641
679
  if len(expected_output_cols_list) == len(output_df_columns_set):
642
- return expected_output_cols_list
680
+ return expected_output_cols_list, output_df_pd
643
681
  # otherwise, use the sklearn estimator's output
644
682
  else:
645
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
683
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
684
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
646
685
 
647
686
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
648
687
  @telemetry.send_api_usage_telemetry(
@@ -688,7 +727,7 @@ class SVR(BaseTransformer):
688
727
  drop_input_cols=self._drop_input_cols,
689
728
  expected_output_cols_type="float",
690
729
  )
691
- expected_output_cols = self._align_expected_output_names(
730
+ expected_output_cols, _ = self._align_expected_output(
692
731
  inference_method, dataset, expected_output_cols, output_cols_prefix
693
732
  )
694
733
 
@@ -754,7 +793,7 @@ class SVR(BaseTransformer):
754
793
  drop_input_cols=self._drop_input_cols,
755
794
  expected_output_cols_type="float",
756
795
  )
757
- expected_output_cols = self._align_expected_output_names(
796
+ expected_output_cols, _ = self._align_expected_output(
758
797
  inference_method, dataset, expected_output_cols, output_cols_prefix
759
798
  )
760
799
  elif isinstance(dataset, pd.DataFrame):
@@ -817,7 +856,7 @@ class SVR(BaseTransformer):
817
856
  drop_input_cols=self._drop_input_cols,
818
857
  expected_output_cols_type="float",
819
858
  )
820
- expected_output_cols = self._align_expected_output_names(
859
+ expected_output_cols, _ = self._align_expected_output(
821
860
  inference_method, dataset, expected_output_cols, output_cols_prefix
822
861
  )
823
862
 
@@ -882,7 +921,7 @@ class SVR(BaseTransformer):
882
921
  drop_input_cols = self._drop_input_cols,
883
922
  expected_output_cols_type="float",
884
923
  )
885
- expected_output_cols = self._align_expected_output_names(
924
+ expected_output_cols, _ = self._align_expected_output(
886
925
  inference_method, dataset, expected_output_cols, output_cols_prefix
887
926
  )
888
927
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -603,12 +600,23 @@ class DecisionTreeClassifier(BaseTransformer):
603
600
  autogenerated=self._autogenerated,
604
601
  subproject=_SUBPROJECT,
605
602
  )
606
- output_result, fitted_estimator = model_trainer.train_fit_predict(
607
- drop_input_cols=self._drop_input_cols,
608
- expected_output_cols_list=(
609
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
610
- ),
603
+ expected_output_cols = (
604
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
611
605
  )
606
+ if isinstance(dataset, DataFrame):
607
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
608
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
609
+ )
610
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
611
+ drop_input_cols=self._drop_input_cols,
612
+ expected_output_cols_list=expected_output_cols,
613
+ example_output_pd_df=example_output_pd_df,
614
+ )
615
+ else:
616
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
617
+ drop_input_cols=self._drop_input_cols,
618
+ expected_output_cols_list=expected_output_cols,
619
+ )
612
620
  self._sklearn_object = fitted_estimator
613
621
  self._is_fitted = True
614
622
  return output_result
@@ -687,12 +695,41 @@ class DecisionTreeClassifier(BaseTransformer):
687
695
 
688
696
  return rv
689
697
 
690
- def _align_expected_output_names(
691
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
692
- ) -> List[str]:
698
+ def _align_expected_output(
699
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
700
+ ) -> Tuple[List[str], pd.DataFrame]:
701
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
702
+ and output dataframe with 1 line.
703
+ If the method is fit_predict, run 2 lines of data.
704
+ """
693
705
  # in case the inferred output column names dimension is different
694
706
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
695
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
707
+
708
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
709
+ # so change the minimum of number of rows to 2
710
+ num_examples = 2
711
+ statement_params = telemetry.get_function_usage_statement_params(
712
+ project=_PROJECT,
713
+ subproject=_SUBPROJECT,
714
+ function_name=telemetry.get_statement_params_full_func_name(
715
+ inspect.currentframe(), DecisionTreeClassifier.__class__.__name__
716
+ ),
717
+ api_calls=[Session.call],
718
+ custom_tags={"autogen": True} if self._autogenerated else None,
719
+ )
720
+ if output_cols_prefix == "fit_predict_":
721
+ if hasattr(self._sklearn_object, "n_clusters"):
722
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
723
+ num_examples = self._sklearn_object.n_clusters
724
+ elif hasattr(self._sklearn_object, "min_samples"):
725
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
726
+ num_examples = self._sklearn_object.min_samples
727
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
728
+ # LocalOutlierFactor expects n_neighbors <= n_samples
729
+ num_examples = self._sklearn_object.n_neighbors
730
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
731
+ else:
732
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
696
733
 
697
734
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
698
735
  # seen during the fit.
@@ -704,12 +741,14 @@ class DecisionTreeClassifier(BaseTransformer):
704
741
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
705
742
  if self.sample_weight_col:
706
743
  output_df_columns_set -= set(self.sample_weight_col)
744
+
707
745
  # if the dimension of inferred output column names is correct; use it
708
746
  if len(expected_output_cols_list) == len(output_df_columns_set):
709
- return expected_output_cols_list
747
+ return expected_output_cols_list, output_df_pd
710
748
  # otherwise, use the sklearn estimator's output
711
749
  else:
712
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
750
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
751
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
713
752
 
714
753
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
715
754
  @telemetry.send_api_usage_telemetry(
@@ -757,7 +796,7 @@ class DecisionTreeClassifier(BaseTransformer):
757
796
  drop_input_cols=self._drop_input_cols,
758
797
  expected_output_cols_type="float",
759
798
  )
760
- expected_output_cols = self._align_expected_output_names(
799
+ expected_output_cols, _ = self._align_expected_output(
761
800
  inference_method, dataset, expected_output_cols, output_cols_prefix
762
801
  )
763
802
 
@@ -825,7 +864,7 @@ class DecisionTreeClassifier(BaseTransformer):
825
864
  drop_input_cols=self._drop_input_cols,
826
865
  expected_output_cols_type="float",
827
866
  )
828
- expected_output_cols = self._align_expected_output_names(
867
+ expected_output_cols, _ = self._align_expected_output(
829
868
  inference_method, dataset, expected_output_cols, output_cols_prefix
830
869
  )
831
870
  elif isinstance(dataset, pd.DataFrame):
@@ -888,7 +927,7 @@ class DecisionTreeClassifier(BaseTransformer):
888
927
  drop_input_cols=self._drop_input_cols,
889
928
  expected_output_cols_type="float",
890
929
  )
891
- expected_output_cols = self._align_expected_output_names(
930
+ expected_output_cols, _ = self._align_expected_output(
892
931
  inference_method, dataset, expected_output_cols, output_cols_prefix
893
932
  )
894
933
 
@@ -953,7 +992,7 @@ class DecisionTreeClassifier(BaseTransformer):
953
992
  drop_input_cols = self._drop_input_cols,
954
993
  expected_output_cols_type="float",
955
994
  )
956
- expected_output_cols = self._align_expected_output_names(
995
+ expected_output_cols, _ = self._align_expected_output(
957
996
  inference_method, dataset, expected_output_cols, output_cols_prefix
958
997
  )
959
998
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -585,12 +582,23 @@ class DecisionTreeRegressor(BaseTransformer):
585
582
  autogenerated=self._autogenerated,
586
583
  subproject=_SUBPROJECT,
587
584
  )
588
- output_result, fitted_estimator = model_trainer.train_fit_predict(
589
- drop_input_cols=self._drop_input_cols,
590
- expected_output_cols_list=(
591
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
592
- ),
585
+ expected_output_cols = (
586
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
593
587
  )
588
+ if isinstance(dataset, DataFrame):
589
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
590
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
591
+ )
592
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
593
+ drop_input_cols=self._drop_input_cols,
594
+ expected_output_cols_list=expected_output_cols,
595
+ example_output_pd_df=example_output_pd_df,
596
+ )
597
+ else:
598
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
599
+ drop_input_cols=self._drop_input_cols,
600
+ expected_output_cols_list=expected_output_cols,
601
+ )
594
602
  self._sklearn_object = fitted_estimator
595
603
  self._is_fitted = True
596
604
  return output_result
@@ -669,12 +677,41 @@ class DecisionTreeRegressor(BaseTransformer):
669
677
 
670
678
  return rv
671
679
 
672
- def _align_expected_output_names(
673
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
674
- ) -> List[str]:
680
+ def _align_expected_output(
681
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
682
+ ) -> Tuple[List[str], pd.DataFrame]:
683
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
684
+ and output dataframe with 1 line.
685
+ If the method is fit_predict, run 2 lines of data.
686
+ """
675
687
  # in case the inferred output column names dimension is different
676
688
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
677
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
689
+
690
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
691
+ # so change the minimum of number of rows to 2
692
+ num_examples = 2
693
+ statement_params = telemetry.get_function_usage_statement_params(
694
+ project=_PROJECT,
695
+ subproject=_SUBPROJECT,
696
+ function_name=telemetry.get_statement_params_full_func_name(
697
+ inspect.currentframe(), DecisionTreeRegressor.__class__.__name__
698
+ ),
699
+ api_calls=[Session.call],
700
+ custom_tags={"autogen": True} if self._autogenerated else None,
701
+ )
702
+ if output_cols_prefix == "fit_predict_":
703
+ if hasattr(self._sklearn_object, "n_clusters"):
704
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
705
+ num_examples = self._sklearn_object.n_clusters
706
+ elif hasattr(self._sklearn_object, "min_samples"):
707
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
708
+ num_examples = self._sklearn_object.min_samples
709
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
710
+ # LocalOutlierFactor expects n_neighbors <= n_samples
711
+ num_examples = self._sklearn_object.n_neighbors
712
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
713
+ else:
714
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
678
715
 
679
716
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
680
717
  # seen during the fit.
@@ -686,12 +723,14 @@ class DecisionTreeRegressor(BaseTransformer):
686
723
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
687
724
  if self.sample_weight_col:
688
725
  output_df_columns_set -= set(self.sample_weight_col)
726
+
689
727
  # if the dimension of inferred output column names is correct; use it
690
728
  if len(expected_output_cols_list) == len(output_df_columns_set):
691
- return expected_output_cols_list
729
+ return expected_output_cols_list, output_df_pd
692
730
  # otherwise, use the sklearn estimator's output
693
731
  else:
694
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
732
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
733
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
695
734
 
696
735
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
697
736
  @telemetry.send_api_usage_telemetry(
@@ -737,7 +776,7 @@ class DecisionTreeRegressor(BaseTransformer):
737
776
  drop_input_cols=self._drop_input_cols,
738
777
  expected_output_cols_type="float",
739
778
  )
740
- expected_output_cols = self._align_expected_output_names(
779
+ expected_output_cols, _ = self._align_expected_output(
741
780
  inference_method, dataset, expected_output_cols, output_cols_prefix
742
781
  )
743
782
 
@@ -803,7 +842,7 @@ class DecisionTreeRegressor(BaseTransformer):
803
842
  drop_input_cols=self._drop_input_cols,
804
843
  expected_output_cols_type="float",
805
844
  )
806
- expected_output_cols = self._align_expected_output_names(
845
+ expected_output_cols, _ = self._align_expected_output(
807
846
  inference_method, dataset, expected_output_cols, output_cols_prefix
808
847
  )
809
848
  elif isinstance(dataset, pd.DataFrame):
@@ -866,7 +905,7 @@ class DecisionTreeRegressor(BaseTransformer):
866
905
  drop_input_cols=self._drop_input_cols,
867
906
  expected_output_cols_type="float",
868
907
  )
869
- expected_output_cols = self._align_expected_output_names(
908
+ expected_output_cols, _ = self._align_expected_output(
870
909
  inference_method, dataset, expected_output_cols, output_cols_prefix
871
910
  )
872
911
 
@@ -931,7 +970,7 @@ class DecisionTreeRegressor(BaseTransformer):
931
970
  drop_input_cols = self._drop_input_cols,
932
971
  expected_output_cols_type="float",
933
972
  )
934
- expected_output_cols = self._align_expected_output_names(
973
+ expected_output_cols, _ = self._align_expected_output(
935
974
  inference_method, dataset, expected_output_cols, output_cols_prefix
936
975
  )
937
976
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -595,12 +592,23 @@ class ExtraTreeClassifier(BaseTransformer):
595
592
  autogenerated=self._autogenerated,
596
593
  subproject=_SUBPROJECT,
597
594
  )
598
- output_result, fitted_estimator = model_trainer.train_fit_predict(
599
- drop_input_cols=self._drop_input_cols,
600
- expected_output_cols_list=(
601
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
602
- ),
595
+ expected_output_cols = (
596
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
603
597
  )
598
+ if isinstance(dataset, DataFrame):
599
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
600
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
601
+ )
602
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
603
+ drop_input_cols=self._drop_input_cols,
604
+ expected_output_cols_list=expected_output_cols,
605
+ example_output_pd_df=example_output_pd_df,
606
+ )
607
+ else:
608
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
609
+ drop_input_cols=self._drop_input_cols,
610
+ expected_output_cols_list=expected_output_cols,
611
+ )
604
612
  self._sklearn_object = fitted_estimator
605
613
  self._is_fitted = True
606
614
  return output_result
@@ -679,12 +687,41 @@ class ExtraTreeClassifier(BaseTransformer):
679
687
 
680
688
  return rv
681
689
 
682
- def _align_expected_output_names(
683
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
684
- ) -> List[str]:
690
+ def _align_expected_output(
691
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
692
+ ) -> Tuple[List[str], pd.DataFrame]:
693
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
694
+ and output dataframe with 1 line.
695
+ If the method is fit_predict, run 2 lines of data.
696
+ """
685
697
  # in case the inferred output column names dimension is different
686
698
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
687
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
699
+
700
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
701
+ # so change the minimum of number of rows to 2
702
+ num_examples = 2
703
+ statement_params = telemetry.get_function_usage_statement_params(
704
+ project=_PROJECT,
705
+ subproject=_SUBPROJECT,
706
+ function_name=telemetry.get_statement_params_full_func_name(
707
+ inspect.currentframe(), ExtraTreeClassifier.__class__.__name__
708
+ ),
709
+ api_calls=[Session.call],
710
+ custom_tags={"autogen": True} if self._autogenerated else None,
711
+ )
712
+ if output_cols_prefix == "fit_predict_":
713
+ if hasattr(self._sklearn_object, "n_clusters"):
714
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
715
+ num_examples = self._sklearn_object.n_clusters
716
+ elif hasattr(self._sklearn_object, "min_samples"):
717
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
718
+ num_examples = self._sklearn_object.min_samples
719
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
720
+ # LocalOutlierFactor expects n_neighbors <= n_samples
721
+ num_examples = self._sklearn_object.n_neighbors
722
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
723
+ else:
724
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
688
725
 
689
726
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
690
727
  # seen during the fit.
@@ -696,12 +733,14 @@ class ExtraTreeClassifier(BaseTransformer):
696
733
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
697
734
  if self.sample_weight_col:
698
735
  output_df_columns_set -= set(self.sample_weight_col)
736
+
699
737
  # if the dimension of inferred output column names is correct; use it
700
738
  if len(expected_output_cols_list) == len(output_df_columns_set):
701
- return expected_output_cols_list
739
+ return expected_output_cols_list, output_df_pd
702
740
  # otherwise, use the sklearn estimator's output
703
741
  else:
704
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
742
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
743
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
705
744
 
706
745
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
707
746
  @telemetry.send_api_usage_telemetry(
@@ -749,7 +788,7 @@ class ExtraTreeClassifier(BaseTransformer):
749
788
  drop_input_cols=self._drop_input_cols,
750
789
  expected_output_cols_type="float",
751
790
  )
752
- expected_output_cols = self._align_expected_output_names(
791
+ expected_output_cols, _ = self._align_expected_output(
753
792
  inference_method, dataset, expected_output_cols, output_cols_prefix
754
793
  )
755
794
 
@@ -817,7 +856,7 @@ class ExtraTreeClassifier(BaseTransformer):
817
856
  drop_input_cols=self._drop_input_cols,
818
857
  expected_output_cols_type="float",
819
858
  )
820
- expected_output_cols = self._align_expected_output_names(
859
+ expected_output_cols, _ = self._align_expected_output(
821
860
  inference_method, dataset, expected_output_cols, output_cols_prefix
822
861
  )
823
862
  elif isinstance(dataset, pd.DataFrame):
@@ -880,7 +919,7 @@ class ExtraTreeClassifier(BaseTransformer):
880
919
  drop_input_cols=self._drop_input_cols,
881
920
  expected_output_cols_type="float",
882
921
  )
883
- expected_output_cols = self._align_expected_output_names(
922
+ expected_output_cols, _ = self._align_expected_output(
884
923
  inference_method, dataset, expected_output_cols, output_cols_prefix
885
924
  )
886
925
 
@@ -945,7 +984,7 @@ class ExtraTreeClassifier(BaseTransformer):
945
984
  drop_input_cols = self._drop_input_cols,
946
985
  expected_output_cols_type="float",
947
986
  )
948
- expected_output_cols = self._align_expected_output_names(
987
+ expected_output_cols, _ = self._align_expected_output(
949
988
  inference_method, dataset, expected_output_cols, output_cols_prefix
950
989
  )
951
990