snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -538,12 +535,23 @@ class LinearSVR(BaseTransformer):
538
535
  autogenerated=self._autogenerated,
539
536
  subproject=_SUBPROJECT,
540
537
  )
541
- output_result, fitted_estimator = model_trainer.train_fit_predict(
542
- drop_input_cols=self._drop_input_cols,
543
- expected_output_cols_list=(
544
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
545
- ),
538
+ expected_output_cols = (
539
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
546
540
  )
541
+ if isinstance(dataset, DataFrame):
542
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
543
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
544
+ )
545
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
546
+ drop_input_cols=self._drop_input_cols,
547
+ expected_output_cols_list=expected_output_cols,
548
+ example_output_pd_df=example_output_pd_df,
549
+ )
550
+ else:
551
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
552
+ drop_input_cols=self._drop_input_cols,
553
+ expected_output_cols_list=expected_output_cols,
554
+ )
547
555
  self._sklearn_object = fitted_estimator
548
556
  self._is_fitted = True
549
557
  return output_result
@@ -566,6 +574,7 @@ class LinearSVR(BaseTransformer):
566
574
  """
567
575
  self._infer_input_output_cols(dataset)
568
576
  super()._check_dataset_type(dataset)
577
+
569
578
  model_trainer = ModelTrainerBuilder.build_fit_transform(
570
579
  estimator=self._sklearn_object,
571
580
  dataset=dataset,
@@ -622,12 +631,41 @@ class LinearSVR(BaseTransformer):
622
631
 
623
632
  return rv
624
633
 
625
- def _align_expected_output_names(
626
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
627
- ) -> List[str]:
634
+ def _align_expected_output(
635
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
636
+ ) -> Tuple[List[str], pd.DataFrame]:
637
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
638
+ and output dataframe with 1 line.
639
+ If the method is fit_predict, run 2 lines of data.
640
+ """
628
641
  # in case the inferred output column names dimension is different
629
642
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
630
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
643
+
644
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
645
+ # so change the minimum of number of rows to 2
646
+ num_examples = 2
647
+ statement_params = telemetry.get_function_usage_statement_params(
648
+ project=_PROJECT,
649
+ subproject=_SUBPROJECT,
650
+ function_name=telemetry.get_statement_params_full_func_name(
651
+ inspect.currentframe(), LinearSVR.__class__.__name__
652
+ ),
653
+ api_calls=[Session.call],
654
+ custom_tags={"autogen": True} if self._autogenerated else None,
655
+ )
656
+ if output_cols_prefix == "fit_predict_":
657
+ if hasattr(self._sklearn_object, "n_clusters"):
658
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
659
+ num_examples = self._sklearn_object.n_clusters
660
+ elif hasattr(self._sklearn_object, "min_samples"):
661
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
662
+ num_examples = self._sklearn_object.min_samples
663
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
664
+ # LocalOutlierFactor expects n_neighbors <= n_samples
665
+ num_examples = self._sklearn_object.n_neighbors
666
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
667
+ else:
668
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
631
669
 
632
670
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
633
671
  # seen during the fit.
@@ -639,12 +677,14 @@ class LinearSVR(BaseTransformer):
639
677
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
640
678
  if self.sample_weight_col:
641
679
  output_df_columns_set -= set(self.sample_weight_col)
680
+
642
681
  # if the dimension of inferred output column names is correct; use it
643
682
  if len(expected_output_cols_list) == len(output_df_columns_set):
644
- return expected_output_cols_list
683
+ return expected_output_cols_list, output_df_pd
645
684
  # otherwise, use the sklearn estimator's output
646
685
  else:
647
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
686
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
687
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
648
688
 
649
689
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
650
690
  @telemetry.send_api_usage_telemetry(
@@ -690,7 +730,7 @@ class LinearSVR(BaseTransformer):
690
730
  drop_input_cols=self._drop_input_cols,
691
731
  expected_output_cols_type="float",
692
732
  )
693
- expected_output_cols = self._align_expected_output_names(
733
+ expected_output_cols, _ = self._align_expected_output(
694
734
  inference_method, dataset, expected_output_cols, output_cols_prefix
695
735
  )
696
736
 
@@ -756,7 +796,7 @@ class LinearSVR(BaseTransformer):
756
796
  drop_input_cols=self._drop_input_cols,
757
797
  expected_output_cols_type="float",
758
798
  )
759
- expected_output_cols = self._align_expected_output_names(
799
+ expected_output_cols, _ = self._align_expected_output(
760
800
  inference_method, dataset, expected_output_cols, output_cols_prefix
761
801
  )
762
802
  elif isinstance(dataset, pd.DataFrame):
@@ -819,7 +859,7 @@ class LinearSVR(BaseTransformer):
819
859
  drop_input_cols=self._drop_input_cols,
820
860
  expected_output_cols_type="float",
821
861
  )
822
- expected_output_cols = self._align_expected_output_names(
862
+ expected_output_cols, _ = self._align_expected_output(
823
863
  inference_method, dataset, expected_output_cols, output_cols_prefix
824
864
  )
825
865
 
@@ -884,7 +924,7 @@ class LinearSVR(BaseTransformer):
884
924
  drop_input_cols = self._drop_input_cols,
885
925
  expected_output_cols_type="float",
886
926
  )
887
- expected_output_cols = self._align_expected_output_names(
927
+ expected_output_cols, _ = self._align_expected_output(
888
928
  inference_method, dataset, expected_output_cols, output_cols_prefix
889
929
  )
890
930
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -572,12 +569,23 @@ class NuSVC(BaseTransformer):
572
569
  autogenerated=self._autogenerated,
573
570
  subproject=_SUBPROJECT,
574
571
  )
575
- output_result, fitted_estimator = model_trainer.train_fit_predict(
576
- drop_input_cols=self._drop_input_cols,
577
- expected_output_cols_list=(
578
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
579
- ),
572
+ expected_output_cols = (
573
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
580
574
  )
575
+ if isinstance(dataset, DataFrame):
576
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
577
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
578
+ )
579
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
580
+ drop_input_cols=self._drop_input_cols,
581
+ expected_output_cols_list=expected_output_cols,
582
+ example_output_pd_df=example_output_pd_df,
583
+ )
584
+ else:
585
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
586
+ drop_input_cols=self._drop_input_cols,
587
+ expected_output_cols_list=expected_output_cols,
588
+ )
581
589
  self._sklearn_object = fitted_estimator
582
590
  self._is_fitted = True
583
591
  return output_result
@@ -600,6 +608,7 @@ class NuSVC(BaseTransformer):
600
608
  """
601
609
  self._infer_input_output_cols(dataset)
602
610
  super()._check_dataset_type(dataset)
611
+
603
612
  model_trainer = ModelTrainerBuilder.build_fit_transform(
604
613
  estimator=self._sklearn_object,
605
614
  dataset=dataset,
@@ -656,12 +665,41 @@ class NuSVC(BaseTransformer):
656
665
 
657
666
  return rv
658
667
 
659
- def _align_expected_output_names(
660
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
661
- ) -> List[str]:
668
+ def _align_expected_output(
669
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
670
+ ) -> Tuple[List[str], pd.DataFrame]:
671
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
672
+ and output dataframe with 1 line.
673
+ If the method is fit_predict, run 2 lines of data.
674
+ """
662
675
  # in case the inferred output column names dimension is different
663
676
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
664
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
677
+
678
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
679
+ # so change the minimum of number of rows to 2
680
+ num_examples = 2
681
+ statement_params = telemetry.get_function_usage_statement_params(
682
+ project=_PROJECT,
683
+ subproject=_SUBPROJECT,
684
+ function_name=telemetry.get_statement_params_full_func_name(
685
+ inspect.currentframe(), NuSVC.__class__.__name__
686
+ ),
687
+ api_calls=[Session.call],
688
+ custom_tags={"autogen": True} if self._autogenerated else None,
689
+ )
690
+ if output_cols_prefix == "fit_predict_":
691
+ if hasattr(self._sklearn_object, "n_clusters"):
692
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
693
+ num_examples = self._sklearn_object.n_clusters
694
+ elif hasattr(self._sklearn_object, "min_samples"):
695
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
696
+ num_examples = self._sklearn_object.min_samples
697
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
698
+ # LocalOutlierFactor expects n_neighbors <= n_samples
699
+ num_examples = self._sklearn_object.n_neighbors
700
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
701
+ else:
702
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
665
703
 
666
704
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
667
705
  # seen during the fit.
@@ -673,12 +711,14 @@ class NuSVC(BaseTransformer):
673
711
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
674
712
  if self.sample_weight_col:
675
713
  output_df_columns_set -= set(self.sample_weight_col)
714
+
676
715
  # if the dimension of inferred output column names is correct; use it
677
716
  if len(expected_output_cols_list) == len(output_df_columns_set):
678
- return expected_output_cols_list
717
+ return expected_output_cols_list, output_df_pd
679
718
  # otherwise, use the sklearn estimator's output
680
719
  else:
681
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
720
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
721
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
682
722
 
683
723
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
684
724
  @telemetry.send_api_usage_telemetry(
@@ -726,7 +766,7 @@ class NuSVC(BaseTransformer):
726
766
  drop_input_cols=self._drop_input_cols,
727
767
  expected_output_cols_type="float",
728
768
  )
729
- expected_output_cols = self._align_expected_output_names(
769
+ expected_output_cols, _ = self._align_expected_output(
730
770
  inference_method, dataset, expected_output_cols, output_cols_prefix
731
771
  )
732
772
 
@@ -794,7 +834,7 @@ class NuSVC(BaseTransformer):
794
834
  drop_input_cols=self._drop_input_cols,
795
835
  expected_output_cols_type="float",
796
836
  )
797
- expected_output_cols = self._align_expected_output_names(
837
+ expected_output_cols, _ = self._align_expected_output(
798
838
  inference_method, dataset, expected_output_cols, output_cols_prefix
799
839
  )
800
840
  elif isinstance(dataset, pd.DataFrame):
@@ -859,7 +899,7 @@ class NuSVC(BaseTransformer):
859
899
  drop_input_cols=self._drop_input_cols,
860
900
  expected_output_cols_type="float",
861
901
  )
862
- expected_output_cols = self._align_expected_output_names(
902
+ expected_output_cols, _ = self._align_expected_output(
863
903
  inference_method, dataset, expected_output_cols, output_cols_prefix
864
904
  )
865
905
 
@@ -924,7 +964,7 @@ class NuSVC(BaseTransformer):
924
964
  drop_input_cols = self._drop_input_cols,
925
965
  expected_output_cols_type="float",
926
966
  )
927
- expected_output_cols = self._align_expected_output_names(
967
+ expected_output_cols, _ = self._align_expected_output(
928
968
  inference_method, dataset, expected_output_cols, output_cols_prefix
929
969
  )
930
970
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -533,12 +530,23 @@ class NuSVR(BaseTransformer):
533
530
  autogenerated=self._autogenerated,
534
531
  subproject=_SUBPROJECT,
535
532
  )
536
- output_result, fitted_estimator = model_trainer.train_fit_predict(
537
- drop_input_cols=self._drop_input_cols,
538
- expected_output_cols_list=(
539
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
540
- ),
533
+ expected_output_cols = (
534
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
541
535
  )
536
+ if isinstance(dataset, DataFrame):
537
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
538
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
539
+ )
540
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
541
+ drop_input_cols=self._drop_input_cols,
542
+ expected_output_cols_list=expected_output_cols,
543
+ example_output_pd_df=example_output_pd_df,
544
+ )
545
+ else:
546
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
547
+ drop_input_cols=self._drop_input_cols,
548
+ expected_output_cols_list=expected_output_cols,
549
+ )
542
550
  self._sklearn_object = fitted_estimator
543
551
  self._is_fitted = True
544
552
  return output_result
@@ -561,6 +569,7 @@ class NuSVR(BaseTransformer):
561
569
  """
562
570
  self._infer_input_output_cols(dataset)
563
571
  super()._check_dataset_type(dataset)
572
+
564
573
  model_trainer = ModelTrainerBuilder.build_fit_transform(
565
574
  estimator=self._sklearn_object,
566
575
  dataset=dataset,
@@ -617,12 +626,41 @@ class NuSVR(BaseTransformer):
617
626
 
618
627
  return rv
619
628
 
620
- def _align_expected_output_names(
621
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
622
- ) -> List[str]:
629
+ def _align_expected_output(
630
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
631
+ ) -> Tuple[List[str], pd.DataFrame]:
632
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
633
+ and output dataframe with 1 line.
634
+ If the method is fit_predict, run 2 lines of data.
635
+ """
623
636
  # in case the inferred output column names dimension is different
624
637
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
625
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
638
+
639
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
640
+ # so change the minimum of number of rows to 2
641
+ num_examples = 2
642
+ statement_params = telemetry.get_function_usage_statement_params(
643
+ project=_PROJECT,
644
+ subproject=_SUBPROJECT,
645
+ function_name=telemetry.get_statement_params_full_func_name(
646
+ inspect.currentframe(), NuSVR.__class__.__name__
647
+ ),
648
+ api_calls=[Session.call],
649
+ custom_tags={"autogen": True} if self._autogenerated else None,
650
+ )
651
+ if output_cols_prefix == "fit_predict_":
652
+ if hasattr(self._sklearn_object, "n_clusters"):
653
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
654
+ num_examples = self._sklearn_object.n_clusters
655
+ elif hasattr(self._sklearn_object, "min_samples"):
656
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
657
+ num_examples = self._sklearn_object.min_samples
658
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
659
+ # LocalOutlierFactor expects n_neighbors <= n_samples
660
+ num_examples = self._sklearn_object.n_neighbors
661
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
662
+ else:
663
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
626
664
 
627
665
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
628
666
  # seen during the fit.
@@ -634,12 +672,14 @@ class NuSVR(BaseTransformer):
634
672
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
635
673
  if self.sample_weight_col:
636
674
  output_df_columns_set -= set(self.sample_weight_col)
675
+
637
676
  # if the dimension of inferred output column names is correct; use it
638
677
  if len(expected_output_cols_list) == len(output_df_columns_set):
639
- return expected_output_cols_list
678
+ return expected_output_cols_list, output_df_pd
640
679
  # otherwise, use the sklearn estimator's output
641
680
  else:
642
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
681
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
682
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
643
683
 
644
684
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
645
685
  @telemetry.send_api_usage_telemetry(
@@ -685,7 +725,7 @@ class NuSVR(BaseTransformer):
685
725
  drop_input_cols=self._drop_input_cols,
686
726
  expected_output_cols_type="float",
687
727
  )
688
- expected_output_cols = self._align_expected_output_names(
728
+ expected_output_cols, _ = self._align_expected_output(
689
729
  inference_method, dataset, expected_output_cols, output_cols_prefix
690
730
  )
691
731
 
@@ -751,7 +791,7 @@ class NuSVR(BaseTransformer):
751
791
  drop_input_cols=self._drop_input_cols,
752
792
  expected_output_cols_type="float",
753
793
  )
754
- expected_output_cols = self._align_expected_output_names(
794
+ expected_output_cols, _ = self._align_expected_output(
755
795
  inference_method, dataset, expected_output_cols, output_cols_prefix
756
796
  )
757
797
  elif isinstance(dataset, pd.DataFrame):
@@ -814,7 +854,7 @@ class NuSVR(BaseTransformer):
814
854
  drop_input_cols=self._drop_input_cols,
815
855
  expected_output_cols_type="float",
816
856
  )
817
- expected_output_cols = self._align_expected_output_names(
857
+ expected_output_cols, _ = self._align_expected_output(
818
858
  inference_method, dataset, expected_output_cols, output_cols_prefix
819
859
  )
820
860
 
@@ -879,7 +919,7 @@ class NuSVR(BaseTransformer):
879
919
  drop_input_cols = self._drop_input_cols,
880
920
  expected_output_cols_type="float",
881
921
  )
882
- expected_output_cols = self._align_expected_output_names(
922
+ expected_output_cols, _ = self._align_expected_output(
883
923
  inference_method, dataset, expected_output_cols, output_cols_prefix
884
924
  )
885
925
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -575,12 +572,23 @@ class SVC(BaseTransformer):
575
572
  autogenerated=self._autogenerated,
576
573
  subproject=_SUBPROJECT,
577
574
  )
578
- output_result, fitted_estimator = model_trainer.train_fit_predict(
579
- drop_input_cols=self._drop_input_cols,
580
- expected_output_cols_list=(
581
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
582
- ),
575
+ expected_output_cols = (
576
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
583
577
  )
578
+ if isinstance(dataset, DataFrame):
579
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
580
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
581
+ )
582
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
583
+ drop_input_cols=self._drop_input_cols,
584
+ expected_output_cols_list=expected_output_cols,
585
+ example_output_pd_df=example_output_pd_df,
586
+ )
587
+ else:
588
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
589
+ drop_input_cols=self._drop_input_cols,
590
+ expected_output_cols_list=expected_output_cols,
591
+ )
584
592
  self._sklearn_object = fitted_estimator
585
593
  self._is_fitted = True
586
594
  return output_result
@@ -603,6 +611,7 @@ class SVC(BaseTransformer):
603
611
  """
604
612
  self._infer_input_output_cols(dataset)
605
613
  super()._check_dataset_type(dataset)
614
+
606
615
  model_trainer = ModelTrainerBuilder.build_fit_transform(
607
616
  estimator=self._sklearn_object,
608
617
  dataset=dataset,
@@ -659,12 +668,41 @@ class SVC(BaseTransformer):
659
668
 
660
669
  return rv
661
670
 
662
- def _align_expected_output_names(
663
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
664
- ) -> List[str]:
671
+ def _align_expected_output(
672
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
673
+ ) -> Tuple[List[str], pd.DataFrame]:
674
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
675
+ and output dataframe with 1 line.
676
+ If the method is fit_predict, run 2 lines of data.
677
+ """
665
678
  # in case the inferred output column names dimension is different
666
679
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
667
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
680
+
681
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
682
+ # so change the minimum of number of rows to 2
683
+ num_examples = 2
684
+ statement_params = telemetry.get_function_usage_statement_params(
685
+ project=_PROJECT,
686
+ subproject=_SUBPROJECT,
687
+ function_name=telemetry.get_statement_params_full_func_name(
688
+ inspect.currentframe(), SVC.__class__.__name__
689
+ ),
690
+ api_calls=[Session.call],
691
+ custom_tags={"autogen": True} if self._autogenerated else None,
692
+ )
693
+ if output_cols_prefix == "fit_predict_":
694
+ if hasattr(self._sklearn_object, "n_clusters"):
695
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
696
+ num_examples = self._sklearn_object.n_clusters
697
+ elif hasattr(self._sklearn_object, "min_samples"):
698
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
699
+ num_examples = self._sklearn_object.min_samples
700
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
701
+ # LocalOutlierFactor expects n_neighbors <= n_samples
702
+ num_examples = self._sklearn_object.n_neighbors
703
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
704
+ else:
705
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
668
706
 
669
707
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
670
708
  # seen during the fit.
@@ -676,12 +714,14 @@ class SVC(BaseTransformer):
676
714
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
677
715
  if self.sample_weight_col:
678
716
  output_df_columns_set -= set(self.sample_weight_col)
717
+
679
718
  # if the dimension of inferred output column names is correct; use it
680
719
  if len(expected_output_cols_list) == len(output_df_columns_set):
681
- return expected_output_cols_list
720
+ return expected_output_cols_list, output_df_pd
682
721
  # otherwise, use the sklearn estimator's output
683
722
  else:
684
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
723
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
724
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
685
725
 
686
726
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
687
727
  @telemetry.send_api_usage_telemetry(
@@ -729,7 +769,7 @@ class SVC(BaseTransformer):
729
769
  drop_input_cols=self._drop_input_cols,
730
770
  expected_output_cols_type="float",
731
771
  )
732
- expected_output_cols = self._align_expected_output_names(
772
+ expected_output_cols, _ = self._align_expected_output(
733
773
  inference_method, dataset, expected_output_cols, output_cols_prefix
734
774
  )
735
775
 
@@ -797,7 +837,7 @@ class SVC(BaseTransformer):
797
837
  drop_input_cols=self._drop_input_cols,
798
838
  expected_output_cols_type="float",
799
839
  )
800
- expected_output_cols = self._align_expected_output_names(
840
+ expected_output_cols, _ = self._align_expected_output(
801
841
  inference_method, dataset, expected_output_cols, output_cols_prefix
802
842
  )
803
843
  elif isinstance(dataset, pd.DataFrame):
@@ -862,7 +902,7 @@ class SVC(BaseTransformer):
862
902
  drop_input_cols=self._drop_input_cols,
863
903
  expected_output_cols_type="float",
864
904
  )
865
- expected_output_cols = self._align_expected_output_names(
905
+ expected_output_cols, _ = self._align_expected_output(
866
906
  inference_method, dataset, expected_output_cols, output_cols_prefix
867
907
  )
868
908
 
@@ -927,7 +967,7 @@ class SVC(BaseTransformer):
927
967
  drop_input_cols = self._drop_input_cols,
928
968
  expected_output_cols_type="float",
929
969
  )
930
- expected_output_cols = self._align_expected_output_names(
970
+ expected_output_cols, _ = self._align_expected_output(
931
971
  inference_method, dataset, expected_output_cols, output_cols_prefix
932
972
  )
933
973