snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -536,12 +533,23 @@ class SVR(BaseTransformer):
536
533
  autogenerated=self._autogenerated,
537
534
  subproject=_SUBPROJECT,
538
535
  )
539
- output_result, fitted_estimator = model_trainer.train_fit_predict(
540
- drop_input_cols=self._drop_input_cols,
541
- expected_output_cols_list=(
542
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
543
- ),
536
+ expected_output_cols = (
537
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
544
538
  )
539
+ if isinstance(dataset, DataFrame):
540
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
541
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
542
+ )
543
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
544
+ drop_input_cols=self._drop_input_cols,
545
+ expected_output_cols_list=expected_output_cols,
546
+ example_output_pd_df=example_output_pd_df,
547
+ )
548
+ else:
549
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
550
+ drop_input_cols=self._drop_input_cols,
551
+ expected_output_cols_list=expected_output_cols,
552
+ )
545
553
  self._sklearn_object = fitted_estimator
546
554
  self._is_fitted = True
547
555
  return output_result
@@ -564,6 +572,7 @@ class SVR(BaseTransformer):
564
572
  """
565
573
  self._infer_input_output_cols(dataset)
566
574
  super()._check_dataset_type(dataset)
575
+
567
576
  model_trainer = ModelTrainerBuilder.build_fit_transform(
568
577
  estimator=self._sklearn_object,
569
578
  dataset=dataset,
@@ -620,12 +629,41 @@ class SVR(BaseTransformer):
620
629
 
621
630
  return rv
622
631
 
623
- def _align_expected_output_names(
624
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
625
- ) -> List[str]:
632
+ def _align_expected_output(
633
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
634
+ ) -> Tuple[List[str], pd.DataFrame]:
635
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
636
+ and output dataframe with 1 line.
637
+ If the method is fit_predict, run 2 lines of data.
638
+ """
626
639
  # in case the inferred output column names dimension is different
627
640
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
628
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
641
+
642
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
643
+ # so change the minimum of number of rows to 2
644
+ num_examples = 2
645
+ statement_params = telemetry.get_function_usage_statement_params(
646
+ project=_PROJECT,
647
+ subproject=_SUBPROJECT,
648
+ function_name=telemetry.get_statement_params_full_func_name(
649
+ inspect.currentframe(), SVR.__class__.__name__
650
+ ),
651
+ api_calls=[Session.call],
652
+ custom_tags={"autogen": True} if self._autogenerated else None,
653
+ )
654
+ if output_cols_prefix == "fit_predict_":
655
+ if hasattr(self._sklearn_object, "n_clusters"):
656
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
657
+ num_examples = self._sklearn_object.n_clusters
658
+ elif hasattr(self._sklearn_object, "min_samples"):
659
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
660
+ num_examples = self._sklearn_object.min_samples
661
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
662
+ # LocalOutlierFactor expects n_neighbors <= n_samples
663
+ num_examples = self._sklearn_object.n_neighbors
664
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
665
+ else:
666
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
629
667
 
630
668
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
631
669
  # seen during the fit.
@@ -637,12 +675,14 @@ class SVR(BaseTransformer):
637
675
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
638
676
  if self.sample_weight_col:
639
677
  output_df_columns_set -= set(self.sample_weight_col)
678
+
640
679
  # if the dimension of inferred output column names is correct; use it
641
680
  if len(expected_output_cols_list) == len(output_df_columns_set):
642
- return expected_output_cols_list
681
+ return expected_output_cols_list, output_df_pd
643
682
  # otherwise, use the sklearn estimator's output
644
683
  else:
645
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
684
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
685
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
646
686
 
647
687
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
648
688
  @telemetry.send_api_usage_telemetry(
@@ -688,7 +728,7 @@ class SVR(BaseTransformer):
688
728
  drop_input_cols=self._drop_input_cols,
689
729
  expected_output_cols_type="float",
690
730
  )
691
- expected_output_cols = self._align_expected_output_names(
731
+ expected_output_cols, _ = self._align_expected_output(
692
732
  inference_method, dataset, expected_output_cols, output_cols_prefix
693
733
  )
694
734
 
@@ -754,7 +794,7 @@ class SVR(BaseTransformer):
754
794
  drop_input_cols=self._drop_input_cols,
755
795
  expected_output_cols_type="float",
756
796
  )
757
- expected_output_cols = self._align_expected_output_names(
797
+ expected_output_cols, _ = self._align_expected_output(
758
798
  inference_method, dataset, expected_output_cols, output_cols_prefix
759
799
  )
760
800
  elif isinstance(dataset, pd.DataFrame):
@@ -817,7 +857,7 @@ class SVR(BaseTransformer):
817
857
  drop_input_cols=self._drop_input_cols,
818
858
  expected_output_cols_type="float",
819
859
  )
820
- expected_output_cols = self._align_expected_output_names(
860
+ expected_output_cols, _ = self._align_expected_output(
821
861
  inference_method, dataset, expected_output_cols, output_cols_prefix
822
862
  )
823
863
 
@@ -882,7 +922,7 @@ class SVR(BaseTransformer):
882
922
  drop_input_cols = self._drop_input_cols,
883
923
  expected_output_cols_type="float",
884
924
  )
885
- expected_output_cols = self._align_expected_output_names(
925
+ expected_output_cols, _ = self._align_expected_output(
886
926
  inference_method, dataset, expected_output_cols, output_cols_prefix
887
927
  )
888
928
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -603,12 +600,23 @@ class DecisionTreeClassifier(BaseTransformer):
603
600
  autogenerated=self._autogenerated,
604
601
  subproject=_SUBPROJECT,
605
602
  )
606
- output_result, fitted_estimator = model_trainer.train_fit_predict(
607
- drop_input_cols=self._drop_input_cols,
608
- expected_output_cols_list=(
609
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
610
- ),
603
+ expected_output_cols = (
604
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
611
605
  )
606
+ if isinstance(dataset, DataFrame):
607
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
608
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
609
+ )
610
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
611
+ drop_input_cols=self._drop_input_cols,
612
+ expected_output_cols_list=expected_output_cols,
613
+ example_output_pd_df=example_output_pd_df,
614
+ )
615
+ else:
616
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
617
+ drop_input_cols=self._drop_input_cols,
618
+ expected_output_cols_list=expected_output_cols,
619
+ )
612
620
  self._sklearn_object = fitted_estimator
613
621
  self._is_fitted = True
614
622
  return output_result
@@ -631,6 +639,7 @@ class DecisionTreeClassifier(BaseTransformer):
631
639
  """
632
640
  self._infer_input_output_cols(dataset)
633
641
  super()._check_dataset_type(dataset)
642
+
634
643
  model_trainer = ModelTrainerBuilder.build_fit_transform(
635
644
  estimator=self._sklearn_object,
636
645
  dataset=dataset,
@@ -687,12 +696,41 @@ class DecisionTreeClassifier(BaseTransformer):
687
696
 
688
697
  return rv
689
698
 
690
- def _align_expected_output_names(
691
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
692
- ) -> List[str]:
699
+ def _align_expected_output(
700
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
701
+ ) -> Tuple[List[str], pd.DataFrame]:
702
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
703
+ and output dataframe with 1 line.
704
+ If the method is fit_predict, run 2 lines of data.
705
+ """
693
706
  # in case the inferred output column names dimension is different
694
707
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
695
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
708
+
709
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
710
+ # so change the minimum of number of rows to 2
711
+ num_examples = 2
712
+ statement_params = telemetry.get_function_usage_statement_params(
713
+ project=_PROJECT,
714
+ subproject=_SUBPROJECT,
715
+ function_name=telemetry.get_statement_params_full_func_name(
716
+ inspect.currentframe(), DecisionTreeClassifier.__class__.__name__
717
+ ),
718
+ api_calls=[Session.call],
719
+ custom_tags={"autogen": True} if self._autogenerated else None,
720
+ )
721
+ if output_cols_prefix == "fit_predict_":
722
+ if hasattr(self._sklearn_object, "n_clusters"):
723
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
724
+ num_examples = self._sklearn_object.n_clusters
725
+ elif hasattr(self._sklearn_object, "min_samples"):
726
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
727
+ num_examples = self._sklearn_object.min_samples
728
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
729
+ # LocalOutlierFactor expects n_neighbors <= n_samples
730
+ num_examples = self._sklearn_object.n_neighbors
731
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
732
+ else:
733
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
696
734
 
697
735
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
698
736
  # seen during the fit.
@@ -704,12 +742,14 @@ class DecisionTreeClassifier(BaseTransformer):
704
742
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
705
743
  if self.sample_weight_col:
706
744
  output_df_columns_set -= set(self.sample_weight_col)
745
+
707
746
  # if the dimension of inferred output column names is correct; use it
708
747
  if len(expected_output_cols_list) == len(output_df_columns_set):
709
- return expected_output_cols_list
748
+ return expected_output_cols_list, output_df_pd
710
749
  # otherwise, use the sklearn estimator's output
711
750
  else:
712
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
751
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
752
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
713
753
 
714
754
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
715
755
  @telemetry.send_api_usage_telemetry(
@@ -757,7 +797,7 @@ class DecisionTreeClassifier(BaseTransformer):
757
797
  drop_input_cols=self._drop_input_cols,
758
798
  expected_output_cols_type="float",
759
799
  )
760
- expected_output_cols = self._align_expected_output_names(
800
+ expected_output_cols, _ = self._align_expected_output(
761
801
  inference_method, dataset, expected_output_cols, output_cols_prefix
762
802
  )
763
803
 
@@ -825,7 +865,7 @@ class DecisionTreeClassifier(BaseTransformer):
825
865
  drop_input_cols=self._drop_input_cols,
826
866
  expected_output_cols_type="float",
827
867
  )
828
- expected_output_cols = self._align_expected_output_names(
868
+ expected_output_cols, _ = self._align_expected_output(
829
869
  inference_method, dataset, expected_output_cols, output_cols_prefix
830
870
  )
831
871
  elif isinstance(dataset, pd.DataFrame):
@@ -888,7 +928,7 @@ class DecisionTreeClassifier(BaseTransformer):
888
928
  drop_input_cols=self._drop_input_cols,
889
929
  expected_output_cols_type="float",
890
930
  )
891
- expected_output_cols = self._align_expected_output_names(
931
+ expected_output_cols, _ = self._align_expected_output(
892
932
  inference_method, dataset, expected_output_cols, output_cols_prefix
893
933
  )
894
934
 
@@ -953,7 +993,7 @@ class DecisionTreeClassifier(BaseTransformer):
953
993
  drop_input_cols = self._drop_input_cols,
954
994
  expected_output_cols_type="float",
955
995
  )
956
- expected_output_cols = self._align_expected_output_names(
996
+ expected_output_cols, _ = self._align_expected_output(
957
997
  inference_method, dataset, expected_output_cols, output_cols_prefix
958
998
  )
959
999
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -585,12 +582,23 @@ class DecisionTreeRegressor(BaseTransformer):
585
582
  autogenerated=self._autogenerated,
586
583
  subproject=_SUBPROJECT,
587
584
  )
588
- output_result, fitted_estimator = model_trainer.train_fit_predict(
589
- drop_input_cols=self._drop_input_cols,
590
- expected_output_cols_list=(
591
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
592
- ),
585
+ expected_output_cols = (
586
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
593
587
  )
588
+ if isinstance(dataset, DataFrame):
589
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
590
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
591
+ )
592
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
593
+ drop_input_cols=self._drop_input_cols,
594
+ expected_output_cols_list=expected_output_cols,
595
+ example_output_pd_df=example_output_pd_df,
596
+ )
597
+ else:
598
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
599
+ drop_input_cols=self._drop_input_cols,
600
+ expected_output_cols_list=expected_output_cols,
601
+ )
594
602
  self._sklearn_object = fitted_estimator
595
603
  self._is_fitted = True
596
604
  return output_result
@@ -613,6 +621,7 @@ class DecisionTreeRegressor(BaseTransformer):
613
621
  """
614
622
  self._infer_input_output_cols(dataset)
615
623
  super()._check_dataset_type(dataset)
624
+
616
625
  model_trainer = ModelTrainerBuilder.build_fit_transform(
617
626
  estimator=self._sklearn_object,
618
627
  dataset=dataset,
@@ -669,12 +678,41 @@ class DecisionTreeRegressor(BaseTransformer):
669
678
 
670
679
  return rv
671
680
 
672
- def _align_expected_output_names(
673
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
674
- ) -> List[str]:
681
+ def _align_expected_output(
682
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
683
+ ) -> Tuple[List[str], pd.DataFrame]:
684
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
685
+ and output dataframe with 1 line.
686
+ If the method is fit_predict, run 2 lines of data.
687
+ """
675
688
  # in case the inferred output column names dimension is different
676
689
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
677
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
690
+
691
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
692
+ # so change the minimum of number of rows to 2
693
+ num_examples = 2
694
+ statement_params = telemetry.get_function_usage_statement_params(
695
+ project=_PROJECT,
696
+ subproject=_SUBPROJECT,
697
+ function_name=telemetry.get_statement_params_full_func_name(
698
+ inspect.currentframe(), DecisionTreeRegressor.__class__.__name__
699
+ ),
700
+ api_calls=[Session.call],
701
+ custom_tags={"autogen": True} if self._autogenerated else None,
702
+ )
703
+ if output_cols_prefix == "fit_predict_":
704
+ if hasattr(self._sklearn_object, "n_clusters"):
705
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
706
+ num_examples = self._sklearn_object.n_clusters
707
+ elif hasattr(self._sklearn_object, "min_samples"):
708
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
709
+ num_examples = self._sklearn_object.min_samples
710
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
711
+ # LocalOutlierFactor expects n_neighbors <= n_samples
712
+ num_examples = self._sklearn_object.n_neighbors
713
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
714
+ else:
715
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
678
716
 
679
717
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
680
718
  # seen during the fit.
@@ -686,12 +724,14 @@ class DecisionTreeRegressor(BaseTransformer):
686
724
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
687
725
  if self.sample_weight_col:
688
726
  output_df_columns_set -= set(self.sample_weight_col)
727
+
689
728
  # if the dimension of inferred output column names is correct; use it
690
729
  if len(expected_output_cols_list) == len(output_df_columns_set):
691
- return expected_output_cols_list
730
+ return expected_output_cols_list, output_df_pd
692
731
  # otherwise, use the sklearn estimator's output
693
732
  else:
694
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
733
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
734
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
695
735
 
696
736
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
697
737
  @telemetry.send_api_usage_telemetry(
@@ -737,7 +777,7 @@ class DecisionTreeRegressor(BaseTransformer):
737
777
  drop_input_cols=self._drop_input_cols,
738
778
  expected_output_cols_type="float",
739
779
  )
740
- expected_output_cols = self._align_expected_output_names(
780
+ expected_output_cols, _ = self._align_expected_output(
741
781
  inference_method, dataset, expected_output_cols, output_cols_prefix
742
782
  )
743
783
 
@@ -803,7 +843,7 @@ class DecisionTreeRegressor(BaseTransformer):
803
843
  drop_input_cols=self._drop_input_cols,
804
844
  expected_output_cols_type="float",
805
845
  )
806
- expected_output_cols = self._align_expected_output_names(
846
+ expected_output_cols, _ = self._align_expected_output(
807
847
  inference_method, dataset, expected_output_cols, output_cols_prefix
808
848
  )
809
849
  elif isinstance(dataset, pd.DataFrame):
@@ -866,7 +906,7 @@ class DecisionTreeRegressor(BaseTransformer):
866
906
  drop_input_cols=self._drop_input_cols,
867
907
  expected_output_cols_type="float",
868
908
  )
869
- expected_output_cols = self._align_expected_output_names(
909
+ expected_output_cols, _ = self._align_expected_output(
870
910
  inference_method, dataset, expected_output_cols, output_cols_prefix
871
911
  )
872
912
 
@@ -931,7 +971,7 @@ class DecisionTreeRegressor(BaseTransformer):
931
971
  drop_input_cols = self._drop_input_cols,
932
972
  expected_output_cols_type="float",
933
973
  )
934
- expected_output_cols = self._align_expected_output_names(
974
+ expected_output_cols, _ = self._align_expected_output(
935
975
  inference_method, dataset, expected_output_cols, output_cols_prefix
936
976
  )
937
977
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -595,12 +592,23 @@ class ExtraTreeClassifier(BaseTransformer):
595
592
  autogenerated=self._autogenerated,
596
593
  subproject=_SUBPROJECT,
597
594
  )
598
- output_result, fitted_estimator = model_trainer.train_fit_predict(
599
- drop_input_cols=self._drop_input_cols,
600
- expected_output_cols_list=(
601
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
602
- ),
595
+ expected_output_cols = (
596
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
603
597
  )
598
+ if isinstance(dataset, DataFrame):
599
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
600
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
601
+ )
602
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
603
+ drop_input_cols=self._drop_input_cols,
604
+ expected_output_cols_list=expected_output_cols,
605
+ example_output_pd_df=example_output_pd_df,
606
+ )
607
+ else:
608
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
609
+ drop_input_cols=self._drop_input_cols,
610
+ expected_output_cols_list=expected_output_cols,
611
+ )
604
612
  self._sklearn_object = fitted_estimator
605
613
  self._is_fitted = True
606
614
  return output_result
@@ -623,6 +631,7 @@ class ExtraTreeClassifier(BaseTransformer):
623
631
  """
624
632
  self._infer_input_output_cols(dataset)
625
633
  super()._check_dataset_type(dataset)
634
+
626
635
  model_trainer = ModelTrainerBuilder.build_fit_transform(
627
636
  estimator=self._sklearn_object,
628
637
  dataset=dataset,
@@ -679,12 +688,41 @@ class ExtraTreeClassifier(BaseTransformer):
679
688
 
680
689
  return rv
681
690
 
682
- def _align_expected_output_names(
683
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
684
- ) -> List[str]:
691
+ def _align_expected_output(
692
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
693
+ ) -> Tuple[List[str], pd.DataFrame]:
694
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
695
+ and output dataframe with 1 line.
696
+ If the method is fit_predict, run 2 lines of data.
697
+ """
685
698
  # in case the inferred output column names dimension is different
686
699
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
687
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
700
+
701
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
702
+ # so change the minimum of number of rows to 2
703
+ num_examples = 2
704
+ statement_params = telemetry.get_function_usage_statement_params(
705
+ project=_PROJECT,
706
+ subproject=_SUBPROJECT,
707
+ function_name=telemetry.get_statement_params_full_func_name(
708
+ inspect.currentframe(), ExtraTreeClassifier.__class__.__name__
709
+ ),
710
+ api_calls=[Session.call],
711
+ custom_tags={"autogen": True} if self._autogenerated else None,
712
+ )
713
+ if output_cols_prefix == "fit_predict_":
714
+ if hasattr(self._sklearn_object, "n_clusters"):
715
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
716
+ num_examples = self._sklearn_object.n_clusters
717
+ elif hasattr(self._sklearn_object, "min_samples"):
718
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
719
+ num_examples = self._sklearn_object.min_samples
720
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
721
+ # LocalOutlierFactor expects n_neighbors <= n_samples
722
+ num_examples = self._sklearn_object.n_neighbors
723
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
724
+ else:
725
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
688
726
 
689
727
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
690
728
  # seen during the fit.
@@ -696,12 +734,14 @@ class ExtraTreeClassifier(BaseTransformer):
696
734
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
697
735
  if self.sample_weight_col:
698
736
  output_df_columns_set -= set(self.sample_weight_col)
737
+
699
738
  # if the dimension of inferred output column names is correct; use it
700
739
  if len(expected_output_cols_list) == len(output_df_columns_set):
701
- return expected_output_cols_list
740
+ return expected_output_cols_list, output_df_pd
702
741
  # otherwise, use the sklearn estimator's output
703
742
  else:
704
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
743
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
744
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
705
745
 
706
746
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
707
747
  @telemetry.send_api_usage_telemetry(
@@ -749,7 +789,7 @@ class ExtraTreeClassifier(BaseTransformer):
749
789
  drop_input_cols=self._drop_input_cols,
750
790
  expected_output_cols_type="float",
751
791
  )
752
- expected_output_cols = self._align_expected_output_names(
792
+ expected_output_cols, _ = self._align_expected_output(
753
793
  inference_method, dataset, expected_output_cols, output_cols_prefix
754
794
  )
755
795
 
@@ -817,7 +857,7 @@ class ExtraTreeClassifier(BaseTransformer):
817
857
  drop_input_cols=self._drop_input_cols,
818
858
  expected_output_cols_type="float",
819
859
  )
820
- expected_output_cols = self._align_expected_output_names(
860
+ expected_output_cols, _ = self._align_expected_output(
821
861
  inference_method, dataset, expected_output_cols, output_cols_prefix
822
862
  )
823
863
  elif isinstance(dataset, pd.DataFrame):
@@ -880,7 +920,7 @@ class ExtraTreeClassifier(BaseTransformer):
880
920
  drop_input_cols=self._drop_input_cols,
881
921
  expected_output_cols_type="float",
882
922
  )
883
- expected_output_cols = self._align_expected_output_names(
923
+ expected_output_cols, _ = self._align_expected_output(
884
924
  inference_method, dataset, expected_output_cols, output_cols_prefix
885
925
  )
886
926
 
@@ -945,7 +985,7 @@ class ExtraTreeClassifier(BaseTransformer):
945
985
  drop_input_cols = self._drop_input_cols,
946
986
  expected_output_cols_type="float",
947
987
  )
948
- expected_output_cols = self._align_expected_output_names(
988
+ expected_output_cols, _ = self._align_expected_output(
949
989
  inference_method, dataset, expected_output_cols, output_cols_prefix
950
990
  )
951
991