snowflake-ml-python 1.1.2__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (215) hide show
  1. snowflake/ml/{model/_deploy_client/utils → _internal/container_services/image_registry}/imagelib.py +3 -1
  2. snowflake/ml/{model/_deploy_client/utils/image_registry_client.py → _internal/container_services/image_registry/registry_client.py} +4 -2
  3. snowflake/ml/_internal/env_utils.py +31 -52
  4. snowflake/ml/_internal/file_utils.py +17 -0
  5. snowflake/ml/_internal/telemetry.py +19 -0
  6. snowflake/ml/_internal/utils/query_result_checker.py +8 -5
  7. snowflake/ml/_internal/utils/snowflake_env.py +95 -0
  8. snowflake/ml/fileset/parquet_parser.py +31 -1
  9. snowflake/ml/model/__init__.py +6 -0
  10. snowflake/ml/model/_client/model/model_impl.py +172 -13
  11. snowflake/ml/model/_client/model/model_version_impl.py +96 -52
  12. snowflake/ml/model/_client/ops/metadata_ops.py +1 -3
  13. snowflake/ml/model/_client/ops/model_ops.py +155 -9
  14. snowflake/ml/model/_client/sql/model.py +55 -10
  15. snowflake/ml/model/_client/sql/model_version.py +72 -61
  16. snowflake/ml/model/_client/sql/stage.py +10 -4
  17. snowflake/ml/model/_client/sql/tag.py +118 -0
  18. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +2 -2
  19. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +8 -8
  20. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +4 -6
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +6 -7
  22. snowflake/ml/model/_deploy_client/snowservice/deploy.py +4 -5
  23. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +9 -1
  24. snowflake/ml/model/_deploy_client/warehouse/deploy.py +20 -11
  25. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +45 -1
  26. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +30 -0
  27. snowflake/ml/model/_model_composer/model_method/function_generator.py +2 -1
  28. snowflake/ml/model/_model_composer/model_runtime/_runtime_requirements.py +10 -1
  29. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +10 -7
  30. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +1 -1
  31. snowflake/ml/model/_packager/model_handlers/xgboost.py +13 -2
  32. snowflake/ml/model/_packager/model_meta/_core_requirements.py +11 -1
  33. snowflake/ml/model/_packager/model_meta/_packaging_requirements.py +3 -0
  34. snowflake/ml/model/_packager/model_meta/model_meta.py +17 -3
  35. snowflake/ml/model/_signatures/core.py +20 -17
  36. snowflake/ml/model/custom_model.py +30 -27
  37. snowflake/ml/model/model_signature.py +16 -17
  38. snowflake/ml/model/type_hints.py +3 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +185 -98
  40. snowflake/ml/modeling/_internal/estimator_utils.py +21 -0
  41. snowflake/ml/modeling/_internal/model_specifications.py +3 -10
  42. snowflake/ml/modeling/_internal/model_trainer_builder.py +55 -11
  43. snowflake/ml/modeling/_internal/snowpark_handlers.py +9 -6
  44. snowflake/ml/modeling/_internal/snowpark_trainer.py +10 -2
  45. snowflake/ml/modeling/_internal/xgboost_external_memory_trainer.py +444 -0
  46. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +51 -16
  47. snowflake/ml/modeling/cluster/affinity_propagation.py +51 -16
  48. snowflake/ml/modeling/cluster/agglomerative_clustering.py +51 -16
  49. snowflake/ml/modeling/cluster/birch.py +51 -16
  50. snowflake/ml/modeling/cluster/bisecting_k_means.py +51 -16
  51. snowflake/ml/modeling/cluster/dbscan.py +51 -16
  52. snowflake/ml/modeling/cluster/feature_agglomeration.py +51 -16
  53. snowflake/ml/modeling/cluster/k_means.py +51 -16
  54. snowflake/ml/modeling/cluster/mean_shift.py +51 -16
  55. snowflake/ml/modeling/cluster/mini_batch_k_means.py +51 -16
  56. snowflake/ml/modeling/cluster/optics.py +51 -16
  57. snowflake/ml/modeling/cluster/spectral_biclustering.py +51 -16
  58. snowflake/ml/modeling/cluster/spectral_clustering.py +51 -16
  59. snowflake/ml/modeling/cluster/spectral_coclustering.py +51 -16
  60. snowflake/ml/modeling/compose/column_transformer.py +51 -16
  61. snowflake/ml/modeling/compose/transformed_target_regressor.py +51 -16
  62. snowflake/ml/modeling/covariance/elliptic_envelope.py +51 -16
  63. snowflake/ml/modeling/covariance/empirical_covariance.py +51 -16
  64. snowflake/ml/modeling/covariance/graphical_lasso.py +51 -16
  65. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +51 -16
  66. snowflake/ml/modeling/covariance/ledoit_wolf.py +51 -16
  67. snowflake/ml/modeling/covariance/min_cov_det.py +51 -16
  68. snowflake/ml/modeling/covariance/oas.py +51 -16
  69. snowflake/ml/modeling/covariance/shrunk_covariance.py +51 -16
  70. snowflake/ml/modeling/decomposition/dictionary_learning.py +51 -16
  71. snowflake/ml/modeling/decomposition/factor_analysis.py +51 -16
  72. snowflake/ml/modeling/decomposition/fast_ica.py +51 -16
  73. snowflake/ml/modeling/decomposition/incremental_pca.py +51 -16
  74. snowflake/ml/modeling/decomposition/kernel_pca.py +51 -16
  75. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +51 -16
  76. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +51 -16
  77. snowflake/ml/modeling/decomposition/pca.py +51 -16
  78. snowflake/ml/modeling/decomposition/sparse_pca.py +51 -16
  79. snowflake/ml/modeling/decomposition/truncated_svd.py +51 -16
  80. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +51 -16
  81. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +51 -16
  82. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +51 -16
  83. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +51 -16
  84. snowflake/ml/modeling/ensemble/bagging_classifier.py +51 -16
  85. snowflake/ml/modeling/ensemble/bagging_regressor.py +51 -16
  86. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +51 -16
  87. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +51 -16
  88. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +51 -16
  89. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +51 -16
  90. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +51 -16
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +51 -16
  92. snowflake/ml/modeling/ensemble/isolation_forest.py +51 -16
  93. snowflake/ml/modeling/ensemble/random_forest_classifier.py +51 -16
  94. snowflake/ml/modeling/ensemble/random_forest_regressor.py +51 -16
  95. snowflake/ml/modeling/ensemble/stacking_regressor.py +51 -16
  96. snowflake/ml/modeling/ensemble/voting_classifier.py +51 -16
  97. snowflake/ml/modeling/ensemble/voting_regressor.py +51 -16
  98. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +51 -16
  99. snowflake/ml/modeling/feature_selection/select_fdr.py +51 -16
  100. snowflake/ml/modeling/feature_selection/select_fpr.py +51 -16
  101. snowflake/ml/modeling/feature_selection/select_fwe.py +51 -16
  102. snowflake/ml/modeling/feature_selection/select_k_best.py +51 -16
  103. snowflake/ml/modeling/feature_selection/select_percentile.py +51 -16
  104. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +51 -16
  105. snowflake/ml/modeling/feature_selection/variance_threshold.py +51 -16
  106. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +51 -16
  107. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +51 -16
  108. snowflake/ml/modeling/impute/iterative_imputer.py +51 -16
  109. snowflake/ml/modeling/impute/knn_imputer.py +51 -16
  110. snowflake/ml/modeling/impute/missing_indicator.py +51 -16
  111. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +51 -16
  112. snowflake/ml/modeling/kernel_approximation/nystroem.py +51 -16
  113. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +51 -16
  114. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +51 -16
  115. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +51 -16
  116. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +51 -16
  117. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +51 -16
  118. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +51 -16
  119. snowflake/ml/modeling/linear_model/ard_regression.py +51 -16
  120. snowflake/ml/modeling/linear_model/bayesian_ridge.py +51 -16
  121. snowflake/ml/modeling/linear_model/elastic_net.py +51 -16
  122. snowflake/ml/modeling/linear_model/elastic_net_cv.py +51 -16
  123. snowflake/ml/modeling/linear_model/gamma_regressor.py +51 -16
  124. snowflake/ml/modeling/linear_model/huber_regressor.py +51 -16
  125. snowflake/ml/modeling/linear_model/lars.py +51 -16
  126. snowflake/ml/modeling/linear_model/lars_cv.py +51 -16
  127. snowflake/ml/modeling/linear_model/lasso.py +51 -16
  128. snowflake/ml/modeling/linear_model/lasso_cv.py +51 -16
  129. snowflake/ml/modeling/linear_model/lasso_lars.py +51 -16
  130. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +51 -16
  131. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +51 -16
  132. snowflake/ml/modeling/linear_model/linear_regression.py +51 -16
  133. snowflake/ml/modeling/linear_model/logistic_regression.py +51 -16
  134. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +51 -16
  135. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +51 -16
  136. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +51 -16
  137. snowflake/ml/modeling/linear_model/multi_task_lasso.py +51 -16
  138. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +51 -16
  139. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +51 -16
  140. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +51 -16
  141. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +51 -16
  142. snowflake/ml/modeling/linear_model/perceptron.py +51 -16
  143. snowflake/ml/modeling/linear_model/poisson_regressor.py +51 -16
  144. snowflake/ml/modeling/linear_model/ransac_regressor.py +51 -16
  145. snowflake/ml/modeling/linear_model/ridge.py +51 -16
  146. snowflake/ml/modeling/linear_model/ridge_classifier.py +51 -16
  147. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +51 -16
  148. snowflake/ml/modeling/linear_model/ridge_cv.py +51 -16
  149. snowflake/ml/modeling/linear_model/sgd_classifier.py +51 -16
  150. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +51 -16
  151. snowflake/ml/modeling/linear_model/sgd_regressor.py +51 -16
  152. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +51 -16
  153. snowflake/ml/modeling/linear_model/tweedie_regressor.py +51 -16
  154. snowflake/ml/modeling/manifold/isomap.py +51 -16
  155. snowflake/ml/modeling/manifold/mds.py +51 -16
  156. snowflake/ml/modeling/manifold/spectral_embedding.py +51 -16
  157. snowflake/ml/modeling/manifold/tsne.py +51 -16
  158. snowflake/ml/modeling/metrics/classification.py +5 -6
  159. snowflake/ml/modeling/metrics/metrics_utils.py +5 -3
  160. snowflake/ml/modeling/metrics/ranking.py +7 -3
  161. snowflake/ml/modeling/metrics/regression.py +6 -3
  162. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +51 -16
  163. snowflake/ml/modeling/mixture/gaussian_mixture.py +51 -16
  164. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +51 -16
  165. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +51 -16
  166. snowflake/ml/modeling/multiclass/output_code_classifier.py +51 -16
  167. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +51 -16
  168. snowflake/ml/modeling/naive_bayes/categorical_nb.py +51 -16
  169. snowflake/ml/modeling/naive_bayes/complement_nb.py +51 -16
  170. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +51 -16
  171. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +51 -16
  172. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +51 -16
  173. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +51 -16
  174. snowflake/ml/modeling/neighbors/kernel_density.py +51 -16
  175. snowflake/ml/modeling/neighbors/local_outlier_factor.py +51 -16
  176. snowflake/ml/modeling/neighbors/nearest_centroid.py +51 -16
  177. snowflake/ml/modeling/neighbors/nearest_neighbors.py +51 -16
  178. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +51 -16
  179. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +51 -16
  180. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +51 -16
  181. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +51 -16
  182. snowflake/ml/modeling/neural_network/mlp_classifier.py +51 -16
  183. snowflake/ml/modeling/neural_network/mlp_regressor.py +51 -16
  184. snowflake/ml/modeling/preprocessing/min_max_scaler.py +15 -1
  185. snowflake/ml/modeling/preprocessing/polynomial_features.py +51 -16
  186. snowflake/ml/modeling/semi_supervised/label_propagation.py +51 -16
  187. snowflake/ml/modeling/semi_supervised/label_spreading.py +51 -16
  188. snowflake/ml/modeling/svm/linear_svc.py +51 -16
  189. snowflake/ml/modeling/svm/linear_svr.py +51 -16
  190. snowflake/ml/modeling/svm/nu_svc.py +51 -16
  191. snowflake/ml/modeling/svm/nu_svr.py +51 -16
  192. snowflake/ml/modeling/svm/svc.py +51 -16
  193. snowflake/ml/modeling/svm/svr.py +51 -16
  194. snowflake/ml/modeling/tree/decision_tree_classifier.py +51 -16
  195. snowflake/ml/modeling/tree/decision_tree_regressor.py +51 -16
  196. snowflake/ml/modeling/tree/extra_tree_classifier.py +51 -16
  197. snowflake/ml/modeling/tree/extra_tree_regressor.py +51 -16
  198. snowflake/ml/modeling/xgboost/xgb_classifier.py +69 -16
  199. snowflake/ml/modeling/xgboost/xgb_regressor.py +69 -16
  200. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +69 -16
  201. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +69 -16
  202. snowflake/ml/registry/__init__.py +3 -0
  203. snowflake/ml/registry/_manager/model_manager.py +163 -0
  204. snowflake/ml/registry/model_registry.py +12 -0
  205. snowflake/ml/registry/registry.py +100 -90
  206. snowflake/ml/version.py +1 -1
  207. snowflake_ml_python-1.2.1.dist-info/LICENSE.txt +202 -0
  208. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/METADATA +295 -60
  209. snowflake_ml_python-1.2.1.dist-info/RECORD +355 -0
  210. {snowflake_ml_python-1.1.2.dist-info → snowflake_ml_python-1.2.1.dist-info}/WHEEL +2 -1
  211. snowflake_ml_python-1.2.1.dist-info/top_level.txt +1 -0
  212. snowflake/ml/model/_client/model/model_method_info.py +0 -19
  213. snowflake_ml_python-1.1.2.dist-info/RECORD +0 -347
  214. /snowflake/ml/_internal/{utils/spcs_image_registry.py → container_services/image_registry/credential.py} +0 -0
  215. /snowflake/ml/_internal/{utils/image_registry_http_client.py → container_services/image_registry/http_client.py} +0 -0
@@ -0,0 +1,444 @@
1
+ import inspect
2
+ import os
3
+ import tempfile
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ import cloudpickle as cp
7
+ import pandas as pd
8
+ import pyarrow.parquet as pq
9
+
10
+ from snowflake.ml._internal import telemetry
11
+ from snowflake.ml._internal.exceptions import (
12
+ error_codes,
13
+ exceptions,
14
+ modeling_error_messages,
15
+ )
16
+ from snowflake.ml._internal.utils import pkg_version_utils
17
+ from snowflake.ml._internal.utils.query_result_checker import ResultValidator
18
+ from snowflake.ml._internal.utils.snowpark_dataframe_utils import (
19
+ cast_snowpark_dataframe,
20
+ )
21
+ from snowflake.ml._internal.utils.temp_file_utils import get_temp_file_path
22
+ from snowflake.ml.modeling._internal.model_specifications import (
23
+ ModelSpecifications,
24
+ ModelSpecificationsBuilder,
25
+ )
26
+ from snowflake.ml.modeling._internal.snowpark_trainer import SnowparkModelTrainer
27
+ from snowflake.snowpark import (
28
+ DataFrame,
29
+ Session,
30
+ exceptions as snowpark_exceptions,
31
+ functions as F,
32
+ )
33
+ from snowflake.snowpark._internal.utils import (
34
+ TempObjectType,
35
+ random_name_for_temp_object,
36
+ )
37
+
38
+ _PROJECT = "ModelDevelopment"
39
+
40
+
41
+ def get_data_iterator(
42
+ file_paths: List[str],
43
+ batch_size: int,
44
+ input_cols: List[str],
45
+ label_cols: List[str],
46
+ sample_weight_col: Optional[str] = None,
47
+ ) -> Any:
48
+ from typing import List, Optional
49
+
50
+ import xgboost
51
+
52
+ class ParquetDataIterator(xgboost.DataIter):
53
+ """
54
+ This iterator reads parquet data stored in a specified files and returns
55
+ deserialized data, enabling seamless integration with the xgboost framework for
56
+ machine learning tasks.
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ file_paths: List[str],
62
+ batch_size: int,
63
+ input_cols: List[str],
64
+ label_cols: List[str],
65
+ sample_weight_col: Optional[str] = None,
66
+ ) -> None:
67
+ """
68
+ Initialize the DataIterator.
69
+
70
+ Args:
71
+ file_paths: List of file paths containing the data.
72
+ batch_size: Target number of rows in each batch.
73
+ input_cols: The name(s) of one or more columns in a DataFrame containing a feature to be used for
74
+ training.
75
+ label_cols: The name(s) of one or more columns in a DataFrame representing the target variable(s)
76
+ to learn.
77
+ sample_weight_col: The column name representing the weight of training examples.
78
+ """
79
+ self._file_paths = file_paths
80
+ self._batch_size = batch_size
81
+ self._input_cols = input_cols
82
+ self._label_cols = label_cols
83
+ self._sample_weight_col = sample_weight_col
84
+
85
+ # File index
86
+ self._it = 0
87
+ # Pandas dataframe containing temp data
88
+ self._df = None
89
+ # XGBoost will generate some cache files under current directory with the prefix
90
+ # "cache"
91
+ cache_dir_name = tempfile.mkdtemp()
92
+ super().__init__(cache_prefix=os.path.join(cache_dir_name, "cache"))
93
+
94
+ def next(self, batch_consumer_fn) -> int: # type: ignore[no-untyped-def]
95
+ """Advance the iterator by 1 step and pass the data to XGBoost's batch_consumer_fn.
96
+ This function is called by XGBoost during the construction of ``DMatrix``
97
+
98
+ Args:
99
+ batch_consumer_fn: batch consumer function
100
+
101
+ Returns:
102
+ 0 if there is no more data, else 1.
103
+ """
104
+ while (self._df is None) or (self._df.shape[0] < self._batch_size):
105
+ # Read files and append data to temp df until batch size is reached.
106
+ if self._it == len(self._file_paths):
107
+ break
108
+ new_df = pq.read_table(self._file_paths[self._it]).to_pandas()
109
+ self._it += 1
110
+
111
+ if self._df is None:
112
+ self._df = new_df
113
+ else:
114
+ self._df = pd.concat([self._df, new_df], ignore_index=True)
115
+
116
+ if (self._df is None) or (self._df.shape[0] == 0):
117
+ # No more data
118
+ return 0
119
+
120
+ # Slice the temp df and save the remainder in the temp df
121
+ batch_end_index = min(self._batch_size, self._df.shape[0])
122
+ batch_df = self._df.iloc[:batch_end_index]
123
+ self._df = self._df.truncate(before=batch_end_index).reset_index(drop=True)
124
+
125
+ # TODO(snandamuri): Make it proper to support categorical features, etc.
126
+ func_args = {
127
+ "data": batch_df[self._input_cols],
128
+ "label": batch_df[self._label_cols].squeeze(),
129
+ }
130
+ if self._sample_weight_col is not None:
131
+ func_args["weight"] = batch_df[self._sample_weight_col].squeeze()
132
+
133
+ batch_consumer_fn(**func_args)
134
+ # Return 1 to let XGBoost know we haven't seen all the files yet.
135
+ return 1
136
+
137
+ def reset(self) -> None:
138
+ """Reset the iterator to its beginning"""
139
+ self._it = 0
140
+
141
+ return ParquetDataIterator(
142
+ file_paths=file_paths,
143
+ batch_size=batch_size,
144
+ input_cols=input_cols,
145
+ label_cols=label_cols,
146
+ sample_weight_col=sample_weight_col,
147
+ )
148
+
149
+
150
+ def train_xgboost_model(
151
+ estimator: object,
152
+ file_paths: List[str],
153
+ batch_size: int,
154
+ input_cols: List[str],
155
+ label_cols: List[str],
156
+ sample_weight_col: Optional[str] = None,
157
+ ) -> object:
158
+ """
159
+ Function to train XGBoost models using the external memory version of XGBoost.
160
+ """
161
+ import xgboost
162
+
163
+ def _objective_decorator(func): # type: ignore[no-untyped-def]
164
+ def inner(preds, dmatrix): # type: ignore[no-untyped-def]
165
+ """internal function"""
166
+ labels = dmatrix.get_label()
167
+ return func(labels, preds)
168
+
169
+ return inner
170
+
171
+ assert isinstance(estimator, xgboost.XGBModel)
172
+ params = estimator.get_xgb_params()
173
+ obj = None
174
+
175
+ if isinstance(estimator, xgboost.XGBClassifier):
176
+ # TODO (snandamuri): Find better way to get expected_classes
177
+ # Set: self.classes_, self.n_classes_
178
+ expected_classes = pd.unique(pq.read_table(file_paths[0]).to_pandas()[label_cols].squeeze())
179
+ estimator.n_classes_ = len(expected_classes)
180
+ if callable(estimator.objective):
181
+ obj = _objective_decorator(estimator.objective) # type: ignore[no-untyped-call]
182
+ # Use default value. Is it really not used ?
183
+ params["objective"] = "binary:logistic"
184
+
185
+ if len(expected_classes) > 2:
186
+ # Switch to using a multiclass objective in the underlying XGB instance
187
+ if params.get("objective", None) != "multi:softmax":
188
+ params["objective"] = "multi:softprob"
189
+ params["num_class"] = len(expected_classes)
190
+
191
+ if "tree_method" not in params.keys() or params["tree_method"] is None or params["tree_method"].lower() == "exact":
192
+ params["tree_method"] = "hist"
193
+
194
+ if (
195
+ "grow_policy" not in params.keys()
196
+ or params["grow_policy"] is None
197
+ or params["grow_policy"].lower() != "depthwise"
198
+ ):
199
+ params["grow_policy"] = "depthwise"
200
+
201
+ it = get_data_iterator(
202
+ file_paths=file_paths,
203
+ batch_size=batch_size,
204
+ input_cols=input_cols,
205
+ label_cols=label_cols,
206
+ sample_weight_col=sample_weight_col,
207
+ )
208
+ Xy = xgboost.DMatrix(it)
209
+ estimator._Booster = xgboost.train(
210
+ params,
211
+ Xy,
212
+ estimator.get_num_boosting_rounds(),
213
+ evals=[],
214
+ early_stopping_rounds=estimator.early_stopping_rounds,
215
+ evals_result=None,
216
+ obj=obj,
217
+ custom_metric=estimator.eval_metric,
218
+ verbose_eval=None,
219
+ xgb_model=None,
220
+ callbacks=None,
221
+ )
222
+ return estimator
223
+
224
+
225
+ cp.register_pickle_by_value(inspect.getmodule(get_data_iterator))
226
+ cp.register_pickle_by_value(inspect.getmodule(train_xgboost_model))
227
+
228
+
229
+ class XGBoostExternalMemoryTrainer(SnowparkModelTrainer):
230
+ """
231
+ When working with large datasets, training XGBoost models traditionally requires loading the entire dataset into
232
+ memory, which can be costly and sometimes infeasible due to memory constraints. To solve this problem, XGBoost
233
+ provides support for loading data from external memory using a built-in data parser. With this feature enabled,
234
+ the training process occurs in a two-step approach:
235
+ Preprocessing Step: Input data is read and parsed into an internal format, such as CSR, CSC, or sorted CSC.
236
+ Processed state is appended to an in-memory buffer. Once the buffer reaches a predefined size, it is
237
+ written out to disk as a page.
238
+ Tree Construction Step: During the tree construction phase, the data pages stored on disk are streamed via
239
+ a multi-threaded pre-fetcher, allowing the model to efficiently access and process the data without
240
+ overloading memory.
241
+ """
242
+
243
+ def __init__(
244
+ self,
245
+ estimator: object,
246
+ dataset: DataFrame,
247
+ session: Session,
248
+ input_cols: List[str],
249
+ label_cols: Optional[List[str]],
250
+ sample_weight_col: Optional[str],
251
+ autogenerated: bool = False,
252
+ subproject: str = "",
253
+ batch_size: int = 10000,
254
+ ) -> None:
255
+ """
256
+ Initializes the XGBoostExternalMemoryTrainer with a model, a Snowpark DataFrame, feature, and label column
257
+ names, etc.
258
+
259
+ Args:
260
+ estimator: SKLearn compatible estimator or transformer object.
261
+ dataset: The dataset used for training the model.
262
+ session: Snowflake session object to be used for training.
263
+ input_cols: The name(s) of one or more columns in a DataFrame containing a feature to be used for training.
264
+ label_cols: The name(s) of one or more columns in a DataFrame representing the target variable(s) to learn.
265
+ sample_weight_col: The column name representing the weight of training examples.
266
+ autogenerated: A boolean denoting if the trainer is being used by autogenerated code or not.
267
+ subproject: subproject name to be used in telemetry.
268
+ batch_size: Number of the rows in the each batch processed during training.
269
+ """
270
+ super().__init__(
271
+ estimator=estimator,
272
+ dataset=dataset,
273
+ session=session,
274
+ input_cols=input_cols,
275
+ label_cols=label_cols,
276
+ sample_weight_col=sample_weight_col,
277
+ autogenerated=autogenerated,
278
+ subproject=subproject,
279
+ )
280
+ self._batch_size = batch_size
281
+
282
+ def _get_xgb_external_memory_fit_wrapper_sproc(
283
+ self,
284
+ model_spec: ModelSpecifications,
285
+ session: Session,
286
+ statement_params: Dict[str, str],
287
+ import_file_paths: List[str],
288
+ ) -> Any:
289
+ fit_sproc_name = random_name_for_temp_object(TempObjectType.PROCEDURE)
290
+
291
+ relaxed_dependencies = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
292
+ pkg_versions=model_spec.pkgDependencies, session=self.session
293
+ )
294
+
295
+ @F.sproc(
296
+ is_permanent=False,
297
+ name=fit_sproc_name,
298
+ packages=list(["snowflake-snowpark-python"] + relaxed_dependencies),
299
+ replace=True,
300
+ session=session,
301
+ statement_params=statement_params,
302
+ anonymous=True,
303
+ imports=list(import_file_paths),
304
+ ) # type: ignore[misc]
305
+ def fit_wrapper_sproc(
306
+ session: Session,
307
+ stage_transform_file_name: str,
308
+ stage_result_file_name: str,
309
+ dataset_stage_name: str,
310
+ batch_size: int,
311
+ input_cols: List[str],
312
+ label_cols: List[str],
313
+ sample_weight_col: Optional[str],
314
+ statement_params: Dict[str, str],
315
+ ) -> str:
316
+ import os
317
+ import sys
318
+
319
+ import cloudpickle as cp
320
+
321
+ local_transform_file_name = get_temp_file_path()
322
+
323
+ session.file.get(stage_transform_file_name, local_transform_file_name, statement_params=statement_params)
324
+
325
+ local_transform_file_path = os.path.join(
326
+ local_transform_file_name, os.listdir(local_transform_file_name)[0]
327
+ )
328
+ with open(local_transform_file_path, mode="r+b") as local_transform_file_obj:
329
+ estimator = cp.load(local_transform_file_obj)
330
+
331
+ data_files = [
332
+ os.path.join(sys._xoptions["snowflake_import_directory"], filename)
333
+ for filename in os.listdir(sys._xoptions["snowflake_import_directory"])
334
+ if filename.startswith(dataset_stage_name)
335
+ ]
336
+
337
+ estimator = train_xgboost_model(
338
+ estimator=estimator,
339
+ file_paths=data_files,
340
+ batch_size=batch_size,
341
+ input_cols=input_cols,
342
+ label_cols=label_cols,
343
+ sample_weight_col=sample_weight_col,
344
+ )
345
+
346
+ local_result_file_name = get_temp_file_path()
347
+ with open(local_result_file_name, mode="w+b") as local_result_file_obj:
348
+ cp.dump(estimator, local_result_file_obj)
349
+
350
+ session.file.put(
351
+ local_result_file_name,
352
+ stage_result_file_name,
353
+ auto_compress=False,
354
+ overwrite=True,
355
+ statement_params=statement_params,
356
+ )
357
+
358
+ # Note: you can add something like + "|" + str(df) to the return string
359
+ # to pass debug information to the caller.
360
+ return str(os.path.basename(local_result_file_name))
361
+
362
+ return fit_wrapper_sproc
363
+
364
+ def _write_training_data_to_stage(self, dataset_stage_name: str) -> List[str]:
365
+ """
366
+ Materializes the training to the specified stage and returns the list of stage file paths.
367
+
368
+ Args:
369
+ dataset_stage_name: Target stage to materialize training data.
370
+
371
+ Returns:
372
+ List of stage file paths that contain the materialized data.
373
+ """
374
+ # Stage data.
375
+ dataset = cast_snowpark_dataframe(self.dataset)
376
+ remote_file_path = f"{dataset_stage_name}/{dataset_stage_name}.parquet"
377
+ copy_response = dataset.write.copy_into_location( # type:ignore[call-overload]
378
+ remote_file_path, file_format_type="parquet", header=True, overwrite=True
379
+ )
380
+ ResultValidator(result=copy_response).has_dimensions(expected_rows=1).validate()
381
+ data_file_paths = [f"@{row.name}" for row in self.session.sql(f"LIST @{dataset_stage_name}").collect()]
382
+ return data_file_paths
383
+
384
+ def train(self) -> object:
385
+ """
386
+ Runs hyper parameter optimization by distributing the tasks across warehouse.
387
+
388
+ Returns:
389
+ Trained model
390
+
391
+ Raises:
392
+ SnowflakeMLException: For known types of user and system errors.
393
+ e: For every unexpected exception from SnowflakeClient.
394
+ """
395
+ temp_stage_name = self._create_temp_stage()
396
+ (stage_transform_file_name, stage_result_file_name) = self._upload_model_to_stage(stage_name=temp_stage_name)
397
+ data_file_paths = self._write_training_data_to_stage(dataset_stage_name=temp_stage_name)
398
+
399
+ # Call fit sproc
400
+ statement_params = telemetry.get_function_usage_statement_params(
401
+ project=_PROJECT,
402
+ subproject=self._subproject,
403
+ function_name=telemetry.get_statement_params_full_func_name(inspect.currentframe(), self._class_name),
404
+ api_calls=[Session.call],
405
+ custom_tags=None,
406
+ )
407
+
408
+ model_spec = ModelSpecificationsBuilder.build(model=self.estimator)
409
+ fit_wrapper = self._get_xgb_external_memory_fit_wrapper_sproc(
410
+ model_spec=model_spec,
411
+ session=self.session,
412
+ statement_params=statement_params,
413
+ import_file_paths=data_file_paths,
414
+ )
415
+
416
+ try:
417
+ sproc_export_file_name = fit_wrapper(
418
+ self.session,
419
+ stage_transform_file_name,
420
+ stage_result_file_name,
421
+ temp_stage_name,
422
+ self._batch_size,
423
+ self.input_cols,
424
+ self.label_cols,
425
+ self.sample_weight_col,
426
+ statement_params,
427
+ )
428
+ except snowpark_exceptions.SnowparkClientException as e:
429
+ if "fit() missing 1 required positional argument: 'y'" in str(e):
430
+ raise exceptions.SnowflakeMLException(
431
+ error_code=error_codes.NOT_FOUND,
432
+ original_exception=RuntimeError(modeling_error_messages.ATTRIBUTE_NOT_SET.format("label_cols")),
433
+ ) from e
434
+ raise e
435
+
436
+ if "|" in sproc_export_file_name:
437
+ fields = sproc_export_file_name.strip().split("|")
438
+ sproc_export_file_name = fields[0]
439
+
440
+ return self._fetch_model_from_stage(
441
+ dir_path=stage_result_file_name,
442
+ file_name=sproc_export_file_name,
443
+ statement_params=statement_params,
444
+ )
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.calibration".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return False and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class CalibratedClassifierCV(BaseTransformer):
58
70
  r"""Probability calibration with isotonic regression or logistic regression
59
71
  For more details on this class, see [sklearn.calibration.CalibratedClassifierCV]
@@ -192,7 +204,9 @@ class CalibratedClassifierCV(BaseTransformer):
192
204
  self.set_label_cols(label_cols)
193
205
  self.set_passthrough_cols(passthrough_cols)
194
206
  self.set_drop_input_cols(drop_input_cols)
195
- self.set_sample_weight_col(sample_weight_col)
207
+ self.set_sample_weight_col(sample_weight_col)
208
+ self._use_external_memory_version = False
209
+ self._batch_size = -1
196
210
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
197
211
  deps = deps | gather_dependencies(estimator)
198
212
  deps = deps | gather_dependencies(base_estimator)
@@ -275,11 +289,6 @@ class CalibratedClassifierCV(BaseTransformer):
275
289
  if isinstance(dataset, DataFrame):
276
290
  session = dataset._session
277
291
  assert session is not None # keep mypy happy
278
- # Validate that key package version in user workspace are supported in snowflake conda channel
279
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
280
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
281
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
282
-
283
292
  # Specify input columns so column pruning will be enforced
284
293
  selected_cols = self._get_active_columns()
285
294
  if len(selected_cols) > 0:
@@ -307,7 +316,9 @@ class CalibratedClassifierCV(BaseTransformer):
307
316
  label_cols=self.label_cols,
308
317
  sample_weight_col=self.sample_weight_col,
309
318
  autogenerated=self._autogenerated,
310
- subproject=_SUBPROJECT
319
+ subproject=_SUBPROJECT,
320
+ use_external_memory_version=self._use_external_memory_version,
321
+ batch_size=self._batch_size,
311
322
  )
312
323
  self._sklearn_object = model_trainer.train()
313
324
  self._is_fitted = True
@@ -578,6 +589,22 @@ class CalibratedClassifierCV(BaseTransformer):
578
589
  # each row containing a list of values.
579
590
  expected_dtype = "ARRAY"
580
591
 
592
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
593
+ if expected_dtype == "":
594
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
595
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
596
+ expected_dtype = "ARRAY"
597
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
598
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
599
+ expected_dtype = "ARRAY"
600
+ else:
601
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
602
+ # We can only infer the output types from the input types if the following two statemetns are true:
603
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
604
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
605
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
606
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
607
+
581
608
  output_df = self._batch_inference(
582
609
  dataset=dataset,
583
610
  inference_method="transform",
@@ -593,8 +620,8 @@ class CalibratedClassifierCV(BaseTransformer):
593
620
 
594
621
  return output_df
595
622
 
596
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
597
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
623
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
624
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
598
625
  """ Method not supported for this class.
599
626
 
600
627
 
@@ -607,13 +634,21 @@ class CalibratedClassifierCV(BaseTransformer):
607
634
  Returns:
608
635
  Predicted dataset.
609
636
  """
610
- if False:
611
- self.fit(dataset)
612
- assert self._sklearn_object is not None
613
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
614
- return labels
615
- else:
616
- raise NotImplementedError
637
+ self.fit(dataset)
638
+ assert self._sklearn_object is not None
639
+ return self._sklearn_object.labels_
640
+
641
+
642
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
643
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
644
+ """
645
+ Returns:
646
+ Transformed dataset.
647
+ """
648
+ self.fit(dataset)
649
+ assert self._sklearn_object is not None
650
+ return self._sklearn_object.embedding_
651
+
617
652
 
618
653
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
619
654
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.
@@ -54,6 +54,18 @@ _PROJECT = "ModelDevelopment"
54
54
  _SUBPROJECT = "".join([s.capitalize() for s in "sklearn.cluster".replace("sklearn.", "").split("_")])
55
55
 
56
56
 
57
+ def _is_fit_predict_method_enabled() -> Callable[[Any], bool]:
58
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
59
+ return True and callable(getattr(self._sklearn_object, "fit_predict", None))
60
+ return check
61
+
62
+
63
+ def _is_fit_transform_method_enabled() -> Callable[[Any], bool]:
64
+ def check(self: BaseTransformer) -> TypeGuard[Callable[..., object]]:
65
+ return False and callable(getattr(self._sklearn_object, "fit_transform", None))
66
+ return check
67
+
68
+
57
69
  class AffinityPropagation(BaseTransformer):
58
70
  r"""Perform Affinity Propagation Clustering of data
59
71
  For more details on this class, see [sklearn.cluster.AffinityPropagation]
@@ -167,7 +179,9 @@ class AffinityPropagation(BaseTransformer):
167
179
  self.set_label_cols(label_cols)
168
180
  self.set_passthrough_cols(passthrough_cols)
169
181
  self.set_drop_input_cols(drop_input_cols)
170
- self.set_sample_weight_col(sample_weight_col)
182
+ self.set_sample_weight_col(sample_weight_col)
183
+ self._use_external_memory_version = False
184
+ self._batch_size = -1
171
185
  deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'cloudpickle=={cp.__version__}'])
172
186
 
173
187
  self._deps = list(deps)
@@ -250,11 +264,6 @@ class AffinityPropagation(BaseTransformer):
250
264
  if isinstance(dataset, DataFrame):
251
265
  session = dataset._session
252
266
  assert session is not None # keep mypy happy
253
- # Validate that key package version in user workspace are supported in snowflake conda channel
254
- # If customer doesn't have package in conda channel, replace the ones have the closest versions
255
- self._deps = pkg_version_utils.get_valid_pkg_versions_supported_in_snowflake_conda_channel(
256
- pkg_versions=self._get_dependencies(), session=session, subproject=_SUBPROJECT)
257
-
258
267
  # Specify input columns so column pruning will be enforced
259
268
  selected_cols = self._get_active_columns()
260
269
  if len(selected_cols) > 0:
@@ -282,7 +291,9 @@ class AffinityPropagation(BaseTransformer):
282
291
  label_cols=self.label_cols,
283
292
  sample_weight_col=self.sample_weight_col,
284
293
  autogenerated=self._autogenerated,
285
- subproject=_SUBPROJECT
294
+ subproject=_SUBPROJECT,
295
+ use_external_memory_version=self._use_external_memory_version,
296
+ batch_size=self._batch_size,
286
297
  )
287
298
  self._sklearn_object = model_trainer.train()
288
299
  self._is_fitted = True
@@ -553,6 +564,22 @@ class AffinityPropagation(BaseTransformer):
553
564
  # each row containing a list of values.
554
565
  expected_dtype = "ARRAY"
555
566
 
567
+ # If we were unable to assign a type to this transform in the factory, infer the type here.
568
+ if expected_dtype == "":
569
+ # If this is a clustering transformer, if the number of output columns does not equal the number of clusters the response will be an "ARRAY"
570
+ if hasattr(self._sklearn_object, "n_clusters") and getattr(self._sklearn_object, "n_clusters") != len(self.output_cols):
571
+ expected_dtype = "ARRAY"
572
+ # If this is a decomposition transformer, if the number of output columns does not equal the number of components the response will be an "ARRAY"
573
+ elif hasattr(self._sklearn_object, "n_components") and getattr(self._sklearn_object, "n_components") != len(self.output_cols):
574
+ expected_dtype = "ARRAY"
575
+ else:
576
+ output_types = [signature.as_snowpark_type() for signature in _infer_signature(dataset[self.input_cols], "output", use_snowflake_identifiers=True)]
577
+ # We can only infer the output types from the input types if the following two statemetns are true:
578
+ # 1) All of the output types are the same. Otherwise, we still have to fall back to variant because `_sklearn_inference` only accepts one type.
579
+ # 2) The length of the input columns equals the length of the output columns. Otherwise the transform will likely result in an `ARRAY`.
580
+ if all(x == output_types[0] for x in output_types) and len(output_types) == len(self.output_cols):
581
+ expected_dtype = convert_sp_to_sf_type(output_types[0])
582
+
556
583
  output_df = self._batch_inference(
557
584
  dataset=dataset,
558
585
  inference_method="transform",
@@ -568,8 +595,8 @@ class AffinityPropagation(BaseTransformer):
568
595
 
569
596
  return output_df
570
597
 
571
- @available_if(original_estimator_has_callable("fit_predict")) # type: ignore[misc]
572
- def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> npt.NDArray[Any]:
598
+ @available_if(_is_fit_predict_method_enabled()) # type: ignore[misc]
599
+ def fit_predict(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
573
600
  """ Fit clustering from features/affinity matrix; return cluster labels
574
601
  For more details on this function, see [sklearn.cluster.AffinityPropagation.fit_predict]
575
602
  (https://scikit-learn.org/stable/modules/generated/sklearn.cluster.AffinityPropagation.html#sklearn.cluster.AffinityPropagation.fit_predict)
@@ -584,13 +611,21 @@ class AffinityPropagation(BaseTransformer):
584
611
  Returns:
585
612
  Predicted dataset.
586
613
  """
587
- if True:
588
- self.fit(dataset)
589
- assert self._sklearn_object is not None
590
- labels : npt.NDArray[Any] = self._sklearn_object.labels_
591
- return labels
592
- else:
593
- raise NotImplementedError
614
+ self.fit(dataset)
615
+ assert self._sklearn_object is not None
616
+ return self._sklearn_object.labels_
617
+
618
+
619
+ @available_if(_is_fit_transform_method_enabled()) # type: ignore[misc]
620
+ def fit_transform(self, dataset: Union[DataFrame, pd.DataFrame]) -> Union[Any, npt.NDArray[Any]]:
621
+ """
622
+ Returns:
623
+ Transformed dataset.
624
+ """
625
+ self.fit(dataset)
626
+ assert self._sklearn_object is not None
627
+ return self._sklearn_object.embedding_
628
+
594
629
 
595
630
  def _get_output_column_names(self, output_cols_prefix: str, output_cols: Optional[List[str]] = None) -> List[str]:
596
631
  """ Returns the list of output columns for predict_proba(), decision_function(), etc.. functions.