snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -592,12 +589,23 @@ class KernelPCA(BaseTransformer):
592
589
  autogenerated=self._autogenerated,
593
590
  subproject=_SUBPROJECT,
594
591
  )
595
- output_result, fitted_estimator = model_trainer.train_fit_predict(
596
- drop_input_cols=self._drop_input_cols,
597
- expected_output_cols_list=(
598
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
599
- ),
592
+ expected_output_cols = (
593
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
600
594
  )
595
+ if isinstance(dataset, DataFrame):
596
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
597
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
598
+ )
599
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
600
+ drop_input_cols=self._drop_input_cols,
601
+ expected_output_cols_list=expected_output_cols,
602
+ example_output_pd_df=example_output_pd_df,
603
+ )
604
+ else:
605
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
606
+ drop_input_cols=self._drop_input_cols,
607
+ expected_output_cols_list=expected_output_cols,
608
+ )
601
609
  self._sklearn_object = fitted_estimator
602
610
  self._is_fitted = True
603
611
  return output_result
@@ -622,6 +630,7 @@ class KernelPCA(BaseTransformer):
622
630
  """
623
631
  self._infer_input_output_cols(dataset)
624
632
  super()._check_dataset_type(dataset)
633
+
625
634
  model_trainer = ModelTrainerBuilder.build_fit_transform(
626
635
  estimator=self._sklearn_object,
627
636
  dataset=dataset,
@@ -678,12 +687,41 @@ class KernelPCA(BaseTransformer):
678
687
 
679
688
  return rv
680
689
 
681
- def _align_expected_output_names(
682
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
683
- ) -> List[str]:
690
+ def _align_expected_output(
691
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
692
+ ) -> Tuple[List[str], pd.DataFrame]:
693
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
694
+ and output dataframe with 1 line.
695
+ If the method is fit_predict, run 2 lines of data.
696
+ """
684
697
  # in case the inferred output column names dimension is different
685
698
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
686
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
699
+
700
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
701
+ # so change the minimum of number of rows to 2
702
+ num_examples = 2
703
+ statement_params = telemetry.get_function_usage_statement_params(
704
+ project=_PROJECT,
705
+ subproject=_SUBPROJECT,
706
+ function_name=telemetry.get_statement_params_full_func_name(
707
+ inspect.currentframe(), KernelPCA.__class__.__name__
708
+ ),
709
+ api_calls=[Session.call],
710
+ custom_tags={"autogen": True} if self._autogenerated else None,
711
+ )
712
+ if output_cols_prefix == "fit_predict_":
713
+ if hasattr(self._sklearn_object, "n_clusters"):
714
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
715
+ num_examples = self._sklearn_object.n_clusters
716
+ elif hasattr(self._sklearn_object, "min_samples"):
717
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
718
+ num_examples = self._sklearn_object.min_samples
719
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
720
+ # LocalOutlierFactor expects n_neighbors <= n_samples
721
+ num_examples = self._sklearn_object.n_neighbors
722
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
723
+ else:
724
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
687
725
 
688
726
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
689
727
  # seen during the fit.
@@ -695,12 +733,14 @@ class KernelPCA(BaseTransformer):
695
733
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
696
734
  if self.sample_weight_col:
697
735
  output_df_columns_set -= set(self.sample_weight_col)
736
+
698
737
  # if the dimension of inferred output column names is correct; use it
699
738
  if len(expected_output_cols_list) == len(output_df_columns_set):
700
- return expected_output_cols_list
739
+ return expected_output_cols_list, output_df_pd
701
740
  # otherwise, use the sklearn estimator's output
702
741
  else:
703
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
742
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
743
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
704
744
 
705
745
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
706
746
  @telemetry.send_api_usage_telemetry(
@@ -746,7 +786,7 @@ class KernelPCA(BaseTransformer):
746
786
  drop_input_cols=self._drop_input_cols,
747
787
  expected_output_cols_type="float",
748
788
  )
749
- expected_output_cols = self._align_expected_output_names(
789
+ expected_output_cols, _ = self._align_expected_output(
750
790
  inference_method, dataset, expected_output_cols, output_cols_prefix
751
791
  )
752
792
 
@@ -812,7 +852,7 @@ class KernelPCA(BaseTransformer):
812
852
  drop_input_cols=self._drop_input_cols,
813
853
  expected_output_cols_type="float",
814
854
  )
815
- expected_output_cols = self._align_expected_output_names(
855
+ expected_output_cols, _ = self._align_expected_output(
816
856
  inference_method, dataset, expected_output_cols, output_cols_prefix
817
857
  )
818
858
  elif isinstance(dataset, pd.DataFrame):
@@ -875,7 +915,7 @@ class KernelPCA(BaseTransformer):
875
915
  drop_input_cols=self._drop_input_cols,
876
916
  expected_output_cols_type="float",
877
917
  )
878
- expected_output_cols = self._align_expected_output_names(
918
+ expected_output_cols, _ = self._align_expected_output(
879
919
  inference_method, dataset, expected_output_cols, output_cols_prefix
880
920
  )
881
921
 
@@ -940,7 +980,7 @@ class KernelPCA(BaseTransformer):
940
980
  drop_input_cols = self._drop_input_cols,
941
981
  expected_output_cols_type="float",
942
982
  )
943
- expected_output_cols = self._align_expected_output_names(
983
+ expected_output_cols, _ = self._align_expected_output(
944
984
  inference_method, dataset, expected_output_cols, output_cols_prefix
945
985
  )
946
986
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -614,12 +611,23 @@ class MiniBatchDictionaryLearning(BaseTransformer):
614
611
  autogenerated=self._autogenerated,
615
612
  subproject=_SUBPROJECT,
616
613
  )
617
- output_result, fitted_estimator = model_trainer.train_fit_predict(
618
- drop_input_cols=self._drop_input_cols,
619
- expected_output_cols_list=(
620
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
621
- ),
614
+ expected_output_cols = (
615
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
622
616
  )
617
+ if isinstance(dataset, DataFrame):
618
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
619
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
620
+ )
621
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
622
+ drop_input_cols=self._drop_input_cols,
623
+ expected_output_cols_list=expected_output_cols,
624
+ example_output_pd_df=example_output_pd_df,
625
+ )
626
+ else:
627
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
628
+ drop_input_cols=self._drop_input_cols,
629
+ expected_output_cols_list=expected_output_cols,
630
+ )
623
631
  self._sklearn_object = fitted_estimator
624
632
  self._is_fitted = True
625
633
  return output_result
@@ -644,6 +652,7 @@ class MiniBatchDictionaryLearning(BaseTransformer):
644
652
  """
645
653
  self._infer_input_output_cols(dataset)
646
654
  super()._check_dataset_type(dataset)
655
+
647
656
  model_trainer = ModelTrainerBuilder.build_fit_transform(
648
657
  estimator=self._sklearn_object,
649
658
  dataset=dataset,
@@ -700,12 +709,41 @@ class MiniBatchDictionaryLearning(BaseTransformer):
700
709
 
701
710
  return rv
702
711
 
703
- def _align_expected_output_names(
704
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
705
- ) -> List[str]:
712
+ def _align_expected_output(
713
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
714
+ ) -> Tuple[List[str], pd.DataFrame]:
715
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
716
+ and output dataframe with 1 line.
717
+ If the method is fit_predict, run 2 lines of data.
718
+ """
706
719
  # in case the inferred output column names dimension is different
707
720
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
708
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
721
+
722
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
723
+ # so change the minimum of number of rows to 2
724
+ num_examples = 2
725
+ statement_params = telemetry.get_function_usage_statement_params(
726
+ project=_PROJECT,
727
+ subproject=_SUBPROJECT,
728
+ function_name=telemetry.get_statement_params_full_func_name(
729
+ inspect.currentframe(), MiniBatchDictionaryLearning.__class__.__name__
730
+ ),
731
+ api_calls=[Session.call],
732
+ custom_tags={"autogen": True} if self._autogenerated else None,
733
+ )
734
+ if output_cols_prefix == "fit_predict_":
735
+ if hasattr(self._sklearn_object, "n_clusters"):
736
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
737
+ num_examples = self._sklearn_object.n_clusters
738
+ elif hasattr(self._sklearn_object, "min_samples"):
739
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
740
+ num_examples = self._sklearn_object.min_samples
741
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
742
+ # LocalOutlierFactor expects n_neighbors <= n_samples
743
+ num_examples = self._sklearn_object.n_neighbors
744
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
745
+ else:
746
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
709
747
 
710
748
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
711
749
  # seen during the fit.
@@ -717,12 +755,14 @@ class MiniBatchDictionaryLearning(BaseTransformer):
717
755
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
718
756
  if self.sample_weight_col:
719
757
  output_df_columns_set -= set(self.sample_weight_col)
758
+
720
759
  # if the dimension of inferred output column names is correct; use it
721
760
  if len(expected_output_cols_list) == len(output_df_columns_set):
722
- return expected_output_cols_list
761
+ return expected_output_cols_list, output_df_pd
723
762
  # otherwise, use the sklearn estimator's output
724
763
  else:
725
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
764
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
765
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
726
766
 
727
767
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
728
768
  @telemetry.send_api_usage_telemetry(
@@ -768,7 +808,7 @@ class MiniBatchDictionaryLearning(BaseTransformer):
768
808
  drop_input_cols=self._drop_input_cols,
769
809
  expected_output_cols_type="float",
770
810
  )
771
- expected_output_cols = self._align_expected_output_names(
811
+ expected_output_cols, _ = self._align_expected_output(
772
812
  inference_method, dataset, expected_output_cols, output_cols_prefix
773
813
  )
774
814
 
@@ -834,7 +874,7 @@ class MiniBatchDictionaryLearning(BaseTransformer):
834
874
  drop_input_cols=self._drop_input_cols,
835
875
  expected_output_cols_type="float",
836
876
  )
837
- expected_output_cols = self._align_expected_output_names(
877
+ expected_output_cols, _ = self._align_expected_output(
838
878
  inference_method, dataset, expected_output_cols, output_cols_prefix
839
879
  )
840
880
  elif isinstance(dataset, pd.DataFrame):
@@ -897,7 +937,7 @@ class MiniBatchDictionaryLearning(BaseTransformer):
897
937
  drop_input_cols=self._drop_input_cols,
898
938
  expected_output_cols_type="float",
899
939
  )
900
- expected_output_cols = self._align_expected_output_names(
940
+ expected_output_cols, _ = self._align_expected_output(
901
941
  inference_method, dataset, expected_output_cols, output_cols_prefix
902
942
  )
903
943
 
@@ -962,7 +1002,7 @@ class MiniBatchDictionaryLearning(BaseTransformer):
962
1002
  drop_input_cols = self._drop_input_cols,
963
1003
  expected_output_cols_type="float",
964
1004
  )
965
- expected_output_cols = self._align_expected_output_names(
1005
+ expected_output_cols, _ = self._align_expected_output(
966
1006
  inference_method, dataset, expected_output_cols, output_cols_prefix
967
1007
  )
968
1008
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -559,12 +556,23 @@ class MiniBatchSparsePCA(BaseTransformer):
559
556
  autogenerated=self._autogenerated,
560
557
  subproject=_SUBPROJECT,
561
558
  )
562
- output_result, fitted_estimator = model_trainer.train_fit_predict(
563
- drop_input_cols=self._drop_input_cols,
564
- expected_output_cols_list=(
565
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
566
- ),
559
+ expected_output_cols = (
560
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
567
561
  )
562
+ if isinstance(dataset, DataFrame):
563
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
564
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
565
+ )
566
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
567
+ drop_input_cols=self._drop_input_cols,
568
+ expected_output_cols_list=expected_output_cols,
569
+ example_output_pd_df=example_output_pd_df,
570
+ )
571
+ else:
572
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
573
+ drop_input_cols=self._drop_input_cols,
574
+ expected_output_cols_list=expected_output_cols,
575
+ )
568
576
  self._sklearn_object = fitted_estimator
569
577
  self._is_fitted = True
570
578
  return output_result
@@ -589,6 +597,7 @@ class MiniBatchSparsePCA(BaseTransformer):
589
597
  """
590
598
  self._infer_input_output_cols(dataset)
591
599
  super()._check_dataset_type(dataset)
600
+
592
601
  model_trainer = ModelTrainerBuilder.build_fit_transform(
593
602
  estimator=self._sklearn_object,
594
603
  dataset=dataset,
@@ -645,12 +654,41 @@ class MiniBatchSparsePCA(BaseTransformer):
645
654
 
646
655
  return rv
647
656
 
648
- def _align_expected_output_names(
649
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
650
- ) -> List[str]:
657
+ def _align_expected_output(
658
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
659
+ ) -> Tuple[List[str], pd.DataFrame]:
660
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
661
+ and output dataframe with 1 line.
662
+ If the method is fit_predict, run 2 lines of data.
663
+ """
651
664
  # in case the inferred output column names dimension is different
652
665
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
653
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
666
+
667
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
668
+ # so change the minimum of number of rows to 2
669
+ num_examples = 2
670
+ statement_params = telemetry.get_function_usage_statement_params(
671
+ project=_PROJECT,
672
+ subproject=_SUBPROJECT,
673
+ function_name=telemetry.get_statement_params_full_func_name(
674
+ inspect.currentframe(), MiniBatchSparsePCA.__class__.__name__
675
+ ),
676
+ api_calls=[Session.call],
677
+ custom_tags={"autogen": True} if self._autogenerated else None,
678
+ )
679
+ if output_cols_prefix == "fit_predict_":
680
+ if hasattr(self._sklearn_object, "n_clusters"):
681
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
682
+ num_examples = self._sklearn_object.n_clusters
683
+ elif hasattr(self._sklearn_object, "min_samples"):
684
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
685
+ num_examples = self._sklearn_object.min_samples
686
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
687
+ # LocalOutlierFactor expects n_neighbors <= n_samples
688
+ num_examples = self._sklearn_object.n_neighbors
689
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
690
+ else:
691
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
654
692
 
655
693
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
656
694
  # seen during the fit.
@@ -662,12 +700,14 @@ class MiniBatchSparsePCA(BaseTransformer):
662
700
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
663
701
  if self.sample_weight_col:
664
702
  output_df_columns_set -= set(self.sample_weight_col)
703
+
665
704
  # if the dimension of inferred output column names is correct; use it
666
705
  if len(expected_output_cols_list) == len(output_df_columns_set):
667
- return expected_output_cols_list
706
+ return expected_output_cols_list, output_df_pd
668
707
  # otherwise, use the sklearn estimator's output
669
708
  else:
670
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
709
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
710
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
671
711
 
672
712
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
673
713
  @telemetry.send_api_usage_telemetry(
@@ -713,7 +753,7 @@ class MiniBatchSparsePCA(BaseTransformer):
713
753
  drop_input_cols=self._drop_input_cols,
714
754
  expected_output_cols_type="float",
715
755
  )
716
- expected_output_cols = self._align_expected_output_names(
756
+ expected_output_cols, _ = self._align_expected_output(
717
757
  inference_method, dataset, expected_output_cols, output_cols_prefix
718
758
  )
719
759
 
@@ -779,7 +819,7 @@ class MiniBatchSparsePCA(BaseTransformer):
779
819
  drop_input_cols=self._drop_input_cols,
780
820
  expected_output_cols_type="float",
781
821
  )
782
- expected_output_cols = self._align_expected_output_names(
822
+ expected_output_cols, _ = self._align_expected_output(
783
823
  inference_method, dataset, expected_output_cols, output_cols_prefix
784
824
  )
785
825
  elif isinstance(dataset, pd.DataFrame):
@@ -842,7 +882,7 @@ class MiniBatchSparsePCA(BaseTransformer):
842
882
  drop_input_cols=self._drop_input_cols,
843
883
  expected_output_cols_type="float",
844
884
  )
845
- expected_output_cols = self._align_expected_output_names(
885
+ expected_output_cols, _ = self._align_expected_output(
846
886
  inference_method, dataset, expected_output_cols, output_cols_prefix
847
887
  )
848
888
 
@@ -907,7 +947,7 @@ class MiniBatchSparsePCA(BaseTransformer):
907
947
  drop_input_cols = self._drop_input_cols,
908
948
  expected_output_cols_type="float",
909
949
  )
910
- expected_output_cols = self._align_expected_output_names(
950
+ expected_output_cols, _ = self._align_expected_output(
911
951
  inference_method, dataset, expected_output_cols, output_cols_prefix
912
952
  )
913
953
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -561,12 +558,23 @@ class PCA(BaseTransformer):
561
558
  autogenerated=self._autogenerated,
562
559
  subproject=_SUBPROJECT,
563
560
  )
564
- output_result, fitted_estimator = model_trainer.train_fit_predict(
565
- drop_input_cols=self._drop_input_cols,
566
- expected_output_cols_list=(
567
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
568
- ),
561
+ expected_output_cols = (
562
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
569
563
  )
564
+ if isinstance(dataset, DataFrame):
565
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
566
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
567
+ )
568
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
569
+ drop_input_cols=self._drop_input_cols,
570
+ expected_output_cols_list=expected_output_cols,
571
+ example_output_pd_df=example_output_pd_df,
572
+ )
573
+ else:
574
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
575
+ drop_input_cols=self._drop_input_cols,
576
+ expected_output_cols_list=expected_output_cols,
577
+ )
570
578
  self._sklearn_object = fitted_estimator
571
579
  self._is_fitted = True
572
580
  return output_result
@@ -591,6 +599,7 @@ class PCA(BaseTransformer):
591
599
  """
592
600
  self._infer_input_output_cols(dataset)
593
601
  super()._check_dataset_type(dataset)
602
+
594
603
  model_trainer = ModelTrainerBuilder.build_fit_transform(
595
604
  estimator=self._sklearn_object,
596
605
  dataset=dataset,
@@ -647,12 +656,41 @@ class PCA(BaseTransformer):
647
656
 
648
657
  return rv
649
658
 
650
- def _align_expected_output_names(
651
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
652
- ) -> List[str]:
659
+ def _align_expected_output(
660
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
661
+ ) -> Tuple[List[str], pd.DataFrame]:
662
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
663
+ and output dataframe with 1 line.
664
+ If the method is fit_predict, run 2 lines of data.
665
+ """
653
666
  # in case the inferred output column names dimension is different
654
667
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
655
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
668
+
669
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
670
+ # so change the minimum of number of rows to 2
671
+ num_examples = 2
672
+ statement_params = telemetry.get_function_usage_statement_params(
673
+ project=_PROJECT,
674
+ subproject=_SUBPROJECT,
675
+ function_name=telemetry.get_statement_params_full_func_name(
676
+ inspect.currentframe(), PCA.__class__.__name__
677
+ ),
678
+ api_calls=[Session.call],
679
+ custom_tags={"autogen": True} if self._autogenerated else None,
680
+ )
681
+ if output_cols_prefix == "fit_predict_":
682
+ if hasattr(self._sklearn_object, "n_clusters"):
683
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
684
+ num_examples = self._sklearn_object.n_clusters
685
+ elif hasattr(self._sklearn_object, "min_samples"):
686
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
687
+ num_examples = self._sklearn_object.min_samples
688
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
689
+ # LocalOutlierFactor expects n_neighbors <= n_samples
690
+ num_examples = self._sklearn_object.n_neighbors
691
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
692
+ else:
693
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
656
694
 
657
695
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
658
696
  # seen during the fit.
@@ -664,12 +702,14 @@ class PCA(BaseTransformer):
664
702
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
665
703
  if self.sample_weight_col:
666
704
  output_df_columns_set -= set(self.sample_weight_col)
705
+
667
706
  # if the dimension of inferred output column names is correct; use it
668
707
  if len(expected_output_cols_list) == len(output_df_columns_set):
669
- return expected_output_cols_list
708
+ return expected_output_cols_list, output_df_pd
670
709
  # otherwise, use the sklearn estimator's output
671
710
  else:
672
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
711
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
712
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
673
713
 
674
714
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
675
715
  @telemetry.send_api_usage_telemetry(
@@ -715,7 +755,7 @@ class PCA(BaseTransformer):
715
755
  drop_input_cols=self._drop_input_cols,
716
756
  expected_output_cols_type="float",
717
757
  )
718
- expected_output_cols = self._align_expected_output_names(
758
+ expected_output_cols, _ = self._align_expected_output(
719
759
  inference_method, dataset, expected_output_cols, output_cols_prefix
720
760
  )
721
761
 
@@ -781,7 +821,7 @@ class PCA(BaseTransformer):
781
821
  drop_input_cols=self._drop_input_cols,
782
822
  expected_output_cols_type="float",
783
823
  )
784
- expected_output_cols = self._align_expected_output_names(
824
+ expected_output_cols, _ = self._align_expected_output(
785
825
  inference_method, dataset, expected_output_cols, output_cols_prefix
786
826
  )
787
827
  elif isinstance(dataset, pd.DataFrame):
@@ -844,7 +884,7 @@ class PCA(BaseTransformer):
844
884
  drop_input_cols=self._drop_input_cols,
845
885
  expected_output_cols_type="float",
846
886
  )
847
- expected_output_cols = self._align_expected_output_names(
887
+ expected_output_cols, _ = self._align_expected_output(
848
888
  inference_method, dataset, expected_output_cols, output_cols_prefix
849
889
  )
850
890
 
@@ -911,7 +951,7 @@ class PCA(BaseTransformer):
911
951
  drop_input_cols = self._drop_input_cols,
912
952
  expected_output_cols_type="float",
913
953
  )
914
- expected_output_cols = self._align_expected_output_names(
954
+ expected_output_cols, _ = self._align_expected_output(
915
955
  inference_method, dataset, expected_output_cols, output_cols_prefix
916
956
  )
917
957