snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. snowflake/ml/_internal/telemetry.py +142 -20
  2. snowflake/ml/_internal/utils/identifier.py +48 -11
  3. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  4. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  5. snowflake/ml/_internal/utils/table_manager.py +19 -1
  6. snowflake/ml/_internal/utils/uri.py +2 -2
  7. snowflake/ml/data/data_connector.py +33 -7
  8. snowflake/ml/data/torch_utils.py +68 -0
  9. snowflake/ml/dataset/dataset.py +1 -3
  10. snowflake/ml/feature_store/feature_store.py +41 -17
  11. snowflake/ml/feature_store/feature_view.py +2 -2
  12. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  13. snowflake/ml/fileset/fileset.py +1 -1
  14. snowflake/ml/fileset/sfcfs.py +9 -3
  15. snowflake/ml/model/_client/model/model_version_impl.py +22 -7
  16. snowflake/ml/model/_client/ops/model_ops.py +39 -3
  17. snowflake/ml/model/_client/ops/service_ops.py +198 -7
  18. snowflake/ml/model/_client/service/model_deployment_spec.py +4 -5
  19. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -2
  20. snowflake/ml/model/_client/sql/service.py +85 -18
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  23. snowflake/ml/model/_model_composer/model_composer.py +2 -0
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -8
  25. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  26. snowflake/ml/model/_packager/model_handlers/catboost.py +17 -15
  27. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +23 -15
  28. snowflake/ml/model/_packager/model_handlers/lightgbm.py +15 -57
  29. snowflake/ml/model/_packager/model_handlers/llm.py +4 -2
  30. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  31. snowflake/ml/model/_packager/model_handlers/sklearn.py +36 -24
  32. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +119 -6
  33. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  34. snowflake/ml/model/_packager/model_handlers/xgboost.py +48 -48
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +10 -7
  36. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +0 -8
  37. snowflake/ml/model/_packager/model_packager.py +2 -0
  38. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  39. snowflake/ml/model/_signatures/utils.py +9 -0
  40. snowflake/ml/model/models/llm.py +3 -1
  41. snowflake/ml/model/type_hints.py +9 -1
  42. snowflake/ml/modeling/_internal/constants.py +1 -0
  43. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  44. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  45. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  46. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  48. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  49. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  50. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  51. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  52. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  53. snowflake/ml/modeling/cluster/birch.py +60 -21
  54. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  55. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  56. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  57. snowflake/ml/modeling/cluster/k_means.py +60 -21
  58. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  59. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  60. snowflake/ml/modeling/cluster/optics.py +60 -21
  61. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  62. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  63. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  64. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  65. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  66. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  67. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  68. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  69. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  70. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  71. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  72. snowflake/ml/modeling/covariance/oas.py +60 -21
  73. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  74. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  75. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  76. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  77. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  78. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  79. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  80. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  81. snowflake/ml/modeling/decomposition/pca.py +60 -21
  82. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  83. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  84. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  85. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  86. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  87. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  88. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  89. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  90. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  91. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  92. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  93. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  94. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  95. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  96. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  97. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  98. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  99. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  100. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  101. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  102. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  103. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  104. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  105. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  106. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  107. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  108. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  109. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  111. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  112. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  113. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  114. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  123. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  125. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  129. snowflake/ml/modeling/linear_model/lars.py +60 -21
  130. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  131. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  136. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  146. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  149. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  158. snowflake/ml/modeling/manifold/isomap.py +60 -21
  159. snowflake/ml/modeling/manifold/mds.py +60 -21
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  161. snowflake/ml/modeling/manifold/tsne.py +60 -21
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  174. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  184. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  185. snowflake/ml/modeling/pipeline/pipeline.py +1 -12
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  189. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  190. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  191. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  192. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  193. snowflake/ml/modeling/svm/svc.py +60 -21
  194. snowflake/ml/modeling/svm/svr.py +60 -21
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  203. snowflake/ml/registry/_manager/model_manager.py +4 -0
  204. snowflake/ml/registry/model_registry.py +1 -1
  205. snowflake/ml/registry/registry.py +1 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +23 -4
  208. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +211 -209
  209. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  210. snowflake/ml/data/torch_dataset.py +0 -33
  211. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  212. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -503,12 +500,23 @@ class EllipticEnvelope(BaseTransformer):
503
500
  autogenerated=self._autogenerated,
504
501
  subproject=_SUBPROJECT,
505
502
  )
506
- output_result, fitted_estimator = model_trainer.train_fit_predict(
507
- drop_input_cols=self._drop_input_cols,
508
- expected_output_cols_list=(
509
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
510
- ),
503
+ expected_output_cols = (
504
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
511
505
  )
506
+ if isinstance(dataset, DataFrame):
507
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
508
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
509
+ )
510
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
511
+ drop_input_cols=self._drop_input_cols,
512
+ expected_output_cols_list=expected_output_cols,
513
+ example_output_pd_df=example_output_pd_df,
514
+ )
515
+ else:
516
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
517
+ drop_input_cols=self._drop_input_cols,
518
+ expected_output_cols_list=expected_output_cols,
519
+ )
512
520
  self._sklearn_object = fitted_estimator
513
521
  self._is_fitted = True
514
522
  return output_result
@@ -587,12 +595,41 @@ class EllipticEnvelope(BaseTransformer):
587
595
 
588
596
  return rv
589
597
 
590
- def _align_expected_output_names(
591
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
592
- ) -> List[str]:
598
+ def _align_expected_output(
599
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
600
+ ) -> Tuple[List[str], pd.DataFrame]:
601
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
602
+ and output dataframe with 1 line.
603
+ If the method is fit_predict, run 2 lines of data.
604
+ """
593
605
  # in case the inferred output column names dimension is different
594
606
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
595
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
607
+
608
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
609
+ # so change the minimum of number of rows to 2
610
+ num_examples = 2
611
+ statement_params = telemetry.get_function_usage_statement_params(
612
+ project=_PROJECT,
613
+ subproject=_SUBPROJECT,
614
+ function_name=telemetry.get_statement_params_full_func_name(
615
+ inspect.currentframe(), EllipticEnvelope.__class__.__name__
616
+ ),
617
+ api_calls=[Session.call],
618
+ custom_tags={"autogen": True} if self._autogenerated else None,
619
+ )
620
+ if output_cols_prefix == "fit_predict_":
621
+ if hasattr(self._sklearn_object, "n_clusters"):
622
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
623
+ num_examples = self._sklearn_object.n_clusters
624
+ elif hasattr(self._sklearn_object, "min_samples"):
625
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
626
+ num_examples = self._sklearn_object.min_samples
627
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
628
+ # LocalOutlierFactor expects n_neighbors <= n_samples
629
+ num_examples = self._sklearn_object.n_neighbors
630
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
631
+ else:
632
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
596
633
 
597
634
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
598
635
  # seen during the fit.
@@ -604,12 +641,14 @@ class EllipticEnvelope(BaseTransformer):
604
641
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
605
642
  if self.sample_weight_col:
606
643
  output_df_columns_set -= set(self.sample_weight_col)
644
+
607
645
  # if the dimension of inferred output column names is correct; use it
608
646
  if len(expected_output_cols_list) == len(output_df_columns_set):
609
- return expected_output_cols_list
647
+ return expected_output_cols_list, output_df_pd
610
648
  # otherwise, use the sklearn estimator's output
611
649
  else:
612
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
650
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
651
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
613
652
 
614
653
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
615
654
  @telemetry.send_api_usage_telemetry(
@@ -655,7 +694,7 @@ class EllipticEnvelope(BaseTransformer):
655
694
  drop_input_cols=self._drop_input_cols,
656
695
  expected_output_cols_type="float",
657
696
  )
658
- expected_output_cols = self._align_expected_output_names(
697
+ expected_output_cols, _ = self._align_expected_output(
659
698
  inference_method, dataset, expected_output_cols, output_cols_prefix
660
699
  )
661
700
 
@@ -721,7 +760,7 @@ class EllipticEnvelope(BaseTransformer):
721
760
  drop_input_cols=self._drop_input_cols,
722
761
  expected_output_cols_type="float",
723
762
  )
724
- expected_output_cols = self._align_expected_output_names(
763
+ expected_output_cols, _ = self._align_expected_output(
725
764
  inference_method, dataset, expected_output_cols, output_cols_prefix
726
765
  )
727
766
  elif isinstance(dataset, pd.DataFrame):
@@ -786,7 +825,7 @@ class EllipticEnvelope(BaseTransformer):
786
825
  drop_input_cols=self._drop_input_cols,
787
826
  expected_output_cols_type="float",
788
827
  )
789
- expected_output_cols = self._align_expected_output_names(
828
+ expected_output_cols, _ = self._align_expected_output(
790
829
  inference_method, dataset, expected_output_cols, output_cols_prefix
791
830
  )
792
831
 
@@ -853,7 +892,7 @@ class EllipticEnvelope(BaseTransformer):
853
892
  drop_input_cols = self._drop_input_cols,
854
893
  expected_output_cols_type="float",
855
894
  )
856
- expected_output_cols = self._align_expected_output_names(
895
+ expected_output_cols, _ = self._align_expected_output(
857
896
  inference_method, dataset, expected_output_cols, output_cols_prefix
858
897
  )
859
898
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -475,12 +472,23 @@ class EmpiricalCovariance(BaseTransformer):
475
472
  autogenerated=self._autogenerated,
476
473
  subproject=_SUBPROJECT,
477
474
  )
478
- output_result, fitted_estimator = model_trainer.train_fit_predict(
479
- drop_input_cols=self._drop_input_cols,
480
- expected_output_cols_list=(
481
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
482
- ),
475
+ expected_output_cols = (
476
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
483
477
  )
478
+ if isinstance(dataset, DataFrame):
479
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
480
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
481
+ )
482
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
483
+ drop_input_cols=self._drop_input_cols,
484
+ expected_output_cols_list=expected_output_cols,
485
+ example_output_pd_df=example_output_pd_df,
486
+ )
487
+ else:
488
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
489
+ drop_input_cols=self._drop_input_cols,
490
+ expected_output_cols_list=expected_output_cols,
491
+ )
484
492
  self._sklearn_object = fitted_estimator
485
493
  self._is_fitted = True
486
494
  return output_result
@@ -559,12 +567,41 @@ class EmpiricalCovariance(BaseTransformer):
559
567
 
560
568
  return rv
561
569
 
562
- def _align_expected_output_names(
563
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
564
- ) -> List[str]:
570
+ def _align_expected_output(
571
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
572
+ ) -> Tuple[List[str], pd.DataFrame]:
573
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
574
+ and output dataframe with 1 line.
575
+ If the method is fit_predict, run 2 lines of data.
576
+ """
565
577
  # in case the inferred output column names dimension is different
566
578
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
567
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
579
+
580
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
581
+ # so change the minimum of number of rows to 2
582
+ num_examples = 2
583
+ statement_params = telemetry.get_function_usage_statement_params(
584
+ project=_PROJECT,
585
+ subproject=_SUBPROJECT,
586
+ function_name=telemetry.get_statement_params_full_func_name(
587
+ inspect.currentframe(), EmpiricalCovariance.__class__.__name__
588
+ ),
589
+ api_calls=[Session.call],
590
+ custom_tags={"autogen": True} if self._autogenerated else None,
591
+ )
592
+ if output_cols_prefix == "fit_predict_":
593
+ if hasattr(self._sklearn_object, "n_clusters"):
594
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
595
+ num_examples = self._sklearn_object.n_clusters
596
+ elif hasattr(self._sklearn_object, "min_samples"):
597
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
598
+ num_examples = self._sklearn_object.min_samples
599
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
600
+ # LocalOutlierFactor expects n_neighbors <= n_samples
601
+ num_examples = self._sklearn_object.n_neighbors
602
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
603
+ else:
604
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
568
605
 
569
606
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
570
607
  # seen during the fit.
@@ -576,12 +613,14 @@ class EmpiricalCovariance(BaseTransformer):
576
613
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
577
614
  if self.sample_weight_col:
578
615
  output_df_columns_set -= set(self.sample_weight_col)
616
+
579
617
  # if the dimension of inferred output column names is correct; use it
580
618
  if len(expected_output_cols_list) == len(output_df_columns_set):
581
- return expected_output_cols_list
619
+ return expected_output_cols_list, output_df_pd
582
620
  # otherwise, use the sklearn estimator's output
583
621
  else:
584
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
622
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
623
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
585
624
 
586
625
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
587
626
  @telemetry.send_api_usage_telemetry(
@@ -627,7 +666,7 @@ class EmpiricalCovariance(BaseTransformer):
627
666
  drop_input_cols=self._drop_input_cols,
628
667
  expected_output_cols_type="float",
629
668
  )
630
- expected_output_cols = self._align_expected_output_names(
669
+ expected_output_cols, _ = self._align_expected_output(
631
670
  inference_method, dataset, expected_output_cols, output_cols_prefix
632
671
  )
633
672
 
@@ -693,7 +732,7 @@ class EmpiricalCovariance(BaseTransformer):
693
732
  drop_input_cols=self._drop_input_cols,
694
733
  expected_output_cols_type="float",
695
734
  )
696
- expected_output_cols = self._align_expected_output_names(
735
+ expected_output_cols, _ = self._align_expected_output(
697
736
  inference_method, dataset, expected_output_cols, output_cols_prefix
698
737
  )
699
738
  elif isinstance(dataset, pd.DataFrame):
@@ -756,7 +795,7 @@ class EmpiricalCovariance(BaseTransformer):
756
795
  drop_input_cols=self._drop_input_cols,
757
796
  expected_output_cols_type="float",
758
797
  )
759
- expected_output_cols = self._align_expected_output_names(
798
+ expected_output_cols, _ = self._align_expected_output(
760
799
  inference_method, dataset, expected_output_cols, output_cols_prefix
761
800
  )
762
801
 
@@ -821,7 +860,7 @@ class EmpiricalCovariance(BaseTransformer):
821
860
  drop_input_cols = self._drop_input_cols,
822
861
  expected_output_cols_type="float",
823
862
  )
824
- expected_output_cols = self._align_expected_output_names(
863
+ expected_output_cols, _ = self._align_expected_output(
825
864
  inference_method, dataset, expected_output_cols, output_cols_prefix
826
865
  )
827
866
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -523,12 +520,23 @@ class GraphicalLasso(BaseTransformer):
523
520
  autogenerated=self._autogenerated,
524
521
  subproject=_SUBPROJECT,
525
522
  )
526
- output_result, fitted_estimator = model_trainer.train_fit_predict(
527
- drop_input_cols=self._drop_input_cols,
528
- expected_output_cols_list=(
529
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
530
- ),
523
+ expected_output_cols = (
524
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
531
525
  )
526
+ if isinstance(dataset, DataFrame):
527
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
528
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
529
+ )
530
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
531
+ drop_input_cols=self._drop_input_cols,
532
+ expected_output_cols_list=expected_output_cols,
533
+ example_output_pd_df=example_output_pd_df,
534
+ )
535
+ else:
536
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
537
+ drop_input_cols=self._drop_input_cols,
538
+ expected_output_cols_list=expected_output_cols,
539
+ )
532
540
  self._sklearn_object = fitted_estimator
533
541
  self._is_fitted = True
534
542
  return output_result
@@ -607,12 +615,41 @@ class GraphicalLasso(BaseTransformer):
607
615
 
608
616
  return rv
609
617
 
610
- def _align_expected_output_names(
611
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
612
- ) -> List[str]:
618
+ def _align_expected_output(
619
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
620
+ ) -> Tuple[List[str], pd.DataFrame]:
621
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
622
+ and output dataframe with 1 line.
623
+ If the method is fit_predict, run 2 lines of data.
624
+ """
613
625
  # in case the inferred output column names dimension is different
614
626
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
615
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
627
+
628
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
629
+ # so change the minimum of number of rows to 2
630
+ num_examples = 2
631
+ statement_params = telemetry.get_function_usage_statement_params(
632
+ project=_PROJECT,
633
+ subproject=_SUBPROJECT,
634
+ function_name=telemetry.get_statement_params_full_func_name(
635
+ inspect.currentframe(), GraphicalLasso.__class__.__name__
636
+ ),
637
+ api_calls=[Session.call],
638
+ custom_tags={"autogen": True} if self._autogenerated else None,
639
+ )
640
+ if output_cols_prefix == "fit_predict_":
641
+ if hasattr(self._sklearn_object, "n_clusters"):
642
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
643
+ num_examples = self._sklearn_object.n_clusters
644
+ elif hasattr(self._sklearn_object, "min_samples"):
645
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
646
+ num_examples = self._sklearn_object.min_samples
647
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
648
+ # LocalOutlierFactor expects n_neighbors <= n_samples
649
+ num_examples = self._sklearn_object.n_neighbors
650
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
651
+ else:
652
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
616
653
 
617
654
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
618
655
  # seen during the fit.
@@ -624,12 +661,14 @@ class GraphicalLasso(BaseTransformer):
624
661
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
625
662
  if self.sample_weight_col:
626
663
  output_df_columns_set -= set(self.sample_weight_col)
664
+
627
665
  # if the dimension of inferred output column names is correct; use it
628
666
  if len(expected_output_cols_list) == len(output_df_columns_set):
629
- return expected_output_cols_list
667
+ return expected_output_cols_list, output_df_pd
630
668
  # otherwise, use the sklearn estimator's output
631
669
  else:
632
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
670
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
671
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
633
672
 
634
673
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
635
674
  @telemetry.send_api_usage_telemetry(
@@ -675,7 +714,7 @@ class GraphicalLasso(BaseTransformer):
675
714
  drop_input_cols=self._drop_input_cols,
676
715
  expected_output_cols_type="float",
677
716
  )
678
- expected_output_cols = self._align_expected_output_names(
717
+ expected_output_cols, _ = self._align_expected_output(
679
718
  inference_method, dataset, expected_output_cols, output_cols_prefix
680
719
  )
681
720
 
@@ -741,7 +780,7 @@ class GraphicalLasso(BaseTransformer):
741
780
  drop_input_cols=self._drop_input_cols,
742
781
  expected_output_cols_type="float",
743
782
  )
744
- expected_output_cols = self._align_expected_output_names(
783
+ expected_output_cols, _ = self._align_expected_output(
745
784
  inference_method, dataset, expected_output_cols, output_cols_prefix
746
785
  )
747
786
  elif isinstance(dataset, pd.DataFrame):
@@ -804,7 +843,7 @@ class GraphicalLasso(BaseTransformer):
804
843
  drop_input_cols=self._drop_input_cols,
805
844
  expected_output_cols_type="float",
806
845
  )
807
- expected_output_cols = self._align_expected_output_names(
846
+ expected_output_cols, _ = self._align_expected_output(
808
847
  inference_method, dataset, expected_output_cols, output_cols_prefix
809
848
  )
810
849
 
@@ -869,7 +908,7 @@ class GraphicalLasso(BaseTransformer):
869
908
  drop_input_cols = self._drop_input_cols,
870
909
  expected_output_cols_type="float",
871
910
  )
872
- expected_output_cols = self._align_expected_output_names(
911
+ expected_output_cols, _ = self._align_expected_output(
873
912
  inference_method, dataset, expected_output_cols, output_cols_prefix
874
913
  )
875
914
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -549,12 +546,23 @@ class GraphicalLassoCV(BaseTransformer):
549
546
  autogenerated=self._autogenerated,
550
547
  subproject=_SUBPROJECT,
551
548
  )
552
- output_result, fitted_estimator = model_trainer.train_fit_predict(
553
- drop_input_cols=self._drop_input_cols,
554
- expected_output_cols_list=(
555
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
556
- ),
549
+ expected_output_cols = (
550
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
557
551
  )
552
+ if isinstance(dataset, DataFrame):
553
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
554
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
555
+ )
556
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
557
+ drop_input_cols=self._drop_input_cols,
558
+ expected_output_cols_list=expected_output_cols,
559
+ example_output_pd_df=example_output_pd_df,
560
+ )
561
+ else:
562
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
563
+ drop_input_cols=self._drop_input_cols,
564
+ expected_output_cols_list=expected_output_cols,
565
+ )
558
566
  self._sklearn_object = fitted_estimator
559
567
  self._is_fitted = True
560
568
  return output_result
@@ -633,12 +641,41 @@ class GraphicalLassoCV(BaseTransformer):
633
641
 
634
642
  return rv
635
643
 
636
- def _align_expected_output_names(
637
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
638
- ) -> List[str]:
644
+ def _align_expected_output(
645
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
646
+ ) -> Tuple[List[str], pd.DataFrame]:
647
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
648
+ and output dataframe with 1 line.
649
+ If the method is fit_predict, run 2 lines of data.
650
+ """
639
651
  # in case the inferred output column names dimension is different
640
652
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
641
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
653
+
654
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
655
+ # so change the minimum of number of rows to 2
656
+ num_examples = 2
657
+ statement_params = telemetry.get_function_usage_statement_params(
658
+ project=_PROJECT,
659
+ subproject=_SUBPROJECT,
660
+ function_name=telemetry.get_statement_params_full_func_name(
661
+ inspect.currentframe(), GraphicalLassoCV.__class__.__name__
662
+ ),
663
+ api_calls=[Session.call],
664
+ custom_tags={"autogen": True} if self._autogenerated else None,
665
+ )
666
+ if output_cols_prefix == "fit_predict_":
667
+ if hasattr(self._sklearn_object, "n_clusters"):
668
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
669
+ num_examples = self._sklearn_object.n_clusters
670
+ elif hasattr(self._sklearn_object, "min_samples"):
671
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
672
+ num_examples = self._sklearn_object.min_samples
673
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
674
+ # LocalOutlierFactor expects n_neighbors <= n_samples
675
+ num_examples = self._sklearn_object.n_neighbors
676
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
677
+ else:
678
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
642
679
 
643
680
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
644
681
  # seen during the fit.
@@ -650,12 +687,14 @@ class GraphicalLassoCV(BaseTransformer):
650
687
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
651
688
  if self.sample_weight_col:
652
689
  output_df_columns_set -= set(self.sample_weight_col)
690
+
653
691
  # if the dimension of inferred output column names is correct; use it
654
692
  if len(expected_output_cols_list) == len(output_df_columns_set):
655
- return expected_output_cols_list
693
+ return expected_output_cols_list, output_df_pd
656
694
  # otherwise, use the sklearn estimator's output
657
695
  else:
658
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
696
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
697
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
659
698
 
660
699
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
661
700
  @telemetry.send_api_usage_telemetry(
@@ -701,7 +740,7 @@ class GraphicalLassoCV(BaseTransformer):
701
740
  drop_input_cols=self._drop_input_cols,
702
741
  expected_output_cols_type="float",
703
742
  )
704
- expected_output_cols = self._align_expected_output_names(
743
+ expected_output_cols, _ = self._align_expected_output(
705
744
  inference_method, dataset, expected_output_cols, output_cols_prefix
706
745
  )
707
746
 
@@ -767,7 +806,7 @@ class GraphicalLassoCV(BaseTransformer):
767
806
  drop_input_cols=self._drop_input_cols,
768
807
  expected_output_cols_type="float",
769
808
  )
770
- expected_output_cols = self._align_expected_output_names(
809
+ expected_output_cols, _ = self._align_expected_output(
771
810
  inference_method, dataset, expected_output_cols, output_cols_prefix
772
811
  )
773
812
  elif isinstance(dataset, pd.DataFrame):
@@ -830,7 +869,7 @@ class GraphicalLassoCV(BaseTransformer):
830
869
  drop_input_cols=self._drop_input_cols,
831
870
  expected_output_cols_type="float",
832
871
  )
833
- expected_output_cols = self._align_expected_output_names(
872
+ expected_output_cols, _ = self._align_expected_output(
834
873
  inference_method, dataset, expected_output_cols, output_cols_prefix
835
874
  )
836
875
 
@@ -895,7 +934,7 @@ class GraphicalLassoCV(BaseTransformer):
895
934
  drop_input_cols = self._drop_input_cols,
896
935
  expected_output_cols_type="float",
897
936
  )
898
- expected_output_cols = self._align_expected_output_names(
937
+ expected_output_cols, _ = self._align_expected_output(
899
938
  inference_method, dataset, expected_output_cols, output_cols_prefix
900
939
  )
901
940