snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -570,12 +567,23 @@ class RidgeClassifier(BaseTransformer):
570
567
  autogenerated=self._autogenerated,
571
568
  subproject=_SUBPROJECT,
572
569
  )
573
- output_result, fitted_estimator = model_trainer.train_fit_predict(
574
- drop_input_cols=self._drop_input_cols,
575
- expected_output_cols_list=(
576
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
577
- ),
570
+ expected_output_cols = (
571
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
578
572
  )
573
+ if isinstance(dataset, DataFrame):
574
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
575
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
576
+ )
577
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
578
+ drop_input_cols=self._drop_input_cols,
579
+ expected_output_cols_list=expected_output_cols,
580
+ example_output_pd_df=example_output_pd_df,
581
+ )
582
+ else:
583
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
584
+ drop_input_cols=self._drop_input_cols,
585
+ expected_output_cols_list=expected_output_cols,
586
+ )
579
587
  self._sklearn_object = fitted_estimator
580
588
  self._is_fitted = True
581
589
  return output_result
@@ -598,6 +606,7 @@ class RidgeClassifier(BaseTransformer):
598
606
  """
599
607
  self._infer_input_output_cols(dataset)
600
608
  super()._check_dataset_type(dataset)
609
+
601
610
  model_trainer = ModelTrainerBuilder.build_fit_transform(
602
611
  estimator=self._sklearn_object,
603
612
  dataset=dataset,
@@ -654,12 +663,41 @@ class RidgeClassifier(BaseTransformer):
654
663
 
655
664
  return rv
656
665
 
657
- def _align_expected_output_names(
658
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
659
- ) -> List[str]:
666
+ def _align_expected_output(
667
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
668
+ ) -> Tuple[List[str], pd.DataFrame]:
669
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
670
+ and output dataframe with 1 line.
671
+ If the method is fit_predict, run 2 lines of data.
672
+ """
660
673
  # in case the inferred output column names dimension is different
661
674
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
662
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
675
+
676
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
677
+ # so change the minimum of number of rows to 2
678
+ num_examples = 2
679
+ statement_params = telemetry.get_function_usage_statement_params(
680
+ project=_PROJECT,
681
+ subproject=_SUBPROJECT,
682
+ function_name=telemetry.get_statement_params_full_func_name(
683
+ inspect.currentframe(), RidgeClassifier.__class__.__name__
684
+ ),
685
+ api_calls=[Session.call],
686
+ custom_tags={"autogen": True} if self._autogenerated else None,
687
+ )
688
+ if output_cols_prefix == "fit_predict_":
689
+ if hasattr(self._sklearn_object, "n_clusters"):
690
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
691
+ num_examples = self._sklearn_object.n_clusters
692
+ elif hasattr(self._sklearn_object, "min_samples"):
693
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
694
+ num_examples = self._sklearn_object.min_samples
695
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
696
+ # LocalOutlierFactor expects n_neighbors <= n_samples
697
+ num_examples = self._sklearn_object.n_neighbors
698
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
699
+ else:
700
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
663
701
 
664
702
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
665
703
  # seen during the fit.
@@ -671,12 +709,14 @@ class RidgeClassifier(BaseTransformer):
671
709
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
672
710
  if self.sample_weight_col:
673
711
  output_df_columns_set -= set(self.sample_weight_col)
712
+
674
713
  # if the dimension of inferred output column names is correct; use it
675
714
  if len(expected_output_cols_list) == len(output_df_columns_set):
676
- return expected_output_cols_list
715
+ return expected_output_cols_list, output_df_pd
677
716
  # otherwise, use the sklearn estimator's output
678
717
  else:
679
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
718
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
719
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
680
720
 
681
721
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
682
722
  @telemetry.send_api_usage_telemetry(
@@ -722,7 +762,7 @@ class RidgeClassifier(BaseTransformer):
722
762
  drop_input_cols=self._drop_input_cols,
723
763
  expected_output_cols_type="float",
724
764
  )
725
- expected_output_cols = self._align_expected_output_names(
765
+ expected_output_cols, _ = self._align_expected_output(
726
766
  inference_method, dataset, expected_output_cols, output_cols_prefix
727
767
  )
728
768
 
@@ -788,7 +828,7 @@ class RidgeClassifier(BaseTransformer):
788
828
  drop_input_cols=self._drop_input_cols,
789
829
  expected_output_cols_type="float",
790
830
  )
791
- expected_output_cols = self._align_expected_output_names(
831
+ expected_output_cols, _ = self._align_expected_output(
792
832
  inference_method, dataset, expected_output_cols, output_cols_prefix
793
833
  )
794
834
  elif isinstance(dataset, pd.DataFrame):
@@ -853,7 +893,7 @@ class RidgeClassifier(BaseTransformer):
853
893
  drop_input_cols=self._drop_input_cols,
854
894
  expected_output_cols_type="float",
855
895
  )
856
- expected_output_cols = self._align_expected_output_names(
896
+ expected_output_cols, _ = self._align_expected_output(
857
897
  inference_method, dataset, expected_output_cols, output_cols_prefix
858
898
  )
859
899
 
@@ -918,7 +958,7 @@ class RidgeClassifier(BaseTransformer):
918
958
  drop_input_cols = self._drop_input_cols,
919
959
  expected_output_cols_type="float",
920
960
  )
921
- expected_output_cols = self._align_expected_output_names(
961
+ expected_output_cols, _ = self._align_expected_output(
922
962
  inference_method, dataset, expected_output_cols, output_cols_prefix
923
963
  )
924
964
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -521,12 +518,23 @@ class RidgeClassifierCV(BaseTransformer):
521
518
  autogenerated=self._autogenerated,
522
519
  subproject=_SUBPROJECT,
523
520
  )
524
- output_result, fitted_estimator = model_trainer.train_fit_predict(
525
- drop_input_cols=self._drop_input_cols,
526
- expected_output_cols_list=(
527
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
528
- ),
521
+ expected_output_cols = (
522
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
529
523
  )
524
+ if isinstance(dataset, DataFrame):
525
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
526
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
527
+ )
528
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
529
+ drop_input_cols=self._drop_input_cols,
530
+ expected_output_cols_list=expected_output_cols,
531
+ example_output_pd_df=example_output_pd_df,
532
+ )
533
+ else:
534
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
535
+ drop_input_cols=self._drop_input_cols,
536
+ expected_output_cols_list=expected_output_cols,
537
+ )
530
538
  self._sklearn_object = fitted_estimator
531
539
  self._is_fitted = True
532
540
  return output_result
@@ -549,6 +557,7 @@ class RidgeClassifierCV(BaseTransformer):
549
557
  """
550
558
  self._infer_input_output_cols(dataset)
551
559
  super()._check_dataset_type(dataset)
560
+
552
561
  model_trainer = ModelTrainerBuilder.build_fit_transform(
553
562
  estimator=self._sklearn_object,
554
563
  dataset=dataset,
@@ -605,12 +614,41 @@ class RidgeClassifierCV(BaseTransformer):
605
614
 
606
615
  return rv
607
616
 
608
- def _align_expected_output_names(
609
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
610
- ) -> List[str]:
617
+ def _align_expected_output(
618
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
619
+ ) -> Tuple[List[str], pd.DataFrame]:
620
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
621
+ and output dataframe with 1 line.
622
+ If the method is fit_predict, run 2 lines of data.
623
+ """
611
624
  # in case the inferred output column names dimension is different
612
625
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
613
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
626
+
627
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
628
+ # so change the minimum of number of rows to 2
629
+ num_examples = 2
630
+ statement_params = telemetry.get_function_usage_statement_params(
631
+ project=_PROJECT,
632
+ subproject=_SUBPROJECT,
633
+ function_name=telemetry.get_statement_params_full_func_name(
634
+ inspect.currentframe(), RidgeClassifierCV.__class__.__name__
635
+ ),
636
+ api_calls=[Session.call],
637
+ custom_tags={"autogen": True} if self._autogenerated else None,
638
+ )
639
+ if output_cols_prefix == "fit_predict_":
640
+ if hasattr(self._sklearn_object, "n_clusters"):
641
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
642
+ num_examples = self._sklearn_object.n_clusters
643
+ elif hasattr(self._sklearn_object, "min_samples"):
644
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
645
+ num_examples = self._sklearn_object.min_samples
646
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
647
+ # LocalOutlierFactor expects n_neighbors <= n_samples
648
+ num_examples = self._sklearn_object.n_neighbors
649
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
650
+ else:
651
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
614
652
 
615
653
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
616
654
  # seen during the fit.
@@ -622,12 +660,14 @@ class RidgeClassifierCV(BaseTransformer):
622
660
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
623
661
  if self.sample_weight_col:
624
662
  output_df_columns_set -= set(self.sample_weight_col)
663
+
625
664
  # if the dimension of inferred output column names is correct; use it
626
665
  if len(expected_output_cols_list) == len(output_df_columns_set):
627
- return expected_output_cols_list
666
+ return expected_output_cols_list, output_df_pd
628
667
  # otherwise, use the sklearn estimator's output
629
668
  else:
630
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
669
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
670
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
631
671
 
632
672
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
633
673
  @telemetry.send_api_usage_telemetry(
@@ -673,7 +713,7 @@ class RidgeClassifierCV(BaseTransformer):
673
713
  drop_input_cols=self._drop_input_cols,
674
714
  expected_output_cols_type="float",
675
715
  )
676
- expected_output_cols = self._align_expected_output_names(
716
+ expected_output_cols, _ = self._align_expected_output(
677
717
  inference_method, dataset, expected_output_cols, output_cols_prefix
678
718
  )
679
719
 
@@ -739,7 +779,7 @@ class RidgeClassifierCV(BaseTransformer):
739
779
  drop_input_cols=self._drop_input_cols,
740
780
  expected_output_cols_type="float",
741
781
  )
742
- expected_output_cols = self._align_expected_output_names(
782
+ expected_output_cols, _ = self._align_expected_output(
743
783
  inference_method, dataset, expected_output_cols, output_cols_prefix
744
784
  )
745
785
  elif isinstance(dataset, pd.DataFrame):
@@ -804,7 +844,7 @@ class RidgeClassifierCV(BaseTransformer):
804
844
  drop_input_cols=self._drop_input_cols,
805
845
  expected_output_cols_type="float",
806
846
  )
807
- expected_output_cols = self._align_expected_output_names(
847
+ expected_output_cols, _ = self._align_expected_output(
808
848
  inference_method, dataset, expected_output_cols, output_cols_prefix
809
849
  )
810
850
 
@@ -869,7 +909,7 @@ class RidgeClassifierCV(BaseTransformer):
869
909
  drop_input_cols = self._drop_input_cols,
870
910
  expected_output_cols_type="float",
871
911
  )
872
- expected_output_cols = self._align_expected_output_names(
912
+ expected_output_cols, _ = self._align_expected_output(
873
913
  inference_method, dataset, expected_output_cols, output_cols_prefix
874
914
  )
875
915
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -542,12 +539,23 @@ class RidgeCV(BaseTransformer):
542
539
  autogenerated=self._autogenerated,
543
540
  subproject=_SUBPROJECT,
544
541
  )
545
- output_result, fitted_estimator = model_trainer.train_fit_predict(
546
- drop_input_cols=self._drop_input_cols,
547
- expected_output_cols_list=(
548
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
549
- ),
542
+ expected_output_cols = (
543
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
550
544
  )
545
+ if isinstance(dataset, DataFrame):
546
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
547
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
548
+ )
549
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
550
+ drop_input_cols=self._drop_input_cols,
551
+ expected_output_cols_list=expected_output_cols,
552
+ example_output_pd_df=example_output_pd_df,
553
+ )
554
+ else:
555
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
556
+ drop_input_cols=self._drop_input_cols,
557
+ expected_output_cols_list=expected_output_cols,
558
+ )
551
559
  self._sklearn_object = fitted_estimator
552
560
  self._is_fitted = True
553
561
  return output_result
@@ -570,6 +578,7 @@ class RidgeCV(BaseTransformer):
570
578
  """
571
579
  self._infer_input_output_cols(dataset)
572
580
  super()._check_dataset_type(dataset)
581
+
573
582
  model_trainer = ModelTrainerBuilder.build_fit_transform(
574
583
  estimator=self._sklearn_object,
575
584
  dataset=dataset,
@@ -626,12 +635,41 @@ class RidgeCV(BaseTransformer):
626
635
 
627
636
  return rv
628
637
 
629
- def _align_expected_output_names(
630
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
631
- ) -> List[str]:
638
+ def _align_expected_output(
639
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
640
+ ) -> Tuple[List[str], pd.DataFrame]:
641
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
642
+ and output dataframe with 1 line.
643
+ If the method is fit_predict, run 2 lines of data.
644
+ """
632
645
  # in case the inferred output column names dimension is different
633
646
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
634
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
647
+
648
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
649
+ # so change the minimum of number of rows to 2
650
+ num_examples = 2
651
+ statement_params = telemetry.get_function_usage_statement_params(
652
+ project=_PROJECT,
653
+ subproject=_SUBPROJECT,
654
+ function_name=telemetry.get_statement_params_full_func_name(
655
+ inspect.currentframe(), RidgeCV.__class__.__name__
656
+ ),
657
+ api_calls=[Session.call],
658
+ custom_tags={"autogen": True} if self._autogenerated else None,
659
+ )
660
+ if output_cols_prefix == "fit_predict_":
661
+ if hasattr(self._sklearn_object, "n_clusters"):
662
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
663
+ num_examples = self._sklearn_object.n_clusters
664
+ elif hasattr(self._sklearn_object, "min_samples"):
665
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
666
+ num_examples = self._sklearn_object.min_samples
667
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
668
+ # LocalOutlierFactor expects n_neighbors <= n_samples
669
+ num_examples = self._sklearn_object.n_neighbors
670
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
671
+ else:
672
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
635
673
 
636
674
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
637
675
  # seen during the fit.
@@ -643,12 +681,14 @@ class RidgeCV(BaseTransformer):
643
681
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
644
682
  if self.sample_weight_col:
645
683
  output_df_columns_set -= set(self.sample_weight_col)
684
+
646
685
  # if the dimension of inferred output column names is correct; use it
647
686
  if len(expected_output_cols_list) == len(output_df_columns_set):
648
- return expected_output_cols_list
687
+ return expected_output_cols_list, output_df_pd
649
688
  # otherwise, use the sklearn estimator's output
650
689
  else:
651
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
690
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
691
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
652
692
 
653
693
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
654
694
  @telemetry.send_api_usage_telemetry(
@@ -694,7 +734,7 @@ class RidgeCV(BaseTransformer):
694
734
  drop_input_cols=self._drop_input_cols,
695
735
  expected_output_cols_type="float",
696
736
  )
697
- expected_output_cols = self._align_expected_output_names(
737
+ expected_output_cols, _ = self._align_expected_output(
698
738
  inference_method, dataset, expected_output_cols, output_cols_prefix
699
739
  )
700
740
 
@@ -760,7 +800,7 @@ class RidgeCV(BaseTransformer):
760
800
  drop_input_cols=self._drop_input_cols,
761
801
  expected_output_cols_type="float",
762
802
  )
763
- expected_output_cols = self._align_expected_output_names(
803
+ expected_output_cols, _ = self._align_expected_output(
764
804
  inference_method, dataset, expected_output_cols, output_cols_prefix
765
805
  )
766
806
  elif isinstance(dataset, pd.DataFrame):
@@ -823,7 +863,7 @@ class RidgeCV(BaseTransformer):
823
863
  drop_input_cols=self._drop_input_cols,
824
864
  expected_output_cols_type="float",
825
865
  )
826
- expected_output_cols = self._align_expected_output_names(
866
+ expected_output_cols, _ = self._align_expected_output(
827
867
  inference_method, dataset, expected_output_cols, output_cols_prefix
828
868
  )
829
869
 
@@ -888,7 +928,7 @@ class RidgeCV(BaseTransformer):
888
928
  drop_input_cols = self._drop_input_cols,
889
929
  expected_output_cols_type="float",
890
930
  )
891
- expected_output_cols = self._align_expected_output_names(
931
+ expected_output_cols, _ = self._align_expected_output(
892
932
  inference_method, dataset, expected_output_cols, output_cols_prefix
893
933
  )
894
934
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -661,12 +658,23 @@ class SGDClassifier(BaseTransformer):
661
658
  autogenerated=self._autogenerated,
662
659
  subproject=_SUBPROJECT,
663
660
  )
664
- output_result, fitted_estimator = model_trainer.train_fit_predict(
665
- drop_input_cols=self._drop_input_cols,
666
- expected_output_cols_list=(
667
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
668
- ),
661
+ expected_output_cols = (
662
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
669
663
  )
664
+ if isinstance(dataset, DataFrame):
665
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
666
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
667
+ )
668
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
669
+ drop_input_cols=self._drop_input_cols,
670
+ expected_output_cols_list=expected_output_cols,
671
+ example_output_pd_df=example_output_pd_df,
672
+ )
673
+ else:
674
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
675
+ drop_input_cols=self._drop_input_cols,
676
+ expected_output_cols_list=expected_output_cols,
677
+ )
670
678
  self._sklearn_object = fitted_estimator
671
679
  self._is_fitted = True
672
680
  return output_result
@@ -689,6 +697,7 @@ class SGDClassifier(BaseTransformer):
689
697
  """
690
698
  self._infer_input_output_cols(dataset)
691
699
  super()._check_dataset_type(dataset)
700
+
692
701
  model_trainer = ModelTrainerBuilder.build_fit_transform(
693
702
  estimator=self._sklearn_object,
694
703
  dataset=dataset,
@@ -745,12 +754,41 @@ class SGDClassifier(BaseTransformer):
745
754
 
746
755
  return rv
747
756
 
748
- def _align_expected_output_names(
749
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
750
- ) -> List[str]:
757
+ def _align_expected_output(
758
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
759
+ ) -> Tuple[List[str], pd.DataFrame]:
760
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
761
+ and output dataframe with 1 line.
762
+ If the method is fit_predict, run 2 lines of data.
763
+ """
751
764
  # in case the inferred output column names dimension is different
752
765
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
753
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
766
+
767
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
768
+ # so change the minimum of number of rows to 2
769
+ num_examples = 2
770
+ statement_params = telemetry.get_function_usage_statement_params(
771
+ project=_PROJECT,
772
+ subproject=_SUBPROJECT,
773
+ function_name=telemetry.get_statement_params_full_func_name(
774
+ inspect.currentframe(), SGDClassifier.__class__.__name__
775
+ ),
776
+ api_calls=[Session.call],
777
+ custom_tags={"autogen": True} if self._autogenerated else None,
778
+ )
779
+ if output_cols_prefix == "fit_predict_":
780
+ if hasattr(self._sklearn_object, "n_clusters"):
781
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
782
+ num_examples = self._sklearn_object.n_clusters
783
+ elif hasattr(self._sklearn_object, "min_samples"):
784
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
785
+ num_examples = self._sklearn_object.min_samples
786
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
787
+ # LocalOutlierFactor expects n_neighbors <= n_samples
788
+ num_examples = self._sklearn_object.n_neighbors
789
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
790
+ else:
791
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
754
792
 
755
793
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
756
794
  # seen during the fit.
@@ -762,12 +800,14 @@ class SGDClassifier(BaseTransformer):
762
800
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
763
801
  if self.sample_weight_col:
764
802
  output_df_columns_set -= set(self.sample_weight_col)
803
+
765
804
  # if the dimension of inferred output column names is correct; use it
766
805
  if len(expected_output_cols_list) == len(output_df_columns_set):
767
- return expected_output_cols_list
806
+ return expected_output_cols_list, output_df_pd
768
807
  # otherwise, use the sklearn estimator's output
769
808
  else:
770
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
809
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
810
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
771
811
 
772
812
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
773
813
  @telemetry.send_api_usage_telemetry(
@@ -815,7 +855,7 @@ class SGDClassifier(BaseTransformer):
815
855
  drop_input_cols=self._drop_input_cols,
816
856
  expected_output_cols_type="float",
817
857
  )
818
- expected_output_cols = self._align_expected_output_names(
858
+ expected_output_cols, _ = self._align_expected_output(
819
859
  inference_method, dataset, expected_output_cols, output_cols_prefix
820
860
  )
821
861
 
@@ -883,7 +923,7 @@ class SGDClassifier(BaseTransformer):
883
923
  drop_input_cols=self._drop_input_cols,
884
924
  expected_output_cols_type="float",
885
925
  )
886
- expected_output_cols = self._align_expected_output_names(
926
+ expected_output_cols, _ = self._align_expected_output(
887
927
  inference_method, dataset, expected_output_cols, output_cols_prefix
888
928
  )
889
929
  elif isinstance(dataset, pd.DataFrame):
@@ -948,7 +988,7 @@ class SGDClassifier(BaseTransformer):
948
988
  drop_input_cols=self._drop_input_cols,
949
989
  expected_output_cols_type="float",
950
990
  )
951
- expected_output_cols = self._align_expected_output_names(
991
+ expected_output_cols, _ = self._align_expected_output(
952
992
  inference_method, dataset, expected_output_cols, output_cols_prefix
953
993
  )
954
994
 
@@ -1013,7 +1053,7 @@ class SGDClassifier(BaseTransformer):
1013
1053
  drop_input_cols = self._drop_input_cols,
1014
1054
  expected_output_cols_type="float",
1015
1055
  )
1016
- expected_output_cols = self._align_expected_output_names(
1056
+ expected_output_cols, _ = self._align_expected_output(
1017
1057
  inference_method, dataset, expected_output_cols, output_cols_prefix
1018
1058
  )
1019
1059