snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -495,12 +492,23 @@ class ComplementNB(BaseTransformer):
495
492
  autogenerated=self._autogenerated,
496
493
  subproject=_SUBPROJECT,
497
494
  )
498
- output_result, fitted_estimator = model_trainer.train_fit_predict(
499
- drop_input_cols=self._drop_input_cols,
500
- expected_output_cols_list=(
501
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
502
- ),
495
+ expected_output_cols = (
496
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
503
497
  )
498
+ if isinstance(dataset, DataFrame):
499
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
500
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
501
+ )
502
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
503
+ drop_input_cols=self._drop_input_cols,
504
+ expected_output_cols_list=expected_output_cols,
505
+ example_output_pd_df=example_output_pd_df,
506
+ )
507
+ else:
508
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
509
+ drop_input_cols=self._drop_input_cols,
510
+ expected_output_cols_list=expected_output_cols,
511
+ )
504
512
  self._sklearn_object = fitted_estimator
505
513
  self._is_fitted = True
506
514
  return output_result
@@ -523,6 +531,7 @@ class ComplementNB(BaseTransformer):
523
531
  """
524
532
  self._infer_input_output_cols(dataset)
525
533
  super()._check_dataset_type(dataset)
534
+
526
535
  model_trainer = ModelTrainerBuilder.build_fit_transform(
527
536
  estimator=self._sklearn_object,
528
537
  dataset=dataset,
@@ -579,12 +588,41 @@ class ComplementNB(BaseTransformer):
579
588
 
580
589
  return rv
581
590
 
582
- def _align_expected_output_names(
583
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
584
- ) -> List[str]:
591
+ def _align_expected_output(
592
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
593
+ ) -> Tuple[List[str], pd.DataFrame]:
594
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
595
+ and output dataframe with 1 line.
596
+ If the method is fit_predict, run 2 lines of data.
597
+ """
585
598
  # in case the inferred output column names dimension is different
586
599
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
587
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
600
+
601
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
602
+ # so change the minimum of number of rows to 2
603
+ num_examples = 2
604
+ statement_params = telemetry.get_function_usage_statement_params(
605
+ project=_PROJECT,
606
+ subproject=_SUBPROJECT,
607
+ function_name=telemetry.get_statement_params_full_func_name(
608
+ inspect.currentframe(), ComplementNB.__class__.__name__
609
+ ),
610
+ api_calls=[Session.call],
611
+ custom_tags={"autogen": True} if self._autogenerated else None,
612
+ )
613
+ if output_cols_prefix == "fit_predict_":
614
+ if hasattr(self._sklearn_object, "n_clusters"):
615
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
616
+ num_examples = self._sklearn_object.n_clusters
617
+ elif hasattr(self._sklearn_object, "min_samples"):
618
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
619
+ num_examples = self._sklearn_object.min_samples
620
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
621
+ # LocalOutlierFactor expects n_neighbors <= n_samples
622
+ num_examples = self._sklearn_object.n_neighbors
623
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
624
+ else:
625
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
588
626
 
589
627
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
590
628
  # seen during the fit.
@@ -596,12 +634,14 @@ class ComplementNB(BaseTransformer):
596
634
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
597
635
  if self.sample_weight_col:
598
636
  output_df_columns_set -= set(self.sample_weight_col)
637
+
599
638
  # if the dimension of inferred output column names is correct; use it
600
639
  if len(expected_output_cols_list) == len(output_df_columns_set):
601
- return expected_output_cols_list
640
+ return expected_output_cols_list, output_df_pd
602
641
  # otherwise, use the sklearn estimator's output
603
642
  else:
604
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
643
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
644
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
605
645
 
606
646
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
607
647
  @telemetry.send_api_usage_telemetry(
@@ -649,7 +689,7 @@ class ComplementNB(BaseTransformer):
649
689
  drop_input_cols=self._drop_input_cols,
650
690
  expected_output_cols_type="float",
651
691
  )
652
- expected_output_cols = self._align_expected_output_names(
692
+ expected_output_cols, _ = self._align_expected_output(
653
693
  inference_method, dataset, expected_output_cols, output_cols_prefix
654
694
  )
655
695
 
@@ -717,7 +757,7 @@ class ComplementNB(BaseTransformer):
717
757
  drop_input_cols=self._drop_input_cols,
718
758
  expected_output_cols_type="float",
719
759
  )
720
- expected_output_cols = self._align_expected_output_names(
760
+ expected_output_cols, _ = self._align_expected_output(
721
761
  inference_method, dataset, expected_output_cols, output_cols_prefix
722
762
  )
723
763
  elif isinstance(dataset, pd.DataFrame):
@@ -780,7 +820,7 @@ class ComplementNB(BaseTransformer):
780
820
  drop_input_cols=self._drop_input_cols,
781
821
  expected_output_cols_type="float",
782
822
  )
783
- expected_output_cols = self._align_expected_output_names(
823
+ expected_output_cols, _ = self._align_expected_output(
784
824
  inference_method, dataset, expected_output_cols, output_cols_prefix
785
825
  )
786
826
 
@@ -845,7 +885,7 @@ class ComplementNB(BaseTransformer):
845
885
  drop_input_cols = self._drop_input_cols,
846
886
  expected_output_cols_type="float",
847
887
  )
848
- expected_output_cols = self._align_expected_output_names(
888
+ expected_output_cols, _ = self._align_expected_output(
849
889
  inference_method, dataset, expected_output_cols, output_cols_prefix
850
890
  )
851
891
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -476,12 +473,23 @@ class GaussianNB(BaseTransformer):
476
473
  autogenerated=self._autogenerated,
477
474
  subproject=_SUBPROJECT,
478
475
  )
479
- output_result, fitted_estimator = model_trainer.train_fit_predict(
480
- drop_input_cols=self._drop_input_cols,
481
- expected_output_cols_list=(
482
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
483
- ),
476
+ expected_output_cols = (
477
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
484
478
  )
479
+ if isinstance(dataset, DataFrame):
480
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
481
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
482
+ )
483
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
484
+ drop_input_cols=self._drop_input_cols,
485
+ expected_output_cols_list=expected_output_cols,
486
+ example_output_pd_df=example_output_pd_df,
487
+ )
488
+ else:
489
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
490
+ drop_input_cols=self._drop_input_cols,
491
+ expected_output_cols_list=expected_output_cols,
492
+ )
485
493
  self._sklearn_object = fitted_estimator
486
494
  self._is_fitted = True
487
495
  return output_result
@@ -504,6 +512,7 @@ class GaussianNB(BaseTransformer):
504
512
  """
505
513
  self._infer_input_output_cols(dataset)
506
514
  super()._check_dataset_type(dataset)
515
+
507
516
  model_trainer = ModelTrainerBuilder.build_fit_transform(
508
517
  estimator=self._sklearn_object,
509
518
  dataset=dataset,
@@ -560,12 +569,41 @@ class GaussianNB(BaseTransformer):
560
569
 
561
570
  return rv
562
571
 
563
- def _align_expected_output_names(
564
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
565
- ) -> List[str]:
572
+ def _align_expected_output(
573
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
574
+ ) -> Tuple[List[str], pd.DataFrame]:
575
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
576
+ and output dataframe with 1 line.
577
+ If the method is fit_predict, run 2 lines of data.
578
+ """
566
579
  # in case the inferred output column names dimension is different
567
580
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
568
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
581
+
582
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
583
+ # so change the minimum of number of rows to 2
584
+ num_examples = 2
585
+ statement_params = telemetry.get_function_usage_statement_params(
586
+ project=_PROJECT,
587
+ subproject=_SUBPROJECT,
588
+ function_name=telemetry.get_statement_params_full_func_name(
589
+ inspect.currentframe(), GaussianNB.__class__.__name__
590
+ ),
591
+ api_calls=[Session.call],
592
+ custom_tags={"autogen": True} if self._autogenerated else None,
593
+ )
594
+ if output_cols_prefix == "fit_predict_":
595
+ if hasattr(self._sklearn_object, "n_clusters"):
596
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
597
+ num_examples = self._sklearn_object.n_clusters
598
+ elif hasattr(self._sklearn_object, "min_samples"):
599
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
600
+ num_examples = self._sklearn_object.min_samples
601
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
602
+ # LocalOutlierFactor expects n_neighbors <= n_samples
603
+ num_examples = self._sklearn_object.n_neighbors
604
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
605
+ else:
606
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
569
607
 
570
608
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
571
609
  # seen during the fit.
@@ -577,12 +615,14 @@ class GaussianNB(BaseTransformer):
577
615
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
578
616
  if self.sample_weight_col:
579
617
  output_df_columns_set -= set(self.sample_weight_col)
618
+
580
619
  # if the dimension of inferred output column names is correct; use it
581
620
  if len(expected_output_cols_list) == len(output_df_columns_set):
582
- return expected_output_cols_list
621
+ return expected_output_cols_list, output_df_pd
583
622
  # otherwise, use the sklearn estimator's output
584
623
  else:
585
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
624
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
625
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
586
626
 
587
627
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
588
628
  @telemetry.send_api_usage_telemetry(
@@ -630,7 +670,7 @@ class GaussianNB(BaseTransformer):
630
670
  drop_input_cols=self._drop_input_cols,
631
671
  expected_output_cols_type="float",
632
672
  )
633
- expected_output_cols = self._align_expected_output_names(
673
+ expected_output_cols, _ = self._align_expected_output(
634
674
  inference_method, dataset, expected_output_cols, output_cols_prefix
635
675
  )
636
676
 
@@ -698,7 +738,7 @@ class GaussianNB(BaseTransformer):
698
738
  drop_input_cols=self._drop_input_cols,
699
739
  expected_output_cols_type="float",
700
740
  )
701
- expected_output_cols = self._align_expected_output_names(
741
+ expected_output_cols, _ = self._align_expected_output(
702
742
  inference_method, dataset, expected_output_cols, output_cols_prefix
703
743
  )
704
744
  elif isinstance(dataset, pd.DataFrame):
@@ -761,7 +801,7 @@ class GaussianNB(BaseTransformer):
761
801
  drop_input_cols=self._drop_input_cols,
762
802
  expected_output_cols_type="float",
763
803
  )
764
- expected_output_cols = self._align_expected_output_names(
804
+ expected_output_cols, _ = self._align_expected_output(
765
805
  inference_method, dataset, expected_output_cols, output_cols_prefix
766
806
  )
767
807
 
@@ -826,7 +866,7 @@ class GaussianNB(BaseTransformer):
826
866
  drop_input_cols = self._drop_input_cols,
827
867
  expected_output_cols_type="float",
828
868
  )
829
- expected_output_cols = self._align_expected_output_names(
869
+ expected_output_cols, _ = self._align_expected_output(
830
870
  inference_method, dataset, expected_output_cols, output_cols_prefix
831
871
  )
832
872
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -489,12 +486,23 @@ class MultinomialNB(BaseTransformer):
489
486
  autogenerated=self._autogenerated,
490
487
  subproject=_SUBPROJECT,
491
488
  )
492
- output_result, fitted_estimator = model_trainer.train_fit_predict(
493
- drop_input_cols=self._drop_input_cols,
494
- expected_output_cols_list=(
495
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
496
- ),
489
+ expected_output_cols = (
490
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
497
491
  )
492
+ if isinstance(dataset, DataFrame):
493
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
494
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
495
+ )
496
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
497
+ drop_input_cols=self._drop_input_cols,
498
+ expected_output_cols_list=expected_output_cols,
499
+ example_output_pd_df=example_output_pd_df,
500
+ )
501
+ else:
502
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
503
+ drop_input_cols=self._drop_input_cols,
504
+ expected_output_cols_list=expected_output_cols,
505
+ )
498
506
  self._sklearn_object = fitted_estimator
499
507
  self._is_fitted = True
500
508
  return output_result
@@ -517,6 +525,7 @@ class MultinomialNB(BaseTransformer):
517
525
  """
518
526
  self._infer_input_output_cols(dataset)
519
527
  super()._check_dataset_type(dataset)
528
+
520
529
  model_trainer = ModelTrainerBuilder.build_fit_transform(
521
530
  estimator=self._sklearn_object,
522
531
  dataset=dataset,
@@ -573,12 +582,41 @@ class MultinomialNB(BaseTransformer):
573
582
 
574
583
  return rv
575
584
 
576
- def _align_expected_output_names(
577
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
578
- ) -> List[str]:
585
+ def _align_expected_output(
586
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
587
+ ) -> Tuple[List[str], pd.DataFrame]:
588
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
589
+ and output dataframe with 1 line.
590
+ If the method is fit_predict, run 2 lines of data.
591
+ """
579
592
  # in case the inferred output column names dimension is different
580
593
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
581
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
594
+
595
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
596
+ # so change the minimum of number of rows to 2
597
+ num_examples = 2
598
+ statement_params = telemetry.get_function_usage_statement_params(
599
+ project=_PROJECT,
600
+ subproject=_SUBPROJECT,
601
+ function_name=telemetry.get_statement_params_full_func_name(
602
+ inspect.currentframe(), MultinomialNB.__class__.__name__
603
+ ),
604
+ api_calls=[Session.call],
605
+ custom_tags={"autogen": True} if self._autogenerated else None,
606
+ )
607
+ if output_cols_prefix == "fit_predict_":
608
+ if hasattr(self._sklearn_object, "n_clusters"):
609
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
610
+ num_examples = self._sklearn_object.n_clusters
611
+ elif hasattr(self._sklearn_object, "min_samples"):
612
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
613
+ num_examples = self._sklearn_object.min_samples
614
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
615
+ # LocalOutlierFactor expects n_neighbors <= n_samples
616
+ num_examples = self._sklearn_object.n_neighbors
617
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
618
+ else:
619
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
582
620
 
583
621
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
584
622
  # seen during the fit.
@@ -590,12 +628,14 @@ class MultinomialNB(BaseTransformer):
590
628
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
591
629
  if self.sample_weight_col:
592
630
  output_df_columns_set -= set(self.sample_weight_col)
631
+
593
632
  # if the dimension of inferred output column names is correct; use it
594
633
  if len(expected_output_cols_list) == len(output_df_columns_set):
595
- return expected_output_cols_list
634
+ return expected_output_cols_list, output_df_pd
596
635
  # otherwise, use the sklearn estimator's output
597
636
  else:
598
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
637
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
638
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
599
639
 
600
640
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
601
641
  @telemetry.send_api_usage_telemetry(
@@ -643,7 +683,7 @@ class MultinomialNB(BaseTransformer):
643
683
  drop_input_cols=self._drop_input_cols,
644
684
  expected_output_cols_type="float",
645
685
  )
646
- expected_output_cols = self._align_expected_output_names(
686
+ expected_output_cols, _ = self._align_expected_output(
647
687
  inference_method, dataset, expected_output_cols, output_cols_prefix
648
688
  )
649
689
 
@@ -711,7 +751,7 @@ class MultinomialNB(BaseTransformer):
711
751
  drop_input_cols=self._drop_input_cols,
712
752
  expected_output_cols_type="float",
713
753
  )
714
- expected_output_cols = self._align_expected_output_names(
754
+ expected_output_cols, _ = self._align_expected_output(
715
755
  inference_method, dataset, expected_output_cols, output_cols_prefix
716
756
  )
717
757
  elif isinstance(dataset, pd.DataFrame):
@@ -774,7 +814,7 @@ class MultinomialNB(BaseTransformer):
774
814
  drop_input_cols=self._drop_input_cols,
775
815
  expected_output_cols_type="float",
776
816
  )
777
- expected_output_cols = self._align_expected_output_names(
817
+ expected_output_cols, _ = self._align_expected_output(
778
818
  inference_method, dataset, expected_output_cols, output_cols_prefix
779
819
  )
780
820
 
@@ -839,7 +879,7 @@ class MultinomialNB(BaseTransformer):
839
879
  drop_input_cols = self._drop_input_cols,
840
880
  expected_output_cols_type="float",
841
881
  )
842
- expected_output_cols = self._align_expected_output_names(
882
+ expected_output_cols, _ = self._align_expected_output(
843
883
  inference_method, dataset, expected_output_cols, output_cols_prefix
844
884
  )
845
885
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -546,12 +543,23 @@ class KNeighborsClassifier(BaseTransformer):
546
543
  autogenerated=self._autogenerated,
547
544
  subproject=_SUBPROJECT,
548
545
  )
549
- output_result, fitted_estimator = model_trainer.train_fit_predict(
550
- drop_input_cols=self._drop_input_cols,
551
- expected_output_cols_list=(
552
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
553
- ),
546
+ expected_output_cols = (
547
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
554
548
  )
549
+ if isinstance(dataset, DataFrame):
550
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
551
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
552
+ )
553
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
554
+ drop_input_cols=self._drop_input_cols,
555
+ expected_output_cols_list=expected_output_cols,
556
+ example_output_pd_df=example_output_pd_df,
557
+ )
558
+ else:
559
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
560
+ drop_input_cols=self._drop_input_cols,
561
+ expected_output_cols_list=expected_output_cols,
562
+ )
555
563
  self._sklearn_object = fitted_estimator
556
564
  self._is_fitted = True
557
565
  return output_result
@@ -574,6 +582,7 @@ class KNeighborsClassifier(BaseTransformer):
574
582
  """
575
583
  self._infer_input_output_cols(dataset)
576
584
  super()._check_dataset_type(dataset)
585
+
577
586
  model_trainer = ModelTrainerBuilder.build_fit_transform(
578
587
  estimator=self._sklearn_object,
579
588
  dataset=dataset,
@@ -630,12 +639,41 @@ class KNeighborsClassifier(BaseTransformer):
630
639
 
631
640
  return rv
632
641
 
633
- def _align_expected_output_names(
634
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
635
- ) -> List[str]:
642
+ def _align_expected_output(
643
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
644
+ ) -> Tuple[List[str], pd.DataFrame]:
645
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
646
+ and output dataframe with 1 line.
647
+ If the method is fit_predict, run 2 lines of data.
648
+ """
636
649
  # in case the inferred output column names dimension is different
637
650
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
638
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
651
+
652
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
653
+ # so change the minimum of number of rows to 2
654
+ num_examples = 2
655
+ statement_params = telemetry.get_function_usage_statement_params(
656
+ project=_PROJECT,
657
+ subproject=_SUBPROJECT,
658
+ function_name=telemetry.get_statement_params_full_func_name(
659
+ inspect.currentframe(), KNeighborsClassifier.__class__.__name__
660
+ ),
661
+ api_calls=[Session.call],
662
+ custom_tags={"autogen": True} if self._autogenerated else None,
663
+ )
664
+ if output_cols_prefix == "fit_predict_":
665
+ if hasattr(self._sklearn_object, "n_clusters"):
666
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
667
+ num_examples = self._sklearn_object.n_clusters
668
+ elif hasattr(self._sklearn_object, "min_samples"):
669
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
670
+ num_examples = self._sklearn_object.min_samples
671
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
672
+ # LocalOutlierFactor expects n_neighbors <= n_samples
673
+ num_examples = self._sklearn_object.n_neighbors
674
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
675
+ else:
676
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
639
677
 
640
678
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
641
679
  # seen during the fit.
@@ -647,12 +685,14 @@ class KNeighborsClassifier(BaseTransformer):
647
685
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
648
686
  if self.sample_weight_col:
649
687
  output_df_columns_set -= set(self.sample_weight_col)
688
+
650
689
  # if the dimension of inferred output column names is correct; use it
651
690
  if len(expected_output_cols_list) == len(output_df_columns_set):
652
- return expected_output_cols_list
691
+ return expected_output_cols_list, output_df_pd
653
692
  # otherwise, use the sklearn estimator's output
654
693
  else:
655
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
694
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
695
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
656
696
 
657
697
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
658
698
  @telemetry.send_api_usage_telemetry(
@@ -700,7 +740,7 @@ class KNeighborsClassifier(BaseTransformer):
700
740
  drop_input_cols=self._drop_input_cols,
701
741
  expected_output_cols_type="float",
702
742
  )
703
- expected_output_cols = self._align_expected_output_names(
743
+ expected_output_cols, _ = self._align_expected_output(
704
744
  inference_method, dataset, expected_output_cols, output_cols_prefix
705
745
  )
706
746
 
@@ -768,7 +808,7 @@ class KNeighborsClassifier(BaseTransformer):
768
808
  drop_input_cols=self._drop_input_cols,
769
809
  expected_output_cols_type="float",
770
810
  )
771
- expected_output_cols = self._align_expected_output_names(
811
+ expected_output_cols, _ = self._align_expected_output(
772
812
  inference_method, dataset, expected_output_cols, output_cols_prefix
773
813
  )
774
814
  elif isinstance(dataset, pd.DataFrame):
@@ -831,7 +871,7 @@ class KNeighborsClassifier(BaseTransformer):
831
871
  drop_input_cols=self._drop_input_cols,
832
872
  expected_output_cols_type="float",
833
873
  )
834
- expected_output_cols = self._align_expected_output_names(
874
+ expected_output_cols, _ = self._align_expected_output(
835
875
  inference_method, dataset, expected_output_cols, output_cols_prefix
836
876
  )
837
877
 
@@ -896,7 +936,7 @@ class KNeighborsClassifier(BaseTransformer):
896
936
  drop_input_cols = self._drop_input_cols,
897
937
  expected_output_cols_type="float",
898
938
  )
899
- expected_output_cols = self._align_expected_output_names(
939
+ expected_output_cols, _ = self._align_expected_output(
900
940
  inference_method, dataset, expected_output_cols, output_cols_prefix
901
941
  )
902
942