snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -652,12 +649,23 @@ class ExtraTreesClassifier(BaseTransformer):
652
649
  autogenerated=self._autogenerated,
653
650
  subproject=_SUBPROJECT,
654
651
  )
655
- output_result, fitted_estimator = model_trainer.train_fit_predict(
656
- drop_input_cols=self._drop_input_cols,
657
- expected_output_cols_list=(
658
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
659
- ),
652
+ expected_output_cols = (
653
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
660
654
  )
655
+ if isinstance(dataset, DataFrame):
656
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
657
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
658
+ )
659
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
660
+ drop_input_cols=self._drop_input_cols,
661
+ expected_output_cols_list=expected_output_cols,
662
+ example_output_pd_df=example_output_pd_df,
663
+ )
664
+ else:
665
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
666
+ drop_input_cols=self._drop_input_cols,
667
+ expected_output_cols_list=expected_output_cols,
668
+ )
661
669
  self._sklearn_object = fitted_estimator
662
670
  self._is_fitted = True
663
671
  return output_result
@@ -680,6 +688,7 @@ class ExtraTreesClassifier(BaseTransformer):
680
688
  """
681
689
  self._infer_input_output_cols(dataset)
682
690
  super()._check_dataset_type(dataset)
691
+
683
692
  model_trainer = ModelTrainerBuilder.build_fit_transform(
684
693
  estimator=self._sklearn_object,
685
694
  dataset=dataset,
@@ -736,12 +745,41 @@ class ExtraTreesClassifier(BaseTransformer):
736
745
 
737
746
  return rv
738
747
 
739
- def _align_expected_output_names(
740
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
741
- ) -> List[str]:
748
+ def _align_expected_output(
749
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
750
+ ) -> Tuple[List[str], pd.DataFrame]:
751
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
752
+ and output dataframe with 1 line.
753
+ If the method is fit_predict, run 2 lines of data.
754
+ """
742
755
  # in case the inferred output column names dimension is different
743
756
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
744
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
757
+
758
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
759
+ # so change the minimum of number of rows to 2
760
+ num_examples = 2
761
+ statement_params = telemetry.get_function_usage_statement_params(
762
+ project=_PROJECT,
763
+ subproject=_SUBPROJECT,
764
+ function_name=telemetry.get_statement_params_full_func_name(
765
+ inspect.currentframe(), ExtraTreesClassifier.__class__.__name__
766
+ ),
767
+ api_calls=[Session.call],
768
+ custom_tags={"autogen": True} if self._autogenerated else None,
769
+ )
770
+ if output_cols_prefix == "fit_predict_":
771
+ if hasattr(self._sklearn_object, "n_clusters"):
772
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
773
+ num_examples = self._sklearn_object.n_clusters
774
+ elif hasattr(self._sklearn_object, "min_samples"):
775
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
776
+ num_examples = self._sklearn_object.min_samples
777
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
778
+ # LocalOutlierFactor expects n_neighbors <= n_samples
779
+ num_examples = self._sklearn_object.n_neighbors
780
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
781
+ else:
782
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
745
783
 
746
784
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
747
785
  # seen during the fit.
@@ -753,12 +791,14 @@ class ExtraTreesClassifier(BaseTransformer):
753
791
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
754
792
  if self.sample_weight_col:
755
793
  output_df_columns_set -= set(self.sample_weight_col)
794
+
756
795
  # if the dimension of inferred output column names is correct; use it
757
796
  if len(expected_output_cols_list) == len(output_df_columns_set):
758
- return expected_output_cols_list
797
+ return expected_output_cols_list, output_df_pd
759
798
  # otherwise, use the sklearn estimator's output
760
799
  else:
761
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
800
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
801
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
762
802
 
763
803
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
764
804
  @telemetry.send_api_usage_telemetry(
@@ -806,7 +846,7 @@ class ExtraTreesClassifier(BaseTransformer):
806
846
  drop_input_cols=self._drop_input_cols,
807
847
  expected_output_cols_type="float",
808
848
  )
809
- expected_output_cols = self._align_expected_output_names(
849
+ expected_output_cols, _ = self._align_expected_output(
810
850
  inference_method, dataset, expected_output_cols, output_cols_prefix
811
851
  )
812
852
 
@@ -874,7 +914,7 @@ class ExtraTreesClassifier(BaseTransformer):
874
914
  drop_input_cols=self._drop_input_cols,
875
915
  expected_output_cols_type="float",
876
916
  )
877
- expected_output_cols = self._align_expected_output_names(
917
+ expected_output_cols, _ = self._align_expected_output(
878
918
  inference_method, dataset, expected_output_cols, output_cols_prefix
879
919
  )
880
920
  elif isinstance(dataset, pd.DataFrame):
@@ -937,7 +977,7 @@ class ExtraTreesClassifier(BaseTransformer):
937
977
  drop_input_cols=self._drop_input_cols,
938
978
  expected_output_cols_type="float",
939
979
  )
940
- expected_output_cols = self._align_expected_output_names(
980
+ expected_output_cols, _ = self._align_expected_output(
941
981
  inference_method, dataset, expected_output_cols, output_cols_prefix
942
982
  )
943
983
 
@@ -1002,7 +1042,7 @@ class ExtraTreesClassifier(BaseTransformer):
1002
1042
  drop_input_cols = self._drop_input_cols,
1003
1043
  expected_output_cols_type="float",
1004
1044
  )
1005
- expected_output_cols = self._align_expected_output_names(
1045
+ expected_output_cols, _ = self._align_expected_output(
1006
1046
  inference_method, dataset, expected_output_cols, output_cols_prefix
1007
1047
  )
1008
1048
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -631,12 +628,23 @@ class ExtraTreesRegressor(BaseTransformer):
631
628
  autogenerated=self._autogenerated,
632
629
  subproject=_SUBPROJECT,
633
630
  )
634
- output_result, fitted_estimator = model_trainer.train_fit_predict(
635
- drop_input_cols=self._drop_input_cols,
636
- expected_output_cols_list=(
637
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
638
- ),
631
+ expected_output_cols = (
632
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
639
633
  )
634
+ if isinstance(dataset, DataFrame):
635
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
636
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
637
+ )
638
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
639
+ drop_input_cols=self._drop_input_cols,
640
+ expected_output_cols_list=expected_output_cols,
641
+ example_output_pd_df=example_output_pd_df,
642
+ )
643
+ else:
644
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
645
+ drop_input_cols=self._drop_input_cols,
646
+ expected_output_cols_list=expected_output_cols,
647
+ )
640
648
  self._sklearn_object = fitted_estimator
641
649
  self._is_fitted = True
642
650
  return output_result
@@ -659,6 +667,7 @@ class ExtraTreesRegressor(BaseTransformer):
659
667
  """
660
668
  self._infer_input_output_cols(dataset)
661
669
  super()._check_dataset_type(dataset)
670
+
662
671
  model_trainer = ModelTrainerBuilder.build_fit_transform(
663
672
  estimator=self._sklearn_object,
664
673
  dataset=dataset,
@@ -715,12 +724,41 @@ class ExtraTreesRegressor(BaseTransformer):
715
724
 
716
725
  return rv
717
726
 
718
- def _align_expected_output_names(
719
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
720
- ) -> List[str]:
727
+ def _align_expected_output(
728
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
729
+ ) -> Tuple[List[str], pd.DataFrame]:
730
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
731
+ and output dataframe with 1 line.
732
+ If the method is fit_predict, run 2 lines of data.
733
+ """
721
734
  # in case the inferred output column names dimension is different
722
735
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
723
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
736
+
737
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
738
+ # so change the minimum of number of rows to 2
739
+ num_examples = 2
740
+ statement_params = telemetry.get_function_usage_statement_params(
741
+ project=_PROJECT,
742
+ subproject=_SUBPROJECT,
743
+ function_name=telemetry.get_statement_params_full_func_name(
744
+ inspect.currentframe(), ExtraTreesRegressor.__class__.__name__
745
+ ),
746
+ api_calls=[Session.call],
747
+ custom_tags={"autogen": True} if self._autogenerated else None,
748
+ )
749
+ if output_cols_prefix == "fit_predict_":
750
+ if hasattr(self._sklearn_object, "n_clusters"):
751
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
752
+ num_examples = self._sklearn_object.n_clusters
753
+ elif hasattr(self._sklearn_object, "min_samples"):
754
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
755
+ num_examples = self._sklearn_object.min_samples
756
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
757
+ # LocalOutlierFactor expects n_neighbors <= n_samples
758
+ num_examples = self._sklearn_object.n_neighbors
759
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
760
+ else:
761
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
724
762
 
725
763
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
726
764
  # seen during the fit.
@@ -732,12 +770,14 @@ class ExtraTreesRegressor(BaseTransformer):
732
770
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
733
771
  if self.sample_weight_col:
734
772
  output_df_columns_set -= set(self.sample_weight_col)
773
+
735
774
  # if the dimension of inferred output column names is correct; use it
736
775
  if len(expected_output_cols_list) == len(output_df_columns_set):
737
- return expected_output_cols_list
776
+ return expected_output_cols_list, output_df_pd
738
777
  # otherwise, use the sklearn estimator's output
739
778
  else:
740
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
779
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
780
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
741
781
 
742
782
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
743
783
  @telemetry.send_api_usage_telemetry(
@@ -783,7 +823,7 @@ class ExtraTreesRegressor(BaseTransformer):
783
823
  drop_input_cols=self._drop_input_cols,
784
824
  expected_output_cols_type="float",
785
825
  )
786
- expected_output_cols = self._align_expected_output_names(
826
+ expected_output_cols, _ = self._align_expected_output(
787
827
  inference_method, dataset, expected_output_cols, output_cols_prefix
788
828
  )
789
829
 
@@ -849,7 +889,7 @@ class ExtraTreesRegressor(BaseTransformer):
849
889
  drop_input_cols=self._drop_input_cols,
850
890
  expected_output_cols_type="float",
851
891
  )
852
- expected_output_cols = self._align_expected_output_names(
892
+ expected_output_cols, _ = self._align_expected_output(
853
893
  inference_method, dataset, expected_output_cols, output_cols_prefix
854
894
  )
855
895
  elif isinstance(dataset, pd.DataFrame):
@@ -912,7 +952,7 @@ class ExtraTreesRegressor(BaseTransformer):
912
952
  drop_input_cols=self._drop_input_cols,
913
953
  expected_output_cols_type="float",
914
954
  )
915
- expected_output_cols = self._align_expected_output_names(
955
+ expected_output_cols, _ = self._align_expected_output(
916
956
  inference_method, dataset, expected_output_cols, output_cols_prefix
917
957
  )
918
958
 
@@ -977,7 +1017,7 @@ class ExtraTreesRegressor(BaseTransformer):
977
1017
  drop_input_cols = self._drop_input_cols,
978
1018
  expected_output_cols_type="float",
979
1019
  )
980
- expected_output_cols = self._align_expected_output_names(
1020
+ expected_output_cols, _ = self._align_expected_output(
981
1021
  inference_method, dataset, expected_output_cols, output_cols_prefix
982
1022
  )
983
1023
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -664,12 +661,23 @@ class GradientBoostingClassifier(BaseTransformer):
664
661
  autogenerated=self._autogenerated,
665
662
  subproject=_SUBPROJECT,
666
663
  )
667
- output_result, fitted_estimator = model_trainer.train_fit_predict(
668
- drop_input_cols=self._drop_input_cols,
669
- expected_output_cols_list=(
670
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
671
- ),
664
+ expected_output_cols = (
665
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
672
666
  )
667
+ if isinstance(dataset, DataFrame):
668
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
669
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
670
+ )
671
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
672
+ drop_input_cols=self._drop_input_cols,
673
+ expected_output_cols_list=expected_output_cols,
674
+ example_output_pd_df=example_output_pd_df,
675
+ )
676
+ else:
677
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
678
+ drop_input_cols=self._drop_input_cols,
679
+ expected_output_cols_list=expected_output_cols,
680
+ )
673
681
  self._sklearn_object = fitted_estimator
674
682
  self._is_fitted = True
675
683
  return output_result
@@ -692,6 +700,7 @@ class GradientBoostingClassifier(BaseTransformer):
692
700
  """
693
701
  self._infer_input_output_cols(dataset)
694
702
  super()._check_dataset_type(dataset)
703
+
695
704
  model_trainer = ModelTrainerBuilder.build_fit_transform(
696
705
  estimator=self._sklearn_object,
697
706
  dataset=dataset,
@@ -748,12 +757,41 @@ class GradientBoostingClassifier(BaseTransformer):
748
757
 
749
758
  return rv
750
759
 
751
- def _align_expected_output_names(
752
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
753
- ) -> List[str]:
760
+ def _align_expected_output(
761
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
762
+ ) -> Tuple[List[str], pd.DataFrame]:
763
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
764
+ and output dataframe with 1 line.
765
+ If the method is fit_predict, run 2 lines of data.
766
+ """
754
767
  # in case the inferred output column names dimension is different
755
768
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
756
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
769
+
770
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
771
+ # so change the minimum of number of rows to 2
772
+ num_examples = 2
773
+ statement_params = telemetry.get_function_usage_statement_params(
774
+ project=_PROJECT,
775
+ subproject=_SUBPROJECT,
776
+ function_name=telemetry.get_statement_params_full_func_name(
777
+ inspect.currentframe(), GradientBoostingClassifier.__class__.__name__
778
+ ),
779
+ api_calls=[Session.call],
780
+ custom_tags={"autogen": True} if self._autogenerated else None,
781
+ )
782
+ if output_cols_prefix == "fit_predict_":
783
+ if hasattr(self._sklearn_object, "n_clusters"):
784
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
785
+ num_examples = self._sklearn_object.n_clusters
786
+ elif hasattr(self._sklearn_object, "min_samples"):
787
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
788
+ num_examples = self._sklearn_object.min_samples
789
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
790
+ # LocalOutlierFactor expects n_neighbors <= n_samples
791
+ num_examples = self._sklearn_object.n_neighbors
792
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
793
+ else:
794
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
757
795
 
758
796
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
759
797
  # seen during the fit.
@@ -765,12 +803,14 @@ class GradientBoostingClassifier(BaseTransformer):
765
803
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
766
804
  if self.sample_weight_col:
767
805
  output_df_columns_set -= set(self.sample_weight_col)
806
+
768
807
  # if the dimension of inferred output column names is correct; use it
769
808
  if len(expected_output_cols_list) == len(output_df_columns_set):
770
- return expected_output_cols_list
809
+ return expected_output_cols_list, output_df_pd
771
810
  # otherwise, use the sklearn estimator's output
772
811
  else:
773
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
812
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
813
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
774
814
 
775
815
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
776
816
  @telemetry.send_api_usage_telemetry(
@@ -818,7 +858,7 @@ class GradientBoostingClassifier(BaseTransformer):
818
858
  drop_input_cols=self._drop_input_cols,
819
859
  expected_output_cols_type="float",
820
860
  )
821
- expected_output_cols = self._align_expected_output_names(
861
+ expected_output_cols, _ = self._align_expected_output(
822
862
  inference_method, dataset, expected_output_cols, output_cols_prefix
823
863
  )
824
864
 
@@ -886,7 +926,7 @@ class GradientBoostingClassifier(BaseTransformer):
886
926
  drop_input_cols=self._drop_input_cols,
887
927
  expected_output_cols_type="float",
888
928
  )
889
- expected_output_cols = self._align_expected_output_names(
929
+ expected_output_cols, _ = self._align_expected_output(
890
930
  inference_method, dataset, expected_output_cols, output_cols_prefix
891
931
  )
892
932
  elif isinstance(dataset, pd.DataFrame):
@@ -951,7 +991,7 @@ class GradientBoostingClassifier(BaseTransformer):
951
991
  drop_input_cols=self._drop_input_cols,
952
992
  expected_output_cols_type="float",
953
993
  )
954
- expected_output_cols = self._align_expected_output_names(
994
+ expected_output_cols, _ = self._align_expected_output(
955
995
  inference_method, dataset, expected_output_cols, output_cols_prefix
956
996
  )
957
997
 
@@ -1016,7 +1056,7 @@ class GradientBoostingClassifier(BaseTransformer):
1016
1056
  drop_input_cols = self._drop_input_cols,
1017
1057
  expected_output_cols_type="float",
1018
1058
  )
1019
- expected_output_cols = self._align_expected_output_names(
1059
+ expected_output_cols, _ = self._align_expected_output(
1020
1060
  inference_method, dataset, expected_output_cols, output_cols_prefix
1021
1061
  )
1022
1062
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -673,12 +670,23 @@ class GradientBoostingRegressor(BaseTransformer):
673
670
  autogenerated=self._autogenerated,
674
671
  subproject=_SUBPROJECT,
675
672
  )
676
- output_result, fitted_estimator = model_trainer.train_fit_predict(
677
- drop_input_cols=self._drop_input_cols,
678
- expected_output_cols_list=(
679
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
680
- ),
673
+ expected_output_cols = (
674
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
681
675
  )
676
+ if isinstance(dataset, DataFrame):
677
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
678
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
679
+ )
680
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
681
+ drop_input_cols=self._drop_input_cols,
682
+ expected_output_cols_list=expected_output_cols,
683
+ example_output_pd_df=example_output_pd_df,
684
+ )
685
+ else:
686
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
687
+ drop_input_cols=self._drop_input_cols,
688
+ expected_output_cols_list=expected_output_cols,
689
+ )
682
690
  self._sklearn_object = fitted_estimator
683
691
  self._is_fitted = True
684
692
  return output_result
@@ -701,6 +709,7 @@ class GradientBoostingRegressor(BaseTransformer):
701
709
  """
702
710
  self._infer_input_output_cols(dataset)
703
711
  super()._check_dataset_type(dataset)
712
+
704
713
  model_trainer = ModelTrainerBuilder.build_fit_transform(
705
714
  estimator=self._sklearn_object,
706
715
  dataset=dataset,
@@ -757,12 +766,41 @@ class GradientBoostingRegressor(BaseTransformer):
757
766
 
758
767
  return rv
759
768
 
760
- def _align_expected_output_names(
761
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
762
- ) -> List[str]:
769
+ def _align_expected_output(
770
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
771
+ ) -> Tuple[List[str], pd.DataFrame]:
772
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
773
+ and output dataframe with 1 line.
774
+ If the method is fit_predict, run 2 lines of data.
775
+ """
763
776
  # in case the inferred output column names dimension is different
764
777
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
765
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
778
+
779
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
780
+ # so change the minimum of number of rows to 2
781
+ num_examples = 2
782
+ statement_params = telemetry.get_function_usage_statement_params(
783
+ project=_PROJECT,
784
+ subproject=_SUBPROJECT,
785
+ function_name=telemetry.get_statement_params_full_func_name(
786
+ inspect.currentframe(), GradientBoostingRegressor.__class__.__name__
787
+ ),
788
+ api_calls=[Session.call],
789
+ custom_tags={"autogen": True} if self._autogenerated else None,
790
+ )
791
+ if output_cols_prefix == "fit_predict_":
792
+ if hasattr(self._sklearn_object, "n_clusters"):
793
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
794
+ num_examples = self._sklearn_object.n_clusters
795
+ elif hasattr(self._sklearn_object, "min_samples"):
796
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
797
+ num_examples = self._sklearn_object.min_samples
798
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
799
+ # LocalOutlierFactor expects n_neighbors <= n_samples
800
+ num_examples = self._sklearn_object.n_neighbors
801
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
802
+ else:
803
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
766
804
 
767
805
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
768
806
  # seen during the fit.
@@ -774,12 +812,14 @@ class GradientBoostingRegressor(BaseTransformer):
774
812
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
775
813
  if self.sample_weight_col:
776
814
  output_df_columns_set -= set(self.sample_weight_col)
815
+
777
816
  # if the dimension of inferred output column names is correct; use it
778
817
  if len(expected_output_cols_list) == len(output_df_columns_set):
779
- return expected_output_cols_list
818
+ return expected_output_cols_list, output_df_pd
780
819
  # otherwise, use the sklearn estimator's output
781
820
  else:
782
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
821
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
822
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
783
823
 
784
824
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
785
825
  @telemetry.send_api_usage_telemetry(
@@ -825,7 +865,7 @@ class GradientBoostingRegressor(BaseTransformer):
825
865
  drop_input_cols=self._drop_input_cols,
826
866
  expected_output_cols_type="float",
827
867
  )
828
- expected_output_cols = self._align_expected_output_names(
868
+ expected_output_cols, _ = self._align_expected_output(
829
869
  inference_method, dataset, expected_output_cols, output_cols_prefix
830
870
  )
831
871
 
@@ -891,7 +931,7 @@ class GradientBoostingRegressor(BaseTransformer):
891
931
  drop_input_cols=self._drop_input_cols,
892
932
  expected_output_cols_type="float",
893
933
  )
894
- expected_output_cols = self._align_expected_output_names(
934
+ expected_output_cols, _ = self._align_expected_output(
895
935
  inference_method, dataset, expected_output_cols, output_cols_prefix
896
936
  )
897
937
  elif isinstance(dataset, pd.DataFrame):
@@ -954,7 +994,7 @@ class GradientBoostingRegressor(BaseTransformer):
954
994
  drop_input_cols=self._drop_input_cols,
955
995
  expected_output_cols_type="float",
956
996
  )
957
- expected_output_cols = self._align_expected_output_names(
997
+ expected_output_cols, _ = self._align_expected_output(
958
998
  inference_method, dataset, expected_output_cols, output_cols_prefix
959
999
  )
960
1000
 
@@ -1019,7 +1059,7 @@ class GradientBoostingRegressor(BaseTransformer):
1019
1059
  drop_input_cols = self._drop_input_cols,
1020
1060
  expected_output_cols_type="float",
1021
1061
  )
1022
- expected_output_cols = self._align_expected_output_names(
1062
+ expected_output_cols, _ = self._align_expected_output(
1023
1063
  inference_method, dataset, expected_output_cols, output_cols_prefix
1024
1064
  )
1025
1065