snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +156 -20
  5. snowflake/ml/_internal/utils/identifier.py +48 -11
  6. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  7. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  8. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  9. snowflake/ml/_internal/utils/table_manager.py +19 -1
  10. snowflake/ml/_internal/utils/uri.py +2 -2
  11. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  12. snowflake/ml/data/data_connector.py +88 -9
  13. snowflake/ml/data/data_ingestor.py +18 -1
  14. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  15. snowflake/ml/data/torch_utils.py +68 -0
  16. snowflake/ml/dataset/dataset.py +1 -3
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +9 -3
  19. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  20. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  21. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  22. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  23. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  24. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  26. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  27. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  29. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  31. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  32. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  33. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  34. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  35. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  36. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  37. snowflake/ml/feature_store/feature_store.py +100 -41
  38. snowflake/ml/feature_store/feature_view.py +149 -5
  39. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  40. snowflake/ml/fileset/fileset.py +1 -1
  41. snowflake/ml/fileset/sfcfs.py +9 -3
  42. snowflake/ml/model/_client/model/model_impl.py +11 -2
  43. snowflake/ml/model/_client/model/model_version_impl.py +186 -20
  44. snowflake/ml/model/_client/ops/model_ops.py +144 -30
  45. snowflake/ml/model/_client/ops/service_ops.py +312 -0
  46. snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
  47. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
  48. snowflake/ml/model/_client/sql/model_version.py +13 -4
  49. snowflake/ml/model/_client/sql/service.py +196 -0
  50. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  51. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  52. snowflake/ml/model/_model_composer/model_composer.py +5 -0
  53. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
  54. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  55. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  57. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  58. snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
  59. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  60. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
  61. snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
  62. snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
  63. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  64. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  65. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  66. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  67. snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
  68. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
  69. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  70. snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
  71. snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
  72. snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
  73. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
  74. snowflake/ml/model/_packager/model_packager.py +4 -1
  75. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  76. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  77. snowflake/ml/model/_signatures/utils.py +9 -0
  78. snowflake/ml/model/models/llm.py +3 -1
  79. snowflake/ml/model/type_hints.py +10 -4
  80. snowflake/ml/modeling/_internal/constants.py +1 -0
  81. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  82. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  83. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  84. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  85. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  86. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  87. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  88. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  89. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  90. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  91. snowflake/ml/modeling/cluster/birch.py +60 -21
  92. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  93. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  94. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  95. snowflake/ml/modeling/cluster/k_means.py +60 -21
  96. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  97. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  98. snowflake/ml/modeling/cluster/optics.py +60 -21
  99. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  100. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  101. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  102. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  103. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  104. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  105. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  106. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  107. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  108. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  109. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  110. snowflake/ml/modeling/covariance/oas.py +60 -21
  111. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  112. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  113. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  114. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  115. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  116. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  117. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  118. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  119. snowflake/ml/modeling/decomposition/pca.py +60 -21
  120. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  121. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  122. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  123. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  124. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  125. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  126. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  127. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  128. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  129. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  130. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  131. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  132. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  133. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  134. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  135. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  136. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  137. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  138. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  139. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  140. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  141. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  142. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  143. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  144. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  145. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  146. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  147. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  148. snowflake/ml/modeling/framework/base.py +28 -19
  149. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  150. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  151. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  152. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  153. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  154. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  155. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  156. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  157. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  158. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  159. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  160. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  161. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  162. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  163. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  164. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  165. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  166. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  167. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  168. snowflake/ml/modeling/linear_model/lars.py +60 -21
  169. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  170. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  171. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  172. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  173. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  174. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  175. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  176. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  177. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  178. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  179. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  180. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  181. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  182. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  183. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  184. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  185. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  186. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  187. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  188. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  189. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  190. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  191. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  192. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  193. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  194. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  195. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  196. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  197. snowflake/ml/modeling/manifold/isomap.py +60 -21
  198. snowflake/ml/modeling/manifold/mds.py +60 -21
  199. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  200. snowflake/ml/modeling/manifold/tsne.py +60 -21
  201. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  202. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  203. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  204. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  205. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  206. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  207. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  208. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  209. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  210. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  211. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  212. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  213. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  214. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  215. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  216. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  217. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  218. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  219. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  220. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  221. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  222. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  223. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  224. snowflake/ml/modeling/pipeline/pipeline.py +4 -12
  225. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  226. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  227. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  228. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  229. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  230. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  231. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  232. snowflake/ml/modeling/svm/svc.py +60 -21
  233. snowflake/ml/modeling/svm/svr.py +60 -21
  234. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  235. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  236. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  237. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  238. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  239. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  240. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  241. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  242. snowflake/ml/registry/_manager/model_manager.py +20 -2
  243. snowflake/ml/registry/model_registry.py +1 -1
  244. snowflake/ml/registry/registry.py +1 -2
  245. snowflake/ml/utils/sql_client.py +22 -0
  246. snowflake/ml/version.py +1 -1
  247. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
  248. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
  249. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  250. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  251. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  252. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -589,12 +586,23 @@ class DictionaryLearning(BaseTransformer):
589
586
  autogenerated=self._autogenerated,
590
587
  subproject=_SUBPROJECT,
591
588
  )
592
- output_result, fitted_estimator = model_trainer.train_fit_predict(
593
- drop_input_cols=self._drop_input_cols,
594
- expected_output_cols_list=(
595
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
596
- ),
589
+ expected_output_cols = (
590
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
597
591
  )
592
+ if isinstance(dataset, DataFrame):
593
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
594
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
595
+ )
596
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
597
+ drop_input_cols=self._drop_input_cols,
598
+ expected_output_cols_list=expected_output_cols,
599
+ example_output_pd_df=example_output_pd_df,
600
+ )
601
+ else:
602
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
603
+ drop_input_cols=self._drop_input_cols,
604
+ expected_output_cols_list=expected_output_cols,
605
+ )
598
606
  self._sklearn_object = fitted_estimator
599
607
  self._is_fitted = True
600
608
  return output_result
@@ -675,12 +683,41 @@ class DictionaryLearning(BaseTransformer):
675
683
 
676
684
  return rv
677
685
 
678
- def _align_expected_output_names(
679
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
680
- ) -> List[str]:
686
+ def _align_expected_output(
687
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
688
+ ) -> Tuple[List[str], pd.DataFrame]:
689
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
690
+ and output dataframe with 1 line.
691
+ If the method is fit_predict, run 2 lines of data.
692
+ """
681
693
  # in case the inferred output column names dimension is different
682
694
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
683
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
695
+
696
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
697
+ # so change the minimum of number of rows to 2
698
+ num_examples = 2
699
+ statement_params = telemetry.get_function_usage_statement_params(
700
+ project=_PROJECT,
701
+ subproject=_SUBPROJECT,
702
+ function_name=telemetry.get_statement_params_full_func_name(
703
+ inspect.currentframe(), DictionaryLearning.__class__.__name__
704
+ ),
705
+ api_calls=[Session.call],
706
+ custom_tags={"autogen": True} if self._autogenerated else None,
707
+ )
708
+ if output_cols_prefix == "fit_predict_":
709
+ if hasattr(self._sklearn_object, "n_clusters"):
710
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
711
+ num_examples = self._sklearn_object.n_clusters
712
+ elif hasattr(self._sklearn_object, "min_samples"):
713
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
714
+ num_examples = self._sklearn_object.min_samples
715
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
716
+ # LocalOutlierFactor expects n_neighbors <= n_samples
717
+ num_examples = self._sklearn_object.n_neighbors
718
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
719
+ else:
720
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
684
721
 
685
722
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
686
723
  # seen during the fit.
@@ -692,12 +729,14 @@ class DictionaryLearning(BaseTransformer):
692
729
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
693
730
  if self.sample_weight_col:
694
731
  output_df_columns_set -= set(self.sample_weight_col)
732
+
695
733
  # if the dimension of inferred output column names is correct; use it
696
734
  if len(expected_output_cols_list) == len(output_df_columns_set):
697
- return expected_output_cols_list
735
+ return expected_output_cols_list, output_df_pd
698
736
  # otherwise, use the sklearn estimator's output
699
737
  else:
700
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
738
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
739
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
701
740
 
702
741
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
703
742
  @telemetry.send_api_usage_telemetry(
@@ -743,7 +782,7 @@ class DictionaryLearning(BaseTransformer):
743
782
  drop_input_cols=self._drop_input_cols,
744
783
  expected_output_cols_type="float",
745
784
  )
746
- expected_output_cols = self._align_expected_output_names(
785
+ expected_output_cols, _ = self._align_expected_output(
747
786
  inference_method, dataset, expected_output_cols, output_cols_prefix
748
787
  )
749
788
 
@@ -809,7 +848,7 @@ class DictionaryLearning(BaseTransformer):
809
848
  drop_input_cols=self._drop_input_cols,
810
849
  expected_output_cols_type="float",
811
850
  )
812
- expected_output_cols = self._align_expected_output_names(
851
+ expected_output_cols, _ = self._align_expected_output(
813
852
  inference_method, dataset, expected_output_cols, output_cols_prefix
814
853
  )
815
854
  elif isinstance(dataset, pd.DataFrame):
@@ -872,7 +911,7 @@ class DictionaryLearning(BaseTransformer):
872
911
  drop_input_cols=self._drop_input_cols,
873
912
  expected_output_cols_type="float",
874
913
  )
875
- expected_output_cols = self._align_expected_output_names(
914
+ expected_output_cols, _ = self._align_expected_output(
876
915
  inference_method, dataset, expected_output_cols, output_cols_prefix
877
916
  )
878
917
 
@@ -937,7 +976,7 @@ class DictionaryLearning(BaseTransformer):
937
976
  drop_input_cols = self._drop_input_cols,
938
977
  expected_output_cols_type="float",
939
978
  )
940
- expected_output_cols = self._align_expected_output_names(
979
+ expected_output_cols, _ = self._align_expected_output(
941
980
  inference_method, dataset, expected_output_cols, output_cols_prefix
942
981
  )
943
982
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -526,12 +523,23 @@ class FactorAnalysis(BaseTransformer):
526
523
  autogenerated=self._autogenerated,
527
524
  subproject=_SUBPROJECT,
528
525
  )
529
- output_result, fitted_estimator = model_trainer.train_fit_predict(
530
- drop_input_cols=self._drop_input_cols,
531
- expected_output_cols_list=(
532
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
533
- ),
526
+ expected_output_cols = (
527
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
534
528
  )
529
+ if isinstance(dataset, DataFrame):
530
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
531
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
532
+ )
533
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
534
+ drop_input_cols=self._drop_input_cols,
535
+ expected_output_cols_list=expected_output_cols,
536
+ example_output_pd_df=example_output_pd_df,
537
+ )
538
+ else:
539
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
540
+ drop_input_cols=self._drop_input_cols,
541
+ expected_output_cols_list=expected_output_cols,
542
+ )
535
543
  self._sklearn_object = fitted_estimator
536
544
  self._is_fitted = True
537
545
  return output_result
@@ -612,12 +620,41 @@ class FactorAnalysis(BaseTransformer):
612
620
 
613
621
  return rv
614
622
 
615
- def _align_expected_output_names(
616
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
617
- ) -> List[str]:
623
+ def _align_expected_output(
624
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
625
+ ) -> Tuple[List[str], pd.DataFrame]:
626
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
627
+ and output dataframe with 1 line.
628
+ If the method is fit_predict, run 2 lines of data.
629
+ """
618
630
  # in case the inferred output column names dimension is different
619
631
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
620
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
632
+
633
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
634
+ # so change the minimum of number of rows to 2
635
+ num_examples = 2
636
+ statement_params = telemetry.get_function_usage_statement_params(
637
+ project=_PROJECT,
638
+ subproject=_SUBPROJECT,
639
+ function_name=telemetry.get_statement_params_full_func_name(
640
+ inspect.currentframe(), FactorAnalysis.__class__.__name__
641
+ ),
642
+ api_calls=[Session.call],
643
+ custom_tags={"autogen": True} if self._autogenerated else None,
644
+ )
645
+ if output_cols_prefix == "fit_predict_":
646
+ if hasattr(self._sklearn_object, "n_clusters"):
647
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
648
+ num_examples = self._sklearn_object.n_clusters
649
+ elif hasattr(self._sklearn_object, "min_samples"):
650
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
651
+ num_examples = self._sklearn_object.min_samples
652
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
653
+ # LocalOutlierFactor expects n_neighbors <= n_samples
654
+ num_examples = self._sklearn_object.n_neighbors
655
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
656
+ else:
657
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
621
658
 
622
659
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
623
660
  # seen during the fit.
@@ -629,12 +666,14 @@ class FactorAnalysis(BaseTransformer):
629
666
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
630
667
  if self.sample_weight_col:
631
668
  output_df_columns_set -= set(self.sample_weight_col)
669
+
632
670
  # if the dimension of inferred output column names is correct; use it
633
671
  if len(expected_output_cols_list) == len(output_df_columns_set):
634
- return expected_output_cols_list
672
+ return expected_output_cols_list, output_df_pd
635
673
  # otherwise, use the sklearn estimator's output
636
674
  else:
637
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
675
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
676
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
638
677
 
639
678
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
640
679
  @telemetry.send_api_usage_telemetry(
@@ -680,7 +719,7 @@ class FactorAnalysis(BaseTransformer):
680
719
  drop_input_cols=self._drop_input_cols,
681
720
  expected_output_cols_type="float",
682
721
  )
683
- expected_output_cols = self._align_expected_output_names(
722
+ expected_output_cols, _ = self._align_expected_output(
684
723
  inference_method, dataset, expected_output_cols, output_cols_prefix
685
724
  )
686
725
 
@@ -746,7 +785,7 @@ class FactorAnalysis(BaseTransformer):
746
785
  drop_input_cols=self._drop_input_cols,
747
786
  expected_output_cols_type="float",
748
787
  )
749
- expected_output_cols = self._align_expected_output_names(
788
+ expected_output_cols, _ = self._align_expected_output(
750
789
  inference_method, dataset, expected_output_cols, output_cols_prefix
751
790
  )
752
791
  elif isinstance(dataset, pd.DataFrame):
@@ -809,7 +848,7 @@ class FactorAnalysis(BaseTransformer):
809
848
  drop_input_cols=self._drop_input_cols,
810
849
  expected_output_cols_type="float",
811
850
  )
812
- expected_output_cols = self._align_expected_output_names(
851
+ expected_output_cols, _ = self._align_expected_output(
813
852
  inference_method, dataset, expected_output_cols, output_cols_prefix
814
853
  )
815
854
 
@@ -876,7 +915,7 @@ class FactorAnalysis(BaseTransformer):
876
915
  drop_input_cols = self._drop_input_cols,
877
916
  expected_output_cols_type="float",
878
917
  )
879
- expected_output_cols = self._align_expected_output_names(
918
+ expected_output_cols, _ = self._align_expected_output(
880
919
  inference_method, dataset, expected_output_cols, output_cols_prefix
881
920
  )
882
921
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -544,12 +541,23 @@ class FastICA(BaseTransformer):
544
541
  autogenerated=self._autogenerated,
545
542
  subproject=_SUBPROJECT,
546
543
  )
547
- output_result, fitted_estimator = model_trainer.train_fit_predict(
548
- drop_input_cols=self._drop_input_cols,
549
- expected_output_cols_list=(
550
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
551
- ),
544
+ expected_output_cols = (
545
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
552
546
  )
547
+ if isinstance(dataset, DataFrame):
548
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
549
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
550
+ )
551
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
552
+ drop_input_cols=self._drop_input_cols,
553
+ expected_output_cols_list=expected_output_cols,
554
+ example_output_pd_df=example_output_pd_df,
555
+ )
556
+ else:
557
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
558
+ drop_input_cols=self._drop_input_cols,
559
+ expected_output_cols_list=expected_output_cols,
560
+ )
553
561
  self._sklearn_object = fitted_estimator
554
562
  self._is_fitted = True
555
563
  return output_result
@@ -630,12 +638,41 @@ class FastICA(BaseTransformer):
630
638
 
631
639
  return rv
632
640
 
633
- def _align_expected_output_names(
634
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
635
- ) -> List[str]:
641
+ def _align_expected_output(
642
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
643
+ ) -> Tuple[List[str], pd.DataFrame]:
644
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
645
+ and output dataframe with 1 line.
646
+ If the method is fit_predict, run 2 lines of data.
647
+ """
636
648
  # in case the inferred output column names dimension is different
637
649
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
638
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
650
+
651
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
652
+ # so change the minimum of number of rows to 2
653
+ num_examples = 2
654
+ statement_params = telemetry.get_function_usage_statement_params(
655
+ project=_PROJECT,
656
+ subproject=_SUBPROJECT,
657
+ function_name=telemetry.get_statement_params_full_func_name(
658
+ inspect.currentframe(), FastICA.__class__.__name__
659
+ ),
660
+ api_calls=[Session.call],
661
+ custom_tags={"autogen": True} if self._autogenerated else None,
662
+ )
663
+ if output_cols_prefix == "fit_predict_":
664
+ if hasattr(self._sklearn_object, "n_clusters"):
665
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
666
+ num_examples = self._sklearn_object.n_clusters
667
+ elif hasattr(self._sklearn_object, "min_samples"):
668
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
669
+ num_examples = self._sklearn_object.min_samples
670
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
671
+ # LocalOutlierFactor expects n_neighbors <= n_samples
672
+ num_examples = self._sklearn_object.n_neighbors
673
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
674
+ else:
675
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
639
676
 
640
677
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
641
678
  # seen during the fit.
@@ -647,12 +684,14 @@ class FastICA(BaseTransformer):
647
684
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
648
685
  if self.sample_weight_col:
649
686
  output_df_columns_set -= set(self.sample_weight_col)
687
+
650
688
  # if the dimension of inferred output column names is correct; use it
651
689
  if len(expected_output_cols_list) == len(output_df_columns_set):
652
- return expected_output_cols_list
690
+ return expected_output_cols_list, output_df_pd
653
691
  # otherwise, use the sklearn estimator's output
654
692
  else:
655
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
693
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
694
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
656
695
 
657
696
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
658
697
  @telemetry.send_api_usage_telemetry(
@@ -698,7 +737,7 @@ class FastICA(BaseTransformer):
698
737
  drop_input_cols=self._drop_input_cols,
699
738
  expected_output_cols_type="float",
700
739
  )
701
- expected_output_cols = self._align_expected_output_names(
740
+ expected_output_cols, _ = self._align_expected_output(
702
741
  inference_method, dataset, expected_output_cols, output_cols_prefix
703
742
  )
704
743
 
@@ -764,7 +803,7 @@ class FastICA(BaseTransformer):
764
803
  drop_input_cols=self._drop_input_cols,
765
804
  expected_output_cols_type="float",
766
805
  )
767
- expected_output_cols = self._align_expected_output_names(
806
+ expected_output_cols, _ = self._align_expected_output(
768
807
  inference_method, dataset, expected_output_cols, output_cols_prefix
769
808
  )
770
809
  elif isinstance(dataset, pd.DataFrame):
@@ -827,7 +866,7 @@ class FastICA(BaseTransformer):
827
866
  drop_input_cols=self._drop_input_cols,
828
867
  expected_output_cols_type="float",
829
868
  )
830
- expected_output_cols = self._align_expected_output_names(
869
+ expected_output_cols, _ = self._align_expected_output(
831
870
  inference_method, dataset, expected_output_cols, output_cols_prefix
832
871
  )
833
872
 
@@ -892,7 +931,7 @@ class FastICA(BaseTransformer):
892
931
  drop_input_cols = self._drop_input_cols,
893
932
  expected_output_cols_type="float",
894
933
  )
895
- expected_output_cols = self._align_expected_output_names(
934
+ expected_output_cols, _ = self._align_expected_output(
896
935
  inference_method, dataset, expected_output_cols, output_cols_prefix
897
936
  )
898
937
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -496,12 +493,23 @@ class IncrementalPCA(BaseTransformer):
496
493
  autogenerated=self._autogenerated,
497
494
  subproject=_SUBPROJECT,
498
495
  )
499
- output_result, fitted_estimator = model_trainer.train_fit_predict(
500
- drop_input_cols=self._drop_input_cols,
501
- expected_output_cols_list=(
502
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
503
- ),
496
+ expected_output_cols = (
497
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
504
498
  )
499
+ if isinstance(dataset, DataFrame):
500
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
501
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
502
+ )
503
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
504
+ drop_input_cols=self._drop_input_cols,
505
+ expected_output_cols_list=expected_output_cols,
506
+ example_output_pd_df=example_output_pd_df,
507
+ )
508
+ else:
509
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
510
+ drop_input_cols=self._drop_input_cols,
511
+ expected_output_cols_list=expected_output_cols,
512
+ )
505
513
  self._sklearn_object = fitted_estimator
506
514
  self._is_fitted = True
507
515
  return output_result
@@ -582,12 +590,41 @@ class IncrementalPCA(BaseTransformer):
582
590
 
583
591
  return rv
584
592
 
585
- def _align_expected_output_names(
586
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
587
- ) -> List[str]:
593
+ def _align_expected_output(
594
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
595
+ ) -> Tuple[List[str], pd.DataFrame]:
596
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
597
+ and output dataframe with 1 line.
598
+ If the method is fit_predict, run 2 lines of data.
599
+ """
588
600
  # in case the inferred output column names dimension is different
589
601
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
590
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
602
+
603
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
604
+ # so change the minimum of number of rows to 2
605
+ num_examples = 2
606
+ statement_params = telemetry.get_function_usage_statement_params(
607
+ project=_PROJECT,
608
+ subproject=_SUBPROJECT,
609
+ function_name=telemetry.get_statement_params_full_func_name(
610
+ inspect.currentframe(), IncrementalPCA.__class__.__name__
611
+ ),
612
+ api_calls=[Session.call],
613
+ custom_tags={"autogen": True} if self._autogenerated else None,
614
+ )
615
+ if output_cols_prefix == "fit_predict_":
616
+ if hasattr(self._sklearn_object, "n_clusters"):
617
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
618
+ num_examples = self._sklearn_object.n_clusters
619
+ elif hasattr(self._sklearn_object, "min_samples"):
620
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
621
+ num_examples = self._sklearn_object.min_samples
622
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
623
+ # LocalOutlierFactor expects n_neighbors <= n_samples
624
+ num_examples = self._sklearn_object.n_neighbors
625
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
626
+ else:
627
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
591
628
 
592
629
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
593
630
  # seen during the fit.
@@ -599,12 +636,14 @@ class IncrementalPCA(BaseTransformer):
599
636
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
600
637
  if self.sample_weight_col:
601
638
  output_df_columns_set -= set(self.sample_weight_col)
639
+
602
640
  # if the dimension of inferred output column names is correct; use it
603
641
  if len(expected_output_cols_list) == len(output_df_columns_set):
604
- return expected_output_cols_list
642
+ return expected_output_cols_list, output_df_pd
605
643
  # otherwise, use the sklearn estimator's output
606
644
  else:
607
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
645
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
646
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
608
647
 
609
648
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
610
649
  @telemetry.send_api_usage_telemetry(
@@ -650,7 +689,7 @@ class IncrementalPCA(BaseTransformer):
650
689
  drop_input_cols=self._drop_input_cols,
651
690
  expected_output_cols_type="float",
652
691
  )
653
- expected_output_cols = self._align_expected_output_names(
692
+ expected_output_cols, _ = self._align_expected_output(
654
693
  inference_method, dataset, expected_output_cols, output_cols_prefix
655
694
  )
656
695
 
@@ -716,7 +755,7 @@ class IncrementalPCA(BaseTransformer):
716
755
  drop_input_cols=self._drop_input_cols,
717
756
  expected_output_cols_type="float",
718
757
  )
719
- expected_output_cols = self._align_expected_output_names(
758
+ expected_output_cols, _ = self._align_expected_output(
720
759
  inference_method, dataset, expected_output_cols, output_cols_prefix
721
760
  )
722
761
  elif isinstance(dataset, pd.DataFrame):
@@ -779,7 +818,7 @@ class IncrementalPCA(BaseTransformer):
779
818
  drop_input_cols=self._drop_input_cols,
780
819
  expected_output_cols_type="float",
781
820
  )
782
- expected_output_cols = self._align_expected_output_names(
821
+ expected_output_cols, _ = self._align_expected_output(
783
822
  inference_method, dataset, expected_output_cols, output_cols_prefix
784
823
  )
785
824
 
@@ -844,7 +883,7 @@ class IncrementalPCA(BaseTransformer):
844
883
  drop_input_cols = self._drop_input_cols,
845
884
  expected_output_cols_type="float",
846
885
  )
847
- expected_output_cols = self._align_expected_output_names(
886
+ expected_output_cols, _ = self._align_expected_output(
848
887
  inference_method, dataset, expected_output_cols, output_cols_prefix
849
888
  )
850
889