snowflake-ml-python 1.6.0__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (252) hide show
  1. snowflake/cortex/_complete.py +7 -33
  2. snowflake/ml/_internal/env_utils.py +11 -5
  3. snowflake/ml/_internal/exceptions/modeling_error_messages.py +4 -1
  4. snowflake/ml/_internal/telemetry.py +156 -20
  5. snowflake/ml/_internal/utils/identifier.py +48 -11
  6. snowflake/ml/_internal/utils/pkg_version_utils.py +8 -22
  7. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  8. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  9. snowflake/ml/_internal/utils/table_manager.py +19 -1
  10. snowflake/ml/_internal/utils/uri.py +2 -2
  11. snowflake/ml/data/_internal/arrow_ingestor.py +66 -10
  12. snowflake/ml/data/data_connector.py +88 -9
  13. snowflake/ml/data/data_ingestor.py +18 -1
  14. snowflake/ml/data/{_internal/ingestor_utils.py → ingestor_utils.py} +5 -1
  15. snowflake/ml/data/torch_utils.py +68 -0
  16. snowflake/ml/dataset/dataset.py +1 -3
  17. snowflake/ml/dataset/dataset_metadata.py +3 -1
  18. snowflake/ml/dataset/dataset_reader.py +9 -3
  19. snowflake/ml/feature_store/examples/airline_features/entities.py +16 -0
  20. snowflake/ml/feature_store/examples/airline_features/features/plane_features.py +31 -0
  21. snowflake/ml/feature_store/examples/airline_features/features/weather_features.py +42 -0
  22. snowflake/ml/feature_store/examples/airline_features/source.yaml +7 -0
  23. snowflake/ml/feature_store/examples/citibike_trip_features/features/station_feature.py +10 -4
  24. snowflake/ml/feature_store/examples/citibike_trip_features/features/trip_feature.py +6 -0
  25. snowflake/ml/feature_store/examples/citibike_trip_features/source.yaml +3 -0
  26. snowflake/ml/feature_store/examples/example_helper.py +69 -31
  27. snowflake/ml/feature_store/examples/new_york_taxi_features/entities.py +3 -3
  28. snowflake/ml/feature_store/examples/new_york_taxi_features/features/{dropoff_features.py → location_features.py} +14 -9
  29. snowflake/ml/feature_store/examples/new_york_taxi_features/features/trip_features.py +36 -0
  30. snowflake/ml/feature_store/examples/new_york_taxi_features/source.yaml +5 -1
  31. snowflake/ml/feature_store/examples/source_data/airline.yaml +4 -0
  32. snowflake/ml/feature_store/examples/source_data/citibike_trips.yaml +1 -1
  33. snowflake/ml/feature_store/examples/wine_quality_features/entities.py +3 -3
  34. snowflake/ml/feature_store/examples/wine_quality_features/features/managed_wine_features.py +13 -6
  35. snowflake/ml/feature_store/examples/wine_quality_features/features/static_wine_features.py +8 -5
  36. snowflake/ml/feature_store/examples/wine_quality_features/source.yaml +3 -0
  37. snowflake/ml/feature_store/feature_store.py +100 -41
  38. snowflake/ml/feature_store/feature_view.py +149 -5
  39. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  40. snowflake/ml/fileset/fileset.py +1 -1
  41. snowflake/ml/fileset/sfcfs.py +9 -3
  42. snowflake/ml/model/_client/model/model_impl.py +11 -2
  43. snowflake/ml/model/_client/model/model_version_impl.py +186 -20
  44. snowflake/ml/model/_client/ops/model_ops.py +144 -30
  45. snowflake/ml/model/_client/ops/service_ops.py +312 -0
  46. snowflake/ml/model/_client/service/model_deployment_spec.py +94 -0
  47. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +30 -0
  48. snowflake/ml/model/_client/sql/model_version.py +13 -4
  49. snowflake/ml/model/_client/sql/service.py +196 -0
  50. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  51. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  52. snowflake/ml/model/_model_composer/model_composer.py +5 -0
  53. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +13 -10
  54. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +3 -0
  55. snowflake/ml/model/_packager/model_env/model_env.py +7 -2
  56. snowflake/ml/model/_packager/model_handlers/_base.py +29 -12
  57. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  58. snowflake/ml/model/_packager/model_handlers/catboost.py +25 -16
  59. snowflake/ml/model/_packager/model_handlers/custom.py +6 -2
  60. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +32 -20
  61. snowflake/ml/model/_packager/model_handlers/lightgbm.py +23 -56
  62. snowflake/ml/model/_packager/model_handlers/llm.py +11 -5
  63. snowflake/ml/model/_packager/model_handlers/mlflow.py +8 -3
  64. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  65. snowflake/ml/model/_packager/model_handlers/pytorch.py +8 -3
  66. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +8 -3
  67. snowflake/ml/model/_packager/model_handlers/sklearn.py +99 -4
  68. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +123 -5
  69. snowflake/ml/model/_packager/model_handlers/tensorflow.py +9 -4
  70. snowflake/ml/model/_packager/model_handlers/torchscript.py +10 -5
  71. snowflake/ml/model/_packager/model_handlers/xgboost.py +56 -47
  72. snowflake/ml/model/_packager/model_meta/model_meta.py +35 -2
  73. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +11 -0
  74. snowflake/ml/model/_packager/model_packager.py +4 -1
  75. snowflake/ml/model/_packager/model_runtime/model_runtime.py +4 -2
  76. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  77. snowflake/ml/model/_signatures/utils.py +9 -0
  78. snowflake/ml/model/models/llm.py +3 -1
  79. snowflake/ml/model/type_hints.py +10 -4
  80. snowflake/ml/modeling/_internal/constants.py +1 -0
  81. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  82. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  83. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  84. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  85. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  86. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  87. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  88. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  89. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  90. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  91. snowflake/ml/modeling/cluster/birch.py +60 -21
  92. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  93. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  94. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  95. snowflake/ml/modeling/cluster/k_means.py +60 -21
  96. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  97. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  98. snowflake/ml/modeling/cluster/optics.py +60 -21
  99. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  100. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  101. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  102. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  103. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  104. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  105. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  106. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  107. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  108. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  109. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  110. snowflake/ml/modeling/covariance/oas.py +60 -21
  111. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  112. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  113. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  114. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  115. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  116. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  117. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  118. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  119. snowflake/ml/modeling/decomposition/pca.py +60 -21
  120. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  121. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  122. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  123. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  124. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  125. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  126. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  127. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  128. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  129. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  130. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  131. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  132. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  133. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  134. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  135. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  136. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  137. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  138. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  139. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  140. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  141. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  142. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  143. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  144. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  145. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  146. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  147. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  148. snowflake/ml/modeling/framework/base.py +28 -19
  149. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  150. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  151. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  152. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  153. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  154. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  155. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  156. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  157. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  158. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  159. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  160. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  161. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  162. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  163. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  164. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  165. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  166. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  167. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  168. snowflake/ml/modeling/linear_model/lars.py +60 -21
  169. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  170. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  171. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  172. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  173. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  174. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  175. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  176. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  177. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  178. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  179. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  180. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  181. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  182. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  183. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  184. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  185. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  186. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  187. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  188. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  189. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  190. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  191. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  192. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  193. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  194. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  195. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  196. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  197. snowflake/ml/modeling/manifold/isomap.py +60 -21
  198. snowflake/ml/modeling/manifold/mds.py +60 -21
  199. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  200. snowflake/ml/modeling/manifold/tsne.py +60 -21
  201. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  202. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  203. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  204. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  205. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  206. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  207. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  208. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  209. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  210. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  211. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  212. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  213. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  214. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  215. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  216. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  217. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  218. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  219. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  220. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  221. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  222. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  223. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  224. snowflake/ml/modeling/pipeline/pipeline.py +4 -12
  225. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  226. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  227. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  228. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  229. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  230. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  231. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  232. snowflake/ml/modeling/svm/svc.py +60 -21
  233. snowflake/ml/modeling/svm/svr.py +60 -21
  234. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  235. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  236. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  237. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  238. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  239. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  240. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  241. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  242. snowflake/ml/registry/_manager/model_manager.py +20 -2
  243. snowflake/ml/registry/model_registry.py +1 -1
  244. snowflake/ml/registry/registry.py +1 -2
  245. snowflake/ml/utils/sql_client.py +22 -0
  246. snowflake/ml/version.py +1 -1
  247. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +55 -3
  248. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +251 -238
  249. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  250. snowflake/ml/feature_store/examples/new_york_taxi_features/features/pickup_features.py +0 -58
  251. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  252. {snowflake_ml_python-1.6.0.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -517,12 +514,23 @@ class MultiTaskLasso(BaseTransformer):
517
514
  autogenerated=self._autogenerated,
518
515
  subproject=_SUBPROJECT,
519
516
  )
520
- output_result, fitted_estimator = model_trainer.train_fit_predict(
521
- drop_input_cols=self._drop_input_cols,
522
- expected_output_cols_list=(
523
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
524
- ),
517
+ expected_output_cols = (
518
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
525
519
  )
520
+ if isinstance(dataset, DataFrame):
521
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
522
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
523
+ )
524
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
525
+ drop_input_cols=self._drop_input_cols,
526
+ expected_output_cols_list=expected_output_cols,
527
+ example_output_pd_df=example_output_pd_df,
528
+ )
529
+ else:
530
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
531
+ drop_input_cols=self._drop_input_cols,
532
+ expected_output_cols_list=expected_output_cols,
533
+ )
526
534
  self._sklearn_object = fitted_estimator
527
535
  self._is_fitted = True
528
536
  return output_result
@@ -601,12 +609,41 @@ class MultiTaskLasso(BaseTransformer):
601
609
 
602
610
  return rv
603
611
 
604
- def _align_expected_output_names(
605
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
606
- ) -> List[str]:
612
+ def _align_expected_output(
613
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
614
+ ) -> Tuple[List[str], pd.DataFrame]:
615
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
616
+ and output dataframe with 1 line.
617
+ If the method is fit_predict, run 2 lines of data.
618
+ """
607
619
  # in case the inferred output column names dimension is different
608
620
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
609
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
621
+
622
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
623
+ # so change the minimum of number of rows to 2
624
+ num_examples = 2
625
+ statement_params = telemetry.get_function_usage_statement_params(
626
+ project=_PROJECT,
627
+ subproject=_SUBPROJECT,
628
+ function_name=telemetry.get_statement_params_full_func_name(
629
+ inspect.currentframe(), MultiTaskLasso.__class__.__name__
630
+ ),
631
+ api_calls=[Session.call],
632
+ custom_tags={"autogen": True} if self._autogenerated else None,
633
+ )
634
+ if output_cols_prefix == "fit_predict_":
635
+ if hasattr(self._sklearn_object, "n_clusters"):
636
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
637
+ num_examples = self._sklearn_object.n_clusters
638
+ elif hasattr(self._sklearn_object, "min_samples"):
639
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
640
+ num_examples = self._sklearn_object.min_samples
641
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
642
+ # LocalOutlierFactor expects n_neighbors <= n_samples
643
+ num_examples = self._sklearn_object.n_neighbors
644
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
645
+ else:
646
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
610
647
 
611
648
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
612
649
  # seen during the fit.
@@ -618,12 +655,14 @@ class MultiTaskLasso(BaseTransformer):
618
655
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
619
656
  if self.sample_weight_col:
620
657
  output_df_columns_set -= set(self.sample_weight_col)
658
+
621
659
  # if the dimension of inferred output column names is correct; use it
622
660
  if len(expected_output_cols_list) == len(output_df_columns_set):
623
- return expected_output_cols_list
661
+ return expected_output_cols_list, output_df_pd
624
662
  # otherwise, use the sklearn estimator's output
625
663
  else:
626
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
664
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
665
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
627
666
 
628
667
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
629
668
  @telemetry.send_api_usage_telemetry(
@@ -669,7 +708,7 @@ class MultiTaskLasso(BaseTransformer):
669
708
  drop_input_cols=self._drop_input_cols,
670
709
  expected_output_cols_type="float",
671
710
  )
672
- expected_output_cols = self._align_expected_output_names(
711
+ expected_output_cols, _ = self._align_expected_output(
673
712
  inference_method, dataset, expected_output_cols, output_cols_prefix
674
713
  )
675
714
 
@@ -735,7 +774,7 @@ class MultiTaskLasso(BaseTransformer):
735
774
  drop_input_cols=self._drop_input_cols,
736
775
  expected_output_cols_type="float",
737
776
  )
738
- expected_output_cols = self._align_expected_output_names(
777
+ expected_output_cols, _ = self._align_expected_output(
739
778
  inference_method, dataset, expected_output_cols, output_cols_prefix
740
779
  )
741
780
  elif isinstance(dataset, pd.DataFrame):
@@ -798,7 +837,7 @@ class MultiTaskLasso(BaseTransformer):
798
837
  drop_input_cols=self._drop_input_cols,
799
838
  expected_output_cols_type="float",
800
839
  )
801
- expected_output_cols = self._align_expected_output_names(
840
+ expected_output_cols, _ = self._align_expected_output(
802
841
  inference_method, dataset, expected_output_cols, output_cols_prefix
803
842
  )
804
843
 
@@ -863,7 +902,7 @@ class MultiTaskLasso(BaseTransformer):
863
902
  drop_input_cols = self._drop_input_cols,
864
903
  expected_output_cols_type="float",
865
904
  )
866
- expected_output_cols = self._align_expected_output_names(
905
+ expected_output_cols, _ = self._align_expected_output(
867
906
  inference_method, dataset, expected_output_cols, output_cols_prefix
868
907
  )
869
908
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -552,12 +549,23 @@ class MultiTaskLassoCV(BaseTransformer):
552
549
  autogenerated=self._autogenerated,
553
550
  subproject=_SUBPROJECT,
554
551
  )
555
- output_result, fitted_estimator = model_trainer.train_fit_predict(
556
- drop_input_cols=self._drop_input_cols,
557
- expected_output_cols_list=(
558
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
559
- ),
552
+ expected_output_cols = (
553
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
560
554
  )
555
+ if isinstance(dataset, DataFrame):
556
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
557
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
558
+ )
559
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
560
+ drop_input_cols=self._drop_input_cols,
561
+ expected_output_cols_list=expected_output_cols,
562
+ example_output_pd_df=example_output_pd_df,
563
+ )
564
+ else:
565
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
566
+ drop_input_cols=self._drop_input_cols,
567
+ expected_output_cols_list=expected_output_cols,
568
+ )
561
569
  self._sklearn_object = fitted_estimator
562
570
  self._is_fitted = True
563
571
  return output_result
@@ -636,12 +644,41 @@ class MultiTaskLassoCV(BaseTransformer):
636
644
 
637
645
  return rv
638
646
 
639
- def _align_expected_output_names(
640
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
641
- ) -> List[str]:
647
+ def _align_expected_output(
648
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
649
+ ) -> Tuple[List[str], pd.DataFrame]:
650
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
651
+ and output dataframe with 1 line.
652
+ If the method is fit_predict, run 2 lines of data.
653
+ """
642
654
  # in case the inferred output column names dimension is different
643
655
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
644
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
656
+
657
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
658
+ # so change the minimum of number of rows to 2
659
+ num_examples = 2
660
+ statement_params = telemetry.get_function_usage_statement_params(
661
+ project=_PROJECT,
662
+ subproject=_SUBPROJECT,
663
+ function_name=telemetry.get_statement_params_full_func_name(
664
+ inspect.currentframe(), MultiTaskLassoCV.__class__.__name__
665
+ ),
666
+ api_calls=[Session.call],
667
+ custom_tags={"autogen": True} if self._autogenerated else None,
668
+ )
669
+ if output_cols_prefix == "fit_predict_":
670
+ if hasattr(self._sklearn_object, "n_clusters"):
671
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
672
+ num_examples = self._sklearn_object.n_clusters
673
+ elif hasattr(self._sklearn_object, "min_samples"):
674
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
675
+ num_examples = self._sklearn_object.min_samples
676
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
677
+ # LocalOutlierFactor expects n_neighbors <= n_samples
678
+ num_examples = self._sklearn_object.n_neighbors
679
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
680
+ else:
681
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
645
682
 
646
683
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
647
684
  # seen during the fit.
@@ -653,12 +690,14 @@ class MultiTaskLassoCV(BaseTransformer):
653
690
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
654
691
  if self.sample_weight_col:
655
692
  output_df_columns_set -= set(self.sample_weight_col)
693
+
656
694
  # if the dimension of inferred output column names is correct; use it
657
695
  if len(expected_output_cols_list) == len(output_df_columns_set):
658
- return expected_output_cols_list
696
+ return expected_output_cols_list, output_df_pd
659
697
  # otherwise, use the sklearn estimator's output
660
698
  else:
661
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
699
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
700
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
662
701
 
663
702
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
664
703
  @telemetry.send_api_usage_telemetry(
@@ -704,7 +743,7 @@ class MultiTaskLassoCV(BaseTransformer):
704
743
  drop_input_cols=self._drop_input_cols,
705
744
  expected_output_cols_type="float",
706
745
  )
707
- expected_output_cols = self._align_expected_output_names(
746
+ expected_output_cols, _ = self._align_expected_output(
708
747
  inference_method, dataset, expected_output_cols, output_cols_prefix
709
748
  )
710
749
 
@@ -770,7 +809,7 @@ class MultiTaskLassoCV(BaseTransformer):
770
809
  drop_input_cols=self._drop_input_cols,
771
810
  expected_output_cols_type="float",
772
811
  )
773
- expected_output_cols = self._align_expected_output_names(
812
+ expected_output_cols, _ = self._align_expected_output(
774
813
  inference_method, dataset, expected_output_cols, output_cols_prefix
775
814
  )
776
815
  elif isinstance(dataset, pd.DataFrame):
@@ -833,7 +872,7 @@ class MultiTaskLassoCV(BaseTransformer):
833
872
  drop_input_cols=self._drop_input_cols,
834
873
  expected_output_cols_type="float",
835
874
  )
836
- expected_output_cols = self._align_expected_output_names(
875
+ expected_output_cols, _ = self._align_expected_output(
837
876
  inference_method, dataset, expected_output_cols, output_cols_prefix
838
877
  )
839
878
 
@@ -898,7 +937,7 @@ class MultiTaskLassoCV(BaseTransformer):
898
937
  drop_input_cols = self._drop_input_cols,
899
938
  expected_output_cols_type="float",
900
939
  )
901
- expected_output_cols = self._align_expected_output_names(
940
+ expected_output_cols, _ = self._align_expected_output(
902
941
  inference_method, dataset, expected_output_cols, output_cols_prefix
903
942
  )
904
943
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -500,12 +497,23 @@ class OrthogonalMatchingPursuit(BaseTransformer):
500
497
  autogenerated=self._autogenerated,
501
498
  subproject=_SUBPROJECT,
502
499
  )
503
- output_result, fitted_estimator = model_trainer.train_fit_predict(
504
- drop_input_cols=self._drop_input_cols,
505
- expected_output_cols_list=(
506
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
507
- ),
500
+ expected_output_cols = (
501
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
508
502
  )
503
+ if isinstance(dataset, DataFrame):
504
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
505
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
506
+ )
507
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
508
+ drop_input_cols=self._drop_input_cols,
509
+ expected_output_cols_list=expected_output_cols,
510
+ example_output_pd_df=example_output_pd_df,
511
+ )
512
+ else:
513
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
514
+ drop_input_cols=self._drop_input_cols,
515
+ expected_output_cols_list=expected_output_cols,
516
+ )
509
517
  self._sklearn_object = fitted_estimator
510
518
  self._is_fitted = True
511
519
  return output_result
@@ -584,12 +592,41 @@ class OrthogonalMatchingPursuit(BaseTransformer):
584
592
 
585
593
  return rv
586
594
 
587
- def _align_expected_output_names(
588
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
589
- ) -> List[str]:
595
+ def _align_expected_output(
596
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
597
+ ) -> Tuple[List[str], pd.DataFrame]:
598
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
599
+ and output dataframe with 1 line.
600
+ If the method is fit_predict, run 2 lines of data.
601
+ """
590
602
  # in case the inferred output column names dimension is different
591
603
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
592
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
604
+
605
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
606
+ # so change the minimum of number of rows to 2
607
+ num_examples = 2
608
+ statement_params = telemetry.get_function_usage_statement_params(
609
+ project=_PROJECT,
610
+ subproject=_SUBPROJECT,
611
+ function_name=telemetry.get_statement_params_full_func_name(
612
+ inspect.currentframe(), OrthogonalMatchingPursuit.__class__.__name__
613
+ ),
614
+ api_calls=[Session.call],
615
+ custom_tags={"autogen": True} if self._autogenerated else None,
616
+ )
617
+ if output_cols_prefix == "fit_predict_":
618
+ if hasattr(self._sklearn_object, "n_clusters"):
619
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
620
+ num_examples = self._sklearn_object.n_clusters
621
+ elif hasattr(self._sklearn_object, "min_samples"):
622
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
623
+ num_examples = self._sklearn_object.min_samples
624
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
625
+ # LocalOutlierFactor expects n_neighbors <= n_samples
626
+ num_examples = self._sklearn_object.n_neighbors
627
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
628
+ else:
629
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
593
630
 
594
631
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
595
632
  # seen during the fit.
@@ -601,12 +638,14 @@ class OrthogonalMatchingPursuit(BaseTransformer):
601
638
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
602
639
  if self.sample_weight_col:
603
640
  output_df_columns_set -= set(self.sample_weight_col)
641
+
604
642
  # if the dimension of inferred output column names is correct; use it
605
643
  if len(expected_output_cols_list) == len(output_df_columns_set):
606
- return expected_output_cols_list
644
+ return expected_output_cols_list, output_df_pd
607
645
  # otherwise, use the sklearn estimator's output
608
646
  else:
609
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
647
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
648
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
610
649
 
611
650
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
612
651
  @telemetry.send_api_usage_telemetry(
@@ -652,7 +691,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
652
691
  drop_input_cols=self._drop_input_cols,
653
692
  expected_output_cols_type="float",
654
693
  )
655
- expected_output_cols = self._align_expected_output_names(
694
+ expected_output_cols, _ = self._align_expected_output(
656
695
  inference_method, dataset, expected_output_cols, output_cols_prefix
657
696
  )
658
697
 
@@ -718,7 +757,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
718
757
  drop_input_cols=self._drop_input_cols,
719
758
  expected_output_cols_type="float",
720
759
  )
721
- expected_output_cols = self._align_expected_output_names(
760
+ expected_output_cols, _ = self._align_expected_output(
722
761
  inference_method, dataset, expected_output_cols, output_cols_prefix
723
762
  )
724
763
  elif isinstance(dataset, pd.DataFrame):
@@ -781,7 +820,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
781
820
  drop_input_cols=self._drop_input_cols,
782
821
  expected_output_cols_type="float",
783
822
  )
784
- expected_output_cols = self._align_expected_output_names(
823
+ expected_output_cols, _ = self._align_expected_output(
785
824
  inference_method, dataset, expected_output_cols, output_cols_prefix
786
825
  )
787
826
 
@@ -846,7 +885,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
846
885
  drop_input_cols = self._drop_input_cols,
847
886
  expected_output_cols_type="float",
848
887
  )
849
- expected_output_cols = self._align_expected_output_names(
888
+ expected_output_cols, _ = self._align_expected_output(
850
889
  inference_method, dataset, expected_output_cols, output_cols_prefix
851
890
  )
852
891
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -574,12 +571,23 @@ class PassiveAggressiveClassifier(BaseTransformer):
574
571
  autogenerated=self._autogenerated,
575
572
  subproject=_SUBPROJECT,
576
573
  )
577
- output_result, fitted_estimator = model_trainer.train_fit_predict(
578
- drop_input_cols=self._drop_input_cols,
579
- expected_output_cols_list=(
580
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
581
- ),
574
+ expected_output_cols = (
575
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
582
576
  )
577
+ if isinstance(dataset, DataFrame):
578
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
579
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
580
+ )
581
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
582
+ drop_input_cols=self._drop_input_cols,
583
+ expected_output_cols_list=expected_output_cols,
584
+ example_output_pd_df=example_output_pd_df,
585
+ )
586
+ else:
587
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
588
+ drop_input_cols=self._drop_input_cols,
589
+ expected_output_cols_list=expected_output_cols,
590
+ )
583
591
  self._sklearn_object = fitted_estimator
584
592
  self._is_fitted = True
585
593
  return output_result
@@ -658,12 +666,41 @@ class PassiveAggressiveClassifier(BaseTransformer):
658
666
 
659
667
  return rv
660
668
 
661
- def _align_expected_output_names(
662
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
663
- ) -> List[str]:
669
+ def _align_expected_output(
670
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
671
+ ) -> Tuple[List[str], pd.DataFrame]:
672
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
673
+ and output dataframe with 1 line.
674
+ If the method is fit_predict, run 2 lines of data.
675
+ """
664
676
  # in case the inferred output column names dimension is different
665
677
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
666
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
678
+
679
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
680
+ # so change the minimum of number of rows to 2
681
+ num_examples = 2
682
+ statement_params = telemetry.get_function_usage_statement_params(
683
+ project=_PROJECT,
684
+ subproject=_SUBPROJECT,
685
+ function_name=telemetry.get_statement_params_full_func_name(
686
+ inspect.currentframe(), PassiveAggressiveClassifier.__class__.__name__
687
+ ),
688
+ api_calls=[Session.call],
689
+ custom_tags={"autogen": True} if self._autogenerated else None,
690
+ )
691
+ if output_cols_prefix == "fit_predict_":
692
+ if hasattr(self._sklearn_object, "n_clusters"):
693
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
694
+ num_examples = self._sklearn_object.n_clusters
695
+ elif hasattr(self._sklearn_object, "min_samples"):
696
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
697
+ num_examples = self._sklearn_object.min_samples
698
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
699
+ # LocalOutlierFactor expects n_neighbors <= n_samples
700
+ num_examples = self._sklearn_object.n_neighbors
701
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
702
+ else:
703
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
667
704
 
668
705
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
669
706
  # seen during the fit.
@@ -675,12 +712,14 @@ class PassiveAggressiveClassifier(BaseTransformer):
675
712
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
676
713
  if self.sample_weight_col:
677
714
  output_df_columns_set -= set(self.sample_weight_col)
715
+
678
716
  # if the dimension of inferred output column names is correct; use it
679
717
  if len(expected_output_cols_list) == len(output_df_columns_set):
680
- return expected_output_cols_list
718
+ return expected_output_cols_list, output_df_pd
681
719
  # otherwise, use the sklearn estimator's output
682
720
  else:
683
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
721
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
722
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
684
723
 
685
724
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
686
725
  @telemetry.send_api_usage_telemetry(
@@ -726,7 +765,7 @@ class PassiveAggressiveClassifier(BaseTransformer):
726
765
  drop_input_cols=self._drop_input_cols,
727
766
  expected_output_cols_type="float",
728
767
  )
729
- expected_output_cols = self._align_expected_output_names(
768
+ expected_output_cols, _ = self._align_expected_output(
730
769
  inference_method, dataset, expected_output_cols, output_cols_prefix
731
770
  )
732
771
 
@@ -792,7 +831,7 @@ class PassiveAggressiveClassifier(BaseTransformer):
792
831
  drop_input_cols=self._drop_input_cols,
793
832
  expected_output_cols_type="float",
794
833
  )
795
- expected_output_cols = self._align_expected_output_names(
834
+ expected_output_cols, _ = self._align_expected_output(
796
835
  inference_method, dataset, expected_output_cols, output_cols_prefix
797
836
  )
798
837
  elif isinstance(dataset, pd.DataFrame):
@@ -857,7 +896,7 @@ class PassiveAggressiveClassifier(BaseTransformer):
857
896
  drop_input_cols=self._drop_input_cols,
858
897
  expected_output_cols_type="float",
859
898
  )
860
- expected_output_cols = self._align_expected_output_names(
899
+ expected_output_cols, _ = self._align_expected_output(
861
900
  inference_method, dataset, expected_output_cols, output_cols_prefix
862
901
  )
863
902
 
@@ -922,7 +961,7 @@ class PassiveAggressiveClassifier(BaseTransformer):
922
961
  drop_input_cols = self._drop_input_cols,
923
962
  expected_output_cols_type="float",
924
963
  )
925
- expected_output_cols = self._align_expected_output_names(
964
+ expected_output_cols, _ = self._align_expected_output(
926
965
  inference_method, dataset, expected_output_cols, output_cols_prefix
927
966
  )
928
967