snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -553,12 +550,23 @@ class Isomap(BaseTransformer):
553
550
  autogenerated=self._autogenerated,
554
551
  subproject=_SUBPROJECT,
555
552
  )
556
- output_result, fitted_estimator = model_trainer.train_fit_predict(
557
- drop_input_cols=self._drop_input_cols,
558
- expected_output_cols_list=(
559
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
560
- ),
553
+ expected_output_cols = (
554
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
561
555
  )
556
+ if isinstance(dataset, DataFrame):
557
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
558
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
559
+ )
560
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
561
+ drop_input_cols=self._drop_input_cols,
562
+ expected_output_cols_list=expected_output_cols,
563
+ example_output_pd_df=example_output_pd_df,
564
+ )
565
+ else:
566
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
567
+ drop_input_cols=self._drop_input_cols,
568
+ expected_output_cols_list=expected_output_cols,
569
+ )
562
570
  self._sklearn_object = fitted_estimator
563
571
  self._is_fitted = True
564
572
  return output_result
@@ -583,6 +591,7 @@ class Isomap(BaseTransformer):
583
591
  """
584
592
  self._infer_input_output_cols(dataset)
585
593
  super()._check_dataset_type(dataset)
594
+
586
595
  model_trainer = ModelTrainerBuilder.build_fit_transform(
587
596
  estimator=self._sklearn_object,
588
597
  dataset=dataset,
@@ -639,12 +648,41 @@ class Isomap(BaseTransformer):
639
648
 
640
649
  return rv
641
650
 
642
- def _align_expected_output_names(
643
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
644
- ) -> List[str]:
651
+ def _align_expected_output(
652
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
653
+ ) -> Tuple[List[str], pd.DataFrame]:
654
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
655
+ and output dataframe with 1 line.
656
+ If the method is fit_predict, run 2 lines of data.
657
+ """
645
658
  # in case the inferred output column names dimension is different
646
659
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
647
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
660
+
661
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
662
+ # so change the minimum of number of rows to 2
663
+ num_examples = 2
664
+ statement_params = telemetry.get_function_usage_statement_params(
665
+ project=_PROJECT,
666
+ subproject=_SUBPROJECT,
667
+ function_name=telemetry.get_statement_params_full_func_name(
668
+ inspect.currentframe(), Isomap.__class__.__name__
669
+ ),
670
+ api_calls=[Session.call],
671
+ custom_tags={"autogen": True} if self._autogenerated else None,
672
+ )
673
+ if output_cols_prefix == "fit_predict_":
674
+ if hasattr(self._sklearn_object, "n_clusters"):
675
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
676
+ num_examples = self._sklearn_object.n_clusters
677
+ elif hasattr(self._sklearn_object, "min_samples"):
678
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
679
+ num_examples = self._sklearn_object.min_samples
680
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
681
+ # LocalOutlierFactor expects n_neighbors <= n_samples
682
+ num_examples = self._sklearn_object.n_neighbors
683
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
684
+ else:
685
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
648
686
 
649
687
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
650
688
  # seen during the fit.
@@ -656,12 +694,14 @@ class Isomap(BaseTransformer):
656
694
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
657
695
  if self.sample_weight_col:
658
696
  output_df_columns_set -= set(self.sample_weight_col)
697
+
659
698
  # if the dimension of inferred output column names is correct; use it
660
699
  if len(expected_output_cols_list) == len(output_df_columns_set):
661
- return expected_output_cols_list
700
+ return expected_output_cols_list, output_df_pd
662
701
  # otherwise, use the sklearn estimator's output
663
702
  else:
664
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
703
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
704
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
665
705
 
666
706
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
667
707
  @telemetry.send_api_usage_telemetry(
@@ -707,7 +747,7 @@ class Isomap(BaseTransformer):
707
747
  drop_input_cols=self._drop_input_cols,
708
748
  expected_output_cols_type="float",
709
749
  )
710
- expected_output_cols = self._align_expected_output_names(
750
+ expected_output_cols, _ = self._align_expected_output(
711
751
  inference_method, dataset, expected_output_cols, output_cols_prefix
712
752
  )
713
753
 
@@ -773,7 +813,7 @@ class Isomap(BaseTransformer):
773
813
  drop_input_cols=self._drop_input_cols,
774
814
  expected_output_cols_type="float",
775
815
  )
776
- expected_output_cols = self._align_expected_output_names(
816
+ expected_output_cols, _ = self._align_expected_output(
777
817
  inference_method, dataset, expected_output_cols, output_cols_prefix
778
818
  )
779
819
  elif isinstance(dataset, pd.DataFrame):
@@ -836,7 +876,7 @@ class Isomap(BaseTransformer):
836
876
  drop_input_cols=self._drop_input_cols,
837
877
  expected_output_cols_type="float",
838
878
  )
839
- expected_output_cols = self._align_expected_output_names(
879
+ expected_output_cols, _ = self._align_expected_output(
840
880
  inference_method, dataset, expected_output_cols, output_cols_prefix
841
881
  )
842
882
 
@@ -901,7 +941,7 @@ class Isomap(BaseTransformer):
901
941
  drop_input_cols = self._drop_input_cols,
902
942
  expected_output_cols_type="float",
903
943
  )
904
- expected_output_cols = self._align_expected_output_names(
944
+ expected_output_cols, _ = self._align_expected_output(
905
945
  inference_method, dataset, expected_output_cols, output_cols_prefix
906
946
  )
907
947
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -534,12 +531,23 @@ class MDS(BaseTransformer):
534
531
  autogenerated=self._autogenerated,
535
532
  subproject=_SUBPROJECT,
536
533
  )
537
- output_result, fitted_estimator = model_trainer.train_fit_predict(
538
- drop_input_cols=self._drop_input_cols,
539
- expected_output_cols_list=(
540
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
541
- ),
534
+ expected_output_cols = (
535
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
542
536
  )
537
+ if isinstance(dataset, DataFrame):
538
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
539
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
540
+ )
541
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
542
+ drop_input_cols=self._drop_input_cols,
543
+ expected_output_cols_list=expected_output_cols,
544
+ example_output_pd_df=example_output_pd_df,
545
+ )
546
+ else:
547
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
548
+ drop_input_cols=self._drop_input_cols,
549
+ expected_output_cols_list=expected_output_cols,
550
+ )
543
551
  self._sklearn_object = fitted_estimator
544
552
  self._is_fitted = True
545
553
  return output_result
@@ -564,6 +572,7 @@ class MDS(BaseTransformer):
564
572
  """
565
573
  self._infer_input_output_cols(dataset)
566
574
  super()._check_dataset_type(dataset)
575
+
567
576
  model_trainer = ModelTrainerBuilder.build_fit_transform(
568
577
  estimator=self._sklearn_object,
569
578
  dataset=dataset,
@@ -620,12 +629,41 @@ class MDS(BaseTransformer):
620
629
 
621
630
  return rv
622
631
 
623
- def _align_expected_output_names(
624
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
625
- ) -> List[str]:
632
+ def _align_expected_output(
633
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
634
+ ) -> Tuple[List[str], pd.DataFrame]:
635
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
636
+ and output dataframe with 1 line.
637
+ If the method is fit_predict, run 2 lines of data.
638
+ """
626
639
  # in case the inferred output column names dimension is different
627
640
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
628
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
641
+
642
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
643
+ # so change the minimum of number of rows to 2
644
+ num_examples = 2
645
+ statement_params = telemetry.get_function_usage_statement_params(
646
+ project=_PROJECT,
647
+ subproject=_SUBPROJECT,
648
+ function_name=telemetry.get_statement_params_full_func_name(
649
+ inspect.currentframe(), MDS.__class__.__name__
650
+ ),
651
+ api_calls=[Session.call],
652
+ custom_tags={"autogen": True} if self._autogenerated else None,
653
+ )
654
+ if output_cols_prefix == "fit_predict_":
655
+ if hasattr(self._sklearn_object, "n_clusters"):
656
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
657
+ num_examples = self._sklearn_object.n_clusters
658
+ elif hasattr(self._sklearn_object, "min_samples"):
659
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
660
+ num_examples = self._sklearn_object.min_samples
661
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
662
+ # LocalOutlierFactor expects n_neighbors <= n_samples
663
+ num_examples = self._sklearn_object.n_neighbors
664
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
665
+ else:
666
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
629
667
 
630
668
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
631
669
  # seen during the fit.
@@ -637,12 +675,14 @@ class MDS(BaseTransformer):
637
675
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
638
676
  if self.sample_weight_col:
639
677
  output_df_columns_set -= set(self.sample_weight_col)
678
+
640
679
  # if the dimension of inferred output column names is correct; use it
641
680
  if len(expected_output_cols_list) == len(output_df_columns_set):
642
- return expected_output_cols_list
681
+ return expected_output_cols_list, output_df_pd
643
682
  # otherwise, use the sklearn estimator's output
644
683
  else:
645
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
684
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
685
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
646
686
 
647
687
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
648
688
  @telemetry.send_api_usage_telemetry(
@@ -688,7 +728,7 @@ class MDS(BaseTransformer):
688
728
  drop_input_cols=self._drop_input_cols,
689
729
  expected_output_cols_type="float",
690
730
  )
691
- expected_output_cols = self._align_expected_output_names(
731
+ expected_output_cols, _ = self._align_expected_output(
692
732
  inference_method, dataset, expected_output_cols, output_cols_prefix
693
733
  )
694
734
 
@@ -754,7 +794,7 @@ class MDS(BaseTransformer):
754
794
  drop_input_cols=self._drop_input_cols,
755
795
  expected_output_cols_type="float",
756
796
  )
757
- expected_output_cols = self._align_expected_output_names(
797
+ expected_output_cols, _ = self._align_expected_output(
758
798
  inference_method, dataset, expected_output_cols, output_cols_prefix
759
799
  )
760
800
  elif isinstance(dataset, pd.DataFrame):
@@ -817,7 +857,7 @@ class MDS(BaseTransformer):
817
857
  drop_input_cols=self._drop_input_cols,
818
858
  expected_output_cols_type="float",
819
859
  )
820
- expected_output_cols = self._align_expected_output_names(
860
+ expected_output_cols, _ = self._align_expected_output(
821
861
  inference_method, dataset, expected_output_cols, output_cols_prefix
822
862
  )
823
863
 
@@ -882,7 +922,7 @@ class MDS(BaseTransformer):
882
922
  drop_input_cols = self._drop_input_cols,
883
923
  expected_output_cols_type="float",
884
924
  )
885
- expected_output_cols = self._align_expected_output_names(
925
+ expected_output_cols, _ = self._align_expected_output(
886
926
  inference_method, dataset, expected_output_cols, output_cols_prefix
887
927
  )
888
928
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -536,12 +533,23 @@ class SpectralEmbedding(BaseTransformer):
536
533
  autogenerated=self._autogenerated,
537
534
  subproject=_SUBPROJECT,
538
535
  )
539
- output_result, fitted_estimator = model_trainer.train_fit_predict(
540
- drop_input_cols=self._drop_input_cols,
541
- expected_output_cols_list=(
542
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
543
- ),
536
+ expected_output_cols = (
537
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
544
538
  )
539
+ if isinstance(dataset, DataFrame):
540
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
541
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
542
+ )
543
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
544
+ drop_input_cols=self._drop_input_cols,
545
+ expected_output_cols_list=expected_output_cols,
546
+ example_output_pd_df=example_output_pd_df,
547
+ )
548
+ else:
549
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
550
+ drop_input_cols=self._drop_input_cols,
551
+ expected_output_cols_list=expected_output_cols,
552
+ )
545
553
  self._sklearn_object = fitted_estimator
546
554
  self._is_fitted = True
547
555
  return output_result
@@ -566,6 +574,7 @@ class SpectralEmbedding(BaseTransformer):
566
574
  """
567
575
  self._infer_input_output_cols(dataset)
568
576
  super()._check_dataset_type(dataset)
577
+
569
578
  model_trainer = ModelTrainerBuilder.build_fit_transform(
570
579
  estimator=self._sklearn_object,
571
580
  dataset=dataset,
@@ -622,12 +631,41 @@ class SpectralEmbedding(BaseTransformer):
622
631
 
623
632
  return rv
624
633
 
625
- def _align_expected_output_names(
626
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
627
- ) -> List[str]:
634
+ def _align_expected_output(
635
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
636
+ ) -> Tuple[List[str], pd.DataFrame]:
637
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
638
+ and output dataframe with 1 line.
639
+ If the method is fit_predict, run 2 lines of data.
640
+ """
628
641
  # in case the inferred output column names dimension is different
629
642
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
630
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
643
+
644
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
645
+ # so change the minimum of number of rows to 2
646
+ num_examples = 2
647
+ statement_params = telemetry.get_function_usage_statement_params(
648
+ project=_PROJECT,
649
+ subproject=_SUBPROJECT,
650
+ function_name=telemetry.get_statement_params_full_func_name(
651
+ inspect.currentframe(), SpectralEmbedding.__class__.__name__
652
+ ),
653
+ api_calls=[Session.call],
654
+ custom_tags={"autogen": True} if self._autogenerated else None,
655
+ )
656
+ if output_cols_prefix == "fit_predict_":
657
+ if hasattr(self._sklearn_object, "n_clusters"):
658
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
659
+ num_examples = self._sklearn_object.n_clusters
660
+ elif hasattr(self._sklearn_object, "min_samples"):
661
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
662
+ num_examples = self._sklearn_object.min_samples
663
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
664
+ # LocalOutlierFactor expects n_neighbors <= n_samples
665
+ num_examples = self._sklearn_object.n_neighbors
666
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
667
+ else:
668
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
631
669
 
632
670
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
633
671
  # seen during the fit.
@@ -639,12 +677,14 @@ class SpectralEmbedding(BaseTransformer):
639
677
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
640
678
  if self.sample_weight_col:
641
679
  output_df_columns_set -= set(self.sample_weight_col)
680
+
642
681
  # if the dimension of inferred output column names is correct; use it
643
682
  if len(expected_output_cols_list) == len(output_df_columns_set):
644
- return expected_output_cols_list
683
+ return expected_output_cols_list, output_df_pd
645
684
  # otherwise, use the sklearn estimator's output
646
685
  else:
647
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
686
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
687
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
648
688
 
649
689
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
650
690
  @telemetry.send_api_usage_telemetry(
@@ -690,7 +730,7 @@ class SpectralEmbedding(BaseTransformer):
690
730
  drop_input_cols=self._drop_input_cols,
691
731
  expected_output_cols_type="float",
692
732
  )
693
- expected_output_cols = self._align_expected_output_names(
733
+ expected_output_cols, _ = self._align_expected_output(
694
734
  inference_method, dataset, expected_output_cols, output_cols_prefix
695
735
  )
696
736
 
@@ -756,7 +796,7 @@ class SpectralEmbedding(BaseTransformer):
756
796
  drop_input_cols=self._drop_input_cols,
757
797
  expected_output_cols_type="float",
758
798
  )
759
- expected_output_cols = self._align_expected_output_names(
799
+ expected_output_cols, _ = self._align_expected_output(
760
800
  inference_method, dataset, expected_output_cols, output_cols_prefix
761
801
  )
762
802
  elif isinstance(dataset, pd.DataFrame):
@@ -819,7 +859,7 @@ class SpectralEmbedding(BaseTransformer):
819
859
  drop_input_cols=self._drop_input_cols,
820
860
  expected_output_cols_type="float",
821
861
  )
822
- expected_output_cols = self._align_expected_output_names(
862
+ expected_output_cols, _ = self._align_expected_output(
823
863
  inference_method, dataset, expected_output_cols, output_cols_prefix
824
864
  )
825
865
 
@@ -884,7 +924,7 @@ class SpectralEmbedding(BaseTransformer):
884
924
  drop_input_cols = self._drop_input_cols,
885
925
  expected_output_cols_type="float",
886
926
  )
887
- expected_output_cols = self._align_expected_output_names(
927
+ expected_output_cols, _ = self._align_expected_output(
888
928
  inference_method, dataset, expected_output_cols, output_cols_prefix
889
929
  )
890
930
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -595,12 +592,23 @@ class TSNE(BaseTransformer):
595
592
  autogenerated=self._autogenerated,
596
593
  subproject=_SUBPROJECT,
597
594
  )
598
- output_result, fitted_estimator = model_trainer.train_fit_predict(
599
- drop_input_cols=self._drop_input_cols,
600
- expected_output_cols_list=(
601
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
602
- ),
595
+ expected_output_cols = (
596
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
603
597
  )
598
+ if isinstance(dataset, DataFrame):
599
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
600
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
601
+ )
602
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
603
+ drop_input_cols=self._drop_input_cols,
604
+ expected_output_cols_list=expected_output_cols,
605
+ example_output_pd_df=example_output_pd_df,
606
+ )
607
+ else:
608
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
609
+ drop_input_cols=self._drop_input_cols,
610
+ expected_output_cols_list=expected_output_cols,
611
+ )
604
612
  self._sklearn_object = fitted_estimator
605
613
  self._is_fitted = True
606
614
  return output_result
@@ -625,6 +633,7 @@ class TSNE(BaseTransformer):
625
633
  """
626
634
  self._infer_input_output_cols(dataset)
627
635
  super()._check_dataset_type(dataset)
636
+
628
637
  model_trainer = ModelTrainerBuilder.build_fit_transform(
629
638
  estimator=self._sklearn_object,
630
639
  dataset=dataset,
@@ -681,12 +690,41 @@ class TSNE(BaseTransformer):
681
690
 
682
691
  return rv
683
692
 
684
- def _align_expected_output_names(
685
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
686
- ) -> List[str]:
693
+ def _align_expected_output(
694
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
695
+ ) -> Tuple[List[str], pd.DataFrame]:
696
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
697
+ and output dataframe with 1 line.
698
+ If the method is fit_predict, run 2 lines of data.
699
+ """
687
700
  # in case the inferred output column names dimension is different
688
701
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
689
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
702
+
703
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
704
+ # so change the minimum of number of rows to 2
705
+ num_examples = 2
706
+ statement_params = telemetry.get_function_usage_statement_params(
707
+ project=_PROJECT,
708
+ subproject=_SUBPROJECT,
709
+ function_name=telemetry.get_statement_params_full_func_name(
710
+ inspect.currentframe(), TSNE.__class__.__name__
711
+ ),
712
+ api_calls=[Session.call],
713
+ custom_tags={"autogen": True} if self._autogenerated else None,
714
+ )
715
+ if output_cols_prefix == "fit_predict_":
716
+ if hasattr(self._sklearn_object, "n_clusters"):
717
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
718
+ num_examples = self._sklearn_object.n_clusters
719
+ elif hasattr(self._sklearn_object, "min_samples"):
720
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
721
+ num_examples = self._sklearn_object.min_samples
722
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
723
+ # LocalOutlierFactor expects n_neighbors <= n_samples
724
+ num_examples = self._sklearn_object.n_neighbors
725
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
726
+ else:
727
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
690
728
 
691
729
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
692
730
  # seen during the fit.
@@ -698,12 +736,14 @@ class TSNE(BaseTransformer):
698
736
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
699
737
  if self.sample_weight_col:
700
738
  output_df_columns_set -= set(self.sample_weight_col)
739
+
701
740
  # if the dimension of inferred output column names is correct; use it
702
741
  if len(expected_output_cols_list) == len(output_df_columns_set):
703
- return expected_output_cols_list
742
+ return expected_output_cols_list, output_df_pd
704
743
  # otherwise, use the sklearn estimator's output
705
744
  else:
706
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
745
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
746
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
707
747
 
708
748
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
709
749
  @telemetry.send_api_usage_telemetry(
@@ -749,7 +789,7 @@ class TSNE(BaseTransformer):
749
789
  drop_input_cols=self._drop_input_cols,
750
790
  expected_output_cols_type="float",
751
791
  )
752
- expected_output_cols = self._align_expected_output_names(
792
+ expected_output_cols, _ = self._align_expected_output(
753
793
  inference_method, dataset, expected_output_cols, output_cols_prefix
754
794
  )
755
795
 
@@ -815,7 +855,7 @@ class TSNE(BaseTransformer):
815
855
  drop_input_cols=self._drop_input_cols,
816
856
  expected_output_cols_type="float",
817
857
  )
818
- expected_output_cols = self._align_expected_output_names(
858
+ expected_output_cols, _ = self._align_expected_output(
819
859
  inference_method, dataset, expected_output_cols, output_cols_prefix
820
860
  )
821
861
  elif isinstance(dataset, pd.DataFrame):
@@ -878,7 +918,7 @@ class TSNE(BaseTransformer):
878
918
  drop_input_cols=self._drop_input_cols,
879
919
  expected_output_cols_type="float",
880
920
  )
881
- expected_output_cols = self._align_expected_output_names(
921
+ expected_output_cols, _ = self._align_expected_output(
882
922
  inference_method, dataset, expected_output_cols, output_cols_prefix
883
923
  )
884
924
 
@@ -943,7 +983,7 @@ class TSNE(BaseTransformer):
943
983
  drop_input_cols = self._drop_input_cols,
944
984
  expected_output_cols_type="float",
945
985
  )
946
- expected_output_cols = self._align_expected_output_names(
986
+ expected_output_cols, _ = self._align_expected_output(
947
987
  inference_method, dataset, expected_output_cols, output_cols_prefix
948
988
  )
949
989