snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -485,12 +482,23 @@ class RBFSampler(BaseTransformer):
485
482
  autogenerated=self._autogenerated,
486
483
  subproject=_SUBPROJECT,
487
484
  )
488
- output_result, fitted_estimator = model_trainer.train_fit_predict(
489
- drop_input_cols=self._drop_input_cols,
490
- expected_output_cols_list=(
491
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
492
- ),
485
+ expected_output_cols = (
486
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
493
487
  )
488
+ if isinstance(dataset, DataFrame):
489
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
490
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
491
+ )
492
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
493
+ drop_input_cols=self._drop_input_cols,
494
+ expected_output_cols_list=expected_output_cols,
495
+ example_output_pd_df=example_output_pd_df,
496
+ )
497
+ else:
498
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
499
+ drop_input_cols=self._drop_input_cols,
500
+ expected_output_cols_list=expected_output_cols,
501
+ )
494
502
  self._sklearn_object = fitted_estimator
495
503
  self._is_fitted = True
496
504
  return output_result
@@ -515,6 +523,7 @@ class RBFSampler(BaseTransformer):
515
523
  """
516
524
  self._infer_input_output_cols(dataset)
517
525
  super()._check_dataset_type(dataset)
526
+
518
527
  model_trainer = ModelTrainerBuilder.build_fit_transform(
519
528
  estimator=self._sklearn_object,
520
529
  dataset=dataset,
@@ -571,12 +580,41 @@ class RBFSampler(BaseTransformer):
571
580
 
572
581
  return rv
573
582
 
574
- def _align_expected_output_names(
575
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
576
- ) -> List[str]:
583
+ def _align_expected_output(
584
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
585
+ ) -> Tuple[List[str], pd.DataFrame]:
586
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
587
+ and output dataframe with 1 line.
588
+ If the method is fit_predict, run 2 lines of data.
589
+ """
577
590
  # in case the inferred output column names dimension is different
578
591
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
579
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
592
+
593
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
594
+ # so change the minimum of number of rows to 2
595
+ num_examples = 2
596
+ statement_params = telemetry.get_function_usage_statement_params(
597
+ project=_PROJECT,
598
+ subproject=_SUBPROJECT,
599
+ function_name=telemetry.get_statement_params_full_func_name(
600
+ inspect.currentframe(), RBFSampler.__class__.__name__
601
+ ),
602
+ api_calls=[Session.call],
603
+ custom_tags={"autogen": True} if self._autogenerated else None,
604
+ )
605
+ if output_cols_prefix == "fit_predict_":
606
+ if hasattr(self._sklearn_object, "n_clusters"):
607
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
608
+ num_examples = self._sklearn_object.n_clusters
609
+ elif hasattr(self._sklearn_object, "min_samples"):
610
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
611
+ num_examples = self._sklearn_object.min_samples
612
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
613
+ # LocalOutlierFactor expects n_neighbors <= n_samples
614
+ num_examples = self._sklearn_object.n_neighbors
615
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
616
+ else:
617
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
580
618
 
581
619
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
582
620
  # seen during the fit.
@@ -588,12 +626,14 @@ class RBFSampler(BaseTransformer):
588
626
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
589
627
  if self.sample_weight_col:
590
628
  output_df_columns_set -= set(self.sample_weight_col)
629
+
591
630
  # if the dimension of inferred output column names is correct; use it
592
631
  if len(expected_output_cols_list) == len(output_df_columns_set):
593
- return expected_output_cols_list
632
+ return expected_output_cols_list, output_df_pd
594
633
  # otherwise, use the sklearn estimator's output
595
634
  else:
596
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
635
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
636
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
597
637
 
598
638
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
599
639
  @telemetry.send_api_usage_telemetry(
@@ -639,7 +679,7 @@ class RBFSampler(BaseTransformer):
639
679
  drop_input_cols=self._drop_input_cols,
640
680
  expected_output_cols_type="float",
641
681
  )
642
- expected_output_cols = self._align_expected_output_names(
682
+ expected_output_cols, _ = self._align_expected_output(
643
683
  inference_method, dataset, expected_output_cols, output_cols_prefix
644
684
  )
645
685
 
@@ -705,7 +745,7 @@ class RBFSampler(BaseTransformer):
705
745
  drop_input_cols=self._drop_input_cols,
706
746
  expected_output_cols_type="float",
707
747
  )
708
- expected_output_cols = self._align_expected_output_names(
748
+ expected_output_cols, _ = self._align_expected_output(
709
749
  inference_method, dataset, expected_output_cols, output_cols_prefix
710
750
  )
711
751
  elif isinstance(dataset, pd.DataFrame):
@@ -768,7 +808,7 @@ class RBFSampler(BaseTransformer):
768
808
  drop_input_cols=self._drop_input_cols,
769
809
  expected_output_cols_type="float",
770
810
  )
771
- expected_output_cols = self._align_expected_output_names(
811
+ expected_output_cols, _ = self._align_expected_output(
772
812
  inference_method, dataset, expected_output_cols, output_cols_prefix
773
813
  )
774
814
 
@@ -833,7 +873,7 @@ class RBFSampler(BaseTransformer):
833
873
  drop_input_cols = self._drop_input_cols,
834
874
  expected_output_cols_type="float",
835
875
  )
836
- expected_output_cols = self._align_expected_output_names(
876
+ expected_output_cols, _ = self._align_expected_output(
837
877
  inference_method, dataset, expected_output_cols, output_cols_prefix
838
878
  )
839
879
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -483,12 +480,23 @@ class SkewedChi2Sampler(BaseTransformer):
483
480
  autogenerated=self._autogenerated,
484
481
  subproject=_SUBPROJECT,
485
482
  )
486
- output_result, fitted_estimator = model_trainer.train_fit_predict(
487
- drop_input_cols=self._drop_input_cols,
488
- expected_output_cols_list=(
489
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
490
- ),
483
+ expected_output_cols = (
484
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
491
485
  )
486
+ if isinstance(dataset, DataFrame):
487
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
488
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
489
+ )
490
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
491
+ drop_input_cols=self._drop_input_cols,
492
+ expected_output_cols_list=expected_output_cols,
493
+ example_output_pd_df=example_output_pd_df,
494
+ )
495
+ else:
496
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
497
+ drop_input_cols=self._drop_input_cols,
498
+ expected_output_cols_list=expected_output_cols,
499
+ )
492
500
  self._sklearn_object = fitted_estimator
493
501
  self._is_fitted = True
494
502
  return output_result
@@ -513,6 +521,7 @@ class SkewedChi2Sampler(BaseTransformer):
513
521
  """
514
522
  self._infer_input_output_cols(dataset)
515
523
  super()._check_dataset_type(dataset)
524
+
516
525
  model_trainer = ModelTrainerBuilder.build_fit_transform(
517
526
  estimator=self._sklearn_object,
518
527
  dataset=dataset,
@@ -569,12 +578,41 @@ class SkewedChi2Sampler(BaseTransformer):
569
578
 
570
579
  return rv
571
580
 
572
- def _align_expected_output_names(
573
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
574
- ) -> List[str]:
581
+ def _align_expected_output(
582
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
583
+ ) -> Tuple[List[str], pd.DataFrame]:
584
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
585
+ and output dataframe with 1 line.
586
+ If the method is fit_predict, run 2 lines of data.
587
+ """
575
588
  # in case the inferred output column names dimension is different
576
589
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
577
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
590
+
591
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
592
+ # so change the minimum of number of rows to 2
593
+ num_examples = 2
594
+ statement_params = telemetry.get_function_usage_statement_params(
595
+ project=_PROJECT,
596
+ subproject=_SUBPROJECT,
597
+ function_name=telemetry.get_statement_params_full_func_name(
598
+ inspect.currentframe(), SkewedChi2Sampler.__class__.__name__
599
+ ),
600
+ api_calls=[Session.call],
601
+ custom_tags={"autogen": True} if self._autogenerated else None,
602
+ )
603
+ if output_cols_prefix == "fit_predict_":
604
+ if hasattr(self._sklearn_object, "n_clusters"):
605
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
606
+ num_examples = self._sklearn_object.n_clusters
607
+ elif hasattr(self._sklearn_object, "min_samples"):
608
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
609
+ num_examples = self._sklearn_object.min_samples
610
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
611
+ # LocalOutlierFactor expects n_neighbors <= n_samples
612
+ num_examples = self._sklearn_object.n_neighbors
613
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
614
+ else:
615
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
578
616
 
579
617
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
580
618
  # seen during the fit.
@@ -586,12 +624,14 @@ class SkewedChi2Sampler(BaseTransformer):
586
624
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
587
625
  if self.sample_weight_col:
588
626
  output_df_columns_set -= set(self.sample_weight_col)
627
+
589
628
  # if the dimension of inferred output column names is correct; use it
590
629
  if len(expected_output_cols_list) == len(output_df_columns_set):
591
- return expected_output_cols_list
630
+ return expected_output_cols_list, output_df_pd
592
631
  # otherwise, use the sklearn estimator's output
593
632
  else:
594
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
633
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
634
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
595
635
 
596
636
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
597
637
  @telemetry.send_api_usage_telemetry(
@@ -637,7 +677,7 @@ class SkewedChi2Sampler(BaseTransformer):
637
677
  drop_input_cols=self._drop_input_cols,
638
678
  expected_output_cols_type="float",
639
679
  )
640
- expected_output_cols = self._align_expected_output_names(
680
+ expected_output_cols, _ = self._align_expected_output(
641
681
  inference_method, dataset, expected_output_cols, output_cols_prefix
642
682
  )
643
683
 
@@ -703,7 +743,7 @@ class SkewedChi2Sampler(BaseTransformer):
703
743
  drop_input_cols=self._drop_input_cols,
704
744
  expected_output_cols_type="float",
705
745
  )
706
- expected_output_cols = self._align_expected_output_names(
746
+ expected_output_cols, _ = self._align_expected_output(
707
747
  inference_method, dataset, expected_output_cols, output_cols_prefix
708
748
  )
709
749
  elif isinstance(dataset, pd.DataFrame):
@@ -766,7 +806,7 @@ class SkewedChi2Sampler(BaseTransformer):
766
806
  drop_input_cols=self._drop_input_cols,
767
807
  expected_output_cols_type="float",
768
808
  )
769
- expected_output_cols = self._align_expected_output_names(
809
+ expected_output_cols, _ = self._align_expected_output(
770
810
  inference_method, dataset, expected_output_cols, output_cols_prefix
771
811
  )
772
812
 
@@ -831,7 +871,7 @@ class SkewedChi2Sampler(BaseTransformer):
831
871
  drop_input_cols = self._drop_input_cols,
832
872
  expected_output_cols_type="float",
833
873
  )
834
- expected_output_cols = self._align_expected_output_names(
874
+ expected_output_cols, _ = self._align_expected_output(
835
875
  inference_method, dataset, expected_output_cols, output_cols_prefix
836
876
  )
837
877
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -517,12 +514,23 @@ class KernelRidge(BaseTransformer):
517
514
  autogenerated=self._autogenerated,
518
515
  subproject=_SUBPROJECT,
519
516
  )
520
- output_result, fitted_estimator = model_trainer.train_fit_predict(
521
- drop_input_cols=self._drop_input_cols,
522
- expected_output_cols_list=(
523
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
524
- ),
517
+ expected_output_cols = (
518
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
525
519
  )
520
+ if isinstance(dataset, DataFrame):
521
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
522
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
523
+ )
524
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
525
+ drop_input_cols=self._drop_input_cols,
526
+ expected_output_cols_list=expected_output_cols,
527
+ example_output_pd_df=example_output_pd_df,
528
+ )
529
+ else:
530
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
531
+ drop_input_cols=self._drop_input_cols,
532
+ expected_output_cols_list=expected_output_cols,
533
+ )
526
534
  self._sklearn_object = fitted_estimator
527
535
  self._is_fitted = True
528
536
  return output_result
@@ -545,6 +553,7 @@ class KernelRidge(BaseTransformer):
545
553
  """
546
554
  self._infer_input_output_cols(dataset)
547
555
  super()._check_dataset_type(dataset)
556
+
548
557
  model_trainer = ModelTrainerBuilder.build_fit_transform(
549
558
  estimator=self._sklearn_object,
550
559
  dataset=dataset,
@@ -601,12 +610,41 @@ class KernelRidge(BaseTransformer):
601
610
 
602
611
  return rv
603
612
 
604
- def _align_expected_output_names(
605
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
606
- ) -> List[str]:
613
+ def _align_expected_output(
614
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
615
+ ) -> Tuple[List[str], pd.DataFrame]:
616
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
617
+ and output dataframe with 1 line.
618
+ If the method is fit_predict, run 2 lines of data.
619
+ """
607
620
  # in case the inferred output column names dimension is different
608
621
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
609
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
622
+
623
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
624
+ # so change the minimum of number of rows to 2
625
+ num_examples = 2
626
+ statement_params = telemetry.get_function_usage_statement_params(
627
+ project=_PROJECT,
628
+ subproject=_SUBPROJECT,
629
+ function_name=telemetry.get_statement_params_full_func_name(
630
+ inspect.currentframe(), KernelRidge.__class__.__name__
631
+ ),
632
+ api_calls=[Session.call],
633
+ custom_tags={"autogen": True} if self._autogenerated else None,
634
+ )
635
+ if output_cols_prefix == "fit_predict_":
636
+ if hasattr(self._sklearn_object, "n_clusters"):
637
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
638
+ num_examples = self._sklearn_object.n_clusters
639
+ elif hasattr(self._sklearn_object, "min_samples"):
640
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
641
+ num_examples = self._sklearn_object.min_samples
642
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
643
+ # LocalOutlierFactor expects n_neighbors <= n_samples
644
+ num_examples = self._sklearn_object.n_neighbors
645
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
646
+ else:
647
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
610
648
 
611
649
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
612
650
  # seen during the fit.
@@ -618,12 +656,14 @@ class KernelRidge(BaseTransformer):
618
656
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
619
657
  if self.sample_weight_col:
620
658
  output_df_columns_set -= set(self.sample_weight_col)
659
+
621
660
  # if the dimension of inferred output column names is correct; use it
622
661
  if len(expected_output_cols_list) == len(output_df_columns_set):
623
- return expected_output_cols_list
662
+ return expected_output_cols_list, output_df_pd
624
663
  # otherwise, use the sklearn estimator's output
625
664
  else:
626
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
665
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
666
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
627
667
 
628
668
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
629
669
  @telemetry.send_api_usage_telemetry(
@@ -669,7 +709,7 @@ class KernelRidge(BaseTransformer):
669
709
  drop_input_cols=self._drop_input_cols,
670
710
  expected_output_cols_type="float",
671
711
  )
672
- expected_output_cols = self._align_expected_output_names(
712
+ expected_output_cols, _ = self._align_expected_output(
673
713
  inference_method, dataset, expected_output_cols, output_cols_prefix
674
714
  )
675
715
 
@@ -735,7 +775,7 @@ class KernelRidge(BaseTransformer):
735
775
  drop_input_cols=self._drop_input_cols,
736
776
  expected_output_cols_type="float",
737
777
  )
738
- expected_output_cols = self._align_expected_output_names(
778
+ expected_output_cols, _ = self._align_expected_output(
739
779
  inference_method, dataset, expected_output_cols, output_cols_prefix
740
780
  )
741
781
  elif isinstance(dataset, pd.DataFrame):
@@ -798,7 +838,7 @@ class KernelRidge(BaseTransformer):
798
838
  drop_input_cols=self._drop_input_cols,
799
839
  expected_output_cols_type="float",
800
840
  )
801
- expected_output_cols = self._align_expected_output_names(
841
+ expected_output_cols, _ = self._align_expected_output(
802
842
  inference_method, dataset, expected_output_cols, output_cols_prefix
803
843
  )
804
844
 
@@ -863,7 +903,7 @@ class KernelRidge(BaseTransformer):
863
903
  drop_input_cols = self._drop_input_cols,
864
904
  expected_output_cols_type="float",
865
905
  )
866
- expected_output_cols = self._align_expected_output_names(
906
+ expected_output_cols, _ = self._align_expected_output(
867
907
  inference_method, dataset, expected_output_cols, output_cols_prefix
868
908
  )
869
909
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -506,12 +503,23 @@ class LGBMClassifier(BaseTransformer):
506
503
  autogenerated=self._autogenerated,
507
504
  subproject=_SUBPROJECT,
508
505
  )
509
- output_result, fitted_estimator = model_trainer.train_fit_predict(
510
- drop_input_cols=self._drop_input_cols,
511
- expected_output_cols_list=(
512
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
513
- ),
506
+ expected_output_cols = (
507
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
514
508
  )
509
+ if isinstance(dataset, DataFrame):
510
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
511
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
512
+ )
513
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
514
+ drop_input_cols=self._drop_input_cols,
515
+ expected_output_cols_list=expected_output_cols,
516
+ example_output_pd_df=example_output_pd_df,
517
+ )
518
+ else:
519
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
520
+ drop_input_cols=self._drop_input_cols,
521
+ expected_output_cols_list=expected_output_cols,
522
+ )
515
523
  self._sklearn_object = fitted_estimator
516
524
  self._is_fitted = True
517
525
  return output_result
@@ -534,6 +542,7 @@ class LGBMClassifier(BaseTransformer):
534
542
  """
535
543
  self._infer_input_output_cols(dataset)
536
544
  super()._check_dataset_type(dataset)
545
+
537
546
  model_trainer = ModelTrainerBuilder.build_fit_transform(
538
547
  estimator=self._sklearn_object,
539
548
  dataset=dataset,
@@ -590,12 +599,41 @@ class LGBMClassifier(BaseTransformer):
590
599
 
591
600
  return rv
592
601
 
593
- def _align_expected_output_names(
594
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
595
- ) -> List[str]:
602
+ def _align_expected_output(
603
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
604
+ ) -> Tuple[List[str], pd.DataFrame]:
605
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
606
+ and output dataframe with 1 line.
607
+ If the method is fit_predict, run 2 lines of data.
608
+ """
596
609
  # in case the inferred output column names dimension is different
597
610
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
598
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
611
+
612
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
613
+ # so change the minimum of number of rows to 2
614
+ num_examples = 2
615
+ statement_params = telemetry.get_function_usage_statement_params(
616
+ project=_PROJECT,
617
+ subproject=_SUBPROJECT,
618
+ function_name=telemetry.get_statement_params_full_func_name(
619
+ inspect.currentframe(), LGBMClassifier.__class__.__name__
620
+ ),
621
+ api_calls=[Session.call],
622
+ custom_tags={"autogen": True} if self._autogenerated else None,
623
+ )
624
+ if output_cols_prefix == "fit_predict_":
625
+ if hasattr(self._sklearn_object, "n_clusters"):
626
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
627
+ num_examples = self._sklearn_object.n_clusters
628
+ elif hasattr(self._sklearn_object, "min_samples"):
629
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
630
+ num_examples = self._sklearn_object.min_samples
631
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
632
+ # LocalOutlierFactor expects n_neighbors <= n_samples
633
+ num_examples = self._sklearn_object.n_neighbors
634
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
635
+ else:
636
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
599
637
 
600
638
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
601
639
  # seen during the fit.
@@ -607,12 +645,14 @@ class LGBMClassifier(BaseTransformer):
607
645
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
608
646
  if self.sample_weight_col:
609
647
  output_df_columns_set -= set(self.sample_weight_col)
648
+
610
649
  # if the dimension of inferred output column names is correct; use it
611
650
  if len(expected_output_cols_list) == len(output_df_columns_set):
612
- return expected_output_cols_list
651
+ return expected_output_cols_list, output_df_pd
613
652
  # otherwise, use the sklearn estimator's output
614
653
  else:
615
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
654
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
655
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
616
656
 
617
657
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
618
658
  @telemetry.send_api_usage_telemetry(
@@ -660,7 +700,7 @@ class LGBMClassifier(BaseTransformer):
660
700
  drop_input_cols=self._drop_input_cols,
661
701
  expected_output_cols_type="float",
662
702
  )
663
- expected_output_cols = self._align_expected_output_names(
703
+ expected_output_cols, _ = self._align_expected_output(
664
704
  inference_method, dataset, expected_output_cols, output_cols_prefix
665
705
  )
666
706
 
@@ -728,7 +768,7 @@ class LGBMClassifier(BaseTransformer):
728
768
  drop_input_cols=self._drop_input_cols,
729
769
  expected_output_cols_type="float",
730
770
  )
731
- expected_output_cols = self._align_expected_output_names(
771
+ expected_output_cols, _ = self._align_expected_output(
732
772
  inference_method, dataset, expected_output_cols, output_cols_prefix
733
773
  )
734
774
  elif isinstance(dataset, pd.DataFrame):
@@ -791,7 +831,7 @@ class LGBMClassifier(BaseTransformer):
791
831
  drop_input_cols=self._drop_input_cols,
792
832
  expected_output_cols_type="float",
793
833
  )
794
- expected_output_cols = self._align_expected_output_names(
834
+ expected_output_cols, _ = self._align_expected_output(
795
835
  inference_method, dataset, expected_output_cols, output_cols_prefix
796
836
  )
797
837
 
@@ -856,7 +896,7 @@ class LGBMClassifier(BaseTransformer):
856
896
  drop_input_cols = self._drop_input_cols,
857
897
  expected_output_cols_type="float",
858
898
  )
859
- expected_output_cols = self._align_expected_output_names(
899
+ expected_output_cols, _ = self._align_expected_output(
860
900
  inference_method, dataset, expected_output_cols, output_cols_prefix
861
901
  )
862
902