snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -482,12 +479,23 @@ class LedoitWolf(BaseTransformer):
482
479
  autogenerated=self._autogenerated,
483
480
  subproject=_SUBPROJECT,
484
481
  )
485
- output_result, fitted_estimator = model_trainer.train_fit_predict(
486
- drop_input_cols=self._drop_input_cols,
487
- expected_output_cols_list=(
488
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
489
- ),
482
+ expected_output_cols = (
483
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
490
484
  )
485
+ if isinstance(dataset, DataFrame):
486
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
487
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
488
+ )
489
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
490
+ drop_input_cols=self._drop_input_cols,
491
+ expected_output_cols_list=expected_output_cols,
492
+ example_output_pd_df=example_output_pd_df,
493
+ )
494
+ else:
495
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
496
+ drop_input_cols=self._drop_input_cols,
497
+ expected_output_cols_list=expected_output_cols,
498
+ )
491
499
  self._sklearn_object = fitted_estimator
492
500
  self._is_fitted = True
493
501
  return output_result
@@ -510,6 +518,7 @@ class LedoitWolf(BaseTransformer):
510
518
  """
511
519
  self._infer_input_output_cols(dataset)
512
520
  super()._check_dataset_type(dataset)
521
+
513
522
  model_trainer = ModelTrainerBuilder.build_fit_transform(
514
523
  estimator=self._sklearn_object,
515
524
  dataset=dataset,
@@ -566,12 +575,41 @@ class LedoitWolf(BaseTransformer):
566
575
 
567
576
  return rv
568
577
 
569
- def _align_expected_output_names(
570
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
571
- ) -> List[str]:
578
+ def _align_expected_output(
579
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
580
+ ) -> Tuple[List[str], pd.DataFrame]:
581
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
582
+ and output dataframe with 1 line.
583
+ If the method is fit_predict, run 2 lines of data.
584
+ """
572
585
  # in case the inferred output column names dimension is different
573
586
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
574
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
587
+
588
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
589
+ # so change the minimum of number of rows to 2
590
+ num_examples = 2
591
+ statement_params = telemetry.get_function_usage_statement_params(
592
+ project=_PROJECT,
593
+ subproject=_SUBPROJECT,
594
+ function_name=telemetry.get_statement_params_full_func_name(
595
+ inspect.currentframe(), LedoitWolf.__class__.__name__
596
+ ),
597
+ api_calls=[Session.call],
598
+ custom_tags={"autogen": True} if self._autogenerated else None,
599
+ )
600
+ if output_cols_prefix == "fit_predict_":
601
+ if hasattr(self._sklearn_object, "n_clusters"):
602
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
603
+ num_examples = self._sklearn_object.n_clusters
604
+ elif hasattr(self._sklearn_object, "min_samples"):
605
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
606
+ num_examples = self._sklearn_object.min_samples
607
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
608
+ # LocalOutlierFactor expects n_neighbors <= n_samples
609
+ num_examples = self._sklearn_object.n_neighbors
610
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
611
+ else:
612
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
575
613
 
576
614
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
577
615
  # seen during the fit.
@@ -583,12 +621,14 @@ class LedoitWolf(BaseTransformer):
583
621
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
584
622
  if self.sample_weight_col:
585
623
  output_df_columns_set -= set(self.sample_weight_col)
624
+
586
625
  # if the dimension of inferred output column names is correct; use it
587
626
  if len(expected_output_cols_list) == len(output_df_columns_set):
588
- return expected_output_cols_list
627
+ return expected_output_cols_list, output_df_pd
589
628
  # otherwise, use the sklearn estimator's output
590
629
  else:
591
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
630
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
631
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
592
632
 
593
633
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
594
634
  @telemetry.send_api_usage_telemetry(
@@ -634,7 +674,7 @@ class LedoitWolf(BaseTransformer):
634
674
  drop_input_cols=self._drop_input_cols,
635
675
  expected_output_cols_type="float",
636
676
  )
637
- expected_output_cols = self._align_expected_output_names(
677
+ expected_output_cols, _ = self._align_expected_output(
638
678
  inference_method, dataset, expected_output_cols, output_cols_prefix
639
679
  )
640
680
 
@@ -700,7 +740,7 @@ class LedoitWolf(BaseTransformer):
700
740
  drop_input_cols=self._drop_input_cols,
701
741
  expected_output_cols_type="float",
702
742
  )
703
- expected_output_cols = self._align_expected_output_names(
743
+ expected_output_cols, _ = self._align_expected_output(
704
744
  inference_method, dataset, expected_output_cols, output_cols_prefix
705
745
  )
706
746
  elif isinstance(dataset, pd.DataFrame):
@@ -763,7 +803,7 @@ class LedoitWolf(BaseTransformer):
763
803
  drop_input_cols=self._drop_input_cols,
764
804
  expected_output_cols_type="float",
765
805
  )
766
- expected_output_cols = self._align_expected_output_names(
806
+ expected_output_cols, _ = self._align_expected_output(
767
807
  inference_method, dataset, expected_output_cols, output_cols_prefix
768
808
  )
769
809
 
@@ -828,7 +868,7 @@ class LedoitWolf(BaseTransformer):
828
868
  drop_input_cols = self._drop_input_cols,
829
869
  expected_output_cols_type="float",
830
870
  )
831
- expected_output_cols = self._align_expected_output_names(
871
+ expected_output_cols, _ = self._align_expected_output(
832
872
  inference_method, dataset, expected_output_cols, output_cols_prefix
833
873
  )
834
874
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -494,12 +491,23 @@ class MinCovDet(BaseTransformer):
494
491
  autogenerated=self._autogenerated,
495
492
  subproject=_SUBPROJECT,
496
493
  )
497
- output_result, fitted_estimator = model_trainer.train_fit_predict(
498
- drop_input_cols=self._drop_input_cols,
499
- expected_output_cols_list=(
500
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
501
- ),
494
+ expected_output_cols = (
495
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
502
496
  )
497
+ if isinstance(dataset, DataFrame):
498
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
499
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
500
+ )
501
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
502
+ drop_input_cols=self._drop_input_cols,
503
+ expected_output_cols_list=expected_output_cols,
504
+ example_output_pd_df=example_output_pd_df,
505
+ )
506
+ else:
507
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
508
+ drop_input_cols=self._drop_input_cols,
509
+ expected_output_cols_list=expected_output_cols,
510
+ )
503
511
  self._sklearn_object = fitted_estimator
504
512
  self._is_fitted = True
505
513
  return output_result
@@ -522,6 +530,7 @@ class MinCovDet(BaseTransformer):
522
530
  """
523
531
  self._infer_input_output_cols(dataset)
524
532
  super()._check_dataset_type(dataset)
533
+
525
534
  model_trainer = ModelTrainerBuilder.build_fit_transform(
526
535
  estimator=self._sklearn_object,
527
536
  dataset=dataset,
@@ -578,12 +587,41 @@ class MinCovDet(BaseTransformer):
578
587
 
579
588
  return rv
580
589
 
581
- def _align_expected_output_names(
582
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
583
- ) -> List[str]:
590
+ def _align_expected_output(
591
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
592
+ ) -> Tuple[List[str], pd.DataFrame]:
593
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
594
+ and output dataframe with 1 line.
595
+ If the method is fit_predict, run 2 lines of data.
596
+ """
584
597
  # in case the inferred output column names dimension is different
585
598
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
586
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
599
+
600
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
601
+ # so change the minimum of number of rows to 2
602
+ num_examples = 2
603
+ statement_params = telemetry.get_function_usage_statement_params(
604
+ project=_PROJECT,
605
+ subproject=_SUBPROJECT,
606
+ function_name=telemetry.get_statement_params_full_func_name(
607
+ inspect.currentframe(), MinCovDet.__class__.__name__
608
+ ),
609
+ api_calls=[Session.call],
610
+ custom_tags={"autogen": True} if self._autogenerated else None,
611
+ )
612
+ if output_cols_prefix == "fit_predict_":
613
+ if hasattr(self._sklearn_object, "n_clusters"):
614
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
615
+ num_examples = self._sklearn_object.n_clusters
616
+ elif hasattr(self._sklearn_object, "min_samples"):
617
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
618
+ num_examples = self._sklearn_object.min_samples
619
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
620
+ # LocalOutlierFactor expects n_neighbors <= n_samples
621
+ num_examples = self._sklearn_object.n_neighbors
622
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
623
+ else:
624
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
587
625
 
588
626
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
589
627
  # seen during the fit.
@@ -595,12 +633,14 @@ class MinCovDet(BaseTransformer):
595
633
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
596
634
  if self.sample_weight_col:
597
635
  output_df_columns_set -= set(self.sample_weight_col)
636
+
598
637
  # if the dimension of inferred output column names is correct; use it
599
638
  if len(expected_output_cols_list) == len(output_df_columns_set):
600
- return expected_output_cols_list
639
+ return expected_output_cols_list, output_df_pd
601
640
  # otherwise, use the sklearn estimator's output
602
641
  else:
603
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
642
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
643
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
604
644
 
605
645
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
606
646
  @telemetry.send_api_usage_telemetry(
@@ -646,7 +686,7 @@ class MinCovDet(BaseTransformer):
646
686
  drop_input_cols=self._drop_input_cols,
647
687
  expected_output_cols_type="float",
648
688
  )
649
- expected_output_cols = self._align_expected_output_names(
689
+ expected_output_cols, _ = self._align_expected_output(
650
690
  inference_method, dataset, expected_output_cols, output_cols_prefix
651
691
  )
652
692
 
@@ -712,7 +752,7 @@ class MinCovDet(BaseTransformer):
712
752
  drop_input_cols=self._drop_input_cols,
713
753
  expected_output_cols_type="float",
714
754
  )
715
- expected_output_cols = self._align_expected_output_names(
755
+ expected_output_cols, _ = self._align_expected_output(
716
756
  inference_method, dataset, expected_output_cols, output_cols_prefix
717
757
  )
718
758
  elif isinstance(dataset, pd.DataFrame):
@@ -775,7 +815,7 @@ class MinCovDet(BaseTransformer):
775
815
  drop_input_cols=self._drop_input_cols,
776
816
  expected_output_cols_type="float",
777
817
  )
778
- expected_output_cols = self._align_expected_output_names(
818
+ expected_output_cols, _ = self._align_expected_output(
779
819
  inference_method, dataset, expected_output_cols, output_cols_prefix
780
820
  )
781
821
 
@@ -840,7 +880,7 @@ class MinCovDet(BaseTransformer):
840
880
  drop_input_cols = self._drop_input_cols,
841
881
  expected_output_cols_type="float",
842
882
  )
843
- expected_output_cols = self._align_expected_output_names(
883
+ expected_output_cols, _ = self._align_expected_output(
844
884
  inference_method, dataset, expected_output_cols, output_cols_prefix
845
885
  )
846
886
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -475,12 +472,23 @@ class OAS(BaseTransformer):
475
472
  autogenerated=self._autogenerated,
476
473
  subproject=_SUBPROJECT,
477
474
  )
478
- output_result, fitted_estimator = model_trainer.train_fit_predict(
479
- drop_input_cols=self._drop_input_cols,
480
- expected_output_cols_list=(
481
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
482
- ),
475
+ expected_output_cols = (
476
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
483
477
  )
478
+ if isinstance(dataset, DataFrame):
479
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
480
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
481
+ )
482
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
483
+ drop_input_cols=self._drop_input_cols,
484
+ expected_output_cols_list=expected_output_cols,
485
+ example_output_pd_df=example_output_pd_df,
486
+ )
487
+ else:
488
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
489
+ drop_input_cols=self._drop_input_cols,
490
+ expected_output_cols_list=expected_output_cols,
491
+ )
484
492
  self._sklearn_object = fitted_estimator
485
493
  self._is_fitted = True
486
494
  return output_result
@@ -503,6 +511,7 @@ class OAS(BaseTransformer):
503
511
  """
504
512
  self._infer_input_output_cols(dataset)
505
513
  super()._check_dataset_type(dataset)
514
+
506
515
  model_trainer = ModelTrainerBuilder.build_fit_transform(
507
516
  estimator=self._sklearn_object,
508
517
  dataset=dataset,
@@ -559,12 +568,41 @@ class OAS(BaseTransformer):
559
568
 
560
569
  return rv
561
570
 
562
- def _align_expected_output_names(
563
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
564
- ) -> List[str]:
571
+ def _align_expected_output(
572
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
573
+ ) -> Tuple[List[str], pd.DataFrame]:
574
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
575
+ and output dataframe with 1 line.
576
+ If the method is fit_predict, run 2 lines of data.
577
+ """
565
578
  # in case the inferred output column names dimension is different
566
579
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
567
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
580
+
581
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
582
+ # so change the minimum of number of rows to 2
583
+ num_examples = 2
584
+ statement_params = telemetry.get_function_usage_statement_params(
585
+ project=_PROJECT,
586
+ subproject=_SUBPROJECT,
587
+ function_name=telemetry.get_statement_params_full_func_name(
588
+ inspect.currentframe(), OAS.__class__.__name__
589
+ ),
590
+ api_calls=[Session.call],
591
+ custom_tags={"autogen": True} if self._autogenerated else None,
592
+ )
593
+ if output_cols_prefix == "fit_predict_":
594
+ if hasattr(self._sklearn_object, "n_clusters"):
595
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
596
+ num_examples = self._sklearn_object.n_clusters
597
+ elif hasattr(self._sklearn_object, "min_samples"):
598
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
599
+ num_examples = self._sklearn_object.min_samples
600
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
601
+ # LocalOutlierFactor expects n_neighbors <= n_samples
602
+ num_examples = self._sklearn_object.n_neighbors
603
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
604
+ else:
605
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
568
606
 
569
607
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
570
608
  # seen during the fit.
@@ -576,12 +614,14 @@ class OAS(BaseTransformer):
576
614
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
577
615
  if self.sample_weight_col:
578
616
  output_df_columns_set -= set(self.sample_weight_col)
617
+
579
618
  # if the dimension of inferred output column names is correct; use it
580
619
  if len(expected_output_cols_list) == len(output_df_columns_set):
581
- return expected_output_cols_list
620
+ return expected_output_cols_list, output_df_pd
582
621
  # otherwise, use the sklearn estimator's output
583
622
  else:
584
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
623
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
624
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
585
625
 
586
626
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
587
627
  @telemetry.send_api_usage_telemetry(
@@ -627,7 +667,7 @@ class OAS(BaseTransformer):
627
667
  drop_input_cols=self._drop_input_cols,
628
668
  expected_output_cols_type="float",
629
669
  )
630
- expected_output_cols = self._align_expected_output_names(
670
+ expected_output_cols, _ = self._align_expected_output(
631
671
  inference_method, dataset, expected_output_cols, output_cols_prefix
632
672
  )
633
673
 
@@ -693,7 +733,7 @@ class OAS(BaseTransformer):
693
733
  drop_input_cols=self._drop_input_cols,
694
734
  expected_output_cols_type="float",
695
735
  )
696
- expected_output_cols = self._align_expected_output_names(
736
+ expected_output_cols, _ = self._align_expected_output(
697
737
  inference_method, dataset, expected_output_cols, output_cols_prefix
698
738
  )
699
739
  elif isinstance(dataset, pd.DataFrame):
@@ -756,7 +796,7 @@ class OAS(BaseTransformer):
756
796
  drop_input_cols=self._drop_input_cols,
757
797
  expected_output_cols_type="float",
758
798
  )
759
- expected_output_cols = self._align_expected_output_names(
799
+ expected_output_cols, _ = self._align_expected_output(
760
800
  inference_method, dataset, expected_output_cols, output_cols_prefix
761
801
  )
762
802
 
@@ -821,7 +861,7 @@ class OAS(BaseTransformer):
821
861
  drop_input_cols = self._drop_input_cols,
822
862
  expected_output_cols_type="float",
823
863
  )
824
- expected_output_cols = self._align_expected_output_names(
864
+ expected_output_cols, _ = self._align_expected_output(
825
865
  inference_method, dataset, expected_output_cols, output_cols_prefix
826
866
  )
827
867
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -481,12 +478,23 @@ class ShrunkCovariance(BaseTransformer):
481
478
  autogenerated=self._autogenerated,
482
479
  subproject=_SUBPROJECT,
483
480
  )
484
- output_result, fitted_estimator = model_trainer.train_fit_predict(
485
- drop_input_cols=self._drop_input_cols,
486
- expected_output_cols_list=(
487
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
488
- ),
481
+ expected_output_cols = (
482
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
489
483
  )
484
+ if isinstance(dataset, DataFrame):
485
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
486
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
487
+ )
488
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
489
+ drop_input_cols=self._drop_input_cols,
490
+ expected_output_cols_list=expected_output_cols,
491
+ example_output_pd_df=example_output_pd_df,
492
+ )
493
+ else:
494
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
495
+ drop_input_cols=self._drop_input_cols,
496
+ expected_output_cols_list=expected_output_cols,
497
+ )
490
498
  self._sklearn_object = fitted_estimator
491
499
  self._is_fitted = True
492
500
  return output_result
@@ -509,6 +517,7 @@ class ShrunkCovariance(BaseTransformer):
509
517
  """
510
518
  self._infer_input_output_cols(dataset)
511
519
  super()._check_dataset_type(dataset)
520
+
512
521
  model_trainer = ModelTrainerBuilder.build_fit_transform(
513
522
  estimator=self._sklearn_object,
514
523
  dataset=dataset,
@@ -565,12 +574,41 @@ class ShrunkCovariance(BaseTransformer):
565
574
 
566
575
  return rv
567
576
 
568
- def _align_expected_output_names(
569
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
570
- ) -> List[str]:
577
+ def _align_expected_output(
578
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
579
+ ) -> Tuple[List[str], pd.DataFrame]:
580
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
581
+ and output dataframe with 1 line.
582
+ If the method is fit_predict, run 2 lines of data.
583
+ """
571
584
  # in case the inferred output column names dimension is different
572
585
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
573
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
586
+
587
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
588
+ # so change the minimum of number of rows to 2
589
+ num_examples = 2
590
+ statement_params = telemetry.get_function_usage_statement_params(
591
+ project=_PROJECT,
592
+ subproject=_SUBPROJECT,
593
+ function_name=telemetry.get_statement_params_full_func_name(
594
+ inspect.currentframe(), ShrunkCovariance.__class__.__name__
595
+ ),
596
+ api_calls=[Session.call],
597
+ custom_tags={"autogen": True} if self._autogenerated else None,
598
+ )
599
+ if output_cols_prefix == "fit_predict_":
600
+ if hasattr(self._sklearn_object, "n_clusters"):
601
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
602
+ num_examples = self._sklearn_object.n_clusters
603
+ elif hasattr(self._sklearn_object, "min_samples"):
604
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
605
+ num_examples = self._sklearn_object.min_samples
606
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
607
+ # LocalOutlierFactor expects n_neighbors <= n_samples
608
+ num_examples = self._sklearn_object.n_neighbors
609
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
610
+ else:
611
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
574
612
 
575
613
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
576
614
  # seen during the fit.
@@ -582,12 +620,14 @@ class ShrunkCovariance(BaseTransformer):
582
620
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
583
621
  if self.sample_weight_col:
584
622
  output_df_columns_set -= set(self.sample_weight_col)
623
+
585
624
  # if the dimension of inferred output column names is correct; use it
586
625
  if len(expected_output_cols_list) == len(output_df_columns_set):
587
- return expected_output_cols_list
626
+ return expected_output_cols_list, output_df_pd
588
627
  # otherwise, use the sklearn estimator's output
589
628
  else:
590
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
629
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
630
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
591
631
 
592
632
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
593
633
  @telemetry.send_api_usage_telemetry(
@@ -633,7 +673,7 @@ class ShrunkCovariance(BaseTransformer):
633
673
  drop_input_cols=self._drop_input_cols,
634
674
  expected_output_cols_type="float",
635
675
  )
636
- expected_output_cols = self._align_expected_output_names(
676
+ expected_output_cols, _ = self._align_expected_output(
637
677
  inference_method, dataset, expected_output_cols, output_cols_prefix
638
678
  )
639
679
 
@@ -699,7 +739,7 @@ class ShrunkCovariance(BaseTransformer):
699
739
  drop_input_cols=self._drop_input_cols,
700
740
  expected_output_cols_type="float",
701
741
  )
702
- expected_output_cols = self._align_expected_output_names(
742
+ expected_output_cols, _ = self._align_expected_output(
703
743
  inference_method, dataset, expected_output_cols, output_cols_prefix
704
744
  )
705
745
  elif isinstance(dataset, pd.DataFrame):
@@ -762,7 +802,7 @@ class ShrunkCovariance(BaseTransformer):
762
802
  drop_input_cols=self._drop_input_cols,
763
803
  expected_output_cols_type="float",
764
804
  )
765
- expected_output_cols = self._align_expected_output_names(
805
+ expected_output_cols, _ = self._align_expected_output(
766
806
  inference_method, dataset, expected_output_cols, output_cols_prefix
767
807
  )
768
808
 
@@ -827,7 +867,7 @@ class ShrunkCovariance(BaseTransformer):
827
867
  drop_input_cols = self._drop_input_cols,
828
868
  expected_output_cols_type="float",
829
869
  )
830
- expected_output_cols = self._align_expected_output_names(
870
+ expected_output_cols, _ = self._align_expected_output(
831
871
  inference_method, dataset, expected_output_cols, output_cols_prefix
832
872
  )
833
873