snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -645,12 +642,23 @@ class HistGradientBoostingClassifier(BaseTransformer):
645
642
  autogenerated=self._autogenerated,
646
643
  subproject=_SUBPROJECT,
647
644
  )
648
- output_result, fitted_estimator = model_trainer.train_fit_predict(
649
- drop_input_cols=self._drop_input_cols,
650
- expected_output_cols_list=(
651
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
652
- ),
645
+ expected_output_cols = (
646
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
653
647
  )
648
+ if isinstance(dataset, DataFrame):
649
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
650
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
651
+ )
652
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
653
+ drop_input_cols=self._drop_input_cols,
654
+ expected_output_cols_list=expected_output_cols,
655
+ example_output_pd_df=example_output_pd_df,
656
+ )
657
+ else:
658
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
659
+ drop_input_cols=self._drop_input_cols,
660
+ expected_output_cols_list=expected_output_cols,
661
+ )
654
662
  self._sklearn_object = fitted_estimator
655
663
  self._is_fitted = True
656
664
  return output_result
@@ -673,6 +681,7 @@ class HistGradientBoostingClassifier(BaseTransformer):
673
681
  """
674
682
  self._infer_input_output_cols(dataset)
675
683
  super()._check_dataset_type(dataset)
684
+
676
685
  model_trainer = ModelTrainerBuilder.build_fit_transform(
677
686
  estimator=self._sklearn_object,
678
687
  dataset=dataset,
@@ -729,12 +738,41 @@ class HistGradientBoostingClassifier(BaseTransformer):
729
738
 
730
739
  return rv
731
740
 
732
- def _align_expected_output_names(
733
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
734
- ) -> List[str]:
741
+ def _align_expected_output(
742
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
743
+ ) -> Tuple[List[str], pd.DataFrame]:
744
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
745
+ and output dataframe with 1 line.
746
+ If the method is fit_predict, run 2 lines of data.
747
+ """
735
748
  # in case the inferred output column names dimension is different
736
749
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
737
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
750
+
751
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
752
+ # so change the minimum of number of rows to 2
753
+ num_examples = 2
754
+ statement_params = telemetry.get_function_usage_statement_params(
755
+ project=_PROJECT,
756
+ subproject=_SUBPROJECT,
757
+ function_name=telemetry.get_statement_params_full_func_name(
758
+ inspect.currentframe(), HistGradientBoostingClassifier.__class__.__name__
759
+ ),
760
+ api_calls=[Session.call],
761
+ custom_tags={"autogen": True} if self._autogenerated else None,
762
+ )
763
+ if output_cols_prefix == "fit_predict_":
764
+ if hasattr(self._sklearn_object, "n_clusters"):
765
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
766
+ num_examples = self._sklearn_object.n_clusters
767
+ elif hasattr(self._sklearn_object, "min_samples"):
768
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
769
+ num_examples = self._sklearn_object.min_samples
770
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
771
+ # LocalOutlierFactor expects n_neighbors <= n_samples
772
+ num_examples = self._sklearn_object.n_neighbors
773
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
774
+ else:
775
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
738
776
 
739
777
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
740
778
  # seen during the fit.
@@ -746,12 +784,14 @@ class HistGradientBoostingClassifier(BaseTransformer):
746
784
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
747
785
  if self.sample_weight_col:
748
786
  output_df_columns_set -= set(self.sample_weight_col)
787
+
749
788
  # if the dimension of inferred output column names is correct; use it
750
789
  if len(expected_output_cols_list) == len(output_df_columns_set):
751
- return expected_output_cols_list
790
+ return expected_output_cols_list, output_df_pd
752
791
  # otherwise, use the sklearn estimator's output
753
792
  else:
754
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
793
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
794
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
755
795
 
756
796
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
757
797
  @telemetry.send_api_usage_telemetry(
@@ -799,7 +839,7 @@ class HistGradientBoostingClassifier(BaseTransformer):
799
839
  drop_input_cols=self._drop_input_cols,
800
840
  expected_output_cols_type="float",
801
841
  )
802
- expected_output_cols = self._align_expected_output_names(
842
+ expected_output_cols, _ = self._align_expected_output(
803
843
  inference_method, dataset, expected_output_cols, output_cols_prefix
804
844
  )
805
845
 
@@ -867,7 +907,7 @@ class HistGradientBoostingClassifier(BaseTransformer):
867
907
  drop_input_cols=self._drop_input_cols,
868
908
  expected_output_cols_type="float",
869
909
  )
870
- expected_output_cols = self._align_expected_output_names(
910
+ expected_output_cols, _ = self._align_expected_output(
871
911
  inference_method, dataset, expected_output_cols, output_cols_prefix
872
912
  )
873
913
  elif isinstance(dataset, pd.DataFrame):
@@ -932,7 +972,7 @@ class HistGradientBoostingClassifier(BaseTransformer):
932
972
  drop_input_cols=self._drop_input_cols,
933
973
  expected_output_cols_type="float",
934
974
  )
935
- expected_output_cols = self._align_expected_output_names(
975
+ expected_output_cols, _ = self._align_expected_output(
936
976
  inference_method, dataset, expected_output_cols, output_cols_prefix
937
977
  )
938
978
 
@@ -997,7 +1037,7 @@ class HistGradientBoostingClassifier(BaseTransformer):
997
1037
  drop_input_cols = self._drop_input_cols,
998
1038
  expected_output_cols_type="float",
999
1039
  )
1000
- expected_output_cols = self._align_expected_output_names(
1040
+ expected_output_cols, _ = self._align_expected_output(
1001
1041
  inference_method, dataset, expected_output_cols, output_cols_prefix
1002
1042
  )
1003
1043
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -636,12 +633,23 @@ class HistGradientBoostingRegressor(BaseTransformer):
636
633
  autogenerated=self._autogenerated,
637
634
  subproject=_SUBPROJECT,
638
635
  )
639
- output_result, fitted_estimator = model_trainer.train_fit_predict(
640
- drop_input_cols=self._drop_input_cols,
641
- expected_output_cols_list=(
642
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
643
- ),
636
+ expected_output_cols = (
637
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
644
638
  )
639
+ if isinstance(dataset, DataFrame):
640
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
641
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
642
+ )
643
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
644
+ drop_input_cols=self._drop_input_cols,
645
+ expected_output_cols_list=expected_output_cols,
646
+ example_output_pd_df=example_output_pd_df,
647
+ )
648
+ else:
649
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
650
+ drop_input_cols=self._drop_input_cols,
651
+ expected_output_cols_list=expected_output_cols,
652
+ )
645
653
  self._sklearn_object = fitted_estimator
646
654
  self._is_fitted = True
647
655
  return output_result
@@ -664,6 +672,7 @@ class HistGradientBoostingRegressor(BaseTransformer):
664
672
  """
665
673
  self._infer_input_output_cols(dataset)
666
674
  super()._check_dataset_type(dataset)
675
+
667
676
  model_trainer = ModelTrainerBuilder.build_fit_transform(
668
677
  estimator=self._sklearn_object,
669
678
  dataset=dataset,
@@ -720,12 +729,41 @@ class HistGradientBoostingRegressor(BaseTransformer):
720
729
 
721
730
  return rv
722
731
 
723
- def _align_expected_output_names(
724
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
725
- ) -> List[str]:
732
+ def _align_expected_output(
733
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
734
+ ) -> Tuple[List[str], pd.DataFrame]:
735
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
736
+ and output dataframe with 1 line.
737
+ If the method is fit_predict, run 2 lines of data.
738
+ """
726
739
  # in case the inferred output column names dimension is different
727
740
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
728
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
741
+
742
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
743
+ # so change the minimum of number of rows to 2
744
+ num_examples = 2
745
+ statement_params = telemetry.get_function_usage_statement_params(
746
+ project=_PROJECT,
747
+ subproject=_SUBPROJECT,
748
+ function_name=telemetry.get_statement_params_full_func_name(
749
+ inspect.currentframe(), HistGradientBoostingRegressor.__class__.__name__
750
+ ),
751
+ api_calls=[Session.call],
752
+ custom_tags={"autogen": True} if self._autogenerated else None,
753
+ )
754
+ if output_cols_prefix == "fit_predict_":
755
+ if hasattr(self._sklearn_object, "n_clusters"):
756
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
757
+ num_examples = self._sklearn_object.n_clusters
758
+ elif hasattr(self._sklearn_object, "min_samples"):
759
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
760
+ num_examples = self._sklearn_object.min_samples
761
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
762
+ # LocalOutlierFactor expects n_neighbors <= n_samples
763
+ num_examples = self._sklearn_object.n_neighbors
764
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
765
+ else:
766
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
729
767
 
730
768
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
731
769
  # seen during the fit.
@@ -737,12 +775,14 @@ class HistGradientBoostingRegressor(BaseTransformer):
737
775
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
738
776
  if self.sample_weight_col:
739
777
  output_df_columns_set -= set(self.sample_weight_col)
778
+
740
779
  # if the dimension of inferred output column names is correct; use it
741
780
  if len(expected_output_cols_list) == len(output_df_columns_set):
742
- return expected_output_cols_list
781
+ return expected_output_cols_list, output_df_pd
743
782
  # otherwise, use the sklearn estimator's output
744
783
  else:
745
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
784
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
785
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
746
786
 
747
787
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
748
788
  @telemetry.send_api_usage_telemetry(
@@ -788,7 +828,7 @@ class HistGradientBoostingRegressor(BaseTransformer):
788
828
  drop_input_cols=self._drop_input_cols,
789
829
  expected_output_cols_type="float",
790
830
  )
791
- expected_output_cols = self._align_expected_output_names(
831
+ expected_output_cols, _ = self._align_expected_output(
792
832
  inference_method, dataset, expected_output_cols, output_cols_prefix
793
833
  )
794
834
 
@@ -854,7 +894,7 @@ class HistGradientBoostingRegressor(BaseTransformer):
854
894
  drop_input_cols=self._drop_input_cols,
855
895
  expected_output_cols_type="float",
856
896
  )
857
- expected_output_cols = self._align_expected_output_names(
897
+ expected_output_cols, _ = self._align_expected_output(
858
898
  inference_method, dataset, expected_output_cols, output_cols_prefix
859
899
  )
860
900
  elif isinstance(dataset, pd.DataFrame):
@@ -917,7 +957,7 @@ class HistGradientBoostingRegressor(BaseTransformer):
917
957
  drop_input_cols=self._drop_input_cols,
918
958
  expected_output_cols_type="float",
919
959
  )
920
- expected_output_cols = self._align_expected_output_names(
960
+ expected_output_cols, _ = self._align_expected_output(
921
961
  inference_method, dataset, expected_output_cols, output_cols_prefix
922
962
  )
923
963
 
@@ -982,7 +1022,7 @@ class HistGradientBoostingRegressor(BaseTransformer):
982
1022
  drop_input_cols = self._drop_input_cols,
983
1023
  expected_output_cols_type="float",
984
1024
  )
985
- expected_output_cols = self._align_expected_output_names(
1025
+ expected_output_cols, _ = self._align_expected_output(
986
1026
  inference_method, dataset, expected_output_cols, output_cols_prefix
987
1027
  )
988
1028
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -540,12 +537,23 @@ class IsolationForest(BaseTransformer):
540
537
  autogenerated=self._autogenerated,
541
538
  subproject=_SUBPROJECT,
542
539
  )
543
- output_result, fitted_estimator = model_trainer.train_fit_predict(
544
- drop_input_cols=self._drop_input_cols,
545
- expected_output_cols_list=(
546
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
547
- ),
540
+ expected_output_cols = (
541
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
548
542
  )
543
+ if isinstance(dataset, DataFrame):
544
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
545
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
546
+ )
547
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
548
+ drop_input_cols=self._drop_input_cols,
549
+ expected_output_cols_list=expected_output_cols,
550
+ example_output_pd_df=example_output_pd_df,
551
+ )
552
+ else:
553
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
554
+ drop_input_cols=self._drop_input_cols,
555
+ expected_output_cols_list=expected_output_cols,
556
+ )
549
557
  self._sklearn_object = fitted_estimator
550
558
  self._is_fitted = True
551
559
  return output_result
@@ -568,6 +576,7 @@ class IsolationForest(BaseTransformer):
568
576
  """
569
577
  self._infer_input_output_cols(dataset)
570
578
  super()._check_dataset_type(dataset)
579
+
571
580
  model_trainer = ModelTrainerBuilder.build_fit_transform(
572
581
  estimator=self._sklearn_object,
573
582
  dataset=dataset,
@@ -624,12 +633,41 @@ class IsolationForest(BaseTransformer):
624
633
 
625
634
  return rv
626
635
 
627
- def _align_expected_output_names(
628
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
629
- ) -> List[str]:
636
+ def _align_expected_output(
637
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
638
+ ) -> Tuple[List[str], pd.DataFrame]:
639
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
640
+ and output dataframe with 1 line.
641
+ If the method is fit_predict, run 2 lines of data.
642
+ """
630
643
  # in case the inferred output column names dimension is different
631
644
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
632
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
645
+
646
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
647
+ # so change the minimum of number of rows to 2
648
+ num_examples = 2
649
+ statement_params = telemetry.get_function_usage_statement_params(
650
+ project=_PROJECT,
651
+ subproject=_SUBPROJECT,
652
+ function_name=telemetry.get_statement_params_full_func_name(
653
+ inspect.currentframe(), IsolationForest.__class__.__name__
654
+ ),
655
+ api_calls=[Session.call],
656
+ custom_tags={"autogen": True} if self._autogenerated else None,
657
+ )
658
+ if output_cols_prefix == "fit_predict_":
659
+ if hasattr(self._sklearn_object, "n_clusters"):
660
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
661
+ num_examples = self._sklearn_object.n_clusters
662
+ elif hasattr(self._sklearn_object, "min_samples"):
663
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
664
+ num_examples = self._sklearn_object.min_samples
665
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
666
+ # LocalOutlierFactor expects n_neighbors <= n_samples
667
+ num_examples = self._sklearn_object.n_neighbors
668
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
669
+ else:
670
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
633
671
 
634
672
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
635
673
  # seen during the fit.
@@ -641,12 +679,14 @@ class IsolationForest(BaseTransformer):
641
679
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
642
680
  if self.sample_weight_col:
643
681
  output_df_columns_set -= set(self.sample_weight_col)
682
+
644
683
  # if the dimension of inferred output column names is correct; use it
645
684
  if len(expected_output_cols_list) == len(output_df_columns_set):
646
- return expected_output_cols_list
685
+ return expected_output_cols_list, output_df_pd
647
686
  # otherwise, use the sklearn estimator's output
648
687
  else:
649
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
688
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
689
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
650
690
 
651
691
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
652
692
  @telemetry.send_api_usage_telemetry(
@@ -692,7 +732,7 @@ class IsolationForest(BaseTransformer):
692
732
  drop_input_cols=self._drop_input_cols,
693
733
  expected_output_cols_type="float",
694
734
  )
695
- expected_output_cols = self._align_expected_output_names(
735
+ expected_output_cols, _ = self._align_expected_output(
696
736
  inference_method, dataset, expected_output_cols, output_cols_prefix
697
737
  )
698
738
 
@@ -758,7 +798,7 @@ class IsolationForest(BaseTransformer):
758
798
  drop_input_cols=self._drop_input_cols,
759
799
  expected_output_cols_type="float",
760
800
  )
761
- expected_output_cols = self._align_expected_output_names(
801
+ expected_output_cols, _ = self._align_expected_output(
762
802
  inference_method, dataset, expected_output_cols, output_cols_prefix
763
803
  )
764
804
  elif isinstance(dataset, pd.DataFrame):
@@ -823,7 +863,7 @@ class IsolationForest(BaseTransformer):
823
863
  drop_input_cols=self._drop_input_cols,
824
864
  expected_output_cols_type="float",
825
865
  )
826
- expected_output_cols = self._align_expected_output_names(
866
+ expected_output_cols, _ = self._align_expected_output(
827
867
  inference_method, dataset, expected_output_cols, output_cols_prefix
828
868
  )
829
869
 
@@ -890,7 +930,7 @@ class IsolationForest(BaseTransformer):
890
930
  drop_input_cols = self._drop_input_cols,
891
931
  expected_output_cols_type="float",
892
932
  )
893
- expected_output_cols = self._align_expected_output_names(
933
+ expected_output_cols, _ = self._align_expected_output(
894
934
  inference_method, dataset, expected_output_cols, output_cols_prefix
895
935
  )
896
936
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -648,12 +645,23 @@ class RandomForestClassifier(BaseTransformer):
648
645
  autogenerated=self._autogenerated,
649
646
  subproject=_SUBPROJECT,
650
647
  )
651
- output_result, fitted_estimator = model_trainer.train_fit_predict(
652
- drop_input_cols=self._drop_input_cols,
653
- expected_output_cols_list=(
654
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
655
- ),
648
+ expected_output_cols = (
649
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
656
650
  )
651
+ if isinstance(dataset, DataFrame):
652
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
653
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
654
+ )
655
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
656
+ drop_input_cols=self._drop_input_cols,
657
+ expected_output_cols_list=expected_output_cols,
658
+ example_output_pd_df=example_output_pd_df,
659
+ )
660
+ else:
661
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
662
+ drop_input_cols=self._drop_input_cols,
663
+ expected_output_cols_list=expected_output_cols,
664
+ )
657
665
  self._sklearn_object = fitted_estimator
658
666
  self._is_fitted = True
659
667
  return output_result
@@ -676,6 +684,7 @@ class RandomForestClassifier(BaseTransformer):
676
684
  """
677
685
  self._infer_input_output_cols(dataset)
678
686
  super()._check_dataset_type(dataset)
687
+
679
688
  model_trainer = ModelTrainerBuilder.build_fit_transform(
680
689
  estimator=self._sklearn_object,
681
690
  dataset=dataset,
@@ -732,12 +741,41 @@ class RandomForestClassifier(BaseTransformer):
732
741
 
733
742
  return rv
734
743
 
735
- def _align_expected_output_names(
736
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
737
- ) -> List[str]:
744
+ def _align_expected_output(
745
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
746
+ ) -> Tuple[List[str], pd.DataFrame]:
747
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
748
+ and output dataframe with 1 line.
749
+ If the method is fit_predict, run 2 lines of data.
750
+ """
738
751
  # in case the inferred output column names dimension is different
739
752
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
740
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
753
+
754
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
755
+ # so change the minimum of number of rows to 2
756
+ num_examples = 2
757
+ statement_params = telemetry.get_function_usage_statement_params(
758
+ project=_PROJECT,
759
+ subproject=_SUBPROJECT,
760
+ function_name=telemetry.get_statement_params_full_func_name(
761
+ inspect.currentframe(), RandomForestClassifier.__class__.__name__
762
+ ),
763
+ api_calls=[Session.call],
764
+ custom_tags={"autogen": True} if self._autogenerated else None,
765
+ )
766
+ if output_cols_prefix == "fit_predict_":
767
+ if hasattr(self._sklearn_object, "n_clusters"):
768
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
769
+ num_examples = self._sklearn_object.n_clusters
770
+ elif hasattr(self._sklearn_object, "min_samples"):
771
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
772
+ num_examples = self._sklearn_object.min_samples
773
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
774
+ # LocalOutlierFactor expects n_neighbors <= n_samples
775
+ num_examples = self._sklearn_object.n_neighbors
776
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
777
+ else:
778
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
741
779
 
742
780
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
743
781
  # seen during the fit.
@@ -749,12 +787,14 @@ class RandomForestClassifier(BaseTransformer):
749
787
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
750
788
  if self.sample_weight_col:
751
789
  output_df_columns_set -= set(self.sample_weight_col)
790
+
752
791
  # if the dimension of inferred output column names is correct; use it
753
792
  if len(expected_output_cols_list) == len(output_df_columns_set):
754
- return expected_output_cols_list
793
+ return expected_output_cols_list, output_df_pd
755
794
  # otherwise, use the sklearn estimator's output
756
795
  else:
757
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
796
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
797
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
758
798
 
759
799
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
760
800
  @telemetry.send_api_usage_telemetry(
@@ -802,7 +842,7 @@ class RandomForestClassifier(BaseTransformer):
802
842
  drop_input_cols=self._drop_input_cols,
803
843
  expected_output_cols_type="float",
804
844
  )
805
- expected_output_cols = self._align_expected_output_names(
845
+ expected_output_cols, _ = self._align_expected_output(
806
846
  inference_method, dataset, expected_output_cols, output_cols_prefix
807
847
  )
808
848
 
@@ -870,7 +910,7 @@ class RandomForestClassifier(BaseTransformer):
870
910
  drop_input_cols=self._drop_input_cols,
871
911
  expected_output_cols_type="float",
872
912
  )
873
- expected_output_cols = self._align_expected_output_names(
913
+ expected_output_cols, _ = self._align_expected_output(
874
914
  inference_method, dataset, expected_output_cols, output_cols_prefix
875
915
  )
876
916
  elif isinstance(dataset, pd.DataFrame):
@@ -933,7 +973,7 @@ class RandomForestClassifier(BaseTransformer):
933
973
  drop_input_cols=self._drop_input_cols,
934
974
  expected_output_cols_type="float",
935
975
  )
936
- expected_output_cols = self._align_expected_output_names(
976
+ expected_output_cols, _ = self._align_expected_output(
937
977
  inference_method, dataset, expected_output_cols, output_cols_prefix
938
978
  )
939
979
 
@@ -998,7 +1038,7 @@ class RandomForestClassifier(BaseTransformer):
998
1038
  drop_input_cols = self._drop_input_cols,
999
1039
  expected_output_cols_type="float",
1000
1040
  )
1001
- expected_output_cols = self._align_expected_output_names(
1041
+ expected_output_cols, _ = self._align_expected_output(
1002
1042
  inference_method, dataset, expected_output_cols, output_cols_prefix
1003
1043
  )
1004
1044