snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +156 -20
  5. snowflake/ml/_internal/utils/identifier.py +48 -11
  6. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  7. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  8. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  9. snowflake/ml/_internal/utils/table_manager.py +19 -1
  10. snowflake/ml/_internal/utils/uri.py +2 -2
  11. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  12. snowflake/ml/data/data_connector.py +88 -9
  13. snowflake/ml/data/data_ingestor.py +18 -1
  14. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  15. snowflake/ml/data/torch_utils.py +68 -0
  16. snowflake/ml/dataset/dataset.py +1 -3
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +9 -3
  19. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  20. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  21. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  22. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  23. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  24. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  26. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  27. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  29. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  31. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  32. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  33. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  34. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  35. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  36. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  37. snowflake/ml/feature_store/feature_store.py +100 -41
  38. snowflake/ml/feature_store/feature_view.py +149 -5
  39. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  40. snowflake/ml/fileset/fileset.py +1 -1
  41. snowflake/ml/fileset/sfcfs.py +9 -3
  42. snowflake/ml/model/_client/model/model_impl.py +11 -2
  43. snowflake/ml/model/_client/model/model_version_impl.py +186 -20
  44. snowflake/ml/model/_client/ops/model_ops.py +144 -30
  45. snowflake/ml/model/_client/ops/service_ops.py +312 -0
  46. snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
  47. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
  48. snowflake/ml/model/_client/sql/model_version.py +13 -4
  49. snowflake/ml/model/_client/sql/service.py +196 -0
  50. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  51. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  52. snowflake/ml/model/_model_composer/model_composer.py +5 -0
  53. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
  54. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  55. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  57. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  58. snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
  59. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  60. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
  61. snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
  62. snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
  63. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  64. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  65. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  66. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  67. snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
  68. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
  69. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  70. snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
  71. snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
  72. snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
  73. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
  74. snowflake/ml/model/_packager/model_packager.py +4 -1
  75. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  76. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  77. snowflake/ml/model/_signatures/utils.py +9 -0
  78. snowflake/ml/model/models/llm.py +3 -1
  79. snowflake/ml/model/type_hints.py +10 -4
  80. snowflake/ml/modeling/_internal/constants.py +1 -0
  81. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  82. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  83. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  84. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  85. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  86. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  87. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  88. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  89. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  90. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  91. snowflake/ml/modeling/cluster/birch.py +60 -21
  92. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  93. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  94. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  95. snowflake/ml/modeling/cluster/k_means.py +60 -21
  96. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  97. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  98. snowflake/ml/modeling/cluster/optics.py +60 -21
  99. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  100. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  101. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  102. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  103. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  104. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  105. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  106. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  107. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  108. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  109. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  110. snowflake/ml/modeling/covariance/oas.py +60 -21
  111. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  112. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  113. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  114. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  115. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  116. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  117. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  118. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  119. snowflake/ml/modeling/decomposition/pca.py +60 -21
  120. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  121. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  122. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  123. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  124. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  125. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  126. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  127. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  128. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  129. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  130. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  131. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  132. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  133. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  134. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  135. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  136. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  137. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  138. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  139. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  140. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  141. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  142. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  143. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  144. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  145. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  146. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  147. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  148. snowflake/ml/modeling/framework/base.py +28 -19
  149. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  150. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  151. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  152. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  153. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  154. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  155. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  156. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  157. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  158. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  159. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  160. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  161. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  162. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  163. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  164. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  165. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  166. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  167. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  168. snowflake/ml/modeling/linear_model/lars.py +60 -21
  169. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  170. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  171. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  172. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  173. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  174. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  175. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  176. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  177. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  178. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  179. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  180. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  181. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  182. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  183. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  184. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  185. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  186. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  187. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  188. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  189. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  190. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  191. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  192. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  193. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  194. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  195. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  196. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  197. snowflake/ml/modeling/manifold/isomap.py +60 -21
  198. snowflake/ml/modeling/manifold/mds.py +60 -21
  199. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  200. snowflake/ml/modeling/manifold/tsne.py +60 -21
  201. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  202. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  203. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  204. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  205. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  206. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  207. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  208. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  209. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  210. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  211. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  212. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  213. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  214. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  215. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  216. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  217. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  218. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  219. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  220. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  221. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  222. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  223. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  224. snowflake/ml/modeling/pipeline/pipeline.py +4 -12
  225. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  226. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  227. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  228. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  229. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  230. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  231. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  232. snowflake/ml/modeling/svm/svc.py +60 -21
  233. snowflake/ml/modeling/svm/svr.py +60 -21
  234. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  235. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  236. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  237. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  238. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  239. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  240. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  241. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  242. snowflake/ml/registry/_manager/model_manager.py +20 -2
  243. snowflake/ml/registry/model_registry.py +1 -1
  244. snowflake/ml/registry/registry.py +1 -2
  245. snowflake/ml/utils/sql_client.py +22 -0
  246. snowflake/ml/version.py +1 -1
  247. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
  248. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
  249. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  250. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  251. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  252. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -661,12 +658,23 @@ class SGDClassifier(BaseTransformer):
661
658
  autogenerated=self._autogenerated,
662
659
  subproject=_SUBPROJECT,
663
660
  )
664
- output_result, fitted_estimator = model_trainer.train_fit_predict(
665
- drop_input_cols=self._drop_input_cols,
666
- expected_output_cols_list=(
667
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
668
- ),
661
+ expected_output_cols = (
662
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
669
663
  )
664
+ if isinstance(dataset, DataFrame):
665
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
666
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
667
+ )
668
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
669
+ drop_input_cols=self._drop_input_cols,
670
+ expected_output_cols_list=expected_output_cols,
671
+ example_output_pd_df=example_output_pd_df,
672
+ )
673
+ else:
674
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
675
+ drop_input_cols=self._drop_input_cols,
676
+ expected_output_cols_list=expected_output_cols,
677
+ )
670
678
  self._sklearn_object = fitted_estimator
671
679
  self._is_fitted = True
672
680
  return output_result
@@ -745,12 +753,41 @@ class SGDClassifier(BaseTransformer):
745
753
 
746
754
  return rv
747
755
 
748
- def _align_expected_output_names(
749
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
750
- ) -> List[str]:
756
+ def _align_expected_output(
757
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
758
+ ) -> Tuple[List[str], pd.DataFrame]:
759
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
760
+ and output dataframe with 1 line.
761
+ If the method is fit_predict, run 2 lines of data.
762
+ """
751
763
  # in case the inferred output column names dimension is different
752
764
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
753
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
765
+
766
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
767
+ # so change the minimum of number of rows to 2
768
+ num_examples = 2
769
+ statement_params = telemetry.get_function_usage_statement_params(
770
+ project=_PROJECT,
771
+ subproject=_SUBPROJECT,
772
+ function_name=telemetry.get_statement_params_full_func_name(
773
+ inspect.currentframe(), SGDClassifier.__class__.__name__
774
+ ),
775
+ api_calls=[Session.call],
776
+ custom_tags={"autogen": True} if self._autogenerated else None,
777
+ )
778
+ if output_cols_prefix == "fit_predict_":
779
+ if hasattr(self._sklearn_object, "n_clusters"):
780
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
781
+ num_examples = self._sklearn_object.n_clusters
782
+ elif hasattr(self._sklearn_object, "min_samples"):
783
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
784
+ num_examples = self._sklearn_object.min_samples
785
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
786
+ # LocalOutlierFactor expects n_neighbors <= n_samples
787
+ num_examples = self._sklearn_object.n_neighbors
788
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
789
+ else:
790
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
754
791
 
755
792
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
756
793
  # seen during the fit.
@@ -762,12 +799,14 @@ class SGDClassifier(BaseTransformer):
762
799
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
763
800
  if self.sample_weight_col:
764
801
  output_df_columns_set -= set(self.sample_weight_col)
802
+
765
803
  # if the dimension of inferred output column names is correct; use it
766
804
  if len(expected_output_cols_list) == len(output_df_columns_set):
767
- return expected_output_cols_list
805
+ return expected_output_cols_list, output_df_pd
768
806
  # otherwise, use the sklearn estimator's output
769
807
  else:
770
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
808
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
809
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
771
810
 
772
811
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
773
812
  @telemetry.send_api_usage_telemetry(
@@ -815,7 +854,7 @@ class SGDClassifier(BaseTransformer):
815
854
  drop_input_cols=self._drop_input_cols,
816
855
  expected_output_cols_type="float",
817
856
  )
818
- expected_output_cols = self._align_expected_output_names(
857
+ expected_output_cols, _ = self._align_expected_output(
819
858
  inference_method, dataset, expected_output_cols, output_cols_prefix
820
859
  )
821
860
 
@@ -883,7 +922,7 @@ class SGDClassifier(BaseTransformer):
883
922
  drop_input_cols=self._drop_input_cols,
884
923
  expected_output_cols_type="float",
885
924
  )
886
- expected_output_cols = self._align_expected_output_names(
925
+ expected_output_cols, _ = self._align_expected_output(
887
926
  inference_method, dataset, expected_output_cols, output_cols_prefix
888
927
  )
889
928
  elif isinstance(dataset, pd.DataFrame):
@@ -948,7 +987,7 @@ class SGDClassifier(BaseTransformer):
948
987
  drop_input_cols=self._drop_input_cols,
949
988
  expected_output_cols_type="float",
950
989
  )
951
- expected_output_cols = self._align_expected_output_names(
990
+ expected_output_cols, _ = self._align_expected_output(
952
991
  inference_method, dataset, expected_output_cols, output_cols_prefix
953
992
  )
954
993
 
@@ -1013,7 +1052,7 @@ class SGDClassifier(BaseTransformer):
1013
1052
  drop_input_cols = self._drop_input_cols,
1014
1053
  expected_output_cols_type="float",
1015
1054
  )
1016
- expected_output_cols = self._align_expected_output_names(
1055
+ expected_output_cols, _ = self._align_expected_output(
1017
1056
  inference_method, dataset, expected_output_cols, output_cols_prefix
1018
1057
  )
1019
1058
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -563,12 +560,23 @@ class SGDOneClassSVM(BaseTransformer):
563
560
  autogenerated=self._autogenerated,
564
561
  subproject=_SUBPROJECT,
565
562
  )
566
- output_result, fitted_estimator = model_trainer.train_fit_predict(
567
- drop_input_cols=self._drop_input_cols,
568
- expected_output_cols_list=(
569
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
570
- ),
563
+ expected_output_cols = (
564
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
571
565
  )
566
+ if isinstance(dataset, DataFrame):
567
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
568
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
569
+ )
570
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
571
+ drop_input_cols=self._drop_input_cols,
572
+ expected_output_cols_list=expected_output_cols,
573
+ example_output_pd_df=example_output_pd_df,
574
+ )
575
+ else:
576
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
577
+ drop_input_cols=self._drop_input_cols,
578
+ expected_output_cols_list=expected_output_cols,
579
+ )
572
580
  self._sklearn_object = fitted_estimator
573
581
  self._is_fitted = True
574
582
  return output_result
@@ -647,12 +655,41 @@ class SGDOneClassSVM(BaseTransformer):
647
655
 
648
656
  return rv
649
657
 
650
- def _align_expected_output_names(
651
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
652
- ) -> List[str]:
658
+ def _align_expected_output(
659
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
660
+ ) -> Tuple[List[str], pd.DataFrame]:
661
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
662
+ and output dataframe with 1 line.
663
+ If the method is fit_predict, run 2 lines of data.
664
+ """
653
665
  # in case the inferred output column names dimension is different
654
666
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
655
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
667
+
668
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
669
+ # so change the minimum of number of rows to 2
670
+ num_examples = 2
671
+ statement_params = telemetry.get_function_usage_statement_params(
672
+ project=_PROJECT,
673
+ subproject=_SUBPROJECT,
674
+ function_name=telemetry.get_statement_params_full_func_name(
675
+ inspect.currentframe(), SGDOneClassSVM.__class__.__name__
676
+ ),
677
+ api_calls=[Session.call],
678
+ custom_tags={"autogen": True} if self._autogenerated else None,
679
+ )
680
+ if output_cols_prefix == "fit_predict_":
681
+ if hasattr(self._sklearn_object, "n_clusters"):
682
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
683
+ num_examples = self._sklearn_object.n_clusters
684
+ elif hasattr(self._sklearn_object, "min_samples"):
685
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
686
+ num_examples = self._sklearn_object.min_samples
687
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
688
+ # LocalOutlierFactor expects n_neighbors <= n_samples
689
+ num_examples = self._sklearn_object.n_neighbors
690
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
691
+ else:
692
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
656
693
 
657
694
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
658
695
  # seen during the fit.
@@ -664,12 +701,14 @@ class SGDOneClassSVM(BaseTransformer):
664
701
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
665
702
  if self.sample_weight_col:
666
703
  output_df_columns_set -= set(self.sample_weight_col)
704
+
667
705
  # if the dimension of inferred output column names is correct; use it
668
706
  if len(expected_output_cols_list) == len(output_df_columns_set):
669
- return expected_output_cols_list
707
+ return expected_output_cols_list, output_df_pd
670
708
  # otherwise, use the sklearn estimator's output
671
709
  else:
672
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
710
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
711
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
673
712
 
674
713
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
675
714
  @telemetry.send_api_usage_telemetry(
@@ -715,7 +754,7 @@ class SGDOneClassSVM(BaseTransformer):
715
754
  drop_input_cols=self._drop_input_cols,
716
755
  expected_output_cols_type="float",
717
756
  )
718
- expected_output_cols = self._align_expected_output_names(
757
+ expected_output_cols, _ = self._align_expected_output(
719
758
  inference_method, dataset, expected_output_cols, output_cols_prefix
720
759
  )
721
760
 
@@ -781,7 +820,7 @@ class SGDOneClassSVM(BaseTransformer):
781
820
  drop_input_cols=self._drop_input_cols,
782
821
  expected_output_cols_type="float",
783
822
  )
784
- expected_output_cols = self._align_expected_output_names(
823
+ expected_output_cols, _ = self._align_expected_output(
785
824
  inference_method, dataset, expected_output_cols, output_cols_prefix
786
825
  )
787
826
  elif isinstance(dataset, pd.DataFrame):
@@ -846,7 +885,7 @@ class SGDOneClassSVM(BaseTransformer):
846
885
  drop_input_cols=self._drop_input_cols,
847
886
  expected_output_cols_type="float",
848
887
  )
849
- expected_output_cols = self._align_expected_output_names(
888
+ expected_output_cols, _ = self._align_expected_output(
850
889
  inference_method, dataset, expected_output_cols, output_cols_prefix
851
890
  )
852
891
 
@@ -913,7 +952,7 @@ class SGDOneClassSVM(BaseTransformer):
913
952
  drop_input_cols = self._drop_input_cols,
914
953
  expected_output_cols_type="float",
915
954
  )
916
- expected_output_cols = self._align_expected_output_names(
955
+ expected_output_cols, _ = self._align_expected_output(
917
956
  inference_method, dataset, expected_output_cols, output_cols_prefix
918
957
  )
919
958
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -627,12 +624,23 @@ class SGDRegressor(BaseTransformer):
627
624
  autogenerated=self._autogenerated,
628
625
  subproject=_SUBPROJECT,
629
626
  )
630
- output_result, fitted_estimator = model_trainer.train_fit_predict(
631
- drop_input_cols=self._drop_input_cols,
632
- expected_output_cols_list=(
633
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
634
- ),
627
+ expected_output_cols = (
628
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
635
629
  )
630
+ if isinstance(dataset, DataFrame):
631
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
632
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
633
+ )
634
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
635
+ drop_input_cols=self._drop_input_cols,
636
+ expected_output_cols_list=expected_output_cols,
637
+ example_output_pd_df=example_output_pd_df,
638
+ )
639
+ else:
640
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
641
+ drop_input_cols=self._drop_input_cols,
642
+ expected_output_cols_list=expected_output_cols,
643
+ )
636
644
  self._sklearn_object = fitted_estimator
637
645
  self._is_fitted = True
638
646
  return output_result
@@ -711,12 +719,41 @@ class SGDRegressor(BaseTransformer):
711
719
 
712
720
  return rv
713
721
 
714
- def _align_expected_output_names(
715
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
716
- ) -> List[str]:
722
+ def _align_expected_output(
723
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
724
+ ) -> Tuple[List[str], pd.DataFrame]:
725
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
726
+ and output dataframe with 1 line.
727
+ If the method is fit_predict, run 2 lines of data.
728
+ """
717
729
  # in case the inferred output column names dimension is different
718
730
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
719
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
731
+
732
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
733
+ # so change the minimum of number of rows to 2
734
+ num_examples = 2
735
+ statement_params = telemetry.get_function_usage_statement_params(
736
+ project=_PROJECT,
737
+ subproject=_SUBPROJECT,
738
+ function_name=telemetry.get_statement_params_full_func_name(
739
+ inspect.currentframe(), SGDRegressor.__class__.__name__
740
+ ),
741
+ api_calls=[Session.call],
742
+ custom_tags={"autogen": True} if self._autogenerated else None,
743
+ )
744
+ if output_cols_prefix == "fit_predict_":
745
+ if hasattr(self._sklearn_object, "n_clusters"):
746
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
747
+ num_examples = self._sklearn_object.n_clusters
748
+ elif hasattr(self._sklearn_object, "min_samples"):
749
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
750
+ num_examples = self._sklearn_object.min_samples
751
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
752
+ # LocalOutlierFactor expects n_neighbors <= n_samples
753
+ num_examples = self._sklearn_object.n_neighbors
754
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
755
+ else:
756
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
720
757
 
721
758
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
722
759
  # seen during the fit.
@@ -728,12 +765,14 @@ class SGDRegressor(BaseTransformer):
728
765
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
729
766
  if self.sample_weight_col:
730
767
  output_df_columns_set -= set(self.sample_weight_col)
768
+
731
769
  # if the dimension of inferred output column names is correct; use it
732
770
  if len(expected_output_cols_list) == len(output_df_columns_set):
733
- return expected_output_cols_list
771
+ return expected_output_cols_list, output_df_pd
734
772
  # otherwise, use the sklearn estimator's output
735
773
  else:
736
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
774
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
775
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
737
776
 
738
777
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
739
778
  @telemetry.send_api_usage_telemetry(
@@ -779,7 +818,7 @@ class SGDRegressor(BaseTransformer):
779
818
  drop_input_cols=self._drop_input_cols,
780
819
  expected_output_cols_type="float",
781
820
  )
782
- expected_output_cols = self._align_expected_output_names(
821
+ expected_output_cols, _ = self._align_expected_output(
783
822
  inference_method, dataset, expected_output_cols, output_cols_prefix
784
823
  )
785
824
 
@@ -845,7 +884,7 @@ class SGDRegressor(BaseTransformer):
845
884
  drop_input_cols=self._drop_input_cols,
846
885
  expected_output_cols_type="float",
847
886
  )
848
- expected_output_cols = self._align_expected_output_names(
887
+ expected_output_cols, _ = self._align_expected_output(
849
888
  inference_method, dataset, expected_output_cols, output_cols_prefix
850
889
  )
851
890
  elif isinstance(dataset, pd.DataFrame):
@@ -908,7 +947,7 @@ class SGDRegressor(BaseTransformer):
908
947
  drop_input_cols=self._drop_input_cols,
909
948
  expected_output_cols_type="float",
910
949
  )
911
- expected_output_cols = self._align_expected_output_names(
950
+ expected_output_cols, _ = self._align_expected_output(
912
951
  inference_method, dataset, expected_output_cols, output_cols_prefix
913
952
  )
914
953
 
@@ -973,7 +1012,7 @@ class SGDRegressor(BaseTransformer):
973
1012
  drop_input_cols = self._drop_input_cols,
974
1013
  expected_output_cols_type="float",
975
1014
  )
976
- expected_output_cols = self._align_expected_output_names(
1015
+ expected_output_cols, _ = self._align_expected_output(
977
1016
  inference_method, dataset, expected_output_cols, output_cols_prefix
978
1017
  )
979
1018
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -529,12 +526,23 @@ class TheilSenRegressor(BaseTransformer):
529
526
  autogenerated=self._autogenerated,
530
527
  subproject=_SUBPROJECT,
531
528
  )
532
- output_result, fitted_estimator = model_trainer.train_fit_predict(
533
- drop_input_cols=self._drop_input_cols,
534
- expected_output_cols_list=(
535
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
536
- ),
529
+ expected_output_cols = (
530
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
537
531
  )
532
+ if isinstance(dataset, DataFrame):
533
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
534
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
535
+ )
536
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
537
+ drop_input_cols=self._drop_input_cols,
538
+ expected_output_cols_list=expected_output_cols,
539
+ example_output_pd_df=example_output_pd_df,
540
+ )
541
+ else:
542
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
543
+ drop_input_cols=self._drop_input_cols,
544
+ expected_output_cols_list=expected_output_cols,
545
+ )
538
546
  self._sklearn_object = fitted_estimator
539
547
  self._is_fitted = True
540
548
  return output_result
@@ -613,12 +621,41 @@ class TheilSenRegressor(BaseTransformer):
613
621
 
614
622
  return rv
615
623
 
616
- def _align_expected_output_names(
617
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
618
- ) -> List[str]:
624
+ def _align_expected_output(
625
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
626
+ ) -> Tuple[List[str], pd.DataFrame]:
627
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
628
+ and output dataframe with 1 line.
629
+ If the method is fit_predict, run 2 lines of data.
630
+ """
619
631
  # in case the inferred output column names dimension is different
620
632
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
621
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
633
+
634
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
635
+ # so change the minimum of number of rows to 2
636
+ num_examples = 2
637
+ statement_params = telemetry.get_function_usage_statement_params(
638
+ project=_PROJECT,
639
+ subproject=_SUBPROJECT,
640
+ function_name=telemetry.get_statement_params_full_func_name(
641
+ inspect.currentframe(), TheilSenRegressor.__class__.__name__
642
+ ),
643
+ api_calls=[Session.call],
644
+ custom_tags={"autogen": True} if self._autogenerated else None,
645
+ )
646
+ if output_cols_prefix == "fit_predict_":
647
+ if hasattr(self._sklearn_object, "n_clusters"):
648
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
649
+ num_examples = self._sklearn_object.n_clusters
650
+ elif hasattr(self._sklearn_object, "min_samples"):
651
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
652
+ num_examples = self._sklearn_object.min_samples
653
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
654
+ # LocalOutlierFactor expects n_neighbors <= n_samples
655
+ num_examples = self._sklearn_object.n_neighbors
656
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
657
+ else:
658
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
622
659
 
623
660
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
624
661
  # seen during the fit.
@@ -630,12 +667,14 @@ class TheilSenRegressor(BaseTransformer):
630
667
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
631
668
  if self.sample_weight_col:
632
669
  output_df_columns_set -= set(self.sample_weight_col)
670
+
633
671
  # if the dimension of inferred output column names is correct; use it
634
672
  if len(expected_output_cols_list) == len(output_df_columns_set):
635
- return expected_output_cols_list
673
+ return expected_output_cols_list, output_df_pd
636
674
  # otherwise, use the sklearn estimator's output
637
675
  else:
638
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
676
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
677
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
639
678
 
640
679
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
641
680
  @telemetry.send_api_usage_telemetry(
@@ -681,7 +720,7 @@ class TheilSenRegressor(BaseTransformer):
681
720
  drop_input_cols=self._drop_input_cols,
682
721
  expected_output_cols_type="float",
683
722
  )
684
- expected_output_cols = self._align_expected_output_names(
723
+ expected_output_cols, _ = self._align_expected_output(
685
724
  inference_method, dataset, expected_output_cols, output_cols_prefix
686
725
  )
687
726
 
@@ -747,7 +786,7 @@ class TheilSenRegressor(BaseTransformer):
747
786
  drop_input_cols=self._drop_input_cols,
748
787
  expected_output_cols_type="float",
749
788
  )
750
- expected_output_cols = self._align_expected_output_names(
789
+ expected_output_cols, _ = self._align_expected_output(
751
790
  inference_method, dataset, expected_output_cols, output_cols_prefix
752
791
  )
753
792
  elif isinstance(dataset, pd.DataFrame):
@@ -810,7 +849,7 @@ class TheilSenRegressor(BaseTransformer):
810
849
  drop_input_cols=self._drop_input_cols,
811
850
  expected_output_cols_type="float",
812
851
  )
813
- expected_output_cols = self._align_expected_output_names(
852
+ expected_output_cols, _ = self._align_expected_output(
814
853
  inference_method, dataset, expected_output_cols, output_cols_prefix
815
854
  )
816
855
 
@@ -875,7 +914,7 @@ class TheilSenRegressor(BaseTransformer):
875
914
  drop_input_cols = self._drop_input_cols,
876
915
  expected_output_cols_type="float",
877
916
  )
878
- expected_output_cols = self._align_expected_output_names(
917
+ expected_output_cols, _ = self._align_expected_output(
879
918
  inference_method, dataset, expected_output_cols, output_cols_prefix
880
919
  )
881
920