snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +156 -20
  5. snowflake/ml/_internal/utils/identifier.py +48 -11
  6. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  7. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  8. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  9. snowflake/ml/_internal/utils/table_manager.py +19 -1
  10. snowflake/ml/_internal/utils/uri.py +2 -2
  11. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  12. snowflake/ml/data/data_connector.py +88 -9
  13. snowflake/ml/data/data_ingestor.py +18 -1
  14. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  15. snowflake/ml/data/torch_utils.py +68 -0
  16. snowflake/ml/dataset/dataset.py +1 -3
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +9 -3
  19. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  20. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  21. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  22. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  23. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  24. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  26. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  27. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  29. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  31. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  32. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  33. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  34. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  35. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  36. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  37. snowflake/ml/feature_store/feature_store.py +100 -41
  38. snowflake/ml/feature_store/feature_view.py +149 -5
  39. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  40. snowflake/ml/fileset/fileset.py +1 -1
  41. snowflake/ml/fileset/sfcfs.py +9 -3
  42. snowflake/ml/model/_client/model/model_impl.py +11 -2
  43. snowflake/ml/model/_client/model/model_version_impl.py +186 -20
  44. snowflake/ml/model/_client/ops/model_ops.py +144 -30
  45. snowflake/ml/model/_client/ops/service_ops.py +312 -0
  46. snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
  47. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
  48. snowflake/ml/model/_client/sql/model_version.py +13 -4
  49. snowflake/ml/model/_client/sql/service.py +196 -0
  50. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  51. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  52. snowflake/ml/model/_model_composer/model_composer.py +5 -0
  53. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
  54. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  55. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  57. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  58. snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
  59. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  60. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
  61. snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
  62. snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
  63. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  64. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  65. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  66. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  67. snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
  68. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
  69. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  70. snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
  71. snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
  72. snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
  73. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
  74. snowflake/ml/model/_packager/model_packager.py +4 -1
  75. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  76. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  77. snowflake/ml/model/_signatures/utils.py +9 -0
  78. snowflake/ml/model/models/llm.py +3 -1
  79. snowflake/ml/model/type_hints.py +10 -4
  80. snowflake/ml/modeling/_internal/constants.py +1 -0
  81. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  82. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  83. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  84. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  85. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  86. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  87. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  88. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  89. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  90. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  91. snowflake/ml/modeling/cluster/birch.py +60 -21
  92. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  93. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  94. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  95. snowflake/ml/modeling/cluster/k_means.py +60 -21
  96. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  97. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  98. snowflake/ml/modeling/cluster/optics.py +60 -21
  99. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  100. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  101. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  102. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  103. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  104. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  105. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  106. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  107. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  108. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  109. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  110. snowflake/ml/modeling/covariance/oas.py +60 -21
  111. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  112. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  113. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  114. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  115. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  116. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  117. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  118. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  119. snowflake/ml/modeling/decomposition/pca.py +60 -21
  120. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  121. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  122. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  123. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  124. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  125. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  126. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  127. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  128. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  129. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  130. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  131. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  132. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  133. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  134. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  135. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  136. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  137. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  138. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  139. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  140. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  141. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  142. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  143. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  144. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  145. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  146. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  147. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  148. snowflake/ml/modeling/framework/base.py +28 -19
  149. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  150. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  151. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  152. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  153. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  154. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  155. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  156. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  157. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  158. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  159. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  160. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  161. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  162. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  163. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  164. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  165. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  166. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  167. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  168. snowflake/ml/modeling/linear_model/lars.py +60 -21
  169. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  170. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  171. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  172. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  173. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  174. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  175. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  176. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  177. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  178. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  179. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  180. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  181. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  182. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  183. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  184. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  185. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  186. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  187. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  188. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  189. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  190. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  191. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  192. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  193. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  194. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  195. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  196. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  197. snowflake/ml/modeling/manifold/isomap.py +60 -21
  198. snowflake/ml/modeling/manifold/mds.py +60 -21
  199. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  200. snowflake/ml/modeling/manifold/tsne.py +60 -21
  201. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  202. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  203. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  204. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  205. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  206. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  207. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  208. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  209. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  210. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  211. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  212. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  213. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  214. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  215. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  216. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  217. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  218. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  219. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  220. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  221. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  222. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  223. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  224. snowflake/ml/modeling/pipeline/pipeline.py +4 -12
  225. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  226. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  227. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  228. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  229. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  230. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  231. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  232. snowflake/ml/modeling/svm/svc.py +60 -21
  233. snowflake/ml/modeling/svm/svr.py +60 -21
  234. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  235. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  236. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  237. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  238. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  239. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  240. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  241. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  242. snowflake/ml/registry/_manager/model_manager.py +20 -2
  243. snowflake/ml/registry/model_registry.py +1 -1
  244. snowflake/ml/registry/registry.py +1 -2
  245. snowflake/ml/utils/sql_client.py +22 -0
  246. snowflake/ml/version.py +1 -1
  247. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
  248. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
  249. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  250. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  251. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  252. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -555,12 +552,23 @@ class LassoLars(BaseTransformer):
555
552
  autogenerated=self._autogenerated,
556
553
  subproject=_SUBPROJECT,
557
554
  )
558
- output_result, fitted_estimator = model_trainer.train_fit_predict(
559
- drop_input_cols=self._drop_input_cols,
560
- expected_output_cols_list=(
561
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
562
- ),
555
+ expected_output_cols = (
556
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
563
557
  )
558
+ if isinstance(dataset, DataFrame):
559
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
560
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
561
+ )
562
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
563
+ drop_input_cols=self._drop_input_cols,
564
+ expected_output_cols_list=expected_output_cols,
565
+ example_output_pd_df=example_output_pd_df,
566
+ )
567
+ else:
568
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
569
+ drop_input_cols=self._drop_input_cols,
570
+ expected_output_cols_list=expected_output_cols,
571
+ )
564
572
  self._sklearn_object = fitted_estimator
565
573
  self._is_fitted = True
566
574
  return output_result
@@ -639,12 +647,41 @@ class LassoLars(BaseTransformer):
639
647
 
640
648
  return rv
641
649
 
642
- def _align_expected_output_names(
643
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
644
- ) -> List[str]:
650
+ def _align_expected_output(
651
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
652
+ ) -> Tuple[List[str], pd.DataFrame]:
653
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
654
+ and output dataframe with 1 line.
655
+ If the method is fit_predict, run 2 lines of data.
656
+ """
645
657
  # in case the inferred output column names dimension is different
646
658
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
647
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
659
+
660
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
661
+ # so change the minimum of number of rows to 2
662
+ num_examples = 2
663
+ statement_params = telemetry.get_function_usage_statement_params(
664
+ project=_PROJECT,
665
+ subproject=_SUBPROJECT,
666
+ function_name=telemetry.get_statement_params_full_func_name(
667
+ inspect.currentframe(), LassoLars.__class__.__name__
668
+ ),
669
+ api_calls=[Session.call],
670
+ custom_tags={"autogen": True} if self._autogenerated else None,
671
+ )
672
+ if output_cols_prefix == "fit_predict_":
673
+ if hasattr(self._sklearn_object, "n_clusters"):
674
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
675
+ num_examples = self._sklearn_object.n_clusters
676
+ elif hasattr(self._sklearn_object, "min_samples"):
677
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
678
+ num_examples = self._sklearn_object.min_samples
679
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
680
+ # LocalOutlierFactor expects n_neighbors <= n_samples
681
+ num_examples = self._sklearn_object.n_neighbors
682
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
683
+ else:
684
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
648
685
 
649
686
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
650
687
  # seen during the fit.
@@ -656,12 +693,14 @@ class LassoLars(BaseTransformer):
656
693
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
657
694
  if self.sample_weight_col:
658
695
  output_df_columns_set -= set(self.sample_weight_col)
696
+
659
697
  # if the dimension of inferred output column names is correct; use it
660
698
  if len(expected_output_cols_list) == len(output_df_columns_set):
661
- return expected_output_cols_list
699
+ return expected_output_cols_list, output_df_pd
662
700
  # otherwise, use the sklearn estimator's output
663
701
  else:
664
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
702
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
703
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
665
704
 
666
705
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
667
706
  @telemetry.send_api_usage_telemetry(
@@ -707,7 +746,7 @@ class LassoLars(BaseTransformer):
707
746
  drop_input_cols=self._drop_input_cols,
708
747
  expected_output_cols_type="float",
709
748
  )
710
- expected_output_cols = self._align_expected_output_names(
749
+ expected_output_cols, _ = self._align_expected_output(
711
750
  inference_method, dataset, expected_output_cols, output_cols_prefix
712
751
  )
713
752
 
@@ -773,7 +812,7 @@ class LassoLars(BaseTransformer):
773
812
  drop_input_cols=self._drop_input_cols,
774
813
  expected_output_cols_type="float",
775
814
  )
776
- expected_output_cols = self._align_expected_output_names(
815
+ expected_output_cols, _ = self._align_expected_output(
777
816
  inference_method, dataset, expected_output_cols, output_cols_prefix
778
817
  )
779
818
  elif isinstance(dataset, pd.DataFrame):
@@ -836,7 +875,7 @@ class LassoLars(BaseTransformer):
836
875
  drop_input_cols=self._drop_input_cols,
837
876
  expected_output_cols_type="float",
838
877
  )
839
- expected_output_cols = self._align_expected_output_names(
878
+ expected_output_cols, _ = self._align_expected_output(
840
879
  inference_method, dataset, expected_output_cols, output_cols_prefix
841
880
  )
842
881
 
@@ -901,7 +940,7 @@ class LassoLars(BaseTransformer):
901
940
  drop_input_cols = self._drop_input_cols,
902
941
  expected_output_cols_type="float",
903
942
  )
904
- expected_output_cols = self._align_expected_output_names(
943
+ expected_output_cols, _ = self._align_expected_output(
905
944
  inference_method, dataset, expected_output_cols, output_cols_prefix
906
945
  )
907
946
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -556,12 +553,23 @@ class LassoLarsCV(BaseTransformer):
556
553
  autogenerated=self._autogenerated,
557
554
  subproject=_SUBPROJECT,
558
555
  )
559
- output_result, fitted_estimator = model_trainer.train_fit_predict(
560
- drop_input_cols=self._drop_input_cols,
561
- expected_output_cols_list=(
562
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
563
- ),
556
+ expected_output_cols = (
557
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
564
558
  )
559
+ if isinstance(dataset, DataFrame):
560
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
561
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
562
+ )
563
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
564
+ drop_input_cols=self._drop_input_cols,
565
+ expected_output_cols_list=expected_output_cols,
566
+ example_output_pd_df=example_output_pd_df,
567
+ )
568
+ else:
569
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
570
+ drop_input_cols=self._drop_input_cols,
571
+ expected_output_cols_list=expected_output_cols,
572
+ )
565
573
  self._sklearn_object = fitted_estimator
566
574
  self._is_fitted = True
567
575
  return output_result
@@ -640,12 +648,41 @@ class LassoLarsCV(BaseTransformer):
640
648
 
641
649
  return rv
642
650
 
643
- def _align_expected_output_names(
644
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
645
- ) -> List[str]:
651
+ def _align_expected_output(
652
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
653
+ ) -> Tuple[List[str], pd.DataFrame]:
654
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
655
+ and output dataframe with 1 line.
656
+ If the method is fit_predict, run 2 lines of data.
657
+ """
646
658
  # in case the inferred output column names dimension is different
647
659
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
648
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
660
+
661
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
662
+ # so change the minimum of number of rows to 2
663
+ num_examples = 2
664
+ statement_params = telemetry.get_function_usage_statement_params(
665
+ project=_PROJECT,
666
+ subproject=_SUBPROJECT,
667
+ function_name=telemetry.get_statement_params_full_func_name(
668
+ inspect.currentframe(), LassoLarsCV.__class__.__name__
669
+ ),
670
+ api_calls=[Session.call],
671
+ custom_tags={"autogen": True} if self._autogenerated else None,
672
+ )
673
+ if output_cols_prefix == "fit_predict_":
674
+ if hasattr(self._sklearn_object, "n_clusters"):
675
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
676
+ num_examples = self._sklearn_object.n_clusters
677
+ elif hasattr(self._sklearn_object, "min_samples"):
678
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
679
+ num_examples = self._sklearn_object.min_samples
680
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
681
+ # LocalOutlierFactor expects n_neighbors <= n_samples
682
+ num_examples = self._sklearn_object.n_neighbors
683
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
684
+ else:
685
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
649
686
 
650
687
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
651
688
  # seen during the fit.
@@ -657,12 +694,14 @@ class LassoLarsCV(BaseTransformer):
657
694
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
658
695
  if self.sample_weight_col:
659
696
  output_df_columns_set -= set(self.sample_weight_col)
697
+
660
698
  # if the dimension of inferred output column names is correct; use it
661
699
  if len(expected_output_cols_list) == len(output_df_columns_set):
662
- return expected_output_cols_list
700
+ return expected_output_cols_list, output_df_pd
663
701
  # otherwise, use the sklearn estimator's output
664
702
  else:
665
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
703
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
704
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
666
705
 
667
706
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
668
707
  @telemetry.send_api_usage_telemetry(
@@ -708,7 +747,7 @@ class LassoLarsCV(BaseTransformer):
708
747
  drop_input_cols=self._drop_input_cols,
709
748
  expected_output_cols_type="float",
710
749
  )
711
- expected_output_cols = self._align_expected_output_names(
750
+ expected_output_cols, _ = self._align_expected_output(
712
751
  inference_method, dataset, expected_output_cols, output_cols_prefix
713
752
  )
714
753
 
@@ -774,7 +813,7 @@ class LassoLarsCV(BaseTransformer):
774
813
  drop_input_cols=self._drop_input_cols,
775
814
  expected_output_cols_type="float",
776
815
  )
777
- expected_output_cols = self._align_expected_output_names(
816
+ expected_output_cols, _ = self._align_expected_output(
778
817
  inference_method, dataset, expected_output_cols, output_cols_prefix
779
818
  )
780
819
  elif isinstance(dataset, pd.DataFrame):
@@ -837,7 +876,7 @@ class LassoLarsCV(BaseTransformer):
837
876
  drop_input_cols=self._drop_input_cols,
838
877
  expected_output_cols_type="float",
839
878
  )
840
- expected_output_cols = self._align_expected_output_names(
879
+ expected_output_cols, _ = self._align_expected_output(
841
880
  inference_method, dataset, expected_output_cols, output_cols_prefix
842
881
  )
843
882
 
@@ -902,7 +941,7 @@ class LassoLarsCV(BaseTransformer):
902
941
  drop_input_cols = self._drop_input_cols,
903
942
  expected_output_cols_type="float",
904
943
  )
905
- expected_output_cols = self._align_expected_output_names(
944
+ expected_output_cols, _ = self._align_expected_output(
906
945
  inference_method, dataset, expected_output_cols, output_cols_prefix
907
946
  )
908
947
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -539,12 +536,23 @@ class LassoLarsIC(BaseTransformer):
539
536
  autogenerated=self._autogenerated,
540
537
  subproject=_SUBPROJECT,
541
538
  )
542
- output_result, fitted_estimator = model_trainer.train_fit_predict(
543
- drop_input_cols=self._drop_input_cols,
544
- expected_output_cols_list=(
545
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
546
- ),
539
+ expected_output_cols = (
540
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
547
541
  )
542
+ if isinstance(dataset, DataFrame):
543
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
544
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
545
+ )
546
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
547
+ drop_input_cols=self._drop_input_cols,
548
+ expected_output_cols_list=expected_output_cols,
549
+ example_output_pd_df=example_output_pd_df,
550
+ )
551
+ else:
552
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
553
+ drop_input_cols=self._drop_input_cols,
554
+ expected_output_cols_list=expected_output_cols,
555
+ )
548
556
  self._sklearn_object = fitted_estimator
549
557
  self._is_fitted = True
550
558
  return output_result
@@ -623,12 +631,41 @@ class LassoLarsIC(BaseTransformer):
623
631
 
624
632
  return rv
625
633
 
626
- def _align_expected_output_names(
627
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
628
- ) -> List[str]:
634
+ def _align_expected_output(
635
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
636
+ ) -> Tuple[List[str], pd.DataFrame]:
637
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
638
+ and output dataframe with 1 line.
639
+ If the method is fit_predict, run 2 lines of data.
640
+ """
629
641
  # in case the inferred output column names dimension is different
630
642
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
631
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
643
+
644
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
645
+ # so change the minimum of number of rows to 2
646
+ num_examples = 2
647
+ statement_params = telemetry.get_function_usage_statement_params(
648
+ project=_PROJECT,
649
+ subproject=_SUBPROJECT,
650
+ function_name=telemetry.get_statement_params_full_func_name(
651
+ inspect.currentframe(), LassoLarsIC.__class__.__name__
652
+ ),
653
+ api_calls=[Session.call],
654
+ custom_tags={"autogen": True} if self._autogenerated else None,
655
+ )
656
+ if output_cols_prefix == "fit_predict_":
657
+ if hasattr(self._sklearn_object, "n_clusters"):
658
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
659
+ num_examples = self._sklearn_object.n_clusters
660
+ elif hasattr(self._sklearn_object, "min_samples"):
661
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
662
+ num_examples = self._sklearn_object.min_samples
663
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
664
+ # LocalOutlierFactor expects n_neighbors <= n_samples
665
+ num_examples = self._sklearn_object.n_neighbors
666
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
667
+ else:
668
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
632
669
 
633
670
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
634
671
  # seen during the fit.
@@ -640,12 +677,14 @@ class LassoLarsIC(BaseTransformer):
640
677
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
641
678
  if self.sample_weight_col:
642
679
  output_df_columns_set -= set(self.sample_weight_col)
680
+
643
681
  # if the dimension of inferred output column names is correct; use it
644
682
  if len(expected_output_cols_list) == len(output_df_columns_set):
645
- return expected_output_cols_list
683
+ return expected_output_cols_list, output_df_pd
646
684
  # otherwise, use the sklearn estimator's output
647
685
  else:
648
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
686
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
687
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
649
688
 
650
689
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
651
690
  @telemetry.send_api_usage_telemetry(
@@ -691,7 +730,7 @@ class LassoLarsIC(BaseTransformer):
691
730
  drop_input_cols=self._drop_input_cols,
692
731
  expected_output_cols_type="float",
693
732
  )
694
- expected_output_cols = self._align_expected_output_names(
733
+ expected_output_cols, _ = self._align_expected_output(
695
734
  inference_method, dataset, expected_output_cols, output_cols_prefix
696
735
  )
697
736
 
@@ -757,7 +796,7 @@ class LassoLarsIC(BaseTransformer):
757
796
  drop_input_cols=self._drop_input_cols,
758
797
  expected_output_cols_type="float",
759
798
  )
760
- expected_output_cols = self._align_expected_output_names(
799
+ expected_output_cols, _ = self._align_expected_output(
761
800
  inference_method, dataset, expected_output_cols, output_cols_prefix
762
801
  )
763
802
  elif isinstance(dataset, pd.DataFrame):
@@ -820,7 +859,7 @@ class LassoLarsIC(BaseTransformer):
820
859
  drop_input_cols=self._drop_input_cols,
821
860
  expected_output_cols_type="float",
822
861
  )
823
- expected_output_cols = self._align_expected_output_names(
862
+ expected_output_cols, _ = self._align_expected_output(
824
863
  inference_method, dataset, expected_output_cols, output_cols_prefix
825
864
  )
826
865
 
@@ -885,7 +924,7 @@ class LassoLarsIC(BaseTransformer):
885
924
  drop_input_cols = self._drop_input_cols,
886
925
  expected_output_cols_type="float",
887
926
  )
888
- expected_output_cols = self._align_expected_output_names(
927
+ expected_output_cols, _ = self._align_expected_output(
889
928
  inference_method, dataset, expected_output_cols, output_cols_prefix
890
929
  )
891
930
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -492,12 +489,23 @@ class LinearRegression(BaseTransformer):
492
489
  autogenerated=self._autogenerated,
493
490
  subproject=_SUBPROJECT,
494
491
  )
495
- output_result, fitted_estimator = model_trainer.train_fit_predict(
496
- drop_input_cols=self._drop_input_cols,
497
- expected_output_cols_list=(
498
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
499
- ),
492
+ expected_output_cols = (
493
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
500
494
  )
495
+ if isinstance(dataset, DataFrame):
496
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
497
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
498
+ )
499
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
500
+ drop_input_cols=self._drop_input_cols,
501
+ expected_output_cols_list=expected_output_cols,
502
+ example_output_pd_df=example_output_pd_df,
503
+ )
504
+ else:
505
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
506
+ drop_input_cols=self._drop_input_cols,
507
+ expected_output_cols_list=expected_output_cols,
508
+ )
501
509
  self._sklearn_object = fitted_estimator
502
510
  self._is_fitted = True
503
511
  return output_result
@@ -576,12 +584,41 @@ class LinearRegression(BaseTransformer):
576
584
 
577
585
  return rv
578
586
 
579
- def _align_expected_output_names(
580
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
581
- ) -> List[str]:
587
+ def _align_expected_output(
588
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
589
+ ) -> Tuple[List[str], pd.DataFrame]:
590
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
591
+ and output dataframe with 1 line.
592
+ If the method is fit_predict, run 2 lines of data.
593
+ """
582
594
  # in case the inferred output column names dimension is different
583
595
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
584
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
596
+
597
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
598
+ # so change the minimum of number of rows to 2
599
+ num_examples = 2
600
+ statement_params = telemetry.get_function_usage_statement_params(
601
+ project=_PROJECT,
602
+ subproject=_SUBPROJECT,
603
+ function_name=telemetry.get_statement_params_full_func_name(
604
+ inspect.currentframe(), LinearRegression.__class__.__name__
605
+ ),
606
+ api_calls=[Session.call],
607
+ custom_tags={"autogen": True} if self._autogenerated else None,
608
+ )
609
+ if output_cols_prefix == "fit_predict_":
610
+ if hasattr(self._sklearn_object, "n_clusters"):
611
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
612
+ num_examples = self._sklearn_object.n_clusters
613
+ elif hasattr(self._sklearn_object, "min_samples"):
614
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
615
+ num_examples = self._sklearn_object.min_samples
616
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
617
+ # LocalOutlierFactor expects n_neighbors <= n_samples
618
+ num_examples = self._sklearn_object.n_neighbors
619
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
620
+ else:
621
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
585
622
 
586
623
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
587
624
  # seen during the fit.
@@ -593,12 +630,14 @@ class LinearRegression(BaseTransformer):
593
630
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
594
631
  if self.sample_weight_col:
595
632
  output_df_columns_set -= set(self.sample_weight_col)
633
+
596
634
  # if the dimension of inferred output column names is correct; use it
597
635
  if len(expected_output_cols_list) == len(output_df_columns_set):
598
- return expected_output_cols_list
636
+ return expected_output_cols_list, output_df_pd
599
637
  # otherwise, use the sklearn estimator's output
600
638
  else:
601
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
639
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
640
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
602
641
 
603
642
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
604
643
  @telemetry.send_api_usage_telemetry(
@@ -644,7 +683,7 @@ class LinearRegression(BaseTransformer):
644
683
  drop_input_cols=self._drop_input_cols,
645
684
  expected_output_cols_type="float",
646
685
  )
647
- expected_output_cols = self._align_expected_output_names(
686
+ expected_output_cols, _ = self._align_expected_output(
648
687
  inference_method, dataset, expected_output_cols, output_cols_prefix
649
688
  )
650
689
 
@@ -710,7 +749,7 @@ class LinearRegression(BaseTransformer):
710
749
  drop_input_cols=self._drop_input_cols,
711
750
  expected_output_cols_type="float",
712
751
  )
713
- expected_output_cols = self._align_expected_output_names(
752
+ expected_output_cols, _ = self._align_expected_output(
714
753
  inference_method, dataset, expected_output_cols, output_cols_prefix
715
754
  )
716
755
  elif isinstance(dataset, pd.DataFrame):
@@ -773,7 +812,7 @@ class LinearRegression(BaseTransformer):
773
812
  drop_input_cols=self._drop_input_cols,
774
813
  expected_output_cols_type="float",
775
814
  )
776
- expected_output_cols = self._align_expected_output_names(
815
+ expected_output_cols, _ = self._align_expected_output(
777
816
  inference_method, dataset, expected_output_cols, output_cols_prefix
778
817
  )
779
818
 
@@ -838,7 +877,7 @@ class LinearRegression(BaseTransformer):
838
877
  drop_input_cols = self._drop_input_cols,
839
878
  expected_output_cols_type="float",
840
879
  )
841
- expected_output_cols = self._align_expected_output_names(
880
+ expected_output_cols, _ = self._align_expected_output(
842
881
  inference_method, dataset, expected_output_cols, output_cols_prefix
843
882
  )
844
883