snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -577,12 +574,23 @@ class ExtraTreeRegressor(BaseTransformer):
577
574
  autogenerated=self._autogenerated,
578
575
  subproject=_SUBPROJECT,
579
576
  )
580
- output_result, fitted_estimator = model_trainer.train_fit_predict(
581
- drop_input_cols=self._drop_input_cols,
582
- expected_output_cols_list=(
583
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
584
- ),
577
+ expected_output_cols = (
578
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
585
579
  )
580
+ if isinstance(dataset, DataFrame):
581
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
582
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
583
+ )
584
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
585
+ drop_input_cols=self._drop_input_cols,
586
+ expected_output_cols_list=expected_output_cols,
587
+ example_output_pd_df=example_output_pd_df,
588
+ )
589
+ else:
590
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
591
+ drop_input_cols=self._drop_input_cols,
592
+ expected_output_cols_list=expected_output_cols,
593
+ )
586
594
  self._sklearn_object = fitted_estimator
587
595
  self._is_fitted = True
588
596
  return output_result
@@ -605,6 +613,7 @@ class ExtraTreeRegressor(BaseTransformer):
605
613
  """
606
614
  self._infer_input_output_cols(dataset)
607
615
  super()._check_dataset_type(dataset)
616
+
608
617
  model_trainer = ModelTrainerBuilder.build_fit_transform(
609
618
  estimator=self._sklearn_object,
610
619
  dataset=dataset,
@@ -661,12 +670,41 @@ class ExtraTreeRegressor(BaseTransformer):
661
670
 
662
671
  return rv
663
672
 
664
- def _align_expected_output_names(
665
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
666
- ) -> List[str]:
673
+ def _align_expected_output(
674
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
675
+ ) -> Tuple[List[str], pd.DataFrame]:
676
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
677
+ and output dataframe with 1 line.
678
+ If the method is fit_predict, run 2 lines of data.
679
+ """
667
680
  # in case the inferred output column names dimension is different
668
681
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
669
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
682
+
683
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
684
+ # so change the minimum of number of rows to 2
685
+ num_examples = 2
686
+ statement_params = telemetry.get_function_usage_statement_params(
687
+ project=_PROJECT,
688
+ subproject=_SUBPROJECT,
689
+ function_name=telemetry.get_statement_params_full_func_name(
690
+ inspect.currentframe(), ExtraTreeRegressor.__class__.__name__
691
+ ),
692
+ api_calls=[Session.call],
693
+ custom_tags={"autogen": True} if self._autogenerated else None,
694
+ )
695
+ if output_cols_prefix == "fit_predict_":
696
+ if hasattr(self._sklearn_object, "n_clusters"):
697
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
698
+ num_examples = self._sklearn_object.n_clusters
699
+ elif hasattr(self._sklearn_object, "min_samples"):
700
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
701
+ num_examples = self._sklearn_object.min_samples
702
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
703
+ # LocalOutlierFactor expects n_neighbors <= n_samples
704
+ num_examples = self._sklearn_object.n_neighbors
705
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
706
+ else:
707
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
670
708
 
671
709
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
672
710
  # seen during the fit.
@@ -678,12 +716,14 @@ class ExtraTreeRegressor(BaseTransformer):
678
716
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
679
717
  if self.sample_weight_col:
680
718
  output_df_columns_set -= set(self.sample_weight_col)
719
+
681
720
  # if the dimension of inferred output column names is correct; use it
682
721
  if len(expected_output_cols_list) == len(output_df_columns_set):
683
- return expected_output_cols_list
722
+ return expected_output_cols_list, output_df_pd
684
723
  # otherwise, use the sklearn estimator's output
685
724
  else:
686
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
725
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
726
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
687
727
 
688
728
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
689
729
  @telemetry.send_api_usage_telemetry(
@@ -729,7 +769,7 @@ class ExtraTreeRegressor(BaseTransformer):
729
769
  drop_input_cols=self._drop_input_cols,
730
770
  expected_output_cols_type="float",
731
771
  )
732
- expected_output_cols = self._align_expected_output_names(
772
+ expected_output_cols, _ = self._align_expected_output(
733
773
  inference_method, dataset, expected_output_cols, output_cols_prefix
734
774
  )
735
775
 
@@ -795,7 +835,7 @@ class ExtraTreeRegressor(BaseTransformer):
795
835
  drop_input_cols=self._drop_input_cols,
796
836
  expected_output_cols_type="float",
797
837
  )
798
- expected_output_cols = self._align_expected_output_names(
838
+ expected_output_cols, _ = self._align_expected_output(
799
839
  inference_method, dataset, expected_output_cols, output_cols_prefix
800
840
  )
801
841
  elif isinstance(dataset, pd.DataFrame):
@@ -858,7 +898,7 @@ class ExtraTreeRegressor(BaseTransformer):
858
898
  drop_input_cols=self._drop_input_cols,
859
899
  expected_output_cols_type="float",
860
900
  )
861
- expected_output_cols = self._align_expected_output_names(
901
+ expected_output_cols, _ = self._align_expected_output(
862
902
  inference_method, dataset, expected_output_cols, output_cols_prefix
863
903
  )
864
904
 
@@ -923,7 +963,7 @@ class ExtraTreeRegressor(BaseTransformer):
923
963
  drop_input_cols = self._drop_input_cols,
924
964
  expected_output_cols_type="float",
925
965
  )
926
- expected_output_cols = self._align_expected_output_names(
966
+ expected_output_cols, _ = self._align_expected_output(
927
967
  inference_method, dataset, expected_output_cols, output_cols_prefix
928
968
  )
929
969
 
@@ -4,18 +4,17 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
18
16
  import numpy
17
+ import sklearn
19
18
  import xgboost
20
19
  from sklearn.utils.metaestimators import available_if
21
20
 
@@ -23,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
23
22
  from snowflake.ml._internal import telemetry
24
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
25
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
26
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
27
26
  from snowflake.snowpark import DataFrame, Session
28
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
29
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
30
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
31
- ModelTransformHandlers,
32
30
  BatchInferenceKwargsTypedDict,
33
31
  ScoreKwargsTypedDict
34
32
  )
@@ -361,7 +359,7 @@ class XGBClassifier(BaseTransformer):
361
359
  self.set_sample_weight_col(sample_weight_col)
362
360
  self._use_external_memory_version = use_external_memory_version
363
361
  self._batch_size = batch_size
364
- deps: Set[str] = set([f'numpy=={np.__version__}', f'xgboost=={xgboost.__version__}', f'cloudpickle=={cp.__version__}'])
362
+ deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'xgboost=={xgboost.__version__}', f'cloudpickle=={cp.__version__}'])
365
363
 
366
364
  self._deps = list(deps)
367
365
 
@@ -695,12 +693,23 @@ class XGBClassifier(BaseTransformer):
695
693
  autogenerated=self._autogenerated,
696
694
  subproject=_SUBPROJECT,
697
695
  )
698
- output_result, fitted_estimator = model_trainer.train_fit_predict(
699
- drop_input_cols=self._drop_input_cols,
700
- expected_output_cols_list=(
701
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
702
- ),
696
+ expected_output_cols = (
697
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
703
698
  )
699
+ if isinstance(dataset, DataFrame):
700
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
701
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
702
+ )
703
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
704
+ drop_input_cols=self._drop_input_cols,
705
+ expected_output_cols_list=expected_output_cols,
706
+ example_output_pd_df=example_output_pd_df,
707
+ )
708
+ else:
709
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
710
+ drop_input_cols=self._drop_input_cols,
711
+ expected_output_cols_list=expected_output_cols,
712
+ )
704
713
  self._sklearn_object = fitted_estimator
705
714
  self._is_fitted = True
706
715
  return output_result
@@ -723,6 +732,7 @@ class XGBClassifier(BaseTransformer):
723
732
  """
724
733
  self._infer_input_output_cols(dataset)
725
734
  super()._check_dataset_type(dataset)
735
+
726
736
  model_trainer = ModelTrainerBuilder.build_fit_transform(
727
737
  estimator=self._sklearn_object,
728
738
  dataset=dataset,
@@ -779,12 +789,41 @@ class XGBClassifier(BaseTransformer):
779
789
 
780
790
  return rv
781
791
 
782
- def _align_expected_output_names(
783
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
784
- ) -> List[str]:
792
+ def _align_expected_output(
793
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
794
+ ) -> Tuple[List[str], pd.DataFrame]:
795
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
796
+ and output dataframe with 1 line.
797
+ If the method is fit_predict, run 2 lines of data.
798
+ """
785
799
  # in case the inferred output column names dimension is different
786
800
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
787
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
801
+
802
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
803
+ # so change the minimum of number of rows to 2
804
+ num_examples = 2
805
+ statement_params = telemetry.get_function_usage_statement_params(
806
+ project=_PROJECT,
807
+ subproject=_SUBPROJECT,
808
+ function_name=telemetry.get_statement_params_full_func_name(
809
+ inspect.currentframe(), XGBClassifier.__class__.__name__
810
+ ),
811
+ api_calls=[Session.call],
812
+ custom_tags={"autogen": True} if self._autogenerated else None,
813
+ )
814
+ if output_cols_prefix == "fit_predict_":
815
+ if hasattr(self._sklearn_object, "n_clusters"):
816
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
817
+ num_examples = self._sklearn_object.n_clusters
818
+ elif hasattr(self._sklearn_object, "min_samples"):
819
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
820
+ num_examples = self._sklearn_object.min_samples
821
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
822
+ # LocalOutlierFactor expects n_neighbors <= n_samples
823
+ num_examples = self._sklearn_object.n_neighbors
824
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
825
+ else:
826
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
788
827
 
789
828
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
790
829
  # seen during the fit.
@@ -796,12 +835,14 @@ class XGBClassifier(BaseTransformer):
796
835
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
797
836
  if self.sample_weight_col:
798
837
  output_df_columns_set -= set(self.sample_weight_col)
838
+
799
839
  # if the dimension of inferred output column names is correct; use it
800
840
  if len(expected_output_cols_list) == len(output_df_columns_set):
801
- return expected_output_cols_list
841
+ return expected_output_cols_list, output_df_pd
802
842
  # otherwise, use the sklearn estimator's output
803
843
  else:
804
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
844
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
845
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
805
846
 
806
847
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
807
848
  @telemetry.send_api_usage_telemetry(
@@ -849,7 +890,7 @@ class XGBClassifier(BaseTransformer):
849
890
  drop_input_cols=self._drop_input_cols,
850
891
  expected_output_cols_type="float",
851
892
  )
852
- expected_output_cols = self._align_expected_output_names(
893
+ expected_output_cols, _ = self._align_expected_output(
853
894
  inference_method, dataset, expected_output_cols, output_cols_prefix
854
895
  )
855
896
 
@@ -917,7 +958,7 @@ class XGBClassifier(BaseTransformer):
917
958
  drop_input_cols=self._drop_input_cols,
918
959
  expected_output_cols_type="float",
919
960
  )
920
- expected_output_cols = self._align_expected_output_names(
961
+ expected_output_cols, _ = self._align_expected_output(
921
962
  inference_method, dataset, expected_output_cols, output_cols_prefix
922
963
  )
923
964
  elif isinstance(dataset, pd.DataFrame):
@@ -980,7 +1021,7 @@ class XGBClassifier(BaseTransformer):
980
1021
  drop_input_cols=self._drop_input_cols,
981
1022
  expected_output_cols_type="float",
982
1023
  )
983
- expected_output_cols = self._align_expected_output_names(
1024
+ expected_output_cols, _ = self._align_expected_output(
984
1025
  inference_method, dataset, expected_output_cols, output_cols_prefix
985
1026
  )
986
1027
 
@@ -1045,7 +1086,7 @@ class XGBClassifier(BaseTransformer):
1045
1086
  drop_input_cols = self._drop_input_cols,
1046
1087
  expected_output_cols_type="float",
1047
1088
  )
1048
- expected_output_cols = self._align_expected_output_names(
1089
+ expected_output_cols, _ = self._align_expected_output(
1049
1090
  inference_method, dataset, expected_output_cols, output_cols_prefix
1050
1091
  )
1051
1092
 
@@ -1110,7 +1151,7 @@ class XGBClassifier(BaseTransformer):
1110
1151
  transform_kwargs = dict(
1111
1152
  session=dataset._session,
1112
1153
  dependencies=self._deps,
1113
- score_sproc_imports=['xgboost'],
1154
+ score_sproc_imports=['xgboost', 'sklearn'],
1114
1155
  )
1115
1156
  elif isinstance(dataset, pd.DataFrame):
1116
1157
  # pandas_handler.score() does not require any extra kwargs.
@@ -4,18 +4,17 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
18
16
  import numpy
17
+ import sklearn
19
18
  import xgboost
20
19
  from sklearn.utils.metaestimators import available_if
21
20
 
@@ -23,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
23
22
  from snowflake.ml._internal import telemetry
24
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
25
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
26
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
27
26
  from snowflake.snowpark import DataFrame, Session
28
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
29
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
30
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
31
- ModelTransformHandlers,
32
30
  BatchInferenceKwargsTypedDict,
33
31
  ScoreKwargsTypedDict
34
32
  )
@@ -361,7 +359,7 @@ class XGBRegressor(BaseTransformer):
361
359
  self.set_sample_weight_col(sample_weight_col)
362
360
  self._use_external_memory_version = use_external_memory_version
363
361
  self._batch_size = batch_size
364
- deps: Set[str] = set([f'numpy=={np.__version__}', f'xgboost=={xgboost.__version__}', f'cloudpickle=={cp.__version__}'])
362
+ deps: Set[str] = set([f'numpy=={np.__version__}', f'scikit-learn=={sklearn.__version__}', f'xgboost=={xgboost.__version__}', f'cloudpickle=={cp.__version__}'])
365
363
 
366
364
  self._deps = list(deps)
367
365
 
@@ -694,12 +692,23 @@ class XGBRegressor(BaseTransformer):
694
692
  autogenerated=self._autogenerated,
695
693
  subproject=_SUBPROJECT,
696
694
  )
697
- output_result, fitted_estimator = model_trainer.train_fit_predict(
698
- drop_input_cols=self._drop_input_cols,
699
- expected_output_cols_list=(
700
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
701
- ),
695
+ expected_output_cols = (
696
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
702
697
  )
698
+ if isinstance(dataset, DataFrame):
699
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
700
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
701
+ )
702
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
703
+ drop_input_cols=self._drop_input_cols,
704
+ expected_output_cols_list=expected_output_cols,
705
+ example_output_pd_df=example_output_pd_df,
706
+ )
707
+ else:
708
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
709
+ drop_input_cols=self._drop_input_cols,
710
+ expected_output_cols_list=expected_output_cols,
711
+ )
703
712
  self._sklearn_object = fitted_estimator
704
713
  self._is_fitted = True
705
714
  return output_result
@@ -722,6 +731,7 @@ class XGBRegressor(BaseTransformer):
722
731
  """
723
732
  self._infer_input_output_cols(dataset)
724
733
  super()._check_dataset_type(dataset)
734
+
725
735
  model_trainer = ModelTrainerBuilder.build_fit_transform(
726
736
  estimator=self._sklearn_object,
727
737
  dataset=dataset,
@@ -778,12 +788,41 @@ class XGBRegressor(BaseTransformer):
778
788
 
779
789
  return rv
780
790
 
781
- def _align_expected_output_names(
782
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
783
- ) -> List[str]:
791
+ def _align_expected_output(
792
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
793
+ ) -> Tuple[List[str], pd.DataFrame]:
794
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
795
+ and output dataframe with 1 line.
796
+ If the method is fit_predict, run 2 lines of data.
797
+ """
784
798
  # in case the inferred output column names dimension is different
785
799
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
786
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
800
+
801
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
802
+ # so change the minimum of number of rows to 2
803
+ num_examples = 2
804
+ statement_params = telemetry.get_function_usage_statement_params(
805
+ project=_PROJECT,
806
+ subproject=_SUBPROJECT,
807
+ function_name=telemetry.get_statement_params_full_func_name(
808
+ inspect.currentframe(), XGBRegressor.__class__.__name__
809
+ ),
810
+ api_calls=[Session.call],
811
+ custom_tags={"autogen": True} if self._autogenerated else None,
812
+ )
813
+ if output_cols_prefix == "fit_predict_":
814
+ if hasattr(self._sklearn_object, "n_clusters"):
815
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
816
+ num_examples = self._sklearn_object.n_clusters
817
+ elif hasattr(self._sklearn_object, "min_samples"):
818
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
819
+ num_examples = self._sklearn_object.min_samples
820
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
821
+ # LocalOutlierFactor expects n_neighbors <= n_samples
822
+ num_examples = self._sklearn_object.n_neighbors
823
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
824
+ else:
825
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
787
826
 
788
827
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
789
828
  # seen during the fit.
@@ -795,12 +834,14 @@ class XGBRegressor(BaseTransformer):
795
834
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
796
835
  if self.sample_weight_col:
797
836
  output_df_columns_set -= set(self.sample_weight_col)
837
+
798
838
  # if the dimension of inferred output column names is correct; use it
799
839
  if len(expected_output_cols_list) == len(output_df_columns_set):
800
- return expected_output_cols_list
840
+ return expected_output_cols_list, output_df_pd
801
841
  # otherwise, use the sklearn estimator's output
802
842
  else:
803
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
843
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
844
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
804
845
 
805
846
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
806
847
  @telemetry.send_api_usage_telemetry(
@@ -846,7 +887,7 @@ class XGBRegressor(BaseTransformer):
846
887
  drop_input_cols=self._drop_input_cols,
847
888
  expected_output_cols_type="float",
848
889
  )
849
- expected_output_cols = self._align_expected_output_names(
890
+ expected_output_cols, _ = self._align_expected_output(
850
891
  inference_method, dataset, expected_output_cols, output_cols_prefix
851
892
  )
852
893
 
@@ -912,7 +953,7 @@ class XGBRegressor(BaseTransformer):
912
953
  drop_input_cols=self._drop_input_cols,
913
954
  expected_output_cols_type="float",
914
955
  )
915
- expected_output_cols = self._align_expected_output_names(
956
+ expected_output_cols, _ = self._align_expected_output(
916
957
  inference_method, dataset, expected_output_cols, output_cols_prefix
917
958
  )
918
959
  elif isinstance(dataset, pd.DataFrame):
@@ -975,7 +1016,7 @@ class XGBRegressor(BaseTransformer):
975
1016
  drop_input_cols=self._drop_input_cols,
976
1017
  expected_output_cols_type="float",
977
1018
  )
978
- expected_output_cols = self._align_expected_output_names(
1019
+ expected_output_cols, _ = self._align_expected_output(
979
1020
  inference_method, dataset, expected_output_cols, output_cols_prefix
980
1021
  )
981
1022
 
@@ -1040,7 +1081,7 @@ class XGBRegressor(BaseTransformer):
1040
1081
  drop_input_cols = self._drop_input_cols,
1041
1082
  expected_output_cols_type="float",
1042
1083
  )
1043
- expected_output_cols = self._align_expected_output_names(
1084
+ expected_output_cols, _ = self._align_expected_output(
1044
1085
  inference_method, dataset, expected_output_cols, output_cols_prefix
1045
1086
  )
1046
1087
 
@@ -1105,7 +1146,7 @@ class XGBRegressor(BaseTransformer):
1105
1146
  transform_kwargs = dict(
1106
1147
  session=dataset._session,
1107
1148
  dependencies=self._deps,
1108
- score_sproc_imports=['xgboost'],
1149
+ score_sproc_imports=['xgboost', 'sklearn'],
1109
1150
  )
1110
1151
  elif isinstance(dataset, pd.DataFrame):
1111
1152
  # pandas_handler.score() does not require any extra kwargs.