snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -564,12 +561,23 @@ class GaussianProcessClassifier(BaseTransformer):
564
561
  autogenerated=self._autogenerated,
565
562
  subproject=_SUBPROJECT,
566
563
  )
567
- output_result, fitted_estimator = model_trainer.train_fit_predict(
568
- drop_input_cols=self._drop_input_cols,
569
- expected_output_cols_list=(
570
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
571
- ),
564
+ expected_output_cols = (
565
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
572
566
  )
567
+ if isinstance(dataset, DataFrame):
568
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
569
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
570
+ )
571
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
572
+ drop_input_cols=self._drop_input_cols,
573
+ expected_output_cols_list=expected_output_cols,
574
+ example_output_pd_df=example_output_pd_df,
575
+ )
576
+ else:
577
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
578
+ drop_input_cols=self._drop_input_cols,
579
+ expected_output_cols_list=expected_output_cols,
580
+ )
573
581
  self._sklearn_object = fitted_estimator
574
582
  self._is_fitted = True
575
583
  return output_result
@@ -592,6 +600,7 @@ class GaussianProcessClassifier(BaseTransformer):
592
600
  """
593
601
  self._infer_input_output_cols(dataset)
594
602
  super()._check_dataset_type(dataset)
603
+
595
604
  model_trainer = ModelTrainerBuilder.build_fit_transform(
596
605
  estimator=self._sklearn_object,
597
606
  dataset=dataset,
@@ -648,12 +657,41 @@ class GaussianProcessClassifier(BaseTransformer):
648
657
 
649
658
  return rv
650
659
 
651
- def _align_expected_output_names(
652
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
653
- ) -> List[str]:
660
+ def _align_expected_output(
661
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
662
+ ) -> Tuple[List[str], pd.DataFrame]:
663
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
664
+ and output dataframe with 1 line.
665
+ If the method is fit_predict, run 2 lines of data.
666
+ """
654
667
  # in case the inferred output column names dimension is different
655
668
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
656
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
669
+
670
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
671
+ # so change the minimum of number of rows to 2
672
+ num_examples = 2
673
+ statement_params = telemetry.get_function_usage_statement_params(
674
+ project=_PROJECT,
675
+ subproject=_SUBPROJECT,
676
+ function_name=telemetry.get_statement_params_full_func_name(
677
+ inspect.currentframe(), GaussianProcessClassifier.__class__.__name__
678
+ ),
679
+ api_calls=[Session.call],
680
+ custom_tags={"autogen": True} if self._autogenerated else None,
681
+ )
682
+ if output_cols_prefix == "fit_predict_":
683
+ if hasattr(self._sklearn_object, "n_clusters"):
684
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
685
+ num_examples = self._sklearn_object.n_clusters
686
+ elif hasattr(self._sklearn_object, "min_samples"):
687
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
688
+ num_examples = self._sklearn_object.min_samples
689
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
690
+ # LocalOutlierFactor expects n_neighbors <= n_samples
691
+ num_examples = self._sklearn_object.n_neighbors
692
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
693
+ else:
694
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
657
695
 
658
696
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
659
697
  # seen during the fit.
@@ -665,12 +703,14 @@ class GaussianProcessClassifier(BaseTransformer):
665
703
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
666
704
  if self.sample_weight_col:
667
705
  output_df_columns_set -= set(self.sample_weight_col)
706
+
668
707
  # if the dimension of inferred output column names is correct; use it
669
708
  if len(expected_output_cols_list) == len(output_df_columns_set):
670
- return expected_output_cols_list
709
+ return expected_output_cols_list, output_df_pd
671
710
  # otherwise, use the sklearn estimator's output
672
711
  else:
673
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
712
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
713
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
674
714
 
675
715
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
676
716
  @telemetry.send_api_usage_telemetry(
@@ -718,7 +758,7 @@ class GaussianProcessClassifier(BaseTransformer):
718
758
  drop_input_cols=self._drop_input_cols,
719
759
  expected_output_cols_type="float",
720
760
  )
721
- expected_output_cols = self._align_expected_output_names(
761
+ expected_output_cols, _ = self._align_expected_output(
722
762
  inference_method, dataset, expected_output_cols, output_cols_prefix
723
763
  )
724
764
 
@@ -786,7 +826,7 @@ class GaussianProcessClassifier(BaseTransformer):
786
826
  drop_input_cols=self._drop_input_cols,
787
827
  expected_output_cols_type="float",
788
828
  )
789
- expected_output_cols = self._align_expected_output_names(
829
+ expected_output_cols, _ = self._align_expected_output(
790
830
  inference_method, dataset, expected_output_cols, output_cols_prefix
791
831
  )
792
832
  elif isinstance(dataset, pd.DataFrame):
@@ -849,7 +889,7 @@ class GaussianProcessClassifier(BaseTransformer):
849
889
  drop_input_cols=self._drop_input_cols,
850
890
  expected_output_cols_type="float",
851
891
  )
852
- expected_output_cols = self._align_expected_output_names(
892
+ expected_output_cols, _ = self._align_expected_output(
853
893
  inference_method, dataset, expected_output_cols, output_cols_prefix
854
894
  )
855
895
 
@@ -914,7 +954,7 @@ class GaussianProcessClassifier(BaseTransformer):
914
954
  drop_input_cols = self._drop_input_cols,
915
955
  expected_output_cols_type="float",
916
956
  )
917
- expected_output_cols = self._align_expected_output_names(
957
+ expected_output_cols, _ = self._align_expected_output(
918
958
  inference_method, dataset, expected_output_cols, output_cols_prefix
919
959
  )
920
960
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -555,12 +552,23 @@ class GaussianProcessRegressor(BaseTransformer):
555
552
  autogenerated=self._autogenerated,
556
553
  subproject=_SUBPROJECT,
557
554
  )
558
- output_result, fitted_estimator = model_trainer.train_fit_predict(
559
- drop_input_cols=self._drop_input_cols,
560
- expected_output_cols_list=(
561
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
562
- ),
555
+ expected_output_cols = (
556
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
563
557
  )
558
+ if isinstance(dataset, DataFrame):
559
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
560
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
561
+ )
562
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
563
+ drop_input_cols=self._drop_input_cols,
564
+ expected_output_cols_list=expected_output_cols,
565
+ example_output_pd_df=example_output_pd_df,
566
+ )
567
+ else:
568
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
569
+ drop_input_cols=self._drop_input_cols,
570
+ expected_output_cols_list=expected_output_cols,
571
+ )
564
572
  self._sklearn_object = fitted_estimator
565
573
  self._is_fitted = True
566
574
  return output_result
@@ -583,6 +591,7 @@ class GaussianProcessRegressor(BaseTransformer):
583
591
  """
584
592
  self._infer_input_output_cols(dataset)
585
593
  super()._check_dataset_type(dataset)
594
+
586
595
  model_trainer = ModelTrainerBuilder.build_fit_transform(
587
596
  estimator=self._sklearn_object,
588
597
  dataset=dataset,
@@ -639,12 +648,41 @@ class GaussianProcessRegressor(BaseTransformer):
639
648
 
640
649
  return rv
641
650
 
642
- def _align_expected_output_names(
643
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
644
- ) -> List[str]:
651
+ def _align_expected_output(
652
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
653
+ ) -> Tuple[List[str], pd.DataFrame]:
654
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
655
+ and output dataframe with 1 line.
656
+ If the method is fit_predict, run 2 lines of data.
657
+ """
645
658
  # in case the inferred output column names dimension is different
646
659
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
647
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
660
+
661
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
662
+ # so change the minimum of number of rows to 2
663
+ num_examples = 2
664
+ statement_params = telemetry.get_function_usage_statement_params(
665
+ project=_PROJECT,
666
+ subproject=_SUBPROJECT,
667
+ function_name=telemetry.get_statement_params_full_func_name(
668
+ inspect.currentframe(), GaussianProcessRegressor.__class__.__name__
669
+ ),
670
+ api_calls=[Session.call],
671
+ custom_tags={"autogen": True} if self._autogenerated else None,
672
+ )
673
+ if output_cols_prefix == "fit_predict_":
674
+ if hasattr(self._sklearn_object, "n_clusters"):
675
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
676
+ num_examples = self._sklearn_object.n_clusters
677
+ elif hasattr(self._sklearn_object, "min_samples"):
678
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
679
+ num_examples = self._sklearn_object.min_samples
680
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
681
+ # LocalOutlierFactor expects n_neighbors <= n_samples
682
+ num_examples = self._sklearn_object.n_neighbors
683
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
684
+ else:
685
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
648
686
 
649
687
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
650
688
  # seen during the fit.
@@ -656,12 +694,14 @@ class GaussianProcessRegressor(BaseTransformer):
656
694
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
657
695
  if self.sample_weight_col:
658
696
  output_df_columns_set -= set(self.sample_weight_col)
697
+
659
698
  # if the dimension of inferred output column names is correct; use it
660
699
  if len(expected_output_cols_list) == len(output_df_columns_set):
661
- return expected_output_cols_list
700
+ return expected_output_cols_list, output_df_pd
662
701
  # otherwise, use the sklearn estimator's output
663
702
  else:
664
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
703
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
704
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
665
705
 
666
706
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
667
707
  @telemetry.send_api_usage_telemetry(
@@ -707,7 +747,7 @@ class GaussianProcessRegressor(BaseTransformer):
707
747
  drop_input_cols=self._drop_input_cols,
708
748
  expected_output_cols_type="float",
709
749
  )
710
- expected_output_cols = self._align_expected_output_names(
750
+ expected_output_cols, _ = self._align_expected_output(
711
751
  inference_method, dataset, expected_output_cols, output_cols_prefix
712
752
  )
713
753
 
@@ -773,7 +813,7 @@ class GaussianProcessRegressor(BaseTransformer):
773
813
  drop_input_cols=self._drop_input_cols,
774
814
  expected_output_cols_type="float",
775
815
  )
776
- expected_output_cols = self._align_expected_output_names(
816
+ expected_output_cols, _ = self._align_expected_output(
777
817
  inference_method, dataset, expected_output_cols, output_cols_prefix
778
818
  )
779
819
  elif isinstance(dataset, pd.DataFrame):
@@ -836,7 +876,7 @@ class GaussianProcessRegressor(BaseTransformer):
836
876
  drop_input_cols=self._drop_input_cols,
837
877
  expected_output_cols_type="float",
838
878
  )
839
- expected_output_cols = self._align_expected_output_names(
879
+ expected_output_cols, _ = self._align_expected_output(
840
880
  inference_method, dataset, expected_output_cols, output_cols_prefix
841
881
  )
842
882
 
@@ -901,7 +941,7 @@ class GaussianProcessRegressor(BaseTransformer):
901
941
  drop_input_cols = self._drop_input_cols,
902
942
  expected_output_cols_type="float",
903
943
  )
904
- expected_output_cols = self._align_expected_output_names(
944
+ expected_output_cols, _ = self._align_expected_output(
905
945
  inference_method, dataset, expected_output_cols, output_cols_prefix
906
946
  )
907
947
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -25,12 +23,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
25
23
  from snowflake.ml._internal import telemetry
26
24
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
27
25
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
28
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
26
+ from snowflake.ml._internal.utils import identifier
29
27
  from snowflake.snowpark import DataFrame, Session
30
28
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
31
29
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
32
30
  from snowflake.ml.modeling._internal.transformer_protocols import (
33
- ModelTransformHandlers,
34
31
  BatchInferenceKwargsTypedDict,
35
32
  ScoreKwargsTypedDict
36
33
  )
@@ -599,12 +596,23 @@ class IterativeImputer(BaseTransformer):
599
596
  autogenerated=self._autogenerated,
600
597
  subproject=_SUBPROJECT,
601
598
  )
602
- output_result, fitted_estimator = model_trainer.train_fit_predict(
603
- drop_input_cols=self._drop_input_cols,
604
- expected_output_cols_list=(
605
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
606
- ),
599
+ expected_output_cols = (
600
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
607
601
  )
602
+ if isinstance(dataset, DataFrame):
603
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
604
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
605
+ )
606
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
607
+ drop_input_cols=self._drop_input_cols,
608
+ expected_output_cols_list=expected_output_cols,
609
+ example_output_pd_df=example_output_pd_df,
610
+ )
611
+ else:
612
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
613
+ drop_input_cols=self._drop_input_cols,
614
+ expected_output_cols_list=expected_output_cols,
615
+ )
608
616
  self._sklearn_object = fitted_estimator
609
617
  self._is_fitted = True
610
618
  return output_result
@@ -629,6 +637,7 @@ class IterativeImputer(BaseTransformer):
629
637
  """
630
638
  self._infer_input_output_cols(dataset)
631
639
  super()._check_dataset_type(dataset)
640
+
632
641
  model_trainer = ModelTrainerBuilder.build_fit_transform(
633
642
  estimator=self._sklearn_object,
634
643
  dataset=dataset,
@@ -685,12 +694,41 @@ class IterativeImputer(BaseTransformer):
685
694
 
686
695
  return rv
687
696
 
688
- def _align_expected_output_names(
689
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
690
- ) -> List[str]:
697
+ def _align_expected_output(
698
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
699
+ ) -> Tuple[List[str], pd.DataFrame]:
700
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
701
+ and output dataframe with 1 line.
702
+ If the method is fit_predict, run 2 lines of data.
703
+ """
691
704
  # in case the inferred output column names dimension is different
692
705
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
693
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
706
+
707
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
708
+ # so change the minimum of number of rows to 2
709
+ num_examples = 2
710
+ statement_params = telemetry.get_function_usage_statement_params(
711
+ project=_PROJECT,
712
+ subproject=_SUBPROJECT,
713
+ function_name=telemetry.get_statement_params_full_func_name(
714
+ inspect.currentframe(), IterativeImputer.__class__.__name__
715
+ ),
716
+ api_calls=[Session.call],
717
+ custom_tags={"autogen": True} if self._autogenerated else None,
718
+ )
719
+ if output_cols_prefix == "fit_predict_":
720
+ if hasattr(self._sklearn_object, "n_clusters"):
721
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
722
+ num_examples = self._sklearn_object.n_clusters
723
+ elif hasattr(self._sklearn_object, "min_samples"):
724
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
725
+ num_examples = self._sklearn_object.min_samples
726
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
727
+ # LocalOutlierFactor expects n_neighbors <= n_samples
728
+ num_examples = self._sklearn_object.n_neighbors
729
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
730
+ else:
731
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
694
732
 
695
733
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
696
734
  # seen during the fit.
@@ -702,12 +740,14 @@ class IterativeImputer(BaseTransformer):
702
740
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
703
741
  if self.sample_weight_col:
704
742
  output_df_columns_set -= set(self.sample_weight_col)
743
+
705
744
  # if the dimension of inferred output column names is correct; use it
706
745
  if len(expected_output_cols_list) == len(output_df_columns_set):
707
- return expected_output_cols_list
746
+ return expected_output_cols_list, output_df_pd
708
747
  # otherwise, use the sklearn estimator's output
709
748
  else:
710
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
749
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
750
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
711
751
 
712
752
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
713
753
  @telemetry.send_api_usage_telemetry(
@@ -753,7 +793,7 @@ class IterativeImputer(BaseTransformer):
753
793
  drop_input_cols=self._drop_input_cols,
754
794
  expected_output_cols_type="float",
755
795
  )
756
- expected_output_cols = self._align_expected_output_names(
796
+ expected_output_cols, _ = self._align_expected_output(
757
797
  inference_method, dataset, expected_output_cols, output_cols_prefix
758
798
  )
759
799
 
@@ -819,7 +859,7 @@ class IterativeImputer(BaseTransformer):
819
859
  drop_input_cols=self._drop_input_cols,
820
860
  expected_output_cols_type="float",
821
861
  )
822
- expected_output_cols = self._align_expected_output_names(
862
+ expected_output_cols, _ = self._align_expected_output(
823
863
  inference_method, dataset, expected_output_cols, output_cols_prefix
824
864
  )
825
865
  elif isinstance(dataset, pd.DataFrame):
@@ -882,7 +922,7 @@ class IterativeImputer(BaseTransformer):
882
922
  drop_input_cols=self._drop_input_cols,
883
923
  expected_output_cols_type="float",
884
924
  )
885
- expected_output_cols = self._align_expected_output_names(
925
+ expected_output_cols, _ = self._align_expected_output(
886
926
  inference_method, dataset, expected_output_cols, output_cols_prefix
887
927
  )
888
928
 
@@ -947,7 +987,7 @@ class IterativeImputer(BaseTransformer):
947
987
  drop_input_cols = self._drop_input_cols,
948
988
  expected_output_cols_type="float",
949
989
  )
950
- expected_output_cols = self._align_expected_output_names(
990
+ expected_output_cols, _ = self._align_expected_output(
951
991
  inference_method, dataset, expected_output_cols, output_cols_prefix
952
992
  )
953
993
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -525,12 +522,23 @@ class KNNImputer(BaseTransformer):
525
522
  autogenerated=self._autogenerated,
526
523
  subproject=_SUBPROJECT,
527
524
  )
528
- output_result, fitted_estimator = model_trainer.train_fit_predict(
529
- drop_input_cols=self._drop_input_cols,
530
- expected_output_cols_list=(
531
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
532
- ),
525
+ expected_output_cols = (
526
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
533
527
  )
528
+ if isinstance(dataset, DataFrame):
529
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
530
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
531
+ )
532
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
533
+ drop_input_cols=self._drop_input_cols,
534
+ expected_output_cols_list=expected_output_cols,
535
+ example_output_pd_df=example_output_pd_df,
536
+ )
537
+ else:
538
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
539
+ drop_input_cols=self._drop_input_cols,
540
+ expected_output_cols_list=expected_output_cols,
541
+ )
534
542
  self._sklearn_object = fitted_estimator
535
543
  self._is_fitted = True
536
544
  return output_result
@@ -555,6 +563,7 @@ class KNNImputer(BaseTransformer):
555
563
  """
556
564
  self._infer_input_output_cols(dataset)
557
565
  super()._check_dataset_type(dataset)
566
+
558
567
  model_trainer = ModelTrainerBuilder.build_fit_transform(
559
568
  estimator=self._sklearn_object,
560
569
  dataset=dataset,
@@ -611,12 +620,41 @@ class KNNImputer(BaseTransformer):
611
620
 
612
621
  return rv
613
622
 
614
- def _align_expected_output_names(
615
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
616
- ) -> List[str]:
623
+ def _align_expected_output(
624
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
625
+ ) -> Tuple[List[str], pd.DataFrame]:
626
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
627
+ and output dataframe with 1 line.
628
+ If the method is fit_predict, run 2 lines of data.
629
+ """
617
630
  # in case the inferred output column names dimension is different
618
631
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
619
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
632
+
633
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
634
+ # so change the minimum of number of rows to 2
635
+ num_examples = 2
636
+ statement_params = telemetry.get_function_usage_statement_params(
637
+ project=_PROJECT,
638
+ subproject=_SUBPROJECT,
639
+ function_name=telemetry.get_statement_params_full_func_name(
640
+ inspect.currentframe(), KNNImputer.__class__.__name__
641
+ ),
642
+ api_calls=[Session.call],
643
+ custom_tags={"autogen": True} if self._autogenerated else None,
644
+ )
645
+ if output_cols_prefix == "fit_predict_":
646
+ if hasattr(self._sklearn_object, "n_clusters"):
647
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
648
+ num_examples = self._sklearn_object.n_clusters
649
+ elif hasattr(self._sklearn_object, "min_samples"):
650
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
651
+ num_examples = self._sklearn_object.min_samples
652
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
653
+ # LocalOutlierFactor expects n_neighbors <= n_samples
654
+ num_examples = self._sklearn_object.n_neighbors
655
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
656
+ else:
657
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
620
658
 
621
659
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
622
660
  # seen during the fit.
@@ -628,12 +666,14 @@ class KNNImputer(BaseTransformer):
628
666
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
629
667
  if self.sample_weight_col:
630
668
  output_df_columns_set -= set(self.sample_weight_col)
669
+
631
670
  # if the dimension of inferred output column names is correct; use it
632
671
  if len(expected_output_cols_list) == len(output_df_columns_set):
633
- return expected_output_cols_list
672
+ return expected_output_cols_list, output_df_pd
634
673
  # otherwise, use the sklearn estimator's output
635
674
  else:
636
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
675
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
676
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
637
677
 
638
678
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
639
679
  @telemetry.send_api_usage_telemetry(
@@ -679,7 +719,7 @@ class KNNImputer(BaseTransformer):
679
719
  drop_input_cols=self._drop_input_cols,
680
720
  expected_output_cols_type="float",
681
721
  )
682
- expected_output_cols = self._align_expected_output_names(
722
+ expected_output_cols, _ = self._align_expected_output(
683
723
  inference_method, dataset, expected_output_cols, output_cols_prefix
684
724
  )
685
725
 
@@ -745,7 +785,7 @@ class KNNImputer(BaseTransformer):
745
785
  drop_input_cols=self._drop_input_cols,
746
786
  expected_output_cols_type="float",
747
787
  )
748
- expected_output_cols = self._align_expected_output_names(
788
+ expected_output_cols, _ = self._align_expected_output(
749
789
  inference_method, dataset, expected_output_cols, output_cols_prefix
750
790
  )
751
791
  elif isinstance(dataset, pd.DataFrame):
@@ -808,7 +848,7 @@ class KNNImputer(BaseTransformer):
808
848
  drop_input_cols=self._drop_input_cols,
809
849
  expected_output_cols_type="float",
810
850
  )
811
- expected_output_cols = self._align_expected_output_names(
851
+ expected_output_cols, _ = self._align_expected_output(
812
852
  inference_method, dataset, expected_output_cols, output_cols_prefix
813
853
  )
814
854
 
@@ -873,7 +913,7 @@ class KNNImputer(BaseTransformer):
873
913
  drop_input_cols = self._drop_input_cols,
874
914
  expected_output_cols_type="float",
875
915
  )
876
- expected_output_cols = self._align_expected_output_names(
916
+ expected_output_cols, _ = self._align_expected_output(
877
917
  inference_method, dataset, expected_output_cols, output_cols_prefix
878
918
  )
879
919