snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -497,12 +494,23 @@ class PolynomialFeatures(BaseTransformer):
497
494
  autogenerated=self._autogenerated,
498
495
  subproject=_SUBPROJECT,
499
496
  )
500
- output_result, fitted_estimator = model_trainer.train_fit_predict(
501
- drop_input_cols=self._drop_input_cols,
502
- expected_output_cols_list=(
503
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
504
- ),
497
+ expected_output_cols = (
498
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
505
499
  )
500
+ if isinstance(dataset, DataFrame):
501
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
502
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
503
+ )
504
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
505
+ drop_input_cols=self._drop_input_cols,
506
+ expected_output_cols_list=expected_output_cols,
507
+ example_output_pd_df=example_output_pd_df,
508
+ )
509
+ else:
510
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
511
+ drop_input_cols=self._drop_input_cols,
512
+ expected_output_cols_list=expected_output_cols,
513
+ )
506
514
  self._sklearn_object = fitted_estimator
507
515
  self._is_fitted = True
508
516
  return output_result
@@ -527,6 +535,7 @@ class PolynomialFeatures(BaseTransformer):
527
535
  """
528
536
  self._infer_input_output_cols(dataset)
529
537
  super()._check_dataset_type(dataset)
538
+
530
539
  model_trainer = ModelTrainerBuilder.build_fit_transform(
531
540
  estimator=self._sklearn_object,
532
541
  dataset=dataset,
@@ -583,12 +592,41 @@ class PolynomialFeatures(BaseTransformer):
583
592
 
584
593
  return rv
585
594
 
586
- def _align_expected_output_names(
587
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
588
- ) -> List[str]:
595
+ def _align_expected_output(
596
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
597
+ ) -> Tuple[List[str], pd.DataFrame]:
598
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
599
+ and output dataframe with 1 line.
600
+ If the method is fit_predict, run 2 lines of data.
601
+ """
589
602
  # in case the inferred output column names dimension is different
590
603
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
591
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
604
+
605
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
606
+ # so change the minimum of number of rows to 2
607
+ num_examples = 2
608
+ statement_params = telemetry.get_function_usage_statement_params(
609
+ project=_PROJECT,
610
+ subproject=_SUBPROJECT,
611
+ function_name=telemetry.get_statement_params_full_func_name(
612
+ inspect.currentframe(), PolynomialFeatures.__class__.__name__
613
+ ),
614
+ api_calls=[Session.call],
615
+ custom_tags={"autogen": True} if self._autogenerated else None,
616
+ )
617
+ if output_cols_prefix == "fit_predict_":
618
+ if hasattr(self._sklearn_object, "n_clusters"):
619
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
620
+ num_examples = self._sklearn_object.n_clusters
621
+ elif hasattr(self._sklearn_object, "min_samples"):
622
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
623
+ num_examples = self._sklearn_object.min_samples
624
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
625
+ # LocalOutlierFactor expects n_neighbors <= n_samples
626
+ num_examples = self._sklearn_object.n_neighbors
627
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
628
+ else:
629
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
592
630
 
593
631
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
594
632
  # seen during the fit.
@@ -600,12 +638,14 @@ class PolynomialFeatures(BaseTransformer):
600
638
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
601
639
  if self.sample_weight_col:
602
640
  output_df_columns_set -= set(self.sample_weight_col)
641
+
603
642
  # if the dimension of inferred output column names is correct; use it
604
643
  if len(expected_output_cols_list) == len(output_df_columns_set):
605
- return expected_output_cols_list
644
+ return expected_output_cols_list, output_df_pd
606
645
  # otherwise, use the sklearn estimator's output
607
646
  else:
608
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
647
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
648
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
609
649
 
610
650
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
611
651
  @telemetry.send_api_usage_telemetry(
@@ -651,7 +691,7 @@ class PolynomialFeatures(BaseTransformer):
651
691
  drop_input_cols=self._drop_input_cols,
652
692
  expected_output_cols_type="float",
653
693
  )
654
- expected_output_cols = self._align_expected_output_names(
694
+ expected_output_cols, _ = self._align_expected_output(
655
695
  inference_method, dataset, expected_output_cols, output_cols_prefix
656
696
  )
657
697
 
@@ -717,7 +757,7 @@ class PolynomialFeatures(BaseTransformer):
717
757
  drop_input_cols=self._drop_input_cols,
718
758
  expected_output_cols_type="float",
719
759
  )
720
- expected_output_cols = self._align_expected_output_names(
760
+ expected_output_cols, _ = self._align_expected_output(
721
761
  inference_method, dataset, expected_output_cols, output_cols_prefix
722
762
  )
723
763
  elif isinstance(dataset, pd.DataFrame):
@@ -780,7 +820,7 @@ class PolynomialFeatures(BaseTransformer):
780
820
  drop_input_cols=self._drop_input_cols,
781
821
  expected_output_cols_type="float",
782
822
  )
783
- expected_output_cols = self._align_expected_output_names(
823
+ expected_output_cols, _ = self._align_expected_output(
784
824
  inference_method, dataset, expected_output_cols, output_cols_prefix
785
825
  )
786
826
 
@@ -845,7 +885,7 @@ class PolynomialFeatures(BaseTransformer):
845
885
  drop_input_cols = self._drop_input_cols,
846
886
  expected_output_cols_type="float",
847
887
  )
848
- expected_output_cols = self._align_expected_output_names(
888
+ expected_output_cols, _ = self._align_expected_output(
849
889
  inference_method, dataset, expected_output_cols, output_cols_prefix
850
890
  )
851
891
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -501,12 +498,23 @@ class LabelPropagation(BaseTransformer):
501
498
  autogenerated=self._autogenerated,
502
499
  subproject=_SUBPROJECT,
503
500
  )
504
- output_result, fitted_estimator = model_trainer.train_fit_predict(
505
- drop_input_cols=self._drop_input_cols,
506
- expected_output_cols_list=(
507
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
508
- ),
501
+ expected_output_cols = (
502
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
509
503
  )
504
+ if isinstance(dataset, DataFrame):
505
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
506
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
507
+ )
508
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
509
+ drop_input_cols=self._drop_input_cols,
510
+ expected_output_cols_list=expected_output_cols,
511
+ example_output_pd_df=example_output_pd_df,
512
+ )
513
+ else:
514
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
515
+ drop_input_cols=self._drop_input_cols,
516
+ expected_output_cols_list=expected_output_cols,
517
+ )
510
518
  self._sklearn_object = fitted_estimator
511
519
  self._is_fitted = True
512
520
  return output_result
@@ -529,6 +537,7 @@ class LabelPropagation(BaseTransformer):
529
537
  """
530
538
  self._infer_input_output_cols(dataset)
531
539
  super()._check_dataset_type(dataset)
540
+
532
541
  model_trainer = ModelTrainerBuilder.build_fit_transform(
533
542
  estimator=self._sklearn_object,
534
543
  dataset=dataset,
@@ -585,12 +594,41 @@ class LabelPropagation(BaseTransformer):
585
594
 
586
595
  return rv
587
596
 
588
- def _align_expected_output_names(
589
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
590
- ) -> List[str]:
597
+ def _align_expected_output(
598
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
599
+ ) -> Tuple[List[str], pd.DataFrame]:
600
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
601
+ and output dataframe with 1 line.
602
+ If the method is fit_predict, run 2 lines of data.
603
+ """
591
604
  # in case the inferred output column names dimension is different
592
605
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
593
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
606
+
607
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
608
+ # so change the minimum of number of rows to 2
609
+ num_examples = 2
610
+ statement_params = telemetry.get_function_usage_statement_params(
611
+ project=_PROJECT,
612
+ subproject=_SUBPROJECT,
613
+ function_name=telemetry.get_statement_params_full_func_name(
614
+ inspect.currentframe(), LabelPropagation.__class__.__name__
615
+ ),
616
+ api_calls=[Session.call],
617
+ custom_tags={"autogen": True} if self._autogenerated else None,
618
+ )
619
+ if output_cols_prefix == "fit_predict_":
620
+ if hasattr(self._sklearn_object, "n_clusters"):
621
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
622
+ num_examples = self._sklearn_object.n_clusters
623
+ elif hasattr(self._sklearn_object, "min_samples"):
624
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
625
+ num_examples = self._sklearn_object.min_samples
626
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
627
+ # LocalOutlierFactor expects n_neighbors <= n_samples
628
+ num_examples = self._sklearn_object.n_neighbors
629
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
630
+ else:
631
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
594
632
 
595
633
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
596
634
  # seen during the fit.
@@ -602,12 +640,14 @@ class LabelPropagation(BaseTransformer):
602
640
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
603
641
  if self.sample_weight_col:
604
642
  output_df_columns_set -= set(self.sample_weight_col)
643
+
605
644
  # if the dimension of inferred output column names is correct; use it
606
645
  if len(expected_output_cols_list) == len(output_df_columns_set):
607
- return expected_output_cols_list
646
+ return expected_output_cols_list, output_df_pd
608
647
  # otherwise, use the sklearn estimator's output
609
648
  else:
610
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
649
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
650
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
611
651
 
612
652
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
613
653
  @telemetry.send_api_usage_telemetry(
@@ -655,7 +695,7 @@ class LabelPropagation(BaseTransformer):
655
695
  drop_input_cols=self._drop_input_cols,
656
696
  expected_output_cols_type="float",
657
697
  )
658
- expected_output_cols = self._align_expected_output_names(
698
+ expected_output_cols, _ = self._align_expected_output(
659
699
  inference_method, dataset, expected_output_cols, output_cols_prefix
660
700
  )
661
701
 
@@ -723,7 +763,7 @@ class LabelPropagation(BaseTransformer):
723
763
  drop_input_cols=self._drop_input_cols,
724
764
  expected_output_cols_type="float",
725
765
  )
726
- expected_output_cols = self._align_expected_output_names(
766
+ expected_output_cols, _ = self._align_expected_output(
727
767
  inference_method, dataset, expected_output_cols, output_cols_prefix
728
768
  )
729
769
  elif isinstance(dataset, pd.DataFrame):
@@ -786,7 +826,7 @@ class LabelPropagation(BaseTransformer):
786
826
  drop_input_cols=self._drop_input_cols,
787
827
  expected_output_cols_type="float",
788
828
  )
789
- expected_output_cols = self._align_expected_output_names(
829
+ expected_output_cols, _ = self._align_expected_output(
790
830
  inference_method, dataset, expected_output_cols, output_cols_prefix
791
831
  )
792
832
 
@@ -851,7 +891,7 @@ class LabelPropagation(BaseTransformer):
851
891
  drop_input_cols = self._drop_input_cols,
852
892
  expected_output_cols_type="float",
853
893
  )
854
- expected_output_cols = self._align_expected_output_names(
894
+ expected_output_cols, _ = self._align_expected_output(
855
895
  inference_method, dataset, expected_output_cols, output_cols_prefix
856
896
  )
857
897
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -510,12 +507,23 @@ class LabelSpreading(BaseTransformer):
510
507
  autogenerated=self._autogenerated,
511
508
  subproject=_SUBPROJECT,
512
509
  )
513
- output_result, fitted_estimator = model_trainer.train_fit_predict(
514
- drop_input_cols=self._drop_input_cols,
515
- expected_output_cols_list=(
516
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
517
- ),
510
+ expected_output_cols = (
511
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
518
512
  )
513
+ if isinstance(dataset, DataFrame):
514
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
515
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
516
+ )
517
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
518
+ drop_input_cols=self._drop_input_cols,
519
+ expected_output_cols_list=expected_output_cols,
520
+ example_output_pd_df=example_output_pd_df,
521
+ )
522
+ else:
523
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
524
+ drop_input_cols=self._drop_input_cols,
525
+ expected_output_cols_list=expected_output_cols,
526
+ )
519
527
  self._sklearn_object = fitted_estimator
520
528
  self._is_fitted = True
521
529
  return output_result
@@ -538,6 +546,7 @@ class LabelSpreading(BaseTransformer):
538
546
  """
539
547
  self._infer_input_output_cols(dataset)
540
548
  super()._check_dataset_type(dataset)
549
+
541
550
  model_trainer = ModelTrainerBuilder.build_fit_transform(
542
551
  estimator=self._sklearn_object,
543
552
  dataset=dataset,
@@ -594,12 +603,41 @@ class LabelSpreading(BaseTransformer):
594
603
 
595
604
  return rv
596
605
 
597
- def _align_expected_output_names(
598
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
599
- ) -> List[str]:
606
+ def _align_expected_output(
607
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
608
+ ) -> Tuple[List[str], pd.DataFrame]:
609
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
610
+ and output dataframe with 1 line.
611
+ If the method is fit_predict, run 2 lines of data.
612
+ """
600
613
  # in case the inferred output column names dimension is different
601
614
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
602
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
615
+
616
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
617
+ # so change the minimum of number of rows to 2
618
+ num_examples = 2
619
+ statement_params = telemetry.get_function_usage_statement_params(
620
+ project=_PROJECT,
621
+ subproject=_SUBPROJECT,
622
+ function_name=telemetry.get_statement_params_full_func_name(
623
+ inspect.currentframe(), LabelSpreading.__class__.__name__
624
+ ),
625
+ api_calls=[Session.call],
626
+ custom_tags={"autogen": True} if self._autogenerated else None,
627
+ )
628
+ if output_cols_prefix == "fit_predict_":
629
+ if hasattr(self._sklearn_object, "n_clusters"):
630
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
631
+ num_examples = self._sklearn_object.n_clusters
632
+ elif hasattr(self._sklearn_object, "min_samples"):
633
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
634
+ num_examples = self._sklearn_object.min_samples
635
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
636
+ # LocalOutlierFactor expects n_neighbors <= n_samples
637
+ num_examples = self._sklearn_object.n_neighbors
638
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
639
+ else:
640
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
603
641
 
604
642
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
605
643
  # seen during the fit.
@@ -611,12 +649,14 @@ class LabelSpreading(BaseTransformer):
611
649
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
612
650
  if self.sample_weight_col:
613
651
  output_df_columns_set -= set(self.sample_weight_col)
652
+
614
653
  # if the dimension of inferred output column names is correct; use it
615
654
  if len(expected_output_cols_list) == len(output_df_columns_set):
616
- return expected_output_cols_list
655
+ return expected_output_cols_list, output_df_pd
617
656
  # otherwise, use the sklearn estimator's output
618
657
  else:
619
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
658
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
659
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
620
660
 
621
661
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
622
662
  @telemetry.send_api_usage_telemetry(
@@ -664,7 +704,7 @@ class LabelSpreading(BaseTransformer):
664
704
  drop_input_cols=self._drop_input_cols,
665
705
  expected_output_cols_type="float",
666
706
  )
667
- expected_output_cols = self._align_expected_output_names(
707
+ expected_output_cols, _ = self._align_expected_output(
668
708
  inference_method, dataset, expected_output_cols, output_cols_prefix
669
709
  )
670
710
 
@@ -732,7 +772,7 @@ class LabelSpreading(BaseTransformer):
732
772
  drop_input_cols=self._drop_input_cols,
733
773
  expected_output_cols_type="float",
734
774
  )
735
- expected_output_cols = self._align_expected_output_names(
775
+ expected_output_cols, _ = self._align_expected_output(
736
776
  inference_method, dataset, expected_output_cols, output_cols_prefix
737
777
  )
738
778
  elif isinstance(dataset, pd.DataFrame):
@@ -795,7 +835,7 @@ class LabelSpreading(BaseTransformer):
795
835
  drop_input_cols=self._drop_input_cols,
796
836
  expected_output_cols_type="float",
797
837
  )
798
- expected_output_cols = self._align_expected_output_names(
838
+ expected_output_cols, _ = self._align_expected_output(
799
839
  inference_method, dataset, expected_output_cols, output_cols_prefix
800
840
  )
801
841
 
@@ -860,7 +900,7 @@ class LabelSpreading(BaseTransformer):
860
900
  drop_input_cols = self._drop_input_cols,
861
901
  expected_output_cols_type="float",
862
902
  )
863
- expected_output_cols = self._align_expected_output_names(
903
+ expected_output_cols, _ = self._align_expected_output(
864
904
  inference_method, dataset, expected_output_cols, output_cols_prefix
865
905
  )
866
906
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -566,12 +563,23 @@ class LinearSVC(BaseTransformer):
566
563
  autogenerated=self._autogenerated,
567
564
  subproject=_SUBPROJECT,
568
565
  )
569
- output_result, fitted_estimator = model_trainer.train_fit_predict(
570
- drop_input_cols=self._drop_input_cols,
571
- expected_output_cols_list=(
572
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
573
- ),
566
+ expected_output_cols = (
567
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
574
568
  )
569
+ if isinstance(dataset, DataFrame):
570
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
571
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
572
+ )
573
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
574
+ drop_input_cols=self._drop_input_cols,
575
+ expected_output_cols_list=expected_output_cols,
576
+ example_output_pd_df=example_output_pd_df,
577
+ )
578
+ else:
579
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
580
+ drop_input_cols=self._drop_input_cols,
581
+ expected_output_cols_list=expected_output_cols,
582
+ )
575
583
  self._sklearn_object = fitted_estimator
576
584
  self._is_fitted = True
577
585
  return output_result
@@ -594,6 +602,7 @@ class LinearSVC(BaseTransformer):
594
602
  """
595
603
  self._infer_input_output_cols(dataset)
596
604
  super()._check_dataset_type(dataset)
605
+
597
606
  model_trainer = ModelTrainerBuilder.build_fit_transform(
598
607
  estimator=self._sklearn_object,
599
608
  dataset=dataset,
@@ -650,12 +659,41 @@ class LinearSVC(BaseTransformer):
650
659
 
651
660
  return rv
652
661
 
653
- def _align_expected_output_names(
654
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
655
- ) -> List[str]:
662
+ def _align_expected_output(
663
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
664
+ ) -> Tuple[List[str], pd.DataFrame]:
665
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
666
+ and output dataframe with 1 line.
667
+ If the method is fit_predict, run 2 lines of data.
668
+ """
656
669
  # in case the inferred output column names dimension is different
657
670
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
658
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
671
+
672
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
673
+ # so change the minimum of number of rows to 2
674
+ num_examples = 2
675
+ statement_params = telemetry.get_function_usage_statement_params(
676
+ project=_PROJECT,
677
+ subproject=_SUBPROJECT,
678
+ function_name=telemetry.get_statement_params_full_func_name(
679
+ inspect.currentframe(), LinearSVC.__class__.__name__
680
+ ),
681
+ api_calls=[Session.call],
682
+ custom_tags={"autogen": True} if self._autogenerated else None,
683
+ )
684
+ if output_cols_prefix == "fit_predict_":
685
+ if hasattr(self._sklearn_object, "n_clusters"):
686
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
687
+ num_examples = self._sklearn_object.n_clusters
688
+ elif hasattr(self._sklearn_object, "min_samples"):
689
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
690
+ num_examples = self._sklearn_object.min_samples
691
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
692
+ # LocalOutlierFactor expects n_neighbors <= n_samples
693
+ num_examples = self._sklearn_object.n_neighbors
694
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
695
+ else:
696
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
659
697
 
660
698
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
661
699
  # seen during the fit.
@@ -667,12 +705,14 @@ class LinearSVC(BaseTransformer):
667
705
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
668
706
  if self.sample_weight_col:
669
707
  output_df_columns_set -= set(self.sample_weight_col)
708
+
670
709
  # if the dimension of inferred output column names is correct; use it
671
710
  if len(expected_output_cols_list) == len(output_df_columns_set):
672
- return expected_output_cols_list
711
+ return expected_output_cols_list, output_df_pd
673
712
  # otherwise, use the sklearn estimator's output
674
713
  else:
675
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
714
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
715
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
676
716
 
677
717
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
678
718
  @telemetry.send_api_usage_telemetry(
@@ -718,7 +758,7 @@ class LinearSVC(BaseTransformer):
718
758
  drop_input_cols=self._drop_input_cols,
719
759
  expected_output_cols_type="float",
720
760
  )
721
- expected_output_cols = self._align_expected_output_names(
761
+ expected_output_cols, _ = self._align_expected_output(
722
762
  inference_method, dataset, expected_output_cols, output_cols_prefix
723
763
  )
724
764
 
@@ -784,7 +824,7 @@ class LinearSVC(BaseTransformer):
784
824
  drop_input_cols=self._drop_input_cols,
785
825
  expected_output_cols_type="float",
786
826
  )
787
- expected_output_cols = self._align_expected_output_names(
827
+ expected_output_cols, _ = self._align_expected_output(
788
828
  inference_method, dataset, expected_output_cols, output_cols_prefix
789
829
  )
790
830
  elif isinstance(dataset, pd.DataFrame):
@@ -849,7 +889,7 @@ class LinearSVC(BaseTransformer):
849
889
  drop_input_cols=self._drop_input_cols,
850
890
  expected_output_cols_type="float",
851
891
  )
852
- expected_output_cols = self._align_expected_output_names(
892
+ expected_output_cols, _ = self._align_expected_output(
853
893
  inference_method, dataset, expected_output_cols, output_cols_prefix
854
894
  )
855
895
 
@@ -914,7 +954,7 @@ class LinearSVC(BaseTransformer):
914
954
  drop_input_cols = self._drop_input_cols,
915
955
  expected_output_cols_type="float",
916
956
  )
917
- expected_output_cols = self._align_expected_output_names(
957
+ expected_output_cols, _ = self._align_expected_output(
918
958
  inference_method, dataset, expected_output_cols, output_cols_prefix
919
959
  )
920
960