snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -542,12 +539,23 @@ class LarsCV(BaseTransformer):
542
539
  autogenerated=self._autogenerated,
543
540
  subproject=_SUBPROJECT,
544
541
  )
545
- output_result, fitted_estimator = model_trainer.train_fit_predict(
546
- drop_input_cols=self._drop_input_cols,
547
- expected_output_cols_list=(
548
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
549
- ),
542
+ expected_output_cols = (
543
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
550
544
  )
545
+ if isinstance(dataset, DataFrame):
546
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
547
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
548
+ )
549
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
550
+ drop_input_cols=self._drop_input_cols,
551
+ expected_output_cols_list=expected_output_cols,
552
+ example_output_pd_df=example_output_pd_df,
553
+ )
554
+ else:
555
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
556
+ drop_input_cols=self._drop_input_cols,
557
+ expected_output_cols_list=expected_output_cols,
558
+ )
551
559
  self._sklearn_object = fitted_estimator
552
560
  self._is_fitted = True
553
561
  return output_result
@@ -570,6 +578,7 @@ class LarsCV(BaseTransformer):
570
578
  """
571
579
  self._infer_input_output_cols(dataset)
572
580
  super()._check_dataset_type(dataset)
581
+
573
582
  model_trainer = ModelTrainerBuilder.build_fit_transform(
574
583
  estimator=self._sklearn_object,
575
584
  dataset=dataset,
@@ -626,12 +635,41 @@ class LarsCV(BaseTransformer):
626
635
 
627
636
  return rv
628
637
 
629
- def _align_expected_output_names(
630
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
631
- ) -> List[str]:
638
+ def _align_expected_output(
639
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
640
+ ) -> Tuple[List[str], pd.DataFrame]:
641
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
642
+ and output dataframe with 1 line.
643
+ If the method is fit_predict, run 2 lines of data.
644
+ """
632
645
  # in case the inferred output column names dimension is different
633
646
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
634
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
647
+
648
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
649
+ # so change the minimum of number of rows to 2
650
+ num_examples = 2
651
+ statement_params = telemetry.get_function_usage_statement_params(
652
+ project=_PROJECT,
653
+ subproject=_SUBPROJECT,
654
+ function_name=telemetry.get_statement_params_full_func_name(
655
+ inspect.currentframe(), LarsCV.__class__.__name__
656
+ ),
657
+ api_calls=[Session.call],
658
+ custom_tags={"autogen": True} if self._autogenerated else None,
659
+ )
660
+ if output_cols_prefix == "fit_predict_":
661
+ if hasattr(self._sklearn_object, "n_clusters"):
662
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
663
+ num_examples = self._sklearn_object.n_clusters
664
+ elif hasattr(self._sklearn_object, "min_samples"):
665
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
666
+ num_examples = self._sklearn_object.min_samples
667
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
668
+ # LocalOutlierFactor expects n_neighbors <= n_samples
669
+ num_examples = self._sklearn_object.n_neighbors
670
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
671
+ else:
672
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
635
673
 
636
674
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
637
675
  # seen during the fit.
@@ -643,12 +681,14 @@ class LarsCV(BaseTransformer):
643
681
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
644
682
  if self.sample_weight_col:
645
683
  output_df_columns_set -= set(self.sample_weight_col)
684
+
646
685
  # if the dimension of inferred output column names is correct; use it
647
686
  if len(expected_output_cols_list) == len(output_df_columns_set):
648
- return expected_output_cols_list
687
+ return expected_output_cols_list, output_df_pd
649
688
  # otherwise, use the sklearn estimator's output
650
689
  else:
651
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
690
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
691
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
652
692
 
653
693
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
654
694
  @telemetry.send_api_usage_telemetry(
@@ -694,7 +734,7 @@ class LarsCV(BaseTransformer):
694
734
  drop_input_cols=self._drop_input_cols,
695
735
  expected_output_cols_type="float",
696
736
  )
697
- expected_output_cols = self._align_expected_output_names(
737
+ expected_output_cols, _ = self._align_expected_output(
698
738
  inference_method, dataset, expected_output_cols, output_cols_prefix
699
739
  )
700
740
 
@@ -760,7 +800,7 @@ class LarsCV(BaseTransformer):
760
800
  drop_input_cols=self._drop_input_cols,
761
801
  expected_output_cols_type="float",
762
802
  )
763
- expected_output_cols = self._align_expected_output_names(
803
+ expected_output_cols, _ = self._align_expected_output(
764
804
  inference_method, dataset, expected_output_cols, output_cols_prefix
765
805
  )
766
806
  elif isinstance(dataset, pd.DataFrame):
@@ -823,7 +863,7 @@ class LarsCV(BaseTransformer):
823
863
  drop_input_cols=self._drop_input_cols,
824
864
  expected_output_cols_type="float",
825
865
  )
826
- expected_output_cols = self._align_expected_output_names(
866
+ expected_output_cols, _ = self._align_expected_output(
827
867
  inference_method, dataset, expected_output_cols, output_cols_prefix
828
868
  )
829
869
 
@@ -888,7 +928,7 @@ class LarsCV(BaseTransformer):
888
928
  drop_input_cols = self._drop_input_cols,
889
929
  expected_output_cols_type="float",
890
930
  )
891
- expected_output_cols = self._align_expected_output_names(
931
+ expected_output_cols, _ = self._align_expected_output(
892
932
  inference_method, dataset, expected_output_cols, output_cols_prefix
893
933
  )
894
934
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -535,12 +532,23 @@ class Lasso(BaseTransformer):
535
532
  autogenerated=self._autogenerated,
536
533
  subproject=_SUBPROJECT,
537
534
  )
538
- output_result, fitted_estimator = model_trainer.train_fit_predict(
539
- drop_input_cols=self._drop_input_cols,
540
- expected_output_cols_list=(
541
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
542
- ),
535
+ expected_output_cols = (
536
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
543
537
  )
538
+ if isinstance(dataset, DataFrame):
539
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
540
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
541
+ )
542
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
543
+ drop_input_cols=self._drop_input_cols,
544
+ expected_output_cols_list=expected_output_cols,
545
+ example_output_pd_df=example_output_pd_df,
546
+ )
547
+ else:
548
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
549
+ drop_input_cols=self._drop_input_cols,
550
+ expected_output_cols_list=expected_output_cols,
551
+ )
544
552
  self._sklearn_object = fitted_estimator
545
553
  self._is_fitted = True
546
554
  return output_result
@@ -563,6 +571,7 @@ class Lasso(BaseTransformer):
563
571
  """
564
572
  self._infer_input_output_cols(dataset)
565
573
  super()._check_dataset_type(dataset)
574
+
566
575
  model_trainer = ModelTrainerBuilder.build_fit_transform(
567
576
  estimator=self._sklearn_object,
568
577
  dataset=dataset,
@@ -619,12 +628,41 @@ class Lasso(BaseTransformer):
619
628
 
620
629
  return rv
621
630
 
622
- def _align_expected_output_names(
623
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
624
- ) -> List[str]:
631
+ def _align_expected_output(
632
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
633
+ ) -> Tuple[List[str], pd.DataFrame]:
634
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
635
+ and output dataframe with 1 line.
636
+ If the method is fit_predict, run 2 lines of data.
637
+ """
625
638
  # in case the inferred output column names dimension is different
626
639
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
627
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
640
+
641
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
642
+ # so change the minimum of number of rows to 2
643
+ num_examples = 2
644
+ statement_params = telemetry.get_function_usage_statement_params(
645
+ project=_PROJECT,
646
+ subproject=_SUBPROJECT,
647
+ function_name=telemetry.get_statement_params_full_func_name(
648
+ inspect.currentframe(), Lasso.__class__.__name__
649
+ ),
650
+ api_calls=[Session.call],
651
+ custom_tags={"autogen": True} if self._autogenerated else None,
652
+ )
653
+ if output_cols_prefix == "fit_predict_":
654
+ if hasattr(self._sklearn_object, "n_clusters"):
655
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
656
+ num_examples = self._sklearn_object.n_clusters
657
+ elif hasattr(self._sklearn_object, "min_samples"):
658
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
659
+ num_examples = self._sklearn_object.min_samples
660
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
661
+ # LocalOutlierFactor expects n_neighbors <= n_samples
662
+ num_examples = self._sklearn_object.n_neighbors
663
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
664
+ else:
665
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
628
666
 
629
667
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
630
668
  # seen during the fit.
@@ -636,12 +674,14 @@ class Lasso(BaseTransformer):
636
674
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
637
675
  if self.sample_weight_col:
638
676
  output_df_columns_set -= set(self.sample_weight_col)
677
+
639
678
  # if the dimension of inferred output column names is correct; use it
640
679
  if len(expected_output_cols_list) == len(output_df_columns_set):
641
- return expected_output_cols_list
680
+ return expected_output_cols_list, output_df_pd
642
681
  # otherwise, use the sklearn estimator's output
643
682
  else:
644
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
683
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
684
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
645
685
 
646
686
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
647
687
  @telemetry.send_api_usage_telemetry(
@@ -687,7 +727,7 @@ class Lasso(BaseTransformer):
687
727
  drop_input_cols=self._drop_input_cols,
688
728
  expected_output_cols_type="float",
689
729
  )
690
- expected_output_cols = self._align_expected_output_names(
730
+ expected_output_cols, _ = self._align_expected_output(
691
731
  inference_method, dataset, expected_output_cols, output_cols_prefix
692
732
  )
693
733
 
@@ -753,7 +793,7 @@ class Lasso(BaseTransformer):
753
793
  drop_input_cols=self._drop_input_cols,
754
794
  expected_output_cols_type="float",
755
795
  )
756
- expected_output_cols = self._align_expected_output_names(
796
+ expected_output_cols, _ = self._align_expected_output(
757
797
  inference_method, dataset, expected_output_cols, output_cols_prefix
758
798
  )
759
799
  elif isinstance(dataset, pd.DataFrame):
@@ -816,7 +856,7 @@ class Lasso(BaseTransformer):
816
856
  drop_input_cols=self._drop_input_cols,
817
857
  expected_output_cols_type="float",
818
858
  )
819
- expected_output_cols = self._align_expected_output_names(
859
+ expected_output_cols, _ = self._align_expected_output(
820
860
  inference_method, dataset, expected_output_cols, output_cols_prefix
821
861
  )
822
862
 
@@ -881,7 +921,7 @@ class Lasso(BaseTransformer):
881
921
  drop_input_cols = self._drop_input_cols,
882
922
  expected_output_cols_type="float",
883
923
  )
884
- expected_output_cols = self._align_expected_output_names(
924
+ expected_output_cols, _ = self._align_expected_output(
885
925
  inference_method, dataset, expected_output_cols, output_cols_prefix
886
926
  )
887
927
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -563,12 +560,23 @@ class LassoCV(BaseTransformer):
563
560
  autogenerated=self._autogenerated,
564
561
  subproject=_SUBPROJECT,
565
562
  )
566
- output_result, fitted_estimator = model_trainer.train_fit_predict(
567
- drop_input_cols=self._drop_input_cols,
568
- expected_output_cols_list=(
569
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
570
- ),
563
+ expected_output_cols = (
564
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
571
565
  )
566
+ if isinstance(dataset, DataFrame):
567
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
568
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
569
+ )
570
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
571
+ drop_input_cols=self._drop_input_cols,
572
+ expected_output_cols_list=expected_output_cols,
573
+ example_output_pd_df=example_output_pd_df,
574
+ )
575
+ else:
576
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
577
+ drop_input_cols=self._drop_input_cols,
578
+ expected_output_cols_list=expected_output_cols,
579
+ )
572
580
  self._sklearn_object = fitted_estimator
573
581
  self._is_fitted = True
574
582
  return output_result
@@ -591,6 +599,7 @@ class LassoCV(BaseTransformer):
591
599
  """
592
600
  self._infer_input_output_cols(dataset)
593
601
  super()._check_dataset_type(dataset)
602
+
594
603
  model_trainer = ModelTrainerBuilder.build_fit_transform(
595
604
  estimator=self._sklearn_object,
596
605
  dataset=dataset,
@@ -647,12 +656,41 @@ class LassoCV(BaseTransformer):
647
656
 
648
657
  return rv
649
658
 
650
- def _align_expected_output_names(
651
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
652
- ) -> List[str]:
659
+ def _align_expected_output(
660
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
661
+ ) -> Tuple[List[str], pd.DataFrame]:
662
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
663
+ and output dataframe with 1 line.
664
+ If the method is fit_predict, run 2 lines of data.
665
+ """
653
666
  # in case the inferred output column names dimension is different
654
667
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
655
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
668
+
669
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
670
+ # so change the minimum of number of rows to 2
671
+ num_examples = 2
672
+ statement_params = telemetry.get_function_usage_statement_params(
673
+ project=_PROJECT,
674
+ subproject=_SUBPROJECT,
675
+ function_name=telemetry.get_statement_params_full_func_name(
676
+ inspect.currentframe(), LassoCV.__class__.__name__
677
+ ),
678
+ api_calls=[Session.call],
679
+ custom_tags={"autogen": True} if self._autogenerated else None,
680
+ )
681
+ if output_cols_prefix == "fit_predict_":
682
+ if hasattr(self._sklearn_object, "n_clusters"):
683
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
684
+ num_examples = self._sklearn_object.n_clusters
685
+ elif hasattr(self._sklearn_object, "min_samples"):
686
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
687
+ num_examples = self._sklearn_object.min_samples
688
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
689
+ # LocalOutlierFactor expects n_neighbors <= n_samples
690
+ num_examples = self._sklearn_object.n_neighbors
691
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
692
+ else:
693
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
656
694
 
657
695
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
658
696
  # seen during the fit.
@@ -664,12 +702,14 @@ class LassoCV(BaseTransformer):
664
702
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
665
703
  if self.sample_weight_col:
666
704
  output_df_columns_set -= set(self.sample_weight_col)
705
+
667
706
  # if the dimension of inferred output column names is correct; use it
668
707
  if len(expected_output_cols_list) == len(output_df_columns_set):
669
- return expected_output_cols_list
708
+ return expected_output_cols_list, output_df_pd
670
709
  # otherwise, use the sklearn estimator's output
671
710
  else:
672
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
711
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
712
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
673
713
 
674
714
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
675
715
  @telemetry.send_api_usage_telemetry(
@@ -715,7 +755,7 @@ class LassoCV(BaseTransformer):
715
755
  drop_input_cols=self._drop_input_cols,
716
756
  expected_output_cols_type="float",
717
757
  )
718
- expected_output_cols = self._align_expected_output_names(
758
+ expected_output_cols, _ = self._align_expected_output(
719
759
  inference_method, dataset, expected_output_cols, output_cols_prefix
720
760
  )
721
761
 
@@ -781,7 +821,7 @@ class LassoCV(BaseTransformer):
781
821
  drop_input_cols=self._drop_input_cols,
782
822
  expected_output_cols_type="float",
783
823
  )
784
- expected_output_cols = self._align_expected_output_names(
824
+ expected_output_cols, _ = self._align_expected_output(
785
825
  inference_method, dataset, expected_output_cols, output_cols_prefix
786
826
  )
787
827
  elif isinstance(dataset, pd.DataFrame):
@@ -844,7 +884,7 @@ class LassoCV(BaseTransformer):
844
884
  drop_input_cols=self._drop_input_cols,
845
885
  expected_output_cols_type="float",
846
886
  )
847
- expected_output_cols = self._align_expected_output_names(
887
+ expected_output_cols, _ = self._align_expected_output(
848
888
  inference_method, dataset, expected_output_cols, output_cols_prefix
849
889
  )
850
890
 
@@ -909,7 +949,7 @@ class LassoCV(BaseTransformer):
909
949
  drop_input_cols = self._drop_input_cols,
910
950
  expected_output_cols_type="float",
911
951
  )
912
- expected_output_cols = self._align_expected_output_names(
952
+ expected_output_cols, _ = self._align_expected_output(
913
953
  inference_method, dataset, expected_output_cols, output_cols_prefix
914
954
  )
915
955
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -555,12 +552,23 @@ class LassoLars(BaseTransformer):
555
552
  autogenerated=self._autogenerated,
556
553
  subproject=_SUBPROJECT,
557
554
  )
558
- output_result, fitted_estimator = model_trainer.train_fit_predict(
559
- drop_input_cols=self._drop_input_cols,
560
- expected_output_cols_list=(
561
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
562
- ),
555
+ expected_output_cols = (
556
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
563
557
  )
558
+ if isinstance(dataset, DataFrame):
559
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
560
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
561
+ )
562
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
563
+ drop_input_cols=self._drop_input_cols,
564
+ expected_output_cols_list=expected_output_cols,
565
+ example_output_pd_df=example_output_pd_df,
566
+ )
567
+ else:
568
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
569
+ drop_input_cols=self._drop_input_cols,
570
+ expected_output_cols_list=expected_output_cols,
571
+ )
564
572
  self._sklearn_object = fitted_estimator
565
573
  self._is_fitted = True
566
574
  return output_result
@@ -583,6 +591,7 @@ class LassoLars(BaseTransformer):
583
591
  """
584
592
  self._infer_input_output_cols(dataset)
585
593
  super()._check_dataset_type(dataset)
594
+
586
595
  model_trainer = ModelTrainerBuilder.build_fit_transform(
587
596
  estimator=self._sklearn_object,
588
597
  dataset=dataset,
@@ -639,12 +648,41 @@ class LassoLars(BaseTransformer):
639
648
 
640
649
  return rv
641
650
 
642
- def _align_expected_output_names(
643
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
644
- ) -> List[str]:
651
+ def _align_expected_output(
652
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
653
+ ) -> Tuple[List[str], pd.DataFrame]:
654
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
655
+ and output dataframe with 1 line.
656
+ If the method is fit_predict, run 2 lines of data.
657
+ """
645
658
  # in case the inferred output column names dimension is different
646
659
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
647
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
660
+
661
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
662
+ # so change the minimum of number of rows to 2
663
+ num_examples = 2
664
+ statement_params = telemetry.get_function_usage_statement_params(
665
+ project=_PROJECT,
666
+ subproject=_SUBPROJECT,
667
+ function_name=telemetry.get_statement_params_full_func_name(
668
+ inspect.currentframe(), LassoLars.__class__.__name__
669
+ ),
670
+ api_calls=[Session.call],
671
+ custom_tags={"autogen": True} if self._autogenerated else None,
672
+ )
673
+ if output_cols_prefix == "fit_predict_":
674
+ if hasattr(self._sklearn_object, "n_clusters"):
675
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
676
+ num_examples = self._sklearn_object.n_clusters
677
+ elif hasattr(self._sklearn_object, "min_samples"):
678
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
679
+ num_examples = self._sklearn_object.min_samples
680
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
681
+ # LocalOutlierFactor expects n_neighbors <= n_samples
682
+ num_examples = self._sklearn_object.n_neighbors
683
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
684
+ else:
685
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
648
686
 
649
687
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
650
688
  # seen during the fit.
@@ -656,12 +694,14 @@ class LassoLars(BaseTransformer):
656
694
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
657
695
  if self.sample_weight_col:
658
696
  output_df_columns_set -= set(self.sample_weight_col)
697
+
659
698
  # if the dimension of inferred output column names is correct; use it
660
699
  if len(expected_output_cols_list) == len(output_df_columns_set):
661
- return expected_output_cols_list
700
+ return expected_output_cols_list, output_df_pd
662
701
  # otherwise, use the sklearn estimator's output
663
702
  else:
664
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
703
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
704
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
665
705
 
666
706
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
667
707
  @telemetry.send_api_usage_telemetry(
@@ -707,7 +747,7 @@ class LassoLars(BaseTransformer):
707
747
  drop_input_cols=self._drop_input_cols,
708
748
  expected_output_cols_type="float",
709
749
  )
710
- expected_output_cols = self._align_expected_output_names(
750
+ expected_output_cols, _ = self._align_expected_output(
711
751
  inference_method, dataset, expected_output_cols, output_cols_prefix
712
752
  )
713
753
 
@@ -773,7 +813,7 @@ class LassoLars(BaseTransformer):
773
813
  drop_input_cols=self._drop_input_cols,
774
814
  expected_output_cols_type="float",
775
815
  )
776
- expected_output_cols = self._align_expected_output_names(
816
+ expected_output_cols, _ = self._align_expected_output(
777
817
  inference_method, dataset, expected_output_cols, output_cols_prefix
778
818
  )
779
819
  elif isinstance(dataset, pd.DataFrame):
@@ -836,7 +876,7 @@ class LassoLars(BaseTransformer):
836
876
  drop_input_cols=self._drop_input_cols,
837
877
  expected_output_cols_type="float",
838
878
  )
839
- expected_output_cols = self._align_expected_output_names(
879
+ expected_output_cols, _ = self._align_expected_output(
840
880
  inference_method, dataset, expected_output_cols, output_cols_prefix
841
881
  )
842
882
 
@@ -901,7 +941,7 @@ class LassoLars(BaseTransformer):
901
941
  drop_input_cols = self._drop_input_cols,
902
942
  expected_output_cols_type="float",
903
943
  )
904
- expected_output_cols = self._align_expected_output_names(
944
+ expected_output_cols, _ = self._align_expected_output(
905
945
  inference_method, dataset, expected_output_cols, output_cols_prefix
906
946
  )
907
947