snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -573,12 +570,23 @@ class Perceptron(BaseTransformer):
573
570
  autogenerated=self._autogenerated,
574
571
  subproject=_SUBPROJECT,
575
572
  )
576
- output_result, fitted_estimator = model_trainer.train_fit_predict(
577
- drop_input_cols=self._drop_input_cols,
578
- expected_output_cols_list=(
579
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
580
- ),
573
+ expected_output_cols = (
574
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
581
575
  )
576
+ if isinstance(dataset, DataFrame):
577
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
578
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
579
+ )
580
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
581
+ drop_input_cols=self._drop_input_cols,
582
+ expected_output_cols_list=expected_output_cols,
583
+ example_output_pd_df=example_output_pd_df,
584
+ )
585
+ else:
586
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
587
+ drop_input_cols=self._drop_input_cols,
588
+ expected_output_cols_list=expected_output_cols,
589
+ )
582
590
  self._sklearn_object = fitted_estimator
583
591
  self._is_fitted = True
584
592
  return output_result
@@ -601,6 +609,7 @@ class Perceptron(BaseTransformer):
601
609
  """
602
610
  self._infer_input_output_cols(dataset)
603
611
  super()._check_dataset_type(dataset)
612
+
604
613
  model_trainer = ModelTrainerBuilder.build_fit_transform(
605
614
  estimator=self._sklearn_object,
606
615
  dataset=dataset,
@@ -657,12 +666,41 @@ class Perceptron(BaseTransformer):
657
666
 
658
667
  return rv
659
668
 
660
- def _align_expected_output_names(
661
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
662
- ) -> List[str]:
669
+ def _align_expected_output(
670
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
671
+ ) -> Tuple[List[str], pd.DataFrame]:
672
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
673
+ and output dataframe with 1 line.
674
+ If the method is fit_predict, run 2 lines of data.
675
+ """
663
676
  # in case the inferred output column names dimension is different
664
677
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
665
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
678
+
679
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
680
+ # so change the minimum of number of rows to 2
681
+ num_examples = 2
682
+ statement_params = telemetry.get_function_usage_statement_params(
683
+ project=_PROJECT,
684
+ subproject=_SUBPROJECT,
685
+ function_name=telemetry.get_statement_params_full_func_name(
686
+ inspect.currentframe(), Perceptron.__class__.__name__
687
+ ),
688
+ api_calls=[Session.call],
689
+ custom_tags={"autogen": True} if self._autogenerated else None,
690
+ )
691
+ if output_cols_prefix == "fit_predict_":
692
+ if hasattr(self._sklearn_object, "n_clusters"):
693
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
694
+ num_examples = self._sklearn_object.n_clusters
695
+ elif hasattr(self._sklearn_object, "min_samples"):
696
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
697
+ num_examples = self._sklearn_object.min_samples
698
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
699
+ # LocalOutlierFactor expects n_neighbors <= n_samples
700
+ num_examples = self._sklearn_object.n_neighbors
701
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
702
+ else:
703
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
666
704
 
667
705
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
668
706
  # seen during the fit.
@@ -674,12 +712,14 @@ class Perceptron(BaseTransformer):
674
712
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
675
713
  if self.sample_weight_col:
676
714
  output_df_columns_set -= set(self.sample_weight_col)
715
+
677
716
  # if the dimension of inferred output column names is correct; use it
678
717
  if len(expected_output_cols_list) == len(output_df_columns_set):
679
- return expected_output_cols_list
718
+ return expected_output_cols_list, output_df_pd
680
719
  # otherwise, use the sklearn estimator's output
681
720
  else:
682
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
721
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
722
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
683
723
 
684
724
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
685
725
  @telemetry.send_api_usage_telemetry(
@@ -725,7 +765,7 @@ class Perceptron(BaseTransformer):
725
765
  drop_input_cols=self._drop_input_cols,
726
766
  expected_output_cols_type="float",
727
767
  )
728
- expected_output_cols = self._align_expected_output_names(
768
+ expected_output_cols, _ = self._align_expected_output(
729
769
  inference_method, dataset, expected_output_cols, output_cols_prefix
730
770
  )
731
771
 
@@ -791,7 +831,7 @@ class Perceptron(BaseTransformer):
791
831
  drop_input_cols=self._drop_input_cols,
792
832
  expected_output_cols_type="float",
793
833
  )
794
- expected_output_cols = self._align_expected_output_names(
834
+ expected_output_cols, _ = self._align_expected_output(
795
835
  inference_method, dataset, expected_output_cols, output_cols_prefix
796
836
  )
797
837
  elif isinstance(dataset, pd.DataFrame):
@@ -856,7 +896,7 @@ class Perceptron(BaseTransformer):
856
896
  drop_input_cols=self._drop_input_cols,
857
897
  expected_output_cols_type="float",
858
898
  )
859
- expected_output_cols = self._align_expected_output_names(
899
+ expected_output_cols, _ = self._align_expected_output(
860
900
  inference_method, dataset, expected_output_cols, output_cols_prefix
861
901
  )
862
902
 
@@ -921,7 +961,7 @@ class Perceptron(BaseTransformer):
921
961
  drop_input_cols = self._drop_input_cols,
922
962
  expected_output_cols_type="float",
923
963
  )
924
- expected_output_cols = self._align_expected_output_names(
964
+ expected_output_cols, _ = self._align_expected_output(
925
965
  inference_method, dataset, expected_output_cols, output_cols_prefix
926
966
  )
927
967
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -522,12 +519,23 @@ class PoissonRegressor(BaseTransformer):
522
519
  autogenerated=self._autogenerated,
523
520
  subproject=_SUBPROJECT,
524
521
  )
525
- output_result, fitted_estimator = model_trainer.train_fit_predict(
526
- drop_input_cols=self._drop_input_cols,
527
- expected_output_cols_list=(
528
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
529
- ),
522
+ expected_output_cols = (
523
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
530
524
  )
525
+ if isinstance(dataset, DataFrame):
526
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
527
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
528
+ )
529
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
530
+ drop_input_cols=self._drop_input_cols,
531
+ expected_output_cols_list=expected_output_cols,
532
+ example_output_pd_df=example_output_pd_df,
533
+ )
534
+ else:
535
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
536
+ drop_input_cols=self._drop_input_cols,
537
+ expected_output_cols_list=expected_output_cols,
538
+ )
531
539
  self._sklearn_object = fitted_estimator
532
540
  self._is_fitted = True
533
541
  return output_result
@@ -550,6 +558,7 @@ class PoissonRegressor(BaseTransformer):
550
558
  """
551
559
  self._infer_input_output_cols(dataset)
552
560
  super()._check_dataset_type(dataset)
561
+
553
562
  model_trainer = ModelTrainerBuilder.build_fit_transform(
554
563
  estimator=self._sklearn_object,
555
564
  dataset=dataset,
@@ -606,12 +615,41 @@ class PoissonRegressor(BaseTransformer):
606
615
 
607
616
  return rv
608
617
 
609
- def _align_expected_output_names(
610
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
611
- ) -> List[str]:
618
+ def _align_expected_output(
619
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
620
+ ) -> Tuple[List[str], pd.DataFrame]:
621
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
622
+ and output dataframe with 1 line.
623
+ If the method is fit_predict, run 2 lines of data.
624
+ """
612
625
  # in case the inferred output column names dimension is different
613
626
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
614
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
627
+
628
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
629
+ # so change the minimum of number of rows to 2
630
+ num_examples = 2
631
+ statement_params = telemetry.get_function_usage_statement_params(
632
+ project=_PROJECT,
633
+ subproject=_SUBPROJECT,
634
+ function_name=telemetry.get_statement_params_full_func_name(
635
+ inspect.currentframe(), PoissonRegressor.__class__.__name__
636
+ ),
637
+ api_calls=[Session.call],
638
+ custom_tags={"autogen": True} if self._autogenerated else None,
639
+ )
640
+ if output_cols_prefix == "fit_predict_":
641
+ if hasattr(self._sklearn_object, "n_clusters"):
642
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
643
+ num_examples = self._sklearn_object.n_clusters
644
+ elif hasattr(self._sklearn_object, "min_samples"):
645
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
646
+ num_examples = self._sklearn_object.min_samples
647
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
648
+ # LocalOutlierFactor expects n_neighbors <= n_samples
649
+ num_examples = self._sklearn_object.n_neighbors
650
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
651
+ else:
652
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
615
653
 
616
654
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
617
655
  # seen during the fit.
@@ -623,12 +661,14 @@ class PoissonRegressor(BaseTransformer):
623
661
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
624
662
  if self.sample_weight_col:
625
663
  output_df_columns_set -= set(self.sample_weight_col)
664
+
626
665
  # if the dimension of inferred output column names is correct; use it
627
666
  if len(expected_output_cols_list) == len(output_df_columns_set):
628
- return expected_output_cols_list
667
+ return expected_output_cols_list, output_df_pd
629
668
  # otherwise, use the sklearn estimator's output
630
669
  else:
631
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
670
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
671
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
632
672
 
633
673
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
634
674
  @telemetry.send_api_usage_telemetry(
@@ -674,7 +714,7 @@ class PoissonRegressor(BaseTransformer):
674
714
  drop_input_cols=self._drop_input_cols,
675
715
  expected_output_cols_type="float",
676
716
  )
677
- expected_output_cols = self._align_expected_output_names(
717
+ expected_output_cols, _ = self._align_expected_output(
678
718
  inference_method, dataset, expected_output_cols, output_cols_prefix
679
719
  )
680
720
 
@@ -740,7 +780,7 @@ class PoissonRegressor(BaseTransformer):
740
780
  drop_input_cols=self._drop_input_cols,
741
781
  expected_output_cols_type="float",
742
782
  )
743
- expected_output_cols = self._align_expected_output_names(
783
+ expected_output_cols, _ = self._align_expected_output(
744
784
  inference_method, dataset, expected_output_cols, output_cols_prefix
745
785
  )
746
786
  elif isinstance(dataset, pd.DataFrame):
@@ -803,7 +843,7 @@ class PoissonRegressor(BaseTransformer):
803
843
  drop_input_cols=self._drop_input_cols,
804
844
  expected_output_cols_type="float",
805
845
  )
806
- expected_output_cols = self._align_expected_output_names(
846
+ expected_output_cols, _ = self._align_expected_output(
807
847
  inference_method, dataset, expected_output_cols, output_cols_prefix
808
848
  )
809
849
 
@@ -868,7 +908,7 @@ class PoissonRegressor(BaseTransformer):
868
908
  drop_input_cols = self._drop_input_cols,
869
909
  expected_output_cols_type="float",
870
910
  )
871
- expected_output_cols = self._align_expected_output_names(
911
+ expected_output_cols, _ = self._align_expected_output(
872
912
  inference_method, dataset, expected_output_cols, output_cols_prefix
873
913
  )
874
914
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -578,12 +575,23 @@ class RANSACRegressor(BaseTransformer):
578
575
  autogenerated=self._autogenerated,
579
576
  subproject=_SUBPROJECT,
580
577
  )
581
- output_result, fitted_estimator = model_trainer.train_fit_predict(
582
- drop_input_cols=self._drop_input_cols,
583
- expected_output_cols_list=(
584
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
585
- ),
578
+ expected_output_cols = (
579
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
586
580
  )
581
+ if isinstance(dataset, DataFrame):
582
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
583
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
584
+ )
585
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
586
+ drop_input_cols=self._drop_input_cols,
587
+ expected_output_cols_list=expected_output_cols,
588
+ example_output_pd_df=example_output_pd_df,
589
+ )
590
+ else:
591
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
592
+ drop_input_cols=self._drop_input_cols,
593
+ expected_output_cols_list=expected_output_cols,
594
+ )
587
595
  self._sklearn_object = fitted_estimator
588
596
  self._is_fitted = True
589
597
  return output_result
@@ -606,6 +614,7 @@ class RANSACRegressor(BaseTransformer):
606
614
  """
607
615
  self._infer_input_output_cols(dataset)
608
616
  super()._check_dataset_type(dataset)
617
+
609
618
  model_trainer = ModelTrainerBuilder.build_fit_transform(
610
619
  estimator=self._sklearn_object,
611
620
  dataset=dataset,
@@ -662,12 +671,41 @@ class RANSACRegressor(BaseTransformer):
662
671
 
663
672
  return rv
664
673
 
665
- def _align_expected_output_names(
666
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
667
- ) -> List[str]:
674
+ def _align_expected_output(
675
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
676
+ ) -> Tuple[List[str], pd.DataFrame]:
677
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
678
+ and output dataframe with 1 line.
679
+ If the method is fit_predict, run 2 lines of data.
680
+ """
668
681
  # in case the inferred output column names dimension is different
669
682
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
670
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
683
+
684
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
685
+ # so change the minimum of number of rows to 2
686
+ num_examples = 2
687
+ statement_params = telemetry.get_function_usage_statement_params(
688
+ project=_PROJECT,
689
+ subproject=_SUBPROJECT,
690
+ function_name=telemetry.get_statement_params_full_func_name(
691
+ inspect.currentframe(), RANSACRegressor.__class__.__name__
692
+ ),
693
+ api_calls=[Session.call],
694
+ custom_tags={"autogen": True} if self._autogenerated else None,
695
+ )
696
+ if output_cols_prefix == "fit_predict_":
697
+ if hasattr(self._sklearn_object, "n_clusters"):
698
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
699
+ num_examples = self._sklearn_object.n_clusters
700
+ elif hasattr(self._sklearn_object, "min_samples"):
701
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
702
+ num_examples = self._sklearn_object.min_samples
703
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
704
+ # LocalOutlierFactor expects n_neighbors <= n_samples
705
+ num_examples = self._sklearn_object.n_neighbors
706
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
707
+ else:
708
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
671
709
 
672
710
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
673
711
  # seen during the fit.
@@ -679,12 +717,14 @@ class RANSACRegressor(BaseTransformer):
679
717
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
680
718
  if self.sample_weight_col:
681
719
  output_df_columns_set -= set(self.sample_weight_col)
720
+
682
721
  # if the dimension of inferred output column names is correct; use it
683
722
  if len(expected_output_cols_list) == len(output_df_columns_set):
684
- return expected_output_cols_list
723
+ return expected_output_cols_list, output_df_pd
685
724
  # otherwise, use the sklearn estimator's output
686
725
  else:
687
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
726
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
727
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
688
728
 
689
729
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
690
730
  @telemetry.send_api_usage_telemetry(
@@ -730,7 +770,7 @@ class RANSACRegressor(BaseTransformer):
730
770
  drop_input_cols=self._drop_input_cols,
731
771
  expected_output_cols_type="float",
732
772
  )
733
- expected_output_cols = self._align_expected_output_names(
773
+ expected_output_cols, _ = self._align_expected_output(
734
774
  inference_method, dataset, expected_output_cols, output_cols_prefix
735
775
  )
736
776
 
@@ -796,7 +836,7 @@ class RANSACRegressor(BaseTransformer):
796
836
  drop_input_cols=self._drop_input_cols,
797
837
  expected_output_cols_type="float",
798
838
  )
799
- expected_output_cols = self._align_expected_output_names(
839
+ expected_output_cols, _ = self._align_expected_output(
800
840
  inference_method, dataset, expected_output_cols, output_cols_prefix
801
841
  )
802
842
  elif isinstance(dataset, pd.DataFrame):
@@ -859,7 +899,7 @@ class RANSACRegressor(BaseTransformer):
859
899
  drop_input_cols=self._drop_input_cols,
860
900
  expected_output_cols_type="float",
861
901
  )
862
- expected_output_cols = self._align_expected_output_names(
902
+ expected_output_cols, _ = self._align_expected_output(
863
903
  inference_method, dataset, expected_output_cols, output_cols_prefix
864
904
  )
865
905
 
@@ -924,7 +964,7 @@ class RANSACRegressor(BaseTransformer):
924
964
  drop_input_cols = self._drop_input_cols,
925
965
  expected_output_cols_type="float",
926
966
  )
927
- expected_output_cols = self._align_expected_output_names(
967
+ expected_output_cols, _ = self._align_expected_output(
928
968
  inference_method, dataset, expected_output_cols, output_cols_prefix
929
969
  )
930
970
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -570,12 +567,23 @@ class Ridge(BaseTransformer):
570
567
  autogenerated=self._autogenerated,
571
568
  subproject=_SUBPROJECT,
572
569
  )
573
- output_result, fitted_estimator = model_trainer.train_fit_predict(
574
- drop_input_cols=self._drop_input_cols,
575
- expected_output_cols_list=(
576
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
577
- ),
570
+ expected_output_cols = (
571
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
578
572
  )
573
+ if isinstance(dataset, DataFrame):
574
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
575
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
576
+ )
577
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
578
+ drop_input_cols=self._drop_input_cols,
579
+ expected_output_cols_list=expected_output_cols,
580
+ example_output_pd_df=example_output_pd_df,
581
+ )
582
+ else:
583
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
584
+ drop_input_cols=self._drop_input_cols,
585
+ expected_output_cols_list=expected_output_cols,
586
+ )
579
587
  self._sklearn_object = fitted_estimator
580
588
  self._is_fitted = True
581
589
  return output_result
@@ -598,6 +606,7 @@ class Ridge(BaseTransformer):
598
606
  """
599
607
  self._infer_input_output_cols(dataset)
600
608
  super()._check_dataset_type(dataset)
609
+
601
610
  model_trainer = ModelTrainerBuilder.build_fit_transform(
602
611
  estimator=self._sklearn_object,
603
612
  dataset=dataset,
@@ -654,12 +663,41 @@ class Ridge(BaseTransformer):
654
663
 
655
664
  return rv
656
665
 
657
- def _align_expected_output_names(
658
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
659
- ) -> List[str]:
666
+ def _align_expected_output(
667
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
668
+ ) -> Tuple[List[str], pd.DataFrame]:
669
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
670
+ and output dataframe with 1 line.
671
+ If the method is fit_predict, run 2 lines of data.
672
+ """
660
673
  # in case the inferred output column names dimension is different
661
674
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
662
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
675
+
676
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
677
+ # so change the minimum of number of rows to 2
678
+ num_examples = 2
679
+ statement_params = telemetry.get_function_usage_statement_params(
680
+ project=_PROJECT,
681
+ subproject=_SUBPROJECT,
682
+ function_name=telemetry.get_statement_params_full_func_name(
683
+ inspect.currentframe(), Ridge.__class__.__name__
684
+ ),
685
+ api_calls=[Session.call],
686
+ custom_tags={"autogen": True} if self._autogenerated else None,
687
+ )
688
+ if output_cols_prefix == "fit_predict_":
689
+ if hasattr(self._sklearn_object, "n_clusters"):
690
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
691
+ num_examples = self._sklearn_object.n_clusters
692
+ elif hasattr(self._sklearn_object, "min_samples"):
693
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
694
+ num_examples = self._sklearn_object.min_samples
695
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
696
+ # LocalOutlierFactor expects n_neighbors <= n_samples
697
+ num_examples = self._sklearn_object.n_neighbors
698
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
699
+ else:
700
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
663
701
 
664
702
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
665
703
  # seen during the fit.
@@ -671,12 +709,14 @@ class Ridge(BaseTransformer):
671
709
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
672
710
  if self.sample_weight_col:
673
711
  output_df_columns_set -= set(self.sample_weight_col)
712
+
674
713
  # if the dimension of inferred output column names is correct; use it
675
714
  if len(expected_output_cols_list) == len(output_df_columns_set):
676
- return expected_output_cols_list
715
+ return expected_output_cols_list, output_df_pd
677
716
  # otherwise, use the sklearn estimator's output
678
717
  else:
679
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
718
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
719
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
680
720
 
681
721
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
682
722
  @telemetry.send_api_usage_telemetry(
@@ -722,7 +762,7 @@ class Ridge(BaseTransformer):
722
762
  drop_input_cols=self._drop_input_cols,
723
763
  expected_output_cols_type="float",
724
764
  )
725
- expected_output_cols = self._align_expected_output_names(
765
+ expected_output_cols, _ = self._align_expected_output(
726
766
  inference_method, dataset, expected_output_cols, output_cols_prefix
727
767
  )
728
768
 
@@ -788,7 +828,7 @@ class Ridge(BaseTransformer):
788
828
  drop_input_cols=self._drop_input_cols,
789
829
  expected_output_cols_type="float",
790
830
  )
791
- expected_output_cols = self._align_expected_output_names(
831
+ expected_output_cols, _ = self._align_expected_output(
792
832
  inference_method, dataset, expected_output_cols, output_cols_prefix
793
833
  )
794
834
  elif isinstance(dataset, pd.DataFrame):
@@ -851,7 +891,7 @@ class Ridge(BaseTransformer):
851
891
  drop_input_cols=self._drop_input_cols,
852
892
  expected_output_cols_type="float",
853
893
  )
854
- expected_output_cols = self._align_expected_output_names(
894
+ expected_output_cols, _ = self._align_expected_output(
855
895
  inference_method, dataset, expected_output_cols, output_cols_prefix
856
896
  )
857
897
 
@@ -916,7 +956,7 @@ class Ridge(BaseTransformer):
916
956
  drop_input_cols = self._drop_input_cols,
917
957
  expected_output_cols_type="float",
918
958
  )
919
- expected_output_cols = self._align_expected_output_names(
959
+ expected_output_cols, _ = self._align_expected_output(
920
960
  inference_method, dataset, expected_output_cols, output_cols_prefix
921
961
  )
922
962