snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +156 -20
  5. snowflake/ml/_internal/utils/identifier.py +48 -11
  6. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  7. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  8. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  9. snowflake/ml/_internal/utils/table_manager.py +19 -1
  10. snowflake/ml/_internal/utils/uri.py +2 -2
  11. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  12. snowflake/ml/data/data_connector.py +88 -9
  13. snowflake/ml/data/data_ingestor.py +18 -1
  14. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  15. snowflake/ml/data/torch_utils.py +68 -0
  16. snowflake/ml/dataset/dataset.py +1 -3
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +9 -3
  19. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  20. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  21. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  22. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  23. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  24. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  26. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  27. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  29. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  31. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  32. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  33. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  34. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  35. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  36. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  37. snowflake/ml/feature_store/feature_store.py +100 -41
  38. snowflake/ml/feature_store/feature_view.py +149 -5
  39. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  40. snowflake/ml/fileset/fileset.py +1 -1
  41. snowflake/ml/fileset/sfcfs.py +9 -3
  42. snowflake/ml/model/_client/model/model_impl.py +11 -2
  43. snowflake/ml/model/_client/model/model_version_impl.py +186 -20
  44. snowflake/ml/model/_client/ops/model_ops.py +144 -30
  45. snowflake/ml/model/_client/ops/service_ops.py +312 -0
  46. snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
  47. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
  48. snowflake/ml/model/_client/sql/model_version.py +13 -4
  49. snowflake/ml/model/_client/sql/service.py +196 -0
  50. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  51. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  52. snowflake/ml/model/_model_composer/model_composer.py +5 -0
  53. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
  54. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  55. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  57. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  58. snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
  59. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  60. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
  61. snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
  62. snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
  63. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  64. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  65. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  66. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  67. snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
  68. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
  69. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  70. snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
  71. snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
  72. snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
  73. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
  74. snowflake/ml/model/_packager/model_packager.py +4 -1
  75. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  76. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  77. snowflake/ml/model/_signatures/utils.py +9 -0
  78. snowflake/ml/model/models/llm.py +3 -1
  79. snowflake/ml/model/type_hints.py +10 -4
  80. snowflake/ml/modeling/_internal/constants.py +1 -0
  81. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  82. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  83. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  84. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  85. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  86. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  87. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  88. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  89. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  90. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  91. snowflake/ml/modeling/cluster/birch.py +60 -21
  92. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  93. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  94. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  95. snowflake/ml/modeling/cluster/k_means.py +60 -21
  96. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  97. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  98. snowflake/ml/modeling/cluster/optics.py +60 -21
  99. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  100. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  101. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  102. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  103. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  104. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  105. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  106. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  107. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  108. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  109. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  110. snowflake/ml/modeling/covariance/oas.py +60 -21
  111. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  112. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  113. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  114. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  115. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  116. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  117. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  118. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  119. snowflake/ml/modeling/decomposition/pca.py +60 -21
  120. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  121. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  122. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  123. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  124. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  125. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  126. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  127. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  128. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  129. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  130. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  131. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  132. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  133. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  134. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  135. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  136. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  137. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  138. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  139. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  140. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  141. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  142. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  143. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  144. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  145. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  146. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  147. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  148. snowflake/ml/modeling/framework/base.py +28 -19
  149. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  150. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  151. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  152. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  153. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  154. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  155. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  156. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  157. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  158. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  159. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  160. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  161. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  162. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  163. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  164. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  165. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  166. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  167. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  168. snowflake/ml/modeling/linear_model/lars.py +60 -21
  169. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  170. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  171. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  172. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  173. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  174. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  175. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  176. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  177. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  178. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  179. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  180. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  181. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  182. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  183. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  184. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  185. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  186. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  187. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  188. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  189. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  190. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  191. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  192. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  193. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  194. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  195. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  196. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  197. snowflake/ml/modeling/manifold/isomap.py +60 -21
  198. snowflake/ml/modeling/manifold/mds.py +60 -21
  199. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  200. snowflake/ml/modeling/manifold/tsne.py +60 -21
  201. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  202. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  203. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  204. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  205. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  206. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  207. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  208. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  209. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  210. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  211. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  212. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  213. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  214. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  215. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  216. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  217. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  218. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  219. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  220. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  221. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  222. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  223. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  224. snowflake/ml/modeling/pipeline/pipeline.py +4 -12
  225. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  226. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  227. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  228. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  229. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  230. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  231. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  232. snowflake/ml/modeling/svm/svc.py +60 -21
  233. snowflake/ml/modeling/svm/svr.py +60 -21
  234. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  235. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  236. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  237. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  238. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  239. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  240. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  241. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  242. snowflake/ml/registry/_manager/model_manager.py +20 -2
  243. snowflake/ml/registry/model_registry.py +1 -1
  244. snowflake/ml/registry/registry.py +1 -2
  245. snowflake/ml/utils/sql_client.py +22 -0
  246. snowflake/ml/version.py +1 -1
  247. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
  248. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
  249. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  250. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  251. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  252. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -497,12 +494,23 @@ class PolynomialFeatures(BaseTransformer):
497
494
  autogenerated=self._autogenerated,
498
495
  subproject=_SUBPROJECT,
499
496
  )
500
- output_result, fitted_estimator = model_trainer.train_fit_predict(
501
- drop_input_cols=self._drop_input_cols,
502
- expected_output_cols_list=(
503
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
504
- ),
497
+ expected_output_cols = (
498
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
505
499
  )
500
+ if isinstance(dataset, DataFrame):
501
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
502
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
503
+ )
504
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
505
+ drop_input_cols=self._drop_input_cols,
506
+ expected_output_cols_list=expected_output_cols,
507
+ example_output_pd_df=example_output_pd_df,
508
+ )
509
+ else:
510
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
511
+ drop_input_cols=self._drop_input_cols,
512
+ expected_output_cols_list=expected_output_cols,
513
+ )
506
514
  self._sklearn_object = fitted_estimator
507
515
  self._is_fitted = True
508
516
  return output_result
@@ -583,12 +591,41 @@ class PolynomialFeatures(BaseTransformer):
583
591
 
584
592
  return rv
585
593
 
586
- def _align_expected_output_names(
587
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
588
- ) -> List[str]:
594
+ def _align_expected_output(
595
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
596
+ ) -> Tuple[List[str], pd.DataFrame]:
597
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
598
+ and output dataframe with 1 line.
599
+ If the method is fit_predict, run 2 lines of data.
600
+ """
589
601
  # in case the inferred output column names dimension is different
590
602
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
591
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
603
+
604
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
605
+ # so change the minimum of number of rows to 2
606
+ num_examples = 2
607
+ statement_params = telemetry.get_function_usage_statement_params(
608
+ project=_PROJECT,
609
+ subproject=_SUBPROJECT,
610
+ function_name=telemetry.get_statement_params_full_func_name(
611
+ inspect.currentframe(), PolynomialFeatures.__class__.__name__
612
+ ),
613
+ api_calls=[Session.call],
614
+ custom_tags={"autogen": True} if self._autogenerated else None,
615
+ )
616
+ if output_cols_prefix == "fit_predict_":
617
+ if hasattr(self._sklearn_object, "n_clusters"):
618
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
619
+ num_examples = self._sklearn_object.n_clusters
620
+ elif hasattr(self._sklearn_object, "min_samples"):
621
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
622
+ num_examples = self._sklearn_object.min_samples
623
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
624
+ # LocalOutlierFactor expects n_neighbors <= n_samples
625
+ num_examples = self._sklearn_object.n_neighbors
626
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
627
+ else:
628
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
592
629
 
593
630
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
594
631
  # seen during the fit.
@@ -600,12 +637,14 @@ class PolynomialFeatures(BaseTransformer):
600
637
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
601
638
  if self.sample_weight_col:
602
639
  output_df_columns_set -= set(self.sample_weight_col)
640
+
603
641
  # if the dimension of inferred output column names is correct; use it
604
642
  if len(expected_output_cols_list) == len(output_df_columns_set):
605
- return expected_output_cols_list
643
+ return expected_output_cols_list, output_df_pd
606
644
  # otherwise, use the sklearn estimator's output
607
645
  else:
608
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
646
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
647
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
609
648
 
610
649
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
611
650
  @telemetry.send_api_usage_telemetry(
@@ -651,7 +690,7 @@ class PolynomialFeatures(BaseTransformer):
651
690
  drop_input_cols=self._drop_input_cols,
652
691
  expected_output_cols_type="float",
653
692
  )
654
- expected_output_cols = self._align_expected_output_names(
693
+ expected_output_cols, _ = self._align_expected_output(
655
694
  inference_method, dataset, expected_output_cols, output_cols_prefix
656
695
  )
657
696
 
@@ -717,7 +756,7 @@ class PolynomialFeatures(BaseTransformer):
717
756
  drop_input_cols=self._drop_input_cols,
718
757
  expected_output_cols_type="float",
719
758
  )
720
- expected_output_cols = self._align_expected_output_names(
759
+ expected_output_cols, _ = self._align_expected_output(
721
760
  inference_method, dataset, expected_output_cols, output_cols_prefix
722
761
  )
723
762
  elif isinstance(dataset, pd.DataFrame):
@@ -780,7 +819,7 @@ class PolynomialFeatures(BaseTransformer):
780
819
  drop_input_cols=self._drop_input_cols,
781
820
  expected_output_cols_type="float",
782
821
  )
783
- expected_output_cols = self._align_expected_output_names(
822
+ expected_output_cols, _ = self._align_expected_output(
784
823
  inference_method, dataset, expected_output_cols, output_cols_prefix
785
824
  )
786
825
 
@@ -845,7 +884,7 @@ class PolynomialFeatures(BaseTransformer):
845
884
  drop_input_cols = self._drop_input_cols,
846
885
  expected_output_cols_type="float",
847
886
  )
848
- expected_output_cols = self._align_expected_output_names(
887
+ expected_output_cols, _ = self._align_expected_output(
849
888
  inference_method, dataset, expected_output_cols, output_cols_prefix
850
889
  )
851
890
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -501,12 +498,23 @@ class LabelPropagation(BaseTransformer):
501
498
  autogenerated=self._autogenerated,
502
499
  subproject=_SUBPROJECT,
503
500
  )
504
- output_result, fitted_estimator = model_trainer.train_fit_predict(
505
- drop_input_cols=self._drop_input_cols,
506
- expected_output_cols_list=(
507
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
508
- ),
501
+ expected_output_cols = (
502
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
509
503
  )
504
+ if isinstance(dataset, DataFrame):
505
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
506
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
507
+ )
508
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
509
+ drop_input_cols=self._drop_input_cols,
510
+ expected_output_cols_list=expected_output_cols,
511
+ example_output_pd_df=example_output_pd_df,
512
+ )
513
+ else:
514
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
515
+ drop_input_cols=self._drop_input_cols,
516
+ expected_output_cols_list=expected_output_cols,
517
+ )
510
518
  self._sklearn_object = fitted_estimator
511
519
  self._is_fitted = True
512
520
  return output_result
@@ -585,12 +593,41 @@ class LabelPropagation(BaseTransformer):
585
593
 
586
594
  return rv
587
595
 
588
- def _align_expected_output_names(
589
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
590
- ) -> List[str]:
596
+ def _align_expected_output(
597
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
598
+ ) -> Tuple[List[str], pd.DataFrame]:
599
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
600
+ and output dataframe with 1 line.
601
+ If the method is fit_predict, run 2 lines of data.
602
+ """
591
603
  # in case the inferred output column names dimension is different
592
604
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
593
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
605
+
606
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
607
+ # so change the minimum of number of rows to 2
608
+ num_examples = 2
609
+ statement_params = telemetry.get_function_usage_statement_params(
610
+ project=_PROJECT,
611
+ subproject=_SUBPROJECT,
612
+ function_name=telemetry.get_statement_params_full_func_name(
613
+ inspect.currentframe(), LabelPropagation.__class__.__name__
614
+ ),
615
+ api_calls=[Session.call],
616
+ custom_tags={"autogen": True} if self._autogenerated else None,
617
+ )
618
+ if output_cols_prefix == "fit_predict_":
619
+ if hasattr(self._sklearn_object, "n_clusters"):
620
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
621
+ num_examples = self._sklearn_object.n_clusters
622
+ elif hasattr(self._sklearn_object, "min_samples"):
623
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
624
+ num_examples = self._sklearn_object.min_samples
625
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
626
+ # LocalOutlierFactor expects n_neighbors <= n_samples
627
+ num_examples = self._sklearn_object.n_neighbors
628
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
629
+ else:
630
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
594
631
 
595
632
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
596
633
  # seen during the fit.
@@ -602,12 +639,14 @@ class LabelPropagation(BaseTransformer):
602
639
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
603
640
  if self.sample_weight_col:
604
641
  output_df_columns_set -= set(self.sample_weight_col)
642
+
605
643
  # if the dimension of inferred output column names is correct; use it
606
644
  if len(expected_output_cols_list) == len(output_df_columns_set):
607
- return expected_output_cols_list
645
+ return expected_output_cols_list, output_df_pd
608
646
  # otherwise, use the sklearn estimator's output
609
647
  else:
610
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
648
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
649
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
611
650
 
612
651
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
613
652
  @telemetry.send_api_usage_telemetry(
@@ -655,7 +694,7 @@ class LabelPropagation(BaseTransformer):
655
694
  drop_input_cols=self._drop_input_cols,
656
695
  expected_output_cols_type="float",
657
696
  )
658
- expected_output_cols = self._align_expected_output_names(
697
+ expected_output_cols, _ = self._align_expected_output(
659
698
  inference_method, dataset, expected_output_cols, output_cols_prefix
660
699
  )
661
700
 
@@ -723,7 +762,7 @@ class LabelPropagation(BaseTransformer):
723
762
  drop_input_cols=self._drop_input_cols,
724
763
  expected_output_cols_type="float",
725
764
  )
726
- expected_output_cols = self._align_expected_output_names(
765
+ expected_output_cols, _ = self._align_expected_output(
727
766
  inference_method, dataset, expected_output_cols, output_cols_prefix
728
767
  )
729
768
  elif isinstance(dataset, pd.DataFrame):
@@ -786,7 +825,7 @@ class LabelPropagation(BaseTransformer):
786
825
  drop_input_cols=self._drop_input_cols,
787
826
  expected_output_cols_type="float",
788
827
  )
789
- expected_output_cols = self._align_expected_output_names(
828
+ expected_output_cols, _ = self._align_expected_output(
790
829
  inference_method, dataset, expected_output_cols, output_cols_prefix
791
830
  )
792
831
 
@@ -851,7 +890,7 @@ class LabelPropagation(BaseTransformer):
851
890
  drop_input_cols = self._drop_input_cols,
852
891
  expected_output_cols_type="float",
853
892
  )
854
- expected_output_cols = self._align_expected_output_names(
893
+ expected_output_cols, _ = self._align_expected_output(
855
894
  inference_method, dataset, expected_output_cols, output_cols_prefix
856
895
  )
857
896
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -510,12 +507,23 @@ class LabelSpreading(BaseTransformer):
510
507
  autogenerated=self._autogenerated,
511
508
  subproject=_SUBPROJECT,
512
509
  )
513
- output_result, fitted_estimator = model_trainer.train_fit_predict(
514
- drop_input_cols=self._drop_input_cols,
515
- expected_output_cols_list=(
516
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
517
- ),
510
+ expected_output_cols = (
511
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
518
512
  )
513
+ if isinstance(dataset, DataFrame):
514
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
515
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
516
+ )
517
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
518
+ drop_input_cols=self._drop_input_cols,
519
+ expected_output_cols_list=expected_output_cols,
520
+ example_output_pd_df=example_output_pd_df,
521
+ )
522
+ else:
523
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
524
+ drop_input_cols=self._drop_input_cols,
525
+ expected_output_cols_list=expected_output_cols,
526
+ )
519
527
  self._sklearn_object = fitted_estimator
520
528
  self._is_fitted = True
521
529
  return output_result
@@ -594,12 +602,41 @@ class LabelSpreading(BaseTransformer):
594
602
 
595
603
  return rv
596
604
 
597
- def _align_expected_output_names(
598
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
599
- ) -> List[str]:
605
+ def _align_expected_output(
606
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
607
+ ) -> Tuple[List[str], pd.DataFrame]:
608
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
609
+ and output dataframe with 1 line.
610
+ If the method is fit_predict, run 2 lines of data.
611
+ """
600
612
  # in case the inferred output column names dimension is different
601
613
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
602
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
614
+
615
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
616
+ # so change the minimum of number of rows to 2
617
+ num_examples = 2
618
+ statement_params = telemetry.get_function_usage_statement_params(
619
+ project=_PROJECT,
620
+ subproject=_SUBPROJECT,
621
+ function_name=telemetry.get_statement_params_full_func_name(
622
+ inspect.currentframe(), LabelSpreading.__class__.__name__
623
+ ),
624
+ api_calls=[Session.call],
625
+ custom_tags={"autogen": True} if self._autogenerated else None,
626
+ )
627
+ if output_cols_prefix == "fit_predict_":
628
+ if hasattr(self._sklearn_object, "n_clusters"):
629
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
630
+ num_examples = self._sklearn_object.n_clusters
631
+ elif hasattr(self._sklearn_object, "min_samples"):
632
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
633
+ num_examples = self._sklearn_object.min_samples
634
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
635
+ # LocalOutlierFactor expects n_neighbors <= n_samples
636
+ num_examples = self._sklearn_object.n_neighbors
637
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
638
+ else:
639
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
603
640
 
604
641
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
605
642
  # seen during the fit.
@@ -611,12 +648,14 @@ class LabelSpreading(BaseTransformer):
611
648
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
612
649
  if self.sample_weight_col:
613
650
  output_df_columns_set -= set(self.sample_weight_col)
651
+
614
652
  # if the dimension of inferred output column names is correct; use it
615
653
  if len(expected_output_cols_list) == len(output_df_columns_set):
616
- return expected_output_cols_list
654
+ return expected_output_cols_list, output_df_pd
617
655
  # otherwise, use the sklearn estimator's output
618
656
  else:
619
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
657
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
658
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
620
659
 
621
660
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
622
661
  @telemetry.send_api_usage_telemetry(
@@ -664,7 +703,7 @@ class LabelSpreading(BaseTransformer):
664
703
  drop_input_cols=self._drop_input_cols,
665
704
  expected_output_cols_type="float",
666
705
  )
667
- expected_output_cols = self._align_expected_output_names(
706
+ expected_output_cols, _ = self._align_expected_output(
668
707
  inference_method, dataset, expected_output_cols, output_cols_prefix
669
708
  )
670
709
 
@@ -732,7 +771,7 @@ class LabelSpreading(BaseTransformer):
732
771
  drop_input_cols=self._drop_input_cols,
733
772
  expected_output_cols_type="float",
734
773
  )
735
- expected_output_cols = self._align_expected_output_names(
774
+ expected_output_cols, _ = self._align_expected_output(
736
775
  inference_method, dataset, expected_output_cols, output_cols_prefix
737
776
  )
738
777
  elif isinstance(dataset, pd.DataFrame):
@@ -795,7 +834,7 @@ class LabelSpreading(BaseTransformer):
795
834
  drop_input_cols=self._drop_input_cols,
796
835
  expected_output_cols_type="float",
797
836
  )
798
- expected_output_cols = self._align_expected_output_names(
837
+ expected_output_cols, _ = self._align_expected_output(
799
838
  inference_method, dataset, expected_output_cols, output_cols_prefix
800
839
  )
801
840
 
@@ -860,7 +899,7 @@ class LabelSpreading(BaseTransformer):
860
899
  drop_input_cols = self._drop_input_cols,
861
900
  expected_output_cols_type="float",
862
901
  )
863
- expected_output_cols = self._align_expected_output_names(
902
+ expected_output_cols, _ = self._align_expected_output(
864
903
  inference_method, dataset, expected_output_cols, output_cols_prefix
865
904
  )
866
905
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -566,12 +563,23 @@ class LinearSVC(BaseTransformer):
566
563
  autogenerated=self._autogenerated,
567
564
  subproject=_SUBPROJECT,
568
565
  )
569
- output_result, fitted_estimator = model_trainer.train_fit_predict(
570
- drop_input_cols=self._drop_input_cols,
571
- expected_output_cols_list=(
572
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
573
- ),
566
+ expected_output_cols = (
567
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
574
568
  )
569
+ if isinstance(dataset, DataFrame):
570
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
571
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
572
+ )
573
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
574
+ drop_input_cols=self._drop_input_cols,
575
+ expected_output_cols_list=expected_output_cols,
576
+ example_output_pd_df=example_output_pd_df,
577
+ )
578
+ else:
579
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
580
+ drop_input_cols=self._drop_input_cols,
581
+ expected_output_cols_list=expected_output_cols,
582
+ )
575
583
  self._sklearn_object = fitted_estimator
576
584
  self._is_fitted = True
577
585
  return output_result
@@ -650,12 +658,41 @@ class LinearSVC(BaseTransformer):
650
658
 
651
659
  return rv
652
660
 
653
- def _align_expected_output_names(
654
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
655
- ) -> List[str]:
661
+ def _align_expected_output(
662
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
663
+ ) -> Tuple[List[str], pd.DataFrame]:
664
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
665
+ and output dataframe with 1 line.
666
+ If the method is fit_predict, run 2 lines of data.
667
+ """
656
668
  # in case the inferred output column names dimension is different
657
669
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
658
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
670
+
671
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
672
+ # so change the minimum of number of rows to 2
673
+ num_examples = 2
674
+ statement_params = telemetry.get_function_usage_statement_params(
675
+ project=_PROJECT,
676
+ subproject=_SUBPROJECT,
677
+ function_name=telemetry.get_statement_params_full_func_name(
678
+ inspect.currentframe(), LinearSVC.__class__.__name__
679
+ ),
680
+ api_calls=[Session.call],
681
+ custom_tags={"autogen": True} if self._autogenerated else None,
682
+ )
683
+ if output_cols_prefix == "fit_predict_":
684
+ if hasattr(self._sklearn_object, "n_clusters"):
685
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
686
+ num_examples = self._sklearn_object.n_clusters
687
+ elif hasattr(self._sklearn_object, "min_samples"):
688
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
689
+ num_examples = self._sklearn_object.min_samples
690
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
691
+ # LocalOutlierFactor expects n_neighbors <= n_samples
692
+ num_examples = self._sklearn_object.n_neighbors
693
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
694
+ else:
695
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
659
696
 
660
697
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
661
698
  # seen during the fit.
@@ -667,12 +704,14 @@ class LinearSVC(BaseTransformer):
667
704
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
668
705
  if self.sample_weight_col:
669
706
  output_df_columns_set -= set(self.sample_weight_col)
707
+
670
708
  # if the dimension of inferred output column names is correct; use it
671
709
  if len(expected_output_cols_list) == len(output_df_columns_set):
672
- return expected_output_cols_list
710
+ return expected_output_cols_list, output_df_pd
673
711
  # otherwise, use the sklearn estimator's output
674
712
  else:
675
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
713
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
714
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
676
715
 
677
716
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
678
717
  @telemetry.send_api_usage_telemetry(
@@ -718,7 +757,7 @@ class LinearSVC(BaseTransformer):
718
757
  drop_input_cols=self._drop_input_cols,
719
758
  expected_output_cols_type="float",
720
759
  )
721
- expected_output_cols = self._align_expected_output_names(
760
+ expected_output_cols, _ = self._align_expected_output(
722
761
  inference_method, dataset, expected_output_cols, output_cols_prefix
723
762
  )
724
763
 
@@ -784,7 +823,7 @@ class LinearSVC(BaseTransformer):
784
823
  drop_input_cols=self._drop_input_cols,
785
824
  expected_output_cols_type="float",
786
825
  )
787
- expected_output_cols = self._align_expected_output_names(
826
+ expected_output_cols, _ = self._align_expected_output(
788
827
  inference_method, dataset, expected_output_cols, output_cols_prefix
789
828
  )
790
829
  elif isinstance(dataset, pd.DataFrame):
@@ -849,7 +888,7 @@ class LinearSVC(BaseTransformer):
849
888
  drop_input_cols=self._drop_input_cols,
850
889
  expected_output_cols_type="float",
851
890
  )
852
- expected_output_cols = self._align_expected_output_names(
891
+ expected_output_cols, _ = self._align_expected_output(
853
892
  inference_method, dataset, expected_output_cols, output_cols_prefix
854
893
  )
855
894
 
@@ -914,7 +953,7 @@ class LinearSVC(BaseTransformer):
914
953
  drop_input_cols = self._drop_input_cols,
915
954
  expected_output_cols_type="float",
916
955
  )
917
- expected_output_cols = self._align_expected_output_names(
956
+ expected_output_cols, _ = self._align_expected_output(
918
957
  inference_method, dataset, expected_output_cols, output_cols_prefix
919
958
  )
920
959