snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -534,12 +531,23 @@ class SparsePCA(BaseTransformer):
534
531
  autogenerated=self._autogenerated,
535
532
  subproject=_SUBPROJECT,
536
533
  )
537
- output_result, fitted_estimator = model_trainer.train_fit_predict(
538
- drop_input_cols=self._drop_input_cols,
539
- expected_output_cols_list=(
540
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
541
- ),
534
+ expected_output_cols = (
535
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
542
536
  )
537
+ if isinstance(dataset, DataFrame):
538
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
539
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
540
+ )
541
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
542
+ drop_input_cols=self._drop_input_cols,
543
+ expected_output_cols_list=expected_output_cols,
544
+ example_output_pd_df=example_output_pd_df,
545
+ )
546
+ else:
547
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
548
+ drop_input_cols=self._drop_input_cols,
549
+ expected_output_cols_list=expected_output_cols,
550
+ )
543
551
  self._sklearn_object = fitted_estimator
544
552
  self._is_fitted = True
545
553
  return output_result
@@ -564,6 +572,7 @@ class SparsePCA(BaseTransformer):
564
572
  """
565
573
  self._infer_input_output_cols(dataset)
566
574
  super()._check_dataset_type(dataset)
575
+
567
576
  model_trainer = ModelTrainerBuilder.build_fit_transform(
568
577
  estimator=self._sklearn_object,
569
578
  dataset=dataset,
@@ -620,12 +629,41 @@ class SparsePCA(BaseTransformer):
620
629
 
621
630
  return rv
622
631
 
623
- def _align_expected_output_names(
624
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
625
- ) -> List[str]:
632
+ def _align_expected_output(
633
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
634
+ ) -> Tuple[List[str], pd.DataFrame]:
635
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
636
+ and output dataframe with 1 line.
637
+ If the method is fit_predict, run 2 lines of data.
638
+ """
626
639
  # in case the inferred output column names dimension is different
627
640
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
628
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
641
+
642
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
643
+ # so change the minimum of number of rows to 2
644
+ num_examples = 2
645
+ statement_params = telemetry.get_function_usage_statement_params(
646
+ project=_PROJECT,
647
+ subproject=_SUBPROJECT,
648
+ function_name=telemetry.get_statement_params_full_func_name(
649
+ inspect.currentframe(), SparsePCA.__class__.__name__
650
+ ),
651
+ api_calls=[Session.call],
652
+ custom_tags={"autogen": True} if self._autogenerated else None,
653
+ )
654
+ if output_cols_prefix == "fit_predict_":
655
+ if hasattr(self._sklearn_object, "n_clusters"):
656
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
657
+ num_examples = self._sklearn_object.n_clusters
658
+ elif hasattr(self._sklearn_object, "min_samples"):
659
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
660
+ num_examples = self._sklearn_object.min_samples
661
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
662
+ # LocalOutlierFactor expects n_neighbors <= n_samples
663
+ num_examples = self._sklearn_object.n_neighbors
664
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
665
+ else:
666
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
629
667
 
630
668
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
631
669
  # seen during the fit.
@@ -637,12 +675,14 @@ class SparsePCA(BaseTransformer):
637
675
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
638
676
  if self.sample_weight_col:
639
677
  output_df_columns_set -= set(self.sample_weight_col)
678
+
640
679
  # if the dimension of inferred output column names is correct; use it
641
680
  if len(expected_output_cols_list) == len(output_df_columns_set):
642
- return expected_output_cols_list
681
+ return expected_output_cols_list, output_df_pd
643
682
  # otherwise, use the sklearn estimator's output
644
683
  else:
645
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
684
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
685
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
646
686
 
647
687
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
648
688
  @telemetry.send_api_usage_telemetry(
@@ -688,7 +728,7 @@ class SparsePCA(BaseTransformer):
688
728
  drop_input_cols=self._drop_input_cols,
689
729
  expected_output_cols_type="float",
690
730
  )
691
- expected_output_cols = self._align_expected_output_names(
731
+ expected_output_cols, _ = self._align_expected_output(
692
732
  inference_method, dataset, expected_output_cols, output_cols_prefix
693
733
  )
694
734
 
@@ -754,7 +794,7 @@ class SparsePCA(BaseTransformer):
754
794
  drop_input_cols=self._drop_input_cols,
755
795
  expected_output_cols_type="float",
756
796
  )
757
- expected_output_cols = self._align_expected_output_names(
797
+ expected_output_cols, _ = self._align_expected_output(
758
798
  inference_method, dataset, expected_output_cols, output_cols_prefix
759
799
  )
760
800
  elif isinstance(dataset, pd.DataFrame):
@@ -817,7 +857,7 @@ class SparsePCA(BaseTransformer):
817
857
  drop_input_cols=self._drop_input_cols,
818
858
  expected_output_cols_type="float",
819
859
  )
820
- expected_output_cols = self._align_expected_output_names(
860
+ expected_output_cols, _ = self._align_expected_output(
821
861
  inference_method, dataset, expected_output_cols, output_cols_prefix
822
862
  )
823
863
 
@@ -882,7 +922,7 @@ class SparsePCA(BaseTransformer):
882
922
  drop_input_cols = self._drop_input_cols,
883
923
  expected_output_cols_type="float",
884
924
  )
885
- expected_output_cols = self._align_expected_output_names(
925
+ expected_output_cols, _ = self._align_expected_output(
886
926
  inference_method, dataset, expected_output_cols, output_cols_prefix
887
927
  )
888
928
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -515,12 +512,23 @@ class TruncatedSVD(BaseTransformer):
515
512
  autogenerated=self._autogenerated,
516
513
  subproject=_SUBPROJECT,
517
514
  )
518
- output_result, fitted_estimator = model_trainer.train_fit_predict(
519
- drop_input_cols=self._drop_input_cols,
520
- expected_output_cols_list=(
521
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
522
- ),
515
+ expected_output_cols = (
516
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
523
517
  )
518
+ if isinstance(dataset, DataFrame):
519
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
520
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
521
+ )
522
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
523
+ drop_input_cols=self._drop_input_cols,
524
+ expected_output_cols_list=expected_output_cols,
525
+ example_output_pd_df=example_output_pd_df,
526
+ )
527
+ else:
528
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
529
+ drop_input_cols=self._drop_input_cols,
530
+ expected_output_cols_list=expected_output_cols,
531
+ )
524
532
  self._sklearn_object = fitted_estimator
525
533
  self._is_fitted = True
526
534
  return output_result
@@ -545,6 +553,7 @@ class TruncatedSVD(BaseTransformer):
545
553
  """
546
554
  self._infer_input_output_cols(dataset)
547
555
  super()._check_dataset_type(dataset)
556
+
548
557
  model_trainer = ModelTrainerBuilder.build_fit_transform(
549
558
  estimator=self._sklearn_object,
550
559
  dataset=dataset,
@@ -601,12 +610,41 @@ class TruncatedSVD(BaseTransformer):
601
610
 
602
611
  return rv
603
612
 
604
- def _align_expected_output_names(
605
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
606
- ) -> List[str]:
613
+ def _align_expected_output(
614
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
615
+ ) -> Tuple[List[str], pd.DataFrame]:
616
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
617
+ and output dataframe with 1 line.
618
+ If the method is fit_predict, run 2 lines of data.
619
+ """
607
620
  # in case the inferred output column names dimension is different
608
621
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
609
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
622
+
623
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
624
+ # so change the minimum of number of rows to 2
625
+ num_examples = 2
626
+ statement_params = telemetry.get_function_usage_statement_params(
627
+ project=_PROJECT,
628
+ subproject=_SUBPROJECT,
629
+ function_name=telemetry.get_statement_params_full_func_name(
630
+ inspect.currentframe(), TruncatedSVD.__class__.__name__
631
+ ),
632
+ api_calls=[Session.call],
633
+ custom_tags={"autogen": True} if self._autogenerated else None,
634
+ )
635
+ if output_cols_prefix == "fit_predict_":
636
+ if hasattr(self._sklearn_object, "n_clusters"):
637
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
638
+ num_examples = self._sklearn_object.n_clusters
639
+ elif hasattr(self._sklearn_object, "min_samples"):
640
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
641
+ num_examples = self._sklearn_object.min_samples
642
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
643
+ # LocalOutlierFactor expects n_neighbors <= n_samples
644
+ num_examples = self._sklearn_object.n_neighbors
645
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
646
+ else:
647
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
610
648
 
611
649
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
612
650
  # seen during the fit.
@@ -618,12 +656,14 @@ class TruncatedSVD(BaseTransformer):
618
656
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
619
657
  if self.sample_weight_col:
620
658
  output_df_columns_set -= set(self.sample_weight_col)
659
+
621
660
  # if the dimension of inferred output column names is correct; use it
622
661
  if len(expected_output_cols_list) == len(output_df_columns_set):
623
- return expected_output_cols_list
662
+ return expected_output_cols_list, output_df_pd
624
663
  # otherwise, use the sklearn estimator's output
625
664
  else:
626
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
665
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
666
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
627
667
 
628
668
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
629
669
  @telemetry.send_api_usage_telemetry(
@@ -669,7 +709,7 @@ class TruncatedSVD(BaseTransformer):
669
709
  drop_input_cols=self._drop_input_cols,
670
710
  expected_output_cols_type="float",
671
711
  )
672
- expected_output_cols = self._align_expected_output_names(
712
+ expected_output_cols, _ = self._align_expected_output(
673
713
  inference_method, dataset, expected_output_cols, output_cols_prefix
674
714
  )
675
715
 
@@ -735,7 +775,7 @@ class TruncatedSVD(BaseTransformer):
735
775
  drop_input_cols=self._drop_input_cols,
736
776
  expected_output_cols_type="float",
737
777
  )
738
- expected_output_cols = self._align_expected_output_names(
778
+ expected_output_cols, _ = self._align_expected_output(
739
779
  inference_method, dataset, expected_output_cols, output_cols_prefix
740
780
  )
741
781
  elif isinstance(dataset, pd.DataFrame):
@@ -798,7 +838,7 @@ class TruncatedSVD(BaseTransformer):
798
838
  drop_input_cols=self._drop_input_cols,
799
839
  expected_output_cols_type="float",
800
840
  )
801
- expected_output_cols = self._align_expected_output_names(
841
+ expected_output_cols, _ = self._align_expected_output(
802
842
  inference_method, dataset, expected_output_cols, output_cols_prefix
803
843
  )
804
844
 
@@ -863,7 +903,7 @@ class TruncatedSVD(BaseTransformer):
863
903
  drop_input_cols = self._drop_input_cols,
864
904
  expected_output_cols_type="float",
865
905
  )
866
- expected_output_cols = self._align_expected_output_names(
906
+ expected_output_cols, _ = self._align_expected_output(
867
907
  inference_method, dataset, expected_output_cols, output_cols_prefix
868
908
  )
869
909
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -532,12 +529,23 @@ class LinearDiscriminantAnalysis(BaseTransformer):
532
529
  autogenerated=self._autogenerated,
533
530
  subproject=_SUBPROJECT,
534
531
  )
535
- output_result, fitted_estimator = model_trainer.train_fit_predict(
536
- drop_input_cols=self._drop_input_cols,
537
- expected_output_cols_list=(
538
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
539
- ),
532
+ expected_output_cols = (
533
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
540
534
  )
535
+ if isinstance(dataset, DataFrame):
536
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
537
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
538
+ )
539
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
540
+ drop_input_cols=self._drop_input_cols,
541
+ expected_output_cols_list=expected_output_cols,
542
+ example_output_pd_df=example_output_pd_df,
543
+ )
544
+ else:
545
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
546
+ drop_input_cols=self._drop_input_cols,
547
+ expected_output_cols_list=expected_output_cols,
548
+ )
541
549
  self._sklearn_object = fitted_estimator
542
550
  self._is_fitted = True
543
551
  return output_result
@@ -562,6 +570,7 @@ class LinearDiscriminantAnalysis(BaseTransformer):
562
570
  """
563
571
  self._infer_input_output_cols(dataset)
564
572
  super()._check_dataset_type(dataset)
573
+
565
574
  model_trainer = ModelTrainerBuilder.build_fit_transform(
566
575
  estimator=self._sklearn_object,
567
576
  dataset=dataset,
@@ -618,12 +627,41 @@ class LinearDiscriminantAnalysis(BaseTransformer):
618
627
 
619
628
  return rv
620
629
 
621
- def _align_expected_output_names(
622
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
623
- ) -> List[str]:
630
+ def _align_expected_output(
631
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
632
+ ) -> Tuple[List[str], pd.DataFrame]:
633
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
634
+ and output dataframe with 1 line.
635
+ If the method is fit_predict, run 2 lines of data.
636
+ """
624
637
  # in case the inferred output column names dimension is different
625
638
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
626
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
639
+
640
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
641
+ # so change the minimum of number of rows to 2
642
+ num_examples = 2
643
+ statement_params = telemetry.get_function_usage_statement_params(
644
+ project=_PROJECT,
645
+ subproject=_SUBPROJECT,
646
+ function_name=telemetry.get_statement_params_full_func_name(
647
+ inspect.currentframe(), LinearDiscriminantAnalysis.__class__.__name__
648
+ ),
649
+ api_calls=[Session.call],
650
+ custom_tags={"autogen": True} if self._autogenerated else None,
651
+ )
652
+ if output_cols_prefix == "fit_predict_":
653
+ if hasattr(self._sklearn_object, "n_clusters"):
654
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
655
+ num_examples = self._sklearn_object.n_clusters
656
+ elif hasattr(self._sklearn_object, "min_samples"):
657
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
658
+ num_examples = self._sklearn_object.min_samples
659
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
660
+ # LocalOutlierFactor expects n_neighbors <= n_samples
661
+ num_examples = self._sklearn_object.n_neighbors
662
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
663
+ else:
664
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
627
665
 
628
666
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
629
667
  # seen during the fit.
@@ -635,12 +673,14 @@ class LinearDiscriminantAnalysis(BaseTransformer):
635
673
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
636
674
  if self.sample_weight_col:
637
675
  output_df_columns_set -= set(self.sample_weight_col)
676
+
638
677
  # if the dimension of inferred output column names is correct; use it
639
678
  if len(expected_output_cols_list) == len(output_df_columns_set):
640
- return expected_output_cols_list
679
+ return expected_output_cols_list, output_df_pd
641
680
  # otherwise, use the sklearn estimator's output
642
681
  else:
643
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
682
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
683
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
644
684
 
645
685
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
646
686
  @telemetry.send_api_usage_telemetry(
@@ -688,7 +728,7 @@ class LinearDiscriminantAnalysis(BaseTransformer):
688
728
  drop_input_cols=self._drop_input_cols,
689
729
  expected_output_cols_type="float",
690
730
  )
691
- expected_output_cols = self._align_expected_output_names(
731
+ expected_output_cols, _ = self._align_expected_output(
692
732
  inference_method, dataset, expected_output_cols, output_cols_prefix
693
733
  )
694
734
 
@@ -756,7 +796,7 @@ class LinearDiscriminantAnalysis(BaseTransformer):
756
796
  drop_input_cols=self._drop_input_cols,
757
797
  expected_output_cols_type="float",
758
798
  )
759
- expected_output_cols = self._align_expected_output_names(
799
+ expected_output_cols, _ = self._align_expected_output(
760
800
  inference_method, dataset, expected_output_cols, output_cols_prefix
761
801
  )
762
802
  elif isinstance(dataset, pd.DataFrame):
@@ -821,7 +861,7 @@ class LinearDiscriminantAnalysis(BaseTransformer):
821
861
  drop_input_cols=self._drop_input_cols,
822
862
  expected_output_cols_type="float",
823
863
  )
824
- expected_output_cols = self._align_expected_output_names(
864
+ expected_output_cols, _ = self._align_expected_output(
825
865
  inference_method, dataset, expected_output_cols, output_cols_prefix
826
866
  )
827
867
 
@@ -886,7 +926,7 @@ class LinearDiscriminantAnalysis(BaseTransformer):
886
926
  drop_input_cols = self._drop_input_cols,
887
927
  expected_output_cols_type="float",
888
928
  )
889
- expected_output_cols = self._align_expected_output_names(
929
+ expected_output_cols, _ = self._align_expected_output(
890
930
  inference_method, dataset, expected_output_cols, output_cols_prefix
891
931
  )
892
932
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -492,12 +489,23 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
492
489
  autogenerated=self._autogenerated,
493
490
  subproject=_SUBPROJECT,
494
491
  )
495
- output_result, fitted_estimator = model_trainer.train_fit_predict(
496
- drop_input_cols=self._drop_input_cols,
497
- expected_output_cols_list=(
498
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
499
- ),
492
+ expected_output_cols = (
493
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
500
494
  )
495
+ if isinstance(dataset, DataFrame):
496
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
497
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
498
+ )
499
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
500
+ drop_input_cols=self._drop_input_cols,
501
+ expected_output_cols_list=expected_output_cols,
502
+ example_output_pd_df=example_output_pd_df,
503
+ )
504
+ else:
505
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
506
+ drop_input_cols=self._drop_input_cols,
507
+ expected_output_cols_list=expected_output_cols,
508
+ )
501
509
  self._sklearn_object = fitted_estimator
502
510
  self._is_fitted = True
503
511
  return output_result
@@ -520,6 +528,7 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
520
528
  """
521
529
  self._infer_input_output_cols(dataset)
522
530
  super()._check_dataset_type(dataset)
531
+
523
532
  model_trainer = ModelTrainerBuilder.build_fit_transform(
524
533
  estimator=self._sklearn_object,
525
534
  dataset=dataset,
@@ -576,12 +585,41 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
576
585
 
577
586
  return rv
578
587
 
579
- def _align_expected_output_names(
580
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
581
- ) -> List[str]:
588
+ def _align_expected_output(
589
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
590
+ ) -> Tuple[List[str], pd.DataFrame]:
591
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
592
+ and output dataframe with 1 line.
593
+ If the method is fit_predict, run 2 lines of data.
594
+ """
582
595
  # in case the inferred output column names dimension is different
583
596
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
584
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
597
+
598
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
599
+ # so change the minimum of number of rows to 2
600
+ num_examples = 2
601
+ statement_params = telemetry.get_function_usage_statement_params(
602
+ project=_PROJECT,
603
+ subproject=_SUBPROJECT,
604
+ function_name=telemetry.get_statement_params_full_func_name(
605
+ inspect.currentframe(), QuadraticDiscriminantAnalysis.__class__.__name__
606
+ ),
607
+ api_calls=[Session.call],
608
+ custom_tags={"autogen": True} if self._autogenerated else None,
609
+ )
610
+ if output_cols_prefix == "fit_predict_":
611
+ if hasattr(self._sklearn_object, "n_clusters"):
612
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
613
+ num_examples = self._sklearn_object.n_clusters
614
+ elif hasattr(self._sklearn_object, "min_samples"):
615
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
616
+ num_examples = self._sklearn_object.min_samples
617
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
618
+ # LocalOutlierFactor expects n_neighbors <= n_samples
619
+ num_examples = self._sklearn_object.n_neighbors
620
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
621
+ else:
622
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
585
623
 
586
624
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
587
625
  # seen during the fit.
@@ -593,12 +631,14 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
593
631
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
594
632
  if self.sample_weight_col:
595
633
  output_df_columns_set -= set(self.sample_weight_col)
634
+
596
635
  # if the dimension of inferred output column names is correct; use it
597
636
  if len(expected_output_cols_list) == len(output_df_columns_set):
598
- return expected_output_cols_list
637
+ return expected_output_cols_list, output_df_pd
599
638
  # otherwise, use the sklearn estimator's output
600
639
  else:
601
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
640
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
641
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
602
642
 
603
643
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
604
644
  @telemetry.send_api_usage_telemetry(
@@ -646,7 +686,7 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
646
686
  drop_input_cols=self._drop_input_cols,
647
687
  expected_output_cols_type="float",
648
688
  )
649
- expected_output_cols = self._align_expected_output_names(
689
+ expected_output_cols, _ = self._align_expected_output(
650
690
  inference_method, dataset, expected_output_cols, output_cols_prefix
651
691
  )
652
692
 
@@ -714,7 +754,7 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
714
754
  drop_input_cols=self._drop_input_cols,
715
755
  expected_output_cols_type="float",
716
756
  )
717
- expected_output_cols = self._align_expected_output_names(
757
+ expected_output_cols, _ = self._align_expected_output(
718
758
  inference_method, dataset, expected_output_cols, output_cols_prefix
719
759
  )
720
760
  elif isinstance(dataset, pd.DataFrame):
@@ -779,7 +819,7 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
779
819
  drop_input_cols=self._drop_input_cols,
780
820
  expected_output_cols_type="float",
781
821
  )
782
- expected_output_cols = self._align_expected_output_names(
822
+ expected_output_cols, _ = self._align_expected_output(
783
823
  inference_method, dataset, expected_output_cols, output_cols_prefix
784
824
  )
785
825
 
@@ -844,7 +884,7 @@ class QuadraticDiscriminantAnalysis(BaseTransformer):
844
884
  drop_input_cols = self._drop_input_cols,
845
885
  expected_output_cols_type="float",
846
886
  )
847
- expected_output_cols = self._align_expected_output_names(
887
+ expected_output_cols, _ = self._align_expected_output(
848
888
  inference_method, dataset, expected_output_cols, output_cols_prefix
849
889
  )
850
890