snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -540,12 +537,23 @@ class CalibratedClassifierCV(BaseTransformer):
540
537
  autogenerated=self._autogenerated,
541
538
  subproject=_SUBPROJECT,
542
539
  )
543
- output_result, fitted_estimator = model_trainer.train_fit_predict(
544
- drop_input_cols=self._drop_input_cols,
545
- expected_output_cols_list=(
546
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
547
- ),
540
+ expected_output_cols = (
541
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
548
542
  )
543
+ if isinstance(dataset, DataFrame):
544
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
545
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
546
+ )
547
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
548
+ drop_input_cols=self._drop_input_cols,
549
+ expected_output_cols_list=expected_output_cols,
550
+ example_output_pd_df=example_output_pd_df,
551
+ )
552
+ else:
553
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
554
+ drop_input_cols=self._drop_input_cols,
555
+ expected_output_cols_list=expected_output_cols,
556
+ )
549
557
  self._sklearn_object = fitted_estimator
550
558
  self._is_fitted = True
551
559
  return output_result
@@ -568,6 +576,7 @@ class CalibratedClassifierCV(BaseTransformer):
568
576
  """
569
577
  self._infer_input_output_cols(dataset)
570
578
  super()._check_dataset_type(dataset)
579
+
571
580
  model_trainer = ModelTrainerBuilder.build_fit_transform(
572
581
  estimator=self._sklearn_object,
573
582
  dataset=dataset,
@@ -624,12 +633,41 @@ class CalibratedClassifierCV(BaseTransformer):
624
633
 
625
634
  return rv
626
635
 
627
- def _align_expected_output_names(
628
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
629
- ) -> List[str]:
636
+ def _align_expected_output(
637
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
638
+ ) -> Tuple[List[str], pd.DataFrame]:
639
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
640
+ and output dataframe with 1 line.
641
+ If the method is fit_predict, run 2 lines of data.
642
+ """
630
643
  # in case the inferred output column names dimension is different
631
644
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
632
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
645
+
646
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
647
+ # so change the minimum of number of rows to 2
648
+ num_examples = 2
649
+ statement_params = telemetry.get_function_usage_statement_params(
650
+ project=_PROJECT,
651
+ subproject=_SUBPROJECT,
652
+ function_name=telemetry.get_statement_params_full_func_name(
653
+ inspect.currentframe(), CalibratedClassifierCV.__class__.__name__
654
+ ),
655
+ api_calls=[Session.call],
656
+ custom_tags={"autogen": True} if self._autogenerated else None,
657
+ )
658
+ if output_cols_prefix == "fit_predict_":
659
+ if hasattr(self._sklearn_object, "n_clusters"):
660
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
661
+ num_examples = self._sklearn_object.n_clusters
662
+ elif hasattr(self._sklearn_object, "min_samples"):
663
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
664
+ num_examples = self._sklearn_object.min_samples
665
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
666
+ # LocalOutlierFactor expects n_neighbors <= n_samples
667
+ num_examples = self._sklearn_object.n_neighbors
668
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
669
+ else:
670
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
633
671
 
634
672
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
635
673
  # seen during the fit.
@@ -641,12 +679,14 @@ class CalibratedClassifierCV(BaseTransformer):
641
679
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
642
680
  if self.sample_weight_col:
643
681
  output_df_columns_set -= set(self.sample_weight_col)
682
+
644
683
  # if the dimension of inferred output column names is correct; use it
645
684
  if len(expected_output_cols_list) == len(output_df_columns_set):
646
- return expected_output_cols_list
685
+ return expected_output_cols_list, output_df_pd
647
686
  # otherwise, use the sklearn estimator's output
648
687
  else:
649
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
688
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
689
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
650
690
 
651
691
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
652
692
  @telemetry.send_api_usage_telemetry(
@@ -694,7 +734,7 @@ class CalibratedClassifierCV(BaseTransformer):
694
734
  drop_input_cols=self._drop_input_cols,
695
735
  expected_output_cols_type="float",
696
736
  )
697
- expected_output_cols = self._align_expected_output_names(
737
+ expected_output_cols, _ = self._align_expected_output(
698
738
  inference_method, dataset, expected_output_cols, output_cols_prefix
699
739
  )
700
740
 
@@ -762,7 +802,7 @@ class CalibratedClassifierCV(BaseTransformer):
762
802
  drop_input_cols=self._drop_input_cols,
763
803
  expected_output_cols_type="float",
764
804
  )
765
- expected_output_cols = self._align_expected_output_names(
805
+ expected_output_cols, _ = self._align_expected_output(
766
806
  inference_method, dataset, expected_output_cols, output_cols_prefix
767
807
  )
768
808
  elif isinstance(dataset, pd.DataFrame):
@@ -825,7 +865,7 @@ class CalibratedClassifierCV(BaseTransformer):
825
865
  drop_input_cols=self._drop_input_cols,
826
866
  expected_output_cols_type="float",
827
867
  )
828
- expected_output_cols = self._align_expected_output_names(
868
+ expected_output_cols, _ = self._align_expected_output(
829
869
  inference_method, dataset, expected_output_cols, output_cols_prefix
830
870
  )
831
871
 
@@ -890,7 +930,7 @@ class CalibratedClassifierCV(BaseTransformer):
890
930
  drop_input_cols = self._drop_input_cols,
891
931
  expected_output_cols_type="float",
892
932
  )
893
- expected_output_cols = self._align_expected_output_names(
933
+ expected_output_cols, _ = self._align_expected_output(
894
934
  inference_method, dataset, expected_output_cols, output_cols_prefix
895
935
  )
896
936
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -519,12 +516,23 @@ class AffinityPropagation(BaseTransformer):
519
516
  autogenerated=self._autogenerated,
520
517
  subproject=_SUBPROJECT,
521
518
  )
522
- output_result, fitted_estimator = model_trainer.train_fit_predict(
523
- drop_input_cols=self._drop_input_cols,
524
- expected_output_cols_list=(
525
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
526
- ),
519
+ expected_output_cols = (
520
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
527
521
  )
522
+ if isinstance(dataset, DataFrame):
523
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
524
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
525
+ )
526
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
527
+ drop_input_cols=self._drop_input_cols,
528
+ expected_output_cols_list=expected_output_cols,
529
+ example_output_pd_df=example_output_pd_df,
530
+ )
531
+ else:
532
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
533
+ drop_input_cols=self._drop_input_cols,
534
+ expected_output_cols_list=expected_output_cols,
535
+ )
528
536
  self._sklearn_object = fitted_estimator
529
537
  self._is_fitted = True
530
538
  return output_result
@@ -547,6 +555,7 @@ class AffinityPropagation(BaseTransformer):
547
555
  """
548
556
  self._infer_input_output_cols(dataset)
549
557
  super()._check_dataset_type(dataset)
558
+
550
559
  model_trainer = ModelTrainerBuilder.build_fit_transform(
551
560
  estimator=self._sklearn_object,
552
561
  dataset=dataset,
@@ -603,12 +612,41 @@ class AffinityPropagation(BaseTransformer):
603
612
 
604
613
  return rv
605
614
 
606
- def _align_expected_output_names(
607
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
608
- ) -> List[str]:
615
+ def _align_expected_output(
616
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
617
+ ) -> Tuple[List[str], pd.DataFrame]:
618
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
619
+ and output dataframe with 1 line.
620
+ If the method is fit_predict, run 2 lines of data.
621
+ """
609
622
  # in case the inferred output column names dimension is different
610
623
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
611
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
624
+
625
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
626
+ # so change the minimum of number of rows to 2
627
+ num_examples = 2
628
+ statement_params = telemetry.get_function_usage_statement_params(
629
+ project=_PROJECT,
630
+ subproject=_SUBPROJECT,
631
+ function_name=telemetry.get_statement_params_full_func_name(
632
+ inspect.currentframe(), AffinityPropagation.__class__.__name__
633
+ ),
634
+ api_calls=[Session.call],
635
+ custom_tags={"autogen": True} if self._autogenerated else None,
636
+ )
637
+ if output_cols_prefix == "fit_predict_":
638
+ if hasattr(self._sklearn_object, "n_clusters"):
639
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
640
+ num_examples = self._sklearn_object.n_clusters
641
+ elif hasattr(self._sklearn_object, "min_samples"):
642
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
643
+ num_examples = self._sklearn_object.min_samples
644
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
645
+ # LocalOutlierFactor expects n_neighbors <= n_samples
646
+ num_examples = self._sklearn_object.n_neighbors
647
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
648
+ else:
649
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
612
650
 
613
651
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
614
652
  # seen during the fit.
@@ -620,12 +658,14 @@ class AffinityPropagation(BaseTransformer):
620
658
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
621
659
  if self.sample_weight_col:
622
660
  output_df_columns_set -= set(self.sample_weight_col)
661
+
623
662
  # if the dimension of inferred output column names is correct; use it
624
663
  if len(expected_output_cols_list) == len(output_df_columns_set):
625
- return expected_output_cols_list
664
+ return expected_output_cols_list, output_df_pd
626
665
  # otherwise, use the sklearn estimator's output
627
666
  else:
628
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
667
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
668
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
629
669
 
630
670
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
631
671
  @telemetry.send_api_usage_telemetry(
@@ -671,7 +711,7 @@ class AffinityPropagation(BaseTransformer):
671
711
  drop_input_cols=self._drop_input_cols,
672
712
  expected_output_cols_type="float",
673
713
  )
674
- expected_output_cols = self._align_expected_output_names(
714
+ expected_output_cols, _ = self._align_expected_output(
675
715
  inference_method, dataset, expected_output_cols, output_cols_prefix
676
716
  )
677
717
 
@@ -737,7 +777,7 @@ class AffinityPropagation(BaseTransformer):
737
777
  drop_input_cols=self._drop_input_cols,
738
778
  expected_output_cols_type="float",
739
779
  )
740
- expected_output_cols = self._align_expected_output_names(
780
+ expected_output_cols, _ = self._align_expected_output(
741
781
  inference_method, dataset, expected_output_cols, output_cols_prefix
742
782
  )
743
783
  elif isinstance(dataset, pd.DataFrame):
@@ -800,7 +840,7 @@ class AffinityPropagation(BaseTransformer):
800
840
  drop_input_cols=self._drop_input_cols,
801
841
  expected_output_cols_type="float",
802
842
  )
803
- expected_output_cols = self._align_expected_output_names(
843
+ expected_output_cols, _ = self._align_expected_output(
804
844
  inference_method, dataset, expected_output_cols, output_cols_prefix
805
845
  )
806
846
 
@@ -865,7 +905,7 @@ class AffinityPropagation(BaseTransformer):
865
905
  drop_input_cols = self._drop_input_cols,
866
906
  expected_output_cols_type="float",
867
907
  )
868
- expected_output_cols = self._align_expected_output_names(
908
+ expected_output_cols, _ = self._align_expected_output(
869
909
  inference_method, dataset, expected_output_cols, output_cols_prefix
870
910
  )
871
911
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -550,12 +547,23 @@ class AgglomerativeClustering(BaseTransformer):
550
547
  autogenerated=self._autogenerated,
551
548
  subproject=_SUBPROJECT,
552
549
  )
553
- output_result, fitted_estimator = model_trainer.train_fit_predict(
554
- drop_input_cols=self._drop_input_cols,
555
- expected_output_cols_list=(
556
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
557
- ),
550
+ expected_output_cols = (
551
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
558
552
  )
553
+ if isinstance(dataset, DataFrame):
554
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
555
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
556
+ )
557
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
558
+ drop_input_cols=self._drop_input_cols,
559
+ expected_output_cols_list=expected_output_cols,
560
+ example_output_pd_df=example_output_pd_df,
561
+ )
562
+ else:
563
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
564
+ drop_input_cols=self._drop_input_cols,
565
+ expected_output_cols_list=expected_output_cols,
566
+ )
559
567
  self._sklearn_object = fitted_estimator
560
568
  self._is_fitted = True
561
569
  return output_result
@@ -578,6 +586,7 @@ class AgglomerativeClustering(BaseTransformer):
578
586
  """
579
587
  self._infer_input_output_cols(dataset)
580
588
  super()._check_dataset_type(dataset)
589
+
581
590
  model_trainer = ModelTrainerBuilder.build_fit_transform(
582
591
  estimator=self._sklearn_object,
583
592
  dataset=dataset,
@@ -634,12 +643,41 @@ class AgglomerativeClustering(BaseTransformer):
634
643
 
635
644
  return rv
636
645
 
637
- def _align_expected_output_names(
638
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
639
- ) -> List[str]:
646
+ def _align_expected_output(
647
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
648
+ ) -> Tuple[List[str], pd.DataFrame]:
649
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
650
+ and output dataframe with 1 line.
651
+ If the method is fit_predict, run 2 lines of data.
652
+ """
640
653
  # in case the inferred output column names dimension is different
641
654
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
642
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
655
+
656
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
657
+ # so change the minimum of number of rows to 2
658
+ num_examples = 2
659
+ statement_params = telemetry.get_function_usage_statement_params(
660
+ project=_PROJECT,
661
+ subproject=_SUBPROJECT,
662
+ function_name=telemetry.get_statement_params_full_func_name(
663
+ inspect.currentframe(), AgglomerativeClustering.__class__.__name__
664
+ ),
665
+ api_calls=[Session.call],
666
+ custom_tags={"autogen": True} if self._autogenerated else None,
667
+ )
668
+ if output_cols_prefix == "fit_predict_":
669
+ if hasattr(self._sklearn_object, "n_clusters"):
670
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
671
+ num_examples = self._sklearn_object.n_clusters
672
+ elif hasattr(self._sklearn_object, "min_samples"):
673
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
674
+ num_examples = self._sklearn_object.min_samples
675
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
676
+ # LocalOutlierFactor expects n_neighbors <= n_samples
677
+ num_examples = self._sklearn_object.n_neighbors
678
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
679
+ else:
680
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
643
681
 
644
682
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
645
683
  # seen during the fit.
@@ -651,12 +689,14 @@ class AgglomerativeClustering(BaseTransformer):
651
689
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
652
690
  if self.sample_weight_col:
653
691
  output_df_columns_set -= set(self.sample_weight_col)
692
+
654
693
  # if the dimension of inferred output column names is correct; use it
655
694
  if len(expected_output_cols_list) == len(output_df_columns_set):
656
- return expected_output_cols_list
695
+ return expected_output_cols_list, output_df_pd
657
696
  # otherwise, use the sklearn estimator's output
658
697
  else:
659
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
698
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
699
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
660
700
 
661
701
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
662
702
  @telemetry.send_api_usage_telemetry(
@@ -702,7 +742,7 @@ class AgglomerativeClustering(BaseTransformer):
702
742
  drop_input_cols=self._drop_input_cols,
703
743
  expected_output_cols_type="float",
704
744
  )
705
- expected_output_cols = self._align_expected_output_names(
745
+ expected_output_cols, _ = self._align_expected_output(
706
746
  inference_method, dataset, expected_output_cols, output_cols_prefix
707
747
  )
708
748
 
@@ -768,7 +808,7 @@ class AgglomerativeClustering(BaseTransformer):
768
808
  drop_input_cols=self._drop_input_cols,
769
809
  expected_output_cols_type="float",
770
810
  )
771
- expected_output_cols = self._align_expected_output_names(
811
+ expected_output_cols, _ = self._align_expected_output(
772
812
  inference_method, dataset, expected_output_cols, output_cols_prefix
773
813
  )
774
814
  elif isinstance(dataset, pd.DataFrame):
@@ -831,7 +871,7 @@ class AgglomerativeClustering(BaseTransformer):
831
871
  drop_input_cols=self._drop_input_cols,
832
872
  expected_output_cols_type="float",
833
873
  )
834
- expected_output_cols = self._align_expected_output_names(
874
+ expected_output_cols, _ = self._align_expected_output(
835
875
  inference_method, dataset, expected_output_cols, output_cols_prefix
836
876
  )
837
877
 
@@ -896,7 +936,7 @@ class AgglomerativeClustering(BaseTransformer):
896
936
  drop_input_cols = self._drop_input_cols,
897
937
  expected_output_cols_type="float",
898
938
  )
899
- expected_output_cols = self._align_expected_output_names(
939
+ expected_output_cols, _ = self._align_expected_output(
900
940
  inference_method, dataset, expected_output_cols, output_cols_prefix
901
941
  )
902
942
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -512,12 +509,23 @@ class Birch(BaseTransformer):
512
509
  autogenerated=self._autogenerated,
513
510
  subproject=_SUBPROJECT,
514
511
  )
515
- output_result, fitted_estimator = model_trainer.train_fit_predict(
516
- drop_input_cols=self._drop_input_cols,
517
- expected_output_cols_list=(
518
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
519
- ),
512
+ expected_output_cols = (
513
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
520
514
  )
515
+ if isinstance(dataset, DataFrame):
516
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
517
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
518
+ )
519
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
520
+ drop_input_cols=self._drop_input_cols,
521
+ expected_output_cols_list=expected_output_cols,
522
+ example_output_pd_df=example_output_pd_df,
523
+ )
524
+ else:
525
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
526
+ drop_input_cols=self._drop_input_cols,
527
+ expected_output_cols_list=expected_output_cols,
528
+ )
521
529
  self._sklearn_object = fitted_estimator
522
530
  self._is_fitted = True
523
531
  return output_result
@@ -542,6 +550,7 @@ class Birch(BaseTransformer):
542
550
  """
543
551
  self._infer_input_output_cols(dataset)
544
552
  super()._check_dataset_type(dataset)
553
+
545
554
  model_trainer = ModelTrainerBuilder.build_fit_transform(
546
555
  estimator=self._sklearn_object,
547
556
  dataset=dataset,
@@ -598,12 +607,41 @@ class Birch(BaseTransformer):
598
607
 
599
608
  return rv
600
609
 
601
- def _align_expected_output_names(
602
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
603
- ) -> List[str]:
610
+ def _align_expected_output(
611
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
612
+ ) -> Tuple[List[str], pd.DataFrame]:
613
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
614
+ and output dataframe with 1 line.
615
+ If the method is fit_predict, run 2 lines of data.
616
+ """
604
617
  # in case the inferred output column names dimension is different
605
618
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
606
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
619
+
620
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
621
+ # so change the minimum of number of rows to 2
622
+ num_examples = 2
623
+ statement_params = telemetry.get_function_usage_statement_params(
624
+ project=_PROJECT,
625
+ subproject=_SUBPROJECT,
626
+ function_name=telemetry.get_statement_params_full_func_name(
627
+ inspect.currentframe(), Birch.__class__.__name__
628
+ ),
629
+ api_calls=[Session.call],
630
+ custom_tags={"autogen": True} if self._autogenerated else None,
631
+ )
632
+ if output_cols_prefix == "fit_predict_":
633
+ if hasattr(self._sklearn_object, "n_clusters"):
634
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
635
+ num_examples = self._sklearn_object.n_clusters
636
+ elif hasattr(self._sklearn_object, "min_samples"):
637
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
638
+ num_examples = self._sklearn_object.min_samples
639
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
640
+ # LocalOutlierFactor expects n_neighbors <= n_samples
641
+ num_examples = self._sklearn_object.n_neighbors
642
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
643
+ else:
644
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
607
645
 
608
646
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
609
647
  # seen during the fit.
@@ -615,12 +653,14 @@ class Birch(BaseTransformer):
615
653
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
616
654
  if self.sample_weight_col:
617
655
  output_df_columns_set -= set(self.sample_weight_col)
656
+
618
657
  # if the dimension of inferred output column names is correct; use it
619
658
  if len(expected_output_cols_list) == len(output_df_columns_set):
620
- return expected_output_cols_list
659
+ return expected_output_cols_list, output_df_pd
621
660
  # otherwise, use the sklearn estimator's output
622
661
  else:
623
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
662
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
663
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
624
664
 
625
665
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
626
666
  @telemetry.send_api_usage_telemetry(
@@ -666,7 +706,7 @@ class Birch(BaseTransformer):
666
706
  drop_input_cols=self._drop_input_cols,
667
707
  expected_output_cols_type="float",
668
708
  )
669
- expected_output_cols = self._align_expected_output_names(
709
+ expected_output_cols, _ = self._align_expected_output(
670
710
  inference_method, dataset, expected_output_cols, output_cols_prefix
671
711
  )
672
712
 
@@ -732,7 +772,7 @@ class Birch(BaseTransformer):
732
772
  drop_input_cols=self._drop_input_cols,
733
773
  expected_output_cols_type="float",
734
774
  )
735
- expected_output_cols = self._align_expected_output_names(
775
+ expected_output_cols, _ = self._align_expected_output(
736
776
  inference_method, dataset, expected_output_cols, output_cols_prefix
737
777
  )
738
778
  elif isinstance(dataset, pd.DataFrame):
@@ -795,7 +835,7 @@ class Birch(BaseTransformer):
795
835
  drop_input_cols=self._drop_input_cols,
796
836
  expected_output_cols_type="float",
797
837
  )
798
- expected_output_cols = self._align_expected_output_names(
838
+ expected_output_cols, _ = self._align_expected_output(
799
839
  inference_method, dataset, expected_output_cols, output_cols_prefix
800
840
  )
801
841
 
@@ -860,7 +900,7 @@ class Birch(BaseTransformer):
860
900
  drop_input_cols = self._drop_input_cols,
861
901
  expected_output_cols_type="float",
862
902
  )
863
- expected_output_cols = self._align_expected_output_names(
903
+ expected_output_cols, _ = self._align_expected_output(
864
904
  inference_method, dataset, expected_output_cols, output_cols_prefix
865
905
  )
866
906