snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +156 -20
  5. snowflake/ml/_internal/utils/identifier.py +48 -11
  6. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  7. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  8. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  9. snowflake/ml/_internal/utils/table_manager.py +19 -1
  10. snowflake/ml/_internal/utils/uri.py +2 -2
  11. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  12. snowflake/ml/data/data_connector.py +88 -9
  13. snowflake/ml/data/data_ingestor.py +18 -1
  14. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  15. snowflake/ml/data/torch_utils.py +68 -0
  16. snowflake/ml/dataset/dataset.py +1 -3
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +9 -3
  19. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  20. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  21. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  22. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  23. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  24. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  26. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  27. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  29. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  31. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  32. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  33. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  34. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  35. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  36. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  37. snowflake/ml/feature_store/feature_store.py +100 -41
  38. snowflake/ml/feature_store/feature_view.py +149 -5
  39. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  40. snowflake/ml/fileset/fileset.py +1 -1
  41. snowflake/ml/fileset/sfcfs.py +9 -3
  42. snowflake/ml/model/_client/model/model_impl.py +11 -2
  43. snowflake/ml/model/_client/model/model_version_impl.py +186 -20
  44. snowflake/ml/model/_client/ops/model_ops.py +144 -30
  45. snowflake/ml/model/_client/ops/service_ops.py +312 -0
  46. snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
  47. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
  48. snowflake/ml/model/_client/sql/model_version.py +13 -4
  49. snowflake/ml/model/_client/sql/service.py +196 -0
  50. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  51. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  52. snowflake/ml/model/_model_composer/model_composer.py +5 -0
  53. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
  54. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  55. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  57. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  58. snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
  59. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  60. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
  61. snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
  62. snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
  63. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  64. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  65. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  66. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  67. snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
  68. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
  69. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  70. snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
  71. snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
  72. snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
  73. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
  74. snowflake/ml/model/_packager/model_packager.py +4 -1
  75. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  76. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  77. snowflake/ml/model/_signatures/utils.py +9 -0
  78. snowflake/ml/model/models/llm.py +3 -1
  79. snowflake/ml/model/type_hints.py +10 -4
  80. snowflake/ml/modeling/_internal/constants.py +1 -0
  81. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  82. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  83. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  84. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  85. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  86. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  87. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  88. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  89. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  90. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  91. snowflake/ml/modeling/cluster/birch.py +60 -21
  92. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  93. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  94. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  95. snowflake/ml/modeling/cluster/k_means.py +60 -21
  96. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  97. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  98. snowflake/ml/modeling/cluster/optics.py +60 -21
  99. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  100. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  101. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  102. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  103. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  104. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  105. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  106. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  107. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  108. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  109. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  110. snowflake/ml/modeling/covariance/oas.py +60 -21
  111. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  112. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  113. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  114. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  115. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  116. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  117. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  118. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  119. snowflake/ml/modeling/decomposition/pca.py +60 -21
  120. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  121. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  122. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  123. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  124. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  125. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  126. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  127. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  128. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  129. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  130. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  131. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  132. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  133. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  134. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  135. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  136. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  137. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  138. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  139. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  140. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  141. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  142. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  143. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  144. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  145. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  146. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  147. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  148. snowflake/ml/modeling/framework/base.py +28 -19
  149. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  150. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  151. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  152. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  153. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  154. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  155. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  156. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  157. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  158. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  159. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  160. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  161. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  162. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  163. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  164. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  165. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  166. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  167. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  168. snowflake/ml/modeling/linear_model/lars.py +60 -21
  169. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  170. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  171. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  172. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  173. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  174. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  175. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  176. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  177. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  178. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  179. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  180. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  181. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  182. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  183. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  184. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  185. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  186. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  187. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  188. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  189. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  190. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  191. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  192. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  193. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  194. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  195. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  196. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  197. snowflake/ml/modeling/manifold/isomap.py +60 -21
  198. snowflake/ml/modeling/manifold/mds.py +60 -21
  199. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  200. snowflake/ml/modeling/manifold/tsne.py +60 -21
  201. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  202. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  203. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  204. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  205. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  206. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  207. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  208. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  209. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  210. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  211. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  212. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  213. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  214. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  215. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  216. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  217. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  218. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  219. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  220. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  221. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  222. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  223. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  224. snowflake/ml/modeling/pipeline/pipeline.py +4 -12
  225. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  226. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  227. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  228. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  229. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  230. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  231. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  232. snowflake/ml/modeling/svm/svc.py +60 -21
  233. snowflake/ml/modeling/svm/svr.py +60 -21
  234. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  235. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  236. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  237. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  238. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  239. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  240. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  241. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  242. snowflake/ml/registry/_manager/model_manager.py +20 -2
  243. snowflake/ml/registry/model_registry.py +1 -1
  244. snowflake/ml/registry/registry.py +1 -2
  245. snowflake/ml/utils/sql_client.py +22 -0
  246. snowflake/ml/version.py +1 -1
  247. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
  248. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
  249. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  250. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  251. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  252. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -538,12 +535,23 @@ class LinearSVR(BaseTransformer):
538
535
  autogenerated=self._autogenerated,
539
536
  subproject=_SUBPROJECT,
540
537
  )
541
- output_result, fitted_estimator = model_trainer.train_fit_predict(
542
- drop_input_cols=self._drop_input_cols,
543
- expected_output_cols_list=(
544
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
545
- ),
538
+ expected_output_cols = (
539
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
546
540
  )
541
+ if isinstance(dataset, DataFrame):
542
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
543
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
544
+ )
545
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
546
+ drop_input_cols=self._drop_input_cols,
547
+ expected_output_cols_list=expected_output_cols,
548
+ example_output_pd_df=example_output_pd_df,
549
+ )
550
+ else:
551
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
552
+ drop_input_cols=self._drop_input_cols,
553
+ expected_output_cols_list=expected_output_cols,
554
+ )
547
555
  self._sklearn_object = fitted_estimator
548
556
  self._is_fitted = True
549
557
  return output_result
@@ -622,12 +630,41 @@ class LinearSVR(BaseTransformer):
622
630
 
623
631
  return rv
624
632
 
625
- def _align_expected_output_names(
626
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
627
- ) -> List[str]:
633
+ def _align_expected_output(
634
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
635
+ ) -> Tuple[List[str], pd.DataFrame]:
636
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
637
+ and output dataframe with 1 line.
638
+ If the method is fit_predict, run 2 lines of data.
639
+ """
628
640
  # in case the inferred output column names dimension is different
629
641
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
630
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
642
+
643
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
644
+ # so change the minimum of number of rows to 2
645
+ num_examples = 2
646
+ statement_params = telemetry.get_function_usage_statement_params(
647
+ project=_PROJECT,
648
+ subproject=_SUBPROJECT,
649
+ function_name=telemetry.get_statement_params_full_func_name(
650
+ inspect.currentframe(), LinearSVR.__class__.__name__
651
+ ),
652
+ api_calls=[Session.call],
653
+ custom_tags={"autogen": True} if self._autogenerated else None,
654
+ )
655
+ if output_cols_prefix == "fit_predict_":
656
+ if hasattr(self._sklearn_object, "n_clusters"):
657
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
658
+ num_examples = self._sklearn_object.n_clusters
659
+ elif hasattr(self._sklearn_object, "min_samples"):
660
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
661
+ num_examples = self._sklearn_object.min_samples
662
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
663
+ # LocalOutlierFactor expects n_neighbors <= n_samples
664
+ num_examples = self._sklearn_object.n_neighbors
665
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
666
+ else:
667
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
631
668
 
632
669
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
633
670
  # seen during the fit.
@@ -639,12 +676,14 @@ class LinearSVR(BaseTransformer):
639
676
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
640
677
  if self.sample_weight_col:
641
678
  output_df_columns_set -= set(self.sample_weight_col)
679
+
642
680
  # if the dimension of inferred output column names is correct; use it
643
681
  if len(expected_output_cols_list) == len(output_df_columns_set):
644
- return expected_output_cols_list
682
+ return expected_output_cols_list, output_df_pd
645
683
  # otherwise, use the sklearn estimator's output
646
684
  else:
647
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
685
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
686
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
648
687
 
649
688
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
650
689
  @telemetry.send_api_usage_telemetry(
@@ -690,7 +729,7 @@ class LinearSVR(BaseTransformer):
690
729
  drop_input_cols=self._drop_input_cols,
691
730
  expected_output_cols_type="float",
692
731
  )
693
- expected_output_cols = self._align_expected_output_names(
732
+ expected_output_cols, _ = self._align_expected_output(
694
733
  inference_method, dataset, expected_output_cols, output_cols_prefix
695
734
  )
696
735
 
@@ -756,7 +795,7 @@ class LinearSVR(BaseTransformer):
756
795
  drop_input_cols=self._drop_input_cols,
757
796
  expected_output_cols_type="float",
758
797
  )
759
- expected_output_cols = self._align_expected_output_names(
798
+ expected_output_cols, _ = self._align_expected_output(
760
799
  inference_method, dataset, expected_output_cols, output_cols_prefix
761
800
  )
762
801
  elif isinstance(dataset, pd.DataFrame):
@@ -819,7 +858,7 @@ class LinearSVR(BaseTransformer):
819
858
  drop_input_cols=self._drop_input_cols,
820
859
  expected_output_cols_type="float",
821
860
  )
822
- expected_output_cols = self._align_expected_output_names(
861
+ expected_output_cols, _ = self._align_expected_output(
823
862
  inference_method, dataset, expected_output_cols, output_cols_prefix
824
863
  )
825
864
 
@@ -884,7 +923,7 @@ class LinearSVR(BaseTransformer):
884
923
  drop_input_cols = self._drop_input_cols,
885
924
  expected_output_cols_type="float",
886
925
  )
887
- expected_output_cols = self._align_expected_output_names(
926
+ expected_output_cols, _ = self._align_expected_output(
888
927
  inference_method, dataset, expected_output_cols, output_cols_prefix
889
928
  )
890
929
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -572,12 +569,23 @@ class NuSVC(BaseTransformer):
572
569
  autogenerated=self._autogenerated,
573
570
  subproject=_SUBPROJECT,
574
571
  )
575
- output_result, fitted_estimator = model_trainer.train_fit_predict(
576
- drop_input_cols=self._drop_input_cols,
577
- expected_output_cols_list=(
578
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
579
- ),
572
+ expected_output_cols = (
573
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
580
574
  )
575
+ if isinstance(dataset, DataFrame):
576
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
577
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
578
+ )
579
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
580
+ drop_input_cols=self._drop_input_cols,
581
+ expected_output_cols_list=expected_output_cols,
582
+ example_output_pd_df=example_output_pd_df,
583
+ )
584
+ else:
585
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
586
+ drop_input_cols=self._drop_input_cols,
587
+ expected_output_cols_list=expected_output_cols,
588
+ )
581
589
  self._sklearn_object = fitted_estimator
582
590
  self._is_fitted = True
583
591
  return output_result
@@ -656,12 +664,41 @@ class NuSVC(BaseTransformer):
656
664
 
657
665
  return rv
658
666
 
659
- def _align_expected_output_names(
660
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
661
- ) -> List[str]:
667
+ def _align_expected_output(
668
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
669
+ ) -> Tuple[List[str], pd.DataFrame]:
670
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
671
+ and output dataframe with 1 line.
672
+ If the method is fit_predict, run 2 lines of data.
673
+ """
662
674
  # in case the inferred output column names dimension is different
663
675
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
664
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
676
+
677
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
678
+ # so change the minimum of number of rows to 2
679
+ num_examples = 2
680
+ statement_params = telemetry.get_function_usage_statement_params(
681
+ project=_PROJECT,
682
+ subproject=_SUBPROJECT,
683
+ function_name=telemetry.get_statement_params_full_func_name(
684
+ inspect.currentframe(), NuSVC.__class__.__name__
685
+ ),
686
+ api_calls=[Session.call],
687
+ custom_tags={"autogen": True} if self._autogenerated else None,
688
+ )
689
+ if output_cols_prefix == "fit_predict_":
690
+ if hasattr(self._sklearn_object, "n_clusters"):
691
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
692
+ num_examples = self._sklearn_object.n_clusters
693
+ elif hasattr(self._sklearn_object, "min_samples"):
694
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
695
+ num_examples = self._sklearn_object.min_samples
696
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
697
+ # LocalOutlierFactor expects n_neighbors <= n_samples
698
+ num_examples = self._sklearn_object.n_neighbors
699
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
700
+ else:
701
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
665
702
 
666
703
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
667
704
  # seen during the fit.
@@ -673,12 +710,14 @@ class NuSVC(BaseTransformer):
673
710
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
674
711
  if self.sample_weight_col:
675
712
  output_df_columns_set -= set(self.sample_weight_col)
713
+
676
714
  # if the dimension of inferred output column names is correct; use it
677
715
  if len(expected_output_cols_list) == len(output_df_columns_set):
678
- return expected_output_cols_list
716
+ return expected_output_cols_list, output_df_pd
679
717
  # otherwise, use the sklearn estimator's output
680
718
  else:
681
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
719
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
720
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
682
721
 
683
722
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
684
723
  @telemetry.send_api_usage_telemetry(
@@ -726,7 +765,7 @@ class NuSVC(BaseTransformer):
726
765
  drop_input_cols=self._drop_input_cols,
727
766
  expected_output_cols_type="float",
728
767
  )
729
- expected_output_cols = self._align_expected_output_names(
768
+ expected_output_cols, _ = self._align_expected_output(
730
769
  inference_method, dataset, expected_output_cols, output_cols_prefix
731
770
  )
732
771
 
@@ -794,7 +833,7 @@ class NuSVC(BaseTransformer):
794
833
  drop_input_cols=self._drop_input_cols,
795
834
  expected_output_cols_type="float",
796
835
  )
797
- expected_output_cols = self._align_expected_output_names(
836
+ expected_output_cols, _ = self._align_expected_output(
798
837
  inference_method, dataset, expected_output_cols, output_cols_prefix
799
838
  )
800
839
  elif isinstance(dataset, pd.DataFrame):
@@ -859,7 +898,7 @@ class NuSVC(BaseTransformer):
859
898
  drop_input_cols=self._drop_input_cols,
860
899
  expected_output_cols_type="float",
861
900
  )
862
- expected_output_cols = self._align_expected_output_names(
901
+ expected_output_cols, _ = self._align_expected_output(
863
902
  inference_method, dataset, expected_output_cols, output_cols_prefix
864
903
  )
865
904
 
@@ -924,7 +963,7 @@ class NuSVC(BaseTransformer):
924
963
  drop_input_cols = self._drop_input_cols,
925
964
  expected_output_cols_type="float",
926
965
  )
927
- expected_output_cols = self._align_expected_output_names(
966
+ expected_output_cols, _ = self._align_expected_output(
928
967
  inference_method, dataset, expected_output_cols, output_cols_prefix
929
968
  )
930
969
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -533,12 +530,23 @@ class NuSVR(BaseTransformer):
533
530
  autogenerated=self._autogenerated,
534
531
  subproject=_SUBPROJECT,
535
532
  )
536
- output_result, fitted_estimator = model_trainer.train_fit_predict(
537
- drop_input_cols=self._drop_input_cols,
538
- expected_output_cols_list=(
539
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
540
- ),
533
+ expected_output_cols = (
534
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
541
535
  )
536
+ if isinstance(dataset, DataFrame):
537
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
538
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
539
+ )
540
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
541
+ drop_input_cols=self._drop_input_cols,
542
+ expected_output_cols_list=expected_output_cols,
543
+ example_output_pd_df=example_output_pd_df,
544
+ )
545
+ else:
546
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
547
+ drop_input_cols=self._drop_input_cols,
548
+ expected_output_cols_list=expected_output_cols,
549
+ )
542
550
  self._sklearn_object = fitted_estimator
543
551
  self._is_fitted = True
544
552
  return output_result
@@ -617,12 +625,41 @@ class NuSVR(BaseTransformer):
617
625
 
618
626
  return rv
619
627
 
620
- def _align_expected_output_names(
621
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
622
- ) -> List[str]:
628
+ def _align_expected_output(
629
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
630
+ ) -> Tuple[List[str], pd.DataFrame]:
631
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
632
+ and output dataframe with 1 line.
633
+ If the method is fit_predict, run 2 lines of data.
634
+ """
623
635
  # in case the inferred output column names dimension is different
624
636
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
625
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
637
+
638
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
639
+ # so change the minimum of number of rows to 2
640
+ num_examples = 2
641
+ statement_params = telemetry.get_function_usage_statement_params(
642
+ project=_PROJECT,
643
+ subproject=_SUBPROJECT,
644
+ function_name=telemetry.get_statement_params_full_func_name(
645
+ inspect.currentframe(), NuSVR.__class__.__name__
646
+ ),
647
+ api_calls=[Session.call],
648
+ custom_tags={"autogen": True} if self._autogenerated else None,
649
+ )
650
+ if output_cols_prefix == "fit_predict_":
651
+ if hasattr(self._sklearn_object, "n_clusters"):
652
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
653
+ num_examples = self._sklearn_object.n_clusters
654
+ elif hasattr(self._sklearn_object, "min_samples"):
655
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
656
+ num_examples = self._sklearn_object.min_samples
657
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
658
+ # LocalOutlierFactor expects n_neighbors <= n_samples
659
+ num_examples = self._sklearn_object.n_neighbors
660
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
661
+ else:
662
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
626
663
 
627
664
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
628
665
  # seen during the fit.
@@ -634,12 +671,14 @@ class NuSVR(BaseTransformer):
634
671
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
635
672
  if self.sample_weight_col:
636
673
  output_df_columns_set -= set(self.sample_weight_col)
674
+
637
675
  # if the dimension of inferred output column names is correct; use it
638
676
  if len(expected_output_cols_list) == len(output_df_columns_set):
639
- return expected_output_cols_list
677
+ return expected_output_cols_list, output_df_pd
640
678
  # otherwise, use the sklearn estimator's output
641
679
  else:
642
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
680
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
681
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
643
682
 
644
683
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
645
684
  @telemetry.send_api_usage_telemetry(
@@ -685,7 +724,7 @@ class NuSVR(BaseTransformer):
685
724
  drop_input_cols=self._drop_input_cols,
686
725
  expected_output_cols_type="float",
687
726
  )
688
- expected_output_cols = self._align_expected_output_names(
727
+ expected_output_cols, _ = self._align_expected_output(
689
728
  inference_method, dataset, expected_output_cols, output_cols_prefix
690
729
  )
691
730
 
@@ -751,7 +790,7 @@ class NuSVR(BaseTransformer):
751
790
  drop_input_cols=self._drop_input_cols,
752
791
  expected_output_cols_type="float",
753
792
  )
754
- expected_output_cols = self._align_expected_output_names(
793
+ expected_output_cols, _ = self._align_expected_output(
755
794
  inference_method, dataset, expected_output_cols, output_cols_prefix
756
795
  )
757
796
  elif isinstance(dataset, pd.DataFrame):
@@ -814,7 +853,7 @@ class NuSVR(BaseTransformer):
814
853
  drop_input_cols=self._drop_input_cols,
815
854
  expected_output_cols_type="float",
816
855
  )
817
- expected_output_cols = self._align_expected_output_names(
856
+ expected_output_cols, _ = self._align_expected_output(
818
857
  inference_method, dataset, expected_output_cols, output_cols_prefix
819
858
  )
820
859
 
@@ -879,7 +918,7 @@ class NuSVR(BaseTransformer):
879
918
  drop_input_cols = self._drop_input_cols,
880
919
  expected_output_cols_type="float",
881
920
  )
882
- expected_output_cols = self._align_expected_output_names(
921
+ expected_output_cols, _ = self._align_expected_output(
883
922
  inference_method, dataset, expected_output_cols, output_cols_prefix
884
923
  )
885
924
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -575,12 +572,23 @@ class SVC(BaseTransformer):
575
572
  autogenerated=self._autogenerated,
576
573
  subproject=_SUBPROJECT,
577
574
  )
578
- output_result, fitted_estimator = model_trainer.train_fit_predict(
579
- drop_input_cols=self._drop_input_cols,
580
- expected_output_cols_list=(
581
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
582
- ),
575
+ expected_output_cols = (
576
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
583
577
  )
578
+ if isinstance(dataset, DataFrame):
579
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
580
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
581
+ )
582
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
583
+ drop_input_cols=self._drop_input_cols,
584
+ expected_output_cols_list=expected_output_cols,
585
+ example_output_pd_df=example_output_pd_df,
586
+ )
587
+ else:
588
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
589
+ drop_input_cols=self._drop_input_cols,
590
+ expected_output_cols_list=expected_output_cols,
591
+ )
584
592
  self._sklearn_object = fitted_estimator
585
593
  self._is_fitted = True
586
594
  return output_result
@@ -659,12 +667,41 @@ class SVC(BaseTransformer):
659
667
 
660
668
  return rv
661
669
 
662
- def _align_expected_output_names(
663
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
664
- ) -> List[str]:
670
+ def _align_expected_output(
671
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
672
+ ) -> Tuple[List[str], pd.DataFrame]:
673
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
674
+ and output dataframe with 1 line.
675
+ If the method is fit_predict, run 2 lines of data.
676
+ """
665
677
  # in case the inferred output column names dimension is different
666
678
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
667
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
679
+
680
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
681
+ # so change the minimum of number of rows to 2
682
+ num_examples = 2
683
+ statement_params = telemetry.get_function_usage_statement_params(
684
+ project=_PROJECT,
685
+ subproject=_SUBPROJECT,
686
+ function_name=telemetry.get_statement_params_full_func_name(
687
+ inspect.currentframe(), SVC.__class__.__name__
688
+ ),
689
+ api_calls=[Session.call],
690
+ custom_tags={"autogen": True} if self._autogenerated else None,
691
+ )
692
+ if output_cols_prefix == "fit_predict_":
693
+ if hasattr(self._sklearn_object, "n_clusters"):
694
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
695
+ num_examples = self._sklearn_object.n_clusters
696
+ elif hasattr(self._sklearn_object, "min_samples"):
697
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
698
+ num_examples = self._sklearn_object.min_samples
699
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
700
+ # LocalOutlierFactor expects n_neighbors <= n_samples
701
+ num_examples = self._sklearn_object.n_neighbors
702
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
703
+ else:
704
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
668
705
 
669
706
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
670
707
  # seen during the fit.
@@ -676,12 +713,14 @@ class SVC(BaseTransformer):
676
713
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
677
714
  if self.sample_weight_col:
678
715
  output_df_columns_set -= set(self.sample_weight_col)
716
+
679
717
  # if the dimension of inferred output column names is correct; use it
680
718
  if len(expected_output_cols_list) == len(output_df_columns_set):
681
- return expected_output_cols_list
719
+ return expected_output_cols_list, output_df_pd
682
720
  # otherwise, use the sklearn estimator's output
683
721
  else:
684
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
722
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
723
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
685
724
 
686
725
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
687
726
  @telemetry.send_api_usage_telemetry(
@@ -729,7 +768,7 @@ class SVC(BaseTransformer):
729
768
  drop_input_cols=self._drop_input_cols,
730
769
  expected_output_cols_type="float",
731
770
  )
732
- expected_output_cols = self._align_expected_output_names(
771
+ expected_output_cols, _ = self._align_expected_output(
733
772
  inference_method, dataset, expected_output_cols, output_cols_prefix
734
773
  )
735
774
 
@@ -797,7 +836,7 @@ class SVC(BaseTransformer):
797
836
  drop_input_cols=self._drop_input_cols,
798
837
  expected_output_cols_type="float",
799
838
  )
800
- expected_output_cols = self._align_expected_output_names(
839
+ expected_output_cols, _ = self._align_expected_output(
801
840
  inference_method, dataset, expected_output_cols, output_cols_prefix
802
841
  )
803
842
  elif isinstance(dataset, pd.DataFrame):
@@ -862,7 +901,7 @@ class SVC(BaseTransformer):
862
901
  drop_input_cols=self._drop_input_cols,
863
902
  expected_output_cols_type="float",
864
903
  )
865
- expected_output_cols = self._align_expected_output_names(
904
+ expected_output_cols, _ = self._align_expected_output(
866
905
  inference_method, dataset, expected_output_cols, output_cols_prefix
867
906
  )
868
907
 
@@ -927,7 +966,7 @@ class SVC(BaseTransformer):
927
966
  drop_input_cols = self._drop_input_cols,
928
967
  expected_output_cols_type="float",
929
968
  )
930
- expected_output_cols = self._align_expected_output_names(
969
+ expected_output_cols, _ = self._align_expected_output(
931
970
  inference_method, dataset, expected_output_cols, output_cols_prefix
932
971
  )
933
972