snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -506,12 +503,23 @@ class LGBMRegressor(BaseTransformer):
506
503
  autogenerated=self._autogenerated,
507
504
  subproject=_SUBPROJECT,
508
505
  )
509
- output_result, fitted_estimator = model_trainer.train_fit_predict(
510
- drop_input_cols=self._drop_input_cols,
511
- expected_output_cols_list=(
512
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
513
- ),
506
+ expected_output_cols = (
507
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
514
508
  )
509
+ if isinstance(dataset, DataFrame):
510
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
511
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
512
+ )
513
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
514
+ drop_input_cols=self._drop_input_cols,
515
+ expected_output_cols_list=expected_output_cols,
516
+ example_output_pd_df=example_output_pd_df,
517
+ )
518
+ else:
519
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
520
+ drop_input_cols=self._drop_input_cols,
521
+ expected_output_cols_list=expected_output_cols,
522
+ )
515
523
  self._sklearn_object = fitted_estimator
516
524
  self._is_fitted = True
517
525
  return output_result
@@ -534,6 +542,7 @@ class LGBMRegressor(BaseTransformer):
534
542
  """
535
543
  self._infer_input_output_cols(dataset)
536
544
  super()._check_dataset_type(dataset)
545
+
537
546
  model_trainer = ModelTrainerBuilder.build_fit_transform(
538
547
  estimator=self._sklearn_object,
539
548
  dataset=dataset,
@@ -590,12 +599,41 @@ class LGBMRegressor(BaseTransformer):
590
599
 
591
600
  return rv
592
601
 
593
- def _align_expected_output_names(
594
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
595
- ) -> List[str]:
602
+ def _align_expected_output(
603
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
604
+ ) -> Tuple[List[str], pd.DataFrame]:
605
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
606
+ and output dataframe with 1 line.
607
+ If the method is fit_predict, run 2 lines of data.
608
+ """
596
609
  # in case the inferred output column names dimension is different
597
610
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
598
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
611
+
612
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
613
+ # so change the minimum of number of rows to 2
614
+ num_examples = 2
615
+ statement_params = telemetry.get_function_usage_statement_params(
616
+ project=_PROJECT,
617
+ subproject=_SUBPROJECT,
618
+ function_name=telemetry.get_statement_params_full_func_name(
619
+ inspect.currentframe(), LGBMRegressor.__class__.__name__
620
+ ),
621
+ api_calls=[Session.call],
622
+ custom_tags={"autogen": True} if self._autogenerated else None,
623
+ )
624
+ if output_cols_prefix == "fit_predict_":
625
+ if hasattr(self._sklearn_object, "n_clusters"):
626
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
627
+ num_examples = self._sklearn_object.n_clusters
628
+ elif hasattr(self._sklearn_object, "min_samples"):
629
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
630
+ num_examples = self._sklearn_object.min_samples
631
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
632
+ # LocalOutlierFactor expects n_neighbors <= n_samples
633
+ num_examples = self._sklearn_object.n_neighbors
634
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
635
+ else:
636
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
599
637
 
600
638
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
601
639
  # seen during the fit.
@@ -607,12 +645,14 @@ class LGBMRegressor(BaseTransformer):
607
645
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
608
646
  if self.sample_weight_col:
609
647
  output_df_columns_set -= set(self.sample_weight_col)
648
+
610
649
  # if the dimension of inferred output column names is correct; use it
611
650
  if len(expected_output_cols_list) == len(output_df_columns_set):
612
- return expected_output_cols_list
651
+ return expected_output_cols_list, output_df_pd
613
652
  # otherwise, use the sklearn estimator's output
614
653
  else:
615
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
654
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
655
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
616
656
 
617
657
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
618
658
  @telemetry.send_api_usage_telemetry(
@@ -658,7 +698,7 @@ class LGBMRegressor(BaseTransformer):
658
698
  drop_input_cols=self._drop_input_cols,
659
699
  expected_output_cols_type="float",
660
700
  )
661
- expected_output_cols = self._align_expected_output_names(
701
+ expected_output_cols, _ = self._align_expected_output(
662
702
  inference_method, dataset, expected_output_cols, output_cols_prefix
663
703
  )
664
704
 
@@ -724,7 +764,7 @@ class LGBMRegressor(BaseTransformer):
724
764
  drop_input_cols=self._drop_input_cols,
725
765
  expected_output_cols_type="float",
726
766
  )
727
- expected_output_cols = self._align_expected_output_names(
767
+ expected_output_cols, _ = self._align_expected_output(
728
768
  inference_method, dataset, expected_output_cols, output_cols_prefix
729
769
  )
730
770
  elif isinstance(dataset, pd.DataFrame):
@@ -787,7 +827,7 @@ class LGBMRegressor(BaseTransformer):
787
827
  drop_input_cols=self._drop_input_cols,
788
828
  expected_output_cols_type="float",
789
829
  )
790
- expected_output_cols = self._align_expected_output_names(
830
+ expected_output_cols, _ = self._align_expected_output(
791
831
  inference_method, dataset, expected_output_cols, output_cols_prefix
792
832
  )
793
833
 
@@ -852,7 +892,7 @@ class LGBMRegressor(BaseTransformer):
852
892
  drop_input_cols = self._drop_input_cols,
853
893
  expected_output_cols_type="float",
854
894
  )
855
- expected_output_cols = self._align_expected_output_names(
895
+ expected_output_cols, _ = self._align_expected_output(
856
896
  inference_method, dataset, expected_output_cols, output_cols_prefix
857
897
  )
858
898
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -531,12 +528,23 @@ class ARDRegression(BaseTransformer):
531
528
  autogenerated=self._autogenerated,
532
529
  subproject=_SUBPROJECT,
533
530
  )
534
- output_result, fitted_estimator = model_trainer.train_fit_predict(
535
- drop_input_cols=self._drop_input_cols,
536
- expected_output_cols_list=(
537
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
538
- ),
531
+ expected_output_cols = (
532
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
539
533
  )
534
+ if isinstance(dataset, DataFrame):
535
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
536
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
537
+ )
538
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
539
+ drop_input_cols=self._drop_input_cols,
540
+ expected_output_cols_list=expected_output_cols,
541
+ example_output_pd_df=example_output_pd_df,
542
+ )
543
+ else:
544
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
545
+ drop_input_cols=self._drop_input_cols,
546
+ expected_output_cols_list=expected_output_cols,
547
+ )
540
548
  self._sklearn_object = fitted_estimator
541
549
  self._is_fitted = True
542
550
  return output_result
@@ -559,6 +567,7 @@ class ARDRegression(BaseTransformer):
559
567
  """
560
568
  self._infer_input_output_cols(dataset)
561
569
  super()._check_dataset_type(dataset)
570
+
562
571
  model_trainer = ModelTrainerBuilder.build_fit_transform(
563
572
  estimator=self._sklearn_object,
564
573
  dataset=dataset,
@@ -615,12 +624,41 @@ class ARDRegression(BaseTransformer):
615
624
 
616
625
  return rv
617
626
 
618
- def _align_expected_output_names(
619
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
620
- ) -> List[str]:
627
+ def _align_expected_output(
628
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
629
+ ) -> Tuple[List[str], pd.DataFrame]:
630
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
631
+ and output dataframe with 1 line.
632
+ If the method is fit_predict, run 2 lines of data.
633
+ """
621
634
  # in case the inferred output column names dimension is different
622
635
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
623
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
636
+
637
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
638
+ # so change the minimum of number of rows to 2
639
+ num_examples = 2
640
+ statement_params = telemetry.get_function_usage_statement_params(
641
+ project=_PROJECT,
642
+ subproject=_SUBPROJECT,
643
+ function_name=telemetry.get_statement_params_full_func_name(
644
+ inspect.currentframe(), ARDRegression.__class__.__name__
645
+ ),
646
+ api_calls=[Session.call],
647
+ custom_tags={"autogen": True} if self._autogenerated else None,
648
+ )
649
+ if output_cols_prefix == "fit_predict_":
650
+ if hasattr(self._sklearn_object, "n_clusters"):
651
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
652
+ num_examples = self._sklearn_object.n_clusters
653
+ elif hasattr(self._sklearn_object, "min_samples"):
654
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
655
+ num_examples = self._sklearn_object.min_samples
656
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
657
+ # LocalOutlierFactor expects n_neighbors <= n_samples
658
+ num_examples = self._sklearn_object.n_neighbors
659
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
660
+ else:
661
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
624
662
 
625
663
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
626
664
  # seen during the fit.
@@ -632,12 +670,14 @@ class ARDRegression(BaseTransformer):
632
670
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
633
671
  if self.sample_weight_col:
634
672
  output_df_columns_set -= set(self.sample_weight_col)
673
+
635
674
  # if the dimension of inferred output column names is correct; use it
636
675
  if len(expected_output_cols_list) == len(output_df_columns_set):
637
- return expected_output_cols_list
676
+ return expected_output_cols_list, output_df_pd
638
677
  # otherwise, use the sklearn estimator's output
639
678
  else:
640
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
679
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
680
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
641
681
 
642
682
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
643
683
  @telemetry.send_api_usage_telemetry(
@@ -683,7 +723,7 @@ class ARDRegression(BaseTransformer):
683
723
  drop_input_cols=self._drop_input_cols,
684
724
  expected_output_cols_type="float",
685
725
  )
686
- expected_output_cols = self._align_expected_output_names(
726
+ expected_output_cols, _ = self._align_expected_output(
687
727
  inference_method, dataset, expected_output_cols, output_cols_prefix
688
728
  )
689
729
 
@@ -749,7 +789,7 @@ class ARDRegression(BaseTransformer):
749
789
  drop_input_cols=self._drop_input_cols,
750
790
  expected_output_cols_type="float",
751
791
  )
752
- expected_output_cols = self._align_expected_output_names(
792
+ expected_output_cols, _ = self._align_expected_output(
753
793
  inference_method, dataset, expected_output_cols, output_cols_prefix
754
794
  )
755
795
  elif isinstance(dataset, pd.DataFrame):
@@ -812,7 +852,7 @@ class ARDRegression(BaseTransformer):
812
852
  drop_input_cols=self._drop_input_cols,
813
853
  expected_output_cols_type="float",
814
854
  )
815
- expected_output_cols = self._align_expected_output_names(
855
+ expected_output_cols, _ = self._align_expected_output(
816
856
  inference_method, dataset, expected_output_cols, output_cols_prefix
817
857
  )
818
858
 
@@ -877,7 +917,7 @@ class ARDRegression(BaseTransformer):
877
917
  drop_input_cols = self._drop_input_cols,
878
918
  expected_output_cols_type="float",
879
919
  )
880
- expected_output_cols = self._align_expected_output_names(
920
+ expected_output_cols, _ = self._align_expected_output(
881
921
  inference_method, dataset, expected_output_cols, output_cols_prefix
882
922
  )
883
923
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -542,12 +539,23 @@ class BayesianRidge(BaseTransformer):
542
539
  autogenerated=self._autogenerated,
543
540
  subproject=_SUBPROJECT,
544
541
  )
545
- output_result, fitted_estimator = model_trainer.train_fit_predict(
546
- drop_input_cols=self._drop_input_cols,
547
- expected_output_cols_list=(
548
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
549
- ),
542
+ expected_output_cols = (
543
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
550
544
  )
545
+ if isinstance(dataset, DataFrame):
546
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
547
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
548
+ )
549
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
550
+ drop_input_cols=self._drop_input_cols,
551
+ expected_output_cols_list=expected_output_cols,
552
+ example_output_pd_df=example_output_pd_df,
553
+ )
554
+ else:
555
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
556
+ drop_input_cols=self._drop_input_cols,
557
+ expected_output_cols_list=expected_output_cols,
558
+ )
551
559
  self._sklearn_object = fitted_estimator
552
560
  self._is_fitted = True
553
561
  return output_result
@@ -570,6 +578,7 @@ class BayesianRidge(BaseTransformer):
570
578
  """
571
579
  self._infer_input_output_cols(dataset)
572
580
  super()._check_dataset_type(dataset)
581
+
573
582
  model_trainer = ModelTrainerBuilder.build_fit_transform(
574
583
  estimator=self._sklearn_object,
575
584
  dataset=dataset,
@@ -626,12 +635,41 @@ class BayesianRidge(BaseTransformer):
626
635
 
627
636
  return rv
628
637
 
629
- def _align_expected_output_names(
630
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
631
- ) -> List[str]:
638
+ def _align_expected_output(
639
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
640
+ ) -> Tuple[List[str], pd.DataFrame]:
641
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
642
+ and output dataframe with 1 line.
643
+ If the method is fit_predict, run 2 lines of data.
644
+ """
632
645
  # in case the inferred output column names dimension is different
633
646
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
634
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
647
+
648
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
649
+ # so change the minimum of number of rows to 2
650
+ num_examples = 2
651
+ statement_params = telemetry.get_function_usage_statement_params(
652
+ project=_PROJECT,
653
+ subproject=_SUBPROJECT,
654
+ function_name=telemetry.get_statement_params_full_func_name(
655
+ inspect.currentframe(), BayesianRidge.__class__.__name__
656
+ ),
657
+ api_calls=[Session.call],
658
+ custom_tags={"autogen": True} if self._autogenerated else None,
659
+ )
660
+ if output_cols_prefix == "fit_predict_":
661
+ if hasattr(self._sklearn_object, "n_clusters"):
662
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
663
+ num_examples = self._sklearn_object.n_clusters
664
+ elif hasattr(self._sklearn_object, "min_samples"):
665
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
666
+ num_examples = self._sklearn_object.min_samples
667
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
668
+ # LocalOutlierFactor expects n_neighbors <= n_samples
669
+ num_examples = self._sklearn_object.n_neighbors
670
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
671
+ else:
672
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
635
673
 
636
674
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
637
675
  # seen during the fit.
@@ -643,12 +681,14 @@ class BayesianRidge(BaseTransformer):
643
681
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
644
682
  if self.sample_weight_col:
645
683
  output_df_columns_set -= set(self.sample_weight_col)
684
+
646
685
  # if the dimension of inferred output column names is correct; use it
647
686
  if len(expected_output_cols_list) == len(output_df_columns_set):
648
- return expected_output_cols_list
687
+ return expected_output_cols_list, output_df_pd
649
688
  # otherwise, use the sklearn estimator's output
650
689
  else:
651
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
690
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
691
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
652
692
 
653
693
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
654
694
  @telemetry.send_api_usage_telemetry(
@@ -694,7 +734,7 @@ class BayesianRidge(BaseTransformer):
694
734
  drop_input_cols=self._drop_input_cols,
695
735
  expected_output_cols_type="float",
696
736
  )
697
- expected_output_cols = self._align_expected_output_names(
737
+ expected_output_cols, _ = self._align_expected_output(
698
738
  inference_method, dataset, expected_output_cols, output_cols_prefix
699
739
  )
700
740
 
@@ -760,7 +800,7 @@ class BayesianRidge(BaseTransformer):
760
800
  drop_input_cols=self._drop_input_cols,
761
801
  expected_output_cols_type="float",
762
802
  )
763
- expected_output_cols = self._align_expected_output_names(
803
+ expected_output_cols, _ = self._align_expected_output(
764
804
  inference_method, dataset, expected_output_cols, output_cols_prefix
765
805
  )
766
806
  elif isinstance(dataset, pd.DataFrame):
@@ -823,7 +863,7 @@ class BayesianRidge(BaseTransformer):
823
863
  drop_input_cols=self._drop_input_cols,
824
864
  expected_output_cols_type="float",
825
865
  )
826
- expected_output_cols = self._align_expected_output_names(
866
+ expected_output_cols, _ = self._align_expected_output(
827
867
  inference_method, dataset, expected_output_cols, output_cols_prefix
828
868
  )
829
869
 
@@ -888,7 +928,7 @@ class BayesianRidge(BaseTransformer):
888
928
  drop_input_cols = self._drop_input_cols,
889
929
  expected_output_cols_type="float",
890
930
  )
891
- expected_output_cols = self._align_expected_output_names(
931
+ expected_output_cols, _ = self._align_expected_output(
892
932
  inference_method, dataset, expected_output_cols, output_cols_prefix
893
933
  )
894
934
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -541,12 +538,23 @@ class ElasticNet(BaseTransformer):
541
538
  autogenerated=self._autogenerated,
542
539
  subproject=_SUBPROJECT,
543
540
  )
544
- output_result, fitted_estimator = model_trainer.train_fit_predict(
545
- drop_input_cols=self._drop_input_cols,
546
- expected_output_cols_list=(
547
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
548
- ),
541
+ expected_output_cols = (
542
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
549
543
  )
544
+ if isinstance(dataset, DataFrame):
545
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
546
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
547
+ )
548
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
549
+ drop_input_cols=self._drop_input_cols,
550
+ expected_output_cols_list=expected_output_cols,
551
+ example_output_pd_df=example_output_pd_df,
552
+ )
553
+ else:
554
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
555
+ drop_input_cols=self._drop_input_cols,
556
+ expected_output_cols_list=expected_output_cols,
557
+ )
550
558
  self._sklearn_object = fitted_estimator
551
559
  self._is_fitted = True
552
560
  return output_result
@@ -569,6 +577,7 @@ class ElasticNet(BaseTransformer):
569
577
  """
570
578
  self._infer_input_output_cols(dataset)
571
579
  super()._check_dataset_type(dataset)
580
+
572
581
  model_trainer = ModelTrainerBuilder.build_fit_transform(
573
582
  estimator=self._sklearn_object,
574
583
  dataset=dataset,
@@ -625,12 +634,41 @@ class ElasticNet(BaseTransformer):
625
634
 
626
635
  return rv
627
636
 
628
- def _align_expected_output_names(
629
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
630
- ) -> List[str]:
637
+ def _align_expected_output(
638
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
639
+ ) -> Tuple[List[str], pd.DataFrame]:
640
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
641
+ and output dataframe with 1 line.
642
+ If the method is fit_predict, run 2 lines of data.
643
+ """
631
644
  # in case the inferred output column names dimension is different
632
645
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
633
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
646
+
647
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
648
+ # so change the minimum of number of rows to 2
649
+ num_examples = 2
650
+ statement_params = telemetry.get_function_usage_statement_params(
651
+ project=_PROJECT,
652
+ subproject=_SUBPROJECT,
653
+ function_name=telemetry.get_statement_params_full_func_name(
654
+ inspect.currentframe(), ElasticNet.__class__.__name__
655
+ ),
656
+ api_calls=[Session.call],
657
+ custom_tags={"autogen": True} if self._autogenerated else None,
658
+ )
659
+ if output_cols_prefix == "fit_predict_":
660
+ if hasattr(self._sklearn_object, "n_clusters"):
661
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
662
+ num_examples = self._sklearn_object.n_clusters
663
+ elif hasattr(self._sklearn_object, "min_samples"):
664
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
665
+ num_examples = self._sklearn_object.min_samples
666
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
667
+ # LocalOutlierFactor expects n_neighbors <= n_samples
668
+ num_examples = self._sklearn_object.n_neighbors
669
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
670
+ else:
671
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
634
672
 
635
673
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
636
674
  # seen during the fit.
@@ -642,12 +680,14 @@ class ElasticNet(BaseTransformer):
642
680
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
643
681
  if self.sample_weight_col:
644
682
  output_df_columns_set -= set(self.sample_weight_col)
683
+
645
684
  # if the dimension of inferred output column names is correct; use it
646
685
  if len(expected_output_cols_list) == len(output_df_columns_set):
647
- return expected_output_cols_list
686
+ return expected_output_cols_list, output_df_pd
648
687
  # otherwise, use the sklearn estimator's output
649
688
  else:
650
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
689
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
690
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
651
691
 
652
692
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
653
693
  @telemetry.send_api_usage_telemetry(
@@ -693,7 +733,7 @@ class ElasticNet(BaseTransformer):
693
733
  drop_input_cols=self._drop_input_cols,
694
734
  expected_output_cols_type="float",
695
735
  )
696
- expected_output_cols = self._align_expected_output_names(
736
+ expected_output_cols, _ = self._align_expected_output(
697
737
  inference_method, dataset, expected_output_cols, output_cols_prefix
698
738
  )
699
739
 
@@ -759,7 +799,7 @@ class ElasticNet(BaseTransformer):
759
799
  drop_input_cols=self._drop_input_cols,
760
800
  expected_output_cols_type="float",
761
801
  )
762
- expected_output_cols = self._align_expected_output_names(
802
+ expected_output_cols, _ = self._align_expected_output(
763
803
  inference_method, dataset, expected_output_cols, output_cols_prefix
764
804
  )
765
805
  elif isinstance(dataset, pd.DataFrame):
@@ -822,7 +862,7 @@ class ElasticNet(BaseTransformer):
822
862
  drop_input_cols=self._drop_input_cols,
823
863
  expected_output_cols_type="float",
824
864
  )
825
- expected_output_cols = self._align_expected_output_names(
865
+ expected_output_cols, _ = self._align_expected_output(
826
866
  inference_method, dataset, expected_output_cols, output_cols_prefix
827
867
  )
828
868
 
@@ -887,7 +927,7 @@ class ElasticNet(BaseTransformer):
887
927
  drop_input_cols = self._drop_input_cols,
888
928
  expected_output_cols_type="float",
889
929
  )
890
- expected_output_cols = self._align_expected_output_names(
930
+ expected_output_cols, _ = self._align_expected_output(
891
931
  inference_method, dataset, expected_output_cols, output_cols_prefix
892
932
  )
893
933