snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -25,12 +23,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
25
23
  from snowflake.ml._internal import telemetry
26
24
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
27
25
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
28
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
26
+ from snowflake.ml._internal.utils import identifier
29
27
  from snowflake.snowpark import DataFrame, Session
30
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
31
29
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
32
30
  from snowflake.ml.modeling._internal.transformer_protocols import (
33
- ModelTransformHandlers,
34
31
  BatchInferenceKwargsTypedDict,
35
32
  ScoreKwargsTypedDict
36
33
  )
@@ -479,12 +476,23 @@ class SelectKBest(BaseTransformer):
479
476
  autogenerated=self._autogenerated,
480
477
  subproject=_SUBPROJECT,
481
478
  )
482
- output_result, fitted_estimator = model_trainer.train_fit_predict(
483
- drop_input_cols=self._drop_input_cols,
484
- expected_output_cols_list=(
485
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
486
- ),
479
+ expected_output_cols = (
480
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
487
481
  )
482
+ if isinstance(dataset, DataFrame):
483
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
484
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
485
+ )
486
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
487
+ drop_input_cols=self._drop_input_cols,
488
+ expected_output_cols_list=expected_output_cols,
489
+ example_output_pd_df=example_output_pd_df,
490
+ )
491
+ else:
492
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
493
+ drop_input_cols=self._drop_input_cols,
494
+ expected_output_cols_list=expected_output_cols,
495
+ )
488
496
  self._sklearn_object = fitted_estimator
489
497
  self._is_fitted = True
490
498
  return output_result
@@ -509,6 +517,7 @@ class SelectKBest(BaseTransformer):
509
517
  """
510
518
  self._infer_input_output_cols(dataset)
511
519
  super()._check_dataset_type(dataset)
520
+
512
521
  model_trainer = ModelTrainerBuilder.build_fit_transform(
513
522
  estimator=self._sklearn_object,
514
523
  dataset=dataset,
@@ -565,12 +574,41 @@ class SelectKBest(BaseTransformer):
565
574
 
566
575
  return rv
567
576
 
568
- def _align_expected_output_names(
569
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
570
- ) -> List[str]:
577
+ def _align_expected_output(
578
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
579
+ ) -> Tuple[List[str], pd.DataFrame]:
580
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
581
+ and output dataframe with 1 line.
582
+ If the method is fit_predict, run 2 lines of data.
583
+ """
571
584
  # in case the inferred output column names dimension is different
572
585
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
573
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
586
+
587
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
588
+ # so change the minimum of number of rows to 2
589
+ num_examples = 2
590
+ statement_params = telemetry.get_function_usage_statement_params(
591
+ project=_PROJECT,
592
+ subproject=_SUBPROJECT,
593
+ function_name=telemetry.get_statement_params_full_func_name(
594
+ inspect.currentframe(), SelectKBest.__class__.__name__
595
+ ),
596
+ api_calls=[Session.call],
597
+ custom_tags={"autogen": True} if self._autogenerated else None,
598
+ )
599
+ if output_cols_prefix == "fit_predict_":
600
+ if hasattr(self._sklearn_object, "n_clusters"):
601
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
602
+ num_examples = self._sklearn_object.n_clusters
603
+ elif hasattr(self._sklearn_object, "min_samples"):
604
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
605
+ num_examples = self._sklearn_object.min_samples
606
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
607
+ # LocalOutlierFactor expects n_neighbors <= n_samples
608
+ num_examples = self._sklearn_object.n_neighbors
609
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
610
+ else:
611
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
574
612
 
575
613
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
576
614
  # seen during the fit.
@@ -582,12 +620,14 @@ class SelectKBest(BaseTransformer):
582
620
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
583
621
  if self.sample_weight_col:
584
622
  output_df_columns_set -= set(self.sample_weight_col)
623
+
585
624
  # if the dimension of inferred output column names is correct; use it
586
625
  if len(expected_output_cols_list) == len(output_df_columns_set):
587
- return expected_output_cols_list
626
+ return expected_output_cols_list, output_df_pd
588
627
  # otherwise, use the sklearn estimator's output
589
628
  else:
590
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
629
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
630
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
591
631
 
592
632
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
593
633
  @telemetry.send_api_usage_telemetry(
@@ -633,7 +673,7 @@ class SelectKBest(BaseTransformer):
633
673
  drop_input_cols=self._drop_input_cols,
634
674
  expected_output_cols_type="float",
635
675
  )
636
- expected_output_cols = self._align_expected_output_names(
676
+ expected_output_cols, _ = self._align_expected_output(
637
677
  inference_method, dataset, expected_output_cols, output_cols_prefix
638
678
  )
639
679
 
@@ -699,7 +739,7 @@ class SelectKBest(BaseTransformer):
699
739
  drop_input_cols=self._drop_input_cols,
700
740
  expected_output_cols_type="float",
701
741
  )
702
- expected_output_cols = self._align_expected_output_names(
742
+ expected_output_cols, _ = self._align_expected_output(
703
743
  inference_method, dataset, expected_output_cols, output_cols_prefix
704
744
  )
705
745
  elif isinstance(dataset, pd.DataFrame):
@@ -762,7 +802,7 @@ class SelectKBest(BaseTransformer):
762
802
  drop_input_cols=self._drop_input_cols,
763
803
  expected_output_cols_type="float",
764
804
  )
765
- expected_output_cols = self._align_expected_output_names(
805
+ expected_output_cols, _ = self._align_expected_output(
766
806
  inference_method, dataset, expected_output_cols, output_cols_prefix
767
807
  )
768
808
 
@@ -827,7 +867,7 @@ class SelectKBest(BaseTransformer):
827
867
  drop_input_cols = self._drop_input_cols,
828
868
  expected_output_cols_type="float",
829
869
  )
830
- expected_output_cols = self._align_expected_output_names(
870
+ expected_output_cols, _ = self._align_expected_output(
831
871
  inference_method, dataset, expected_output_cols, output_cols_prefix
832
872
  )
833
873
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -25,12 +23,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
25
23
  from snowflake.ml._internal import telemetry
26
24
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
27
25
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
28
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
26
+ from snowflake.ml._internal.utils import identifier
29
27
  from snowflake.snowpark import DataFrame, Session
30
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
31
29
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
32
30
  from snowflake.ml.modeling._internal.transformer_protocols import (
33
- ModelTransformHandlers,
34
31
  BatchInferenceKwargsTypedDict,
35
32
  ScoreKwargsTypedDict
36
33
  )
@@ -478,12 +475,23 @@ class SelectPercentile(BaseTransformer):
478
475
  autogenerated=self._autogenerated,
479
476
  subproject=_SUBPROJECT,
480
477
  )
481
- output_result, fitted_estimator = model_trainer.train_fit_predict(
482
- drop_input_cols=self._drop_input_cols,
483
- expected_output_cols_list=(
484
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
485
- ),
478
+ expected_output_cols = (
479
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
486
480
  )
481
+ if isinstance(dataset, DataFrame):
482
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
483
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
484
+ )
485
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
486
+ drop_input_cols=self._drop_input_cols,
487
+ expected_output_cols_list=expected_output_cols,
488
+ example_output_pd_df=example_output_pd_df,
489
+ )
490
+ else:
491
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
492
+ drop_input_cols=self._drop_input_cols,
493
+ expected_output_cols_list=expected_output_cols,
494
+ )
487
495
  self._sklearn_object = fitted_estimator
488
496
  self._is_fitted = True
489
497
  return output_result
@@ -508,6 +516,7 @@ class SelectPercentile(BaseTransformer):
508
516
  """
509
517
  self._infer_input_output_cols(dataset)
510
518
  super()._check_dataset_type(dataset)
519
+
511
520
  model_trainer = ModelTrainerBuilder.build_fit_transform(
512
521
  estimator=self._sklearn_object,
513
522
  dataset=dataset,
@@ -564,12 +573,41 @@ class SelectPercentile(BaseTransformer):
564
573
 
565
574
  return rv
566
575
 
567
- def _align_expected_output_names(
568
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
569
- ) -> List[str]:
576
+ def _align_expected_output(
577
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
578
+ ) -> Tuple[List[str], pd.DataFrame]:
579
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
580
+ and output dataframe with 1 line.
581
+ If the method is fit_predict, run 2 lines of data.
582
+ """
570
583
  # in case the inferred output column names dimension is different
571
584
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
572
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
585
+
586
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
587
+ # so change the minimum of number of rows to 2
588
+ num_examples = 2
589
+ statement_params = telemetry.get_function_usage_statement_params(
590
+ project=_PROJECT,
591
+ subproject=_SUBPROJECT,
592
+ function_name=telemetry.get_statement_params_full_func_name(
593
+ inspect.currentframe(), SelectPercentile.__class__.__name__
594
+ ),
595
+ api_calls=[Session.call],
596
+ custom_tags={"autogen": True} if self._autogenerated else None,
597
+ )
598
+ if output_cols_prefix == "fit_predict_":
599
+ if hasattr(self._sklearn_object, "n_clusters"):
600
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
601
+ num_examples = self._sklearn_object.n_clusters
602
+ elif hasattr(self._sklearn_object, "min_samples"):
603
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
604
+ num_examples = self._sklearn_object.min_samples
605
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
606
+ # LocalOutlierFactor expects n_neighbors <= n_samples
607
+ num_examples = self._sklearn_object.n_neighbors
608
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
609
+ else:
610
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
573
611
 
574
612
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
575
613
  # seen during the fit.
@@ -581,12 +619,14 @@ class SelectPercentile(BaseTransformer):
581
619
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
582
620
  if self.sample_weight_col:
583
621
  output_df_columns_set -= set(self.sample_weight_col)
622
+
584
623
  # if the dimension of inferred output column names is correct; use it
585
624
  if len(expected_output_cols_list) == len(output_df_columns_set):
586
- return expected_output_cols_list
625
+ return expected_output_cols_list, output_df_pd
587
626
  # otherwise, use the sklearn estimator's output
588
627
  else:
589
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
628
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
629
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
590
630
 
591
631
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
592
632
  @telemetry.send_api_usage_telemetry(
@@ -632,7 +672,7 @@ class SelectPercentile(BaseTransformer):
632
672
  drop_input_cols=self._drop_input_cols,
633
673
  expected_output_cols_type="float",
634
674
  )
635
- expected_output_cols = self._align_expected_output_names(
675
+ expected_output_cols, _ = self._align_expected_output(
636
676
  inference_method, dataset, expected_output_cols, output_cols_prefix
637
677
  )
638
678
 
@@ -698,7 +738,7 @@ class SelectPercentile(BaseTransformer):
698
738
  drop_input_cols=self._drop_input_cols,
699
739
  expected_output_cols_type="float",
700
740
  )
701
- expected_output_cols = self._align_expected_output_names(
741
+ expected_output_cols, _ = self._align_expected_output(
702
742
  inference_method, dataset, expected_output_cols, output_cols_prefix
703
743
  )
704
744
  elif isinstance(dataset, pd.DataFrame):
@@ -761,7 +801,7 @@ class SelectPercentile(BaseTransformer):
761
801
  drop_input_cols=self._drop_input_cols,
762
802
  expected_output_cols_type="float",
763
803
  )
764
- expected_output_cols = self._align_expected_output_names(
804
+ expected_output_cols, _ = self._align_expected_output(
765
805
  inference_method, dataset, expected_output_cols, output_cols_prefix
766
806
  )
767
807
 
@@ -826,7 +866,7 @@ class SelectPercentile(BaseTransformer):
826
866
  drop_input_cols = self._drop_input_cols,
827
867
  expected_output_cols_type="float",
828
868
  )
829
- expected_output_cols = self._align_expected_output_names(
869
+ expected_output_cols, _ = self._align_expected_output(
830
870
  inference_method, dataset, expected_output_cols, output_cols_prefix
831
871
  )
832
872
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -538,12 +535,23 @@ class SequentialFeatureSelector(BaseTransformer):
538
535
  autogenerated=self._autogenerated,
539
536
  subproject=_SUBPROJECT,
540
537
  )
541
- output_result, fitted_estimator = model_trainer.train_fit_predict(
542
- drop_input_cols=self._drop_input_cols,
543
- expected_output_cols_list=(
544
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
545
- ),
538
+ expected_output_cols = (
539
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
546
540
  )
541
+ if isinstance(dataset, DataFrame):
542
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
543
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
544
+ )
545
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
546
+ drop_input_cols=self._drop_input_cols,
547
+ expected_output_cols_list=expected_output_cols,
548
+ example_output_pd_df=example_output_pd_df,
549
+ )
550
+ else:
551
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
552
+ drop_input_cols=self._drop_input_cols,
553
+ expected_output_cols_list=expected_output_cols,
554
+ )
547
555
  self._sklearn_object = fitted_estimator
548
556
  self._is_fitted = True
549
557
  return output_result
@@ -568,6 +576,7 @@ class SequentialFeatureSelector(BaseTransformer):
568
576
  """
569
577
  self._infer_input_output_cols(dataset)
570
578
  super()._check_dataset_type(dataset)
579
+
571
580
  model_trainer = ModelTrainerBuilder.build_fit_transform(
572
581
  estimator=self._sklearn_object,
573
582
  dataset=dataset,
@@ -624,12 +633,41 @@ class SequentialFeatureSelector(BaseTransformer):
624
633
 
625
634
  return rv
626
635
 
627
- def _align_expected_output_names(
628
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
629
- ) -> List[str]:
636
+ def _align_expected_output(
637
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
638
+ ) -> Tuple[List[str], pd.DataFrame]:
639
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
640
+ and output dataframe with 1 line.
641
+ If the method is fit_predict, run 2 lines of data.
642
+ """
630
643
  # in case the inferred output column names dimension is different
631
644
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
632
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
645
+
646
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
647
+ # so change the minimum of number of rows to 2
648
+ num_examples = 2
649
+ statement_params = telemetry.get_function_usage_statement_params(
650
+ project=_PROJECT,
651
+ subproject=_SUBPROJECT,
652
+ function_name=telemetry.get_statement_params_full_func_name(
653
+ inspect.currentframe(), SequentialFeatureSelector.__class__.__name__
654
+ ),
655
+ api_calls=[Session.call],
656
+ custom_tags={"autogen": True} if self._autogenerated else None,
657
+ )
658
+ if output_cols_prefix == "fit_predict_":
659
+ if hasattr(self._sklearn_object, "n_clusters"):
660
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
661
+ num_examples = self._sklearn_object.n_clusters
662
+ elif hasattr(self._sklearn_object, "min_samples"):
663
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
664
+ num_examples = self._sklearn_object.min_samples
665
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
666
+ # LocalOutlierFactor expects n_neighbors <= n_samples
667
+ num_examples = self._sklearn_object.n_neighbors
668
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
669
+ else:
670
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
633
671
 
634
672
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
635
673
  # seen during the fit.
@@ -641,12 +679,14 @@ class SequentialFeatureSelector(BaseTransformer):
641
679
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
642
680
  if self.sample_weight_col:
643
681
  output_df_columns_set -= set(self.sample_weight_col)
682
+
644
683
  # if the dimension of inferred output column names is correct; use it
645
684
  if len(expected_output_cols_list) == len(output_df_columns_set):
646
- return expected_output_cols_list
685
+ return expected_output_cols_list, output_df_pd
647
686
  # otherwise, use the sklearn estimator's output
648
687
  else:
649
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
688
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
689
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
650
690
 
651
691
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
652
692
  @telemetry.send_api_usage_telemetry(
@@ -692,7 +732,7 @@ class SequentialFeatureSelector(BaseTransformer):
692
732
  drop_input_cols=self._drop_input_cols,
693
733
  expected_output_cols_type="float",
694
734
  )
695
- expected_output_cols = self._align_expected_output_names(
735
+ expected_output_cols, _ = self._align_expected_output(
696
736
  inference_method, dataset, expected_output_cols, output_cols_prefix
697
737
  )
698
738
 
@@ -758,7 +798,7 @@ class SequentialFeatureSelector(BaseTransformer):
758
798
  drop_input_cols=self._drop_input_cols,
759
799
  expected_output_cols_type="float",
760
800
  )
761
- expected_output_cols = self._align_expected_output_names(
801
+ expected_output_cols, _ = self._align_expected_output(
762
802
  inference_method, dataset, expected_output_cols, output_cols_prefix
763
803
  )
764
804
  elif isinstance(dataset, pd.DataFrame):
@@ -821,7 +861,7 @@ class SequentialFeatureSelector(BaseTransformer):
821
861
  drop_input_cols=self._drop_input_cols,
822
862
  expected_output_cols_type="float",
823
863
  )
824
- expected_output_cols = self._align_expected_output_names(
864
+ expected_output_cols, _ = self._align_expected_output(
825
865
  inference_method, dataset, expected_output_cols, output_cols_prefix
826
866
  )
827
867
 
@@ -886,7 +926,7 @@ class SequentialFeatureSelector(BaseTransformer):
886
926
  drop_input_cols = self._drop_input_cols,
887
927
  expected_output_cols_type="float",
888
928
  )
889
- expected_output_cols = self._align_expected_output_names(
929
+ expected_output_cols, _ = self._align_expected_output(
890
930
  inference_method, dataset, expected_output_cols, output_cols_prefix
891
931
  )
892
932
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -471,12 +468,23 @@ class VarianceThreshold(BaseTransformer):
471
468
  autogenerated=self._autogenerated,
472
469
  subproject=_SUBPROJECT,
473
470
  )
474
- output_result, fitted_estimator = model_trainer.train_fit_predict(
475
- drop_input_cols=self._drop_input_cols,
476
- expected_output_cols_list=(
477
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
478
- ),
471
+ expected_output_cols = (
472
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
479
473
  )
474
+ if isinstance(dataset, DataFrame):
475
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
476
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
477
+ )
478
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
479
+ drop_input_cols=self._drop_input_cols,
480
+ expected_output_cols_list=expected_output_cols,
481
+ example_output_pd_df=example_output_pd_df,
482
+ )
483
+ else:
484
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
485
+ drop_input_cols=self._drop_input_cols,
486
+ expected_output_cols_list=expected_output_cols,
487
+ )
480
488
  self._sklearn_object = fitted_estimator
481
489
  self._is_fitted = True
482
490
  return output_result
@@ -501,6 +509,7 @@ class VarianceThreshold(BaseTransformer):
501
509
  """
502
510
  self._infer_input_output_cols(dataset)
503
511
  super()._check_dataset_type(dataset)
512
+
504
513
  model_trainer = ModelTrainerBuilder.build_fit_transform(
505
514
  estimator=self._sklearn_object,
506
515
  dataset=dataset,
@@ -557,12 +566,41 @@ class VarianceThreshold(BaseTransformer):
557
566
 
558
567
  return rv
559
568
 
560
- def _align_expected_output_names(
561
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
562
- ) -> List[str]:
569
+ def _align_expected_output(
570
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
571
+ ) -> Tuple[List[str], pd.DataFrame]:
572
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
573
+ and output dataframe with 1 line.
574
+ If the method is fit_predict, run 2 lines of data.
575
+ """
563
576
  # in case the inferred output column names dimension is different
564
577
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
565
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
578
+
579
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
580
+ # so change the minimum of number of rows to 2
581
+ num_examples = 2
582
+ statement_params = telemetry.get_function_usage_statement_params(
583
+ project=_PROJECT,
584
+ subproject=_SUBPROJECT,
585
+ function_name=telemetry.get_statement_params_full_func_name(
586
+ inspect.currentframe(), VarianceThreshold.__class__.__name__
587
+ ),
588
+ api_calls=[Session.call],
589
+ custom_tags={"autogen": True} if self._autogenerated else None,
590
+ )
591
+ if output_cols_prefix == "fit_predict_":
592
+ if hasattr(self._sklearn_object, "n_clusters"):
593
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
594
+ num_examples = self._sklearn_object.n_clusters
595
+ elif hasattr(self._sklearn_object, "min_samples"):
596
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
597
+ num_examples = self._sklearn_object.min_samples
598
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
599
+ # LocalOutlierFactor expects n_neighbors <= n_samples
600
+ num_examples = self._sklearn_object.n_neighbors
601
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
602
+ else:
603
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
566
604
 
567
605
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
568
606
  # seen during the fit.
@@ -574,12 +612,14 @@ class VarianceThreshold(BaseTransformer):
574
612
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
575
613
  if self.sample_weight_col:
576
614
  output_df_columns_set -= set(self.sample_weight_col)
615
+
577
616
  # if the dimension of inferred output column names is correct; use it
578
617
  if len(expected_output_cols_list) == len(output_df_columns_set):
579
- return expected_output_cols_list
618
+ return expected_output_cols_list, output_df_pd
580
619
  # otherwise, use the sklearn estimator's output
581
620
  else:
582
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
621
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
622
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
583
623
 
584
624
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
585
625
  @telemetry.send_api_usage_telemetry(
@@ -625,7 +665,7 @@ class VarianceThreshold(BaseTransformer):
625
665
  drop_input_cols=self._drop_input_cols,
626
666
  expected_output_cols_type="float",
627
667
  )
628
- expected_output_cols = self._align_expected_output_names(
668
+ expected_output_cols, _ = self._align_expected_output(
629
669
  inference_method, dataset, expected_output_cols, output_cols_prefix
630
670
  )
631
671
 
@@ -691,7 +731,7 @@ class VarianceThreshold(BaseTransformer):
691
731
  drop_input_cols=self._drop_input_cols,
692
732
  expected_output_cols_type="float",
693
733
  )
694
- expected_output_cols = self._align_expected_output_names(
734
+ expected_output_cols, _ = self._align_expected_output(
695
735
  inference_method, dataset, expected_output_cols, output_cols_prefix
696
736
  )
697
737
  elif isinstance(dataset, pd.DataFrame):
@@ -754,7 +794,7 @@ class VarianceThreshold(BaseTransformer):
754
794
  drop_input_cols=self._drop_input_cols,
755
795
  expected_output_cols_type="float",
756
796
  )
757
- expected_output_cols = self._align_expected_output_names(
797
+ expected_output_cols, _ = self._align_expected_output(
758
798
  inference_method, dataset, expected_output_cols, output_cols_prefix
759
799
  )
760
800
 
@@ -819,7 +859,7 @@ class VarianceThreshold(BaseTransformer):
819
859
  drop_input_cols = self._drop_input_cols,
820
860
  expected_output_cols_type="float",
821
861
  )
822
- expected_output_cols = self._align_expected_output_names(
862
+ expected_output_cols, _ = self._align_expected_output(
823
863
  inference_method, dataset, expected_output_cols, output_cols_prefix
824
864
  )
825
865