snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. snowflake/ml/_internal/telemetry.py +142 -20
  2. snowflake/ml/_internal/utils/identifier.py +48 -11
  3. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  4. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  5. snowflake/ml/_internal/utils/table_manager.py +19 -1
  6. snowflake/ml/_internal/utils/uri.py +2 -2
  7. snowflake/ml/data/data_connector.py +33 -7
  8. snowflake/ml/data/torch_utils.py +68 -0
  9. snowflake/ml/dataset/dataset.py +1 -3
  10. snowflake/ml/feature_store/feature_store.py +41 -17
  11. snowflake/ml/feature_store/feature_view.py +2 -2
  12. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  13. snowflake/ml/fileset/fileset.py +1 -1
  14. snowflake/ml/fileset/sfcfs.py +9 -3
  15. snowflake/ml/model/_client/model/model_version_impl.py +22 -7
  16. snowflake/ml/model/_client/ops/model_ops.py +39 -3
  17. snowflake/ml/model/_client/ops/service_ops.py +198 -7
  18. snowflake/ml/model/_client/service/model_deployment_spec.py +4 -5
  19. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -2
  20. snowflake/ml/model/_client/sql/service.py +85 -18
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  23. snowflake/ml/model/_model_composer/model_composer.py +2 -0
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -8
  25. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  26. snowflake/ml/model/_packager/model_handlers/catboost.py +17 -15
  27. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +23 -15
  28. snowflake/ml/model/_packager/model_handlers/lightgbm.py +15 -57
  29. snowflake/ml/model/_packager/model_handlers/llm.py +4 -2
  30. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  31. snowflake/ml/model/_packager/model_handlers/sklearn.py +36 -24
  32. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +119 -6
  33. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  34. snowflake/ml/model/_packager/model_handlers/xgboost.py +48 -48
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +10 -7
  36. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +0 -8
  37. snowflake/ml/model/_packager/model_packager.py +2 -0
  38. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  39. snowflake/ml/model/_signatures/utils.py +9 -0
  40. snowflake/ml/model/models/llm.py +3 -1
  41. snowflake/ml/model/type_hints.py +9 -1
  42. snowflake/ml/modeling/_internal/constants.py +1 -0
  43. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  44. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  45. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  46. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  48. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  49. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  50. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  51. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  52. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  53. snowflake/ml/modeling/cluster/birch.py +60 -21
  54. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  55. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  56. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  57. snowflake/ml/modeling/cluster/k_means.py +60 -21
  58. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  59. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  60. snowflake/ml/modeling/cluster/optics.py +60 -21
  61. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  62. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  63. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  64. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  65. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  66. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  67. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  68. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  69. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  70. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  71. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  72. snowflake/ml/modeling/covariance/oas.py +60 -21
  73. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  74. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  75. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  76. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  77. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  78. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  79. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  80. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  81. snowflake/ml/modeling/decomposition/pca.py +60 -21
  82. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  83. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  84. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  85. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  86. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  87. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  88. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  89. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  90. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  91. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  92. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  93. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  94. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  95. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  96. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  97. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  98. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  99. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  100. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  101. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  102. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  103. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  104. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  105. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  106. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  107. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  108. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  109. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  111. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  112. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  113. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  114. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  123. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  125. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  129. snowflake/ml/modeling/linear_model/lars.py +60 -21
  130. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  131. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  136. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  146. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  149. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  158. snowflake/ml/modeling/manifold/isomap.py +60 -21
  159. snowflake/ml/modeling/manifold/mds.py +60 -21
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  161. snowflake/ml/modeling/manifold/tsne.py +60 -21
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  174. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  184. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  185. snowflake/ml/modeling/pipeline/pipeline.py +1 -12
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  189. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  190. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  191. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  192. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  193. snowflake/ml/modeling/svm/svc.py +60 -21
  194. snowflake/ml/modeling/svm/svr.py +60 -21
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  203. snowflake/ml/registry/_manager/model_manager.py +4 -0
  204. snowflake/ml/registry/model_registry.py +1 -1
  205. snowflake/ml/registry/registry.py +1 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +23 -4
  208. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +211 -209
  209. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  210. snowflake/ml/data/torch_dataset.py +0 -33
  211. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  212. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -557,12 +554,23 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
557
554
  autogenerated=self._autogenerated,
558
555
  subproject=_SUBPROJECT,
559
556
  )
560
- output_result, fitted_estimator = model_trainer.train_fit_predict(
561
- drop_input_cols=self._drop_input_cols,
562
- expected_output_cols_list=(
563
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
564
- ),
557
+ expected_output_cols = (
558
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
565
559
  )
560
+ if isinstance(dataset, DataFrame):
561
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
562
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
563
+ )
564
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
565
+ drop_input_cols=self._drop_input_cols,
566
+ expected_output_cols_list=expected_output_cols,
567
+ example_output_pd_df=example_output_pd_df,
568
+ )
569
+ else:
570
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
571
+ drop_input_cols=self._drop_input_cols,
572
+ expected_output_cols_list=expected_output_cols,
573
+ )
566
574
  self._sklearn_object = fitted_estimator
567
575
  self._is_fitted = True
568
576
  return output_result
@@ -643,12 +651,41 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
643
651
 
644
652
  return rv
645
653
 
646
- def _align_expected_output_names(
647
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
648
- ) -> List[str]:
654
+ def _align_expected_output(
655
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
656
+ ) -> Tuple[List[str], pd.DataFrame]:
657
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
658
+ and output dataframe with 1 line.
659
+ If the method is fit_predict, run 2 lines of data.
660
+ """
649
661
  # in case the inferred output column names dimension is different
650
662
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
651
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
663
+
664
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
665
+ # so change the minimum of number of rows to 2
666
+ num_examples = 2
667
+ statement_params = telemetry.get_function_usage_statement_params(
668
+ project=_PROJECT,
669
+ subproject=_SUBPROJECT,
670
+ function_name=telemetry.get_statement_params_full_func_name(
671
+ inspect.currentframe(), NeighborhoodComponentsAnalysis.__class__.__name__
672
+ ),
673
+ api_calls=[Session.call],
674
+ custom_tags={"autogen": True} if self._autogenerated else None,
675
+ )
676
+ if output_cols_prefix == "fit_predict_":
677
+ if hasattr(self._sklearn_object, "n_clusters"):
678
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
679
+ num_examples = self._sklearn_object.n_clusters
680
+ elif hasattr(self._sklearn_object, "min_samples"):
681
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
682
+ num_examples = self._sklearn_object.min_samples
683
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
684
+ # LocalOutlierFactor expects n_neighbors <= n_samples
685
+ num_examples = self._sklearn_object.n_neighbors
686
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
687
+ else:
688
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
652
689
 
653
690
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
654
691
  # seen during the fit.
@@ -660,12 +697,14 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
660
697
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
661
698
  if self.sample_weight_col:
662
699
  output_df_columns_set -= set(self.sample_weight_col)
700
+
663
701
  # if the dimension of inferred output column names is correct; use it
664
702
  if len(expected_output_cols_list) == len(output_df_columns_set):
665
- return expected_output_cols_list
703
+ return expected_output_cols_list, output_df_pd
666
704
  # otherwise, use the sklearn estimator's output
667
705
  else:
668
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
706
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
707
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
669
708
 
670
709
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
671
710
  @telemetry.send_api_usage_telemetry(
@@ -711,7 +750,7 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
711
750
  drop_input_cols=self._drop_input_cols,
712
751
  expected_output_cols_type="float",
713
752
  )
714
- expected_output_cols = self._align_expected_output_names(
753
+ expected_output_cols, _ = self._align_expected_output(
715
754
  inference_method, dataset, expected_output_cols, output_cols_prefix
716
755
  )
717
756
 
@@ -777,7 +816,7 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
777
816
  drop_input_cols=self._drop_input_cols,
778
817
  expected_output_cols_type="float",
779
818
  )
780
- expected_output_cols = self._align_expected_output_names(
819
+ expected_output_cols, _ = self._align_expected_output(
781
820
  inference_method, dataset, expected_output_cols, output_cols_prefix
782
821
  )
783
822
  elif isinstance(dataset, pd.DataFrame):
@@ -840,7 +879,7 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
840
879
  drop_input_cols=self._drop_input_cols,
841
880
  expected_output_cols_type="float",
842
881
  )
843
- expected_output_cols = self._align_expected_output_names(
882
+ expected_output_cols, _ = self._align_expected_output(
844
883
  inference_method, dataset, expected_output_cols, output_cols_prefix
845
884
  )
846
885
 
@@ -905,7 +944,7 @@ class NeighborhoodComponentsAnalysis(BaseTransformer):
905
944
  drop_input_cols = self._drop_input_cols,
906
945
  expected_output_cols_type="float",
907
946
  )
908
- expected_output_cols = self._align_expected_output_names(
947
+ expected_output_cols, _ = self._align_expected_output(
909
948
  inference_method, dataset, expected_output_cols, output_cols_prefix
910
949
  )
911
950
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -558,12 +555,23 @@ class RadiusNeighborsClassifier(BaseTransformer):
558
555
  autogenerated=self._autogenerated,
559
556
  subproject=_SUBPROJECT,
560
557
  )
561
- output_result, fitted_estimator = model_trainer.train_fit_predict(
562
- drop_input_cols=self._drop_input_cols,
563
- expected_output_cols_list=(
564
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
565
- ),
558
+ expected_output_cols = (
559
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
566
560
  )
561
+ if isinstance(dataset, DataFrame):
562
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
563
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
564
+ )
565
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
566
+ drop_input_cols=self._drop_input_cols,
567
+ expected_output_cols_list=expected_output_cols,
568
+ example_output_pd_df=example_output_pd_df,
569
+ )
570
+ else:
571
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
572
+ drop_input_cols=self._drop_input_cols,
573
+ expected_output_cols_list=expected_output_cols,
574
+ )
567
575
  self._sklearn_object = fitted_estimator
568
576
  self._is_fitted = True
569
577
  return output_result
@@ -642,12 +650,41 @@ class RadiusNeighborsClassifier(BaseTransformer):
642
650
 
643
651
  return rv
644
652
 
645
- def _align_expected_output_names(
646
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
647
- ) -> List[str]:
653
+ def _align_expected_output(
654
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
655
+ ) -> Tuple[List[str], pd.DataFrame]:
656
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
657
+ and output dataframe with 1 line.
658
+ If the method is fit_predict, run 2 lines of data.
659
+ """
648
660
  # in case the inferred output column names dimension is different
649
661
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
650
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
662
+
663
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
664
+ # so change the minimum of number of rows to 2
665
+ num_examples = 2
666
+ statement_params = telemetry.get_function_usage_statement_params(
667
+ project=_PROJECT,
668
+ subproject=_SUBPROJECT,
669
+ function_name=telemetry.get_statement_params_full_func_name(
670
+ inspect.currentframe(), RadiusNeighborsClassifier.__class__.__name__
671
+ ),
672
+ api_calls=[Session.call],
673
+ custom_tags={"autogen": True} if self._autogenerated else None,
674
+ )
675
+ if output_cols_prefix == "fit_predict_":
676
+ if hasattr(self._sklearn_object, "n_clusters"):
677
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
678
+ num_examples = self._sklearn_object.n_clusters
679
+ elif hasattr(self._sklearn_object, "min_samples"):
680
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
681
+ num_examples = self._sklearn_object.min_samples
682
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
683
+ # LocalOutlierFactor expects n_neighbors <= n_samples
684
+ num_examples = self._sklearn_object.n_neighbors
685
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
686
+ else:
687
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
651
688
 
652
689
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
653
690
  # seen during the fit.
@@ -659,12 +696,14 @@ class RadiusNeighborsClassifier(BaseTransformer):
659
696
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
660
697
  if self.sample_weight_col:
661
698
  output_df_columns_set -= set(self.sample_weight_col)
699
+
662
700
  # if the dimension of inferred output column names is correct; use it
663
701
  if len(expected_output_cols_list) == len(output_df_columns_set):
664
- return expected_output_cols_list
702
+ return expected_output_cols_list, output_df_pd
665
703
  # otherwise, use the sklearn estimator's output
666
704
  else:
667
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
705
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
706
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
668
707
 
669
708
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
670
709
  @telemetry.send_api_usage_telemetry(
@@ -712,7 +751,7 @@ class RadiusNeighborsClassifier(BaseTransformer):
712
751
  drop_input_cols=self._drop_input_cols,
713
752
  expected_output_cols_type="float",
714
753
  )
715
- expected_output_cols = self._align_expected_output_names(
754
+ expected_output_cols, _ = self._align_expected_output(
716
755
  inference_method, dataset, expected_output_cols, output_cols_prefix
717
756
  )
718
757
 
@@ -780,7 +819,7 @@ class RadiusNeighborsClassifier(BaseTransformer):
780
819
  drop_input_cols=self._drop_input_cols,
781
820
  expected_output_cols_type="float",
782
821
  )
783
- expected_output_cols = self._align_expected_output_names(
822
+ expected_output_cols, _ = self._align_expected_output(
784
823
  inference_method, dataset, expected_output_cols, output_cols_prefix
785
824
  )
786
825
  elif isinstance(dataset, pd.DataFrame):
@@ -843,7 +882,7 @@ class RadiusNeighborsClassifier(BaseTransformer):
843
882
  drop_input_cols=self._drop_input_cols,
844
883
  expected_output_cols_type="float",
845
884
  )
846
- expected_output_cols = self._align_expected_output_names(
885
+ expected_output_cols, _ = self._align_expected_output(
847
886
  inference_method, dataset, expected_output_cols, output_cols_prefix
848
887
  )
849
888
 
@@ -908,7 +947,7 @@ class RadiusNeighborsClassifier(BaseTransformer):
908
947
  drop_input_cols = self._drop_input_cols,
909
948
  expected_output_cols_type="float",
910
949
  )
911
- expected_output_cols = self._align_expected_output_names(
950
+ expected_output_cols, _ = self._align_expected_output(
912
951
  inference_method, dataset, expected_output_cols, output_cols_prefix
913
952
  )
914
953
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -548,12 +545,23 @@ class RadiusNeighborsRegressor(BaseTransformer):
548
545
  autogenerated=self._autogenerated,
549
546
  subproject=_SUBPROJECT,
550
547
  )
551
- output_result, fitted_estimator = model_trainer.train_fit_predict(
552
- drop_input_cols=self._drop_input_cols,
553
- expected_output_cols_list=(
554
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
555
- ),
548
+ expected_output_cols = (
549
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
556
550
  )
551
+ if isinstance(dataset, DataFrame):
552
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
553
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
554
+ )
555
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
556
+ drop_input_cols=self._drop_input_cols,
557
+ expected_output_cols_list=expected_output_cols,
558
+ example_output_pd_df=example_output_pd_df,
559
+ )
560
+ else:
561
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
562
+ drop_input_cols=self._drop_input_cols,
563
+ expected_output_cols_list=expected_output_cols,
564
+ )
557
565
  self._sklearn_object = fitted_estimator
558
566
  self._is_fitted = True
559
567
  return output_result
@@ -632,12 +640,41 @@ class RadiusNeighborsRegressor(BaseTransformer):
632
640
 
633
641
  return rv
634
642
 
635
- def _align_expected_output_names(
636
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
637
- ) -> List[str]:
643
+ def _align_expected_output(
644
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
645
+ ) -> Tuple[List[str], pd.DataFrame]:
646
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
647
+ and output dataframe with 1 line.
648
+ If the method is fit_predict, run 2 lines of data.
649
+ """
638
650
  # in case the inferred output column names dimension is different
639
651
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
640
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
652
+
653
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
654
+ # so change the minimum of number of rows to 2
655
+ num_examples = 2
656
+ statement_params = telemetry.get_function_usage_statement_params(
657
+ project=_PROJECT,
658
+ subproject=_SUBPROJECT,
659
+ function_name=telemetry.get_statement_params_full_func_name(
660
+ inspect.currentframe(), RadiusNeighborsRegressor.__class__.__name__
661
+ ),
662
+ api_calls=[Session.call],
663
+ custom_tags={"autogen": True} if self._autogenerated else None,
664
+ )
665
+ if output_cols_prefix == "fit_predict_":
666
+ if hasattr(self._sklearn_object, "n_clusters"):
667
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
668
+ num_examples = self._sklearn_object.n_clusters
669
+ elif hasattr(self._sklearn_object, "min_samples"):
670
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
671
+ num_examples = self._sklearn_object.min_samples
672
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
673
+ # LocalOutlierFactor expects n_neighbors <= n_samples
674
+ num_examples = self._sklearn_object.n_neighbors
675
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
676
+ else:
677
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
641
678
 
642
679
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
643
680
  # seen during the fit.
@@ -649,12 +686,14 @@ class RadiusNeighborsRegressor(BaseTransformer):
649
686
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
650
687
  if self.sample_weight_col:
651
688
  output_df_columns_set -= set(self.sample_weight_col)
689
+
652
690
  # if the dimension of inferred output column names is correct; use it
653
691
  if len(expected_output_cols_list) == len(output_df_columns_set):
654
- return expected_output_cols_list
692
+ return expected_output_cols_list, output_df_pd
655
693
  # otherwise, use the sklearn estimator's output
656
694
  else:
657
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
695
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
696
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
658
697
 
659
698
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
660
699
  @telemetry.send_api_usage_telemetry(
@@ -700,7 +739,7 @@ class RadiusNeighborsRegressor(BaseTransformer):
700
739
  drop_input_cols=self._drop_input_cols,
701
740
  expected_output_cols_type="float",
702
741
  )
703
- expected_output_cols = self._align_expected_output_names(
742
+ expected_output_cols, _ = self._align_expected_output(
704
743
  inference_method, dataset, expected_output_cols, output_cols_prefix
705
744
  )
706
745
 
@@ -766,7 +805,7 @@ class RadiusNeighborsRegressor(BaseTransformer):
766
805
  drop_input_cols=self._drop_input_cols,
767
806
  expected_output_cols_type="float",
768
807
  )
769
- expected_output_cols = self._align_expected_output_names(
808
+ expected_output_cols, _ = self._align_expected_output(
770
809
  inference_method, dataset, expected_output_cols, output_cols_prefix
771
810
  )
772
811
  elif isinstance(dataset, pd.DataFrame):
@@ -829,7 +868,7 @@ class RadiusNeighborsRegressor(BaseTransformer):
829
868
  drop_input_cols=self._drop_input_cols,
830
869
  expected_output_cols_type="float",
831
870
  )
832
- expected_output_cols = self._align_expected_output_names(
871
+ expected_output_cols, _ = self._align_expected_output(
833
872
  inference_method, dataset, expected_output_cols, output_cols_prefix
834
873
  )
835
874
 
@@ -894,7 +933,7 @@ class RadiusNeighborsRegressor(BaseTransformer):
894
933
  drop_input_cols = self._drop_input_cols,
895
934
  expected_output_cols_type="float",
896
935
  )
897
- expected_output_cols = self._align_expected_output_names(
936
+ expected_output_cols, _ = self._align_expected_output(
898
937
  inference_method, dataset, expected_output_cols, output_cols_prefix
899
938
  )
900
939
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -507,12 +504,23 @@ class BernoulliRBM(BaseTransformer):
507
504
  autogenerated=self._autogenerated,
508
505
  subproject=_SUBPROJECT,
509
506
  )
510
- output_result, fitted_estimator = model_trainer.train_fit_predict(
511
- drop_input_cols=self._drop_input_cols,
512
- expected_output_cols_list=(
513
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
514
- ),
507
+ expected_output_cols = (
508
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
515
509
  )
510
+ if isinstance(dataset, DataFrame):
511
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
512
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
513
+ )
514
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
515
+ drop_input_cols=self._drop_input_cols,
516
+ expected_output_cols_list=expected_output_cols,
517
+ example_output_pd_df=example_output_pd_df,
518
+ )
519
+ else:
520
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
521
+ drop_input_cols=self._drop_input_cols,
522
+ expected_output_cols_list=expected_output_cols,
523
+ )
516
524
  self._sklearn_object = fitted_estimator
517
525
  self._is_fitted = True
518
526
  return output_result
@@ -593,12 +601,41 @@ class BernoulliRBM(BaseTransformer):
593
601
 
594
602
  return rv
595
603
 
596
- def _align_expected_output_names(
597
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
598
- ) -> List[str]:
604
+ def _align_expected_output(
605
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
606
+ ) -> Tuple[List[str], pd.DataFrame]:
607
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
608
+ and output dataframe with 1 line.
609
+ If the method is fit_predict, run 2 lines of data.
610
+ """
599
611
  # in case the inferred output column names dimension is different
600
612
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
601
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
613
+
614
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
615
+ # so change the minimum of number of rows to 2
616
+ num_examples = 2
617
+ statement_params = telemetry.get_function_usage_statement_params(
618
+ project=_PROJECT,
619
+ subproject=_SUBPROJECT,
620
+ function_name=telemetry.get_statement_params_full_func_name(
621
+ inspect.currentframe(), BernoulliRBM.__class__.__name__
622
+ ),
623
+ api_calls=[Session.call],
624
+ custom_tags={"autogen": True} if self._autogenerated else None,
625
+ )
626
+ if output_cols_prefix == "fit_predict_":
627
+ if hasattr(self._sklearn_object, "n_clusters"):
628
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
629
+ num_examples = self._sklearn_object.n_clusters
630
+ elif hasattr(self._sklearn_object, "min_samples"):
631
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
632
+ num_examples = self._sklearn_object.min_samples
633
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
634
+ # LocalOutlierFactor expects n_neighbors <= n_samples
635
+ num_examples = self._sklearn_object.n_neighbors
636
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
637
+ else:
638
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
602
639
 
603
640
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
604
641
  # seen during the fit.
@@ -610,12 +647,14 @@ class BernoulliRBM(BaseTransformer):
610
647
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
611
648
  if self.sample_weight_col:
612
649
  output_df_columns_set -= set(self.sample_weight_col)
650
+
613
651
  # if the dimension of inferred output column names is correct; use it
614
652
  if len(expected_output_cols_list) == len(output_df_columns_set):
615
- return expected_output_cols_list
653
+ return expected_output_cols_list, output_df_pd
616
654
  # otherwise, use the sklearn estimator's output
617
655
  else:
618
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
656
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
657
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
619
658
 
620
659
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
621
660
  @telemetry.send_api_usage_telemetry(
@@ -661,7 +700,7 @@ class BernoulliRBM(BaseTransformer):
661
700
  drop_input_cols=self._drop_input_cols,
662
701
  expected_output_cols_type="float",
663
702
  )
664
- expected_output_cols = self._align_expected_output_names(
703
+ expected_output_cols, _ = self._align_expected_output(
665
704
  inference_method, dataset, expected_output_cols, output_cols_prefix
666
705
  )
667
706
 
@@ -727,7 +766,7 @@ class BernoulliRBM(BaseTransformer):
727
766
  drop_input_cols=self._drop_input_cols,
728
767
  expected_output_cols_type="float",
729
768
  )
730
- expected_output_cols = self._align_expected_output_names(
769
+ expected_output_cols, _ = self._align_expected_output(
731
770
  inference_method, dataset, expected_output_cols, output_cols_prefix
732
771
  )
733
772
  elif isinstance(dataset, pd.DataFrame):
@@ -790,7 +829,7 @@ class BernoulliRBM(BaseTransformer):
790
829
  drop_input_cols=self._drop_input_cols,
791
830
  expected_output_cols_type="float",
792
831
  )
793
- expected_output_cols = self._align_expected_output_names(
832
+ expected_output_cols, _ = self._align_expected_output(
794
833
  inference_method, dataset, expected_output_cols, output_cols_prefix
795
834
  )
796
835
 
@@ -857,7 +896,7 @@ class BernoulliRBM(BaseTransformer):
857
896
  drop_input_cols = self._drop_input_cols,
858
897
  expected_output_cols_type="float",
859
898
  )
860
- expected_output_cols = self._align_expected_output_names(
899
+ expected_output_cols, _ = self._align_expected_output(
861
900
  inference_method, dataset, expected_output_cols, output_cols_prefix
862
901
  )
863
902