snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -530,12 +527,23 @@ class MeanShift(BaseTransformer):
530
527
  autogenerated=self._autogenerated,
531
528
  subproject=_SUBPROJECT,
532
529
  )
533
- output_result, fitted_estimator = model_trainer.train_fit_predict(
534
- drop_input_cols=self._drop_input_cols,
535
- expected_output_cols_list=(
536
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
537
- ),
530
+ expected_output_cols = (
531
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
538
532
  )
533
+ if isinstance(dataset, DataFrame):
534
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
535
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
536
+ )
537
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
538
+ drop_input_cols=self._drop_input_cols,
539
+ expected_output_cols_list=expected_output_cols,
540
+ example_output_pd_df=example_output_pd_df,
541
+ )
542
+ else:
543
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
544
+ drop_input_cols=self._drop_input_cols,
545
+ expected_output_cols_list=expected_output_cols,
546
+ )
539
547
  self._sklearn_object = fitted_estimator
540
548
  self._is_fitted = True
541
549
  return output_result
@@ -558,6 +566,7 @@ class MeanShift(BaseTransformer):
558
566
  """
559
567
  self._infer_input_output_cols(dataset)
560
568
  super()._check_dataset_type(dataset)
569
+
561
570
  model_trainer = ModelTrainerBuilder.build_fit_transform(
562
571
  estimator=self._sklearn_object,
563
572
  dataset=dataset,
@@ -614,12 +623,41 @@ class MeanShift(BaseTransformer):
614
623
 
615
624
  return rv
616
625
 
617
- def _align_expected_output_names(
618
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
619
- ) -> List[str]:
626
+ def _align_expected_output(
627
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
628
+ ) -> Tuple[List[str], pd.DataFrame]:
629
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
630
+ and output dataframe with 1 line.
631
+ If the method is fit_predict, run 2 lines of data.
632
+ """
620
633
  # in case the inferred output column names dimension is different
621
634
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
622
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
635
+
636
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
637
+ # so change the minimum of number of rows to 2
638
+ num_examples = 2
639
+ statement_params = telemetry.get_function_usage_statement_params(
640
+ project=_PROJECT,
641
+ subproject=_SUBPROJECT,
642
+ function_name=telemetry.get_statement_params_full_func_name(
643
+ inspect.currentframe(), MeanShift.__class__.__name__
644
+ ),
645
+ api_calls=[Session.call],
646
+ custom_tags={"autogen": True} if self._autogenerated else None,
647
+ )
648
+ if output_cols_prefix == "fit_predict_":
649
+ if hasattr(self._sklearn_object, "n_clusters"):
650
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
651
+ num_examples = self._sklearn_object.n_clusters
652
+ elif hasattr(self._sklearn_object, "min_samples"):
653
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
654
+ num_examples = self._sklearn_object.min_samples
655
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
656
+ # LocalOutlierFactor expects n_neighbors <= n_samples
657
+ num_examples = self._sklearn_object.n_neighbors
658
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
659
+ else:
660
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
623
661
 
624
662
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
625
663
  # seen during the fit.
@@ -631,12 +669,14 @@ class MeanShift(BaseTransformer):
631
669
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
632
670
  if self.sample_weight_col:
633
671
  output_df_columns_set -= set(self.sample_weight_col)
672
+
634
673
  # if the dimension of inferred output column names is correct; use it
635
674
  if len(expected_output_cols_list) == len(output_df_columns_set):
636
- return expected_output_cols_list
675
+ return expected_output_cols_list, output_df_pd
637
676
  # otherwise, use the sklearn estimator's output
638
677
  else:
639
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
678
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
679
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
640
680
 
641
681
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
642
682
  @telemetry.send_api_usage_telemetry(
@@ -682,7 +722,7 @@ class MeanShift(BaseTransformer):
682
722
  drop_input_cols=self._drop_input_cols,
683
723
  expected_output_cols_type="float",
684
724
  )
685
- expected_output_cols = self._align_expected_output_names(
725
+ expected_output_cols, _ = self._align_expected_output(
686
726
  inference_method, dataset, expected_output_cols, output_cols_prefix
687
727
  )
688
728
 
@@ -748,7 +788,7 @@ class MeanShift(BaseTransformer):
748
788
  drop_input_cols=self._drop_input_cols,
749
789
  expected_output_cols_type="float",
750
790
  )
751
- expected_output_cols = self._align_expected_output_names(
791
+ expected_output_cols, _ = self._align_expected_output(
752
792
  inference_method, dataset, expected_output_cols, output_cols_prefix
753
793
  )
754
794
  elif isinstance(dataset, pd.DataFrame):
@@ -811,7 +851,7 @@ class MeanShift(BaseTransformer):
811
851
  drop_input_cols=self._drop_input_cols,
812
852
  expected_output_cols_type="float",
813
853
  )
814
- expected_output_cols = self._align_expected_output_names(
854
+ expected_output_cols, _ = self._align_expected_output(
815
855
  inference_method, dataset, expected_output_cols, output_cols_prefix
816
856
  )
817
857
 
@@ -876,7 +916,7 @@ class MeanShift(BaseTransformer):
876
916
  drop_input_cols = self._drop_input_cols,
877
917
  expected_output_cols_type="float",
878
918
  )
879
- expected_output_cols = self._align_expected_output_names(
919
+ expected_output_cols, _ = self._align_expected_output(
880
920
  inference_method, dataset, expected_output_cols, output_cols_prefix
881
921
  )
882
922
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -582,12 +579,23 @@ class MiniBatchKMeans(BaseTransformer):
582
579
  autogenerated=self._autogenerated,
583
580
  subproject=_SUBPROJECT,
584
581
  )
585
- output_result, fitted_estimator = model_trainer.train_fit_predict(
586
- drop_input_cols=self._drop_input_cols,
587
- expected_output_cols_list=(
588
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
589
- ),
582
+ expected_output_cols = (
583
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
590
584
  )
585
+ if isinstance(dataset, DataFrame):
586
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
587
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
588
+ )
589
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
590
+ drop_input_cols=self._drop_input_cols,
591
+ expected_output_cols_list=expected_output_cols,
592
+ example_output_pd_df=example_output_pd_df,
593
+ )
594
+ else:
595
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
596
+ drop_input_cols=self._drop_input_cols,
597
+ expected_output_cols_list=expected_output_cols,
598
+ )
591
599
  self._sklearn_object = fitted_estimator
592
600
  self._is_fitted = True
593
601
  return output_result
@@ -612,6 +620,7 @@ class MiniBatchKMeans(BaseTransformer):
612
620
  """
613
621
  self._infer_input_output_cols(dataset)
614
622
  super()._check_dataset_type(dataset)
623
+
615
624
  model_trainer = ModelTrainerBuilder.build_fit_transform(
616
625
  estimator=self._sklearn_object,
617
626
  dataset=dataset,
@@ -668,12 +677,41 @@ class MiniBatchKMeans(BaseTransformer):
668
677
 
669
678
  return rv
670
679
 
671
- def _align_expected_output_names(
672
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
673
- ) -> List[str]:
680
+ def _align_expected_output(
681
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
682
+ ) -> Tuple[List[str], pd.DataFrame]:
683
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
684
+ and output dataframe with 1 line.
685
+ If the method is fit_predict, run 2 lines of data.
686
+ """
674
687
  # in case the inferred output column names dimension is different
675
688
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
676
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
689
+
690
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
691
+ # so change the minimum of number of rows to 2
692
+ num_examples = 2
693
+ statement_params = telemetry.get_function_usage_statement_params(
694
+ project=_PROJECT,
695
+ subproject=_SUBPROJECT,
696
+ function_name=telemetry.get_statement_params_full_func_name(
697
+ inspect.currentframe(), MiniBatchKMeans.__class__.__name__
698
+ ),
699
+ api_calls=[Session.call],
700
+ custom_tags={"autogen": True} if self._autogenerated else None,
701
+ )
702
+ if output_cols_prefix == "fit_predict_":
703
+ if hasattr(self._sklearn_object, "n_clusters"):
704
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
705
+ num_examples = self._sklearn_object.n_clusters
706
+ elif hasattr(self._sklearn_object, "min_samples"):
707
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
708
+ num_examples = self._sklearn_object.min_samples
709
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
710
+ # LocalOutlierFactor expects n_neighbors <= n_samples
711
+ num_examples = self._sklearn_object.n_neighbors
712
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
713
+ else:
714
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
677
715
 
678
716
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
679
717
  # seen during the fit.
@@ -685,12 +723,14 @@ class MiniBatchKMeans(BaseTransformer):
685
723
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
686
724
  if self.sample_weight_col:
687
725
  output_df_columns_set -= set(self.sample_weight_col)
726
+
688
727
  # if the dimension of inferred output column names is correct; use it
689
728
  if len(expected_output_cols_list) == len(output_df_columns_set):
690
- return expected_output_cols_list
729
+ return expected_output_cols_list, output_df_pd
691
730
  # otherwise, use the sklearn estimator's output
692
731
  else:
693
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
732
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
733
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
694
734
 
695
735
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
696
736
  @telemetry.send_api_usage_telemetry(
@@ -736,7 +776,7 @@ class MiniBatchKMeans(BaseTransformer):
736
776
  drop_input_cols=self._drop_input_cols,
737
777
  expected_output_cols_type="float",
738
778
  )
739
- expected_output_cols = self._align_expected_output_names(
779
+ expected_output_cols, _ = self._align_expected_output(
740
780
  inference_method, dataset, expected_output_cols, output_cols_prefix
741
781
  )
742
782
 
@@ -802,7 +842,7 @@ class MiniBatchKMeans(BaseTransformer):
802
842
  drop_input_cols=self._drop_input_cols,
803
843
  expected_output_cols_type="float",
804
844
  )
805
- expected_output_cols = self._align_expected_output_names(
845
+ expected_output_cols, _ = self._align_expected_output(
806
846
  inference_method, dataset, expected_output_cols, output_cols_prefix
807
847
  )
808
848
  elif isinstance(dataset, pd.DataFrame):
@@ -865,7 +905,7 @@ class MiniBatchKMeans(BaseTransformer):
865
905
  drop_input_cols=self._drop_input_cols,
866
906
  expected_output_cols_type="float",
867
907
  )
868
- expected_output_cols = self._align_expected_output_names(
908
+ expected_output_cols, _ = self._align_expected_output(
869
909
  inference_method, dataset, expected_output_cols, output_cols_prefix
870
910
  )
871
911
 
@@ -930,7 +970,7 @@ class MiniBatchKMeans(BaseTransformer):
930
970
  drop_input_cols = self._drop_input_cols,
931
971
  expected_output_cols_type="float",
932
972
  )
933
- expected_output_cols = self._align_expected_output_names(
973
+ expected_output_cols, _ = self._align_expected_output(
934
974
  inference_method, dataset, expected_output_cols, output_cols_prefix
935
975
  )
936
976
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -598,12 +595,23 @@ class OPTICS(BaseTransformer):
598
595
  autogenerated=self._autogenerated,
599
596
  subproject=_SUBPROJECT,
600
597
  )
601
- output_result, fitted_estimator = model_trainer.train_fit_predict(
602
- drop_input_cols=self._drop_input_cols,
603
- expected_output_cols_list=(
604
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
605
- ),
598
+ expected_output_cols = (
599
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
606
600
  )
601
+ if isinstance(dataset, DataFrame):
602
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
603
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
604
+ )
605
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
606
+ drop_input_cols=self._drop_input_cols,
607
+ expected_output_cols_list=expected_output_cols,
608
+ example_output_pd_df=example_output_pd_df,
609
+ )
610
+ else:
611
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
612
+ drop_input_cols=self._drop_input_cols,
613
+ expected_output_cols_list=expected_output_cols,
614
+ )
607
615
  self._sklearn_object = fitted_estimator
608
616
  self._is_fitted = True
609
617
  return output_result
@@ -626,6 +634,7 @@ class OPTICS(BaseTransformer):
626
634
  """
627
635
  self._infer_input_output_cols(dataset)
628
636
  super()._check_dataset_type(dataset)
637
+
629
638
  model_trainer = ModelTrainerBuilder.build_fit_transform(
630
639
  estimator=self._sklearn_object,
631
640
  dataset=dataset,
@@ -682,12 +691,41 @@ class OPTICS(BaseTransformer):
682
691
 
683
692
  return rv
684
693
 
685
- def _align_expected_output_names(
686
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
687
- ) -> List[str]:
694
+ def _align_expected_output(
695
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
696
+ ) -> Tuple[List[str], pd.DataFrame]:
697
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
698
+ and output dataframe with 1 line.
699
+ If the method is fit_predict, run 2 lines of data.
700
+ """
688
701
  # in case the inferred output column names dimension is different
689
702
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
690
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
703
+
704
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
705
+ # so change the minimum of number of rows to 2
706
+ num_examples = 2
707
+ statement_params = telemetry.get_function_usage_statement_params(
708
+ project=_PROJECT,
709
+ subproject=_SUBPROJECT,
710
+ function_name=telemetry.get_statement_params_full_func_name(
711
+ inspect.currentframe(), OPTICS.__class__.__name__
712
+ ),
713
+ api_calls=[Session.call],
714
+ custom_tags={"autogen": True} if self._autogenerated else None,
715
+ )
716
+ if output_cols_prefix == "fit_predict_":
717
+ if hasattr(self._sklearn_object, "n_clusters"):
718
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
719
+ num_examples = self._sklearn_object.n_clusters
720
+ elif hasattr(self._sklearn_object, "min_samples"):
721
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
722
+ num_examples = self._sklearn_object.min_samples
723
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
724
+ # LocalOutlierFactor expects n_neighbors <= n_samples
725
+ num_examples = self._sklearn_object.n_neighbors
726
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
727
+ else:
728
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
691
729
 
692
730
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
693
731
  # seen during the fit.
@@ -699,12 +737,14 @@ class OPTICS(BaseTransformer):
699
737
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
700
738
  if self.sample_weight_col:
701
739
  output_df_columns_set -= set(self.sample_weight_col)
740
+
702
741
  # if the dimension of inferred output column names is correct; use it
703
742
  if len(expected_output_cols_list) == len(output_df_columns_set):
704
- return expected_output_cols_list
743
+ return expected_output_cols_list, output_df_pd
705
744
  # otherwise, use the sklearn estimator's output
706
745
  else:
707
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
746
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
747
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
708
748
 
709
749
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
710
750
  @telemetry.send_api_usage_telemetry(
@@ -750,7 +790,7 @@ class OPTICS(BaseTransformer):
750
790
  drop_input_cols=self._drop_input_cols,
751
791
  expected_output_cols_type="float",
752
792
  )
753
- expected_output_cols = self._align_expected_output_names(
793
+ expected_output_cols, _ = self._align_expected_output(
754
794
  inference_method, dataset, expected_output_cols, output_cols_prefix
755
795
  )
756
796
 
@@ -816,7 +856,7 @@ class OPTICS(BaseTransformer):
816
856
  drop_input_cols=self._drop_input_cols,
817
857
  expected_output_cols_type="float",
818
858
  )
819
- expected_output_cols = self._align_expected_output_names(
859
+ expected_output_cols, _ = self._align_expected_output(
820
860
  inference_method, dataset, expected_output_cols, output_cols_prefix
821
861
  )
822
862
  elif isinstance(dataset, pd.DataFrame):
@@ -879,7 +919,7 @@ class OPTICS(BaseTransformer):
879
919
  drop_input_cols=self._drop_input_cols,
880
920
  expected_output_cols_type="float",
881
921
  )
882
- expected_output_cols = self._align_expected_output_names(
922
+ expected_output_cols, _ = self._align_expected_output(
883
923
  inference_method, dataset, expected_output_cols, output_cols_prefix
884
924
  )
885
925
 
@@ -944,7 +984,7 @@ class OPTICS(BaseTransformer):
944
984
  drop_input_cols = self._drop_input_cols,
945
985
  expected_output_cols_type="float",
946
986
  )
947
- expected_output_cols = self._align_expected_output_names(
987
+ expected_output_cols, _ = self._align_expected_output(
948
988
  inference_method, dataset, expected_output_cols, output_cols_prefix
949
989
  )
950
990
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -534,12 +531,23 @@ class SpectralBiclustering(BaseTransformer):
534
531
  autogenerated=self._autogenerated,
535
532
  subproject=_SUBPROJECT,
536
533
  )
537
- output_result, fitted_estimator = model_trainer.train_fit_predict(
538
- drop_input_cols=self._drop_input_cols,
539
- expected_output_cols_list=(
540
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
541
- ),
534
+ expected_output_cols = (
535
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
542
536
  )
537
+ if isinstance(dataset, DataFrame):
538
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
539
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
540
+ )
541
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
542
+ drop_input_cols=self._drop_input_cols,
543
+ expected_output_cols_list=expected_output_cols,
544
+ example_output_pd_df=example_output_pd_df,
545
+ )
546
+ else:
547
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
548
+ drop_input_cols=self._drop_input_cols,
549
+ expected_output_cols_list=expected_output_cols,
550
+ )
543
551
  self._sklearn_object = fitted_estimator
544
552
  self._is_fitted = True
545
553
  return output_result
@@ -562,6 +570,7 @@ class SpectralBiclustering(BaseTransformer):
562
570
  """
563
571
  self._infer_input_output_cols(dataset)
564
572
  super()._check_dataset_type(dataset)
573
+
565
574
  model_trainer = ModelTrainerBuilder.build_fit_transform(
566
575
  estimator=self._sklearn_object,
567
576
  dataset=dataset,
@@ -618,12 +627,41 @@ class SpectralBiclustering(BaseTransformer):
618
627
 
619
628
  return rv
620
629
 
621
- def _align_expected_output_names(
622
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
623
- ) -> List[str]:
630
+ def _align_expected_output(
631
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
632
+ ) -> Tuple[List[str], pd.DataFrame]:
633
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
634
+ and output dataframe with 1 line.
635
+ If the method is fit_predict, run 2 lines of data.
636
+ """
624
637
  # in case the inferred output column names dimension is different
625
638
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
626
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
639
+
640
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
641
+ # so change the minimum of number of rows to 2
642
+ num_examples = 2
643
+ statement_params = telemetry.get_function_usage_statement_params(
644
+ project=_PROJECT,
645
+ subproject=_SUBPROJECT,
646
+ function_name=telemetry.get_statement_params_full_func_name(
647
+ inspect.currentframe(), SpectralBiclustering.__class__.__name__
648
+ ),
649
+ api_calls=[Session.call],
650
+ custom_tags={"autogen": True} if self._autogenerated else None,
651
+ )
652
+ if output_cols_prefix == "fit_predict_":
653
+ if hasattr(self._sklearn_object, "n_clusters"):
654
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
655
+ num_examples = self._sklearn_object.n_clusters
656
+ elif hasattr(self._sklearn_object, "min_samples"):
657
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
658
+ num_examples = self._sklearn_object.min_samples
659
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
660
+ # LocalOutlierFactor expects n_neighbors <= n_samples
661
+ num_examples = self._sklearn_object.n_neighbors
662
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
663
+ else:
664
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
627
665
 
628
666
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
629
667
  # seen during the fit.
@@ -635,12 +673,14 @@ class SpectralBiclustering(BaseTransformer):
635
673
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
636
674
  if self.sample_weight_col:
637
675
  output_df_columns_set -= set(self.sample_weight_col)
676
+
638
677
  # if the dimension of inferred output column names is correct; use it
639
678
  if len(expected_output_cols_list) == len(output_df_columns_set):
640
- return expected_output_cols_list
679
+ return expected_output_cols_list, output_df_pd
641
680
  # otherwise, use the sklearn estimator's output
642
681
  else:
643
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
682
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
683
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
644
684
 
645
685
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
646
686
  @telemetry.send_api_usage_telemetry(
@@ -686,7 +726,7 @@ class SpectralBiclustering(BaseTransformer):
686
726
  drop_input_cols=self._drop_input_cols,
687
727
  expected_output_cols_type="float",
688
728
  )
689
- expected_output_cols = self._align_expected_output_names(
729
+ expected_output_cols, _ = self._align_expected_output(
690
730
  inference_method, dataset, expected_output_cols, output_cols_prefix
691
731
  )
692
732
 
@@ -752,7 +792,7 @@ class SpectralBiclustering(BaseTransformer):
752
792
  drop_input_cols=self._drop_input_cols,
753
793
  expected_output_cols_type="float",
754
794
  )
755
- expected_output_cols = self._align_expected_output_names(
795
+ expected_output_cols, _ = self._align_expected_output(
756
796
  inference_method, dataset, expected_output_cols, output_cols_prefix
757
797
  )
758
798
  elif isinstance(dataset, pd.DataFrame):
@@ -815,7 +855,7 @@ class SpectralBiclustering(BaseTransformer):
815
855
  drop_input_cols=self._drop_input_cols,
816
856
  expected_output_cols_type="float",
817
857
  )
818
- expected_output_cols = self._align_expected_output_names(
858
+ expected_output_cols, _ = self._align_expected_output(
819
859
  inference_method, dataset, expected_output_cols, output_cols_prefix
820
860
  )
821
861
 
@@ -880,7 +920,7 @@ class SpectralBiclustering(BaseTransformer):
880
920
  drop_input_cols = self._drop_input_cols,
881
921
  expected_output_cols_type="float",
882
922
  )
883
- expected_output_cols = self._align_expected_output_names(
923
+ expected_output_cols, _ = self._align_expected_output(
884
924
  inference_method, dataset, expected_output_cols, output_cols_prefix
885
925
  )
886
926