snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -499,12 +496,23 @@ class MissingIndicator(BaseTransformer):
499
496
  autogenerated=self._autogenerated,
500
497
  subproject=_SUBPROJECT,
501
498
  )
502
- output_result, fitted_estimator = model_trainer.train_fit_predict(
503
- drop_input_cols=self._drop_input_cols,
504
- expected_output_cols_list=(
505
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
506
- ),
499
+ expected_output_cols = (
500
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
507
501
  )
502
+ if isinstance(dataset, DataFrame):
503
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
504
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
505
+ )
506
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
507
+ drop_input_cols=self._drop_input_cols,
508
+ expected_output_cols_list=expected_output_cols,
509
+ example_output_pd_df=example_output_pd_df,
510
+ )
511
+ else:
512
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
513
+ drop_input_cols=self._drop_input_cols,
514
+ expected_output_cols_list=expected_output_cols,
515
+ )
508
516
  self._sklearn_object = fitted_estimator
509
517
  self._is_fitted = True
510
518
  return output_result
@@ -529,6 +537,7 @@ class MissingIndicator(BaseTransformer):
529
537
  """
530
538
  self._infer_input_output_cols(dataset)
531
539
  super()._check_dataset_type(dataset)
540
+
532
541
  model_trainer = ModelTrainerBuilder.build_fit_transform(
533
542
  estimator=self._sklearn_object,
534
543
  dataset=dataset,
@@ -585,12 +594,41 @@ class MissingIndicator(BaseTransformer):
585
594
 
586
595
  return rv
587
596
 
588
- def _align_expected_output_names(
589
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
590
- ) -> List[str]:
597
+ def _align_expected_output(
598
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
599
+ ) -> Tuple[List[str], pd.DataFrame]:
600
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
601
+ and output dataframe with 1 line.
602
+ If the method is fit_predict, run 2 lines of data.
603
+ """
591
604
  # in case the inferred output column names dimension is different
592
605
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
593
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
606
+
607
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
608
+ # so change the minimum of number of rows to 2
609
+ num_examples = 2
610
+ statement_params = telemetry.get_function_usage_statement_params(
611
+ project=_PROJECT,
612
+ subproject=_SUBPROJECT,
613
+ function_name=telemetry.get_statement_params_full_func_name(
614
+ inspect.currentframe(), MissingIndicator.__class__.__name__
615
+ ),
616
+ api_calls=[Session.call],
617
+ custom_tags={"autogen": True} if self._autogenerated else None,
618
+ )
619
+ if output_cols_prefix == "fit_predict_":
620
+ if hasattr(self._sklearn_object, "n_clusters"):
621
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
622
+ num_examples = self._sklearn_object.n_clusters
623
+ elif hasattr(self._sklearn_object, "min_samples"):
624
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
625
+ num_examples = self._sklearn_object.min_samples
626
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
627
+ # LocalOutlierFactor expects n_neighbors <= n_samples
628
+ num_examples = self._sklearn_object.n_neighbors
629
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
630
+ else:
631
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
594
632
 
595
633
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
596
634
  # seen during the fit.
@@ -602,12 +640,14 @@ class MissingIndicator(BaseTransformer):
602
640
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
603
641
  if self.sample_weight_col:
604
642
  output_df_columns_set -= set(self.sample_weight_col)
643
+
605
644
  # if the dimension of inferred output column names is correct; use it
606
645
  if len(expected_output_cols_list) == len(output_df_columns_set):
607
- return expected_output_cols_list
646
+ return expected_output_cols_list, output_df_pd
608
647
  # otherwise, use the sklearn estimator's output
609
648
  else:
610
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
649
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
650
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
611
651
 
612
652
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
613
653
  @telemetry.send_api_usage_telemetry(
@@ -653,7 +693,7 @@ class MissingIndicator(BaseTransformer):
653
693
  drop_input_cols=self._drop_input_cols,
654
694
  expected_output_cols_type="float",
655
695
  )
656
- expected_output_cols = self._align_expected_output_names(
696
+ expected_output_cols, _ = self._align_expected_output(
657
697
  inference_method, dataset, expected_output_cols, output_cols_prefix
658
698
  )
659
699
 
@@ -719,7 +759,7 @@ class MissingIndicator(BaseTransformer):
719
759
  drop_input_cols=self._drop_input_cols,
720
760
  expected_output_cols_type="float",
721
761
  )
722
- expected_output_cols = self._align_expected_output_names(
762
+ expected_output_cols, _ = self._align_expected_output(
723
763
  inference_method, dataset, expected_output_cols, output_cols_prefix
724
764
  )
725
765
  elif isinstance(dataset, pd.DataFrame):
@@ -782,7 +822,7 @@ class MissingIndicator(BaseTransformer):
782
822
  drop_input_cols=self._drop_input_cols,
783
823
  expected_output_cols_type="float",
784
824
  )
785
- expected_output_cols = self._align_expected_output_names(
825
+ expected_output_cols, _ = self._align_expected_output(
786
826
  inference_method, dataset, expected_output_cols, output_cols_prefix
787
827
  )
788
828
 
@@ -847,7 +887,7 @@ class MissingIndicator(BaseTransformer):
847
887
  drop_input_cols = self._drop_input_cols,
848
888
  expected_output_cols_type="float",
849
889
  )
850
- expected_output_cols = self._align_expected_output_names(
890
+ expected_output_cols, _ = self._align_expected_output(
851
891
  inference_method, dataset, expected_output_cols, output_cols_prefix
852
892
  )
853
893
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -474,12 +471,23 @@ class AdditiveChi2Sampler(BaseTransformer):
474
471
  autogenerated=self._autogenerated,
475
472
  subproject=_SUBPROJECT,
476
473
  )
477
- output_result, fitted_estimator = model_trainer.train_fit_predict(
478
- drop_input_cols=self._drop_input_cols,
479
- expected_output_cols_list=(
480
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
481
- ),
474
+ expected_output_cols = (
475
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
482
476
  )
477
+ if isinstance(dataset, DataFrame):
478
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
479
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
480
+ )
481
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
482
+ drop_input_cols=self._drop_input_cols,
483
+ expected_output_cols_list=expected_output_cols,
484
+ example_output_pd_df=example_output_pd_df,
485
+ )
486
+ else:
487
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
488
+ drop_input_cols=self._drop_input_cols,
489
+ expected_output_cols_list=expected_output_cols,
490
+ )
483
491
  self._sklearn_object = fitted_estimator
484
492
  self._is_fitted = True
485
493
  return output_result
@@ -504,6 +512,7 @@ class AdditiveChi2Sampler(BaseTransformer):
504
512
  """
505
513
  self._infer_input_output_cols(dataset)
506
514
  super()._check_dataset_type(dataset)
515
+
507
516
  model_trainer = ModelTrainerBuilder.build_fit_transform(
508
517
  estimator=self._sklearn_object,
509
518
  dataset=dataset,
@@ -560,12 +569,41 @@ class AdditiveChi2Sampler(BaseTransformer):
560
569
 
561
570
  return rv
562
571
 
563
- def _align_expected_output_names(
564
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
565
- ) -> List[str]:
572
+ def _align_expected_output(
573
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
574
+ ) -> Tuple[List[str], pd.DataFrame]:
575
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
576
+ and output dataframe with 1 line.
577
+ If the method is fit_predict, run 2 lines of data.
578
+ """
566
579
  # in case the inferred output column names dimension is different
567
580
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
568
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
581
+
582
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
583
+ # so change the minimum of number of rows to 2
584
+ num_examples = 2
585
+ statement_params = telemetry.get_function_usage_statement_params(
586
+ project=_PROJECT,
587
+ subproject=_SUBPROJECT,
588
+ function_name=telemetry.get_statement_params_full_func_name(
589
+ inspect.currentframe(), AdditiveChi2Sampler.__class__.__name__
590
+ ),
591
+ api_calls=[Session.call],
592
+ custom_tags={"autogen": True} if self._autogenerated else None,
593
+ )
594
+ if output_cols_prefix == "fit_predict_":
595
+ if hasattr(self._sklearn_object, "n_clusters"):
596
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
597
+ num_examples = self._sklearn_object.n_clusters
598
+ elif hasattr(self._sklearn_object, "min_samples"):
599
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
600
+ num_examples = self._sklearn_object.min_samples
601
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
602
+ # LocalOutlierFactor expects n_neighbors <= n_samples
603
+ num_examples = self._sklearn_object.n_neighbors
604
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
605
+ else:
606
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
569
607
 
570
608
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
571
609
  # seen during the fit.
@@ -577,12 +615,14 @@ class AdditiveChi2Sampler(BaseTransformer):
577
615
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
578
616
  if self.sample_weight_col:
579
617
  output_df_columns_set -= set(self.sample_weight_col)
618
+
580
619
  # if the dimension of inferred output column names is correct; use it
581
620
  if len(expected_output_cols_list) == len(output_df_columns_set):
582
- return expected_output_cols_list
621
+ return expected_output_cols_list, output_df_pd
583
622
  # otherwise, use the sklearn estimator's output
584
623
  else:
585
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
624
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
625
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
586
626
 
587
627
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
588
628
  @telemetry.send_api_usage_telemetry(
@@ -628,7 +668,7 @@ class AdditiveChi2Sampler(BaseTransformer):
628
668
  drop_input_cols=self._drop_input_cols,
629
669
  expected_output_cols_type="float",
630
670
  )
631
- expected_output_cols = self._align_expected_output_names(
671
+ expected_output_cols, _ = self._align_expected_output(
632
672
  inference_method, dataset, expected_output_cols, output_cols_prefix
633
673
  )
634
674
 
@@ -694,7 +734,7 @@ class AdditiveChi2Sampler(BaseTransformer):
694
734
  drop_input_cols=self._drop_input_cols,
695
735
  expected_output_cols_type="float",
696
736
  )
697
- expected_output_cols = self._align_expected_output_names(
737
+ expected_output_cols, _ = self._align_expected_output(
698
738
  inference_method, dataset, expected_output_cols, output_cols_prefix
699
739
  )
700
740
  elif isinstance(dataset, pd.DataFrame):
@@ -757,7 +797,7 @@ class AdditiveChi2Sampler(BaseTransformer):
757
797
  drop_input_cols=self._drop_input_cols,
758
798
  expected_output_cols_type="float",
759
799
  )
760
- expected_output_cols = self._align_expected_output_names(
800
+ expected_output_cols, _ = self._align_expected_output(
761
801
  inference_method, dataset, expected_output_cols, output_cols_prefix
762
802
  )
763
803
 
@@ -822,7 +862,7 @@ class AdditiveChi2Sampler(BaseTransformer):
822
862
  drop_input_cols = self._drop_input_cols,
823
863
  expected_output_cols_type="float",
824
864
  )
825
- expected_output_cols = self._align_expected_output_names(
865
+ expected_output_cols, _ = self._align_expected_output(
826
866
  inference_method, dataset, expected_output_cols, output_cols_prefix
827
867
  )
828
868
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -522,12 +519,23 @@ class Nystroem(BaseTransformer):
522
519
  autogenerated=self._autogenerated,
523
520
  subproject=_SUBPROJECT,
524
521
  )
525
- output_result, fitted_estimator = model_trainer.train_fit_predict(
526
- drop_input_cols=self._drop_input_cols,
527
- expected_output_cols_list=(
528
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
529
- ),
522
+ expected_output_cols = (
523
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
530
524
  )
525
+ if isinstance(dataset, DataFrame):
526
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
527
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
528
+ )
529
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
530
+ drop_input_cols=self._drop_input_cols,
531
+ expected_output_cols_list=expected_output_cols,
532
+ example_output_pd_df=example_output_pd_df,
533
+ )
534
+ else:
535
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
536
+ drop_input_cols=self._drop_input_cols,
537
+ expected_output_cols_list=expected_output_cols,
538
+ )
531
539
  self._sklearn_object = fitted_estimator
532
540
  self._is_fitted = True
533
541
  return output_result
@@ -552,6 +560,7 @@ class Nystroem(BaseTransformer):
552
560
  """
553
561
  self._infer_input_output_cols(dataset)
554
562
  super()._check_dataset_type(dataset)
563
+
555
564
  model_trainer = ModelTrainerBuilder.build_fit_transform(
556
565
  estimator=self._sklearn_object,
557
566
  dataset=dataset,
@@ -608,12 +617,41 @@ class Nystroem(BaseTransformer):
608
617
 
609
618
  return rv
610
619
 
611
- def _align_expected_output_names(
612
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
613
- ) -> List[str]:
620
+ def _align_expected_output(
621
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
622
+ ) -> Tuple[List[str], pd.DataFrame]:
623
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
624
+ and output dataframe with 1 line.
625
+ If the method is fit_predict, run 2 lines of data.
626
+ """
614
627
  # in case the inferred output column names dimension is different
615
628
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
616
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
629
+
630
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
631
+ # so change the minimum of number of rows to 2
632
+ num_examples = 2
633
+ statement_params = telemetry.get_function_usage_statement_params(
634
+ project=_PROJECT,
635
+ subproject=_SUBPROJECT,
636
+ function_name=telemetry.get_statement_params_full_func_name(
637
+ inspect.currentframe(), Nystroem.__class__.__name__
638
+ ),
639
+ api_calls=[Session.call],
640
+ custom_tags={"autogen": True} if self._autogenerated else None,
641
+ )
642
+ if output_cols_prefix == "fit_predict_":
643
+ if hasattr(self._sklearn_object, "n_clusters"):
644
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
645
+ num_examples = self._sklearn_object.n_clusters
646
+ elif hasattr(self._sklearn_object, "min_samples"):
647
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
648
+ num_examples = self._sklearn_object.min_samples
649
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
650
+ # LocalOutlierFactor expects n_neighbors <= n_samples
651
+ num_examples = self._sklearn_object.n_neighbors
652
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
653
+ else:
654
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
617
655
 
618
656
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
619
657
  # seen during the fit.
@@ -625,12 +663,14 @@ class Nystroem(BaseTransformer):
625
663
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
626
664
  if self.sample_weight_col:
627
665
  output_df_columns_set -= set(self.sample_weight_col)
666
+
628
667
  # if the dimension of inferred output column names is correct; use it
629
668
  if len(expected_output_cols_list) == len(output_df_columns_set):
630
- return expected_output_cols_list
669
+ return expected_output_cols_list, output_df_pd
631
670
  # otherwise, use the sklearn estimator's output
632
671
  else:
633
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
672
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
673
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
634
674
 
635
675
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
636
676
  @telemetry.send_api_usage_telemetry(
@@ -676,7 +716,7 @@ class Nystroem(BaseTransformer):
676
716
  drop_input_cols=self._drop_input_cols,
677
717
  expected_output_cols_type="float",
678
718
  )
679
- expected_output_cols = self._align_expected_output_names(
719
+ expected_output_cols, _ = self._align_expected_output(
680
720
  inference_method, dataset, expected_output_cols, output_cols_prefix
681
721
  )
682
722
 
@@ -742,7 +782,7 @@ class Nystroem(BaseTransformer):
742
782
  drop_input_cols=self._drop_input_cols,
743
783
  expected_output_cols_type="float",
744
784
  )
745
- expected_output_cols = self._align_expected_output_names(
785
+ expected_output_cols, _ = self._align_expected_output(
746
786
  inference_method, dataset, expected_output_cols, output_cols_prefix
747
787
  )
748
788
  elif isinstance(dataset, pd.DataFrame):
@@ -805,7 +845,7 @@ class Nystroem(BaseTransformer):
805
845
  drop_input_cols=self._drop_input_cols,
806
846
  expected_output_cols_type="float",
807
847
  )
808
- expected_output_cols = self._align_expected_output_names(
848
+ expected_output_cols, _ = self._align_expected_output(
809
849
  inference_method, dataset, expected_output_cols, output_cols_prefix
810
850
  )
811
851
 
@@ -870,7 +910,7 @@ class Nystroem(BaseTransformer):
870
910
  drop_input_cols = self._drop_input_cols,
871
911
  expected_output_cols_type="float",
872
912
  )
873
- expected_output_cols = self._align_expected_output_names(
913
+ expected_output_cols, _ = self._align_expected_output(
874
914
  inference_method, dataset, expected_output_cols, output_cols_prefix
875
915
  )
876
916
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -498,12 +495,23 @@ class PolynomialCountSketch(BaseTransformer):
498
495
  autogenerated=self._autogenerated,
499
496
  subproject=_SUBPROJECT,
500
497
  )
501
- output_result, fitted_estimator = model_trainer.train_fit_predict(
502
- drop_input_cols=self._drop_input_cols,
503
- expected_output_cols_list=(
504
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
505
- ),
498
+ expected_output_cols = (
499
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
506
500
  )
501
+ if isinstance(dataset, DataFrame):
502
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
503
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
504
+ )
505
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
506
+ drop_input_cols=self._drop_input_cols,
507
+ expected_output_cols_list=expected_output_cols,
508
+ example_output_pd_df=example_output_pd_df,
509
+ )
510
+ else:
511
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
512
+ drop_input_cols=self._drop_input_cols,
513
+ expected_output_cols_list=expected_output_cols,
514
+ )
507
515
  self._sklearn_object = fitted_estimator
508
516
  self._is_fitted = True
509
517
  return output_result
@@ -528,6 +536,7 @@ class PolynomialCountSketch(BaseTransformer):
528
536
  """
529
537
  self._infer_input_output_cols(dataset)
530
538
  super()._check_dataset_type(dataset)
539
+
531
540
  model_trainer = ModelTrainerBuilder.build_fit_transform(
532
541
  estimator=self._sklearn_object,
533
542
  dataset=dataset,
@@ -584,12 +593,41 @@ class PolynomialCountSketch(BaseTransformer):
584
593
 
585
594
  return rv
586
595
 
587
- def _align_expected_output_names(
588
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
589
- ) -> List[str]:
596
+ def _align_expected_output(
597
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
598
+ ) -> Tuple[List[str], pd.DataFrame]:
599
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
600
+ and output dataframe with 1 line.
601
+ If the method is fit_predict, run 2 lines of data.
602
+ """
590
603
  # in case the inferred output column names dimension is different
591
604
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
592
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
605
+
606
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
607
+ # so change the minimum of number of rows to 2
608
+ num_examples = 2
609
+ statement_params = telemetry.get_function_usage_statement_params(
610
+ project=_PROJECT,
611
+ subproject=_SUBPROJECT,
612
+ function_name=telemetry.get_statement_params_full_func_name(
613
+ inspect.currentframe(), PolynomialCountSketch.__class__.__name__
614
+ ),
615
+ api_calls=[Session.call],
616
+ custom_tags={"autogen": True} if self._autogenerated else None,
617
+ )
618
+ if output_cols_prefix == "fit_predict_":
619
+ if hasattr(self._sklearn_object, "n_clusters"):
620
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
621
+ num_examples = self._sklearn_object.n_clusters
622
+ elif hasattr(self._sklearn_object, "min_samples"):
623
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
624
+ num_examples = self._sklearn_object.min_samples
625
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
626
+ # LocalOutlierFactor expects n_neighbors <= n_samples
627
+ num_examples = self._sklearn_object.n_neighbors
628
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
629
+ else:
630
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
593
631
 
594
632
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
595
633
  # seen during the fit.
@@ -601,12 +639,14 @@ class PolynomialCountSketch(BaseTransformer):
601
639
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
602
640
  if self.sample_weight_col:
603
641
  output_df_columns_set -= set(self.sample_weight_col)
642
+
604
643
  # if the dimension of inferred output column names is correct; use it
605
644
  if len(expected_output_cols_list) == len(output_df_columns_set):
606
- return expected_output_cols_list
645
+ return expected_output_cols_list, output_df_pd
607
646
  # otherwise, use the sklearn estimator's output
608
647
  else:
609
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
648
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
649
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
610
650
 
611
651
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
612
652
  @telemetry.send_api_usage_telemetry(
@@ -652,7 +692,7 @@ class PolynomialCountSketch(BaseTransformer):
652
692
  drop_input_cols=self._drop_input_cols,
653
693
  expected_output_cols_type="float",
654
694
  )
655
- expected_output_cols = self._align_expected_output_names(
695
+ expected_output_cols, _ = self._align_expected_output(
656
696
  inference_method, dataset, expected_output_cols, output_cols_prefix
657
697
  )
658
698
 
@@ -718,7 +758,7 @@ class PolynomialCountSketch(BaseTransformer):
718
758
  drop_input_cols=self._drop_input_cols,
719
759
  expected_output_cols_type="float",
720
760
  )
721
- expected_output_cols = self._align_expected_output_names(
761
+ expected_output_cols, _ = self._align_expected_output(
722
762
  inference_method, dataset, expected_output_cols, output_cols_prefix
723
763
  )
724
764
  elif isinstance(dataset, pd.DataFrame):
@@ -781,7 +821,7 @@ class PolynomialCountSketch(BaseTransformer):
781
821
  drop_input_cols=self._drop_input_cols,
782
822
  expected_output_cols_type="float",
783
823
  )
784
- expected_output_cols = self._align_expected_output_names(
824
+ expected_output_cols, _ = self._align_expected_output(
785
825
  inference_method, dataset, expected_output_cols, output_cols_prefix
786
826
  )
787
827
 
@@ -846,7 +886,7 @@ class PolynomialCountSketch(BaseTransformer):
846
886
  drop_input_cols = self._drop_input_cols,
847
887
  expected_output_cols_type="float",
848
888
  )
849
- expected_output_cols = self._align_expected_output_names(
889
+ expected_output_cols, _ = self._align_expected_output(
850
890
  inference_method, dataset, expected_output_cols, output_cols_prefix
851
891
  )
852
892