snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -627,12 +624,23 @@ class LogisticRegressionCV(BaseTransformer):
627
624
  autogenerated=self._autogenerated,
628
625
  subproject=_SUBPROJECT,
629
626
  )
630
- output_result, fitted_estimator = model_trainer.train_fit_predict(
631
- drop_input_cols=self._drop_input_cols,
632
- expected_output_cols_list=(
633
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
634
- ),
627
+ expected_output_cols = (
628
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
635
629
  )
630
+ if isinstance(dataset, DataFrame):
631
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
632
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
633
+ )
634
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
635
+ drop_input_cols=self._drop_input_cols,
636
+ expected_output_cols_list=expected_output_cols,
637
+ example_output_pd_df=example_output_pd_df,
638
+ )
639
+ else:
640
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
641
+ drop_input_cols=self._drop_input_cols,
642
+ expected_output_cols_list=expected_output_cols,
643
+ )
636
644
  self._sklearn_object = fitted_estimator
637
645
  self._is_fitted = True
638
646
  return output_result
@@ -655,6 +663,7 @@ class LogisticRegressionCV(BaseTransformer):
655
663
  """
656
664
  self._infer_input_output_cols(dataset)
657
665
  super()._check_dataset_type(dataset)
666
+
658
667
  model_trainer = ModelTrainerBuilder.build_fit_transform(
659
668
  estimator=self._sklearn_object,
660
669
  dataset=dataset,
@@ -711,12 +720,41 @@ class LogisticRegressionCV(BaseTransformer):
711
720
 
712
721
  return rv
713
722
 
714
- def _align_expected_output_names(
715
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
716
- ) -> List[str]:
723
+ def _align_expected_output(
724
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
725
+ ) -> Tuple[List[str], pd.DataFrame]:
726
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
727
+ and output dataframe with 1 line.
728
+ If the method is fit_predict, run 2 lines of data.
729
+ """
717
730
  # in case the inferred output column names dimension is different
718
731
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
719
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
732
+
733
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
734
+ # so change the minimum of number of rows to 2
735
+ num_examples = 2
736
+ statement_params = telemetry.get_function_usage_statement_params(
737
+ project=_PROJECT,
738
+ subproject=_SUBPROJECT,
739
+ function_name=telemetry.get_statement_params_full_func_name(
740
+ inspect.currentframe(), LogisticRegressionCV.__class__.__name__
741
+ ),
742
+ api_calls=[Session.call],
743
+ custom_tags={"autogen": True} if self._autogenerated else None,
744
+ )
745
+ if output_cols_prefix == "fit_predict_":
746
+ if hasattr(self._sklearn_object, "n_clusters"):
747
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
748
+ num_examples = self._sklearn_object.n_clusters
749
+ elif hasattr(self._sklearn_object, "min_samples"):
750
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
751
+ num_examples = self._sklearn_object.min_samples
752
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
753
+ # LocalOutlierFactor expects n_neighbors <= n_samples
754
+ num_examples = self._sklearn_object.n_neighbors
755
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
756
+ else:
757
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
720
758
 
721
759
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
722
760
  # seen during the fit.
@@ -728,12 +766,14 @@ class LogisticRegressionCV(BaseTransformer):
728
766
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
729
767
  if self.sample_weight_col:
730
768
  output_df_columns_set -= set(self.sample_weight_col)
769
+
731
770
  # if the dimension of inferred output column names is correct; use it
732
771
  if len(expected_output_cols_list) == len(output_df_columns_set):
733
- return expected_output_cols_list
772
+ return expected_output_cols_list, output_df_pd
734
773
  # otherwise, use the sklearn estimator's output
735
774
  else:
736
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
775
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
776
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
737
777
 
738
778
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
739
779
  @telemetry.send_api_usage_telemetry(
@@ -781,7 +821,7 @@ class LogisticRegressionCV(BaseTransformer):
781
821
  drop_input_cols=self._drop_input_cols,
782
822
  expected_output_cols_type="float",
783
823
  )
784
- expected_output_cols = self._align_expected_output_names(
824
+ expected_output_cols, _ = self._align_expected_output(
785
825
  inference_method, dataset, expected_output_cols, output_cols_prefix
786
826
  )
787
827
 
@@ -849,7 +889,7 @@ class LogisticRegressionCV(BaseTransformer):
849
889
  drop_input_cols=self._drop_input_cols,
850
890
  expected_output_cols_type="float",
851
891
  )
852
- expected_output_cols = self._align_expected_output_names(
892
+ expected_output_cols, _ = self._align_expected_output(
853
893
  inference_method, dataset, expected_output_cols, output_cols_prefix
854
894
  )
855
895
  elif isinstance(dataset, pd.DataFrame):
@@ -914,7 +954,7 @@ class LogisticRegressionCV(BaseTransformer):
914
954
  drop_input_cols=self._drop_input_cols,
915
955
  expected_output_cols_type="float",
916
956
  )
917
- expected_output_cols = self._align_expected_output_names(
957
+ expected_output_cols, _ = self._align_expected_output(
918
958
  inference_method, dataset, expected_output_cols, output_cols_prefix
919
959
  )
920
960
 
@@ -979,7 +1019,7 @@ class LogisticRegressionCV(BaseTransformer):
979
1019
  drop_input_cols = self._drop_input_cols,
980
1020
  expected_output_cols_type="float",
981
1021
  )
982
- expected_output_cols = self._align_expected_output_names(
1022
+ expected_output_cols, _ = self._align_expected_output(
983
1023
  inference_method, dataset, expected_output_cols, output_cols_prefix
984
1024
  )
985
1025
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -525,12 +522,23 @@ class MultiTaskElasticNet(BaseTransformer):
525
522
  autogenerated=self._autogenerated,
526
523
  subproject=_SUBPROJECT,
527
524
  )
528
- output_result, fitted_estimator = model_trainer.train_fit_predict(
529
- drop_input_cols=self._drop_input_cols,
530
- expected_output_cols_list=(
531
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
532
- ),
525
+ expected_output_cols = (
526
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
533
527
  )
528
+ if isinstance(dataset, DataFrame):
529
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
530
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
531
+ )
532
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
533
+ drop_input_cols=self._drop_input_cols,
534
+ expected_output_cols_list=expected_output_cols,
535
+ example_output_pd_df=example_output_pd_df,
536
+ )
537
+ else:
538
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
539
+ drop_input_cols=self._drop_input_cols,
540
+ expected_output_cols_list=expected_output_cols,
541
+ )
534
542
  self._sklearn_object = fitted_estimator
535
543
  self._is_fitted = True
536
544
  return output_result
@@ -553,6 +561,7 @@ class MultiTaskElasticNet(BaseTransformer):
553
561
  """
554
562
  self._infer_input_output_cols(dataset)
555
563
  super()._check_dataset_type(dataset)
564
+
556
565
  model_trainer = ModelTrainerBuilder.build_fit_transform(
557
566
  estimator=self._sklearn_object,
558
567
  dataset=dataset,
@@ -609,12 +618,41 @@ class MultiTaskElasticNet(BaseTransformer):
609
618
 
610
619
  return rv
611
620
 
612
- def _align_expected_output_names(
613
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
614
- ) -> List[str]:
621
+ def _align_expected_output(
622
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
623
+ ) -> Tuple[List[str], pd.DataFrame]:
624
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
625
+ and output dataframe with 1 line.
626
+ If the method is fit_predict, run 2 lines of data.
627
+ """
615
628
  # in case the inferred output column names dimension is different
616
629
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
617
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
630
+
631
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
632
+ # so change the minimum of number of rows to 2
633
+ num_examples = 2
634
+ statement_params = telemetry.get_function_usage_statement_params(
635
+ project=_PROJECT,
636
+ subproject=_SUBPROJECT,
637
+ function_name=telemetry.get_statement_params_full_func_name(
638
+ inspect.currentframe(), MultiTaskElasticNet.__class__.__name__
639
+ ),
640
+ api_calls=[Session.call],
641
+ custom_tags={"autogen": True} if self._autogenerated else None,
642
+ )
643
+ if output_cols_prefix == "fit_predict_":
644
+ if hasattr(self._sklearn_object, "n_clusters"):
645
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
646
+ num_examples = self._sklearn_object.n_clusters
647
+ elif hasattr(self._sklearn_object, "min_samples"):
648
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
649
+ num_examples = self._sklearn_object.min_samples
650
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
651
+ # LocalOutlierFactor expects n_neighbors <= n_samples
652
+ num_examples = self._sklearn_object.n_neighbors
653
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
654
+ else:
655
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
618
656
 
619
657
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
620
658
  # seen during the fit.
@@ -626,12 +664,14 @@ class MultiTaskElasticNet(BaseTransformer):
626
664
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
627
665
  if self.sample_weight_col:
628
666
  output_df_columns_set -= set(self.sample_weight_col)
667
+
629
668
  # if the dimension of inferred output column names is correct; use it
630
669
  if len(expected_output_cols_list) == len(output_df_columns_set):
631
- return expected_output_cols_list
670
+ return expected_output_cols_list, output_df_pd
632
671
  # otherwise, use the sklearn estimator's output
633
672
  else:
634
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
673
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
674
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
635
675
 
636
676
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
637
677
  @telemetry.send_api_usage_telemetry(
@@ -677,7 +717,7 @@ class MultiTaskElasticNet(BaseTransformer):
677
717
  drop_input_cols=self._drop_input_cols,
678
718
  expected_output_cols_type="float",
679
719
  )
680
- expected_output_cols = self._align_expected_output_names(
720
+ expected_output_cols, _ = self._align_expected_output(
681
721
  inference_method, dataset, expected_output_cols, output_cols_prefix
682
722
  )
683
723
 
@@ -743,7 +783,7 @@ class MultiTaskElasticNet(BaseTransformer):
743
783
  drop_input_cols=self._drop_input_cols,
744
784
  expected_output_cols_type="float",
745
785
  )
746
- expected_output_cols = self._align_expected_output_names(
786
+ expected_output_cols, _ = self._align_expected_output(
747
787
  inference_method, dataset, expected_output_cols, output_cols_prefix
748
788
  )
749
789
  elif isinstance(dataset, pd.DataFrame):
@@ -806,7 +846,7 @@ class MultiTaskElasticNet(BaseTransformer):
806
846
  drop_input_cols=self._drop_input_cols,
807
847
  expected_output_cols_type="float",
808
848
  )
809
- expected_output_cols = self._align_expected_output_names(
849
+ expected_output_cols, _ = self._align_expected_output(
810
850
  inference_method, dataset, expected_output_cols, output_cols_prefix
811
851
  )
812
852
 
@@ -871,7 +911,7 @@ class MultiTaskElasticNet(BaseTransformer):
871
911
  drop_input_cols = self._drop_input_cols,
872
912
  expected_output_cols_type="float",
873
913
  )
874
- expected_output_cols = self._align_expected_output_names(
914
+ expected_output_cols, _ = self._align_expected_output(
875
915
  inference_method, dataset, expected_output_cols, output_cols_prefix
876
916
  )
877
917
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -566,12 +563,23 @@ class MultiTaskElasticNetCV(BaseTransformer):
566
563
  autogenerated=self._autogenerated,
567
564
  subproject=_SUBPROJECT,
568
565
  )
569
- output_result, fitted_estimator = model_trainer.train_fit_predict(
570
- drop_input_cols=self._drop_input_cols,
571
- expected_output_cols_list=(
572
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
573
- ),
566
+ expected_output_cols = (
567
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
574
568
  )
569
+ if isinstance(dataset, DataFrame):
570
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
571
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
572
+ )
573
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
574
+ drop_input_cols=self._drop_input_cols,
575
+ expected_output_cols_list=expected_output_cols,
576
+ example_output_pd_df=example_output_pd_df,
577
+ )
578
+ else:
579
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
580
+ drop_input_cols=self._drop_input_cols,
581
+ expected_output_cols_list=expected_output_cols,
582
+ )
575
583
  self._sklearn_object = fitted_estimator
576
584
  self._is_fitted = True
577
585
  return output_result
@@ -594,6 +602,7 @@ class MultiTaskElasticNetCV(BaseTransformer):
594
602
  """
595
603
  self._infer_input_output_cols(dataset)
596
604
  super()._check_dataset_type(dataset)
605
+
597
606
  model_trainer = ModelTrainerBuilder.build_fit_transform(
598
607
  estimator=self._sklearn_object,
599
608
  dataset=dataset,
@@ -650,12 +659,41 @@ class MultiTaskElasticNetCV(BaseTransformer):
650
659
 
651
660
  return rv
652
661
 
653
- def _align_expected_output_names(
654
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
655
- ) -> List[str]:
662
+ def _align_expected_output(
663
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
664
+ ) -> Tuple[List[str], pd.DataFrame]:
665
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
666
+ and output dataframe with 1 line.
667
+ If the method is fit_predict, run 2 lines of data.
668
+ """
656
669
  # in case the inferred output column names dimension is different
657
670
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
658
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
671
+
672
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
673
+ # so change the minimum of number of rows to 2
674
+ num_examples = 2
675
+ statement_params = telemetry.get_function_usage_statement_params(
676
+ project=_PROJECT,
677
+ subproject=_SUBPROJECT,
678
+ function_name=telemetry.get_statement_params_full_func_name(
679
+ inspect.currentframe(), MultiTaskElasticNetCV.__class__.__name__
680
+ ),
681
+ api_calls=[Session.call],
682
+ custom_tags={"autogen": True} if self._autogenerated else None,
683
+ )
684
+ if output_cols_prefix == "fit_predict_":
685
+ if hasattr(self._sklearn_object, "n_clusters"):
686
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
687
+ num_examples = self._sklearn_object.n_clusters
688
+ elif hasattr(self._sklearn_object, "min_samples"):
689
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
690
+ num_examples = self._sklearn_object.min_samples
691
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
692
+ # LocalOutlierFactor expects n_neighbors <= n_samples
693
+ num_examples = self._sklearn_object.n_neighbors
694
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
695
+ else:
696
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
659
697
 
660
698
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
661
699
  # seen during the fit.
@@ -667,12 +705,14 @@ class MultiTaskElasticNetCV(BaseTransformer):
667
705
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
668
706
  if self.sample_weight_col:
669
707
  output_df_columns_set -= set(self.sample_weight_col)
708
+
670
709
  # if the dimension of inferred output column names is correct; use it
671
710
  if len(expected_output_cols_list) == len(output_df_columns_set):
672
- return expected_output_cols_list
711
+ return expected_output_cols_list, output_df_pd
673
712
  # otherwise, use the sklearn estimator's output
674
713
  else:
675
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
714
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
715
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
676
716
 
677
717
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
678
718
  @telemetry.send_api_usage_telemetry(
@@ -718,7 +758,7 @@ class MultiTaskElasticNetCV(BaseTransformer):
718
758
  drop_input_cols=self._drop_input_cols,
719
759
  expected_output_cols_type="float",
720
760
  )
721
- expected_output_cols = self._align_expected_output_names(
761
+ expected_output_cols, _ = self._align_expected_output(
722
762
  inference_method, dataset, expected_output_cols, output_cols_prefix
723
763
  )
724
764
 
@@ -784,7 +824,7 @@ class MultiTaskElasticNetCV(BaseTransformer):
784
824
  drop_input_cols=self._drop_input_cols,
785
825
  expected_output_cols_type="float",
786
826
  )
787
- expected_output_cols = self._align_expected_output_names(
827
+ expected_output_cols, _ = self._align_expected_output(
788
828
  inference_method, dataset, expected_output_cols, output_cols_prefix
789
829
  )
790
830
  elif isinstance(dataset, pd.DataFrame):
@@ -847,7 +887,7 @@ class MultiTaskElasticNetCV(BaseTransformer):
847
887
  drop_input_cols=self._drop_input_cols,
848
888
  expected_output_cols_type="float",
849
889
  )
850
- expected_output_cols = self._align_expected_output_names(
890
+ expected_output_cols, _ = self._align_expected_output(
851
891
  inference_method, dataset, expected_output_cols, output_cols_prefix
852
892
  )
853
893
 
@@ -912,7 +952,7 @@ class MultiTaskElasticNetCV(BaseTransformer):
912
952
  drop_input_cols = self._drop_input_cols,
913
953
  expected_output_cols_type="float",
914
954
  )
915
- expected_output_cols = self._align_expected_output_names(
955
+ expected_output_cols, _ = self._align_expected_output(
916
956
  inference_method, dataset, expected_output_cols, output_cols_prefix
917
957
  )
918
958
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -517,12 +514,23 @@ class MultiTaskLasso(BaseTransformer):
517
514
  autogenerated=self._autogenerated,
518
515
  subproject=_SUBPROJECT,
519
516
  )
520
- output_result, fitted_estimator = model_trainer.train_fit_predict(
521
- drop_input_cols=self._drop_input_cols,
522
- expected_output_cols_list=(
523
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
524
- ),
517
+ expected_output_cols = (
518
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
525
519
  )
520
+ if isinstance(dataset, DataFrame):
521
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
522
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
523
+ )
524
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
525
+ drop_input_cols=self._drop_input_cols,
526
+ expected_output_cols_list=expected_output_cols,
527
+ example_output_pd_df=example_output_pd_df,
528
+ )
529
+ else:
530
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
531
+ drop_input_cols=self._drop_input_cols,
532
+ expected_output_cols_list=expected_output_cols,
533
+ )
526
534
  self._sklearn_object = fitted_estimator
527
535
  self._is_fitted = True
528
536
  return output_result
@@ -545,6 +553,7 @@ class MultiTaskLasso(BaseTransformer):
545
553
  """
546
554
  self._infer_input_output_cols(dataset)
547
555
  super()._check_dataset_type(dataset)
556
+
548
557
  model_trainer = ModelTrainerBuilder.build_fit_transform(
549
558
  estimator=self._sklearn_object,
550
559
  dataset=dataset,
@@ -601,12 +610,41 @@ class MultiTaskLasso(BaseTransformer):
601
610
 
602
611
  return rv
603
612
 
604
- def _align_expected_output_names(
605
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
606
- ) -> List[str]:
613
+ def _align_expected_output(
614
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
615
+ ) -> Tuple[List[str], pd.DataFrame]:
616
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
617
+ and output dataframe with 1 line.
618
+ If the method is fit_predict, run 2 lines of data.
619
+ """
607
620
  # in case the inferred output column names dimension is different
608
621
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
609
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
622
+
623
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
624
+ # so change the minimum of number of rows to 2
625
+ num_examples = 2
626
+ statement_params = telemetry.get_function_usage_statement_params(
627
+ project=_PROJECT,
628
+ subproject=_SUBPROJECT,
629
+ function_name=telemetry.get_statement_params_full_func_name(
630
+ inspect.currentframe(), MultiTaskLasso.__class__.__name__
631
+ ),
632
+ api_calls=[Session.call],
633
+ custom_tags={"autogen": True} if self._autogenerated else None,
634
+ )
635
+ if output_cols_prefix == "fit_predict_":
636
+ if hasattr(self._sklearn_object, "n_clusters"):
637
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
638
+ num_examples = self._sklearn_object.n_clusters
639
+ elif hasattr(self._sklearn_object, "min_samples"):
640
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
641
+ num_examples = self._sklearn_object.min_samples
642
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
643
+ # LocalOutlierFactor expects n_neighbors <= n_samples
644
+ num_examples = self._sklearn_object.n_neighbors
645
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
646
+ else:
647
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
610
648
 
611
649
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
612
650
  # seen during the fit.
@@ -618,12 +656,14 @@ class MultiTaskLasso(BaseTransformer):
618
656
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
619
657
  if self.sample_weight_col:
620
658
  output_df_columns_set -= set(self.sample_weight_col)
659
+
621
660
  # if the dimension of inferred output column names is correct; use it
622
661
  if len(expected_output_cols_list) == len(output_df_columns_set):
623
- return expected_output_cols_list
662
+ return expected_output_cols_list, output_df_pd
624
663
  # otherwise, use the sklearn estimator's output
625
664
  else:
626
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
665
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
666
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
627
667
 
628
668
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
629
669
  @telemetry.send_api_usage_telemetry(
@@ -669,7 +709,7 @@ class MultiTaskLasso(BaseTransformer):
669
709
  drop_input_cols=self._drop_input_cols,
670
710
  expected_output_cols_type="float",
671
711
  )
672
- expected_output_cols = self._align_expected_output_names(
712
+ expected_output_cols, _ = self._align_expected_output(
673
713
  inference_method, dataset, expected_output_cols, output_cols_prefix
674
714
  )
675
715
 
@@ -735,7 +775,7 @@ class MultiTaskLasso(BaseTransformer):
735
775
  drop_input_cols=self._drop_input_cols,
736
776
  expected_output_cols_type="float",
737
777
  )
738
- expected_output_cols = self._align_expected_output_names(
778
+ expected_output_cols, _ = self._align_expected_output(
739
779
  inference_method, dataset, expected_output_cols, output_cols_prefix
740
780
  )
741
781
  elif isinstance(dataset, pd.DataFrame):
@@ -798,7 +838,7 @@ class MultiTaskLasso(BaseTransformer):
798
838
  drop_input_cols=self._drop_input_cols,
799
839
  expected_output_cols_type="float",
800
840
  )
801
- expected_output_cols = self._align_expected_output_names(
841
+ expected_output_cols, _ = self._align_expected_output(
802
842
  inference_method, dataset, expected_output_cols, output_cols_prefix
803
843
  )
804
844
 
@@ -863,7 +903,7 @@ class MultiTaskLasso(BaseTransformer):
863
903
  drop_input_cols = self._drop_input_cols,
864
904
  expected_output_cols_type="float",
865
905
  )
866
- expected_output_cols = self._align_expected_output_names(
906
+ expected_output_cols, _ = self._align_expected_output(
867
907
  inference_method, dataset, expected_output_cols, output_cols_prefix
868
908
  )
869
909