snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +156 -20
  5. snowflake/ml/_internal/utils/identifier.py +48 -11
  6. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  7. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  8. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  9. snowflake/ml/_internal/utils/table_manager.py +19 -1
  10. snowflake/ml/_internal/utils/uri.py +2 -2
  11. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  12. snowflake/ml/data/data_connector.py +88 -9
  13. snowflake/ml/data/data_ingestor.py +18 -1
  14. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  15. snowflake/ml/data/torch_utils.py +68 -0
  16. snowflake/ml/dataset/dataset.py +1 -3
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +9 -3
  19. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  20. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  21. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  22. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  23. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  24. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  26. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  27. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  29. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  31. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  32. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  33. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  34. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  35. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  36. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  37. snowflake/ml/feature_store/feature_store.py +100 -41
  38. snowflake/ml/feature_store/feature_view.py +149 -5
  39. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  40. snowflake/ml/fileset/fileset.py +1 -1
  41. snowflake/ml/fileset/sfcfs.py +9 -3
  42. snowflake/ml/model/_client/model/model_impl.py +11 -2
  43. snowflake/ml/model/_client/model/model_version_impl.py +186 -20
  44. snowflake/ml/model/_client/ops/model_ops.py +144 -30
  45. snowflake/ml/model/_client/ops/service_ops.py +312 -0
  46. snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
  47. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
  48. snowflake/ml/model/_client/sql/model_version.py +13 -4
  49. snowflake/ml/model/_client/sql/service.py +196 -0
  50. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  51. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  52. snowflake/ml/model/_model_composer/model_composer.py +5 -0
  53. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
  54. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  55. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  57. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  58. snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
  59. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  60. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
  61. snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
  62. snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
  63. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  64. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  65. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  66. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  67. snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
  68. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
  69. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  70. snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
  71. snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
  72. snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
  73. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
  74. snowflake/ml/model/_packager/model_packager.py +4 -1
  75. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  76. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  77. snowflake/ml/model/_signatures/utils.py +9 -0
  78. snowflake/ml/model/models/llm.py +3 -1
  79. snowflake/ml/model/type_hints.py +10 -4
  80. snowflake/ml/modeling/_internal/constants.py +1 -0
  81. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  82. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  83. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  84. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  85. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  86. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  87. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  88. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  89. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  90. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  91. snowflake/ml/modeling/cluster/birch.py +60 -21
  92. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  93. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  94. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  95. snowflake/ml/modeling/cluster/k_means.py +60 -21
  96. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  97. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  98. snowflake/ml/modeling/cluster/optics.py +60 -21
  99. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  100. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  101. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  102. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  103. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  104. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  105. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  106. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  107. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  108. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  109. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  110. snowflake/ml/modeling/covariance/oas.py +60 -21
  111. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  112. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  113. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  114. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  115. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  116. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  117. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  118. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  119. snowflake/ml/modeling/decomposition/pca.py +60 -21
  120. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  121. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  122. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  123. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  124. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  125. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  126. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  127. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  128. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  129. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  130. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  131. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  132. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  133. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  134. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  135. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  136. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  137. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  138. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  139. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  140. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  141. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  142. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  143. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  144. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  145. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  146. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  147. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  148. snowflake/ml/modeling/framework/base.py +28 -19
  149. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  150. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  151. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  152. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  153. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  154. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  155. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  156. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  157. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  158. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  159. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  160. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  161. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  162. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  163. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  164. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  165. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  166. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  167. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  168. snowflake/ml/modeling/linear_model/lars.py +60 -21
  169. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  170. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  171. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  172. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  173. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  174. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  175. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  176. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  177. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  178. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  179. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  180. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  181. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  182. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  183. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  184. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  185. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  186. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  187. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  188. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  189. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  190. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  191. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  192. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  193. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  194. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  195. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  196. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  197. snowflake/ml/modeling/manifold/isomap.py +60 -21
  198. snowflake/ml/modeling/manifold/mds.py +60 -21
  199. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  200. snowflake/ml/modeling/manifold/tsne.py +60 -21
  201. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  202. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  203. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  204. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  205. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  206. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  207. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  208. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  209. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  210. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  211. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  212. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  213. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  214. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  215. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  216. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  217. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  218. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  219. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  220. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  221. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  222. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  223. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  224. snowflake/ml/modeling/pipeline/pipeline.py +4 -12
  225. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  226. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  227. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  228. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  229. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  230. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  231. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  232. snowflake/ml/modeling/svm/svc.py +60 -21
  233. snowflake/ml/modeling/svm/svr.py +60 -21
  234. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  235. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  236. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  237. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  238. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  239. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  240. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  241. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  242. snowflake/ml/registry/_manager/model_manager.py +20 -2
  243. snowflake/ml/registry/model_registry.py +1 -1
  244. snowflake/ml/registry/registry.py +1 -2
  245. snowflake/ml/utils/sql_client.py +22 -0
  246. snowflake/ml/version.py +1 -1
  247. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
  248. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
  249. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  250. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  251. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  252. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -482,12 +479,23 @@ class LedoitWolf(BaseTransformer):
482
479
  autogenerated=self._autogenerated,
483
480
  subproject=_SUBPROJECT,
484
481
  )
485
- output_result, fitted_estimator = model_trainer.train_fit_predict(
486
- drop_input_cols=self._drop_input_cols,
487
- expected_output_cols_list=(
488
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
489
- ),
482
+ expected_output_cols = (
483
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
490
484
  )
485
+ if isinstance(dataset, DataFrame):
486
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
487
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
488
+ )
489
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
490
+ drop_input_cols=self._drop_input_cols,
491
+ expected_output_cols_list=expected_output_cols,
492
+ example_output_pd_df=example_output_pd_df,
493
+ )
494
+ else:
495
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
496
+ drop_input_cols=self._drop_input_cols,
497
+ expected_output_cols_list=expected_output_cols,
498
+ )
491
499
  self._sklearn_object = fitted_estimator
492
500
  self._is_fitted = True
493
501
  return output_result
@@ -566,12 +574,41 @@ class LedoitWolf(BaseTransformer):
566
574
 
567
575
  return rv
568
576
 
569
- def _align_expected_output_names(
570
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
571
- ) -> List[str]:
577
+ def _align_expected_output(
578
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
579
+ ) -> Tuple[List[str], pd.DataFrame]:
580
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
581
+ and output dataframe with 1 line.
582
+ If the method is fit_predict, run 2 lines of data.
583
+ """
572
584
  # in case the inferred output column names dimension is different
573
585
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
574
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
586
+
587
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
588
+ # so change the minimum of number of rows to 2
589
+ num_examples = 2
590
+ statement_params = telemetry.get_function_usage_statement_params(
591
+ project=_PROJECT,
592
+ subproject=_SUBPROJECT,
593
+ function_name=telemetry.get_statement_params_full_func_name(
594
+ inspect.currentframe(), LedoitWolf.__class__.__name__
595
+ ),
596
+ api_calls=[Session.call],
597
+ custom_tags={"autogen": True} if self._autogenerated else None,
598
+ )
599
+ if output_cols_prefix == "fit_predict_":
600
+ if hasattr(self._sklearn_object, "n_clusters"):
601
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
602
+ num_examples = self._sklearn_object.n_clusters
603
+ elif hasattr(self._sklearn_object, "min_samples"):
604
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
605
+ num_examples = self._sklearn_object.min_samples
606
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
607
+ # LocalOutlierFactor expects n_neighbors <= n_samples
608
+ num_examples = self._sklearn_object.n_neighbors
609
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
610
+ else:
611
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
575
612
 
576
613
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
577
614
  # seen during the fit.
@@ -583,12 +620,14 @@ class LedoitWolf(BaseTransformer):
583
620
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
584
621
  if self.sample_weight_col:
585
622
  output_df_columns_set -= set(self.sample_weight_col)
623
+
586
624
  # if the dimension of inferred output column names is correct; use it
587
625
  if len(expected_output_cols_list) == len(output_df_columns_set):
588
- return expected_output_cols_list
626
+ return expected_output_cols_list, output_df_pd
589
627
  # otherwise, use the sklearn estimator's output
590
628
  else:
591
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
629
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
630
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
592
631
 
593
632
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
594
633
  @telemetry.send_api_usage_telemetry(
@@ -634,7 +673,7 @@ class LedoitWolf(BaseTransformer):
634
673
  drop_input_cols=self._drop_input_cols,
635
674
  expected_output_cols_type="float",
636
675
  )
637
- expected_output_cols = self._align_expected_output_names(
676
+ expected_output_cols, _ = self._align_expected_output(
638
677
  inference_method, dataset, expected_output_cols, output_cols_prefix
639
678
  )
640
679
 
@@ -700,7 +739,7 @@ class LedoitWolf(BaseTransformer):
700
739
  drop_input_cols=self._drop_input_cols,
701
740
  expected_output_cols_type="float",
702
741
  )
703
- expected_output_cols = self._align_expected_output_names(
742
+ expected_output_cols, _ = self._align_expected_output(
704
743
  inference_method, dataset, expected_output_cols, output_cols_prefix
705
744
  )
706
745
  elif isinstance(dataset, pd.DataFrame):
@@ -763,7 +802,7 @@ class LedoitWolf(BaseTransformer):
763
802
  drop_input_cols=self._drop_input_cols,
764
803
  expected_output_cols_type="float",
765
804
  )
766
- expected_output_cols = self._align_expected_output_names(
805
+ expected_output_cols, _ = self._align_expected_output(
767
806
  inference_method, dataset, expected_output_cols, output_cols_prefix
768
807
  )
769
808
 
@@ -828,7 +867,7 @@ class LedoitWolf(BaseTransformer):
828
867
  drop_input_cols = self._drop_input_cols,
829
868
  expected_output_cols_type="float",
830
869
  )
831
- expected_output_cols = self._align_expected_output_names(
870
+ expected_output_cols, _ = self._align_expected_output(
832
871
  inference_method, dataset, expected_output_cols, output_cols_prefix
833
872
  )
834
873
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -494,12 +491,23 @@ class MinCovDet(BaseTransformer):
494
491
  autogenerated=self._autogenerated,
495
492
  subproject=_SUBPROJECT,
496
493
  )
497
- output_result, fitted_estimator = model_trainer.train_fit_predict(
498
- drop_input_cols=self._drop_input_cols,
499
- expected_output_cols_list=(
500
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
501
- ),
494
+ expected_output_cols = (
495
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
502
496
  )
497
+ if isinstance(dataset, DataFrame):
498
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
499
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
500
+ )
501
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
502
+ drop_input_cols=self._drop_input_cols,
503
+ expected_output_cols_list=expected_output_cols,
504
+ example_output_pd_df=example_output_pd_df,
505
+ )
506
+ else:
507
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
508
+ drop_input_cols=self._drop_input_cols,
509
+ expected_output_cols_list=expected_output_cols,
510
+ )
503
511
  self._sklearn_object = fitted_estimator
504
512
  self._is_fitted = True
505
513
  return output_result
@@ -578,12 +586,41 @@ class MinCovDet(BaseTransformer):
578
586
 
579
587
  return rv
580
588
 
581
- def _align_expected_output_names(
582
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
583
- ) -> List[str]:
589
+ def _align_expected_output(
590
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
591
+ ) -> Tuple[List[str], pd.DataFrame]:
592
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
593
+ and output dataframe with 1 line.
594
+ If the method is fit_predict, run 2 lines of data.
595
+ """
584
596
  # in case the inferred output column names dimension is different
585
597
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
586
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
598
+
599
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
600
+ # so change the minimum of number of rows to 2
601
+ num_examples = 2
602
+ statement_params = telemetry.get_function_usage_statement_params(
603
+ project=_PROJECT,
604
+ subproject=_SUBPROJECT,
605
+ function_name=telemetry.get_statement_params_full_func_name(
606
+ inspect.currentframe(), MinCovDet.__class__.__name__
607
+ ),
608
+ api_calls=[Session.call],
609
+ custom_tags={"autogen": True} if self._autogenerated else None,
610
+ )
611
+ if output_cols_prefix == "fit_predict_":
612
+ if hasattr(self._sklearn_object, "n_clusters"):
613
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
614
+ num_examples = self._sklearn_object.n_clusters
615
+ elif hasattr(self._sklearn_object, "min_samples"):
616
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
617
+ num_examples = self._sklearn_object.min_samples
618
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
619
+ # LocalOutlierFactor expects n_neighbors <= n_samples
620
+ num_examples = self._sklearn_object.n_neighbors
621
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
622
+ else:
623
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
587
624
 
588
625
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
589
626
  # seen during the fit.
@@ -595,12 +632,14 @@ class MinCovDet(BaseTransformer):
595
632
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
596
633
  if self.sample_weight_col:
597
634
  output_df_columns_set -= set(self.sample_weight_col)
635
+
598
636
  # if the dimension of inferred output column names is correct; use it
599
637
  if len(expected_output_cols_list) == len(output_df_columns_set):
600
- return expected_output_cols_list
638
+ return expected_output_cols_list, output_df_pd
601
639
  # otherwise, use the sklearn estimator's output
602
640
  else:
603
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
641
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
642
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
604
643
 
605
644
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
606
645
  @telemetry.send_api_usage_telemetry(
@@ -646,7 +685,7 @@ class MinCovDet(BaseTransformer):
646
685
  drop_input_cols=self._drop_input_cols,
647
686
  expected_output_cols_type="float",
648
687
  )
649
- expected_output_cols = self._align_expected_output_names(
688
+ expected_output_cols, _ = self._align_expected_output(
650
689
  inference_method, dataset, expected_output_cols, output_cols_prefix
651
690
  )
652
691
 
@@ -712,7 +751,7 @@ class MinCovDet(BaseTransformer):
712
751
  drop_input_cols=self._drop_input_cols,
713
752
  expected_output_cols_type="float",
714
753
  )
715
- expected_output_cols = self._align_expected_output_names(
754
+ expected_output_cols, _ = self._align_expected_output(
716
755
  inference_method, dataset, expected_output_cols, output_cols_prefix
717
756
  )
718
757
  elif isinstance(dataset, pd.DataFrame):
@@ -775,7 +814,7 @@ class MinCovDet(BaseTransformer):
775
814
  drop_input_cols=self._drop_input_cols,
776
815
  expected_output_cols_type="float",
777
816
  )
778
- expected_output_cols = self._align_expected_output_names(
817
+ expected_output_cols, _ = self._align_expected_output(
779
818
  inference_method, dataset, expected_output_cols, output_cols_prefix
780
819
  )
781
820
 
@@ -840,7 +879,7 @@ class MinCovDet(BaseTransformer):
840
879
  drop_input_cols = self._drop_input_cols,
841
880
  expected_output_cols_type="float",
842
881
  )
843
- expected_output_cols = self._align_expected_output_names(
882
+ expected_output_cols, _ = self._align_expected_output(
844
883
  inference_method, dataset, expected_output_cols, output_cols_prefix
845
884
  )
846
885
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -475,12 +472,23 @@ class OAS(BaseTransformer):
475
472
  autogenerated=self._autogenerated,
476
473
  subproject=_SUBPROJECT,
477
474
  )
478
- output_result, fitted_estimator = model_trainer.train_fit_predict(
479
- drop_input_cols=self._drop_input_cols,
480
- expected_output_cols_list=(
481
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
482
- ),
475
+ expected_output_cols = (
476
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
483
477
  )
478
+ if isinstance(dataset, DataFrame):
479
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
480
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
481
+ )
482
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
483
+ drop_input_cols=self._drop_input_cols,
484
+ expected_output_cols_list=expected_output_cols,
485
+ example_output_pd_df=example_output_pd_df,
486
+ )
487
+ else:
488
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
489
+ drop_input_cols=self._drop_input_cols,
490
+ expected_output_cols_list=expected_output_cols,
491
+ )
484
492
  self._sklearn_object = fitted_estimator
485
493
  self._is_fitted = True
486
494
  return output_result
@@ -559,12 +567,41 @@ class OAS(BaseTransformer):
559
567
 
560
568
  return rv
561
569
 
562
- def _align_expected_output_names(
563
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
564
- ) -> List[str]:
570
+ def _align_expected_output(
571
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
572
+ ) -> Tuple[List[str], pd.DataFrame]:
573
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
574
+ and output dataframe with 1 line.
575
+ If the method is fit_predict, run 2 lines of data.
576
+ """
565
577
  # in case the inferred output column names dimension is different
566
578
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
567
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
579
+
580
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
581
+ # so change the minimum of number of rows to 2
582
+ num_examples = 2
583
+ statement_params = telemetry.get_function_usage_statement_params(
584
+ project=_PROJECT,
585
+ subproject=_SUBPROJECT,
586
+ function_name=telemetry.get_statement_params_full_func_name(
587
+ inspect.currentframe(), OAS.__class__.__name__
588
+ ),
589
+ api_calls=[Session.call],
590
+ custom_tags={"autogen": True} if self._autogenerated else None,
591
+ )
592
+ if output_cols_prefix == "fit_predict_":
593
+ if hasattr(self._sklearn_object, "n_clusters"):
594
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
595
+ num_examples = self._sklearn_object.n_clusters
596
+ elif hasattr(self._sklearn_object, "min_samples"):
597
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
598
+ num_examples = self._sklearn_object.min_samples
599
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
600
+ # LocalOutlierFactor expects n_neighbors <= n_samples
601
+ num_examples = self._sklearn_object.n_neighbors
602
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
603
+ else:
604
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
568
605
 
569
606
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
570
607
  # seen during the fit.
@@ -576,12 +613,14 @@ class OAS(BaseTransformer):
576
613
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
577
614
  if self.sample_weight_col:
578
615
  output_df_columns_set -= set(self.sample_weight_col)
616
+
579
617
  # if the dimension of inferred output column names is correct; use it
580
618
  if len(expected_output_cols_list) == len(output_df_columns_set):
581
- return expected_output_cols_list
619
+ return expected_output_cols_list, output_df_pd
582
620
  # otherwise, use the sklearn estimator's output
583
621
  else:
584
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
622
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
623
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
585
624
 
586
625
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
587
626
  @telemetry.send_api_usage_telemetry(
@@ -627,7 +666,7 @@ class OAS(BaseTransformer):
627
666
  drop_input_cols=self._drop_input_cols,
628
667
  expected_output_cols_type="float",
629
668
  )
630
- expected_output_cols = self._align_expected_output_names(
669
+ expected_output_cols, _ = self._align_expected_output(
631
670
  inference_method, dataset, expected_output_cols, output_cols_prefix
632
671
  )
633
672
 
@@ -693,7 +732,7 @@ class OAS(BaseTransformer):
693
732
  drop_input_cols=self._drop_input_cols,
694
733
  expected_output_cols_type="float",
695
734
  )
696
- expected_output_cols = self._align_expected_output_names(
735
+ expected_output_cols, _ = self._align_expected_output(
697
736
  inference_method, dataset, expected_output_cols, output_cols_prefix
698
737
  )
699
738
  elif isinstance(dataset, pd.DataFrame):
@@ -756,7 +795,7 @@ class OAS(BaseTransformer):
756
795
  drop_input_cols=self._drop_input_cols,
757
796
  expected_output_cols_type="float",
758
797
  )
759
- expected_output_cols = self._align_expected_output_names(
798
+ expected_output_cols, _ = self._align_expected_output(
760
799
  inference_method, dataset, expected_output_cols, output_cols_prefix
761
800
  )
762
801
 
@@ -821,7 +860,7 @@ class OAS(BaseTransformer):
821
860
  drop_input_cols = self._drop_input_cols,
822
861
  expected_output_cols_type="float",
823
862
  )
824
- expected_output_cols = self._align_expected_output_names(
863
+ expected_output_cols, _ = self._align_expected_output(
825
864
  inference_method, dataset, expected_output_cols, output_cols_prefix
826
865
  )
827
866
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -481,12 +478,23 @@ class ShrunkCovariance(BaseTransformer):
481
478
  autogenerated=self._autogenerated,
482
479
  subproject=_SUBPROJECT,
483
480
  )
484
- output_result, fitted_estimator = model_trainer.train_fit_predict(
485
- drop_input_cols=self._drop_input_cols,
486
- expected_output_cols_list=(
487
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
488
- ),
481
+ expected_output_cols = (
482
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
489
483
  )
484
+ if isinstance(dataset, DataFrame):
485
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
486
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
487
+ )
488
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
489
+ drop_input_cols=self._drop_input_cols,
490
+ expected_output_cols_list=expected_output_cols,
491
+ example_output_pd_df=example_output_pd_df,
492
+ )
493
+ else:
494
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
495
+ drop_input_cols=self._drop_input_cols,
496
+ expected_output_cols_list=expected_output_cols,
497
+ )
490
498
  self._sklearn_object = fitted_estimator
491
499
  self._is_fitted = True
492
500
  return output_result
@@ -565,12 +573,41 @@ class ShrunkCovariance(BaseTransformer):
565
573
 
566
574
  return rv
567
575
 
568
- def _align_expected_output_names(
569
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
570
- ) -> List[str]:
576
+ def _align_expected_output(
577
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
578
+ ) -> Tuple[List[str], pd.DataFrame]:
579
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
580
+ and output dataframe with 1 line.
581
+ If the method is fit_predict, run 2 lines of data.
582
+ """
571
583
  # in case the inferred output column names dimension is different
572
584
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
573
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
585
+
586
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
587
+ # so change the minimum of number of rows to 2
588
+ num_examples = 2
589
+ statement_params = telemetry.get_function_usage_statement_params(
590
+ project=_PROJECT,
591
+ subproject=_SUBPROJECT,
592
+ function_name=telemetry.get_statement_params_full_func_name(
593
+ inspect.currentframe(), ShrunkCovariance.__class__.__name__
594
+ ),
595
+ api_calls=[Session.call],
596
+ custom_tags={"autogen": True} if self._autogenerated else None,
597
+ )
598
+ if output_cols_prefix == "fit_predict_":
599
+ if hasattr(self._sklearn_object, "n_clusters"):
600
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
601
+ num_examples = self._sklearn_object.n_clusters
602
+ elif hasattr(self._sklearn_object, "min_samples"):
603
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
604
+ num_examples = self._sklearn_object.min_samples
605
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
606
+ # LocalOutlierFactor expects n_neighbors <= n_samples
607
+ num_examples = self._sklearn_object.n_neighbors
608
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
609
+ else:
610
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
574
611
 
575
612
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
576
613
  # seen during the fit.
@@ -582,12 +619,14 @@ class ShrunkCovariance(BaseTransformer):
582
619
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
583
620
  if self.sample_weight_col:
584
621
  output_df_columns_set -= set(self.sample_weight_col)
622
+
585
623
  # if the dimension of inferred output column names is correct; use it
586
624
  if len(expected_output_cols_list) == len(output_df_columns_set):
587
- return expected_output_cols_list
625
+ return expected_output_cols_list, output_df_pd
588
626
  # otherwise, use the sklearn estimator's output
589
627
  else:
590
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
628
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
629
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
591
630
 
592
631
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
593
632
  @telemetry.send_api_usage_telemetry(
@@ -633,7 +672,7 @@ class ShrunkCovariance(BaseTransformer):
633
672
  drop_input_cols=self._drop_input_cols,
634
673
  expected_output_cols_type="float",
635
674
  )
636
- expected_output_cols = self._align_expected_output_names(
675
+ expected_output_cols, _ = self._align_expected_output(
637
676
  inference_method, dataset, expected_output_cols, output_cols_prefix
638
677
  )
639
678
 
@@ -699,7 +738,7 @@ class ShrunkCovariance(BaseTransformer):
699
738
  drop_input_cols=self._drop_input_cols,
700
739
  expected_output_cols_type="float",
701
740
  )
702
- expected_output_cols = self._align_expected_output_names(
741
+ expected_output_cols, _ = self._align_expected_output(
703
742
  inference_method, dataset, expected_output_cols, output_cols_prefix
704
743
  )
705
744
  elif isinstance(dataset, pd.DataFrame):
@@ -762,7 +801,7 @@ class ShrunkCovariance(BaseTransformer):
762
801
  drop_input_cols=self._drop_input_cols,
763
802
  expected_output_cols_type="float",
764
803
  )
765
- expected_output_cols = self._align_expected_output_names(
804
+ expected_output_cols, _ = self._align_expected_output(
766
805
  inference_method, dataset, expected_output_cols, output_cols_prefix
767
806
  )
768
807
 
@@ -827,7 +866,7 @@ class ShrunkCovariance(BaseTransformer):
827
866
  drop_input_cols = self._drop_input_cols,
828
867
  expected_output_cols_type="float",
829
868
  )
830
- expected_output_cols = self._align_expected_output_names(
869
+ expected_output_cols, _ = self._align_expected_output(
831
870
  inference_method, dataset, expected_output_cols, output_cols_prefix
832
871
  )
833
872