snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +156 -20
  5. snowflake/ml/_internal/utils/identifier.py +48 -11
  6. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  7. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  8. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  9. snowflake/ml/_internal/utils/table_manager.py +19 -1
  10. snowflake/ml/_internal/utils/uri.py +2 -2
  11. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  12. snowflake/ml/data/data_connector.py +88 -9
  13. snowflake/ml/data/data_ingestor.py +18 -1
  14. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  15. snowflake/ml/data/torch_utils.py +68 -0
  16. snowflake/ml/dataset/dataset.py +1 -3
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +9 -3
  19. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  20. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  21. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  22. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  23. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  24. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  26. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  27. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  29. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  31. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  32. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  33. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  34. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  35. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  36. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  37. snowflake/ml/feature_store/feature_store.py +100 -41
  38. snowflake/ml/feature_store/feature_view.py +149 -5
  39. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  40. snowflake/ml/fileset/fileset.py +1 -1
  41. snowflake/ml/fileset/sfcfs.py +9 -3
  42. snowflake/ml/model/_client/model/model_impl.py +11 -2
  43. snowflake/ml/model/_client/model/model_version_impl.py +186 -20
  44. snowflake/ml/model/_client/ops/model_ops.py +144 -30
  45. snowflake/ml/model/_client/ops/service_ops.py +312 -0
  46. snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
  47. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
  48. snowflake/ml/model/_client/sql/model_version.py +13 -4
  49. snowflake/ml/model/_client/sql/service.py +196 -0
  50. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  51. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  52. snowflake/ml/model/_model_composer/model_composer.py +5 -0
  53. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
  54. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  55. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  57. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  58. snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
  59. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  60. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
  61. snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
  62. snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
  63. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  64. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  65. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  66. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  67. snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
  68. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
  69. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  70. snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
  71. snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
  72. snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
  73. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
  74. snowflake/ml/model/_packager/model_packager.py +4 -1
  75. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  76. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  77. snowflake/ml/model/_signatures/utils.py +9 -0
  78. snowflake/ml/model/models/llm.py +3 -1
  79. snowflake/ml/model/type_hints.py +10 -4
  80. snowflake/ml/modeling/_internal/constants.py +1 -0
  81. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  82. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  83. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  84. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  85. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  86. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  87. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  88. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  89. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  90. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  91. snowflake/ml/modeling/cluster/birch.py +60 -21
  92. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  93. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  94. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  95. snowflake/ml/modeling/cluster/k_means.py +60 -21
  96. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  97. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  98. snowflake/ml/modeling/cluster/optics.py +60 -21
  99. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  100. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  101. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  102. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  103. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  104. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  105. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  106. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  107. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  108. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  109. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  110. snowflake/ml/modeling/covariance/oas.py +60 -21
  111. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  112. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  113. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  114. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  115. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  116. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  117. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  118. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  119. snowflake/ml/modeling/decomposition/pca.py +60 -21
  120. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  121. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  122. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  123. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  124. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  125. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  126. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  127. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  128. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  129. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  130. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  131. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  132. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  133. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  134. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  135. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  136. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  137. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  138. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  139. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  140. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  141. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  142. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  143. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  144. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  145. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  146. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  147. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  148. snowflake/ml/modeling/framework/base.py +28 -19
  149. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  150. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  151. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  152. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  153. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  154. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  155. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  156. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  157. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  158. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  159. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  160. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  161. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  162. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  163. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  164. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  165. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  166. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  167. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  168. snowflake/ml/modeling/linear_model/lars.py +60 -21
  169. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  170. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  171. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  172. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  173. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  174. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  175. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  176. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  177. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  178. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  179. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  180. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  181. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  182. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  183. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  184. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  185. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  186. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  187. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  188. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  189. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  190. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  191. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  192. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  193. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  194. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  195. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  196. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  197. snowflake/ml/modeling/manifold/isomap.py +60 -21
  198. snowflake/ml/modeling/manifold/mds.py +60 -21
  199. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  200. snowflake/ml/modeling/manifold/tsne.py +60 -21
  201. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  202. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  203. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  204. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  205. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  206. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  207. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  208. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  209. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  210. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  211. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  212. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  213. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  214. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  215. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  216. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  217. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  218. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  219. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  220. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  221. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  222. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  223. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  224. snowflake/ml/modeling/pipeline/pipeline.py +4 -12
  225. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  226. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  227. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  228. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  229. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  230. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  231. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  232. snowflake/ml/modeling/svm/svc.py +60 -21
  233. snowflake/ml/modeling/svm/svr.py +60 -21
  234. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  235. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  236. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  237. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  238. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  239. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  240. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  241. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  242. snowflake/ml/registry/_manager/model_manager.py +20 -2
  243. snowflake/ml/registry/model_registry.py +1 -1
  244. snowflake/ml/registry/registry.py +1 -2
  245. snowflake/ml/utils/sql_client.py +22 -0
  246. snowflake/ml/version.py +1 -1
  247. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
  248. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
  249. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  250. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  251. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  252. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -492,12 +489,23 @@ class OneVsRestClassifier(BaseTransformer):
492
489
  autogenerated=self._autogenerated,
493
490
  subproject=_SUBPROJECT,
494
491
  )
495
- output_result, fitted_estimator = model_trainer.train_fit_predict(
496
- drop_input_cols=self._drop_input_cols,
497
- expected_output_cols_list=(
498
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
499
- ),
492
+ expected_output_cols = (
493
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
500
494
  )
495
+ if isinstance(dataset, DataFrame):
496
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
497
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
498
+ )
499
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
500
+ drop_input_cols=self._drop_input_cols,
501
+ expected_output_cols_list=expected_output_cols,
502
+ example_output_pd_df=example_output_pd_df,
503
+ )
504
+ else:
505
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
506
+ drop_input_cols=self._drop_input_cols,
507
+ expected_output_cols_list=expected_output_cols,
508
+ )
501
509
  self._sklearn_object = fitted_estimator
502
510
  self._is_fitted = True
503
511
  return output_result
@@ -576,12 +584,41 @@ class OneVsRestClassifier(BaseTransformer):
576
584
 
577
585
  return rv
578
586
 
579
- def _align_expected_output_names(
580
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
581
- ) -> List[str]:
587
+ def _align_expected_output(
588
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
589
+ ) -> Tuple[List[str], pd.DataFrame]:
590
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
591
+ and output dataframe with 1 line.
592
+ If the method is fit_predict, run 2 lines of data.
593
+ """
582
594
  # in case the inferred output column names dimension is different
583
595
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
584
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
596
+
597
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
598
+ # so change the minimum of number of rows to 2
599
+ num_examples = 2
600
+ statement_params = telemetry.get_function_usage_statement_params(
601
+ project=_PROJECT,
602
+ subproject=_SUBPROJECT,
603
+ function_name=telemetry.get_statement_params_full_func_name(
604
+ inspect.currentframe(), OneVsRestClassifier.__class__.__name__
605
+ ),
606
+ api_calls=[Session.call],
607
+ custom_tags={"autogen": True} if self._autogenerated else None,
608
+ )
609
+ if output_cols_prefix == "fit_predict_":
610
+ if hasattr(self._sklearn_object, "n_clusters"):
611
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
612
+ num_examples = self._sklearn_object.n_clusters
613
+ elif hasattr(self._sklearn_object, "min_samples"):
614
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
615
+ num_examples = self._sklearn_object.min_samples
616
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
617
+ # LocalOutlierFactor expects n_neighbors <= n_samples
618
+ num_examples = self._sklearn_object.n_neighbors
619
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
620
+ else:
621
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
585
622
 
586
623
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
587
624
  # seen during the fit.
@@ -593,12 +630,14 @@ class OneVsRestClassifier(BaseTransformer):
593
630
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
594
631
  if self.sample_weight_col:
595
632
  output_df_columns_set -= set(self.sample_weight_col)
633
+
596
634
  # if the dimension of inferred output column names is correct; use it
597
635
  if len(expected_output_cols_list) == len(output_df_columns_set):
598
- return expected_output_cols_list
636
+ return expected_output_cols_list, output_df_pd
599
637
  # otherwise, use the sklearn estimator's output
600
638
  else:
601
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
639
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
640
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
602
641
 
603
642
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
604
643
  @telemetry.send_api_usage_telemetry(
@@ -646,7 +685,7 @@ class OneVsRestClassifier(BaseTransformer):
646
685
  drop_input_cols=self._drop_input_cols,
647
686
  expected_output_cols_type="float",
648
687
  )
649
- expected_output_cols = self._align_expected_output_names(
688
+ expected_output_cols, _ = self._align_expected_output(
650
689
  inference_method, dataset, expected_output_cols, output_cols_prefix
651
690
  )
652
691
 
@@ -714,7 +753,7 @@ class OneVsRestClassifier(BaseTransformer):
714
753
  drop_input_cols=self._drop_input_cols,
715
754
  expected_output_cols_type="float",
716
755
  )
717
- expected_output_cols = self._align_expected_output_names(
756
+ expected_output_cols, _ = self._align_expected_output(
718
757
  inference_method, dataset, expected_output_cols, output_cols_prefix
719
758
  )
720
759
  elif isinstance(dataset, pd.DataFrame):
@@ -779,7 +818,7 @@ class OneVsRestClassifier(BaseTransformer):
779
818
  drop_input_cols=self._drop_input_cols,
780
819
  expected_output_cols_type="float",
781
820
  )
782
- expected_output_cols = self._align_expected_output_names(
821
+ expected_output_cols, _ = self._align_expected_output(
783
822
  inference_method, dataset, expected_output_cols, output_cols_prefix
784
823
  )
785
824
 
@@ -844,7 +883,7 @@ class OneVsRestClassifier(BaseTransformer):
844
883
  drop_input_cols = self._drop_input_cols,
845
884
  expected_output_cols_type="float",
846
885
  )
847
- expected_output_cols = self._align_expected_output_names(
886
+ expected_output_cols, _ = self._align_expected_output(
848
887
  inference_method, dataset, expected_output_cols, output_cols_prefix
849
888
  )
850
889
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -495,12 +492,23 @@ class OutputCodeClassifier(BaseTransformer):
495
492
  autogenerated=self._autogenerated,
496
493
  subproject=_SUBPROJECT,
497
494
  )
498
- output_result, fitted_estimator = model_trainer.train_fit_predict(
499
- drop_input_cols=self._drop_input_cols,
500
- expected_output_cols_list=(
501
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
502
- ),
495
+ expected_output_cols = (
496
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
503
497
  )
498
+ if isinstance(dataset, DataFrame):
499
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
500
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
501
+ )
502
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
503
+ drop_input_cols=self._drop_input_cols,
504
+ expected_output_cols_list=expected_output_cols,
505
+ example_output_pd_df=example_output_pd_df,
506
+ )
507
+ else:
508
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
509
+ drop_input_cols=self._drop_input_cols,
510
+ expected_output_cols_list=expected_output_cols,
511
+ )
504
512
  self._sklearn_object = fitted_estimator
505
513
  self._is_fitted = True
506
514
  return output_result
@@ -579,12 +587,41 @@ class OutputCodeClassifier(BaseTransformer):
579
587
 
580
588
  return rv
581
589
 
582
- def _align_expected_output_names(
583
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
584
- ) -> List[str]:
590
+ def _align_expected_output(
591
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
592
+ ) -> Tuple[List[str], pd.DataFrame]:
593
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
594
+ and output dataframe with 1 line.
595
+ If the method is fit_predict, run 2 lines of data.
596
+ """
585
597
  # in case the inferred output column names dimension is different
586
598
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
587
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
599
+
600
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
601
+ # so change the minimum of number of rows to 2
602
+ num_examples = 2
603
+ statement_params = telemetry.get_function_usage_statement_params(
604
+ project=_PROJECT,
605
+ subproject=_SUBPROJECT,
606
+ function_name=telemetry.get_statement_params_full_func_name(
607
+ inspect.currentframe(), OutputCodeClassifier.__class__.__name__
608
+ ),
609
+ api_calls=[Session.call],
610
+ custom_tags={"autogen": True} if self._autogenerated else None,
611
+ )
612
+ if output_cols_prefix == "fit_predict_":
613
+ if hasattr(self._sklearn_object, "n_clusters"):
614
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
615
+ num_examples = self._sklearn_object.n_clusters
616
+ elif hasattr(self._sklearn_object, "min_samples"):
617
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
618
+ num_examples = self._sklearn_object.min_samples
619
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
620
+ # LocalOutlierFactor expects n_neighbors <= n_samples
621
+ num_examples = self._sklearn_object.n_neighbors
622
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
623
+ else:
624
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
588
625
 
589
626
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
590
627
  # seen during the fit.
@@ -596,12 +633,14 @@ class OutputCodeClassifier(BaseTransformer):
596
633
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
597
634
  if self.sample_weight_col:
598
635
  output_df_columns_set -= set(self.sample_weight_col)
636
+
599
637
  # if the dimension of inferred output column names is correct; use it
600
638
  if len(expected_output_cols_list) == len(output_df_columns_set):
601
- return expected_output_cols_list
639
+ return expected_output_cols_list, output_df_pd
602
640
  # otherwise, use the sklearn estimator's output
603
641
  else:
604
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
642
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
643
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
605
644
 
606
645
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
607
646
  @telemetry.send_api_usage_telemetry(
@@ -647,7 +686,7 @@ class OutputCodeClassifier(BaseTransformer):
647
686
  drop_input_cols=self._drop_input_cols,
648
687
  expected_output_cols_type="float",
649
688
  )
650
- expected_output_cols = self._align_expected_output_names(
689
+ expected_output_cols, _ = self._align_expected_output(
651
690
  inference_method, dataset, expected_output_cols, output_cols_prefix
652
691
  )
653
692
 
@@ -713,7 +752,7 @@ class OutputCodeClassifier(BaseTransformer):
713
752
  drop_input_cols=self._drop_input_cols,
714
753
  expected_output_cols_type="float",
715
754
  )
716
- expected_output_cols = self._align_expected_output_names(
755
+ expected_output_cols, _ = self._align_expected_output(
717
756
  inference_method, dataset, expected_output_cols, output_cols_prefix
718
757
  )
719
758
  elif isinstance(dataset, pd.DataFrame):
@@ -776,7 +815,7 @@ class OutputCodeClassifier(BaseTransformer):
776
815
  drop_input_cols=self._drop_input_cols,
777
816
  expected_output_cols_type="float",
778
817
  )
779
- expected_output_cols = self._align_expected_output_names(
818
+ expected_output_cols, _ = self._align_expected_output(
780
819
  inference_method, dataset, expected_output_cols, output_cols_prefix
781
820
  )
782
821
 
@@ -841,7 +880,7 @@ class OutputCodeClassifier(BaseTransformer):
841
880
  drop_input_cols = self._drop_input_cols,
842
881
  expected_output_cols_type="float",
843
882
  )
844
- expected_output_cols = self._align_expected_output_names(
883
+ expected_output_cols, _ = self._align_expected_output(
845
884
  inference_method, dataset, expected_output_cols, output_cols_prefix
846
885
  )
847
886
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -495,12 +492,23 @@ class BernoulliNB(BaseTransformer):
495
492
  autogenerated=self._autogenerated,
496
493
  subproject=_SUBPROJECT,
497
494
  )
498
- output_result, fitted_estimator = model_trainer.train_fit_predict(
499
- drop_input_cols=self._drop_input_cols,
500
- expected_output_cols_list=(
501
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
502
- ),
495
+ expected_output_cols = (
496
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
503
497
  )
498
+ if isinstance(dataset, DataFrame):
499
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
500
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
501
+ )
502
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
503
+ drop_input_cols=self._drop_input_cols,
504
+ expected_output_cols_list=expected_output_cols,
505
+ example_output_pd_df=example_output_pd_df,
506
+ )
507
+ else:
508
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
509
+ drop_input_cols=self._drop_input_cols,
510
+ expected_output_cols_list=expected_output_cols,
511
+ )
504
512
  self._sklearn_object = fitted_estimator
505
513
  self._is_fitted = True
506
514
  return output_result
@@ -579,12 +587,41 @@ class BernoulliNB(BaseTransformer):
579
587
 
580
588
  return rv
581
589
 
582
- def _align_expected_output_names(
583
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
584
- ) -> List[str]:
590
+ def _align_expected_output(
591
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
592
+ ) -> Tuple[List[str], pd.DataFrame]:
593
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
594
+ and output dataframe with 1 line.
595
+ If the method is fit_predict, run 2 lines of data.
596
+ """
585
597
  # in case the inferred output column names dimension is different
586
598
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
587
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
599
+
600
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
601
+ # so change the minimum of number of rows to 2
602
+ num_examples = 2
603
+ statement_params = telemetry.get_function_usage_statement_params(
604
+ project=_PROJECT,
605
+ subproject=_SUBPROJECT,
606
+ function_name=telemetry.get_statement_params_full_func_name(
607
+ inspect.currentframe(), BernoulliNB.__class__.__name__
608
+ ),
609
+ api_calls=[Session.call],
610
+ custom_tags={"autogen": True} if self._autogenerated else None,
611
+ )
612
+ if output_cols_prefix == "fit_predict_":
613
+ if hasattr(self._sklearn_object, "n_clusters"):
614
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
615
+ num_examples = self._sklearn_object.n_clusters
616
+ elif hasattr(self._sklearn_object, "min_samples"):
617
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
618
+ num_examples = self._sklearn_object.min_samples
619
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
620
+ # LocalOutlierFactor expects n_neighbors <= n_samples
621
+ num_examples = self._sklearn_object.n_neighbors
622
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
623
+ else:
624
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
588
625
 
589
626
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
590
627
  # seen during the fit.
@@ -596,12 +633,14 @@ class BernoulliNB(BaseTransformer):
596
633
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
597
634
  if self.sample_weight_col:
598
635
  output_df_columns_set -= set(self.sample_weight_col)
636
+
599
637
  # if the dimension of inferred output column names is correct; use it
600
638
  if len(expected_output_cols_list) == len(output_df_columns_set):
601
- return expected_output_cols_list
639
+ return expected_output_cols_list, output_df_pd
602
640
  # otherwise, use the sklearn estimator's output
603
641
  else:
604
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
642
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
643
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
605
644
 
606
645
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
607
646
  @telemetry.send_api_usage_telemetry(
@@ -649,7 +688,7 @@ class BernoulliNB(BaseTransformer):
649
688
  drop_input_cols=self._drop_input_cols,
650
689
  expected_output_cols_type="float",
651
690
  )
652
- expected_output_cols = self._align_expected_output_names(
691
+ expected_output_cols, _ = self._align_expected_output(
653
692
  inference_method, dataset, expected_output_cols, output_cols_prefix
654
693
  )
655
694
 
@@ -717,7 +756,7 @@ class BernoulliNB(BaseTransformer):
717
756
  drop_input_cols=self._drop_input_cols,
718
757
  expected_output_cols_type="float",
719
758
  )
720
- expected_output_cols = self._align_expected_output_names(
759
+ expected_output_cols, _ = self._align_expected_output(
721
760
  inference_method, dataset, expected_output_cols, output_cols_prefix
722
761
  )
723
762
  elif isinstance(dataset, pd.DataFrame):
@@ -780,7 +819,7 @@ class BernoulliNB(BaseTransformer):
780
819
  drop_input_cols=self._drop_input_cols,
781
820
  expected_output_cols_type="float",
782
821
  )
783
- expected_output_cols = self._align_expected_output_names(
822
+ expected_output_cols, _ = self._align_expected_output(
784
823
  inference_method, dataset, expected_output_cols, output_cols_prefix
785
824
  )
786
825
 
@@ -845,7 +884,7 @@ class BernoulliNB(BaseTransformer):
845
884
  drop_input_cols = self._drop_input_cols,
846
885
  expected_output_cols_type="float",
847
886
  )
848
- expected_output_cols = self._align_expected_output_names(
887
+ expected_output_cols, _ = self._align_expected_output(
849
888
  inference_method, dataset, expected_output_cols, output_cols_prefix
850
889
  )
851
890
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -501,12 +498,23 @@ class CategoricalNB(BaseTransformer):
501
498
  autogenerated=self._autogenerated,
502
499
  subproject=_SUBPROJECT,
503
500
  )
504
- output_result, fitted_estimator = model_trainer.train_fit_predict(
505
- drop_input_cols=self._drop_input_cols,
506
- expected_output_cols_list=(
507
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
508
- ),
501
+ expected_output_cols = (
502
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
509
503
  )
504
+ if isinstance(dataset, DataFrame):
505
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
506
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
507
+ )
508
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
509
+ drop_input_cols=self._drop_input_cols,
510
+ expected_output_cols_list=expected_output_cols,
511
+ example_output_pd_df=example_output_pd_df,
512
+ )
513
+ else:
514
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
515
+ drop_input_cols=self._drop_input_cols,
516
+ expected_output_cols_list=expected_output_cols,
517
+ )
510
518
  self._sklearn_object = fitted_estimator
511
519
  self._is_fitted = True
512
520
  return output_result
@@ -585,12 +593,41 @@ class CategoricalNB(BaseTransformer):
585
593
 
586
594
  return rv
587
595
 
588
- def _align_expected_output_names(
589
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
590
- ) -> List[str]:
596
+ def _align_expected_output(
597
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
598
+ ) -> Tuple[List[str], pd.DataFrame]:
599
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
600
+ and output dataframe with 1 line.
601
+ If the method is fit_predict, run 2 lines of data.
602
+ """
591
603
  # in case the inferred output column names dimension is different
592
604
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
593
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
605
+
606
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
607
+ # so change the minimum of number of rows to 2
608
+ num_examples = 2
609
+ statement_params = telemetry.get_function_usage_statement_params(
610
+ project=_PROJECT,
611
+ subproject=_SUBPROJECT,
612
+ function_name=telemetry.get_statement_params_full_func_name(
613
+ inspect.currentframe(), CategoricalNB.__class__.__name__
614
+ ),
615
+ api_calls=[Session.call],
616
+ custom_tags={"autogen": True} if self._autogenerated else None,
617
+ )
618
+ if output_cols_prefix == "fit_predict_":
619
+ if hasattr(self._sklearn_object, "n_clusters"):
620
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
621
+ num_examples = self._sklearn_object.n_clusters
622
+ elif hasattr(self._sklearn_object, "min_samples"):
623
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
624
+ num_examples = self._sklearn_object.min_samples
625
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
626
+ # LocalOutlierFactor expects n_neighbors <= n_samples
627
+ num_examples = self._sklearn_object.n_neighbors
628
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
629
+ else:
630
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
594
631
 
595
632
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
596
633
  # seen during the fit.
@@ -602,12 +639,14 @@ class CategoricalNB(BaseTransformer):
602
639
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
603
640
  if self.sample_weight_col:
604
641
  output_df_columns_set -= set(self.sample_weight_col)
642
+
605
643
  # if the dimension of inferred output column names is correct; use it
606
644
  if len(expected_output_cols_list) == len(output_df_columns_set):
607
- return expected_output_cols_list
645
+ return expected_output_cols_list, output_df_pd
608
646
  # otherwise, use the sklearn estimator's output
609
647
  else:
610
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
648
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
649
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
611
650
 
612
651
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
613
652
  @telemetry.send_api_usage_telemetry(
@@ -655,7 +694,7 @@ class CategoricalNB(BaseTransformer):
655
694
  drop_input_cols=self._drop_input_cols,
656
695
  expected_output_cols_type="float",
657
696
  )
658
- expected_output_cols = self._align_expected_output_names(
697
+ expected_output_cols, _ = self._align_expected_output(
659
698
  inference_method, dataset, expected_output_cols, output_cols_prefix
660
699
  )
661
700
 
@@ -723,7 +762,7 @@ class CategoricalNB(BaseTransformer):
723
762
  drop_input_cols=self._drop_input_cols,
724
763
  expected_output_cols_type="float",
725
764
  )
726
- expected_output_cols = self._align_expected_output_names(
765
+ expected_output_cols, _ = self._align_expected_output(
727
766
  inference_method, dataset, expected_output_cols, output_cols_prefix
728
767
  )
729
768
  elif isinstance(dataset, pd.DataFrame):
@@ -786,7 +825,7 @@ class CategoricalNB(BaseTransformer):
786
825
  drop_input_cols=self._drop_input_cols,
787
826
  expected_output_cols_type="float",
788
827
  )
789
- expected_output_cols = self._align_expected_output_names(
828
+ expected_output_cols, _ = self._align_expected_output(
790
829
  inference_method, dataset, expected_output_cols, output_cols_prefix
791
830
  )
792
831
 
@@ -851,7 +890,7 @@ class CategoricalNB(BaseTransformer):
851
890
  drop_input_cols = self._drop_input_cols,
852
891
  expected_output_cols_type="float",
853
892
  )
854
- expected_output_cols = self._align_expected_output_names(
893
+ expected_output_cols, _ = self._align_expected_output(
855
894
  inference_method, dataset, expected_output_cols, output_cols_prefix
856
895
  )
857
896