snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -556,12 +553,23 @@ class LassoLarsCV(BaseTransformer):
556
553
  autogenerated=self._autogenerated,
557
554
  subproject=_SUBPROJECT,
558
555
  )
559
- output_result, fitted_estimator = model_trainer.train_fit_predict(
560
- drop_input_cols=self._drop_input_cols,
561
- expected_output_cols_list=(
562
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
563
- ),
556
+ expected_output_cols = (
557
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
564
558
  )
559
+ if isinstance(dataset, DataFrame):
560
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
561
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
562
+ )
563
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
564
+ drop_input_cols=self._drop_input_cols,
565
+ expected_output_cols_list=expected_output_cols,
566
+ example_output_pd_df=example_output_pd_df,
567
+ )
568
+ else:
569
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
570
+ drop_input_cols=self._drop_input_cols,
571
+ expected_output_cols_list=expected_output_cols,
572
+ )
565
573
  self._sklearn_object = fitted_estimator
566
574
  self._is_fitted = True
567
575
  return output_result
@@ -584,6 +592,7 @@ class LassoLarsCV(BaseTransformer):
584
592
  """
585
593
  self._infer_input_output_cols(dataset)
586
594
  super()._check_dataset_type(dataset)
595
+
587
596
  model_trainer = ModelTrainerBuilder.build_fit_transform(
588
597
  estimator=self._sklearn_object,
589
598
  dataset=dataset,
@@ -640,12 +649,41 @@ class LassoLarsCV(BaseTransformer):
640
649
 
641
650
  return rv
642
651
 
643
- def _align_expected_output_names(
644
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
645
- ) -> List[str]:
652
+ def _align_expected_output(
653
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
654
+ ) -> Tuple[List[str], pd.DataFrame]:
655
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
656
+ and output dataframe with 1 line.
657
+ If the method is fit_predict, run 2 lines of data.
658
+ """
646
659
  # in case the inferred output column names dimension is different
647
660
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
648
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
661
+
662
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
663
+ # so change the minimum of number of rows to 2
664
+ num_examples = 2
665
+ statement_params = telemetry.get_function_usage_statement_params(
666
+ project=_PROJECT,
667
+ subproject=_SUBPROJECT,
668
+ function_name=telemetry.get_statement_params_full_func_name(
669
+ inspect.currentframe(), LassoLarsCV.__class__.__name__
670
+ ),
671
+ api_calls=[Session.call],
672
+ custom_tags={"autogen": True} if self._autogenerated else None,
673
+ )
674
+ if output_cols_prefix == "fit_predict_":
675
+ if hasattr(self._sklearn_object, "n_clusters"):
676
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
677
+ num_examples = self._sklearn_object.n_clusters
678
+ elif hasattr(self._sklearn_object, "min_samples"):
679
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
680
+ num_examples = self._sklearn_object.min_samples
681
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
682
+ # LocalOutlierFactor expects n_neighbors <= n_samples
683
+ num_examples = self._sklearn_object.n_neighbors
684
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
685
+ else:
686
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
649
687
 
650
688
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
651
689
  # seen during the fit.
@@ -657,12 +695,14 @@ class LassoLarsCV(BaseTransformer):
657
695
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
658
696
  if self.sample_weight_col:
659
697
  output_df_columns_set -= set(self.sample_weight_col)
698
+
660
699
  # if the dimension of inferred output column names is correct; use it
661
700
  if len(expected_output_cols_list) == len(output_df_columns_set):
662
- return expected_output_cols_list
701
+ return expected_output_cols_list, output_df_pd
663
702
  # otherwise, use the sklearn estimator's output
664
703
  else:
665
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
704
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
705
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
666
706
 
667
707
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
668
708
  @telemetry.send_api_usage_telemetry(
@@ -708,7 +748,7 @@ class LassoLarsCV(BaseTransformer):
708
748
  drop_input_cols=self._drop_input_cols,
709
749
  expected_output_cols_type="float",
710
750
  )
711
- expected_output_cols = self._align_expected_output_names(
751
+ expected_output_cols, _ = self._align_expected_output(
712
752
  inference_method, dataset, expected_output_cols, output_cols_prefix
713
753
  )
714
754
 
@@ -774,7 +814,7 @@ class LassoLarsCV(BaseTransformer):
774
814
  drop_input_cols=self._drop_input_cols,
775
815
  expected_output_cols_type="float",
776
816
  )
777
- expected_output_cols = self._align_expected_output_names(
817
+ expected_output_cols, _ = self._align_expected_output(
778
818
  inference_method, dataset, expected_output_cols, output_cols_prefix
779
819
  )
780
820
  elif isinstance(dataset, pd.DataFrame):
@@ -837,7 +877,7 @@ class LassoLarsCV(BaseTransformer):
837
877
  drop_input_cols=self._drop_input_cols,
838
878
  expected_output_cols_type="float",
839
879
  )
840
- expected_output_cols = self._align_expected_output_names(
880
+ expected_output_cols, _ = self._align_expected_output(
841
881
  inference_method, dataset, expected_output_cols, output_cols_prefix
842
882
  )
843
883
 
@@ -902,7 +942,7 @@ class LassoLarsCV(BaseTransformer):
902
942
  drop_input_cols = self._drop_input_cols,
903
943
  expected_output_cols_type="float",
904
944
  )
905
- expected_output_cols = self._align_expected_output_names(
945
+ expected_output_cols, _ = self._align_expected_output(
906
946
  inference_method, dataset, expected_output_cols, output_cols_prefix
907
947
  )
908
948
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -539,12 +536,23 @@ class LassoLarsIC(BaseTransformer):
539
536
  autogenerated=self._autogenerated,
540
537
  subproject=_SUBPROJECT,
541
538
  )
542
- output_result, fitted_estimator = model_trainer.train_fit_predict(
543
- drop_input_cols=self._drop_input_cols,
544
- expected_output_cols_list=(
545
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
546
- ),
539
+ expected_output_cols = (
540
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
547
541
  )
542
+ if isinstance(dataset, DataFrame):
543
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
544
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
545
+ )
546
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
547
+ drop_input_cols=self._drop_input_cols,
548
+ expected_output_cols_list=expected_output_cols,
549
+ example_output_pd_df=example_output_pd_df,
550
+ )
551
+ else:
552
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
553
+ drop_input_cols=self._drop_input_cols,
554
+ expected_output_cols_list=expected_output_cols,
555
+ )
548
556
  self._sklearn_object = fitted_estimator
549
557
  self._is_fitted = True
550
558
  return output_result
@@ -567,6 +575,7 @@ class LassoLarsIC(BaseTransformer):
567
575
  """
568
576
  self._infer_input_output_cols(dataset)
569
577
  super()._check_dataset_type(dataset)
578
+
570
579
  model_trainer = ModelTrainerBuilder.build_fit_transform(
571
580
  estimator=self._sklearn_object,
572
581
  dataset=dataset,
@@ -623,12 +632,41 @@ class LassoLarsIC(BaseTransformer):
623
632
 
624
633
  return rv
625
634
 
626
- def _align_expected_output_names(
627
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
628
- ) -> List[str]:
635
+ def _align_expected_output(
636
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
637
+ ) -> Tuple[List[str], pd.DataFrame]:
638
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
639
+ and output dataframe with 1 line.
640
+ If the method is fit_predict, run 2 lines of data.
641
+ """
629
642
  # in case the inferred output column names dimension is different
630
643
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
631
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
644
+
645
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
646
+ # so change the minimum of number of rows to 2
647
+ num_examples = 2
648
+ statement_params = telemetry.get_function_usage_statement_params(
649
+ project=_PROJECT,
650
+ subproject=_SUBPROJECT,
651
+ function_name=telemetry.get_statement_params_full_func_name(
652
+ inspect.currentframe(), LassoLarsIC.__class__.__name__
653
+ ),
654
+ api_calls=[Session.call],
655
+ custom_tags={"autogen": True} if self._autogenerated else None,
656
+ )
657
+ if output_cols_prefix == "fit_predict_":
658
+ if hasattr(self._sklearn_object, "n_clusters"):
659
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
660
+ num_examples = self._sklearn_object.n_clusters
661
+ elif hasattr(self._sklearn_object, "min_samples"):
662
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
663
+ num_examples = self._sklearn_object.min_samples
664
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
665
+ # LocalOutlierFactor expects n_neighbors <= n_samples
666
+ num_examples = self._sklearn_object.n_neighbors
667
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
668
+ else:
669
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
632
670
 
633
671
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
634
672
  # seen during the fit.
@@ -640,12 +678,14 @@ class LassoLarsIC(BaseTransformer):
640
678
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
641
679
  if self.sample_weight_col:
642
680
  output_df_columns_set -= set(self.sample_weight_col)
681
+
643
682
  # if the dimension of inferred output column names is correct; use it
644
683
  if len(expected_output_cols_list) == len(output_df_columns_set):
645
- return expected_output_cols_list
684
+ return expected_output_cols_list, output_df_pd
646
685
  # otherwise, use the sklearn estimator's output
647
686
  else:
648
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
687
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
688
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
649
689
 
650
690
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
651
691
  @telemetry.send_api_usage_telemetry(
@@ -691,7 +731,7 @@ class LassoLarsIC(BaseTransformer):
691
731
  drop_input_cols=self._drop_input_cols,
692
732
  expected_output_cols_type="float",
693
733
  )
694
- expected_output_cols = self._align_expected_output_names(
734
+ expected_output_cols, _ = self._align_expected_output(
695
735
  inference_method, dataset, expected_output_cols, output_cols_prefix
696
736
  )
697
737
 
@@ -757,7 +797,7 @@ class LassoLarsIC(BaseTransformer):
757
797
  drop_input_cols=self._drop_input_cols,
758
798
  expected_output_cols_type="float",
759
799
  )
760
- expected_output_cols = self._align_expected_output_names(
800
+ expected_output_cols, _ = self._align_expected_output(
761
801
  inference_method, dataset, expected_output_cols, output_cols_prefix
762
802
  )
763
803
  elif isinstance(dataset, pd.DataFrame):
@@ -820,7 +860,7 @@ class LassoLarsIC(BaseTransformer):
820
860
  drop_input_cols=self._drop_input_cols,
821
861
  expected_output_cols_type="float",
822
862
  )
823
- expected_output_cols = self._align_expected_output_names(
863
+ expected_output_cols, _ = self._align_expected_output(
824
864
  inference_method, dataset, expected_output_cols, output_cols_prefix
825
865
  )
826
866
 
@@ -885,7 +925,7 @@ class LassoLarsIC(BaseTransformer):
885
925
  drop_input_cols = self._drop_input_cols,
886
926
  expected_output_cols_type="float",
887
927
  )
888
- expected_output_cols = self._align_expected_output_names(
928
+ expected_output_cols, _ = self._align_expected_output(
889
929
  inference_method, dataset, expected_output_cols, output_cols_prefix
890
930
  )
891
931
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -492,12 +489,23 @@ class LinearRegression(BaseTransformer):
492
489
  autogenerated=self._autogenerated,
493
490
  subproject=_SUBPROJECT,
494
491
  )
495
- output_result, fitted_estimator = model_trainer.train_fit_predict(
496
- drop_input_cols=self._drop_input_cols,
497
- expected_output_cols_list=(
498
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
499
- ),
492
+ expected_output_cols = (
493
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
500
494
  )
495
+ if isinstance(dataset, DataFrame):
496
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
497
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
498
+ )
499
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
500
+ drop_input_cols=self._drop_input_cols,
501
+ expected_output_cols_list=expected_output_cols,
502
+ example_output_pd_df=example_output_pd_df,
503
+ )
504
+ else:
505
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
506
+ drop_input_cols=self._drop_input_cols,
507
+ expected_output_cols_list=expected_output_cols,
508
+ )
501
509
  self._sklearn_object = fitted_estimator
502
510
  self._is_fitted = True
503
511
  return output_result
@@ -520,6 +528,7 @@ class LinearRegression(BaseTransformer):
520
528
  """
521
529
  self._infer_input_output_cols(dataset)
522
530
  super()._check_dataset_type(dataset)
531
+
523
532
  model_trainer = ModelTrainerBuilder.build_fit_transform(
524
533
  estimator=self._sklearn_object,
525
534
  dataset=dataset,
@@ -576,12 +585,41 @@ class LinearRegression(BaseTransformer):
576
585
 
577
586
  return rv
578
587
 
579
- def _align_expected_output_names(
580
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
581
- ) -> List[str]:
588
+ def _align_expected_output(
589
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
590
+ ) -> Tuple[List[str], pd.DataFrame]:
591
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
592
+ and output dataframe with 1 line.
593
+ If the method is fit_predict, run 2 lines of data.
594
+ """
582
595
  # in case the inferred output column names dimension is different
583
596
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
584
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
597
+
598
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
599
+ # so change the minimum of number of rows to 2
600
+ num_examples = 2
601
+ statement_params = telemetry.get_function_usage_statement_params(
602
+ project=_PROJECT,
603
+ subproject=_SUBPROJECT,
604
+ function_name=telemetry.get_statement_params_full_func_name(
605
+ inspect.currentframe(), LinearRegression.__class__.__name__
606
+ ),
607
+ api_calls=[Session.call],
608
+ custom_tags={"autogen": True} if self._autogenerated else None,
609
+ )
610
+ if output_cols_prefix == "fit_predict_":
611
+ if hasattr(self._sklearn_object, "n_clusters"):
612
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
613
+ num_examples = self._sklearn_object.n_clusters
614
+ elif hasattr(self._sklearn_object, "min_samples"):
615
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
616
+ num_examples = self._sklearn_object.min_samples
617
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
618
+ # LocalOutlierFactor expects n_neighbors <= n_samples
619
+ num_examples = self._sklearn_object.n_neighbors
620
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
621
+ else:
622
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
585
623
 
586
624
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
587
625
  # seen during the fit.
@@ -593,12 +631,14 @@ class LinearRegression(BaseTransformer):
593
631
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
594
632
  if self.sample_weight_col:
595
633
  output_df_columns_set -= set(self.sample_weight_col)
634
+
596
635
  # if the dimension of inferred output column names is correct; use it
597
636
  if len(expected_output_cols_list) == len(output_df_columns_set):
598
- return expected_output_cols_list
637
+ return expected_output_cols_list, output_df_pd
599
638
  # otherwise, use the sklearn estimator's output
600
639
  else:
601
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
640
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
641
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
602
642
 
603
643
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
604
644
  @telemetry.send_api_usage_telemetry(
@@ -644,7 +684,7 @@ class LinearRegression(BaseTransformer):
644
684
  drop_input_cols=self._drop_input_cols,
645
685
  expected_output_cols_type="float",
646
686
  )
647
- expected_output_cols = self._align_expected_output_names(
687
+ expected_output_cols, _ = self._align_expected_output(
648
688
  inference_method, dataset, expected_output_cols, output_cols_prefix
649
689
  )
650
690
 
@@ -710,7 +750,7 @@ class LinearRegression(BaseTransformer):
710
750
  drop_input_cols=self._drop_input_cols,
711
751
  expected_output_cols_type="float",
712
752
  )
713
- expected_output_cols = self._align_expected_output_names(
753
+ expected_output_cols, _ = self._align_expected_output(
714
754
  inference_method, dataset, expected_output_cols, output_cols_prefix
715
755
  )
716
756
  elif isinstance(dataset, pd.DataFrame):
@@ -773,7 +813,7 @@ class LinearRegression(BaseTransformer):
773
813
  drop_input_cols=self._drop_input_cols,
774
814
  expected_output_cols_type="float",
775
815
  )
776
- expected_output_cols = self._align_expected_output_names(
816
+ expected_output_cols, _ = self._align_expected_output(
777
817
  inference_method, dataset, expected_output_cols, output_cols_prefix
778
818
  )
779
819
 
@@ -838,7 +878,7 @@ class LinearRegression(BaseTransformer):
838
878
  drop_input_cols = self._drop_input_cols,
839
879
  expected_output_cols_type="float",
840
880
  )
841
- expected_output_cols = self._align_expected_output_names(
881
+ expected_output_cols, _ = self._align_expected_output(
842
882
  inference_method, dataset, expected_output_cols, output_cols_prefix
843
883
  )
844
884
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -606,12 +603,23 @@ class LogisticRegression(BaseTransformer):
606
603
  autogenerated=self._autogenerated,
607
604
  subproject=_SUBPROJECT,
608
605
  )
609
- output_result, fitted_estimator = model_trainer.train_fit_predict(
610
- drop_input_cols=self._drop_input_cols,
611
- expected_output_cols_list=(
612
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
613
- ),
606
+ expected_output_cols = (
607
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
614
608
  )
609
+ if isinstance(dataset, DataFrame):
610
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
611
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
612
+ )
613
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
614
+ drop_input_cols=self._drop_input_cols,
615
+ expected_output_cols_list=expected_output_cols,
616
+ example_output_pd_df=example_output_pd_df,
617
+ )
618
+ else:
619
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
620
+ drop_input_cols=self._drop_input_cols,
621
+ expected_output_cols_list=expected_output_cols,
622
+ )
615
623
  self._sklearn_object = fitted_estimator
616
624
  self._is_fitted = True
617
625
  return output_result
@@ -634,6 +642,7 @@ class LogisticRegression(BaseTransformer):
634
642
  """
635
643
  self._infer_input_output_cols(dataset)
636
644
  super()._check_dataset_type(dataset)
645
+
637
646
  model_trainer = ModelTrainerBuilder.build_fit_transform(
638
647
  estimator=self._sklearn_object,
639
648
  dataset=dataset,
@@ -690,12 +699,41 @@ class LogisticRegression(BaseTransformer):
690
699
 
691
700
  return rv
692
701
 
693
- def _align_expected_output_names(
694
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
695
- ) -> List[str]:
702
+ def _align_expected_output(
703
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
704
+ ) -> Tuple[List[str], pd.DataFrame]:
705
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
706
+ and output dataframe with 1 line.
707
+ If the method is fit_predict, run 2 lines of data.
708
+ """
696
709
  # in case the inferred output column names dimension is different
697
710
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
698
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
711
+
712
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
713
+ # so change the minimum of number of rows to 2
714
+ num_examples = 2
715
+ statement_params = telemetry.get_function_usage_statement_params(
716
+ project=_PROJECT,
717
+ subproject=_SUBPROJECT,
718
+ function_name=telemetry.get_statement_params_full_func_name(
719
+ inspect.currentframe(), LogisticRegression.__class__.__name__
720
+ ),
721
+ api_calls=[Session.call],
722
+ custom_tags={"autogen": True} if self._autogenerated else None,
723
+ )
724
+ if output_cols_prefix == "fit_predict_":
725
+ if hasattr(self._sklearn_object, "n_clusters"):
726
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
727
+ num_examples = self._sklearn_object.n_clusters
728
+ elif hasattr(self._sklearn_object, "min_samples"):
729
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
730
+ num_examples = self._sklearn_object.min_samples
731
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
732
+ # LocalOutlierFactor expects n_neighbors <= n_samples
733
+ num_examples = self._sklearn_object.n_neighbors
734
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
735
+ else:
736
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
699
737
 
700
738
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
701
739
  # seen during the fit.
@@ -707,12 +745,14 @@ class LogisticRegression(BaseTransformer):
707
745
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
708
746
  if self.sample_weight_col:
709
747
  output_df_columns_set -= set(self.sample_weight_col)
748
+
710
749
  # if the dimension of inferred output column names is correct; use it
711
750
  if len(expected_output_cols_list) == len(output_df_columns_set):
712
- return expected_output_cols_list
751
+ return expected_output_cols_list, output_df_pd
713
752
  # otherwise, use the sklearn estimator's output
714
753
  else:
715
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
754
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
755
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
716
756
 
717
757
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
718
758
  @telemetry.send_api_usage_telemetry(
@@ -760,7 +800,7 @@ class LogisticRegression(BaseTransformer):
760
800
  drop_input_cols=self._drop_input_cols,
761
801
  expected_output_cols_type="float",
762
802
  )
763
- expected_output_cols = self._align_expected_output_names(
803
+ expected_output_cols, _ = self._align_expected_output(
764
804
  inference_method, dataset, expected_output_cols, output_cols_prefix
765
805
  )
766
806
 
@@ -828,7 +868,7 @@ class LogisticRegression(BaseTransformer):
828
868
  drop_input_cols=self._drop_input_cols,
829
869
  expected_output_cols_type="float",
830
870
  )
831
- expected_output_cols = self._align_expected_output_names(
871
+ expected_output_cols, _ = self._align_expected_output(
832
872
  inference_method, dataset, expected_output_cols, output_cols_prefix
833
873
  )
834
874
  elif isinstance(dataset, pd.DataFrame):
@@ -893,7 +933,7 @@ class LogisticRegression(BaseTransformer):
893
933
  drop_input_cols=self._drop_input_cols,
894
934
  expected_output_cols_type="float",
895
935
  )
896
- expected_output_cols = self._align_expected_output_names(
936
+ expected_output_cols, _ = self._align_expected_output(
897
937
  inference_method, dataset, expected_output_cols, output_cols_prefix
898
938
  )
899
939
 
@@ -958,7 +998,7 @@ class LogisticRegression(BaseTransformer):
958
998
  drop_input_cols = self._drop_input_cols,
959
999
  expected_output_cols_type="float",
960
1000
  )
961
- expected_output_cols = self._align_expected_output_names(
1001
+ expected_output_cols, _ = self._align_expected_output(
962
1002
  inference_method, dataset, expected_output_cols, output_cols_prefix
963
1003
  )
964
1004