snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +156 -20
  5. snowflake/ml/_internal/utils/identifier.py +48 -11
  6. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  7. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  8. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  9. snowflake/ml/_internal/utils/table_manager.py +19 -1
  10. snowflake/ml/_internal/utils/uri.py +2 -2
  11. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  12. snowflake/ml/data/data_connector.py +88 -9
  13. snowflake/ml/data/data_ingestor.py +18 -1
  14. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  15. snowflake/ml/data/torch_utils.py +68 -0
  16. snowflake/ml/dataset/dataset.py +1 -3
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +9 -3
  19. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  20. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  21. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  22. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  23. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  24. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  26. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  27. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  29. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  31. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  32. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  33. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  34. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  35. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  36. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  37. snowflake/ml/feature_store/feature_store.py +100 -41
  38. snowflake/ml/feature_store/feature_view.py +149 -5
  39. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  40. snowflake/ml/fileset/fileset.py +1 -1
  41. snowflake/ml/fileset/sfcfs.py +9 -3
  42. snowflake/ml/model/_client/model/model_impl.py +11 -2
  43. snowflake/ml/model/_client/model/model_version_impl.py +186 -20
  44. snowflake/ml/model/_client/ops/model_ops.py +144 -30
  45. snowflake/ml/model/_client/ops/service_ops.py +312 -0
  46. snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
  47. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
  48. snowflake/ml/model/_client/sql/model_version.py +13 -4
  49. snowflake/ml/model/_client/sql/service.py +196 -0
  50. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  51. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  52. snowflake/ml/model/_model_composer/model_composer.py +5 -0
  53. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
  54. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  55. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  57. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  58. snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
  59. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  60. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
  61. snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
  62. snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
  63. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  64. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  65. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  66. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  67. snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
  68. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
  69. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  70. snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
  71. snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
  72. snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
  73. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
  74. snowflake/ml/model/_packager/model_packager.py +4 -1
  75. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  76. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  77. snowflake/ml/model/_signatures/utils.py +9 -0
  78. snowflake/ml/model/models/llm.py +3 -1
  79. snowflake/ml/model/type_hints.py +10 -4
  80. snowflake/ml/modeling/_internal/constants.py +1 -0
  81. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  82. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  83. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  84. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  85. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  86. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  87. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  88. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  89. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  90. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  91. snowflake/ml/modeling/cluster/birch.py +60 -21
  92. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  93. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  94. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  95. snowflake/ml/modeling/cluster/k_means.py +60 -21
  96. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  97. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  98. snowflake/ml/modeling/cluster/optics.py +60 -21
  99. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  100. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  101. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  102. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  103. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  104. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  105. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  106. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  107. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  108. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  109. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  110. snowflake/ml/modeling/covariance/oas.py +60 -21
  111. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  112. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  113. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  114. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  115. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  116. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  117. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  118. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  119. snowflake/ml/modeling/decomposition/pca.py +60 -21
  120. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  121. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  122. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  123. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  124. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  125. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  126. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  127. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  128. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  129. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  130. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  131. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  132. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  133. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  134. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  135. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  136. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  137. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  138. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  139. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  140. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  141. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  142. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  143. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  144. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  145. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  146. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  147. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  148. snowflake/ml/modeling/framework/base.py +28 -19
  149. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  150. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  151. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  152. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  153. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  154. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  155. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  156. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  157. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  158. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  159. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  160. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  161. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  162. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  163. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  164. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  165. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  166. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  167. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  168. snowflake/ml/modeling/linear_model/lars.py +60 -21
  169. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  170. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  171. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  172. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  173. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  174. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  175. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  176. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  177. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  178. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  179. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  180. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  181. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  182. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  183. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  184. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  185. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  186. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  187. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  188. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  189. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  190. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  191. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  192. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  193. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  194. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  195. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  196. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  197. snowflake/ml/modeling/manifold/isomap.py +60 -21
  198. snowflake/ml/modeling/manifold/mds.py +60 -21
  199. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  200. snowflake/ml/modeling/manifold/tsne.py +60 -21
  201. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  202. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  203. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  204. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  205. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  206. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  207. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  208. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  209. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  210. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  211. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  212. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  213. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  214. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  215. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  216. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  217. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  218. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  219. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  220. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  221. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  222. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  223. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  224. snowflake/ml/modeling/pipeline/pipeline.py +4 -12
  225. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  226. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  227. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  228. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  229. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  230. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  231. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  232. snowflake/ml/modeling/svm/svc.py +60 -21
  233. snowflake/ml/modeling/svm/svr.py +60 -21
  234. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  235. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  236. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  237. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  238. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  239. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  240. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  241. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  242. snowflake/ml/registry/_manager/model_manager.py +20 -2
  243. snowflake/ml/registry/model_registry.py +1 -1
  244. snowflake/ml/registry/registry.py +1 -2
  245. snowflake/ml/utils/sql_client.py +22 -0
  246. snowflake/ml/version.py +1 -1
  247. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
  248. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
  249. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  250. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  251. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  252. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -525,12 +522,23 @@ class KNNImputer(BaseTransformer):
525
522
  autogenerated=self._autogenerated,
526
523
  subproject=_SUBPROJECT,
527
524
  )
528
- output_result, fitted_estimator = model_trainer.train_fit_predict(
529
- drop_input_cols=self._drop_input_cols,
530
- expected_output_cols_list=(
531
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
532
- ),
525
+ expected_output_cols = (
526
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
533
527
  )
528
+ if isinstance(dataset, DataFrame):
529
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
530
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
531
+ )
532
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
533
+ drop_input_cols=self._drop_input_cols,
534
+ expected_output_cols_list=expected_output_cols,
535
+ example_output_pd_df=example_output_pd_df,
536
+ )
537
+ else:
538
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
539
+ drop_input_cols=self._drop_input_cols,
540
+ expected_output_cols_list=expected_output_cols,
541
+ )
534
542
  self._sklearn_object = fitted_estimator
535
543
  self._is_fitted = True
536
544
  return output_result
@@ -611,12 +619,41 @@ class KNNImputer(BaseTransformer):
611
619
 
612
620
  return rv
613
621
 
614
- def _align_expected_output_names(
615
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
616
- ) -> List[str]:
622
+ def _align_expected_output(
623
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
624
+ ) -> Tuple[List[str], pd.DataFrame]:
625
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
626
+ and output dataframe with 1 line.
627
+ If the method is fit_predict, run 2 lines of data.
628
+ """
617
629
  # in case the inferred output column names dimension is different
618
630
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
619
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
631
+
632
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
633
+ # so change the minimum of number of rows to 2
634
+ num_examples = 2
635
+ statement_params = telemetry.get_function_usage_statement_params(
636
+ project=_PROJECT,
637
+ subproject=_SUBPROJECT,
638
+ function_name=telemetry.get_statement_params_full_func_name(
639
+ inspect.currentframe(), KNNImputer.__class__.__name__
640
+ ),
641
+ api_calls=[Session.call],
642
+ custom_tags={"autogen": True} if self._autogenerated else None,
643
+ )
644
+ if output_cols_prefix == "fit_predict_":
645
+ if hasattr(self._sklearn_object, "n_clusters"):
646
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
647
+ num_examples = self._sklearn_object.n_clusters
648
+ elif hasattr(self._sklearn_object, "min_samples"):
649
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
650
+ num_examples = self._sklearn_object.min_samples
651
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
652
+ # LocalOutlierFactor expects n_neighbors <= n_samples
653
+ num_examples = self._sklearn_object.n_neighbors
654
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
655
+ else:
656
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
620
657
 
621
658
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
622
659
  # seen during the fit.
@@ -628,12 +665,14 @@ class KNNImputer(BaseTransformer):
628
665
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
629
666
  if self.sample_weight_col:
630
667
  output_df_columns_set -= set(self.sample_weight_col)
668
+
631
669
  # if the dimension of inferred output column names is correct; use it
632
670
  if len(expected_output_cols_list) == len(output_df_columns_set):
633
- return expected_output_cols_list
671
+ return expected_output_cols_list, output_df_pd
634
672
  # otherwise, use the sklearn estimator's output
635
673
  else:
636
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
674
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
675
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
637
676
 
638
677
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
639
678
  @telemetry.send_api_usage_telemetry(
@@ -679,7 +718,7 @@ class KNNImputer(BaseTransformer):
679
718
  drop_input_cols=self._drop_input_cols,
680
719
  expected_output_cols_type="float",
681
720
  )
682
- expected_output_cols = self._align_expected_output_names(
721
+ expected_output_cols, _ = self._align_expected_output(
683
722
  inference_method, dataset, expected_output_cols, output_cols_prefix
684
723
  )
685
724
 
@@ -745,7 +784,7 @@ class KNNImputer(BaseTransformer):
745
784
  drop_input_cols=self._drop_input_cols,
746
785
  expected_output_cols_type="float",
747
786
  )
748
- expected_output_cols = self._align_expected_output_names(
787
+ expected_output_cols, _ = self._align_expected_output(
749
788
  inference_method, dataset, expected_output_cols, output_cols_prefix
750
789
  )
751
790
  elif isinstance(dataset, pd.DataFrame):
@@ -808,7 +847,7 @@ class KNNImputer(BaseTransformer):
808
847
  drop_input_cols=self._drop_input_cols,
809
848
  expected_output_cols_type="float",
810
849
  )
811
- expected_output_cols = self._align_expected_output_names(
850
+ expected_output_cols, _ = self._align_expected_output(
812
851
  inference_method, dataset, expected_output_cols, output_cols_prefix
813
852
  )
814
853
 
@@ -873,7 +912,7 @@ class KNNImputer(BaseTransformer):
873
912
  drop_input_cols = self._drop_input_cols,
874
913
  expected_output_cols_type="float",
875
914
  )
876
- expected_output_cols = self._align_expected_output_names(
915
+ expected_output_cols, _ = self._align_expected_output(
877
916
  inference_method, dataset, expected_output_cols, output_cols_prefix
878
917
  )
879
918
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -499,12 +496,23 @@ class MissingIndicator(BaseTransformer):
499
496
  autogenerated=self._autogenerated,
500
497
  subproject=_SUBPROJECT,
501
498
  )
502
- output_result, fitted_estimator = model_trainer.train_fit_predict(
503
- drop_input_cols=self._drop_input_cols,
504
- expected_output_cols_list=(
505
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
506
- ),
499
+ expected_output_cols = (
500
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
507
501
  )
502
+ if isinstance(dataset, DataFrame):
503
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
504
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
505
+ )
506
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
507
+ drop_input_cols=self._drop_input_cols,
508
+ expected_output_cols_list=expected_output_cols,
509
+ example_output_pd_df=example_output_pd_df,
510
+ )
511
+ else:
512
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
513
+ drop_input_cols=self._drop_input_cols,
514
+ expected_output_cols_list=expected_output_cols,
515
+ )
508
516
  self._sklearn_object = fitted_estimator
509
517
  self._is_fitted = True
510
518
  return output_result
@@ -585,12 +593,41 @@ class MissingIndicator(BaseTransformer):
585
593
 
586
594
  return rv
587
595
 
588
- def _align_expected_output_names(
589
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
590
- ) -> List[str]:
596
+ def _align_expected_output(
597
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
598
+ ) -> Tuple[List[str], pd.DataFrame]:
599
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
600
+ and output dataframe with 1 line.
601
+ If the method is fit_predict, run 2 lines of data.
602
+ """
591
603
  # in case the inferred output column names dimension is different
592
604
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
593
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
605
+
606
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
607
+ # so change the minimum of number of rows to 2
608
+ num_examples = 2
609
+ statement_params = telemetry.get_function_usage_statement_params(
610
+ project=_PROJECT,
611
+ subproject=_SUBPROJECT,
612
+ function_name=telemetry.get_statement_params_full_func_name(
613
+ inspect.currentframe(), MissingIndicator.__class__.__name__
614
+ ),
615
+ api_calls=[Session.call],
616
+ custom_tags={"autogen": True} if self._autogenerated else None,
617
+ )
618
+ if output_cols_prefix == "fit_predict_":
619
+ if hasattr(self._sklearn_object, "n_clusters"):
620
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
621
+ num_examples = self._sklearn_object.n_clusters
622
+ elif hasattr(self._sklearn_object, "min_samples"):
623
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
624
+ num_examples = self._sklearn_object.min_samples
625
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
626
+ # LocalOutlierFactor expects n_neighbors <= n_samples
627
+ num_examples = self._sklearn_object.n_neighbors
628
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
629
+ else:
630
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
594
631
 
595
632
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
596
633
  # seen during the fit.
@@ -602,12 +639,14 @@ class MissingIndicator(BaseTransformer):
602
639
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
603
640
  if self.sample_weight_col:
604
641
  output_df_columns_set -= set(self.sample_weight_col)
642
+
605
643
  # if the dimension of inferred output column names is correct; use it
606
644
  if len(expected_output_cols_list) == len(output_df_columns_set):
607
- return expected_output_cols_list
645
+ return expected_output_cols_list, output_df_pd
608
646
  # otherwise, use the sklearn estimator's output
609
647
  else:
610
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
648
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
649
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
611
650
 
612
651
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
613
652
  @telemetry.send_api_usage_telemetry(
@@ -653,7 +692,7 @@ class MissingIndicator(BaseTransformer):
653
692
  drop_input_cols=self._drop_input_cols,
654
693
  expected_output_cols_type="float",
655
694
  )
656
- expected_output_cols = self._align_expected_output_names(
695
+ expected_output_cols, _ = self._align_expected_output(
657
696
  inference_method, dataset, expected_output_cols, output_cols_prefix
658
697
  )
659
698
 
@@ -719,7 +758,7 @@ class MissingIndicator(BaseTransformer):
719
758
  drop_input_cols=self._drop_input_cols,
720
759
  expected_output_cols_type="float",
721
760
  )
722
- expected_output_cols = self._align_expected_output_names(
761
+ expected_output_cols, _ = self._align_expected_output(
723
762
  inference_method, dataset, expected_output_cols, output_cols_prefix
724
763
  )
725
764
  elif isinstance(dataset, pd.DataFrame):
@@ -782,7 +821,7 @@ class MissingIndicator(BaseTransformer):
782
821
  drop_input_cols=self._drop_input_cols,
783
822
  expected_output_cols_type="float",
784
823
  )
785
- expected_output_cols = self._align_expected_output_names(
824
+ expected_output_cols, _ = self._align_expected_output(
786
825
  inference_method, dataset, expected_output_cols, output_cols_prefix
787
826
  )
788
827
 
@@ -847,7 +886,7 @@ class MissingIndicator(BaseTransformer):
847
886
  drop_input_cols = self._drop_input_cols,
848
887
  expected_output_cols_type="float",
849
888
  )
850
- expected_output_cols = self._align_expected_output_names(
889
+ expected_output_cols, _ = self._align_expected_output(
851
890
  inference_method, dataset, expected_output_cols, output_cols_prefix
852
891
  )
853
892
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -474,12 +471,23 @@ class AdditiveChi2Sampler(BaseTransformer):
474
471
  autogenerated=self._autogenerated,
475
472
  subproject=_SUBPROJECT,
476
473
  )
477
- output_result, fitted_estimator = model_trainer.train_fit_predict(
478
- drop_input_cols=self._drop_input_cols,
479
- expected_output_cols_list=(
480
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
481
- ),
474
+ expected_output_cols = (
475
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
482
476
  )
477
+ if isinstance(dataset, DataFrame):
478
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
479
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
480
+ )
481
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
482
+ drop_input_cols=self._drop_input_cols,
483
+ expected_output_cols_list=expected_output_cols,
484
+ example_output_pd_df=example_output_pd_df,
485
+ )
486
+ else:
487
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
488
+ drop_input_cols=self._drop_input_cols,
489
+ expected_output_cols_list=expected_output_cols,
490
+ )
483
491
  self._sklearn_object = fitted_estimator
484
492
  self._is_fitted = True
485
493
  return output_result
@@ -560,12 +568,41 @@ class AdditiveChi2Sampler(BaseTransformer):
560
568
 
561
569
  return rv
562
570
 
563
- def _align_expected_output_names(
564
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
565
- ) -> List[str]:
571
+ def _align_expected_output(
572
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
573
+ ) -> Tuple[List[str], pd.DataFrame]:
574
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
575
+ and output dataframe with 1 line.
576
+ If the method is fit_predict, run 2 lines of data.
577
+ """
566
578
  # in case the inferred output column names dimension is different
567
579
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
568
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
580
+
581
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
582
+ # so change the minimum of number of rows to 2
583
+ num_examples = 2
584
+ statement_params = telemetry.get_function_usage_statement_params(
585
+ project=_PROJECT,
586
+ subproject=_SUBPROJECT,
587
+ function_name=telemetry.get_statement_params_full_func_name(
588
+ inspect.currentframe(), AdditiveChi2Sampler.__class__.__name__
589
+ ),
590
+ api_calls=[Session.call],
591
+ custom_tags={"autogen": True} if self._autogenerated else None,
592
+ )
593
+ if output_cols_prefix == "fit_predict_":
594
+ if hasattr(self._sklearn_object, "n_clusters"):
595
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
596
+ num_examples = self._sklearn_object.n_clusters
597
+ elif hasattr(self._sklearn_object, "min_samples"):
598
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
599
+ num_examples = self._sklearn_object.min_samples
600
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
601
+ # LocalOutlierFactor expects n_neighbors <= n_samples
602
+ num_examples = self._sklearn_object.n_neighbors
603
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
604
+ else:
605
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
569
606
 
570
607
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
571
608
  # seen during the fit.
@@ -577,12 +614,14 @@ class AdditiveChi2Sampler(BaseTransformer):
577
614
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
578
615
  if self.sample_weight_col:
579
616
  output_df_columns_set -= set(self.sample_weight_col)
617
+
580
618
  # if the dimension of inferred output column names is correct; use it
581
619
  if len(expected_output_cols_list) == len(output_df_columns_set):
582
- return expected_output_cols_list
620
+ return expected_output_cols_list, output_df_pd
583
621
  # otherwise, use the sklearn estimator's output
584
622
  else:
585
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
623
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
624
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
586
625
 
587
626
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
588
627
  @telemetry.send_api_usage_telemetry(
@@ -628,7 +667,7 @@ class AdditiveChi2Sampler(BaseTransformer):
628
667
  drop_input_cols=self._drop_input_cols,
629
668
  expected_output_cols_type="float",
630
669
  )
631
- expected_output_cols = self._align_expected_output_names(
670
+ expected_output_cols, _ = self._align_expected_output(
632
671
  inference_method, dataset, expected_output_cols, output_cols_prefix
633
672
  )
634
673
 
@@ -694,7 +733,7 @@ class AdditiveChi2Sampler(BaseTransformer):
694
733
  drop_input_cols=self._drop_input_cols,
695
734
  expected_output_cols_type="float",
696
735
  )
697
- expected_output_cols = self._align_expected_output_names(
736
+ expected_output_cols, _ = self._align_expected_output(
698
737
  inference_method, dataset, expected_output_cols, output_cols_prefix
699
738
  )
700
739
  elif isinstance(dataset, pd.DataFrame):
@@ -757,7 +796,7 @@ class AdditiveChi2Sampler(BaseTransformer):
757
796
  drop_input_cols=self._drop_input_cols,
758
797
  expected_output_cols_type="float",
759
798
  )
760
- expected_output_cols = self._align_expected_output_names(
799
+ expected_output_cols, _ = self._align_expected_output(
761
800
  inference_method, dataset, expected_output_cols, output_cols_prefix
762
801
  )
763
802
 
@@ -822,7 +861,7 @@ class AdditiveChi2Sampler(BaseTransformer):
822
861
  drop_input_cols = self._drop_input_cols,
823
862
  expected_output_cols_type="float",
824
863
  )
825
- expected_output_cols = self._align_expected_output_names(
864
+ expected_output_cols, _ = self._align_expected_output(
826
865
  inference_method, dataset, expected_output_cols, output_cols_prefix
827
866
  )
828
867
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -522,12 +519,23 @@ class Nystroem(BaseTransformer):
522
519
  autogenerated=self._autogenerated,
523
520
  subproject=_SUBPROJECT,
524
521
  )
525
- output_result, fitted_estimator = model_trainer.train_fit_predict(
526
- drop_input_cols=self._drop_input_cols,
527
- expected_output_cols_list=(
528
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
529
- ),
522
+ expected_output_cols = (
523
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
530
524
  )
525
+ if isinstance(dataset, DataFrame):
526
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
527
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
528
+ )
529
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
530
+ drop_input_cols=self._drop_input_cols,
531
+ expected_output_cols_list=expected_output_cols,
532
+ example_output_pd_df=example_output_pd_df,
533
+ )
534
+ else:
535
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
536
+ drop_input_cols=self._drop_input_cols,
537
+ expected_output_cols_list=expected_output_cols,
538
+ )
531
539
  self._sklearn_object = fitted_estimator
532
540
  self._is_fitted = True
533
541
  return output_result
@@ -608,12 +616,41 @@ class Nystroem(BaseTransformer):
608
616
 
609
617
  return rv
610
618
 
611
- def _align_expected_output_names(
612
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
613
- ) -> List[str]:
619
+ def _align_expected_output(
620
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
621
+ ) -> Tuple[List[str], pd.DataFrame]:
622
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
623
+ and output dataframe with 1 line.
624
+ If the method is fit_predict, run 2 lines of data.
625
+ """
614
626
  # in case the inferred output column names dimension is different
615
627
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
616
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
628
+
629
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
630
+ # so change the minimum of number of rows to 2
631
+ num_examples = 2
632
+ statement_params = telemetry.get_function_usage_statement_params(
633
+ project=_PROJECT,
634
+ subproject=_SUBPROJECT,
635
+ function_name=telemetry.get_statement_params_full_func_name(
636
+ inspect.currentframe(), Nystroem.__class__.__name__
637
+ ),
638
+ api_calls=[Session.call],
639
+ custom_tags={"autogen": True} if self._autogenerated else None,
640
+ )
641
+ if output_cols_prefix == "fit_predict_":
642
+ if hasattr(self._sklearn_object, "n_clusters"):
643
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
644
+ num_examples = self._sklearn_object.n_clusters
645
+ elif hasattr(self._sklearn_object, "min_samples"):
646
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
647
+ num_examples = self._sklearn_object.min_samples
648
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
649
+ # LocalOutlierFactor expects n_neighbors <= n_samples
650
+ num_examples = self._sklearn_object.n_neighbors
651
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
652
+ else:
653
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
617
654
 
618
655
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
619
656
  # seen during the fit.
@@ -625,12 +662,14 @@ class Nystroem(BaseTransformer):
625
662
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
626
663
  if self.sample_weight_col:
627
664
  output_df_columns_set -= set(self.sample_weight_col)
665
+
628
666
  # if the dimension of inferred output column names is correct; use it
629
667
  if len(expected_output_cols_list) == len(output_df_columns_set):
630
- return expected_output_cols_list
668
+ return expected_output_cols_list, output_df_pd
631
669
  # otherwise, use the sklearn estimator's output
632
670
  else:
633
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
671
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
672
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
634
673
 
635
674
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
636
675
  @telemetry.send_api_usage_telemetry(
@@ -676,7 +715,7 @@ class Nystroem(BaseTransformer):
676
715
  drop_input_cols=self._drop_input_cols,
677
716
  expected_output_cols_type="float",
678
717
  )
679
- expected_output_cols = self._align_expected_output_names(
718
+ expected_output_cols, _ = self._align_expected_output(
680
719
  inference_method, dataset, expected_output_cols, output_cols_prefix
681
720
  )
682
721
 
@@ -742,7 +781,7 @@ class Nystroem(BaseTransformer):
742
781
  drop_input_cols=self._drop_input_cols,
743
782
  expected_output_cols_type="float",
744
783
  )
745
- expected_output_cols = self._align_expected_output_names(
784
+ expected_output_cols, _ = self._align_expected_output(
746
785
  inference_method, dataset, expected_output_cols, output_cols_prefix
747
786
  )
748
787
  elif isinstance(dataset, pd.DataFrame):
@@ -805,7 +844,7 @@ class Nystroem(BaseTransformer):
805
844
  drop_input_cols=self._drop_input_cols,
806
845
  expected_output_cols_type="float",
807
846
  )
808
- expected_output_cols = self._align_expected_output_names(
847
+ expected_output_cols, _ = self._align_expected_output(
809
848
  inference_method, dataset, expected_output_cols, output_cols_prefix
810
849
  )
811
850
 
@@ -870,7 +909,7 @@ class Nystroem(BaseTransformer):
870
909
  drop_input_cols = self._drop_input_cols,
871
910
  expected_output_cols_type="float",
872
911
  )
873
- expected_output_cols = self._align_expected_output_names(
912
+ expected_output_cols, _ = self._align_expected_output(
874
913
  inference_method, dataset, expected_output_cols, output_cols_prefix
875
914
  )
876
915