snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -577,12 +574,23 @@ class ElasticNetCV(BaseTransformer):
577
574
  autogenerated=self._autogenerated,
578
575
  subproject=_SUBPROJECT,
579
576
  )
580
- output_result, fitted_estimator = model_trainer.train_fit_predict(
581
- drop_input_cols=self._drop_input_cols,
582
- expected_output_cols_list=(
583
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
584
- ),
577
+ expected_output_cols = (
578
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
585
579
  )
580
+ if isinstance(dataset, DataFrame):
581
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
582
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
583
+ )
584
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
585
+ drop_input_cols=self._drop_input_cols,
586
+ expected_output_cols_list=expected_output_cols,
587
+ example_output_pd_df=example_output_pd_df,
588
+ )
589
+ else:
590
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
591
+ drop_input_cols=self._drop_input_cols,
592
+ expected_output_cols_list=expected_output_cols,
593
+ )
586
594
  self._sklearn_object = fitted_estimator
587
595
  self._is_fitted = True
588
596
  return output_result
@@ -605,6 +613,7 @@ class ElasticNetCV(BaseTransformer):
605
613
  """
606
614
  self._infer_input_output_cols(dataset)
607
615
  super()._check_dataset_type(dataset)
616
+
608
617
  model_trainer = ModelTrainerBuilder.build_fit_transform(
609
618
  estimator=self._sklearn_object,
610
619
  dataset=dataset,
@@ -661,12 +670,41 @@ class ElasticNetCV(BaseTransformer):
661
670
 
662
671
  return rv
663
672
 
664
- def _align_expected_output_names(
665
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
666
- ) -> List[str]:
673
+ def _align_expected_output(
674
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
675
+ ) -> Tuple[List[str], pd.DataFrame]:
676
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
677
+ and output dataframe with 1 line.
678
+ If the method is fit_predict, run 2 lines of data.
679
+ """
667
680
  # in case the inferred output column names dimension is different
668
681
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
669
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
682
+
683
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
684
+ # so change the minimum of number of rows to 2
685
+ num_examples = 2
686
+ statement_params = telemetry.get_function_usage_statement_params(
687
+ project=_PROJECT,
688
+ subproject=_SUBPROJECT,
689
+ function_name=telemetry.get_statement_params_full_func_name(
690
+ inspect.currentframe(), ElasticNetCV.__class__.__name__
691
+ ),
692
+ api_calls=[Session.call],
693
+ custom_tags={"autogen": True} if self._autogenerated else None,
694
+ )
695
+ if output_cols_prefix == "fit_predict_":
696
+ if hasattr(self._sklearn_object, "n_clusters"):
697
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
698
+ num_examples = self._sklearn_object.n_clusters
699
+ elif hasattr(self._sklearn_object, "min_samples"):
700
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
701
+ num_examples = self._sklearn_object.min_samples
702
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
703
+ # LocalOutlierFactor expects n_neighbors <= n_samples
704
+ num_examples = self._sklearn_object.n_neighbors
705
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
706
+ else:
707
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
670
708
 
671
709
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
672
710
  # seen during the fit.
@@ -678,12 +716,14 @@ class ElasticNetCV(BaseTransformer):
678
716
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
679
717
  if self.sample_weight_col:
680
718
  output_df_columns_set -= set(self.sample_weight_col)
719
+
681
720
  # if the dimension of inferred output column names is correct; use it
682
721
  if len(expected_output_cols_list) == len(output_df_columns_set):
683
- return expected_output_cols_list
722
+ return expected_output_cols_list, output_df_pd
684
723
  # otherwise, use the sklearn estimator's output
685
724
  else:
686
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
725
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
726
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
687
727
 
688
728
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
689
729
  @telemetry.send_api_usage_telemetry(
@@ -729,7 +769,7 @@ class ElasticNetCV(BaseTransformer):
729
769
  drop_input_cols=self._drop_input_cols,
730
770
  expected_output_cols_type="float",
731
771
  )
732
- expected_output_cols = self._align_expected_output_names(
772
+ expected_output_cols, _ = self._align_expected_output(
733
773
  inference_method, dataset, expected_output_cols, output_cols_prefix
734
774
  )
735
775
 
@@ -795,7 +835,7 @@ class ElasticNetCV(BaseTransformer):
795
835
  drop_input_cols=self._drop_input_cols,
796
836
  expected_output_cols_type="float",
797
837
  )
798
- expected_output_cols = self._align_expected_output_names(
838
+ expected_output_cols, _ = self._align_expected_output(
799
839
  inference_method, dataset, expected_output_cols, output_cols_prefix
800
840
  )
801
841
  elif isinstance(dataset, pd.DataFrame):
@@ -858,7 +898,7 @@ class ElasticNetCV(BaseTransformer):
858
898
  drop_input_cols=self._drop_input_cols,
859
899
  expected_output_cols_type="float",
860
900
  )
861
- expected_output_cols = self._align_expected_output_names(
901
+ expected_output_cols, _ = self._align_expected_output(
862
902
  inference_method, dataset, expected_output_cols, output_cols_prefix
863
903
  )
864
904
 
@@ -923,7 +963,7 @@ class ElasticNetCV(BaseTransformer):
923
963
  drop_input_cols = self._drop_input_cols,
924
964
  expected_output_cols_type="float",
925
965
  )
926
- expected_output_cols = self._align_expected_output_names(
966
+ expected_output_cols, _ = self._align_expected_output(
927
967
  inference_method, dataset, expected_output_cols, output_cols_prefix
928
968
  )
929
969
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -522,12 +519,23 @@ class GammaRegressor(BaseTransformer):
522
519
  autogenerated=self._autogenerated,
523
520
  subproject=_SUBPROJECT,
524
521
  )
525
- output_result, fitted_estimator = model_trainer.train_fit_predict(
526
- drop_input_cols=self._drop_input_cols,
527
- expected_output_cols_list=(
528
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
529
- ),
522
+ expected_output_cols = (
523
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
530
524
  )
525
+ if isinstance(dataset, DataFrame):
526
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
527
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
528
+ )
529
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
530
+ drop_input_cols=self._drop_input_cols,
531
+ expected_output_cols_list=expected_output_cols,
532
+ example_output_pd_df=example_output_pd_df,
533
+ )
534
+ else:
535
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
536
+ drop_input_cols=self._drop_input_cols,
537
+ expected_output_cols_list=expected_output_cols,
538
+ )
531
539
  self._sklearn_object = fitted_estimator
532
540
  self._is_fitted = True
533
541
  return output_result
@@ -550,6 +558,7 @@ class GammaRegressor(BaseTransformer):
550
558
  """
551
559
  self._infer_input_output_cols(dataset)
552
560
  super()._check_dataset_type(dataset)
561
+
553
562
  model_trainer = ModelTrainerBuilder.build_fit_transform(
554
563
  estimator=self._sklearn_object,
555
564
  dataset=dataset,
@@ -606,12 +615,41 @@ class GammaRegressor(BaseTransformer):
606
615
 
607
616
  return rv
608
617
 
609
- def _align_expected_output_names(
610
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
611
- ) -> List[str]:
618
+ def _align_expected_output(
619
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
620
+ ) -> Tuple[List[str], pd.DataFrame]:
621
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
622
+ and output dataframe with 1 line.
623
+ If the method is fit_predict, run 2 lines of data.
624
+ """
612
625
  # in case the inferred output column names dimension is different
613
626
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
614
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
627
+
628
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
629
+ # so change the minimum of number of rows to 2
630
+ num_examples = 2
631
+ statement_params = telemetry.get_function_usage_statement_params(
632
+ project=_PROJECT,
633
+ subproject=_SUBPROJECT,
634
+ function_name=telemetry.get_statement_params_full_func_name(
635
+ inspect.currentframe(), GammaRegressor.__class__.__name__
636
+ ),
637
+ api_calls=[Session.call],
638
+ custom_tags={"autogen": True} if self._autogenerated else None,
639
+ )
640
+ if output_cols_prefix == "fit_predict_":
641
+ if hasattr(self._sklearn_object, "n_clusters"):
642
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
643
+ num_examples = self._sklearn_object.n_clusters
644
+ elif hasattr(self._sklearn_object, "min_samples"):
645
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
646
+ num_examples = self._sklearn_object.min_samples
647
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
648
+ # LocalOutlierFactor expects n_neighbors <= n_samples
649
+ num_examples = self._sklearn_object.n_neighbors
650
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
651
+ else:
652
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
615
653
 
616
654
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
617
655
  # seen during the fit.
@@ -623,12 +661,14 @@ class GammaRegressor(BaseTransformer):
623
661
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
624
662
  if self.sample_weight_col:
625
663
  output_df_columns_set -= set(self.sample_weight_col)
664
+
626
665
  # if the dimension of inferred output column names is correct; use it
627
666
  if len(expected_output_cols_list) == len(output_df_columns_set):
628
- return expected_output_cols_list
667
+ return expected_output_cols_list, output_df_pd
629
668
  # otherwise, use the sklearn estimator's output
630
669
  else:
631
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
670
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
671
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
632
672
 
633
673
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
634
674
  @telemetry.send_api_usage_telemetry(
@@ -674,7 +714,7 @@ class GammaRegressor(BaseTransformer):
674
714
  drop_input_cols=self._drop_input_cols,
675
715
  expected_output_cols_type="float",
676
716
  )
677
- expected_output_cols = self._align_expected_output_names(
717
+ expected_output_cols, _ = self._align_expected_output(
678
718
  inference_method, dataset, expected_output_cols, output_cols_prefix
679
719
  )
680
720
 
@@ -740,7 +780,7 @@ class GammaRegressor(BaseTransformer):
740
780
  drop_input_cols=self._drop_input_cols,
741
781
  expected_output_cols_type="float",
742
782
  )
743
- expected_output_cols = self._align_expected_output_names(
783
+ expected_output_cols, _ = self._align_expected_output(
744
784
  inference_method, dataset, expected_output_cols, output_cols_prefix
745
785
  )
746
786
  elif isinstance(dataset, pd.DataFrame):
@@ -803,7 +843,7 @@ class GammaRegressor(BaseTransformer):
803
843
  drop_input_cols=self._drop_input_cols,
804
844
  expected_output_cols_type="float",
805
845
  )
806
- expected_output_cols = self._align_expected_output_names(
846
+ expected_output_cols, _ = self._align_expected_output(
807
847
  inference_method, dataset, expected_output_cols, output_cols_prefix
808
848
  )
809
849
 
@@ -868,7 +908,7 @@ class GammaRegressor(BaseTransformer):
868
908
  drop_input_cols = self._drop_input_cols,
869
909
  expected_output_cols_type="float",
870
910
  )
871
- expected_output_cols = self._align_expected_output_names(
911
+ expected_output_cols, _ = self._align_expected_output(
872
912
  inference_method, dataset, expected_output_cols, output_cols_prefix
873
913
  )
874
914
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -505,12 +502,23 @@ class HuberRegressor(BaseTransformer):
505
502
  autogenerated=self._autogenerated,
506
503
  subproject=_SUBPROJECT,
507
504
  )
508
- output_result, fitted_estimator = model_trainer.train_fit_predict(
509
- drop_input_cols=self._drop_input_cols,
510
- expected_output_cols_list=(
511
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
512
- ),
505
+ expected_output_cols = (
506
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
513
507
  )
508
+ if isinstance(dataset, DataFrame):
509
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
510
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
511
+ )
512
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
513
+ drop_input_cols=self._drop_input_cols,
514
+ expected_output_cols_list=expected_output_cols,
515
+ example_output_pd_df=example_output_pd_df,
516
+ )
517
+ else:
518
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
519
+ drop_input_cols=self._drop_input_cols,
520
+ expected_output_cols_list=expected_output_cols,
521
+ )
514
522
  self._sklearn_object = fitted_estimator
515
523
  self._is_fitted = True
516
524
  return output_result
@@ -533,6 +541,7 @@ class HuberRegressor(BaseTransformer):
533
541
  """
534
542
  self._infer_input_output_cols(dataset)
535
543
  super()._check_dataset_type(dataset)
544
+
536
545
  model_trainer = ModelTrainerBuilder.build_fit_transform(
537
546
  estimator=self._sklearn_object,
538
547
  dataset=dataset,
@@ -589,12 +598,41 @@ class HuberRegressor(BaseTransformer):
589
598
 
590
599
  return rv
591
600
 
592
- def _align_expected_output_names(
593
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
594
- ) -> List[str]:
601
+ def _align_expected_output(
602
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
603
+ ) -> Tuple[List[str], pd.DataFrame]:
604
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
605
+ and output dataframe with 1 line.
606
+ If the method is fit_predict, run 2 lines of data.
607
+ """
595
608
  # in case the inferred output column names dimension is different
596
609
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
597
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
610
+
611
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
612
+ # so change the minimum of number of rows to 2
613
+ num_examples = 2
614
+ statement_params = telemetry.get_function_usage_statement_params(
615
+ project=_PROJECT,
616
+ subproject=_SUBPROJECT,
617
+ function_name=telemetry.get_statement_params_full_func_name(
618
+ inspect.currentframe(), HuberRegressor.__class__.__name__
619
+ ),
620
+ api_calls=[Session.call],
621
+ custom_tags={"autogen": True} if self._autogenerated else None,
622
+ )
623
+ if output_cols_prefix == "fit_predict_":
624
+ if hasattr(self._sklearn_object, "n_clusters"):
625
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
626
+ num_examples = self._sklearn_object.n_clusters
627
+ elif hasattr(self._sklearn_object, "min_samples"):
628
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
629
+ num_examples = self._sklearn_object.min_samples
630
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
631
+ # LocalOutlierFactor expects n_neighbors <= n_samples
632
+ num_examples = self._sklearn_object.n_neighbors
633
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
634
+ else:
635
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
598
636
 
599
637
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
600
638
  # seen during the fit.
@@ -606,12 +644,14 @@ class HuberRegressor(BaseTransformer):
606
644
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
607
645
  if self.sample_weight_col:
608
646
  output_df_columns_set -= set(self.sample_weight_col)
647
+
609
648
  # if the dimension of inferred output column names is correct; use it
610
649
  if len(expected_output_cols_list) == len(output_df_columns_set):
611
- return expected_output_cols_list
650
+ return expected_output_cols_list, output_df_pd
612
651
  # otherwise, use the sklearn estimator's output
613
652
  else:
614
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
653
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
654
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
615
655
 
616
656
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
617
657
  @telemetry.send_api_usage_telemetry(
@@ -657,7 +697,7 @@ class HuberRegressor(BaseTransformer):
657
697
  drop_input_cols=self._drop_input_cols,
658
698
  expected_output_cols_type="float",
659
699
  )
660
- expected_output_cols = self._align_expected_output_names(
700
+ expected_output_cols, _ = self._align_expected_output(
661
701
  inference_method, dataset, expected_output_cols, output_cols_prefix
662
702
  )
663
703
 
@@ -723,7 +763,7 @@ class HuberRegressor(BaseTransformer):
723
763
  drop_input_cols=self._drop_input_cols,
724
764
  expected_output_cols_type="float",
725
765
  )
726
- expected_output_cols = self._align_expected_output_names(
766
+ expected_output_cols, _ = self._align_expected_output(
727
767
  inference_method, dataset, expected_output_cols, output_cols_prefix
728
768
  )
729
769
  elif isinstance(dataset, pd.DataFrame):
@@ -786,7 +826,7 @@ class HuberRegressor(BaseTransformer):
786
826
  drop_input_cols=self._drop_input_cols,
787
827
  expected_output_cols_type="float",
788
828
  )
789
- expected_output_cols = self._align_expected_output_names(
829
+ expected_output_cols, _ = self._align_expected_output(
790
830
  inference_method, dataset, expected_output_cols, output_cols_prefix
791
831
  )
792
832
 
@@ -851,7 +891,7 @@ class HuberRegressor(BaseTransformer):
851
891
  drop_input_cols = self._drop_input_cols,
852
892
  expected_output_cols_type="float",
853
893
  )
854
- expected_output_cols = self._align_expected_output_names(
894
+ expected_output_cols, _ = self._align_expected_output(
855
895
  inference_method, dataset, expected_output_cols, output_cols_prefix
856
896
  )
857
897
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -534,12 +531,23 @@ class Lars(BaseTransformer):
534
531
  autogenerated=self._autogenerated,
535
532
  subproject=_SUBPROJECT,
536
533
  )
537
- output_result, fitted_estimator = model_trainer.train_fit_predict(
538
- drop_input_cols=self._drop_input_cols,
539
- expected_output_cols_list=(
540
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
541
- ),
534
+ expected_output_cols = (
535
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
542
536
  )
537
+ if isinstance(dataset, DataFrame):
538
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
539
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
540
+ )
541
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
542
+ drop_input_cols=self._drop_input_cols,
543
+ expected_output_cols_list=expected_output_cols,
544
+ example_output_pd_df=example_output_pd_df,
545
+ )
546
+ else:
547
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
548
+ drop_input_cols=self._drop_input_cols,
549
+ expected_output_cols_list=expected_output_cols,
550
+ )
543
551
  self._sklearn_object = fitted_estimator
544
552
  self._is_fitted = True
545
553
  return output_result
@@ -562,6 +570,7 @@ class Lars(BaseTransformer):
562
570
  """
563
571
  self._infer_input_output_cols(dataset)
564
572
  super()._check_dataset_type(dataset)
573
+
565
574
  model_trainer = ModelTrainerBuilder.build_fit_transform(
566
575
  estimator=self._sklearn_object,
567
576
  dataset=dataset,
@@ -618,12 +627,41 @@ class Lars(BaseTransformer):
618
627
 
619
628
  return rv
620
629
 
621
- def _align_expected_output_names(
622
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
623
- ) -> List[str]:
630
+ def _align_expected_output(
631
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
632
+ ) -> Tuple[List[str], pd.DataFrame]:
633
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
634
+ and output dataframe with 1 line.
635
+ If the method is fit_predict, run 2 lines of data.
636
+ """
624
637
  # in case the inferred output column names dimension is different
625
638
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
626
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
639
+
640
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
641
+ # so change the minimum of number of rows to 2
642
+ num_examples = 2
643
+ statement_params = telemetry.get_function_usage_statement_params(
644
+ project=_PROJECT,
645
+ subproject=_SUBPROJECT,
646
+ function_name=telemetry.get_statement_params_full_func_name(
647
+ inspect.currentframe(), Lars.__class__.__name__
648
+ ),
649
+ api_calls=[Session.call],
650
+ custom_tags={"autogen": True} if self._autogenerated else None,
651
+ )
652
+ if output_cols_prefix == "fit_predict_":
653
+ if hasattr(self._sklearn_object, "n_clusters"):
654
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
655
+ num_examples = self._sklearn_object.n_clusters
656
+ elif hasattr(self._sklearn_object, "min_samples"):
657
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
658
+ num_examples = self._sklearn_object.min_samples
659
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
660
+ # LocalOutlierFactor expects n_neighbors <= n_samples
661
+ num_examples = self._sklearn_object.n_neighbors
662
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
663
+ else:
664
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
627
665
 
628
666
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
629
667
  # seen during the fit.
@@ -635,12 +673,14 @@ class Lars(BaseTransformer):
635
673
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
636
674
  if self.sample_weight_col:
637
675
  output_df_columns_set -= set(self.sample_weight_col)
676
+
638
677
  # if the dimension of inferred output column names is correct; use it
639
678
  if len(expected_output_cols_list) == len(output_df_columns_set):
640
- return expected_output_cols_list
679
+ return expected_output_cols_list, output_df_pd
641
680
  # otherwise, use the sklearn estimator's output
642
681
  else:
643
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
682
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
683
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
644
684
 
645
685
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
646
686
  @telemetry.send_api_usage_telemetry(
@@ -686,7 +726,7 @@ class Lars(BaseTransformer):
686
726
  drop_input_cols=self._drop_input_cols,
687
727
  expected_output_cols_type="float",
688
728
  )
689
- expected_output_cols = self._align_expected_output_names(
729
+ expected_output_cols, _ = self._align_expected_output(
690
730
  inference_method, dataset, expected_output_cols, output_cols_prefix
691
731
  )
692
732
 
@@ -752,7 +792,7 @@ class Lars(BaseTransformer):
752
792
  drop_input_cols=self._drop_input_cols,
753
793
  expected_output_cols_type="float",
754
794
  )
755
- expected_output_cols = self._align_expected_output_names(
795
+ expected_output_cols, _ = self._align_expected_output(
756
796
  inference_method, dataset, expected_output_cols, output_cols_prefix
757
797
  )
758
798
  elif isinstance(dataset, pd.DataFrame):
@@ -815,7 +855,7 @@ class Lars(BaseTransformer):
815
855
  drop_input_cols=self._drop_input_cols,
816
856
  expected_output_cols_type="float",
817
857
  )
818
- expected_output_cols = self._align_expected_output_names(
858
+ expected_output_cols, _ = self._align_expected_output(
819
859
  inference_method, dataset, expected_output_cols, output_cols_prefix
820
860
  )
821
861
 
@@ -880,7 +920,7 @@ class Lars(BaseTransformer):
880
920
  drop_input_cols = self._drop_input_cols,
881
921
  expected_output_cols_type="float",
882
922
  )
883
- expected_output_cols = self._align_expected_output_names(
923
+ expected_output_cols, _ = self._align_expected_output(
884
924
  inference_method, dataset, expected_output_cols, output_cols_prefix
885
925
  )
886
926