snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +156 -20
  5. snowflake/ml/_internal/utils/identifier.py +48 -11
  6. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  7. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  8. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  9. snowflake/ml/_internal/utils/table_manager.py +19 -1
  10. snowflake/ml/_internal/utils/uri.py +2 -2
  11. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  12. snowflake/ml/data/data_connector.py +88 -9
  13. snowflake/ml/data/data_ingestor.py +18 -1
  14. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  15. snowflake/ml/data/torch_utils.py +68 -0
  16. snowflake/ml/dataset/dataset.py +1 -3
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +9 -3
  19. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  20. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  21. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  22. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  23. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  24. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  26. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  27. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  29. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  31. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  32. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  33. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  34. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  35. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  36. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  37. snowflake/ml/feature_store/feature_store.py +100 -41
  38. snowflake/ml/feature_store/feature_view.py +149 -5
  39. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  40. snowflake/ml/fileset/fileset.py +1 -1
  41. snowflake/ml/fileset/sfcfs.py +9 -3
  42. snowflake/ml/model/_client/model/model_impl.py +11 -2
  43. snowflake/ml/model/_client/model/model_version_impl.py +186 -20
  44. snowflake/ml/model/_client/ops/model_ops.py +144 -30
  45. snowflake/ml/model/_client/ops/service_ops.py +312 -0
  46. snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
  47. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
  48. snowflake/ml/model/_client/sql/model_version.py +13 -4
  49. snowflake/ml/model/_client/sql/service.py +196 -0
  50. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  51. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  52. snowflake/ml/model/_model_composer/model_composer.py +5 -0
  53. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
  54. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  55. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  57. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  58. snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
  59. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  60. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
  61. snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
  62. snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
  63. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  64. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  65. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  66. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  67. snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
  68. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
  69. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  70. snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
  71. snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
  72. snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
  73. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
  74. snowflake/ml/model/_packager/model_packager.py +4 -1
  75. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  76. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  77. snowflake/ml/model/_signatures/utils.py +9 -0
  78. snowflake/ml/model/models/llm.py +3 -1
  79. snowflake/ml/model/type_hints.py +10 -4
  80. snowflake/ml/modeling/_internal/constants.py +1 -0
  81. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  82. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  83. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  84. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  85. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  86. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  87. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  88. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  89. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  90. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  91. snowflake/ml/modeling/cluster/birch.py +60 -21
  92. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  93. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  94. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  95. snowflake/ml/modeling/cluster/k_means.py +60 -21
  96. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  97. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  98. snowflake/ml/modeling/cluster/optics.py +60 -21
  99. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  100. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  101. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  102. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  103. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  104. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  105. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  106. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  107. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  108. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  109. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  110. snowflake/ml/modeling/covariance/oas.py +60 -21
  111. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  112. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  113. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  114. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  115. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  116. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  117. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  118. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  119. snowflake/ml/modeling/decomposition/pca.py +60 -21
  120. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  121. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  122. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  123. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  124. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  125. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  126. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  127. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  128. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  129. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  130. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  131. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  132. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  133. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  134. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  135. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  136. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  137. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  138. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  139. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  140. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  141. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  142. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  143. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  144. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  145. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  146. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  147. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  148. snowflake/ml/modeling/framework/base.py +28 -19
  149. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  150. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  151. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  152. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  153. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  154. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  155. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  156. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  157. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  158. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  159. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  160. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  161. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  162. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  163. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  164. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  165. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  166. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  167. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  168. snowflake/ml/modeling/linear_model/lars.py +60 -21
  169. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  170. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  171. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  172. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  173. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  174. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  175. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  176. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  177. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  178. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  179. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  180. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  181. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  182. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  183. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  184. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  185. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  186. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  187. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  188. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  189. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  190. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  191. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  192. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  193. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  194. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  195. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  196. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  197. snowflake/ml/modeling/manifold/isomap.py +60 -21
  198. snowflake/ml/modeling/manifold/mds.py +60 -21
  199. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  200. snowflake/ml/modeling/manifold/tsne.py +60 -21
  201. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  202. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  203. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  204. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  205. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  206. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  207. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  208. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  209. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  210. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  211. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  212. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  213. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  214. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  215. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  216. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  217. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  218. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  219. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  220. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  221. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  222. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  223. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  224. snowflake/ml/modeling/pipeline/pipeline.py +4 -12
  225. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  226. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  227. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  228. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  229. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  230. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  231. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  232. snowflake/ml/modeling/svm/svc.py +60 -21
  233. snowflake/ml/modeling/svm/svr.py +60 -21
  234. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  235. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  236. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  237. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  238. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  239. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  240. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  241. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  242. snowflake/ml/registry/_manager/model_manager.py +20 -2
  243. snowflake/ml/registry/model_registry.py +1 -1
  244. snowflake/ml/registry/registry.py +1 -2
  245. snowflake/ml/utils/sql_client.py +22 -0
  246. snowflake/ml/version.py +1 -1
  247. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
  248. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
  249. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  250. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  251. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  252. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -534,12 +531,23 @@ class Lars(BaseTransformer):
534
531
  autogenerated=self._autogenerated,
535
532
  subproject=_SUBPROJECT,
536
533
  )
537
- output_result, fitted_estimator = model_trainer.train_fit_predict(
538
- drop_input_cols=self._drop_input_cols,
539
- expected_output_cols_list=(
540
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
541
- ),
534
+ expected_output_cols = (
535
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
542
536
  )
537
+ if isinstance(dataset, DataFrame):
538
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
539
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
540
+ )
541
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
542
+ drop_input_cols=self._drop_input_cols,
543
+ expected_output_cols_list=expected_output_cols,
544
+ example_output_pd_df=example_output_pd_df,
545
+ )
546
+ else:
547
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
548
+ drop_input_cols=self._drop_input_cols,
549
+ expected_output_cols_list=expected_output_cols,
550
+ )
543
551
  self._sklearn_object = fitted_estimator
544
552
  self._is_fitted = True
545
553
  return output_result
@@ -618,12 +626,41 @@ class Lars(BaseTransformer):
618
626
 
619
627
  return rv
620
628
 
621
- def _align_expected_output_names(
622
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
623
- ) -> List[str]:
629
+ def _align_expected_output(
630
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
631
+ ) -> Tuple[List[str], pd.DataFrame]:
632
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
633
+ and output dataframe with 1 line.
634
+ If the method is fit_predict, run 2 lines of data.
635
+ """
624
636
  # in case the inferred output column names dimension is different
625
637
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
626
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
638
+
639
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
640
+ # so change the minimum of number of rows to 2
641
+ num_examples = 2
642
+ statement_params = telemetry.get_function_usage_statement_params(
643
+ project=_PROJECT,
644
+ subproject=_SUBPROJECT,
645
+ function_name=telemetry.get_statement_params_full_func_name(
646
+ inspect.currentframe(), Lars.__class__.__name__
647
+ ),
648
+ api_calls=[Session.call],
649
+ custom_tags={"autogen": True} if self._autogenerated else None,
650
+ )
651
+ if output_cols_prefix == "fit_predict_":
652
+ if hasattr(self._sklearn_object, "n_clusters"):
653
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
654
+ num_examples = self._sklearn_object.n_clusters
655
+ elif hasattr(self._sklearn_object, "min_samples"):
656
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
657
+ num_examples = self._sklearn_object.min_samples
658
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
659
+ # LocalOutlierFactor expects n_neighbors <= n_samples
660
+ num_examples = self._sklearn_object.n_neighbors
661
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
662
+ else:
663
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
627
664
 
628
665
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
629
666
  # seen during the fit.
@@ -635,12 +672,14 @@ class Lars(BaseTransformer):
635
672
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
636
673
  if self.sample_weight_col:
637
674
  output_df_columns_set -= set(self.sample_weight_col)
675
+
638
676
  # if the dimension of inferred output column names is correct; use it
639
677
  if len(expected_output_cols_list) == len(output_df_columns_set):
640
- return expected_output_cols_list
678
+ return expected_output_cols_list, output_df_pd
641
679
  # otherwise, use the sklearn estimator's output
642
680
  else:
643
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
681
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
682
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
644
683
 
645
684
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
646
685
  @telemetry.send_api_usage_telemetry(
@@ -686,7 +725,7 @@ class Lars(BaseTransformer):
686
725
  drop_input_cols=self._drop_input_cols,
687
726
  expected_output_cols_type="float",
688
727
  )
689
- expected_output_cols = self._align_expected_output_names(
728
+ expected_output_cols, _ = self._align_expected_output(
690
729
  inference_method, dataset, expected_output_cols, output_cols_prefix
691
730
  )
692
731
 
@@ -752,7 +791,7 @@ class Lars(BaseTransformer):
752
791
  drop_input_cols=self._drop_input_cols,
753
792
  expected_output_cols_type="float",
754
793
  )
755
- expected_output_cols = self._align_expected_output_names(
794
+ expected_output_cols, _ = self._align_expected_output(
756
795
  inference_method, dataset, expected_output_cols, output_cols_prefix
757
796
  )
758
797
  elif isinstance(dataset, pd.DataFrame):
@@ -815,7 +854,7 @@ class Lars(BaseTransformer):
815
854
  drop_input_cols=self._drop_input_cols,
816
855
  expected_output_cols_type="float",
817
856
  )
818
- expected_output_cols = self._align_expected_output_names(
857
+ expected_output_cols, _ = self._align_expected_output(
819
858
  inference_method, dataset, expected_output_cols, output_cols_prefix
820
859
  )
821
860
 
@@ -880,7 +919,7 @@ class Lars(BaseTransformer):
880
919
  drop_input_cols = self._drop_input_cols,
881
920
  expected_output_cols_type="float",
882
921
  )
883
- expected_output_cols = self._align_expected_output_names(
922
+ expected_output_cols, _ = self._align_expected_output(
884
923
  inference_method, dataset, expected_output_cols, output_cols_prefix
885
924
  )
886
925
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -542,12 +539,23 @@ class LarsCV(BaseTransformer):
542
539
  autogenerated=self._autogenerated,
543
540
  subproject=_SUBPROJECT,
544
541
  )
545
- output_result, fitted_estimator = model_trainer.train_fit_predict(
546
- drop_input_cols=self._drop_input_cols,
547
- expected_output_cols_list=(
548
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
549
- ),
542
+ expected_output_cols = (
543
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
550
544
  )
545
+ if isinstance(dataset, DataFrame):
546
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
547
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
548
+ )
549
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
550
+ drop_input_cols=self._drop_input_cols,
551
+ expected_output_cols_list=expected_output_cols,
552
+ example_output_pd_df=example_output_pd_df,
553
+ )
554
+ else:
555
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
556
+ drop_input_cols=self._drop_input_cols,
557
+ expected_output_cols_list=expected_output_cols,
558
+ )
551
559
  self._sklearn_object = fitted_estimator
552
560
  self._is_fitted = True
553
561
  return output_result
@@ -626,12 +634,41 @@ class LarsCV(BaseTransformer):
626
634
 
627
635
  return rv
628
636
 
629
- def _align_expected_output_names(
630
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
631
- ) -> List[str]:
637
+ def _align_expected_output(
638
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
639
+ ) -> Tuple[List[str], pd.DataFrame]:
640
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
641
+ and output dataframe with 1 line.
642
+ If the method is fit_predict, run 2 lines of data.
643
+ """
632
644
  # in case the inferred output column names dimension is different
633
645
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
634
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
646
+
647
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
648
+ # so change the minimum of number of rows to 2
649
+ num_examples = 2
650
+ statement_params = telemetry.get_function_usage_statement_params(
651
+ project=_PROJECT,
652
+ subproject=_SUBPROJECT,
653
+ function_name=telemetry.get_statement_params_full_func_name(
654
+ inspect.currentframe(), LarsCV.__class__.__name__
655
+ ),
656
+ api_calls=[Session.call],
657
+ custom_tags={"autogen": True} if self._autogenerated else None,
658
+ )
659
+ if output_cols_prefix == "fit_predict_":
660
+ if hasattr(self._sklearn_object, "n_clusters"):
661
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
662
+ num_examples = self._sklearn_object.n_clusters
663
+ elif hasattr(self._sklearn_object, "min_samples"):
664
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
665
+ num_examples = self._sklearn_object.min_samples
666
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
667
+ # LocalOutlierFactor expects n_neighbors <= n_samples
668
+ num_examples = self._sklearn_object.n_neighbors
669
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
670
+ else:
671
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
635
672
 
636
673
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
637
674
  # seen during the fit.
@@ -643,12 +680,14 @@ class LarsCV(BaseTransformer):
643
680
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
644
681
  if self.sample_weight_col:
645
682
  output_df_columns_set -= set(self.sample_weight_col)
683
+
646
684
  # if the dimension of inferred output column names is correct; use it
647
685
  if len(expected_output_cols_list) == len(output_df_columns_set):
648
- return expected_output_cols_list
686
+ return expected_output_cols_list, output_df_pd
649
687
  # otherwise, use the sklearn estimator's output
650
688
  else:
651
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
689
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
690
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
652
691
 
653
692
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
654
693
  @telemetry.send_api_usage_telemetry(
@@ -694,7 +733,7 @@ class LarsCV(BaseTransformer):
694
733
  drop_input_cols=self._drop_input_cols,
695
734
  expected_output_cols_type="float",
696
735
  )
697
- expected_output_cols = self._align_expected_output_names(
736
+ expected_output_cols, _ = self._align_expected_output(
698
737
  inference_method, dataset, expected_output_cols, output_cols_prefix
699
738
  )
700
739
 
@@ -760,7 +799,7 @@ class LarsCV(BaseTransformer):
760
799
  drop_input_cols=self._drop_input_cols,
761
800
  expected_output_cols_type="float",
762
801
  )
763
- expected_output_cols = self._align_expected_output_names(
802
+ expected_output_cols, _ = self._align_expected_output(
764
803
  inference_method, dataset, expected_output_cols, output_cols_prefix
765
804
  )
766
805
  elif isinstance(dataset, pd.DataFrame):
@@ -823,7 +862,7 @@ class LarsCV(BaseTransformer):
823
862
  drop_input_cols=self._drop_input_cols,
824
863
  expected_output_cols_type="float",
825
864
  )
826
- expected_output_cols = self._align_expected_output_names(
865
+ expected_output_cols, _ = self._align_expected_output(
827
866
  inference_method, dataset, expected_output_cols, output_cols_prefix
828
867
  )
829
868
 
@@ -888,7 +927,7 @@ class LarsCV(BaseTransformer):
888
927
  drop_input_cols = self._drop_input_cols,
889
928
  expected_output_cols_type="float",
890
929
  )
891
- expected_output_cols = self._align_expected_output_names(
930
+ expected_output_cols, _ = self._align_expected_output(
892
931
  inference_method, dataset, expected_output_cols, output_cols_prefix
893
932
  )
894
933
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -535,12 +532,23 @@ class Lasso(BaseTransformer):
535
532
  autogenerated=self._autogenerated,
536
533
  subproject=_SUBPROJECT,
537
534
  )
538
- output_result, fitted_estimator = model_trainer.train_fit_predict(
539
- drop_input_cols=self._drop_input_cols,
540
- expected_output_cols_list=(
541
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
542
- ),
535
+ expected_output_cols = (
536
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
543
537
  )
538
+ if isinstance(dataset, DataFrame):
539
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
540
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
541
+ )
542
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
543
+ drop_input_cols=self._drop_input_cols,
544
+ expected_output_cols_list=expected_output_cols,
545
+ example_output_pd_df=example_output_pd_df,
546
+ )
547
+ else:
548
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
549
+ drop_input_cols=self._drop_input_cols,
550
+ expected_output_cols_list=expected_output_cols,
551
+ )
544
552
  self._sklearn_object = fitted_estimator
545
553
  self._is_fitted = True
546
554
  return output_result
@@ -619,12 +627,41 @@ class Lasso(BaseTransformer):
619
627
 
620
628
  return rv
621
629
 
622
- def _align_expected_output_names(
623
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
624
- ) -> List[str]:
630
+ def _align_expected_output(
631
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
632
+ ) -> Tuple[List[str], pd.DataFrame]:
633
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
634
+ and output dataframe with 1 line.
635
+ If the method is fit_predict, run 2 lines of data.
636
+ """
625
637
  # in case the inferred output column names dimension is different
626
638
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
627
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
639
+
640
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
641
+ # so change the minimum of number of rows to 2
642
+ num_examples = 2
643
+ statement_params = telemetry.get_function_usage_statement_params(
644
+ project=_PROJECT,
645
+ subproject=_SUBPROJECT,
646
+ function_name=telemetry.get_statement_params_full_func_name(
647
+ inspect.currentframe(), Lasso.__class__.__name__
648
+ ),
649
+ api_calls=[Session.call],
650
+ custom_tags={"autogen": True} if self._autogenerated else None,
651
+ )
652
+ if output_cols_prefix == "fit_predict_":
653
+ if hasattr(self._sklearn_object, "n_clusters"):
654
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
655
+ num_examples = self._sklearn_object.n_clusters
656
+ elif hasattr(self._sklearn_object, "min_samples"):
657
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
658
+ num_examples = self._sklearn_object.min_samples
659
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
660
+ # LocalOutlierFactor expects n_neighbors <= n_samples
661
+ num_examples = self._sklearn_object.n_neighbors
662
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
663
+ else:
664
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
628
665
 
629
666
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
630
667
  # seen during the fit.
@@ -636,12 +673,14 @@ class Lasso(BaseTransformer):
636
673
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
637
674
  if self.sample_weight_col:
638
675
  output_df_columns_set -= set(self.sample_weight_col)
676
+
639
677
  # if the dimension of inferred output column names is correct; use it
640
678
  if len(expected_output_cols_list) == len(output_df_columns_set):
641
- return expected_output_cols_list
679
+ return expected_output_cols_list, output_df_pd
642
680
  # otherwise, use the sklearn estimator's output
643
681
  else:
644
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
682
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
683
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
645
684
 
646
685
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
647
686
  @telemetry.send_api_usage_telemetry(
@@ -687,7 +726,7 @@ class Lasso(BaseTransformer):
687
726
  drop_input_cols=self._drop_input_cols,
688
727
  expected_output_cols_type="float",
689
728
  )
690
- expected_output_cols = self._align_expected_output_names(
729
+ expected_output_cols, _ = self._align_expected_output(
691
730
  inference_method, dataset, expected_output_cols, output_cols_prefix
692
731
  )
693
732
 
@@ -753,7 +792,7 @@ class Lasso(BaseTransformer):
753
792
  drop_input_cols=self._drop_input_cols,
754
793
  expected_output_cols_type="float",
755
794
  )
756
- expected_output_cols = self._align_expected_output_names(
795
+ expected_output_cols, _ = self._align_expected_output(
757
796
  inference_method, dataset, expected_output_cols, output_cols_prefix
758
797
  )
759
798
  elif isinstance(dataset, pd.DataFrame):
@@ -816,7 +855,7 @@ class Lasso(BaseTransformer):
816
855
  drop_input_cols=self._drop_input_cols,
817
856
  expected_output_cols_type="float",
818
857
  )
819
- expected_output_cols = self._align_expected_output_names(
858
+ expected_output_cols, _ = self._align_expected_output(
820
859
  inference_method, dataset, expected_output_cols, output_cols_prefix
821
860
  )
822
861
 
@@ -881,7 +920,7 @@ class Lasso(BaseTransformer):
881
920
  drop_input_cols = self._drop_input_cols,
882
921
  expected_output_cols_type="float",
883
922
  )
884
- expected_output_cols = self._align_expected_output_names(
923
+ expected_output_cols, _ = self._align_expected_output(
885
924
  inference_method, dataset, expected_output_cols, output_cols_prefix
886
925
  )
887
926
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -563,12 +560,23 @@ class LassoCV(BaseTransformer):
563
560
  autogenerated=self._autogenerated,
564
561
  subproject=_SUBPROJECT,
565
562
  )
566
- output_result, fitted_estimator = model_trainer.train_fit_predict(
567
- drop_input_cols=self._drop_input_cols,
568
- expected_output_cols_list=(
569
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
570
- ),
563
+ expected_output_cols = (
564
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
571
565
  )
566
+ if isinstance(dataset, DataFrame):
567
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
568
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
569
+ )
570
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
571
+ drop_input_cols=self._drop_input_cols,
572
+ expected_output_cols_list=expected_output_cols,
573
+ example_output_pd_df=example_output_pd_df,
574
+ )
575
+ else:
576
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
577
+ drop_input_cols=self._drop_input_cols,
578
+ expected_output_cols_list=expected_output_cols,
579
+ )
572
580
  self._sklearn_object = fitted_estimator
573
581
  self._is_fitted = True
574
582
  return output_result
@@ -647,12 +655,41 @@ class LassoCV(BaseTransformer):
647
655
 
648
656
  return rv
649
657
 
650
- def _align_expected_output_names(
651
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
652
- ) -> List[str]:
658
+ def _align_expected_output(
659
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
660
+ ) -> Tuple[List[str], pd.DataFrame]:
661
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
662
+ and output dataframe with 1 line.
663
+ If the method is fit_predict, run 2 lines of data.
664
+ """
653
665
  # in case the inferred output column names dimension is different
654
666
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
655
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
667
+
668
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
669
+ # so change the minimum of number of rows to 2
670
+ num_examples = 2
671
+ statement_params = telemetry.get_function_usage_statement_params(
672
+ project=_PROJECT,
673
+ subproject=_SUBPROJECT,
674
+ function_name=telemetry.get_statement_params_full_func_name(
675
+ inspect.currentframe(), LassoCV.__class__.__name__
676
+ ),
677
+ api_calls=[Session.call],
678
+ custom_tags={"autogen": True} if self._autogenerated else None,
679
+ )
680
+ if output_cols_prefix == "fit_predict_":
681
+ if hasattr(self._sklearn_object, "n_clusters"):
682
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
683
+ num_examples = self._sklearn_object.n_clusters
684
+ elif hasattr(self._sklearn_object, "min_samples"):
685
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
686
+ num_examples = self._sklearn_object.min_samples
687
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
688
+ # LocalOutlierFactor expects n_neighbors <= n_samples
689
+ num_examples = self._sklearn_object.n_neighbors
690
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
691
+ else:
692
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
656
693
 
657
694
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
658
695
  # seen during the fit.
@@ -664,12 +701,14 @@ class LassoCV(BaseTransformer):
664
701
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
665
702
  if self.sample_weight_col:
666
703
  output_df_columns_set -= set(self.sample_weight_col)
704
+
667
705
  # if the dimension of inferred output column names is correct; use it
668
706
  if len(expected_output_cols_list) == len(output_df_columns_set):
669
- return expected_output_cols_list
707
+ return expected_output_cols_list, output_df_pd
670
708
  # otherwise, use the sklearn estimator's output
671
709
  else:
672
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
710
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
711
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
673
712
 
674
713
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
675
714
  @telemetry.send_api_usage_telemetry(
@@ -715,7 +754,7 @@ class LassoCV(BaseTransformer):
715
754
  drop_input_cols=self._drop_input_cols,
716
755
  expected_output_cols_type="float",
717
756
  )
718
- expected_output_cols = self._align_expected_output_names(
757
+ expected_output_cols, _ = self._align_expected_output(
719
758
  inference_method, dataset, expected_output_cols, output_cols_prefix
720
759
  )
721
760
 
@@ -781,7 +820,7 @@ class LassoCV(BaseTransformer):
781
820
  drop_input_cols=self._drop_input_cols,
782
821
  expected_output_cols_type="float",
783
822
  )
784
- expected_output_cols = self._align_expected_output_names(
823
+ expected_output_cols, _ = self._align_expected_output(
785
824
  inference_method, dataset, expected_output_cols, output_cols_prefix
786
825
  )
787
826
  elif isinstance(dataset, pd.DataFrame):
@@ -844,7 +883,7 @@ class LassoCV(BaseTransformer):
844
883
  drop_input_cols=self._drop_input_cols,
845
884
  expected_output_cols_type="float",
846
885
  )
847
- expected_output_cols = self._align_expected_output_names(
886
+ expected_output_cols, _ = self._align_expected_output(
848
887
  inference_method, dataset, expected_output_cols, output_cols_prefix
849
888
  )
850
889
 
@@ -909,7 +948,7 @@ class LassoCV(BaseTransformer):
909
948
  drop_input_cols = self._drop_input_cols,
910
949
  expected_output_cols_type="float",
911
950
  )
912
- expected_output_cols = self._align_expected_output_names(
951
+ expected_output_cols, _ = self._align_expected_output(
913
952
  inference_method, dataset, expected_output_cols, output_cols_prefix
914
953
  )
915
954