snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -492,12 +489,23 @@ class OneVsRestClassifier(BaseTransformer):
492
489
  autogenerated=self._autogenerated,
493
490
  subproject=_SUBPROJECT,
494
491
  )
495
- output_result, fitted_estimator = model_trainer.train_fit_predict(
496
- drop_input_cols=self._drop_input_cols,
497
- expected_output_cols_list=(
498
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
499
- ),
492
+ expected_output_cols = (
493
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
500
494
  )
495
+ if isinstance(dataset, DataFrame):
496
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
497
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
498
+ )
499
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
500
+ drop_input_cols=self._drop_input_cols,
501
+ expected_output_cols_list=expected_output_cols,
502
+ example_output_pd_df=example_output_pd_df,
503
+ )
504
+ else:
505
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
506
+ drop_input_cols=self._drop_input_cols,
507
+ expected_output_cols_list=expected_output_cols,
508
+ )
501
509
  self._sklearn_object = fitted_estimator
502
510
  self._is_fitted = True
503
511
  return output_result
@@ -520,6 +528,7 @@ class OneVsRestClassifier(BaseTransformer):
520
528
  """
521
529
  self._infer_input_output_cols(dataset)
522
530
  super()._check_dataset_type(dataset)
531
+
523
532
  model_trainer = ModelTrainerBuilder.build_fit_transform(
524
533
  estimator=self._sklearn_object,
525
534
  dataset=dataset,
@@ -576,12 +585,41 @@ class OneVsRestClassifier(BaseTransformer):
576
585
 
577
586
  return rv
578
587
 
579
- def _align_expected_output_names(
580
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
581
- ) -> List[str]:
588
+ def _align_expected_output(
589
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
590
+ ) -> Tuple[List[str], pd.DataFrame]:
591
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
592
+ and output dataframe with 1 line.
593
+ If the method is fit_predict, run 2 lines of data.
594
+ """
582
595
  # in case the inferred output column names dimension is different
583
596
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
584
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
597
+
598
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
599
+ # so change the minimum of number of rows to 2
600
+ num_examples = 2
601
+ statement_params = telemetry.get_function_usage_statement_params(
602
+ project=_PROJECT,
603
+ subproject=_SUBPROJECT,
604
+ function_name=telemetry.get_statement_params_full_func_name(
605
+ inspect.currentframe(), OneVsRestClassifier.__class__.__name__
606
+ ),
607
+ api_calls=[Session.call],
608
+ custom_tags={"autogen": True} if self._autogenerated else None,
609
+ )
610
+ if output_cols_prefix == "fit_predict_":
611
+ if hasattr(self._sklearn_object, "n_clusters"):
612
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
613
+ num_examples = self._sklearn_object.n_clusters
614
+ elif hasattr(self._sklearn_object, "min_samples"):
615
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
616
+ num_examples = self._sklearn_object.min_samples
617
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
618
+ # LocalOutlierFactor expects n_neighbors <= n_samples
619
+ num_examples = self._sklearn_object.n_neighbors
620
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
621
+ else:
622
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
585
623
 
586
624
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
587
625
  # seen during the fit.
@@ -593,12 +631,14 @@ class OneVsRestClassifier(BaseTransformer):
593
631
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
594
632
  if self.sample_weight_col:
595
633
  output_df_columns_set -= set(self.sample_weight_col)
634
+
596
635
  # if the dimension of inferred output column names is correct; use it
597
636
  if len(expected_output_cols_list) == len(output_df_columns_set):
598
- return expected_output_cols_list
637
+ return expected_output_cols_list, output_df_pd
599
638
  # otherwise, use the sklearn estimator's output
600
639
  else:
601
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
640
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
641
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
602
642
 
603
643
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
604
644
  @telemetry.send_api_usage_telemetry(
@@ -646,7 +686,7 @@ class OneVsRestClassifier(BaseTransformer):
646
686
  drop_input_cols=self._drop_input_cols,
647
687
  expected_output_cols_type="float",
648
688
  )
649
- expected_output_cols = self._align_expected_output_names(
689
+ expected_output_cols, _ = self._align_expected_output(
650
690
  inference_method, dataset, expected_output_cols, output_cols_prefix
651
691
  )
652
692
 
@@ -714,7 +754,7 @@ class OneVsRestClassifier(BaseTransformer):
714
754
  drop_input_cols=self._drop_input_cols,
715
755
  expected_output_cols_type="float",
716
756
  )
717
- expected_output_cols = self._align_expected_output_names(
757
+ expected_output_cols, _ = self._align_expected_output(
718
758
  inference_method, dataset, expected_output_cols, output_cols_prefix
719
759
  )
720
760
  elif isinstance(dataset, pd.DataFrame):
@@ -779,7 +819,7 @@ class OneVsRestClassifier(BaseTransformer):
779
819
  drop_input_cols=self._drop_input_cols,
780
820
  expected_output_cols_type="float",
781
821
  )
782
- expected_output_cols = self._align_expected_output_names(
822
+ expected_output_cols, _ = self._align_expected_output(
783
823
  inference_method, dataset, expected_output_cols, output_cols_prefix
784
824
  )
785
825
 
@@ -844,7 +884,7 @@ class OneVsRestClassifier(BaseTransformer):
844
884
  drop_input_cols = self._drop_input_cols,
845
885
  expected_output_cols_type="float",
846
886
  )
847
- expected_output_cols = self._align_expected_output_names(
887
+ expected_output_cols, _ = self._align_expected_output(
848
888
  inference_method, dataset, expected_output_cols, output_cols_prefix
849
889
  )
850
890
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -495,12 +492,23 @@ class OutputCodeClassifier(BaseTransformer):
495
492
  autogenerated=self._autogenerated,
496
493
  subproject=_SUBPROJECT,
497
494
  )
498
- output_result, fitted_estimator = model_trainer.train_fit_predict(
499
- drop_input_cols=self._drop_input_cols,
500
- expected_output_cols_list=(
501
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
502
- ),
495
+ expected_output_cols = (
496
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
503
497
  )
498
+ if isinstance(dataset, DataFrame):
499
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
500
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
501
+ )
502
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
503
+ drop_input_cols=self._drop_input_cols,
504
+ expected_output_cols_list=expected_output_cols,
505
+ example_output_pd_df=example_output_pd_df,
506
+ )
507
+ else:
508
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
509
+ drop_input_cols=self._drop_input_cols,
510
+ expected_output_cols_list=expected_output_cols,
511
+ )
504
512
  self._sklearn_object = fitted_estimator
505
513
  self._is_fitted = True
506
514
  return output_result
@@ -523,6 +531,7 @@ class OutputCodeClassifier(BaseTransformer):
523
531
  """
524
532
  self._infer_input_output_cols(dataset)
525
533
  super()._check_dataset_type(dataset)
534
+
526
535
  model_trainer = ModelTrainerBuilder.build_fit_transform(
527
536
  estimator=self._sklearn_object,
528
537
  dataset=dataset,
@@ -579,12 +588,41 @@ class OutputCodeClassifier(BaseTransformer):
579
588
 
580
589
  return rv
581
590
 
582
- def _align_expected_output_names(
583
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
584
- ) -> List[str]:
591
+ def _align_expected_output(
592
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
593
+ ) -> Tuple[List[str], pd.DataFrame]:
594
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
595
+ and output dataframe with 1 line.
596
+ If the method is fit_predict, run 2 lines of data.
597
+ """
585
598
  # in case the inferred output column names dimension is different
586
599
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
587
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
600
+
601
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
602
+ # so change the minimum of number of rows to 2
603
+ num_examples = 2
604
+ statement_params = telemetry.get_function_usage_statement_params(
605
+ project=_PROJECT,
606
+ subproject=_SUBPROJECT,
607
+ function_name=telemetry.get_statement_params_full_func_name(
608
+ inspect.currentframe(), OutputCodeClassifier.__class__.__name__
609
+ ),
610
+ api_calls=[Session.call],
611
+ custom_tags={"autogen": True} if self._autogenerated else None,
612
+ )
613
+ if output_cols_prefix == "fit_predict_":
614
+ if hasattr(self._sklearn_object, "n_clusters"):
615
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
616
+ num_examples = self._sklearn_object.n_clusters
617
+ elif hasattr(self._sklearn_object, "min_samples"):
618
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
619
+ num_examples = self._sklearn_object.min_samples
620
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
621
+ # LocalOutlierFactor expects n_neighbors <= n_samples
622
+ num_examples = self._sklearn_object.n_neighbors
623
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
624
+ else:
625
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
588
626
 
589
627
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
590
628
  # seen during the fit.
@@ -596,12 +634,14 @@ class OutputCodeClassifier(BaseTransformer):
596
634
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
597
635
  if self.sample_weight_col:
598
636
  output_df_columns_set -= set(self.sample_weight_col)
637
+
599
638
  # if the dimension of inferred output column names is correct; use it
600
639
  if len(expected_output_cols_list) == len(output_df_columns_set):
601
- return expected_output_cols_list
640
+ return expected_output_cols_list, output_df_pd
602
641
  # otherwise, use the sklearn estimator's output
603
642
  else:
604
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
643
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
644
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
605
645
 
606
646
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
607
647
  @telemetry.send_api_usage_telemetry(
@@ -647,7 +687,7 @@ class OutputCodeClassifier(BaseTransformer):
647
687
  drop_input_cols=self._drop_input_cols,
648
688
  expected_output_cols_type="float",
649
689
  )
650
- expected_output_cols = self._align_expected_output_names(
690
+ expected_output_cols, _ = self._align_expected_output(
651
691
  inference_method, dataset, expected_output_cols, output_cols_prefix
652
692
  )
653
693
 
@@ -713,7 +753,7 @@ class OutputCodeClassifier(BaseTransformer):
713
753
  drop_input_cols=self._drop_input_cols,
714
754
  expected_output_cols_type="float",
715
755
  )
716
- expected_output_cols = self._align_expected_output_names(
756
+ expected_output_cols, _ = self._align_expected_output(
717
757
  inference_method, dataset, expected_output_cols, output_cols_prefix
718
758
  )
719
759
  elif isinstance(dataset, pd.DataFrame):
@@ -776,7 +816,7 @@ class OutputCodeClassifier(BaseTransformer):
776
816
  drop_input_cols=self._drop_input_cols,
777
817
  expected_output_cols_type="float",
778
818
  )
779
- expected_output_cols = self._align_expected_output_names(
819
+ expected_output_cols, _ = self._align_expected_output(
780
820
  inference_method, dataset, expected_output_cols, output_cols_prefix
781
821
  )
782
822
 
@@ -841,7 +881,7 @@ class OutputCodeClassifier(BaseTransformer):
841
881
  drop_input_cols = self._drop_input_cols,
842
882
  expected_output_cols_type="float",
843
883
  )
844
- expected_output_cols = self._align_expected_output_names(
884
+ expected_output_cols, _ = self._align_expected_output(
845
885
  inference_method, dataset, expected_output_cols, output_cols_prefix
846
886
  )
847
887
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -495,12 +492,23 @@ class BernoulliNB(BaseTransformer):
495
492
  autogenerated=self._autogenerated,
496
493
  subproject=_SUBPROJECT,
497
494
  )
498
- output_result, fitted_estimator = model_trainer.train_fit_predict(
499
- drop_input_cols=self._drop_input_cols,
500
- expected_output_cols_list=(
501
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
502
- ),
495
+ expected_output_cols = (
496
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
503
497
  )
498
+ if isinstance(dataset, DataFrame):
499
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
500
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
501
+ )
502
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
503
+ drop_input_cols=self._drop_input_cols,
504
+ expected_output_cols_list=expected_output_cols,
505
+ example_output_pd_df=example_output_pd_df,
506
+ )
507
+ else:
508
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
509
+ drop_input_cols=self._drop_input_cols,
510
+ expected_output_cols_list=expected_output_cols,
511
+ )
504
512
  self._sklearn_object = fitted_estimator
505
513
  self._is_fitted = True
506
514
  return output_result
@@ -523,6 +531,7 @@ class BernoulliNB(BaseTransformer):
523
531
  """
524
532
  self._infer_input_output_cols(dataset)
525
533
  super()._check_dataset_type(dataset)
534
+
526
535
  model_trainer = ModelTrainerBuilder.build_fit_transform(
527
536
  estimator=self._sklearn_object,
528
537
  dataset=dataset,
@@ -579,12 +588,41 @@ class BernoulliNB(BaseTransformer):
579
588
 
580
589
  return rv
581
590
 
582
- def _align_expected_output_names(
583
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
584
- ) -> List[str]:
591
+ def _align_expected_output(
592
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
593
+ ) -> Tuple[List[str], pd.DataFrame]:
594
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
595
+ and output dataframe with 1 line.
596
+ If the method is fit_predict, run 2 lines of data.
597
+ """
585
598
  # in case the inferred output column names dimension is different
586
599
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
587
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
600
+
601
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
602
+ # so change the minimum of number of rows to 2
603
+ num_examples = 2
604
+ statement_params = telemetry.get_function_usage_statement_params(
605
+ project=_PROJECT,
606
+ subproject=_SUBPROJECT,
607
+ function_name=telemetry.get_statement_params_full_func_name(
608
+ inspect.currentframe(), BernoulliNB.__class__.__name__
609
+ ),
610
+ api_calls=[Session.call],
611
+ custom_tags={"autogen": True} if self._autogenerated else None,
612
+ )
613
+ if output_cols_prefix == "fit_predict_":
614
+ if hasattr(self._sklearn_object, "n_clusters"):
615
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
616
+ num_examples = self._sklearn_object.n_clusters
617
+ elif hasattr(self._sklearn_object, "min_samples"):
618
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
619
+ num_examples = self._sklearn_object.min_samples
620
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
621
+ # LocalOutlierFactor expects n_neighbors <= n_samples
622
+ num_examples = self._sklearn_object.n_neighbors
623
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
624
+ else:
625
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
588
626
 
589
627
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
590
628
  # seen during the fit.
@@ -596,12 +634,14 @@ class BernoulliNB(BaseTransformer):
596
634
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
597
635
  if self.sample_weight_col:
598
636
  output_df_columns_set -= set(self.sample_weight_col)
637
+
599
638
  # if the dimension of inferred output column names is correct; use it
600
639
  if len(expected_output_cols_list) == len(output_df_columns_set):
601
- return expected_output_cols_list
640
+ return expected_output_cols_list, output_df_pd
602
641
  # otherwise, use the sklearn estimator's output
603
642
  else:
604
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
643
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
644
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
605
645
 
606
646
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
607
647
  @telemetry.send_api_usage_telemetry(
@@ -649,7 +689,7 @@ class BernoulliNB(BaseTransformer):
649
689
  drop_input_cols=self._drop_input_cols,
650
690
  expected_output_cols_type="float",
651
691
  )
652
- expected_output_cols = self._align_expected_output_names(
692
+ expected_output_cols, _ = self._align_expected_output(
653
693
  inference_method, dataset, expected_output_cols, output_cols_prefix
654
694
  )
655
695
 
@@ -717,7 +757,7 @@ class BernoulliNB(BaseTransformer):
717
757
  drop_input_cols=self._drop_input_cols,
718
758
  expected_output_cols_type="float",
719
759
  )
720
- expected_output_cols = self._align_expected_output_names(
760
+ expected_output_cols, _ = self._align_expected_output(
721
761
  inference_method, dataset, expected_output_cols, output_cols_prefix
722
762
  )
723
763
  elif isinstance(dataset, pd.DataFrame):
@@ -780,7 +820,7 @@ class BernoulliNB(BaseTransformer):
780
820
  drop_input_cols=self._drop_input_cols,
781
821
  expected_output_cols_type="float",
782
822
  )
783
- expected_output_cols = self._align_expected_output_names(
823
+ expected_output_cols, _ = self._align_expected_output(
784
824
  inference_method, dataset, expected_output_cols, output_cols_prefix
785
825
  )
786
826
 
@@ -845,7 +885,7 @@ class BernoulliNB(BaseTransformer):
845
885
  drop_input_cols = self._drop_input_cols,
846
886
  expected_output_cols_type="float",
847
887
  )
848
- expected_output_cols = self._align_expected_output_names(
888
+ expected_output_cols, _ = self._align_expected_output(
849
889
  inference_method, dataset, expected_output_cols, output_cols_prefix
850
890
  )
851
891
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -501,12 +498,23 @@ class CategoricalNB(BaseTransformer):
501
498
  autogenerated=self._autogenerated,
502
499
  subproject=_SUBPROJECT,
503
500
  )
504
- output_result, fitted_estimator = model_trainer.train_fit_predict(
505
- drop_input_cols=self._drop_input_cols,
506
- expected_output_cols_list=(
507
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
508
- ),
501
+ expected_output_cols = (
502
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
509
503
  )
504
+ if isinstance(dataset, DataFrame):
505
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
506
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
507
+ )
508
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
509
+ drop_input_cols=self._drop_input_cols,
510
+ expected_output_cols_list=expected_output_cols,
511
+ example_output_pd_df=example_output_pd_df,
512
+ )
513
+ else:
514
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
515
+ drop_input_cols=self._drop_input_cols,
516
+ expected_output_cols_list=expected_output_cols,
517
+ )
510
518
  self._sklearn_object = fitted_estimator
511
519
  self._is_fitted = True
512
520
  return output_result
@@ -529,6 +537,7 @@ class CategoricalNB(BaseTransformer):
529
537
  """
530
538
  self._infer_input_output_cols(dataset)
531
539
  super()._check_dataset_type(dataset)
540
+
532
541
  model_trainer = ModelTrainerBuilder.build_fit_transform(
533
542
  estimator=self._sklearn_object,
534
543
  dataset=dataset,
@@ -585,12 +594,41 @@ class CategoricalNB(BaseTransformer):
585
594
 
586
595
  return rv
587
596
 
588
- def _align_expected_output_names(
589
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
590
- ) -> List[str]:
597
+ def _align_expected_output(
598
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
599
+ ) -> Tuple[List[str], pd.DataFrame]:
600
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
601
+ and output dataframe with 1 line.
602
+ If the method is fit_predict, run 2 lines of data.
603
+ """
591
604
  # in case the inferred output column names dimension is different
592
605
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
593
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
606
+
607
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
608
+ # so change the minimum of number of rows to 2
609
+ num_examples = 2
610
+ statement_params = telemetry.get_function_usage_statement_params(
611
+ project=_PROJECT,
612
+ subproject=_SUBPROJECT,
613
+ function_name=telemetry.get_statement_params_full_func_name(
614
+ inspect.currentframe(), CategoricalNB.__class__.__name__
615
+ ),
616
+ api_calls=[Session.call],
617
+ custom_tags={"autogen": True} if self._autogenerated else None,
618
+ )
619
+ if output_cols_prefix == "fit_predict_":
620
+ if hasattr(self._sklearn_object, "n_clusters"):
621
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
622
+ num_examples = self._sklearn_object.n_clusters
623
+ elif hasattr(self._sklearn_object, "min_samples"):
624
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
625
+ num_examples = self._sklearn_object.min_samples
626
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
627
+ # LocalOutlierFactor expects n_neighbors <= n_samples
628
+ num_examples = self._sklearn_object.n_neighbors
629
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
630
+ else:
631
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
594
632
 
595
633
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
596
634
  # seen during the fit.
@@ -602,12 +640,14 @@ class CategoricalNB(BaseTransformer):
602
640
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
603
641
  if self.sample_weight_col:
604
642
  output_df_columns_set -= set(self.sample_weight_col)
643
+
605
644
  # if the dimension of inferred output column names is correct; use it
606
645
  if len(expected_output_cols_list) == len(output_df_columns_set):
607
- return expected_output_cols_list
646
+ return expected_output_cols_list, output_df_pd
608
647
  # otherwise, use the sklearn estimator's output
609
648
  else:
610
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
649
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
650
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
611
651
 
612
652
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
613
653
  @telemetry.send_api_usage_telemetry(
@@ -655,7 +695,7 @@ class CategoricalNB(BaseTransformer):
655
695
  drop_input_cols=self._drop_input_cols,
656
696
  expected_output_cols_type="float",
657
697
  )
658
- expected_output_cols = self._align_expected_output_names(
698
+ expected_output_cols, _ = self._align_expected_output(
659
699
  inference_method, dataset, expected_output_cols, output_cols_prefix
660
700
  )
661
701
 
@@ -723,7 +763,7 @@ class CategoricalNB(BaseTransformer):
723
763
  drop_input_cols=self._drop_input_cols,
724
764
  expected_output_cols_type="float",
725
765
  )
726
- expected_output_cols = self._align_expected_output_names(
766
+ expected_output_cols, _ = self._align_expected_output(
727
767
  inference_method, dataset, expected_output_cols, output_cols_prefix
728
768
  )
729
769
  elif isinstance(dataset, pd.DataFrame):
@@ -786,7 +826,7 @@ class CategoricalNB(BaseTransformer):
786
826
  drop_input_cols=self._drop_input_cols,
787
827
  expected_output_cols_type="float",
788
828
  )
789
- expected_output_cols = self._align_expected_output_names(
829
+ expected_output_cols, _ = self._align_expected_output(
790
830
  inference_method, dataset, expected_output_cols, output_cols_prefix
791
831
  )
792
832
 
@@ -851,7 +891,7 @@ class CategoricalNB(BaseTransformer):
851
891
  drop_input_cols = self._drop_input_cols,
852
892
  expected_output_cols_type="float",
853
893
  )
854
- expected_output_cols = self._align_expected_output_names(
894
+ expected_output_cols, _ = self._align_expected_output(
855
895
  inference_method, dataset, expected_output_cols, output_cols_prefix
856
896
  )
857
897