snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +156 -20
  5. snowflake/ml/_internal/utils/identifier.py +48 -11
  6. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  7. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  8. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  9. snowflake/ml/_internal/utils/table_manager.py +19 -1
  10. snowflake/ml/_internal/utils/uri.py +2 -2
  11. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  12. snowflake/ml/data/data_connector.py +88 -9
  13. snowflake/ml/data/data_ingestor.py +18 -1
  14. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  15. snowflake/ml/data/torch_utils.py +68 -0
  16. snowflake/ml/dataset/dataset.py +1 -3
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +9 -3
  19. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  20. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  21. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  22. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  23. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  24. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  26. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  27. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  29. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  31. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  32. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  33. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  34. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  35. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  36. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  37. snowflake/ml/feature_store/feature_store.py +100 -41
  38. snowflake/ml/feature_store/feature_view.py +149 -5
  39. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  40. snowflake/ml/fileset/fileset.py +1 -1
  41. snowflake/ml/fileset/sfcfs.py +9 -3
  42. snowflake/ml/model/_client/model/model_impl.py +11 -2
  43. snowflake/ml/model/_client/model/model_version_impl.py +186 -20
  44. snowflake/ml/model/_client/ops/model_ops.py +144 -30
  45. snowflake/ml/model/_client/ops/service_ops.py +312 -0
  46. snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
  47. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
  48. snowflake/ml/model/_client/sql/model_version.py +13 -4
  49. snowflake/ml/model/_client/sql/service.py +196 -0
  50. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  51. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  52. snowflake/ml/model/_model_composer/model_composer.py +5 -0
  53. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
  54. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  55. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  57. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  58. snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
  59. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  60. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
  61. snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
  62. snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
  63. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  64. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  65. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  66. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  67. snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
  68. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
  69. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  70. snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
  71. snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
  72. snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
  73. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
  74. snowflake/ml/model/_packager/model_packager.py +4 -1
  75. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  76. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  77. snowflake/ml/model/_signatures/utils.py +9 -0
  78. snowflake/ml/model/models/llm.py +3 -1
  79. snowflake/ml/model/type_hints.py +10 -4
  80. snowflake/ml/modeling/_internal/constants.py +1 -0
  81. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  82. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  83. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  84. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  85. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  86. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  87. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  88. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  89. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  90. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  91. snowflake/ml/modeling/cluster/birch.py +60 -21
  92. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  93. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  94. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  95. snowflake/ml/modeling/cluster/k_means.py +60 -21
  96. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  97. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  98. snowflake/ml/modeling/cluster/optics.py +60 -21
  99. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  100. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  101. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  102. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  103. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  104. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  105. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  106. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  107. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  108. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  109. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  110. snowflake/ml/modeling/covariance/oas.py +60 -21
  111. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  112. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  113. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  114. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  115. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  116. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  117. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  118. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  119. snowflake/ml/modeling/decomposition/pca.py +60 -21
  120. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  121. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  122. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  123. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  124. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  125. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  126. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  127. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  128. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  129. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  130. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  131. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  132. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  133. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  134. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  135. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  136. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  137. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  138. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  139. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  140. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  141. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  142. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  143. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  144. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  145. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  146. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  147. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  148. snowflake/ml/modeling/framework/base.py +28 -19
  149. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  150. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  151. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  152. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  153. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  154. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  155. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  156. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  157. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  158. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  159. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  160. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  161. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  162. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  163. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  164. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  165. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  166. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  167. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  168. snowflake/ml/modeling/linear_model/lars.py +60 -21
  169. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  170. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  171. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  172. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  173. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  174. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  175. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  176. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  177. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  178. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  179. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  180. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  181. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  182. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  183. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  184. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  185. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  186. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  187. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  188. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  189. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  190. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  191. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  192. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  193. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  194. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  195. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  196. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  197. snowflake/ml/modeling/manifold/isomap.py +60 -21
  198. snowflake/ml/modeling/manifold/mds.py +60 -21
  199. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  200. snowflake/ml/modeling/manifold/tsne.py +60 -21
  201. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  202. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  203. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  204. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  205. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  206. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  207. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  208. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  209. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  210. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  211. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  212. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  213. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  214. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  215. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  216. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  217. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  218. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  219. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  220. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  221. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  222. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  223. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  224. snowflake/ml/modeling/pipeline/pipeline.py +4 -12
  225. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  226. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  227. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  228. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  229. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  230. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  231. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  232. snowflake/ml/modeling/svm/svc.py +60 -21
  233. snowflake/ml/modeling/svm/svr.py +60 -21
  234. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  235. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  236. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  237. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  238. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  239. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  240. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  241. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  242. snowflake/ml/registry/_manager/model_manager.py +20 -2
  243. snowflake/ml/registry/model_registry.py +1 -1
  244. snowflake/ml/registry/registry.py +1 -2
  245. snowflake/ml/utils/sql_client.py +22 -0
  246. snowflake/ml/version.py +1 -1
  247. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
  248. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
  249. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  250. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  251. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  252. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -541,12 +538,23 @@ class ElasticNet(BaseTransformer):
541
538
  autogenerated=self._autogenerated,
542
539
  subproject=_SUBPROJECT,
543
540
  )
544
- output_result, fitted_estimator = model_trainer.train_fit_predict(
545
- drop_input_cols=self._drop_input_cols,
546
- expected_output_cols_list=(
547
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
548
- ),
541
+ expected_output_cols = (
542
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
549
543
  )
544
+ if isinstance(dataset, DataFrame):
545
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
546
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
547
+ )
548
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
549
+ drop_input_cols=self._drop_input_cols,
550
+ expected_output_cols_list=expected_output_cols,
551
+ example_output_pd_df=example_output_pd_df,
552
+ )
553
+ else:
554
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
555
+ drop_input_cols=self._drop_input_cols,
556
+ expected_output_cols_list=expected_output_cols,
557
+ )
550
558
  self._sklearn_object = fitted_estimator
551
559
  self._is_fitted = True
552
560
  return output_result
@@ -625,12 +633,41 @@ class ElasticNet(BaseTransformer):
625
633
 
626
634
  return rv
627
635
 
628
- def _align_expected_output_names(
629
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
630
- ) -> List[str]:
636
+ def _align_expected_output(
637
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
638
+ ) -> Tuple[List[str], pd.DataFrame]:
639
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
640
+ and output dataframe with 1 line.
641
+ If the method is fit_predict, run 2 lines of data.
642
+ """
631
643
  # in case the inferred output column names dimension is different
632
644
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
633
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
645
+
646
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
647
+ # so change the minimum of number of rows to 2
648
+ num_examples = 2
649
+ statement_params = telemetry.get_function_usage_statement_params(
650
+ project=_PROJECT,
651
+ subproject=_SUBPROJECT,
652
+ function_name=telemetry.get_statement_params_full_func_name(
653
+ inspect.currentframe(), ElasticNet.__class__.__name__
654
+ ),
655
+ api_calls=[Session.call],
656
+ custom_tags={"autogen": True} if self._autogenerated else None,
657
+ )
658
+ if output_cols_prefix == "fit_predict_":
659
+ if hasattr(self._sklearn_object, "n_clusters"):
660
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
661
+ num_examples = self._sklearn_object.n_clusters
662
+ elif hasattr(self._sklearn_object, "min_samples"):
663
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
664
+ num_examples = self._sklearn_object.min_samples
665
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
666
+ # LocalOutlierFactor expects n_neighbors <= n_samples
667
+ num_examples = self._sklearn_object.n_neighbors
668
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
669
+ else:
670
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
634
671
 
635
672
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
636
673
  # seen during the fit.
@@ -642,12 +679,14 @@ class ElasticNet(BaseTransformer):
642
679
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
643
680
  if self.sample_weight_col:
644
681
  output_df_columns_set -= set(self.sample_weight_col)
682
+
645
683
  # if the dimension of inferred output column names is correct; use it
646
684
  if len(expected_output_cols_list) == len(output_df_columns_set):
647
- return expected_output_cols_list
685
+ return expected_output_cols_list, output_df_pd
648
686
  # otherwise, use the sklearn estimator's output
649
687
  else:
650
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
688
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
689
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
651
690
 
652
691
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
653
692
  @telemetry.send_api_usage_telemetry(
@@ -693,7 +732,7 @@ class ElasticNet(BaseTransformer):
693
732
  drop_input_cols=self._drop_input_cols,
694
733
  expected_output_cols_type="float",
695
734
  )
696
- expected_output_cols = self._align_expected_output_names(
735
+ expected_output_cols, _ = self._align_expected_output(
697
736
  inference_method, dataset, expected_output_cols, output_cols_prefix
698
737
  )
699
738
 
@@ -759,7 +798,7 @@ class ElasticNet(BaseTransformer):
759
798
  drop_input_cols=self._drop_input_cols,
760
799
  expected_output_cols_type="float",
761
800
  )
762
- expected_output_cols = self._align_expected_output_names(
801
+ expected_output_cols, _ = self._align_expected_output(
763
802
  inference_method, dataset, expected_output_cols, output_cols_prefix
764
803
  )
765
804
  elif isinstance(dataset, pd.DataFrame):
@@ -822,7 +861,7 @@ class ElasticNet(BaseTransformer):
822
861
  drop_input_cols=self._drop_input_cols,
823
862
  expected_output_cols_type="float",
824
863
  )
825
- expected_output_cols = self._align_expected_output_names(
864
+ expected_output_cols, _ = self._align_expected_output(
826
865
  inference_method, dataset, expected_output_cols, output_cols_prefix
827
866
  )
828
867
 
@@ -887,7 +926,7 @@ class ElasticNet(BaseTransformer):
887
926
  drop_input_cols = self._drop_input_cols,
888
927
  expected_output_cols_type="float",
889
928
  )
890
- expected_output_cols = self._align_expected_output_names(
929
+ expected_output_cols, _ = self._align_expected_output(
891
930
  inference_method, dataset, expected_output_cols, output_cols_prefix
892
931
  )
893
932
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -577,12 +574,23 @@ class ElasticNetCV(BaseTransformer):
577
574
  autogenerated=self._autogenerated,
578
575
  subproject=_SUBPROJECT,
579
576
  )
580
- output_result, fitted_estimator = model_trainer.train_fit_predict(
581
- drop_input_cols=self._drop_input_cols,
582
- expected_output_cols_list=(
583
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
584
- ),
577
+ expected_output_cols = (
578
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
585
579
  )
580
+ if isinstance(dataset, DataFrame):
581
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
582
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
583
+ )
584
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
585
+ drop_input_cols=self._drop_input_cols,
586
+ expected_output_cols_list=expected_output_cols,
587
+ example_output_pd_df=example_output_pd_df,
588
+ )
589
+ else:
590
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
591
+ drop_input_cols=self._drop_input_cols,
592
+ expected_output_cols_list=expected_output_cols,
593
+ )
586
594
  self._sklearn_object = fitted_estimator
587
595
  self._is_fitted = True
588
596
  return output_result
@@ -661,12 +669,41 @@ class ElasticNetCV(BaseTransformer):
661
669
 
662
670
  return rv
663
671
 
664
- def _align_expected_output_names(
665
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
666
- ) -> List[str]:
672
+ def _align_expected_output(
673
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
674
+ ) -> Tuple[List[str], pd.DataFrame]:
675
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
676
+ and output dataframe with 1 line.
677
+ If the method is fit_predict, run 2 lines of data.
678
+ """
667
679
  # in case the inferred output column names dimension is different
668
680
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
669
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
681
+
682
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
683
+ # so change the minimum of number of rows to 2
684
+ num_examples = 2
685
+ statement_params = telemetry.get_function_usage_statement_params(
686
+ project=_PROJECT,
687
+ subproject=_SUBPROJECT,
688
+ function_name=telemetry.get_statement_params_full_func_name(
689
+ inspect.currentframe(), ElasticNetCV.__class__.__name__
690
+ ),
691
+ api_calls=[Session.call],
692
+ custom_tags={"autogen": True} if self._autogenerated else None,
693
+ )
694
+ if output_cols_prefix == "fit_predict_":
695
+ if hasattr(self._sklearn_object, "n_clusters"):
696
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
697
+ num_examples = self._sklearn_object.n_clusters
698
+ elif hasattr(self._sklearn_object, "min_samples"):
699
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
700
+ num_examples = self._sklearn_object.min_samples
701
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
702
+ # LocalOutlierFactor expects n_neighbors <= n_samples
703
+ num_examples = self._sklearn_object.n_neighbors
704
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
705
+ else:
706
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
670
707
 
671
708
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
672
709
  # seen during the fit.
@@ -678,12 +715,14 @@ class ElasticNetCV(BaseTransformer):
678
715
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
679
716
  if self.sample_weight_col:
680
717
  output_df_columns_set -= set(self.sample_weight_col)
718
+
681
719
  # if the dimension of inferred output column names is correct; use it
682
720
  if len(expected_output_cols_list) == len(output_df_columns_set):
683
- return expected_output_cols_list
721
+ return expected_output_cols_list, output_df_pd
684
722
  # otherwise, use the sklearn estimator's output
685
723
  else:
686
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
724
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
725
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
687
726
 
688
727
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
689
728
  @telemetry.send_api_usage_telemetry(
@@ -729,7 +768,7 @@ class ElasticNetCV(BaseTransformer):
729
768
  drop_input_cols=self._drop_input_cols,
730
769
  expected_output_cols_type="float",
731
770
  )
732
- expected_output_cols = self._align_expected_output_names(
771
+ expected_output_cols, _ = self._align_expected_output(
733
772
  inference_method, dataset, expected_output_cols, output_cols_prefix
734
773
  )
735
774
 
@@ -795,7 +834,7 @@ class ElasticNetCV(BaseTransformer):
795
834
  drop_input_cols=self._drop_input_cols,
796
835
  expected_output_cols_type="float",
797
836
  )
798
- expected_output_cols = self._align_expected_output_names(
837
+ expected_output_cols, _ = self._align_expected_output(
799
838
  inference_method, dataset, expected_output_cols, output_cols_prefix
800
839
  )
801
840
  elif isinstance(dataset, pd.DataFrame):
@@ -858,7 +897,7 @@ class ElasticNetCV(BaseTransformer):
858
897
  drop_input_cols=self._drop_input_cols,
859
898
  expected_output_cols_type="float",
860
899
  )
861
- expected_output_cols = self._align_expected_output_names(
900
+ expected_output_cols, _ = self._align_expected_output(
862
901
  inference_method, dataset, expected_output_cols, output_cols_prefix
863
902
  )
864
903
 
@@ -923,7 +962,7 @@ class ElasticNetCV(BaseTransformer):
923
962
  drop_input_cols = self._drop_input_cols,
924
963
  expected_output_cols_type="float",
925
964
  )
926
- expected_output_cols = self._align_expected_output_names(
965
+ expected_output_cols, _ = self._align_expected_output(
927
966
  inference_method, dataset, expected_output_cols, output_cols_prefix
928
967
  )
929
968
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -522,12 +519,23 @@ class GammaRegressor(BaseTransformer):
522
519
  autogenerated=self._autogenerated,
523
520
  subproject=_SUBPROJECT,
524
521
  )
525
- output_result, fitted_estimator = model_trainer.train_fit_predict(
526
- drop_input_cols=self._drop_input_cols,
527
- expected_output_cols_list=(
528
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
529
- ),
522
+ expected_output_cols = (
523
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
530
524
  )
525
+ if isinstance(dataset, DataFrame):
526
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
527
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
528
+ )
529
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
530
+ drop_input_cols=self._drop_input_cols,
531
+ expected_output_cols_list=expected_output_cols,
532
+ example_output_pd_df=example_output_pd_df,
533
+ )
534
+ else:
535
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
536
+ drop_input_cols=self._drop_input_cols,
537
+ expected_output_cols_list=expected_output_cols,
538
+ )
531
539
  self._sklearn_object = fitted_estimator
532
540
  self._is_fitted = True
533
541
  return output_result
@@ -606,12 +614,41 @@ class GammaRegressor(BaseTransformer):
606
614
 
607
615
  return rv
608
616
 
609
- def _align_expected_output_names(
610
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
611
- ) -> List[str]:
617
+ def _align_expected_output(
618
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
619
+ ) -> Tuple[List[str], pd.DataFrame]:
620
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
621
+ and output dataframe with 1 line.
622
+ If the method is fit_predict, run 2 lines of data.
623
+ """
612
624
  # in case the inferred output column names dimension is different
613
625
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
614
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
626
+
627
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
628
+ # so change the minimum of number of rows to 2
629
+ num_examples = 2
630
+ statement_params = telemetry.get_function_usage_statement_params(
631
+ project=_PROJECT,
632
+ subproject=_SUBPROJECT,
633
+ function_name=telemetry.get_statement_params_full_func_name(
634
+ inspect.currentframe(), GammaRegressor.__class__.__name__
635
+ ),
636
+ api_calls=[Session.call],
637
+ custom_tags={"autogen": True} if self._autogenerated else None,
638
+ )
639
+ if output_cols_prefix == "fit_predict_":
640
+ if hasattr(self._sklearn_object, "n_clusters"):
641
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
642
+ num_examples = self._sklearn_object.n_clusters
643
+ elif hasattr(self._sklearn_object, "min_samples"):
644
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
645
+ num_examples = self._sklearn_object.min_samples
646
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
647
+ # LocalOutlierFactor expects n_neighbors <= n_samples
648
+ num_examples = self._sklearn_object.n_neighbors
649
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
650
+ else:
651
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
615
652
 
616
653
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
617
654
  # seen during the fit.
@@ -623,12 +660,14 @@ class GammaRegressor(BaseTransformer):
623
660
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
624
661
  if self.sample_weight_col:
625
662
  output_df_columns_set -= set(self.sample_weight_col)
663
+
626
664
  # if the dimension of inferred output column names is correct; use it
627
665
  if len(expected_output_cols_list) == len(output_df_columns_set):
628
- return expected_output_cols_list
666
+ return expected_output_cols_list, output_df_pd
629
667
  # otherwise, use the sklearn estimator's output
630
668
  else:
631
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
669
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
670
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
632
671
 
633
672
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
634
673
  @telemetry.send_api_usage_telemetry(
@@ -674,7 +713,7 @@ class GammaRegressor(BaseTransformer):
674
713
  drop_input_cols=self._drop_input_cols,
675
714
  expected_output_cols_type="float",
676
715
  )
677
- expected_output_cols = self._align_expected_output_names(
716
+ expected_output_cols, _ = self._align_expected_output(
678
717
  inference_method, dataset, expected_output_cols, output_cols_prefix
679
718
  )
680
719
 
@@ -740,7 +779,7 @@ class GammaRegressor(BaseTransformer):
740
779
  drop_input_cols=self._drop_input_cols,
741
780
  expected_output_cols_type="float",
742
781
  )
743
- expected_output_cols = self._align_expected_output_names(
782
+ expected_output_cols, _ = self._align_expected_output(
744
783
  inference_method, dataset, expected_output_cols, output_cols_prefix
745
784
  )
746
785
  elif isinstance(dataset, pd.DataFrame):
@@ -803,7 +842,7 @@ class GammaRegressor(BaseTransformer):
803
842
  drop_input_cols=self._drop_input_cols,
804
843
  expected_output_cols_type="float",
805
844
  )
806
- expected_output_cols = self._align_expected_output_names(
845
+ expected_output_cols, _ = self._align_expected_output(
807
846
  inference_method, dataset, expected_output_cols, output_cols_prefix
808
847
  )
809
848
 
@@ -868,7 +907,7 @@ class GammaRegressor(BaseTransformer):
868
907
  drop_input_cols = self._drop_input_cols,
869
908
  expected_output_cols_type="float",
870
909
  )
871
- expected_output_cols = self._align_expected_output_names(
910
+ expected_output_cols, _ = self._align_expected_output(
872
911
  inference_method, dataset, expected_output_cols, output_cols_prefix
873
912
  )
874
913
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -505,12 +502,23 @@ class HuberRegressor(BaseTransformer):
505
502
  autogenerated=self._autogenerated,
506
503
  subproject=_SUBPROJECT,
507
504
  )
508
- output_result, fitted_estimator = model_trainer.train_fit_predict(
509
- drop_input_cols=self._drop_input_cols,
510
- expected_output_cols_list=(
511
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
512
- ),
505
+ expected_output_cols = (
506
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
513
507
  )
508
+ if isinstance(dataset, DataFrame):
509
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
510
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
511
+ )
512
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
513
+ drop_input_cols=self._drop_input_cols,
514
+ expected_output_cols_list=expected_output_cols,
515
+ example_output_pd_df=example_output_pd_df,
516
+ )
517
+ else:
518
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
519
+ drop_input_cols=self._drop_input_cols,
520
+ expected_output_cols_list=expected_output_cols,
521
+ )
514
522
  self._sklearn_object = fitted_estimator
515
523
  self._is_fitted = True
516
524
  return output_result
@@ -589,12 +597,41 @@ class HuberRegressor(BaseTransformer):
589
597
 
590
598
  return rv
591
599
 
592
- def _align_expected_output_names(
593
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
594
- ) -> List[str]:
600
+ def _align_expected_output(
601
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
602
+ ) -> Tuple[List[str], pd.DataFrame]:
603
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
604
+ and output dataframe with 1 line.
605
+ If the method is fit_predict, run 2 lines of data.
606
+ """
595
607
  # in case the inferred output column names dimension is different
596
608
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
597
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
609
+
610
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
611
+ # so change the minimum of number of rows to 2
612
+ num_examples = 2
613
+ statement_params = telemetry.get_function_usage_statement_params(
614
+ project=_PROJECT,
615
+ subproject=_SUBPROJECT,
616
+ function_name=telemetry.get_statement_params_full_func_name(
617
+ inspect.currentframe(), HuberRegressor.__class__.__name__
618
+ ),
619
+ api_calls=[Session.call],
620
+ custom_tags={"autogen": True} if self._autogenerated else None,
621
+ )
622
+ if output_cols_prefix == "fit_predict_":
623
+ if hasattr(self._sklearn_object, "n_clusters"):
624
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
625
+ num_examples = self._sklearn_object.n_clusters
626
+ elif hasattr(self._sklearn_object, "min_samples"):
627
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
628
+ num_examples = self._sklearn_object.min_samples
629
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
630
+ # LocalOutlierFactor expects n_neighbors <= n_samples
631
+ num_examples = self._sklearn_object.n_neighbors
632
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
633
+ else:
634
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
598
635
 
599
636
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
600
637
  # seen during the fit.
@@ -606,12 +643,14 @@ class HuberRegressor(BaseTransformer):
606
643
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
607
644
  if self.sample_weight_col:
608
645
  output_df_columns_set -= set(self.sample_weight_col)
646
+
609
647
  # if the dimension of inferred output column names is correct; use it
610
648
  if len(expected_output_cols_list) == len(output_df_columns_set):
611
- return expected_output_cols_list
649
+ return expected_output_cols_list, output_df_pd
612
650
  # otherwise, use the sklearn estimator's output
613
651
  else:
614
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
652
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
653
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
615
654
 
616
655
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
617
656
  @telemetry.send_api_usage_telemetry(
@@ -657,7 +696,7 @@ class HuberRegressor(BaseTransformer):
657
696
  drop_input_cols=self._drop_input_cols,
658
697
  expected_output_cols_type="float",
659
698
  )
660
- expected_output_cols = self._align_expected_output_names(
699
+ expected_output_cols, _ = self._align_expected_output(
661
700
  inference_method, dataset, expected_output_cols, output_cols_prefix
662
701
  )
663
702
 
@@ -723,7 +762,7 @@ class HuberRegressor(BaseTransformer):
723
762
  drop_input_cols=self._drop_input_cols,
724
763
  expected_output_cols_type="float",
725
764
  )
726
- expected_output_cols = self._align_expected_output_names(
765
+ expected_output_cols, _ = self._align_expected_output(
727
766
  inference_method, dataset, expected_output_cols, output_cols_prefix
728
767
  )
729
768
  elif isinstance(dataset, pd.DataFrame):
@@ -786,7 +825,7 @@ class HuberRegressor(BaseTransformer):
786
825
  drop_input_cols=self._drop_input_cols,
787
826
  expected_output_cols_type="float",
788
827
  )
789
- expected_output_cols = self._align_expected_output_names(
828
+ expected_output_cols, _ = self._align_expected_output(
790
829
  inference_method, dataset, expected_output_cols, output_cols_prefix
791
830
  )
792
831
 
@@ -851,7 +890,7 @@ class HuberRegressor(BaseTransformer):
851
890
  drop_input_cols = self._drop_input_cols,
852
891
  expected_output_cols_type="float",
853
892
  )
854
- expected_output_cols = self._align_expected_output_names(
893
+ expected_output_cols, _ = self._align_expected_output(
855
894
  inference_method, dataset, expected_output_cols, output_cols_prefix
856
895
  )
857
896