snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +156 -20
  5. snowflake/ml/_internal/utils/identifier.py +48 -11
  6. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  7. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  8. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  9. snowflake/ml/_internal/utils/table_manager.py +19 -1
  10. snowflake/ml/_internal/utils/uri.py +2 -2
  11. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  12. snowflake/ml/data/data_connector.py +88 -9
  13. snowflake/ml/data/data_ingestor.py +18 -1
  14. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  15. snowflake/ml/data/torch_utils.py +68 -0
  16. snowflake/ml/dataset/dataset.py +1 -3
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +9 -3
  19. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  20. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  21. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  22. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  23. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  24. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  26. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  27. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  29. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  31. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  32. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  33. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  34. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  35. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  36. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  37. snowflake/ml/feature_store/feature_store.py +100 -41
  38. snowflake/ml/feature_store/feature_view.py +149 -5
  39. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  40. snowflake/ml/fileset/fileset.py +1 -1
  41. snowflake/ml/fileset/sfcfs.py +9 -3
  42. snowflake/ml/model/_client/model/model_impl.py +11 -2
  43. snowflake/ml/model/_client/model/model_version_impl.py +186 -20
  44. snowflake/ml/model/_client/ops/model_ops.py +144 -30
  45. snowflake/ml/model/_client/ops/service_ops.py +312 -0
  46. snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
  47. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
  48. snowflake/ml/model/_client/sql/model_version.py +13 -4
  49. snowflake/ml/model/_client/sql/service.py +196 -0
  50. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  51. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  52. snowflake/ml/model/_model_composer/model_composer.py +5 -0
  53. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
  54. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  55. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  57. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  58. snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
  59. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  60. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
  61. snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
  62. snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
  63. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  64. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  65. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  66. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  67. snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
  68. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
  69. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  70. snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
  71. snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
  72. snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
  73. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
  74. snowflake/ml/model/_packager/model_packager.py +4 -1
  75. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  76. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  77. snowflake/ml/model/_signatures/utils.py +9 -0
  78. snowflake/ml/model/models/llm.py +3 -1
  79. snowflake/ml/model/type_hints.py +10 -4
  80. snowflake/ml/modeling/_internal/constants.py +1 -0
  81. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  82. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  83. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  84. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  85. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  86. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  87. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  88. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  89. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  90. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  91. snowflake/ml/modeling/cluster/birch.py +60 -21
  92. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  93. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  94. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  95. snowflake/ml/modeling/cluster/k_means.py +60 -21
  96. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  97. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  98. snowflake/ml/modeling/cluster/optics.py +60 -21
  99. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  100. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  101. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  102. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  103. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  104. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  105. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  106. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  107. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  108. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  109. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  110. snowflake/ml/modeling/covariance/oas.py +60 -21
  111. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  112. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  113. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  114. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  115. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  116. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  117. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  118. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  119. snowflake/ml/modeling/decomposition/pca.py +60 -21
  120. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  121. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  122. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  123. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  124. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  125. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  126. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  127. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  128. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  129. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  130. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  131. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  132. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  133. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  134. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  135. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  136. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  137. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  138. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  139. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  140. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  141. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  142. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  143. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  144. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  145. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  146. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  147. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  148. snowflake/ml/modeling/framework/base.py +28 -19
  149. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  150. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  151. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  152. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  153. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  154. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  155. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  156. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  157. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  158. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  159. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  160. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  161. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  162. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  163. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  164. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  165. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  166. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  167. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  168. snowflake/ml/modeling/linear_model/lars.py +60 -21
  169. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  170. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  171. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  172. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  173. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  174. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  175. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  176. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  177. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  178. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  179. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  180. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  181. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  182. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  183. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  184. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  185. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  186. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  187. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  188. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  189. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  190. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  191. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  192. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  193. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  194. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  195. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  196. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  197. snowflake/ml/modeling/manifold/isomap.py +60 -21
  198. snowflake/ml/modeling/manifold/mds.py +60 -21
  199. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  200. snowflake/ml/modeling/manifold/tsne.py +60 -21
  201. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  202. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  203. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  204. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  205. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  206. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  207. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  208. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  209. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  210. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  211. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  212. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  213. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  214. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  215. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  216. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  217. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  218. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  219. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  220. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  221. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  222. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  223. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  224. snowflake/ml/modeling/pipeline/pipeline.py +4 -12
  225. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  226. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  227. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  228. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  229. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  230. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  231. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  232. snowflake/ml/modeling/svm/svc.py +60 -21
  233. snowflake/ml/modeling/svm/svr.py +60 -21
  234. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  235. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  236. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  237. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  238. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  239. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  240. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  241. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  242. snowflake/ml/registry/_manager/model_manager.py +20 -2
  243. snowflake/ml/registry/model_registry.py +1 -1
  244. snowflake/ml/registry/registry.py +1 -2
  245. snowflake/ml/utils/sql_client.py +22 -0
  246. snowflake/ml/version.py +1 -1
  247. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
  248. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
  249. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  250. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  251. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  252. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -530,12 +527,23 @@ class MeanShift(BaseTransformer):
530
527
  autogenerated=self._autogenerated,
531
528
  subproject=_SUBPROJECT,
532
529
  )
533
- output_result, fitted_estimator = model_trainer.train_fit_predict(
534
- drop_input_cols=self._drop_input_cols,
535
- expected_output_cols_list=(
536
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
537
- ),
530
+ expected_output_cols = (
531
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
538
532
  )
533
+ if isinstance(dataset, DataFrame):
534
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
535
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
536
+ )
537
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
538
+ drop_input_cols=self._drop_input_cols,
539
+ expected_output_cols_list=expected_output_cols,
540
+ example_output_pd_df=example_output_pd_df,
541
+ )
542
+ else:
543
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
544
+ drop_input_cols=self._drop_input_cols,
545
+ expected_output_cols_list=expected_output_cols,
546
+ )
539
547
  self._sklearn_object = fitted_estimator
540
548
  self._is_fitted = True
541
549
  return output_result
@@ -614,12 +622,41 @@ class MeanShift(BaseTransformer):
614
622
 
615
623
  return rv
616
624
 
617
- def _align_expected_output_names(
618
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
619
- ) -> List[str]:
625
+ def _align_expected_output(
626
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
627
+ ) -> Tuple[List[str], pd.DataFrame]:
628
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
629
+ and output dataframe with 1 line.
630
+ If the method is fit_predict, run 2 lines of data.
631
+ """
620
632
  # in case the inferred output column names dimension is different
621
633
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
622
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
634
+
635
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
636
+ # so change the minimum of number of rows to 2
637
+ num_examples = 2
638
+ statement_params = telemetry.get_function_usage_statement_params(
639
+ project=_PROJECT,
640
+ subproject=_SUBPROJECT,
641
+ function_name=telemetry.get_statement_params_full_func_name(
642
+ inspect.currentframe(), MeanShift.__class__.__name__
643
+ ),
644
+ api_calls=[Session.call],
645
+ custom_tags={"autogen": True} if self._autogenerated else None,
646
+ )
647
+ if output_cols_prefix == "fit_predict_":
648
+ if hasattr(self._sklearn_object, "n_clusters"):
649
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
650
+ num_examples = self._sklearn_object.n_clusters
651
+ elif hasattr(self._sklearn_object, "min_samples"):
652
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
653
+ num_examples = self._sklearn_object.min_samples
654
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
655
+ # LocalOutlierFactor expects n_neighbors <= n_samples
656
+ num_examples = self._sklearn_object.n_neighbors
657
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
658
+ else:
659
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
623
660
 
624
661
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
625
662
  # seen during the fit.
@@ -631,12 +668,14 @@ class MeanShift(BaseTransformer):
631
668
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
632
669
  if self.sample_weight_col:
633
670
  output_df_columns_set -= set(self.sample_weight_col)
671
+
634
672
  # if the dimension of inferred output column names is correct; use it
635
673
  if len(expected_output_cols_list) == len(output_df_columns_set):
636
- return expected_output_cols_list
674
+ return expected_output_cols_list, output_df_pd
637
675
  # otherwise, use the sklearn estimator's output
638
676
  else:
639
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
677
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
678
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
640
679
 
641
680
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
642
681
  @telemetry.send_api_usage_telemetry(
@@ -682,7 +721,7 @@ class MeanShift(BaseTransformer):
682
721
  drop_input_cols=self._drop_input_cols,
683
722
  expected_output_cols_type="float",
684
723
  )
685
- expected_output_cols = self._align_expected_output_names(
724
+ expected_output_cols, _ = self._align_expected_output(
686
725
  inference_method, dataset, expected_output_cols, output_cols_prefix
687
726
  )
688
727
 
@@ -748,7 +787,7 @@ class MeanShift(BaseTransformer):
748
787
  drop_input_cols=self._drop_input_cols,
749
788
  expected_output_cols_type="float",
750
789
  )
751
- expected_output_cols = self._align_expected_output_names(
790
+ expected_output_cols, _ = self._align_expected_output(
752
791
  inference_method, dataset, expected_output_cols, output_cols_prefix
753
792
  )
754
793
  elif isinstance(dataset, pd.DataFrame):
@@ -811,7 +850,7 @@ class MeanShift(BaseTransformer):
811
850
  drop_input_cols=self._drop_input_cols,
812
851
  expected_output_cols_type="float",
813
852
  )
814
- expected_output_cols = self._align_expected_output_names(
853
+ expected_output_cols, _ = self._align_expected_output(
815
854
  inference_method, dataset, expected_output_cols, output_cols_prefix
816
855
  )
817
856
 
@@ -876,7 +915,7 @@ class MeanShift(BaseTransformer):
876
915
  drop_input_cols = self._drop_input_cols,
877
916
  expected_output_cols_type="float",
878
917
  )
879
- expected_output_cols = self._align_expected_output_names(
918
+ expected_output_cols, _ = self._align_expected_output(
880
919
  inference_method, dataset, expected_output_cols, output_cols_prefix
881
920
  )
882
921
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -582,12 +579,23 @@ class MiniBatchKMeans(BaseTransformer):
582
579
  autogenerated=self._autogenerated,
583
580
  subproject=_SUBPROJECT,
584
581
  )
585
- output_result, fitted_estimator = model_trainer.train_fit_predict(
586
- drop_input_cols=self._drop_input_cols,
587
- expected_output_cols_list=(
588
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
589
- ),
582
+ expected_output_cols = (
583
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
590
584
  )
585
+ if isinstance(dataset, DataFrame):
586
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
587
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
588
+ )
589
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
590
+ drop_input_cols=self._drop_input_cols,
591
+ expected_output_cols_list=expected_output_cols,
592
+ example_output_pd_df=example_output_pd_df,
593
+ )
594
+ else:
595
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
596
+ drop_input_cols=self._drop_input_cols,
597
+ expected_output_cols_list=expected_output_cols,
598
+ )
591
599
  self._sklearn_object = fitted_estimator
592
600
  self._is_fitted = True
593
601
  return output_result
@@ -668,12 +676,41 @@ class MiniBatchKMeans(BaseTransformer):
668
676
 
669
677
  return rv
670
678
 
671
- def _align_expected_output_names(
672
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
673
- ) -> List[str]:
679
+ def _align_expected_output(
680
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
681
+ ) -> Tuple[List[str], pd.DataFrame]:
682
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
683
+ and output dataframe with 1 line.
684
+ If the method is fit_predict, run 2 lines of data.
685
+ """
674
686
  # in case the inferred output column names dimension is different
675
687
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
676
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
688
+
689
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
690
+ # so change the minimum of number of rows to 2
691
+ num_examples = 2
692
+ statement_params = telemetry.get_function_usage_statement_params(
693
+ project=_PROJECT,
694
+ subproject=_SUBPROJECT,
695
+ function_name=telemetry.get_statement_params_full_func_name(
696
+ inspect.currentframe(), MiniBatchKMeans.__class__.__name__
697
+ ),
698
+ api_calls=[Session.call],
699
+ custom_tags={"autogen": True} if self._autogenerated else None,
700
+ )
701
+ if output_cols_prefix == "fit_predict_":
702
+ if hasattr(self._sklearn_object, "n_clusters"):
703
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
704
+ num_examples = self._sklearn_object.n_clusters
705
+ elif hasattr(self._sklearn_object, "min_samples"):
706
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
707
+ num_examples = self._sklearn_object.min_samples
708
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
709
+ # LocalOutlierFactor expects n_neighbors <= n_samples
710
+ num_examples = self._sklearn_object.n_neighbors
711
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
712
+ else:
713
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
677
714
 
678
715
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
679
716
  # seen during the fit.
@@ -685,12 +722,14 @@ class MiniBatchKMeans(BaseTransformer):
685
722
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
686
723
  if self.sample_weight_col:
687
724
  output_df_columns_set -= set(self.sample_weight_col)
725
+
688
726
  # if the dimension of inferred output column names is correct; use it
689
727
  if len(expected_output_cols_list) == len(output_df_columns_set):
690
- return expected_output_cols_list
728
+ return expected_output_cols_list, output_df_pd
691
729
  # otherwise, use the sklearn estimator's output
692
730
  else:
693
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
731
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
732
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
694
733
 
695
734
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
696
735
  @telemetry.send_api_usage_telemetry(
@@ -736,7 +775,7 @@ class MiniBatchKMeans(BaseTransformer):
736
775
  drop_input_cols=self._drop_input_cols,
737
776
  expected_output_cols_type="float",
738
777
  )
739
- expected_output_cols = self._align_expected_output_names(
778
+ expected_output_cols, _ = self._align_expected_output(
740
779
  inference_method, dataset, expected_output_cols, output_cols_prefix
741
780
  )
742
781
 
@@ -802,7 +841,7 @@ class MiniBatchKMeans(BaseTransformer):
802
841
  drop_input_cols=self._drop_input_cols,
803
842
  expected_output_cols_type="float",
804
843
  )
805
- expected_output_cols = self._align_expected_output_names(
844
+ expected_output_cols, _ = self._align_expected_output(
806
845
  inference_method, dataset, expected_output_cols, output_cols_prefix
807
846
  )
808
847
  elif isinstance(dataset, pd.DataFrame):
@@ -865,7 +904,7 @@ class MiniBatchKMeans(BaseTransformer):
865
904
  drop_input_cols=self._drop_input_cols,
866
905
  expected_output_cols_type="float",
867
906
  )
868
- expected_output_cols = self._align_expected_output_names(
907
+ expected_output_cols, _ = self._align_expected_output(
869
908
  inference_method, dataset, expected_output_cols, output_cols_prefix
870
909
  )
871
910
 
@@ -930,7 +969,7 @@ class MiniBatchKMeans(BaseTransformer):
930
969
  drop_input_cols = self._drop_input_cols,
931
970
  expected_output_cols_type="float",
932
971
  )
933
- expected_output_cols = self._align_expected_output_names(
972
+ expected_output_cols, _ = self._align_expected_output(
934
973
  inference_method, dataset, expected_output_cols, output_cols_prefix
935
974
  )
936
975
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -598,12 +595,23 @@ class OPTICS(BaseTransformer):
598
595
  autogenerated=self._autogenerated,
599
596
  subproject=_SUBPROJECT,
600
597
  )
601
- output_result, fitted_estimator = model_trainer.train_fit_predict(
602
- drop_input_cols=self._drop_input_cols,
603
- expected_output_cols_list=(
604
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
605
- ),
598
+ expected_output_cols = (
599
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
606
600
  )
601
+ if isinstance(dataset, DataFrame):
602
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
603
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
604
+ )
605
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
606
+ drop_input_cols=self._drop_input_cols,
607
+ expected_output_cols_list=expected_output_cols,
608
+ example_output_pd_df=example_output_pd_df,
609
+ )
610
+ else:
611
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
612
+ drop_input_cols=self._drop_input_cols,
613
+ expected_output_cols_list=expected_output_cols,
614
+ )
607
615
  self._sklearn_object = fitted_estimator
608
616
  self._is_fitted = True
609
617
  return output_result
@@ -682,12 +690,41 @@ class OPTICS(BaseTransformer):
682
690
 
683
691
  return rv
684
692
 
685
- def _align_expected_output_names(
686
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
687
- ) -> List[str]:
693
+ def _align_expected_output(
694
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
695
+ ) -> Tuple[List[str], pd.DataFrame]:
696
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
697
+ and output dataframe with 1 line.
698
+ If the method is fit_predict, run 2 lines of data.
699
+ """
688
700
  # in case the inferred output column names dimension is different
689
701
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
690
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
702
+
703
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
704
+ # so change the minimum of number of rows to 2
705
+ num_examples = 2
706
+ statement_params = telemetry.get_function_usage_statement_params(
707
+ project=_PROJECT,
708
+ subproject=_SUBPROJECT,
709
+ function_name=telemetry.get_statement_params_full_func_name(
710
+ inspect.currentframe(), OPTICS.__class__.__name__
711
+ ),
712
+ api_calls=[Session.call],
713
+ custom_tags={"autogen": True} if self._autogenerated else None,
714
+ )
715
+ if output_cols_prefix == "fit_predict_":
716
+ if hasattr(self._sklearn_object, "n_clusters"):
717
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
718
+ num_examples = self._sklearn_object.n_clusters
719
+ elif hasattr(self._sklearn_object, "min_samples"):
720
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
721
+ num_examples = self._sklearn_object.min_samples
722
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
723
+ # LocalOutlierFactor expects n_neighbors <= n_samples
724
+ num_examples = self._sklearn_object.n_neighbors
725
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
726
+ else:
727
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
691
728
 
692
729
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
693
730
  # seen during the fit.
@@ -699,12 +736,14 @@ class OPTICS(BaseTransformer):
699
736
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
700
737
  if self.sample_weight_col:
701
738
  output_df_columns_set -= set(self.sample_weight_col)
739
+
702
740
  # if the dimension of inferred output column names is correct; use it
703
741
  if len(expected_output_cols_list) == len(output_df_columns_set):
704
- return expected_output_cols_list
742
+ return expected_output_cols_list, output_df_pd
705
743
  # otherwise, use the sklearn estimator's output
706
744
  else:
707
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
745
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
746
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
708
747
 
709
748
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
710
749
  @telemetry.send_api_usage_telemetry(
@@ -750,7 +789,7 @@ class OPTICS(BaseTransformer):
750
789
  drop_input_cols=self._drop_input_cols,
751
790
  expected_output_cols_type="float",
752
791
  )
753
- expected_output_cols = self._align_expected_output_names(
792
+ expected_output_cols, _ = self._align_expected_output(
754
793
  inference_method, dataset, expected_output_cols, output_cols_prefix
755
794
  )
756
795
 
@@ -816,7 +855,7 @@ class OPTICS(BaseTransformer):
816
855
  drop_input_cols=self._drop_input_cols,
817
856
  expected_output_cols_type="float",
818
857
  )
819
- expected_output_cols = self._align_expected_output_names(
858
+ expected_output_cols, _ = self._align_expected_output(
820
859
  inference_method, dataset, expected_output_cols, output_cols_prefix
821
860
  )
822
861
  elif isinstance(dataset, pd.DataFrame):
@@ -879,7 +918,7 @@ class OPTICS(BaseTransformer):
879
918
  drop_input_cols=self._drop_input_cols,
880
919
  expected_output_cols_type="float",
881
920
  )
882
- expected_output_cols = self._align_expected_output_names(
921
+ expected_output_cols, _ = self._align_expected_output(
883
922
  inference_method, dataset, expected_output_cols, output_cols_prefix
884
923
  )
885
924
 
@@ -944,7 +983,7 @@ class OPTICS(BaseTransformer):
944
983
  drop_input_cols = self._drop_input_cols,
945
984
  expected_output_cols_type="float",
946
985
  )
947
- expected_output_cols = self._align_expected_output_names(
986
+ expected_output_cols, _ = self._align_expected_output(
948
987
  inference_method, dataset, expected_output_cols, output_cols_prefix
949
988
  )
950
989
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -534,12 +531,23 @@ class SpectralBiclustering(BaseTransformer):
534
531
  autogenerated=self._autogenerated,
535
532
  subproject=_SUBPROJECT,
536
533
  )
537
- output_result, fitted_estimator = model_trainer.train_fit_predict(
538
- drop_input_cols=self._drop_input_cols,
539
- expected_output_cols_list=(
540
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
541
- ),
534
+ expected_output_cols = (
535
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
542
536
  )
537
+ if isinstance(dataset, DataFrame):
538
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
539
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
540
+ )
541
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
542
+ drop_input_cols=self._drop_input_cols,
543
+ expected_output_cols_list=expected_output_cols,
544
+ example_output_pd_df=example_output_pd_df,
545
+ )
546
+ else:
547
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
548
+ drop_input_cols=self._drop_input_cols,
549
+ expected_output_cols_list=expected_output_cols,
550
+ )
543
551
  self._sklearn_object = fitted_estimator
544
552
  self._is_fitted = True
545
553
  return output_result
@@ -618,12 +626,41 @@ class SpectralBiclustering(BaseTransformer):
618
626
 
619
627
  return rv
620
628
 
621
- def _align_expected_output_names(
622
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
623
- ) -> List[str]:
629
+ def _align_expected_output(
630
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
631
+ ) -> Tuple[List[str], pd.DataFrame]:
632
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
633
+ and output dataframe with 1 line.
634
+ If the method is fit_predict, run 2 lines of data.
635
+ """
624
636
  # in case the inferred output column names dimension is different
625
637
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
626
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
638
+
639
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
640
+ # so change the minimum of number of rows to 2
641
+ num_examples = 2
642
+ statement_params = telemetry.get_function_usage_statement_params(
643
+ project=_PROJECT,
644
+ subproject=_SUBPROJECT,
645
+ function_name=telemetry.get_statement_params_full_func_name(
646
+ inspect.currentframe(), SpectralBiclustering.__class__.__name__
647
+ ),
648
+ api_calls=[Session.call],
649
+ custom_tags={"autogen": True} if self._autogenerated else None,
650
+ )
651
+ if output_cols_prefix == "fit_predict_":
652
+ if hasattr(self._sklearn_object, "n_clusters"):
653
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
654
+ num_examples = self._sklearn_object.n_clusters
655
+ elif hasattr(self._sklearn_object, "min_samples"):
656
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
657
+ num_examples = self._sklearn_object.min_samples
658
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
659
+ # LocalOutlierFactor expects n_neighbors <= n_samples
660
+ num_examples = self._sklearn_object.n_neighbors
661
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
662
+ else:
663
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
627
664
 
628
665
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
629
666
  # seen during the fit.
@@ -635,12 +672,14 @@ class SpectralBiclustering(BaseTransformer):
635
672
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
636
673
  if self.sample_weight_col:
637
674
  output_df_columns_set -= set(self.sample_weight_col)
675
+
638
676
  # if the dimension of inferred output column names is correct; use it
639
677
  if len(expected_output_cols_list) == len(output_df_columns_set):
640
- return expected_output_cols_list
678
+ return expected_output_cols_list, output_df_pd
641
679
  # otherwise, use the sklearn estimator's output
642
680
  else:
643
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
681
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
682
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
644
683
 
645
684
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
646
685
  @telemetry.send_api_usage_telemetry(
@@ -686,7 +725,7 @@ class SpectralBiclustering(BaseTransformer):
686
725
  drop_input_cols=self._drop_input_cols,
687
726
  expected_output_cols_type="float",
688
727
  )
689
- expected_output_cols = self._align_expected_output_names(
728
+ expected_output_cols, _ = self._align_expected_output(
690
729
  inference_method, dataset, expected_output_cols, output_cols_prefix
691
730
  )
692
731
 
@@ -752,7 +791,7 @@ class SpectralBiclustering(BaseTransformer):
752
791
  drop_input_cols=self._drop_input_cols,
753
792
  expected_output_cols_type="float",
754
793
  )
755
- expected_output_cols = self._align_expected_output_names(
794
+ expected_output_cols, _ = self._align_expected_output(
756
795
  inference_method, dataset, expected_output_cols, output_cols_prefix
757
796
  )
758
797
  elif isinstance(dataset, pd.DataFrame):
@@ -815,7 +854,7 @@ class SpectralBiclustering(BaseTransformer):
815
854
  drop_input_cols=self._drop_input_cols,
816
855
  expected_output_cols_type="float",
817
856
  )
818
- expected_output_cols = self._align_expected_output_names(
857
+ expected_output_cols, _ = self._align_expected_output(
819
858
  inference_method, dataset, expected_output_cols, output_cols_prefix
820
859
  )
821
860
 
@@ -880,7 +919,7 @@ class SpectralBiclustering(BaseTransformer):
880
919
  drop_input_cols = self._drop_input_cols,
881
920
  expected_output_cols_type="float",
882
921
  )
883
- expected_output_cols = self._align_expected_output_names(
922
+ expected_output_cols, _ = self._align_expected_output(
884
923
  inference_method, dataset, expected_output_cols, output_cols_prefix
885
924
  )
886
925