snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -548,12 +545,23 @@ class KNeighborsRegressor(BaseTransformer):
548
545
  autogenerated=self._autogenerated,
549
546
  subproject=_SUBPROJECT,
550
547
  )
551
- output_result, fitted_estimator = model_trainer.train_fit_predict(
552
- drop_input_cols=self._drop_input_cols,
553
- expected_output_cols_list=(
554
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
555
- ),
548
+ expected_output_cols = (
549
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
556
550
  )
551
+ if isinstance(dataset, DataFrame):
552
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
553
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
554
+ )
555
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
556
+ drop_input_cols=self._drop_input_cols,
557
+ expected_output_cols_list=expected_output_cols,
558
+ example_output_pd_df=example_output_pd_df,
559
+ )
560
+ else:
561
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
562
+ drop_input_cols=self._drop_input_cols,
563
+ expected_output_cols_list=expected_output_cols,
564
+ )
557
565
  self._sklearn_object = fitted_estimator
558
566
  self._is_fitted = True
559
567
  return output_result
@@ -576,6 +584,7 @@ class KNeighborsRegressor(BaseTransformer):
576
584
  """
577
585
  self._infer_input_output_cols(dataset)
578
586
  super()._check_dataset_type(dataset)
587
+
579
588
  model_trainer = ModelTrainerBuilder.build_fit_transform(
580
589
  estimator=self._sklearn_object,
581
590
  dataset=dataset,
@@ -632,12 +641,41 @@ class KNeighborsRegressor(BaseTransformer):
632
641
 
633
642
  return rv
634
643
 
635
- def _align_expected_output_names(
636
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
637
- ) -> List[str]:
644
+ def _align_expected_output(
645
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
646
+ ) -> Tuple[List[str], pd.DataFrame]:
647
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
648
+ and output dataframe with 1 line.
649
+ If the method is fit_predict, run 2 lines of data.
650
+ """
638
651
  # in case the inferred output column names dimension is different
639
652
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
640
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
653
+
654
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
655
+ # so change the minimum of number of rows to 2
656
+ num_examples = 2
657
+ statement_params = telemetry.get_function_usage_statement_params(
658
+ project=_PROJECT,
659
+ subproject=_SUBPROJECT,
660
+ function_name=telemetry.get_statement_params_full_func_name(
661
+ inspect.currentframe(), KNeighborsRegressor.__class__.__name__
662
+ ),
663
+ api_calls=[Session.call],
664
+ custom_tags={"autogen": True} if self._autogenerated else None,
665
+ )
666
+ if output_cols_prefix == "fit_predict_":
667
+ if hasattr(self._sklearn_object, "n_clusters"):
668
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
669
+ num_examples = self._sklearn_object.n_clusters
670
+ elif hasattr(self._sklearn_object, "min_samples"):
671
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
672
+ num_examples = self._sklearn_object.min_samples
673
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
674
+ # LocalOutlierFactor expects n_neighbors <= n_samples
675
+ num_examples = self._sklearn_object.n_neighbors
676
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
677
+ else:
678
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
641
679
 
642
680
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
643
681
  # seen during the fit.
@@ -649,12 +687,14 @@ class KNeighborsRegressor(BaseTransformer):
649
687
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
650
688
  if self.sample_weight_col:
651
689
  output_df_columns_set -= set(self.sample_weight_col)
690
+
652
691
  # if the dimension of inferred output column names is correct; use it
653
692
  if len(expected_output_cols_list) == len(output_df_columns_set):
654
- return expected_output_cols_list
693
+ return expected_output_cols_list, output_df_pd
655
694
  # otherwise, use the sklearn estimator's output
656
695
  else:
657
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
696
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
697
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
658
698
 
659
699
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
660
700
  @telemetry.send_api_usage_telemetry(
@@ -700,7 +740,7 @@ class KNeighborsRegressor(BaseTransformer):
700
740
  drop_input_cols=self._drop_input_cols,
701
741
  expected_output_cols_type="float",
702
742
  )
703
- expected_output_cols = self._align_expected_output_names(
743
+ expected_output_cols, _ = self._align_expected_output(
704
744
  inference_method, dataset, expected_output_cols, output_cols_prefix
705
745
  )
706
746
 
@@ -766,7 +806,7 @@ class KNeighborsRegressor(BaseTransformer):
766
806
  drop_input_cols=self._drop_input_cols,
767
807
  expected_output_cols_type="float",
768
808
  )
769
- expected_output_cols = self._align_expected_output_names(
809
+ expected_output_cols, _ = self._align_expected_output(
770
810
  inference_method, dataset, expected_output_cols, output_cols_prefix
771
811
  )
772
812
  elif isinstance(dataset, pd.DataFrame):
@@ -829,7 +869,7 @@ class KNeighborsRegressor(BaseTransformer):
829
869
  drop_input_cols=self._drop_input_cols,
830
870
  expected_output_cols_type="float",
831
871
  )
832
- expected_output_cols = self._align_expected_output_names(
872
+ expected_output_cols, _ = self._align_expected_output(
833
873
  inference_method, dataset, expected_output_cols, output_cols_prefix
834
874
  )
835
875
 
@@ -894,7 +934,7 @@ class KNeighborsRegressor(BaseTransformer):
894
934
  drop_input_cols = self._drop_input_cols,
895
935
  expected_output_cols_type="float",
896
936
  )
897
- expected_output_cols = self._align_expected_output_names(
937
+ expected_output_cols, _ = self._align_expected_output(
898
938
  inference_method, dataset, expected_output_cols, output_cols_prefix
899
939
  )
900
940
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -525,12 +522,23 @@ class KernelDensity(BaseTransformer):
525
522
  autogenerated=self._autogenerated,
526
523
  subproject=_SUBPROJECT,
527
524
  )
528
- output_result, fitted_estimator = model_trainer.train_fit_predict(
529
- drop_input_cols=self._drop_input_cols,
530
- expected_output_cols_list=(
531
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
532
- ),
525
+ expected_output_cols = (
526
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
533
527
  )
528
+ if isinstance(dataset, DataFrame):
529
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
530
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
531
+ )
532
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
533
+ drop_input_cols=self._drop_input_cols,
534
+ expected_output_cols_list=expected_output_cols,
535
+ example_output_pd_df=example_output_pd_df,
536
+ )
537
+ else:
538
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
539
+ drop_input_cols=self._drop_input_cols,
540
+ expected_output_cols_list=expected_output_cols,
541
+ )
534
542
  self._sklearn_object = fitted_estimator
535
543
  self._is_fitted = True
536
544
  return output_result
@@ -553,6 +561,7 @@ class KernelDensity(BaseTransformer):
553
561
  """
554
562
  self._infer_input_output_cols(dataset)
555
563
  super()._check_dataset_type(dataset)
564
+
556
565
  model_trainer = ModelTrainerBuilder.build_fit_transform(
557
566
  estimator=self._sklearn_object,
558
567
  dataset=dataset,
@@ -609,12 +618,41 @@ class KernelDensity(BaseTransformer):
609
618
 
610
619
  return rv
611
620
 
612
- def _align_expected_output_names(
613
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
614
- ) -> List[str]:
621
+ def _align_expected_output(
622
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
623
+ ) -> Tuple[List[str], pd.DataFrame]:
624
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
625
+ and output dataframe with 1 line.
626
+ If the method is fit_predict, run 2 lines of data.
627
+ """
615
628
  # in case the inferred output column names dimension is different
616
629
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
617
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
630
+
631
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
632
+ # so change the minimum of number of rows to 2
633
+ num_examples = 2
634
+ statement_params = telemetry.get_function_usage_statement_params(
635
+ project=_PROJECT,
636
+ subproject=_SUBPROJECT,
637
+ function_name=telemetry.get_statement_params_full_func_name(
638
+ inspect.currentframe(), KernelDensity.__class__.__name__
639
+ ),
640
+ api_calls=[Session.call],
641
+ custom_tags={"autogen": True} if self._autogenerated else None,
642
+ )
643
+ if output_cols_prefix == "fit_predict_":
644
+ if hasattr(self._sklearn_object, "n_clusters"):
645
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
646
+ num_examples = self._sklearn_object.n_clusters
647
+ elif hasattr(self._sklearn_object, "min_samples"):
648
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
649
+ num_examples = self._sklearn_object.min_samples
650
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
651
+ # LocalOutlierFactor expects n_neighbors <= n_samples
652
+ num_examples = self._sklearn_object.n_neighbors
653
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
654
+ else:
655
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
618
656
 
619
657
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
620
658
  # seen during the fit.
@@ -626,12 +664,14 @@ class KernelDensity(BaseTransformer):
626
664
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
627
665
  if self.sample_weight_col:
628
666
  output_df_columns_set -= set(self.sample_weight_col)
667
+
629
668
  # if the dimension of inferred output column names is correct; use it
630
669
  if len(expected_output_cols_list) == len(output_df_columns_set):
631
- return expected_output_cols_list
670
+ return expected_output_cols_list, output_df_pd
632
671
  # otherwise, use the sklearn estimator's output
633
672
  else:
634
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
673
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
674
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
635
675
 
636
676
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
637
677
  @telemetry.send_api_usage_telemetry(
@@ -677,7 +717,7 @@ class KernelDensity(BaseTransformer):
677
717
  drop_input_cols=self._drop_input_cols,
678
718
  expected_output_cols_type="float",
679
719
  )
680
- expected_output_cols = self._align_expected_output_names(
720
+ expected_output_cols, _ = self._align_expected_output(
681
721
  inference_method, dataset, expected_output_cols, output_cols_prefix
682
722
  )
683
723
 
@@ -743,7 +783,7 @@ class KernelDensity(BaseTransformer):
743
783
  drop_input_cols=self._drop_input_cols,
744
784
  expected_output_cols_type="float",
745
785
  )
746
- expected_output_cols = self._align_expected_output_names(
786
+ expected_output_cols, _ = self._align_expected_output(
747
787
  inference_method, dataset, expected_output_cols, output_cols_prefix
748
788
  )
749
789
  elif isinstance(dataset, pd.DataFrame):
@@ -806,7 +846,7 @@ class KernelDensity(BaseTransformer):
806
846
  drop_input_cols=self._drop_input_cols,
807
847
  expected_output_cols_type="float",
808
848
  )
809
- expected_output_cols = self._align_expected_output_names(
849
+ expected_output_cols, _ = self._align_expected_output(
810
850
  inference_method, dataset, expected_output_cols, output_cols_prefix
811
851
  )
812
852
 
@@ -873,7 +913,7 @@ class KernelDensity(BaseTransformer):
873
913
  drop_input_cols = self._drop_input_cols,
874
914
  expected_output_cols_type="float",
875
915
  )
876
- expected_output_cols = self._align_expected_output_names(
916
+ expected_output_cols, _ = self._align_expected_output(
877
917
  inference_method, dataset, expected_output_cols, output_cols_prefix
878
918
  )
879
919
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -557,12 +554,23 @@ class LocalOutlierFactor(BaseTransformer):
557
554
  autogenerated=self._autogenerated,
558
555
  subproject=_SUBPROJECT,
559
556
  )
560
- output_result, fitted_estimator = model_trainer.train_fit_predict(
561
- drop_input_cols=self._drop_input_cols,
562
- expected_output_cols_list=(
563
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
564
- ),
557
+ expected_output_cols = (
558
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
565
559
  )
560
+ if isinstance(dataset, DataFrame):
561
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
562
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
563
+ )
564
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
565
+ drop_input_cols=self._drop_input_cols,
566
+ expected_output_cols_list=expected_output_cols,
567
+ example_output_pd_df=example_output_pd_df,
568
+ )
569
+ else:
570
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
571
+ drop_input_cols=self._drop_input_cols,
572
+ expected_output_cols_list=expected_output_cols,
573
+ )
566
574
  self._sklearn_object = fitted_estimator
567
575
  self._is_fitted = True
568
576
  return output_result
@@ -585,6 +593,7 @@ class LocalOutlierFactor(BaseTransformer):
585
593
  """
586
594
  self._infer_input_output_cols(dataset)
587
595
  super()._check_dataset_type(dataset)
596
+
588
597
  model_trainer = ModelTrainerBuilder.build_fit_transform(
589
598
  estimator=self._sklearn_object,
590
599
  dataset=dataset,
@@ -641,12 +650,41 @@ class LocalOutlierFactor(BaseTransformer):
641
650
 
642
651
  return rv
643
652
 
644
- def _align_expected_output_names(
645
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
646
- ) -> List[str]:
653
+ def _align_expected_output(
654
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
655
+ ) -> Tuple[List[str], pd.DataFrame]:
656
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
657
+ and output dataframe with 1 line.
658
+ If the method is fit_predict, run 2 lines of data.
659
+ """
647
660
  # in case the inferred output column names dimension is different
648
661
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
649
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
662
+
663
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
664
+ # so change the minimum of number of rows to 2
665
+ num_examples = 2
666
+ statement_params = telemetry.get_function_usage_statement_params(
667
+ project=_PROJECT,
668
+ subproject=_SUBPROJECT,
669
+ function_name=telemetry.get_statement_params_full_func_name(
670
+ inspect.currentframe(), LocalOutlierFactor.__class__.__name__
671
+ ),
672
+ api_calls=[Session.call],
673
+ custom_tags={"autogen": True} if self._autogenerated else None,
674
+ )
675
+ if output_cols_prefix == "fit_predict_":
676
+ if hasattr(self._sklearn_object, "n_clusters"):
677
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
678
+ num_examples = self._sklearn_object.n_clusters
679
+ elif hasattr(self._sklearn_object, "min_samples"):
680
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
681
+ num_examples = self._sklearn_object.min_samples
682
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
683
+ # LocalOutlierFactor expects n_neighbors <= n_samples
684
+ num_examples = self._sklearn_object.n_neighbors
685
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
686
+ else:
687
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
650
688
 
651
689
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
652
690
  # seen during the fit.
@@ -658,12 +696,14 @@ class LocalOutlierFactor(BaseTransformer):
658
696
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
659
697
  if self.sample_weight_col:
660
698
  output_df_columns_set -= set(self.sample_weight_col)
699
+
661
700
  # if the dimension of inferred output column names is correct; use it
662
701
  if len(expected_output_cols_list) == len(output_df_columns_set):
663
- return expected_output_cols_list
702
+ return expected_output_cols_list, output_df_pd
664
703
  # otherwise, use the sklearn estimator's output
665
704
  else:
666
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
705
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
706
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
667
707
 
668
708
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
669
709
  @telemetry.send_api_usage_telemetry(
@@ -709,7 +749,7 @@ class LocalOutlierFactor(BaseTransformer):
709
749
  drop_input_cols=self._drop_input_cols,
710
750
  expected_output_cols_type="float",
711
751
  )
712
- expected_output_cols = self._align_expected_output_names(
752
+ expected_output_cols, _ = self._align_expected_output(
713
753
  inference_method, dataset, expected_output_cols, output_cols_prefix
714
754
  )
715
755
 
@@ -775,7 +815,7 @@ class LocalOutlierFactor(BaseTransformer):
775
815
  drop_input_cols=self._drop_input_cols,
776
816
  expected_output_cols_type="float",
777
817
  )
778
- expected_output_cols = self._align_expected_output_names(
818
+ expected_output_cols, _ = self._align_expected_output(
779
819
  inference_method, dataset, expected_output_cols, output_cols_prefix
780
820
  )
781
821
  elif isinstance(dataset, pd.DataFrame):
@@ -840,7 +880,7 @@ class LocalOutlierFactor(BaseTransformer):
840
880
  drop_input_cols=self._drop_input_cols,
841
881
  expected_output_cols_type="float",
842
882
  )
843
- expected_output_cols = self._align_expected_output_names(
883
+ expected_output_cols, _ = self._align_expected_output(
844
884
  inference_method, dataset, expected_output_cols, output_cols_prefix
845
885
  )
846
886
 
@@ -907,7 +947,7 @@ class LocalOutlierFactor(BaseTransformer):
907
947
  drop_input_cols = self._drop_input_cols,
908
948
  expected_output_cols_type="float",
909
949
  )
910
- expected_output_cols = self._align_expected_output_names(
950
+ expected_output_cols, _ = self._align_expected_output(
911
951
  inference_method, dataset, expected_output_cols, output_cols_prefix
912
952
  )
913
953
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -486,12 +483,23 @@ class NearestCentroid(BaseTransformer):
486
483
  autogenerated=self._autogenerated,
487
484
  subproject=_SUBPROJECT,
488
485
  )
489
- output_result, fitted_estimator = model_trainer.train_fit_predict(
490
- drop_input_cols=self._drop_input_cols,
491
- expected_output_cols_list=(
492
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
493
- ),
486
+ expected_output_cols = (
487
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
494
488
  )
489
+ if isinstance(dataset, DataFrame):
490
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
491
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
492
+ )
493
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
494
+ drop_input_cols=self._drop_input_cols,
495
+ expected_output_cols_list=expected_output_cols,
496
+ example_output_pd_df=example_output_pd_df,
497
+ )
498
+ else:
499
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
500
+ drop_input_cols=self._drop_input_cols,
501
+ expected_output_cols_list=expected_output_cols,
502
+ )
495
503
  self._sklearn_object = fitted_estimator
496
504
  self._is_fitted = True
497
505
  return output_result
@@ -514,6 +522,7 @@ class NearestCentroid(BaseTransformer):
514
522
  """
515
523
  self._infer_input_output_cols(dataset)
516
524
  super()._check_dataset_type(dataset)
525
+
517
526
  model_trainer = ModelTrainerBuilder.build_fit_transform(
518
527
  estimator=self._sklearn_object,
519
528
  dataset=dataset,
@@ -570,12 +579,41 @@ class NearestCentroid(BaseTransformer):
570
579
 
571
580
  return rv
572
581
 
573
- def _align_expected_output_names(
574
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
575
- ) -> List[str]:
582
+ def _align_expected_output(
583
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
584
+ ) -> Tuple[List[str], pd.DataFrame]:
585
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
586
+ and output dataframe with 1 line.
587
+ If the method is fit_predict, run 2 lines of data.
588
+ """
576
589
  # in case the inferred output column names dimension is different
577
590
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
578
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
591
+
592
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
593
+ # so change the minimum of number of rows to 2
594
+ num_examples = 2
595
+ statement_params = telemetry.get_function_usage_statement_params(
596
+ project=_PROJECT,
597
+ subproject=_SUBPROJECT,
598
+ function_name=telemetry.get_statement_params_full_func_name(
599
+ inspect.currentframe(), NearestCentroid.__class__.__name__
600
+ ),
601
+ api_calls=[Session.call],
602
+ custom_tags={"autogen": True} if self._autogenerated else None,
603
+ )
604
+ if output_cols_prefix == "fit_predict_":
605
+ if hasattr(self._sklearn_object, "n_clusters"):
606
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
607
+ num_examples = self._sklearn_object.n_clusters
608
+ elif hasattr(self._sklearn_object, "min_samples"):
609
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
610
+ num_examples = self._sklearn_object.min_samples
611
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
612
+ # LocalOutlierFactor expects n_neighbors <= n_samples
613
+ num_examples = self._sklearn_object.n_neighbors
614
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
615
+ else:
616
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
579
617
 
580
618
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
581
619
  # seen during the fit.
@@ -587,12 +625,14 @@ class NearestCentroid(BaseTransformer):
587
625
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
588
626
  if self.sample_weight_col:
589
627
  output_df_columns_set -= set(self.sample_weight_col)
628
+
590
629
  # if the dimension of inferred output column names is correct; use it
591
630
  if len(expected_output_cols_list) == len(output_df_columns_set):
592
- return expected_output_cols_list
631
+ return expected_output_cols_list, output_df_pd
593
632
  # otherwise, use the sklearn estimator's output
594
633
  else:
595
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
634
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
635
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
596
636
 
597
637
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
598
638
  @telemetry.send_api_usage_telemetry(
@@ -638,7 +678,7 @@ class NearestCentroid(BaseTransformer):
638
678
  drop_input_cols=self._drop_input_cols,
639
679
  expected_output_cols_type="float",
640
680
  )
641
- expected_output_cols = self._align_expected_output_names(
681
+ expected_output_cols, _ = self._align_expected_output(
642
682
  inference_method, dataset, expected_output_cols, output_cols_prefix
643
683
  )
644
684
 
@@ -704,7 +744,7 @@ class NearestCentroid(BaseTransformer):
704
744
  drop_input_cols=self._drop_input_cols,
705
745
  expected_output_cols_type="float",
706
746
  )
707
- expected_output_cols = self._align_expected_output_names(
747
+ expected_output_cols, _ = self._align_expected_output(
708
748
  inference_method, dataset, expected_output_cols, output_cols_prefix
709
749
  )
710
750
  elif isinstance(dataset, pd.DataFrame):
@@ -767,7 +807,7 @@ class NearestCentroid(BaseTransformer):
767
807
  drop_input_cols=self._drop_input_cols,
768
808
  expected_output_cols_type="float",
769
809
  )
770
- expected_output_cols = self._align_expected_output_names(
810
+ expected_output_cols, _ = self._align_expected_output(
771
811
  inference_method, dataset, expected_output_cols, output_cols_prefix
772
812
  )
773
813
 
@@ -832,7 +872,7 @@ class NearestCentroid(BaseTransformer):
832
872
  drop_input_cols = self._drop_input_cols,
833
873
  expected_output_cols_type="float",
834
874
  )
835
- expected_output_cols = self._align_expected_output_names(
875
+ expected_output_cols, _ = self._align_expected_output(
836
876
  inference_method, dataset, expected_output_cols, output_cols_prefix
837
877
  )
838
878