snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -594,12 +591,23 @@ class SpectralClustering(BaseTransformer):
594
591
  autogenerated=self._autogenerated,
595
592
  subproject=_SUBPROJECT,
596
593
  )
597
- output_result, fitted_estimator = model_trainer.train_fit_predict(
598
- drop_input_cols=self._drop_input_cols,
599
- expected_output_cols_list=(
600
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
601
- ),
594
+ expected_output_cols = (
595
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
602
596
  )
597
+ if isinstance(dataset, DataFrame):
598
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
599
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
600
+ )
601
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
602
+ drop_input_cols=self._drop_input_cols,
603
+ expected_output_cols_list=expected_output_cols,
604
+ example_output_pd_df=example_output_pd_df,
605
+ )
606
+ else:
607
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
608
+ drop_input_cols=self._drop_input_cols,
609
+ expected_output_cols_list=expected_output_cols,
610
+ )
603
611
  self._sklearn_object = fitted_estimator
604
612
  self._is_fitted = True
605
613
  return output_result
@@ -622,6 +630,7 @@ class SpectralClustering(BaseTransformer):
622
630
  """
623
631
  self._infer_input_output_cols(dataset)
624
632
  super()._check_dataset_type(dataset)
633
+
625
634
  model_trainer = ModelTrainerBuilder.build_fit_transform(
626
635
  estimator=self._sklearn_object,
627
636
  dataset=dataset,
@@ -678,12 +687,41 @@ class SpectralClustering(BaseTransformer):
678
687
 
679
688
  return rv
680
689
 
681
- def _align_expected_output_names(
682
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
683
- ) -> List[str]:
690
+ def _align_expected_output(
691
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
692
+ ) -> Tuple[List[str], pd.DataFrame]:
693
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
694
+ and output dataframe with 1 line.
695
+ If the method is fit_predict, run 2 lines of data.
696
+ """
684
697
  # in case the inferred output column names dimension is different
685
698
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
686
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
699
+
700
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
701
+ # so change the minimum of number of rows to 2
702
+ num_examples = 2
703
+ statement_params = telemetry.get_function_usage_statement_params(
704
+ project=_PROJECT,
705
+ subproject=_SUBPROJECT,
706
+ function_name=telemetry.get_statement_params_full_func_name(
707
+ inspect.currentframe(), SpectralClustering.__class__.__name__
708
+ ),
709
+ api_calls=[Session.call],
710
+ custom_tags={"autogen": True} if self._autogenerated else None,
711
+ )
712
+ if output_cols_prefix == "fit_predict_":
713
+ if hasattr(self._sklearn_object, "n_clusters"):
714
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
715
+ num_examples = self._sklearn_object.n_clusters
716
+ elif hasattr(self._sklearn_object, "min_samples"):
717
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
718
+ num_examples = self._sklearn_object.min_samples
719
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
720
+ # LocalOutlierFactor expects n_neighbors <= n_samples
721
+ num_examples = self._sklearn_object.n_neighbors
722
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
723
+ else:
724
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
687
725
 
688
726
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
689
727
  # seen during the fit.
@@ -695,12 +733,14 @@ class SpectralClustering(BaseTransformer):
695
733
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
696
734
  if self.sample_weight_col:
697
735
  output_df_columns_set -= set(self.sample_weight_col)
736
+
698
737
  # if the dimension of inferred output column names is correct; use it
699
738
  if len(expected_output_cols_list) == len(output_df_columns_set):
700
- return expected_output_cols_list
739
+ return expected_output_cols_list, output_df_pd
701
740
  # otherwise, use the sklearn estimator's output
702
741
  else:
703
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
742
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
743
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
704
744
 
705
745
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
706
746
  @telemetry.send_api_usage_telemetry(
@@ -746,7 +786,7 @@ class SpectralClustering(BaseTransformer):
746
786
  drop_input_cols=self._drop_input_cols,
747
787
  expected_output_cols_type="float",
748
788
  )
749
- expected_output_cols = self._align_expected_output_names(
789
+ expected_output_cols, _ = self._align_expected_output(
750
790
  inference_method, dataset, expected_output_cols, output_cols_prefix
751
791
  )
752
792
 
@@ -812,7 +852,7 @@ class SpectralClustering(BaseTransformer):
812
852
  drop_input_cols=self._drop_input_cols,
813
853
  expected_output_cols_type="float",
814
854
  )
815
- expected_output_cols = self._align_expected_output_names(
855
+ expected_output_cols, _ = self._align_expected_output(
816
856
  inference_method, dataset, expected_output_cols, output_cols_prefix
817
857
  )
818
858
  elif isinstance(dataset, pd.DataFrame):
@@ -875,7 +915,7 @@ class SpectralClustering(BaseTransformer):
875
915
  drop_input_cols=self._drop_input_cols,
876
916
  expected_output_cols_type="float",
877
917
  )
878
- expected_output_cols = self._align_expected_output_names(
918
+ expected_output_cols, _ = self._align_expected_output(
879
919
  inference_method, dataset, expected_output_cols, output_cols_prefix
880
920
  )
881
921
 
@@ -940,7 +980,7 @@ class SpectralClustering(BaseTransformer):
940
980
  drop_input_cols = self._drop_input_cols,
941
981
  expected_output_cols_type="float",
942
982
  )
943
- expected_output_cols = self._align_expected_output_names(
983
+ expected_output_cols, _ = self._align_expected_output(
944
984
  inference_method, dataset, expected_output_cols, output_cols_prefix
945
985
  )
946
986
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -513,12 +510,23 @@ class SpectralCoclustering(BaseTransformer):
513
510
  autogenerated=self._autogenerated,
514
511
  subproject=_SUBPROJECT,
515
512
  )
516
- output_result, fitted_estimator = model_trainer.train_fit_predict(
517
- drop_input_cols=self._drop_input_cols,
518
- expected_output_cols_list=(
519
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
520
- ),
513
+ expected_output_cols = (
514
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
521
515
  )
516
+ if isinstance(dataset, DataFrame):
517
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
518
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
519
+ )
520
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
521
+ drop_input_cols=self._drop_input_cols,
522
+ expected_output_cols_list=expected_output_cols,
523
+ example_output_pd_df=example_output_pd_df,
524
+ )
525
+ else:
526
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
527
+ drop_input_cols=self._drop_input_cols,
528
+ expected_output_cols_list=expected_output_cols,
529
+ )
522
530
  self._sklearn_object = fitted_estimator
523
531
  self._is_fitted = True
524
532
  return output_result
@@ -541,6 +549,7 @@ class SpectralCoclustering(BaseTransformer):
541
549
  """
542
550
  self._infer_input_output_cols(dataset)
543
551
  super()._check_dataset_type(dataset)
552
+
544
553
  model_trainer = ModelTrainerBuilder.build_fit_transform(
545
554
  estimator=self._sklearn_object,
546
555
  dataset=dataset,
@@ -597,12 +606,41 @@ class SpectralCoclustering(BaseTransformer):
597
606
 
598
607
  return rv
599
608
 
600
- def _align_expected_output_names(
601
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
602
- ) -> List[str]:
609
+ def _align_expected_output(
610
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
611
+ ) -> Tuple[List[str], pd.DataFrame]:
612
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
613
+ and output dataframe with 1 line.
614
+ If the method is fit_predict, run 2 lines of data.
615
+ """
603
616
  # in case the inferred output column names dimension is different
604
617
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
605
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
618
+
619
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
620
+ # so change the minimum of number of rows to 2
621
+ num_examples = 2
622
+ statement_params = telemetry.get_function_usage_statement_params(
623
+ project=_PROJECT,
624
+ subproject=_SUBPROJECT,
625
+ function_name=telemetry.get_statement_params_full_func_name(
626
+ inspect.currentframe(), SpectralCoclustering.__class__.__name__
627
+ ),
628
+ api_calls=[Session.call],
629
+ custom_tags={"autogen": True} if self._autogenerated else None,
630
+ )
631
+ if output_cols_prefix == "fit_predict_":
632
+ if hasattr(self._sklearn_object, "n_clusters"):
633
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
634
+ num_examples = self._sklearn_object.n_clusters
635
+ elif hasattr(self._sklearn_object, "min_samples"):
636
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
637
+ num_examples = self._sklearn_object.min_samples
638
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
639
+ # LocalOutlierFactor expects n_neighbors <= n_samples
640
+ num_examples = self._sklearn_object.n_neighbors
641
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
642
+ else:
643
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
606
644
 
607
645
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
608
646
  # seen during the fit.
@@ -614,12 +652,14 @@ class SpectralCoclustering(BaseTransformer):
614
652
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
615
653
  if self.sample_weight_col:
616
654
  output_df_columns_set -= set(self.sample_weight_col)
655
+
617
656
  # if the dimension of inferred output column names is correct; use it
618
657
  if len(expected_output_cols_list) == len(output_df_columns_set):
619
- return expected_output_cols_list
658
+ return expected_output_cols_list, output_df_pd
620
659
  # otherwise, use the sklearn estimator's output
621
660
  else:
622
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
661
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
662
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
623
663
 
624
664
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
625
665
  @telemetry.send_api_usage_telemetry(
@@ -665,7 +705,7 @@ class SpectralCoclustering(BaseTransformer):
665
705
  drop_input_cols=self._drop_input_cols,
666
706
  expected_output_cols_type="float",
667
707
  )
668
- expected_output_cols = self._align_expected_output_names(
708
+ expected_output_cols, _ = self._align_expected_output(
669
709
  inference_method, dataset, expected_output_cols, output_cols_prefix
670
710
  )
671
711
 
@@ -731,7 +771,7 @@ class SpectralCoclustering(BaseTransformer):
731
771
  drop_input_cols=self._drop_input_cols,
732
772
  expected_output_cols_type="float",
733
773
  )
734
- expected_output_cols = self._align_expected_output_names(
774
+ expected_output_cols, _ = self._align_expected_output(
735
775
  inference_method, dataset, expected_output_cols, output_cols_prefix
736
776
  )
737
777
  elif isinstance(dataset, pd.DataFrame):
@@ -794,7 +834,7 @@ class SpectralCoclustering(BaseTransformer):
794
834
  drop_input_cols=self._drop_input_cols,
795
835
  expected_output_cols_type="float",
796
836
  )
797
- expected_output_cols = self._align_expected_output_names(
837
+ expected_output_cols, _ = self._align_expected_output(
798
838
  inference_method, dataset, expected_output_cols, output_cols_prefix
799
839
  )
800
840
 
@@ -859,7 +899,7 @@ class SpectralCoclustering(BaseTransformer):
859
899
  drop_input_cols = self._drop_input_cols,
860
900
  expected_output_cols_type="float",
861
901
  )
862
- expected_output_cols = self._align_expected_output_names(
902
+ expected_output_cols, _ = self._align_expected_output(
863
903
  inference_method, dataset, expected_output_cols, output_cols_prefix
864
904
  )
865
905
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -545,12 +542,23 @@ class ColumnTransformer(BaseTransformer):
545
542
  autogenerated=self._autogenerated,
546
543
  subproject=_SUBPROJECT,
547
544
  )
548
- output_result, fitted_estimator = model_trainer.train_fit_predict(
549
- drop_input_cols=self._drop_input_cols,
550
- expected_output_cols_list=(
551
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
552
- ),
545
+ expected_output_cols = (
546
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
553
547
  )
548
+ if isinstance(dataset, DataFrame):
549
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
550
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
551
+ )
552
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
553
+ drop_input_cols=self._drop_input_cols,
554
+ expected_output_cols_list=expected_output_cols,
555
+ example_output_pd_df=example_output_pd_df,
556
+ )
557
+ else:
558
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
559
+ drop_input_cols=self._drop_input_cols,
560
+ expected_output_cols_list=expected_output_cols,
561
+ )
554
562
  self._sklearn_object = fitted_estimator
555
563
  self._is_fitted = True
556
564
  return output_result
@@ -575,6 +583,7 @@ class ColumnTransformer(BaseTransformer):
575
583
  """
576
584
  self._infer_input_output_cols(dataset)
577
585
  super()._check_dataset_type(dataset)
586
+
578
587
  model_trainer = ModelTrainerBuilder.build_fit_transform(
579
588
  estimator=self._sklearn_object,
580
589
  dataset=dataset,
@@ -631,12 +640,41 @@ class ColumnTransformer(BaseTransformer):
631
640
 
632
641
  return rv
633
642
 
634
- def _align_expected_output_names(
635
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
636
- ) -> List[str]:
643
+ def _align_expected_output(
644
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
645
+ ) -> Tuple[List[str], pd.DataFrame]:
646
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
647
+ and output dataframe with 1 line.
648
+ If the method is fit_predict, run 2 lines of data.
649
+ """
637
650
  # in case the inferred output column names dimension is different
638
651
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
639
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
652
+
653
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
654
+ # so change the minimum of number of rows to 2
655
+ num_examples = 2
656
+ statement_params = telemetry.get_function_usage_statement_params(
657
+ project=_PROJECT,
658
+ subproject=_SUBPROJECT,
659
+ function_name=telemetry.get_statement_params_full_func_name(
660
+ inspect.currentframe(), ColumnTransformer.__class__.__name__
661
+ ),
662
+ api_calls=[Session.call],
663
+ custom_tags={"autogen": True} if self._autogenerated else None,
664
+ )
665
+ if output_cols_prefix == "fit_predict_":
666
+ if hasattr(self._sklearn_object, "n_clusters"):
667
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
668
+ num_examples = self._sklearn_object.n_clusters
669
+ elif hasattr(self._sklearn_object, "min_samples"):
670
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
671
+ num_examples = self._sklearn_object.min_samples
672
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
673
+ # LocalOutlierFactor expects n_neighbors <= n_samples
674
+ num_examples = self._sklearn_object.n_neighbors
675
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
676
+ else:
677
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
640
678
 
641
679
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
642
680
  # seen during the fit.
@@ -648,12 +686,14 @@ class ColumnTransformer(BaseTransformer):
648
686
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
649
687
  if self.sample_weight_col:
650
688
  output_df_columns_set -= set(self.sample_weight_col)
689
+
651
690
  # if the dimension of inferred output column names is correct; use it
652
691
  if len(expected_output_cols_list) == len(output_df_columns_set):
653
- return expected_output_cols_list
692
+ return expected_output_cols_list, output_df_pd
654
693
  # otherwise, use the sklearn estimator's output
655
694
  else:
656
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
695
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
696
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
657
697
 
658
698
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
659
699
  @telemetry.send_api_usage_telemetry(
@@ -699,7 +739,7 @@ class ColumnTransformer(BaseTransformer):
699
739
  drop_input_cols=self._drop_input_cols,
700
740
  expected_output_cols_type="float",
701
741
  )
702
- expected_output_cols = self._align_expected_output_names(
742
+ expected_output_cols, _ = self._align_expected_output(
703
743
  inference_method, dataset, expected_output_cols, output_cols_prefix
704
744
  )
705
745
 
@@ -765,7 +805,7 @@ class ColumnTransformer(BaseTransformer):
765
805
  drop_input_cols=self._drop_input_cols,
766
806
  expected_output_cols_type="float",
767
807
  )
768
- expected_output_cols = self._align_expected_output_names(
808
+ expected_output_cols, _ = self._align_expected_output(
769
809
  inference_method, dataset, expected_output_cols, output_cols_prefix
770
810
  )
771
811
  elif isinstance(dataset, pd.DataFrame):
@@ -828,7 +868,7 @@ class ColumnTransformer(BaseTransformer):
828
868
  drop_input_cols=self._drop_input_cols,
829
869
  expected_output_cols_type="float",
830
870
  )
831
- expected_output_cols = self._align_expected_output_names(
871
+ expected_output_cols, _ = self._align_expected_output(
832
872
  inference_method, dataset, expected_output_cols, output_cols_prefix
833
873
  )
834
874
 
@@ -893,7 +933,7 @@ class ColumnTransformer(BaseTransformer):
893
933
  drop_input_cols = self._drop_input_cols,
894
934
  expected_output_cols_type="float",
895
935
  )
896
- expected_output_cols = self._align_expected_output_names(
936
+ expected_output_cols, _ = self._align_expected_output(
897
937
  inference_method, dataset, expected_output_cols, output_cols_prefix
898
938
  )
899
939
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -504,12 +501,23 @@ class TransformedTargetRegressor(BaseTransformer):
504
501
  autogenerated=self._autogenerated,
505
502
  subproject=_SUBPROJECT,
506
503
  )
507
- output_result, fitted_estimator = model_trainer.train_fit_predict(
508
- drop_input_cols=self._drop_input_cols,
509
- expected_output_cols_list=(
510
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
511
- ),
504
+ expected_output_cols = (
505
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
512
506
  )
507
+ if isinstance(dataset, DataFrame):
508
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
509
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
510
+ )
511
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
512
+ drop_input_cols=self._drop_input_cols,
513
+ expected_output_cols_list=expected_output_cols,
514
+ example_output_pd_df=example_output_pd_df,
515
+ )
516
+ else:
517
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
518
+ drop_input_cols=self._drop_input_cols,
519
+ expected_output_cols_list=expected_output_cols,
520
+ )
513
521
  self._sklearn_object = fitted_estimator
514
522
  self._is_fitted = True
515
523
  return output_result
@@ -532,6 +540,7 @@ class TransformedTargetRegressor(BaseTransformer):
532
540
  """
533
541
  self._infer_input_output_cols(dataset)
534
542
  super()._check_dataset_type(dataset)
543
+
535
544
  model_trainer = ModelTrainerBuilder.build_fit_transform(
536
545
  estimator=self._sklearn_object,
537
546
  dataset=dataset,
@@ -588,12 +597,41 @@ class TransformedTargetRegressor(BaseTransformer):
588
597
 
589
598
  return rv
590
599
 
591
- def _align_expected_output_names(
592
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
593
- ) -> List[str]:
600
+ def _align_expected_output(
601
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
602
+ ) -> Tuple[List[str], pd.DataFrame]:
603
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
604
+ and output dataframe with 1 line.
605
+ If the method is fit_predict, run 2 lines of data.
606
+ """
594
607
  # in case the inferred output column names dimension is different
595
608
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
596
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
609
+
610
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
611
+ # so change the minimum of number of rows to 2
612
+ num_examples = 2
613
+ statement_params = telemetry.get_function_usage_statement_params(
614
+ project=_PROJECT,
615
+ subproject=_SUBPROJECT,
616
+ function_name=telemetry.get_statement_params_full_func_name(
617
+ inspect.currentframe(), TransformedTargetRegressor.__class__.__name__
618
+ ),
619
+ api_calls=[Session.call],
620
+ custom_tags={"autogen": True} if self._autogenerated else None,
621
+ )
622
+ if output_cols_prefix == "fit_predict_":
623
+ if hasattr(self._sklearn_object, "n_clusters"):
624
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
625
+ num_examples = self._sklearn_object.n_clusters
626
+ elif hasattr(self._sklearn_object, "min_samples"):
627
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
628
+ num_examples = self._sklearn_object.min_samples
629
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
630
+ # LocalOutlierFactor expects n_neighbors <= n_samples
631
+ num_examples = self._sklearn_object.n_neighbors
632
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
633
+ else:
634
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
597
635
 
598
636
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
599
637
  # seen during the fit.
@@ -605,12 +643,14 @@ class TransformedTargetRegressor(BaseTransformer):
605
643
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
606
644
  if self.sample_weight_col:
607
645
  output_df_columns_set -= set(self.sample_weight_col)
646
+
608
647
  # if the dimension of inferred output column names is correct; use it
609
648
  if len(expected_output_cols_list) == len(output_df_columns_set):
610
- return expected_output_cols_list
649
+ return expected_output_cols_list, output_df_pd
611
650
  # otherwise, use the sklearn estimator's output
612
651
  else:
613
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
652
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
653
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
614
654
 
615
655
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
616
656
  @telemetry.send_api_usage_telemetry(
@@ -656,7 +696,7 @@ class TransformedTargetRegressor(BaseTransformer):
656
696
  drop_input_cols=self._drop_input_cols,
657
697
  expected_output_cols_type="float",
658
698
  )
659
- expected_output_cols = self._align_expected_output_names(
699
+ expected_output_cols, _ = self._align_expected_output(
660
700
  inference_method, dataset, expected_output_cols, output_cols_prefix
661
701
  )
662
702
 
@@ -722,7 +762,7 @@ class TransformedTargetRegressor(BaseTransformer):
722
762
  drop_input_cols=self._drop_input_cols,
723
763
  expected_output_cols_type="float",
724
764
  )
725
- expected_output_cols = self._align_expected_output_names(
765
+ expected_output_cols, _ = self._align_expected_output(
726
766
  inference_method, dataset, expected_output_cols, output_cols_prefix
727
767
  )
728
768
  elif isinstance(dataset, pd.DataFrame):
@@ -785,7 +825,7 @@ class TransformedTargetRegressor(BaseTransformer):
785
825
  drop_input_cols=self._drop_input_cols,
786
826
  expected_output_cols_type="float",
787
827
  )
788
- expected_output_cols = self._align_expected_output_names(
828
+ expected_output_cols, _ = self._align_expected_output(
789
829
  inference_method, dataset, expected_output_cols, output_cols_prefix
790
830
  )
791
831
 
@@ -850,7 +890,7 @@ class TransformedTargetRegressor(BaseTransformer):
850
890
  drop_input_cols = self._drop_input_cols,
851
891
  expected_output_cols_type="float",
852
892
  )
853
- expected_output_cols = self._align_expected_output_names(
893
+ expected_output_cols, _ = self._align_expected_output(
854
894
  inference_method, dataset, expected_output_cols, output_cols_prefix
855
895
  )
856
896