snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -561,12 +558,23 @@ class BisectingKMeans(BaseTransformer):
561
558
  autogenerated=self._autogenerated,
562
559
  subproject=_SUBPROJECT,
563
560
  )
564
- output_result, fitted_estimator = model_trainer.train_fit_predict(
565
- drop_input_cols=self._drop_input_cols,
566
- expected_output_cols_list=(
567
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
568
- ),
561
+ expected_output_cols = (
562
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
569
563
  )
564
+ if isinstance(dataset, DataFrame):
565
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
566
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
567
+ )
568
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
569
+ drop_input_cols=self._drop_input_cols,
570
+ expected_output_cols_list=expected_output_cols,
571
+ example_output_pd_df=example_output_pd_df,
572
+ )
573
+ else:
574
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
575
+ drop_input_cols=self._drop_input_cols,
576
+ expected_output_cols_list=expected_output_cols,
577
+ )
570
578
  self._sklearn_object = fitted_estimator
571
579
  self._is_fitted = True
572
580
  return output_result
@@ -591,6 +599,7 @@ class BisectingKMeans(BaseTransformer):
591
599
  """
592
600
  self._infer_input_output_cols(dataset)
593
601
  super()._check_dataset_type(dataset)
602
+
594
603
  model_trainer = ModelTrainerBuilder.build_fit_transform(
595
604
  estimator=self._sklearn_object,
596
605
  dataset=dataset,
@@ -647,12 +656,41 @@ class BisectingKMeans(BaseTransformer):
647
656
 
648
657
  return rv
649
658
 
650
- def _align_expected_output_names(
651
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
652
- ) -> List[str]:
659
+ def _align_expected_output(
660
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
661
+ ) -> Tuple[List[str], pd.DataFrame]:
662
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
663
+ and output dataframe with 1 line.
664
+ If the method is fit_predict, run 2 lines of data.
665
+ """
653
666
  # in case the inferred output column names dimension is different
654
667
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
655
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
668
+
669
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
670
+ # so change the minimum of number of rows to 2
671
+ num_examples = 2
672
+ statement_params = telemetry.get_function_usage_statement_params(
673
+ project=_PROJECT,
674
+ subproject=_SUBPROJECT,
675
+ function_name=telemetry.get_statement_params_full_func_name(
676
+ inspect.currentframe(), BisectingKMeans.__class__.__name__
677
+ ),
678
+ api_calls=[Session.call],
679
+ custom_tags={"autogen": True} if self._autogenerated else None,
680
+ )
681
+ if output_cols_prefix == "fit_predict_":
682
+ if hasattr(self._sklearn_object, "n_clusters"):
683
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
684
+ num_examples = self._sklearn_object.n_clusters
685
+ elif hasattr(self._sklearn_object, "min_samples"):
686
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
687
+ num_examples = self._sklearn_object.min_samples
688
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
689
+ # LocalOutlierFactor expects n_neighbors <= n_samples
690
+ num_examples = self._sklearn_object.n_neighbors
691
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
692
+ else:
693
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
656
694
 
657
695
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
658
696
  # seen during the fit.
@@ -664,12 +702,14 @@ class BisectingKMeans(BaseTransformer):
664
702
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
665
703
  if self.sample_weight_col:
666
704
  output_df_columns_set -= set(self.sample_weight_col)
705
+
667
706
  # if the dimension of inferred output column names is correct; use it
668
707
  if len(expected_output_cols_list) == len(output_df_columns_set):
669
- return expected_output_cols_list
708
+ return expected_output_cols_list, output_df_pd
670
709
  # otherwise, use the sklearn estimator's output
671
710
  else:
672
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
711
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
712
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
673
713
 
674
714
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
675
715
  @telemetry.send_api_usage_telemetry(
@@ -715,7 +755,7 @@ class BisectingKMeans(BaseTransformer):
715
755
  drop_input_cols=self._drop_input_cols,
716
756
  expected_output_cols_type="float",
717
757
  )
718
- expected_output_cols = self._align_expected_output_names(
758
+ expected_output_cols, _ = self._align_expected_output(
719
759
  inference_method, dataset, expected_output_cols, output_cols_prefix
720
760
  )
721
761
 
@@ -781,7 +821,7 @@ class BisectingKMeans(BaseTransformer):
781
821
  drop_input_cols=self._drop_input_cols,
782
822
  expected_output_cols_type="float",
783
823
  )
784
- expected_output_cols = self._align_expected_output_names(
824
+ expected_output_cols, _ = self._align_expected_output(
785
825
  inference_method, dataset, expected_output_cols, output_cols_prefix
786
826
  )
787
827
  elif isinstance(dataset, pd.DataFrame):
@@ -844,7 +884,7 @@ class BisectingKMeans(BaseTransformer):
844
884
  drop_input_cols=self._drop_input_cols,
845
885
  expected_output_cols_type="float",
846
886
  )
847
- expected_output_cols = self._align_expected_output_names(
887
+ expected_output_cols, _ = self._align_expected_output(
848
888
  inference_method, dataset, expected_output_cols, output_cols_prefix
849
889
  )
850
890
 
@@ -909,7 +949,7 @@ class BisectingKMeans(BaseTransformer):
909
949
  drop_input_cols = self._drop_input_cols,
910
950
  expected_output_cols_type="float",
911
951
  )
912
- expected_output_cols = self._align_expected_output_names(
952
+ expected_output_cols, _ = self._align_expected_output(
913
953
  inference_method, dataset, expected_output_cols, output_cols_prefix
914
954
  )
915
955
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -525,12 +522,23 @@ class DBSCAN(BaseTransformer):
525
522
  autogenerated=self._autogenerated,
526
523
  subproject=_SUBPROJECT,
527
524
  )
528
- output_result, fitted_estimator = model_trainer.train_fit_predict(
529
- drop_input_cols=self._drop_input_cols,
530
- expected_output_cols_list=(
531
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
532
- ),
525
+ expected_output_cols = (
526
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
533
527
  )
528
+ if isinstance(dataset, DataFrame):
529
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
530
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
531
+ )
532
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
533
+ drop_input_cols=self._drop_input_cols,
534
+ expected_output_cols_list=expected_output_cols,
535
+ example_output_pd_df=example_output_pd_df,
536
+ )
537
+ else:
538
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
539
+ drop_input_cols=self._drop_input_cols,
540
+ expected_output_cols_list=expected_output_cols,
541
+ )
534
542
  self._sklearn_object = fitted_estimator
535
543
  self._is_fitted = True
536
544
  return output_result
@@ -553,6 +561,7 @@ class DBSCAN(BaseTransformer):
553
561
  """
554
562
  self._infer_input_output_cols(dataset)
555
563
  super()._check_dataset_type(dataset)
564
+
556
565
  model_trainer = ModelTrainerBuilder.build_fit_transform(
557
566
  estimator=self._sklearn_object,
558
567
  dataset=dataset,
@@ -609,12 +618,41 @@ class DBSCAN(BaseTransformer):
609
618
 
610
619
  return rv
611
620
 
612
- def _align_expected_output_names(
613
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
614
- ) -> List[str]:
621
+ def _align_expected_output(
622
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
623
+ ) -> Tuple[List[str], pd.DataFrame]:
624
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
625
+ and output dataframe with 1 line.
626
+ If the method is fit_predict, run 2 lines of data.
627
+ """
615
628
  # in case the inferred output column names dimension is different
616
629
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
617
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
630
+
631
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
632
+ # so change the minimum of number of rows to 2
633
+ num_examples = 2
634
+ statement_params = telemetry.get_function_usage_statement_params(
635
+ project=_PROJECT,
636
+ subproject=_SUBPROJECT,
637
+ function_name=telemetry.get_statement_params_full_func_name(
638
+ inspect.currentframe(), DBSCAN.__class__.__name__
639
+ ),
640
+ api_calls=[Session.call],
641
+ custom_tags={"autogen": True} if self._autogenerated else None,
642
+ )
643
+ if output_cols_prefix == "fit_predict_":
644
+ if hasattr(self._sklearn_object, "n_clusters"):
645
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
646
+ num_examples = self._sklearn_object.n_clusters
647
+ elif hasattr(self._sklearn_object, "min_samples"):
648
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
649
+ num_examples = self._sklearn_object.min_samples
650
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
651
+ # LocalOutlierFactor expects n_neighbors <= n_samples
652
+ num_examples = self._sklearn_object.n_neighbors
653
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
654
+ else:
655
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
618
656
 
619
657
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
620
658
  # seen during the fit.
@@ -626,12 +664,14 @@ class DBSCAN(BaseTransformer):
626
664
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
627
665
  if self.sample_weight_col:
628
666
  output_df_columns_set -= set(self.sample_weight_col)
667
+
629
668
  # if the dimension of inferred output column names is correct; use it
630
669
  if len(expected_output_cols_list) == len(output_df_columns_set):
631
- return expected_output_cols_list
670
+ return expected_output_cols_list, output_df_pd
632
671
  # otherwise, use the sklearn estimator's output
633
672
  else:
634
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
673
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
674
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
635
675
 
636
676
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
637
677
  @telemetry.send_api_usage_telemetry(
@@ -677,7 +717,7 @@ class DBSCAN(BaseTransformer):
677
717
  drop_input_cols=self._drop_input_cols,
678
718
  expected_output_cols_type="float",
679
719
  )
680
- expected_output_cols = self._align_expected_output_names(
720
+ expected_output_cols, _ = self._align_expected_output(
681
721
  inference_method, dataset, expected_output_cols, output_cols_prefix
682
722
  )
683
723
 
@@ -743,7 +783,7 @@ class DBSCAN(BaseTransformer):
743
783
  drop_input_cols=self._drop_input_cols,
744
784
  expected_output_cols_type="float",
745
785
  )
746
- expected_output_cols = self._align_expected_output_names(
786
+ expected_output_cols, _ = self._align_expected_output(
747
787
  inference_method, dataset, expected_output_cols, output_cols_prefix
748
788
  )
749
789
  elif isinstance(dataset, pd.DataFrame):
@@ -806,7 +846,7 @@ class DBSCAN(BaseTransformer):
806
846
  drop_input_cols=self._drop_input_cols,
807
847
  expected_output_cols_type="float",
808
848
  )
809
- expected_output_cols = self._align_expected_output_names(
849
+ expected_output_cols, _ = self._align_expected_output(
810
850
  inference_method, dataset, expected_output_cols, output_cols_prefix
811
851
  )
812
852
 
@@ -871,7 +911,7 @@ class DBSCAN(BaseTransformer):
871
911
  drop_input_cols = self._drop_input_cols,
872
912
  expected_output_cols_type="float",
873
913
  )
874
- expected_output_cols = self._align_expected_output_names(
914
+ expected_output_cols, _ = self._align_expected_output(
875
915
  inference_method, dataset, expected_output_cols, output_cols_prefix
876
916
  )
877
917
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -559,12 +556,23 @@ class FeatureAgglomeration(BaseTransformer):
559
556
  autogenerated=self._autogenerated,
560
557
  subproject=_SUBPROJECT,
561
558
  )
562
- output_result, fitted_estimator = model_trainer.train_fit_predict(
563
- drop_input_cols=self._drop_input_cols,
564
- expected_output_cols_list=(
565
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
566
- ),
559
+ expected_output_cols = (
560
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
567
561
  )
562
+ if isinstance(dataset, DataFrame):
563
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
564
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
565
+ )
566
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
567
+ drop_input_cols=self._drop_input_cols,
568
+ expected_output_cols_list=expected_output_cols,
569
+ example_output_pd_df=example_output_pd_df,
570
+ )
571
+ else:
572
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
573
+ drop_input_cols=self._drop_input_cols,
574
+ expected_output_cols_list=expected_output_cols,
575
+ )
568
576
  self._sklearn_object = fitted_estimator
569
577
  self._is_fitted = True
570
578
  return output_result
@@ -589,6 +597,7 @@ class FeatureAgglomeration(BaseTransformer):
589
597
  """
590
598
  self._infer_input_output_cols(dataset)
591
599
  super()._check_dataset_type(dataset)
600
+
592
601
  model_trainer = ModelTrainerBuilder.build_fit_transform(
593
602
  estimator=self._sklearn_object,
594
603
  dataset=dataset,
@@ -645,12 +654,41 @@ class FeatureAgglomeration(BaseTransformer):
645
654
 
646
655
  return rv
647
656
 
648
- def _align_expected_output_names(
649
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
650
- ) -> List[str]:
657
+ def _align_expected_output(
658
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
659
+ ) -> Tuple[List[str], pd.DataFrame]:
660
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
661
+ and output dataframe with 1 line.
662
+ If the method is fit_predict, run 2 lines of data.
663
+ """
651
664
  # in case the inferred output column names dimension is different
652
665
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
653
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
666
+
667
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
668
+ # so change the minimum of number of rows to 2
669
+ num_examples = 2
670
+ statement_params = telemetry.get_function_usage_statement_params(
671
+ project=_PROJECT,
672
+ subproject=_SUBPROJECT,
673
+ function_name=telemetry.get_statement_params_full_func_name(
674
+ inspect.currentframe(), FeatureAgglomeration.__class__.__name__
675
+ ),
676
+ api_calls=[Session.call],
677
+ custom_tags={"autogen": True} if self._autogenerated else None,
678
+ )
679
+ if output_cols_prefix == "fit_predict_":
680
+ if hasattr(self._sklearn_object, "n_clusters"):
681
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
682
+ num_examples = self._sklearn_object.n_clusters
683
+ elif hasattr(self._sklearn_object, "min_samples"):
684
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
685
+ num_examples = self._sklearn_object.min_samples
686
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
687
+ # LocalOutlierFactor expects n_neighbors <= n_samples
688
+ num_examples = self._sklearn_object.n_neighbors
689
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
690
+ else:
691
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
654
692
 
655
693
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
656
694
  # seen during the fit.
@@ -662,12 +700,14 @@ class FeatureAgglomeration(BaseTransformer):
662
700
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
663
701
  if self.sample_weight_col:
664
702
  output_df_columns_set -= set(self.sample_weight_col)
703
+
665
704
  # if the dimension of inferred output column names is correct; use it
666
705
  if len(expected_output_cols_list) == len(output_df_columns_set):
667
- return expected_output_cols_list
706
+ return expected_output_cols_list, output_df_pd
668
707
  # otherwise, use the sklearn estimator's output
669
708
  else:
670
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
709
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
710
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
671
711
 
672
712
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
673
713
  @telemetry.send_api_usage_telemetry(
@@ -713,7 +753,7 @@ class FeatureAgglomeration(BaseTransformer):
713
753
  drop_input_cols=self._drop_input_cols,
714
754
  expected_output_cols_type="float",
715
755
  )
716
- expected_output_cols = self._align_expected_output_names(
756
+ expected_output_cols, _ = self._align_expected_output(
717
757
  inference_method, dataset, expected_output_cols, output_cols_prefix
718
758
  )
719
759
 
@@ -779,7 +819,7 @@ class FeatureAgglomeration(BaseTransformer):
779
819
  drop_input_cols=self._drop_input_cols,
780
820
  expected_output_cols_type="float",
781
821
  )
782
- expected_output_cols = self._align_expected_output_names(
822
+ expected_output_cols, _ = self._align_expected_output(
783
823
  inference_method, dataset, expected_output_cols, output_cols_prefix
784
824
  )
785
825
  elif isinstance(dataset, pd.DataFrame):
@@ -842,7 +882,7 @@ class FeatureAgglomeration(BaseTransformer):
842
882
  drop_input_cols=self._drop_input_cols,
843
883
  expected_output_cols_type="float",
844
884
  )
845
- expected_output_cols = self._align_expected_output_names(
885
+ expected_output_cols, _ = self._align_expected_output(
846
886
  inference_method, dataset, expected_output_cols, output_cols_prefix
847
887
  )
848
888
 
@@ -907,7 +947,7 @@ class FeatureAgglomeration(BaseTransformer):
907
947
  drop_input_cols = self._drop_input_cols,
908
948
  expected_output_cols_type="float",
909
949
  )
910
- expected_output_cols = self._align_expected_output_names(
950
+ expected_output_cols, _ = self._align_expected_output(
911
951
  inference_method, dataset, expected_output_cols, output_cols_prefix
912
952
  )
913
953
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -556,12 +553,23 @@ class KMeans(BaseTransformer):
556
553
  autogenerated=self._autogenerated,
557
554
  subproject=_SUBPROJECT,
558
555
  )
559
- output_result, fitted_estimator = model_trainer.train_fit_predict(
560
- drop_input_cols=self._drop_input_cols,
561
- expected_output_cols_list=(
562
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
563
- ),
556
+ expected_output_cols = (
557
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
564
558
  )
559
+ if isinstance(dataset, DataFrame):
560
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
561
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
562
+ )
563
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
564
+ drop_input_cols=self._drop_input_cols,
565
+ expected_output_cols_list=expected_output_cols,
566
+ example_output_pd_df=example_output_pd_df,
567
+ )
568
+ else:
569
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
570
+ drop_input_cols=self._drop_input_cols,
571
+ expected_output_cols_list=expected_output_cols,
572
+ )
565
573
  self._sklearn_object = fitted_estimator
566
574
  self._is_fitted = True
567
575
  return output_result
@@ -586,6 +594,7 @@ class KMeans(BaseTransformer):
586
594
  """
587
595
  self._infer_input_output_cols(dataset)
588
596
  super()._check_dataset_type(dataset)
597
+
589
598
  model_trainer = ModelTrainerBuilder.build_fit_transform(
590
599
  estimator=self._sklearn_object,
591
600
  dataset=dataset,
@@ -642,12 +651,41 @@ class KMeans(BaseTransformer):
642
651
 
643
652
  return rv
644
653
 
645
- def _align_expected_output_names(
646
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
647
- ) -> List[str]:
654
+ def _align_expected_output(
655
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
656
+ ) -> Tuple[List[str], pd.DataFrame]:
657
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
658
+ and output dataframe with 1 line.
659
+ If the method is fit_predict, run 2 lines of data.
660
+ """
648
661
  # in case the inferred output column names dimension is different
649
662
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
650
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
663
+
664
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
665
+ # so change the minimum of number of rows to 2
666
+ num_examples = 2
667
+ statement_params = telemetry.get_function_usage_statement_params(
668
+ project=_PROJECT,
669
+ subproject=_SUBPROJECT,
670
+ function_name=telemetry.get_statement_params_full_func_name(
671
+ inspect.currentframe(), KMeans.__class__.__name__
672
+ ),
673
+ api_calls=[Session.call],
674
+ custom_tags={"autogen": True} if self._autogenerated else None,
675
+ )
676
+ if output_cols_prefix == "fit_predict_":
677
+ if hasattr(self._sklearn_object, "n_clusters"):
678
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
679
+ num_examples = self._sklearn_object.n_clusters
680
+ elif hasattr(self._sklearn_object, "min_samples"):
681
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
682
+ num_examples = self._sklearn_object.min_samples
683
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
684
+ # LocalOutlierFactor expects n_neighbors <= n_samples
685
+ num_examples = self._sklearn_object.n_neighbors
686
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
687
+ else:
688
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
651
689
 
652
690
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
653
691
  # seen during the fit.
@@ -659,12 +697,14 @@ class KMeans(BaseTransformer):
659
697
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
660
698
  if self.sample_weight_col:
661
699
  output_df_columns_set -= set(self.sample_weight_col)
700
+
662
701
  # if the dimension of inferred output column names is correct; use it
663
702
  if len(expected_output_cols_list) == len(output_df_columns_set):
664
- return expected_output_cols_list
703
+ return expected_output_cols_list, output_df_pd
665
704
  # otherwise, use the sklearn estimator's output
666
705
  else:
667
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
706
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
707
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
668
708
 
669
709
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
670
710
  @telemetry.send_api_usage_telemetry(
@@ -710,7 +750,7 @@ class KMeans(BaseTransformer):
710
750
  drop_input_cols=self._drop_input_cols,
711
751
  expected_output_cols_type="float",
712
752
  )
713
- expected_output_cols = self._align_expected_output_names(
753
+ expected_output_cols, _ = self._align_expected_output(
714
754
  inference_method, dataset, expected_output_cols, output_cols_prefix
715
755
  )
716
756
 
@@ -776,7 +816,7 @@ class KMeans(BaseTransformer):
776
816
  drop_input_cols=self._drop_input_cols,
777
817
  expected_output_cols_type="float",
778
818
  )
779
- expected_output_cols = self._align_expected_output_names(
819
+ expected_output_cols, _ = self._align_expected_output(
780
820
  inference_method, dataset, expected_output_cols, output_cols_prefix
781
821
  )
782
822
  elif isinstance(dataset, pd.DataFrame):
@@ -839,7 +879,7 @@ class KMeans(BaseTransformer):
839
879
  drop_input_cols=self._drop_input_cols,
840
880
  expected_output_cols_type="float",
841
881
  )
842
- expected_output_cols = self._align_expected_output_names(
882
+ expected_output_cols, _ = self._align_expected_output(
843
883
  inference_method, dataset, expected_output_cols, output_cols_prefix
844
884
  )
845
885
 
@@ -904,7 +944,7 @@ class KMeans(BaseTransformer):
904
944
  drop_input_cols = self._drop_input_cols,
905
945
  expected_output_cols_type="float",
906
946
  )
907
- expected_output_cols = self._align_expected_output_names(
947
+ expected_output_cols, _ = self._align_expected_output(
908
948
  inference_method, dataset, expected_output_cols, output_cols_prefix
909
949
  )
910
950