snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. snowflake/ml/_internal/telemetry.py +142 -20
  2. snowflake/ml/_internal/utils/identifier.py +48 -11
  3. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  4. snowflake/ml/_internal/utils/sql_identifier.py +1 -1
  5. snowflake/ml/_internal/utils/table_manager.py +19 -1
  6. snowflake/ml/_internal/utils/uri.py +2 -2
  7. snowflake/ml/data/data_connector.py +33 -7
  8. snowflake/ml/data/torch_utils.py +68 -0
  9. snowflake/ml/dataset/dataset.py +1 -3
  10. snowflake/ml/feature_store/feature_store.py +41 -17
  11. snowflake/ml/feature_store/feature_view.py +2 -2
  12. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  13. snowflake/ml/fileset/fileset.py +1 -1
  14. snowflake/ml/fileset/sfcfs.py +9 -3
  15. snowflake/ml/model/_client/model/model_version_impl.py +22 -7
  16. snowflake/ml/model/_client/ops/model_ops.py +39 -3
  17. snowflake/ml/model/_client/ops/service_ops.py +198 -7
  18. snowflake/ml/model/_client/service/model_deployment_spec.py +4 -5
  19. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +1 -2
  20. snowflake/ml/model/_client/sql/service.py +85 -18
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +1 -1
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +3 -3
  23. snowflake/ml/model/_model_composer/model_composer.py +2 -0
  24. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +3 -8
  25. snowflake/ml/model/_packager/model_handlers/_utils.py +46 -14
  26. snowflake/ml/model/_packager/model_handlers/catboost.py +17 -15
  27. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +23 -15
  28. snowflake/ml/model/_packager/model_handlers/lightgbm.py +15 -57
  29. snowflake/ml/model/_packager/model_handlers/llm.py +4 -2
  30. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +116 -0
  31. snowflake/ml/model/_packager/model_handlers/sklearn.py +36 -24
  32. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +119 -6
  33. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  34. snowflake/ml/model/_packager/model_handlers/xgboost.py +48 -48
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +10 -7
  36. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +0 -8
  37. snowflake/ml/model/_packager/model_packager.py +2 -0
  38. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  39. snowflake/ml/model/_signatures/utils.py +9 -0
  40. snowflake/ml/model/models/llm.py +3 -1
  41. snowflake/ml/model/type_hints.py +9 -1
  42. snowflake/ml/modeling/_internal/constants.py +1 -0
  43. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  44. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  45. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  46. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  47. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -2
  48. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  49. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +113 -160
  50. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +60 -21
  51. snowflake/ml/modeling/cluster/affinity_propagation.py +60 -21
  52. snowflake/ml/modeling/cluster/agglomerative_clustering.py +60 -21
  53. snowflake/ml/modeling/cluster/birch.py +60 -21
  54. snowflake/ml/modeling/cluster/bisecting_k_means.py +60 -21
  55. snowflake/ml/modeling/cluster/dbscan.py +60 -21
  56. snowflake/ml/modeling/cluster/feature_agglomeration.py +60 -21
  57. snowflake/ml/modeling/cluster/k_means.py +60 -21
  58. snowflake/ml/modeling/cluster/mean_shift.py +60 -21
  59. snowflake/ml/modeling/cluster/mini_batch_k_means.py +60 -21
  60. snowflake/ml/modeling/cluster/optics.py +60 -21
  61. snowflake/ml/modeling/cluster/spectral_biclustering.py +60 -21
  62. snowflake/ml/modeling/cluster/spectral_clustering.py +60 -21
  63. snowflake/ml/modeling/cluster/spectral_coclustering.py +60 -21
  64. snowflake/ml/modeling/compose/column_transformer.py +60 -21
  65. snowflake/ml/modeling/compose/transformed_target_regressor.py +60 -21
  66. snowflake/ml/modeling/covariance/elliptic_envelope.py +60 -21
  67. snowflake/ml/modeling/covariance/empirical_covariance.py +60 -21
  68. snowflake/ml/modeling/covariance/graphical_lasso.py +60 -21
  69. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +60 -21
  70. snowflake/ml/modeling/covariance/ledoit_wolf.py +60 -21
  71. snowflake/ml/modeling/covariance/min_cov_det.py +60 -21
  72. snowflake/ml/modeling/covariance/oas.py +60 -21
  73. snowflake/ml/modeling/covariance/shrunk_covariance.py +60 -21
  74. snowflake/ml/modeling/decomposition/dictionary_learning.py +60 -21
  75. snowflake/ml/modeling/decomposition/factor_analysis.py +60 -21
  76. snowflake/ml/modeling/decomposition/fast_ica.py +60 -21
  77. snowflake/ml/modeling/decomposition/incremental_pca.py +60 -21
  78. snowflake/ml/modeling/decomposition/kernel_pca.py +60 -21
  79. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +60 -21
  80. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +60 -21
  81. snowflake/ml/modeling/decomposition/pca.py +60 -21
  82. snowflake/ml/modeling/decomposition/sparse_pca.py +60 -21
  83. snowflake/ml/modeling/decomposition/truncated_svd.py +60 -21
  84. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +60 -21
  85. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +60 -21
  86. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +60 -21
  87. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +60 -21
  88. snowflake/ml/modeling/ensemble/bagging_classifier.py +60 -21
  89. snowflake/ml/modeling/ensemble/bagging_regressor.py +60 -21
  90. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +60 -21
  91. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +60 -21
  92. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +60 -21
  93. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +60 -21
  94. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +60 -21
  95. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +60 -21
  96. snowflake/ml/modeling/ensemble/isolation_forest.py +60 -21
  97. snowflake/ml/modeling/ensemble/random_forest_classifier.py +60 -21
  98. snowflake/ml/modeling/ensemble/random_forest_regressor.py +60 -21
  99. snowflake/ml/modeling/ensemble/stacking_regressor.py +60 -21
  100. snowflake/ml/modeling/ensemble/voting_classifier.py +60 -21
  101. snowflake/ml/modeling/ensemble/voting_regressor.py +60 -21
  102. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +60 -21
  103. snowflake/ml/modeling/feature_selection/select_fdr.py +60 -21
  104. snowflake/ml/modeling/feature_selection/select_fpr.py +60 -21
  105. snowflake/ml/modeling/feature_selection/select_fwe.py +60 -21
  106. snowflake/ml/modeling/feature_selection/select_k_best.py +60 -21
  107. snowflake/ml/modeling/feature_selection/select_percentile.py +60 -21
  108. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +60 -21
  109. snowflake/ml/modeling/feature_selection/variance_threshold.py +60 -21
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +60 -21
  111. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +60 -21
  112. snowflake/ml/modeling/impute/iterative_imputer.py +60 -21
  113. snowflake/ml/modeling/impute/knn_imputer.py +60 -21
  114. snowflake/ml/modeling/impute/missing_indicator.py +60 -21
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +60 -21
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +60 -21
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +60 -21
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +60 -21
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +60 -21
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +60 -21
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +60 -21
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +60 -21
  123. snowflake/ml/modeling/linear_model/ard_regression.py +60 -21
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +60 -21
  125. snowflake/ml/modeling/linear_model/elastic_net.py +60 -21
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +60 -21
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +60 -21
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +60 -21
  129. snowflake/ml/modeling/linear_model/lars.py +60 -21
  130. snowflake/ml/modeling/linear_model/lars_cv.py +60 -21
  131. snowflake/ml/modeling/linear_model/lasso.py +60 -21
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +60 -21
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +60 -21
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +60 -21
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +60 -21
  136. snowflake/ml/modeling/linear_model/linear_regression.py +60 -21
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +60 -21
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +60 -21
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +60 -21
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +60 -21
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +60 -21
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +60 -21
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +60 -21
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +60 -21
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +60 -21
  146. snowflake/ml/modeling/linear_model/perceptron.py +60 -21
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +60 -21
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +60 -21
  149. snowflake/ml/modeling/linear_model/ridge.py +60 -21
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +60 -21
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +60 -21
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +60 -21
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +60 -21
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +60 -21
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +60 -21
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +60 -21
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +60 -21
  158. snowflake/ml/modeling/manifold/isomap.py +60 -21
  159. snowflake/ml/modeling/manifold/mds.py +60 -21
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +60 -21
  161. snowflake/ml/modeling/manifold/tsne.py +60 -21
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +60 -21
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +60 -21
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +60 -21
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +60 -21
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +60 -21
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +60 -21
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +60 -21
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +60 -21
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +60 -21
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +60 -21
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +60 -21
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +60 -21
  174. snowflake/ml/modeling/neighbors/kernel_density.py +60 -21
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +60 -21
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +60 -21
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +60 -21
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +60 -21
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +60 -21
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +60 -21
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +60 -21
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +60 -21
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +60 -21
  184. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  185. snowflake/ml/modeling/pipeline/pipeline.py +1 -12
  186. snowflake/ml/modeling/preprocessing/polynomial_features.py +60 -21
  187. snowflake/ml/modeling/semi_supervised/label_propagation.py +60 -21
  188. snowflake/ml/modeling/semi_supervised/label_spreading.py +60 -21
  189. snowflake/ml/modeling/svm/linear_svc.py +60 -21
  190. snowflake/ml/modeling/svm/linear_svr.py +60 -21
  191. snowflake/ml/modeling/svm/nu_svc.py +60 -21
  192. snowflake/ml/modeling/svm/nu_svr.py +60 -21
  193. snowflake/ml/modeling/svm/svc.py +60 -21
  194. snowflake/ml/modeling/svm/svr.py +60 -21
  195. snowflake/ml/modeling/tree/decision_tree_classifier.py +60 -21
  196. snowflake/ml/modeling/tree/decision_tree_regressor.py +60 -21
  197. snowflake/ml/modeling/tree/extra_tree_classifier.py +60 -21
  198. snowflake/ml/modeling/tree/extra_tree_regressor.py +60 -21
  199. snowflake/ml/modeling/xgboost/xgb_classifier.py +63 -23
  200. snowflake/ml/modeling/xgboost/xgb_regressor.py +63 -23
  201. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +63 -23
  202. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +63 -23
  203. snowflake/ml/registry/_manager/model_manager.py +4 -0
  204. snowflake/ml/registry/model_registry.py +1 -1
  205. snowflake/ml/registry/registry.py +1 -2
  206. snowflake/ml/version.py +1 -1
  207. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/METADATA +23 -4
  208. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/RECORD +211 -209
  209. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/WHEEL +1 -1
  210. snowflake/ml/data/torch_dataset.py +0 -33
  211. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/LICENSE.txt +0 -0
  212. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.2.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -585,12 +582,23 @@ class DecisionTreeRegressor(BaseTransformer):
585
582
  autogenerated=self._autogenerated,
586
583
  subproject=_SUBPROJECT,
587
584
  )
588
- output_result, fitted_estimator = model_trainer.train_fit_predict(
589
- drop_input_cols=self._drop_input_cols,
590
- expected_output_cols_list=(
591
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
592
- ),
585
+ expected_output_cols = (
586
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
593
587
  )
588
+ if isinstance(dataset, DataFrame):
589
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
590
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
591
+ )
592
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
593
+ drop_input_cols=self._drop_input_cols,
594
+ expected_output_cols_list=expected_output_cols,
595
+ example_output_pd_df=example_output_pd_df,
596
+ )
597
+ else:
598
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
599
+ drop_input_cols=self._drop_input_cols,
600
+ expected_output_cols_list=expected_output_cols,
601
+ )
594
602
  self._sklearn_object = fitted_estimator
595
603
  self._is_fitted = True
596
604
  return output_result
@@ -669,12 +677,41 @@ class DecisionTreeRegressor(BaseTransformer):
669
677
 
670
678
  return rv
671
679
 
672
- def _align_expected_output_names(
673
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
674
- ) -> List[str]:
680
+ def _align_expected_output(
681
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
682
+ ) -> Tuple[List[str], pd.DataFrame]:
683
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
684
+ and output dataframe with 1 line.
685
+ If the method is fit_predict, run 2 lines of data.
686
+ """
675
687
  # in case the inferred output column names dimension is different
676
688
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
677
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
689
+
690
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
691
+ # so change the minimum of number of rows to 2
692
+ num_examples = 2
693
+ statement_params = telemetry.get_function_usage_statement_params(
694
+ project=_PROJECT,
695
+ subproject=_SUBPROJECT,
696
+ function_name=telemetry.get_statement_params_full_func_name(
697
+ inspect.currentframe(), DecisionTreeRegressor.__class__.__name__
698
+ ),
699
+ api_calls=[Session.call],
700
+ custom_tags={"autogen": True} if self._autogenerated else None,
701
+ )
702
+ if output_cols_prefix == "fit_predict_":
703
+ if hasattr(self._sklearn_object, "n_clusters"):
704
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
705
+ num_examples = self._sklearn_object.n_clusters
706
+ elif hasattr(self._sklearn_object, "min_samples"):
707
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
708
+ num_examples = self._sklearn_object.min_samples
709
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
710
+ # LocalOutlierFactor expects n_neighbors <= n_samples
711
+ num_examples = self._sklearn_object.n_neighbors
712
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
713
+ else:
714
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
678
715
 
679
716
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
680
717
  # seen during the fit.
@@ -686,12 +723,14 @@ class DecisionTreeRegressor(BaseTransformer):
686
723
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
687
724
  if self.sample_weight_col:
688
725
  output_df_columns_set -= set(self.sample_weight_col)
726
+
689
727
  # if the dimension of inferred output column names is correct; use it
690
728
  if len(expected_output_cols_list) == len(output_df_columns_set):
691
- return expected_output_cols_list
729
+ return expected_output_cols_list, output_df_pd
692
730
  # otherwise, use the sklearn estimator's output
693
731
  else:
694
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
732
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
733
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
695
734
 
696
735
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
697
736
  @telemetry.send_api_usage_telemetry(
@@ -737,7 +776,7 @@ class DecisionTreeRegressor(BaseTransformer):
737
776
  drop_input_cols=self._drop_input_cols,
738
777
  expected_output_cols_type="float",
739
778
  )
740
- expected_output_cols = self._align_expected_output_names(
779
+ expected_output_cols, _ = self._align_expected_output(
741
780
  inference_method, dataset, expected_output_cols, output_cols_prefix
742
781
  )
743
782
 
@@ -803,7 +842,7 @@ class DecisionTreeRegressor(BaseTransformer):
803
842
  drop_input_cols=self._drop_input_cols,
804
843
  expected_output_cols_type="float",
805
844
  )
806
- expected_output_cols = self._align_expected_output_names(
845
+ expected_output_cols, _ = self._align_expected_output(
807
846
  inference_method, dataset, expected_output_cols, output_cols_prefix
808
847
  )
809
848
  elif isinstance(dataset, pd.DataFrame):
@@ -866,7 +905,7 @@ class DecisionTreeRegressor(BaseTransformer):
866
905
  drop_input_cols=self._drop_input_cols,
867
906
  expected_output_cols_type="float",
868
907
  )
869
- expected_output_cols = self._align_expected_output_names(
908
+ expected_output_cols, _ = self._align_expected_output(
870
909
  inference_method, dataset, expected_output_cols, output_cols_prefix
871
910
  )
872
911
 
@@ -931,7 +970,7 @@ class DecisionTreeRegressor(BaseTransformer):
931
970
  drop_input_cols = self._drop_input_cols,
932
971
  expected_output_cols_type="float",
933
972
  )
934
- expected_output_cols = self._align_expected_output_names(
973
+ expected_output_cols, _ = self._align_expected_output(
935
974
  inference_method, dataset, expected_output_cols, output_cols_prefix
936
975
  )
937
976
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -595,12 +592,23 @@ class ExtraTreeClassifier(BaseTransformer):
595
592
  autogenerated=self._autogenerated,
596
593
  subproject=_SUBPROJECT,
597
594
  )
598
- output_result, fitted_estimator = model_trainer.train_fit_predict(
599
- drop_input_cols=self._drop_input_cols,
600
- expected_output_cols_list=(
601
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
602
- ),
595
+ expected_output_cols = (
596
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
603
597
  )
598
+ if isinstance(dataset, DataFrame):
599
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
600
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
601
+ )
602
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
603
+ drop_input_cols=self._drop_input_cols,
604
+ expected_output_cols_list=expected_output_cols,
605
+ example_output_pd_df=example_output_pd_df,
606
+ )
607
+ else:
608
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
609
+ drop_input_cols=self._drop_input_cols,
610
+ expected_output_cols_list=expected_output_cols,
611
+ )
604
612
  self._sklearn_object = fitted_estimator
605
613
  self._is_fitted = True
606
614
  return output_result
@@ -679,12 +687,41 @@ class ExtraTreeClassifier(BaseTransformer):
679
687
 
680
688
  return rv
681
689
 
682
- def _align_expected_output_names(
683
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
684
- ) -> List[str]:
690
+ def _align_expected_output(
691
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
692
+ ) -> Tuple[List[str], pd.DataFrame]:
693
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
694
+ and output dataframe with 1 line.
695
+ If the method is fit_predict, run 2 lines of data.
696
+ """
685
697
  # in case the inferred output column names dimension is different
686
698
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
687
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
699
+
700
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
701
+ # so change the minimum of number of rows to 2
702
+ num_examples = 2
703
+ statement_params = telemetry.get_function_usage_statement_params(
704
+ project=_PROJECT,
705
+ subproject=_SUBPROJECT,
706
+ function_name=telemetry.get_statement_params_full_func_name(
707
+ inspect.currentframe(), ExtraTreeClassifier.__class__.__name__
708
+ ),
709
+ api_calls=[Session.call],
710
+ custom_tags={"autogen": True} if self._autogenerated else None,
711
+ )
712
+ if output_cols_prefix == "fit_predict_":
713
+ if hasattr(self._sklearn_object, "n_clusters"):
714
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
715
+ num_examples = self._sklearn_object.n_clusters
716
+ elif hasattr(self._sklearn_object, "min_samples"):
717
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
718
+ num_examples = self._sklearn_object.min_samples
719
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
720
+ # LocalOutlierFactor expects n_neighbors <= n_samples
721
+ num_examples = self._sklearn_object.n_neighbors
722
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
723
+ else:
724
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
688
725
 
689
726
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
690
727
  # seen during the fit.
@@ -696,12 +733,14 @@ class ExtraTreeClassifier(BaseTransformer):
696
733
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
697
734
  if self.sample_weight_col:
698
735
  output_df_columns_set -= set(self.sample_weight_col)
736
+
699
737
  # if the dimension of inferred output column names is correct; use it
700
738
  if len(expected_output_cols_list) == len(output_df_columns_set):
701
- return expected_output_cols_list
739
+ return expected_output_cols_list, output_df_pd
702
740
  # otherwise, use the sklearn estimator's output
703
741
  else:
704
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
742
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
743
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
705
744
 
706
745
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
707
746
  @telemetry.send_api_usage_telemetry(
@@ -749,7 +788,7 @@ class ExtraTreeClassifier(BaseTransformer):
749
788
  drop_input_cols=self._drop_input_cols,
750
789
  expected_output_cols_type="float",
751
790
  )
752
- expected_output_cols = self._align_expected_output_names(
791
+ expected_output_cols, _ = self._align_expected_output(
753
792
  inference_method, dataset, expected_output_cols, output_cols_prefix
754
793
  )
755
794
 
@@ -817,7 +856,7 @@ class ExtraTreeClassifier(BaseTransformer):
817
856
  drop_input_cols=self._drop_input_cols,
818
857
  expected_output_cols_type="float",
819
858
  )
820
- expected_output_cols = self._align_expected_output_names(
859
+ expected_output_cols, _ = self._align_expected_output(
821
860
  inference_method, dataset, expected_output_cols, output_cols_prefix
822
861
  )
823
862
  elif isinstance(dataset, pd.DataFrame):
@@ -880,7 +919,7 @@ class ExtraTreeClassifier(BaseTransformer):
880
919
  drop_input_cols=self._drop_input_cols,
881
920
  expected_output_cols_type="float",
882
921
  )
883
- expected_output_cols = self._align_expected_output_names(
922
+ expected_output_cols, _ = self._align_expected_output(
884
923
  inference_method, dataset, expected_output_cols, output_cols_prefix
885
924
  )
886
925
 
@@ -945,7 +984,7 @@ class ExtraTreeClassifier(BaseTransformer):
945
984
  drop_input_cols = self._drop_input_cols,
946
985
  expected_output_cols_type="float",
947
986
  )
948
- expected_output_cols = self._align_expected_output_names(
987
+ expected_output_cols, _ = self._align_expected_output(
949
988
  inference_method, dataset, expected_output_cols, output_cols_prefix
950
989
  )
951
990
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -577,12 +574,23 @@ class ExtraTreeRegressor(BaseTransformer):
577
574
  autogenerated=self._autogenerated,
578
575
  subproject=_SUBPROJECT,
579
576
  )
580
- output_result, fitted_estimator = model_trainer.train_fit_predict(
581
- drop_input_cols=self._drop_input_cols,
582
- expected_output_cols_list=(
583
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
584
- ),
577
+ expected_output_cols = (
578
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
585
579
  )
580
+ if isinstance(dataset, DataFrame):
581
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
582
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
583
+ )
584
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
585
+ drop_input_cols=self._drop_input_cols,
586
+ expected_output_cols_list=expected_output_cols,
587
+ example_output_pd_df=example_output_pd_df,
588
+ )
589
+ else:
590
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
591
+ drop_input_cols=self._drop_input_cols,
592
+ expected_output_cols_list=expected_output_cols,
593
+ )
586
594
  self._sklearn_object = fitted_estimator
587
595
  self._is_fitted = True
588
596
  return output_result
@@ -661,12 +669,41 @@ class ExtraTreeRegressor(BaseTransformer):
661
669
 
662
670
  return rv
663
671
 
664
- def _align_expected_output_names(
665
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
666
- ) -> List[str]:
672
+ def _align_expected_output(
673
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
674
+ ) -> Tuple[List[str], pd.DataFrame]:
675
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
676
+ and output dataframe with 1 line.
677
+ If the method is fit_predict, run 2 lines of data.
678
+ """
667
679
  # in case the inferred output column names dimension is different
668
680
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
669
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
681
+
682
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
683
+ # so change the minimum of number of rows to 2
684
+ num_examples = 2
685
+ statement_params = telemetry.get_function_usage_statement_params(
686
+ project=_PROJECT,
687
+ subproject=_SUBPROJECT,
688
+ function_name=telemetry.get_statement_params_full_func_name(
689
+ inspect.currentframe(), ExtraTreeRegressor.__class__.__name__
690
+ ),
691
+ api_calls=[Session.call],
692
+ custom_tags={"autogen": True} if self._autogenerated else None,
693
+ )
694
+ if output_cols_prefix == "fit_predict_":
695
+ if hasattr(self._sklearn_object, "n_clusters"):
696
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
697
+ num_examples = self._sklearn_object.n_clusters
698
+ elif hasattr(self._sklearn_object, "min_samples"):
699
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
700
+ num_examples = self._sklearn_object.min_samples
701
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
702
+ # LocalOutlierFactor expects n_neighbors <= n_samples
703
+ num_examples = self._sklearn_object.n_neighbors
704
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
705
+ else:
706
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
670
707
 
671
708
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
672
709
  # seen during the fit.
@@ -678,12 +715,14 @@ class ExtraTreeRegressor(BaseTransformer):
678
715
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
679
716
  if self.sample_weight_col:
680
717
  output_df_columns_set -= set(self.sample_weight_col)
718
+
681
719
  # if the dimension of inferred output column names is correct; use it
682
720
  if len(expected_output_cols_list) == len(output_df_columns_set):
683
- return expected_output_cols_list
721
+ return expected_output_cols_list, output_df_pd
684
722
  # otherwise, use the sklearn estimator's output
685
723
  else:
686
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
724
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
725
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
687
726
 
688
727
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
689
728
  @telemetry.send_api_usage_telemetry(
@@ -729,7 +768,7 @@ class ExtraTreeRegressor(BaseTransformer):
729
768
  drop_input_cols=self._drop_input_cols,
730
769
  expected_output_cols_type="float",
731
770
  )
732
- expected_output_cols = self._align_expected_output_names(
771
+ expected_output_cols, _ = self._align_expected_output(
733
772
  inference_method, dataset, expected_output_cols, output_cols_prefix
734
773
  )
735
774
 
@@ -795,7 +834,7 @@ class ExtraTreeRegressor(BaseTransformer):
795
834
  drop_input_cols=self._drop_input_cols,
796
835
  expected_output_cols_type="float",
797
836
  )
798
- expected_output_cols = self._align_expected_output_names(
837
+ expected_output_cols, _ = self._align_expected_output(
799
838
  inference_method, dataset, expected_output_cols, output_cols_prefix
800
839
  )
801
840
  elif isinstance(dataset, pd.DataFrame):
@@ -858,7 +897,7 @@ class ExtraTreeRegressor(BaseTransformer):
858
897
  drop_input_cols=self._drop_input_cols,
859
898
  expected_output_cols_type="float",
860
899
  )
861
- expected_output_cols = self._align_expected_output_names(
900
+ expected_output_cols, _ = self._align_expected_output(
862
901
  inference_method, dataset, expected_output_cols, output_cols_prefix
863
902
  )
864
903
 
@@ -923,7 +962,7 @@ class ExtraTreeRegressor(BaseTransformer):
923
962
  drop_input_cols = self._drop_input_cols,
924
963
  expected_output_cols_type="float",
925
964
  )
926
- expected_output_cols = self._align_expected_output_names(
965
+ expected_output_cols, _ = self._align_expected_output(
927
966
  inference_method, dataset, expected_output_cols, output_cols_prefix
928
967
  )
929
968
 
@@ -4,18 +4,17 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
18
16
  import numpy
17
+ import sklearn
19
18
  import xgboost
20
19
  from sklearn.utils.metaestimators import available_if
21
20
 
@@ -23,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
23
22
  from snowflake.ml._internal import telemetry
24
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
25
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
26
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
27
26
  from snowflake.snowpark import DataFrame, Session
28
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
29
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
30
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
31
- ModelTransformHandlers,
32
30
  BatchInferenceKwargsTypedDict,
33
31
  ScoreKwargsTypedDict
34
32
  )
@@ -361,7 +359,7 @@ class XGBClassifier(BaseTransformer):
361
359
  self.set_sample_weight_col(sample_weight_col)
362
360
  self._use_external_memory_version = use_external_memory_version
363
361
  self._batch_size = batch_size
364
- deps: Set[str] = set([f'numpy=={np.__version__}', f'xgboost=={xgboost.__version__}', f'cloudpickle=={cp.__version__}'])
362
+ deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'xgboost=={xgboost.__version__}', f'cloudpickle=={cp.__version__}'])
365
363
 
366
364
  self._deps = list(deps)
367
365
 
@@ -695,12 +693,23 @@ class XGBClassifier(BaseTransformer):
695
693
  autogenerated=self._autogenerated,
696
694
  subproject=_SUBPROJECT,
697
695
  )
698
- output_result, fitted_estimator = model_trainer.train_fit_predict(
699
- drop_input_cols=self._drop_input_cols,
700
- expected_output_cols_list=(
701
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
702
- ),
696
+ expected_output_cols = (
697
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
703
698
  )
699
+ if isinstance(dataset, DataFrame):
700
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
701
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
702
+ )
703
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
704
+ drop_input_cols=self._drop_input_cols,
705
+ expected_output_cols_list=expected_output_cols,
706
+ example_output_pd_df=example_output_pd_df,
707
+ )
708
+ else:
709
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
710
+ drop_input_cols=self._drop_input_cols,
711
+ expected_output_cols_list=expected_output_cols,
712
+ )
704
713
  self._sklearn_object = fitted_estimator
705
714
  self._is_fitted = True
706
715
  return output_result
@@ -779,12 +788,41 @@ class XGBClassifier(BaseTransformer):
779
788
 
780
789
  return rv
781
790
 
782
- def _align_expected_output_names(
783
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
784
- ) -> List[str]:
791
+ def _align_expected_output(
792
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
793
+ ) -> Tuple[List[str], pd.DataFrame]:
794
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
795
+ and output dataframe with 1 line.
796
+ If the method is fit_predict, run 2 lines of data.
797
+ """
785
798
  # in case the inferred output column names dimension is different
786
799
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
787
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
800
+
801
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
802
+ # so change the minimum of number of rows to 2
803
+ num_examples = 2
804
+ statement_params = telemetry.get_function_usage_statement_params(
805
+ project=_PROJECT,
806
+ subproject=_SUBPROJECT,
807
+ function_name=telemetry.get_statement_params_full_func_name(
808
+ inspect.currentframe(), XGBClassifier.__class__.__name__
809
+ ),
810
+ api_calls=[Session.call],
811
+ custom_tags={"autogen": True} if self._autogenerated else None,
812
+ )
813
+ if output_cols_prefix == "fit_predict_":
814
+ if hasattr(self._sklearn_object, "n_clusters"):
815
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
816
+ num_examples = self._sklearn_object.n_clusters
817
+ elif hasattr(self._sklearn_object, "min_samples"):
818
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
819
+ num_examples = self._sklearn_object.min_samples
820
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
821
+ # LocalOutlierFactor expects n_neighbors <= n_samples
822
+ num_examples = self._sklearn_object.n_neighbors
823
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
824
+ else:
825
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
788
826
 
789
827
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
790
828
  # seen during the fit.
@@ -796,12 +834,14 @@ class XGBClassifier(BaseTransformer):
796
834
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
797
835
  if self.sample_weight_col:
798
836
  output_df_columns_set -= set(self.sample_weight_col)
837
+
799
838
  # if the dimension of inferred output column names is correct; use it
800
839
  if len(expected_output_cols_list) == len(output_df_columns_set):
801
- return expected_output_cols_list
840
+ return expected_output_cols_list, output_df_pd
802
841
  # otherwise, use the sklearn estimator's output
803
842
  else:
804
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
843
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
844
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
805
845
 
806
846
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
807
847
  @telemetry.send_api_usage_telemetry(
@@ -849,7 +889,7 @@ class XGBClassifier(BaseTransformer):
849
889
  drop_input_cols=self._drop_input_cols,
850
890
  expected_output_cols_type="float",
851
891
  )
852
- expected_output_cols = self._align_expected_output_names(
892
+ expected_output_cols, _ = self._align_expected_output(
853
893
  inference_method, dataset, expected_output_cols, output_cols_prefix
854
894
  )
855
895
 
@@ -917,7 +957,7 @@ class XGBClassifier(BaseTransformer):
917
957
  drop_input_cols=self._drop_input_cols,
918
958
  expected_output_cols_type="float",
919
959
  )
920
- expected_output_cols = self._align_expected_output_names(
960
+ expected_output_cols, _ = self._align_expected_output(
921
961
  inference_method, dataset, expected_output_cols, output_cols_prefix
922
962
  )
923
963
  elif isinstance(dataset, pd.DataFrame):
@@ -980,7 +1020,7 @@ class XGBClassifier(BaseTransformer):
980
1020
  drop_input_cols=self._drop_input_cols,
981
1021
  expected_output_cols_type="float",
982
1022
  )
983
- expected_output_cols = self._align_expected_output_names(
1023
+ expected_output_cols, _ = self._align_expected_output(
984
1024
  inference_method, dataset, expected_output_cols, output_cols_prefix
985
1025
  )
986
1026
 
@@ -1045,7 +1085,7 @@ class XGBClassifier(BaseTransformer):
1045
1085
  drop_input_cols = self._drop_input_cols,
1046
1086
  expected_output_cols_type="float",
1047
1087
  )
1048
- expected_output_cols = self._align_expected_output_names(
1088
+ expected_output_cols, _ = self._align_expected_output(
1049
1089
  inference_method, dataset, expected_output_cols, output_cols_prefix
1050
1090
  )
1051
1091
 
@@ -1110,7 +1150,7 @@ class XGBClassifier(BaseTransformer):
1110
1150
  transform_kwargs = dict(
1111
1151
  session=dataset._session,
1112
1152
  dependencies=self._deps,
1113
- score_sproc_imports=['xgboost'],
1153
+ score_sproc_imports=['xgboost', 'sklearn'],
1114
1154
  )
1115
1155
  elif isinstance(dataset, pd.DataFrame):
1116
1156
  # pandas_handler.score() does not require any extra kwargs.