snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -563,12 +560,23 @@ class SGDOneClassSVM(BaseTransformer):
563
560
  autogenerated=self._autogenerated,
564
561
  subproject=_SUBPROJECT,
565
562
  )
566
- output_result, fitted_estimator = model_trainer.train_fit_predict(
567
- drop_input_cols=self._drop_input_cols,
568
- expected_output_cols_list=(
569
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
570
- ),
563
+ expected_output_cols = (
564
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
571
565
  )
566
+ if isinstance(dataset, DataFrame):
567
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
568
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
569
+ )
570
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
571
+ drop_input_cols=self._drop_input_cols,
572
+ expected_output_cols_list=expected_output_cols,
573
+ example_output_pd_df=example_output_pd_df,
574
+ )
575
+ else:
576
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
577
+ drop_input_cols=self._drop_input_cols,
578
+ expected_output_cols_list=expected_output_cols,
579
+ )
572
580
  self._sklearn_object = fitted_estimator
573
581
  self._is_fitted = True
574
582
  return output_result
@@ -591,6 +599,7 @@ class SGDOneClassSVM(BaseTransformer):
591
599
  """
592
600
  self._infer_input_output_cols(dataset)
593
601
  super()._check_dataset_type(dataset)
602
+
594
603
  model_trainer = ModelTrainerBuilder.build_fit_transform(
595
604
  estimator=self._sklearn_object,
596
605
  dataset=dataset,
@@ -647,12 +656,41 @@ class SGDOneClassSVM(BaseTransformer):
647
656
 
648
657
  return rv
649
658
 
650
- def _align_expected_output_names(
651
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
652
- ) -> List[str]:
659
+ def _align_expected_output(
660
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
661
+ ) -> Tuple[List[str], pd.DataFrame]:
662
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
663
+ and output dataframe with 1 line.
664
+ If the method is fit_predict, run 2 lines of data.
665
+ """
653
666
  # in case the inferred output column names dimension is different
654
667
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
655
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
668
+
669
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
670
+ # so change the minimum of number of rows to 2
671
+ num_examples = 2
672
+ statement_params = telemetry.get_function_usage_statement_params(
673
+ project=_PROJECT,
674
+ subproject=_SUBPROJECT,
675
+ function_name=telemetry.get_statement_params_full_func_name(
676
+ inspect.currentframe(), SGDOneClassSVM.__class__.__name__
677
+ ),
678
+ api_calls=[Session.call],
679
+ custom_tags={"autogen": True} if self._autogenerated else None,
680
+ )
681
+ if output_cols_prefix == "fit_predict_":
682
+ if hasattr(self._sklearn_object, "n_clusters"):
683
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
684
+ num_examples = self._sklearn_object.n_clusters
685
+ elif hasattr(self._sklearn_object, "min_samples"):
686
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
687
+ num_examples = self._sklearn_object.min_samples
688
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
689
+ # LocalOutlierFactor expects n_neighbors <= n_samples
690
+ num_examples = self._sklearn_object.n_neighbors
691
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
692
+ else:
693
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
656
694
 
657
695
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
658
696
  # seen during the fit.
@@ -664,12 +702,14 @@ class SGDOneClassSVM(BaseTransformer):
664
702
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
665
703
  if self.sample_weight_col:
666
704
  output_df_columns_set -= set(self.sample_weight_col)
705
+
667
706
  # if the dimension of inferred output column names is correct; use it
668
707
  if len(expected_output_cols_list) == len(output_df_columns_set):
669
- return expected_output_cols_list
708
+ return expected_output_cols_list, output_df_pd
670
709
  # otherwise, use the sklearn estimator's output
671
710
  else:
672
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
711
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
712
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
673
713
 
674
714
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
675
715
  @telemetry.send_api_usage_telemetry(
@@ -715,7 +755,7 @@ class SGDOneClassSVM(BaseTransformer):
715
755
  drop_input_cols=self._drop_input_cols,
716
756
  expected_output_cols_type="float",
717
757
  )
718
- expected_output_cols = self._align_expected_output_names(
758
+ expected_output_cols, _ = self._align_expected_output(
719
759
  inference_method, dataset, expected_output_cols, output_cols_prefix
720
760
  )
721
761
 
@@ -781,7 +821,7 @@ class SGDOneClassSVM(BaseTransformer):
781
821
  drop_input_cols=self._drop_input_cols,
782
822
  expected_output_cols_type="float",
783
823
  )
784
- expected_output_cols = self._align_expected_output_names(
824
+ expected_output_cols, _ = self._align_expected_output(
785
825
  inference_method, dataset, expected_output_cols, output_cols_prefix
786
826
  )
787
827
  elif isinstance(dataset, pd.DataFrame):
@@ -846,7 +886,7 @@ class SGDOneClassSVM(BaseTransformer):
846
886
  drop_input_cols=self._drop_input_cols,
847
887
  expected_output_cols_type="float",
848
888
  )
849
- expected_output_cols = self._align_expected_output_names(
889
+ expected_output_cols, _ = self._align_expected_output(
850
890
  inference_method, dataset, expected_output_cols, output_cols_prefix
851
891
  )
852
892
 
@@ -913,7 +953,7 @@ class SGDOneClassSVM(BaseTransformer):
913
953
  drop_input_cols = self._drop_input_cols,
914
954
  expected_output_cols_type="float",
915
955
  )
916
- expected_output_cols = self._align_expected_output_names(
956
+ expected_output_cols, _ = self._align_expected_output(
917
957
  inference_method, dataset, expected_output_cols, output_cols_prefix
918
958
  )
919
959
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -627,12 +624,23 @@ class SGDRegressor(BaseTransformer):
627
624
  autogenerated=self._autogenerated,
628
625
  subproject=_SUBPROJECT,
629
626
  )
630
- output_result, fitted_estimator = model_trainer.train_fit_predict(
631
- drop_input_cols=self._drop_input_cols,
632
- expected_output_cols_list=(
633
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
634
- ),
627
+ expected_output_cols = (
628
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
635
629
  )
630
+ if isinstance(dataset, DataFrame):
631
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
632
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
633
+ )
634
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
635
+ drop_input_cols=self._drop_input_cols,
636
+ expected_output_cols_list=expected_output_cols,
637
+ example_output_pd_df=example_output_pd_df,
638
+ )
639
+ else:
640
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
641
+ drop_input_cols=self._drop_input_cols,
642
+ expected_output_cols_list=expected_output_cols,
643
+ )
636
644
  self._sklearn_object = fitted_estimator
637
645
  self._is_fitted = True
638
646
  return output_result
@@ -655,6 +663,7 @@ class SGDRegressor(BaseTransformer):
655
663
  """
656
664
  self._infer_input_output_cols(dataset)
657
665
  super()._check_dataset_type(dataset)
666
+
658
667
  model_trainer = ModelTrainerBuilder.build_fit_transform(
659
668
  estimator=self._sklearn_object,
660
669
  dataset=dataset,
@@ -711,12 +720,41 @@ class SGDRegressor(BaseTransformer):
711
720
 
712
721
  return rv
713
722
 
714
- def _align_expected_output_names(
715
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
716
- ) -> List[str]:
723
+ def _align_expected_output(
724
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
725
+ ) -> Tuple[List[str], pd.DataFrame]:
726
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
727
+ and output dataframe with 1 line.
728
+ If the method is fit_predict, run 2 lines of data.
729
+ """
717
730
  # in case the inferred output column names dimension is different
718
731
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
719
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
732
+
733
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
734
+ # so change the minimum of number of rows to 2
735
+ num_examples = 2
736
+ statement_params = telemetry.get_function_usage_statement_params(
737
+ project=_PROJECT,
738
+ subproject=_SUBPROJECT,
739
+ function_name=telemetry.get_statement_params_full_func_name(
740
+ inspect.currentframe(), SGDRegressor.__class__.__name__
741
+ ),
742
+ api_calls=[Session.call],
743
+ custom_tags={"autogen": True} if self._autogenerated else None,
744
+ )
745
+ if output_cols_prefix == "fit_predict_":
746
+ if hasattr(self._sklearn_object, "n_clusters"):
747
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
748
+ num_examples = self._sklearn_object.n_clusters
749
+ elif hasattr(self._sklearn_object, "min_samples"):
750
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
751
+ num_examples = self._sklearn_object.min_samples
752
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
753
+ # LocalOutlierFactor expects n_neighbors <= n_samples
754
+ num_examples = self._sklearn_object.n_neighbors
755
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
756
+ else:
757
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
720
758
 
721
759
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
722
760
  # seen during the fit.
@@ -728,12 +766,14 @@ class SGDRegressor(BaseTransformer):
728
766
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
729
767
  if self.sample_weight_col:
730
768
  output_df_columns_set -= set(self.sample_weight_col)
769
+
731
770
  # if the dimension of inferred output column names is correct; use it
732
771
  if len(expected_output_cols_list) == len(output_df_columns_set):
733
- return expected_output_cols_list
772
+ return expected_output_cols_list, output_df_pd
734
773
  # otherwise, use the sklearn estimator's output
735
774
  else:
736
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
775
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
776
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
737
777
 
738
778
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
739
779
  @telemetry.send_api_usage_telemetry(
@@ -779,7 +819,7 @@ class SGDRegressor(BaseTransformer):
779
819
  drop_input_cols=self._drop_input_cols,
780
820
  expected_output_cols_type="float",
781
821
  )
782
- expected_output_cols = self._align_expected_output_names(
822
+ expected_output_cols, _ = self._align_expected_output(
783
823
  inference_method, dataset, expected_output_cols, output_cols_prefix
784
824
  )
785
825
 
@@ -845,7 +885,7 @@ class SGDRegressor(BaseTransformer):
845
885
  drop_input_cols=self._drop_input_cols,
846
886
  expected_output_cols_type="float",
847
887
  )
848
- expected_output_cols = self._align_expected_output_names(
888
+ expected_output_cols, _ = self._align_expected_output(
849
889
  inference_method, dataset, expected_output_cols, output_cols_prefix
850
890
  )
851
891
  elif isinstance(dataset, pd.DataFrame):
@@ -908,7 +948,7 @@ class SGDRegressor(BaseTransformer):
908
948
  drop_input_cols=self._drop_input_cols,
909
949
  expected_output_cols_type="float",
910
950
  )
911
- expected_output_cols = self._align_expected_output_names(
951
+ expected_output_cols, _ = self._align_expected_output(
912
952
  inference_method, dataset, expected_output_cols, output_cols_prefix
913
953
  )
914
954
 
@@ -973,7 +1013,7 @@ class SGDRegressor(BaseTransformer):
973
1013
  drop_input_cols = self._drop_input_cols,
974
1014
  expected_output_cols_type="float",
975
1015
  )
976
- expected_output_cols = self._align_expected_output_names(
1016
+ expected_output_cols, _ = self._align_expected_output(
977
1017
  inference_method, dataset, expected_output_cols, output_cols_prefix
978
1018
  )
979
1019
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -529,12 +526,23 @@ class TheilSenRegressor(BaseTransformer):
529
526
  autogenerated=self._autogenerated,
530
527
  subproject=_SUBPROJECT,
531
528
  )
532
- output_result, fitted_estimator = model_trainer.train_fit_predict(
533
- drop_input_cols=self._drop_input_cols,
534
- expected_output_cols_list=(
535
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
536
- ),
529
+ expected_output_cols = (
530
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
537
531
  )
532
+ if isinstance(dataset, DataFrame):
533
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
534
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
535
+ )
536
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
537
+ drop_input_cols=self._drop_input_cols,
538
+ expected_output_cols_list=expected_output_cols,
539
+ example_output_pd_df=example_output_pd_df,
540
+ )
541
+ else:
542
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
543
+ drop_input_cols=self._drop_input_cols,
544
+ expected_output_cols_list=expected_output_cols,
545
+ )
538
546
  self._sklearn_object = fitted_estimator
539
547
  self._is_fitted = True
540
548
  return output_result
@@ -557,6 +565,7 @@ class TheilSenRegressor(BaseTransformer):
557
565
  """
558
566
  self._infer_input_output_cols(dataset)
559
567
  super()._check_dataset_type(dataset)
568
+
560
569
  model_trainer = ModelTrainerBuilder.build_fit_transform(
561
570
  estimator=self._sklearn_object,
562
571
  dataset=dataset,
@@ -613,12 +622,41 @@ class TheilSenRegressor(BaseTransformer):
613
622
 
614
623
  return rv
615
624
 
616
- def _align_expected_output_names(
617
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
618
- ) -> List[str]:
625
+ def _align_expected_output(
626
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
627
+ ) -> Tuple[List[str], pd.DataFrame]:
628
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
629
+ and output dataframe with 1 line.
630
+ If the method is fit_predict, run 2 lines of data.
631
+ """
619
632
  # in case the inferred output column names dimension is different
620
633
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
621
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
634
+
635
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
636
+ # so change the minimum of number of rows to 2
637
+ num_examples = 2
638
+ statement_params = telemetry.get_function_usage_statement_params(
639
+ project=_PROJECT,
640
+ subproject=_SUBPROJECT,
641
+ function_name=telemetry.get_statement_params_full_func_name(
642
+ inspect.currentframe(), TheilSenRegressor.__class__.__name__
643
+ ),
644
+ api_calls=[Session.call],
645
+ custom_tags={"autogen": True} if self._autogenerated else None,
646
+ )
647
+ if output_cols_prefix == "fit_predict_":
648
+ if hasattr(self._sklearn_object, "n_clusters"):
649
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
650
+ num_examples = self._sklearn_object.n_clusters
651
+ elif hasattr(self._sklearn_object, "min_samples"):
652
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
653
+ num_examples = self._sklearn_object.min_samples
654
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
655
+ # LocalOutlierFactor expects n_neighbors <= n_samples
656
+ num_examples = self._sklearn_object.n_neighbors
657
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
658
+ else:
659
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
622
660
 
623
661
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
624
662
  # seen during the fit.
@@ -630,12 +668,14 @@ class TheilSenRegressor(BaseTransformer):
630
668
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
631
669
  if self.sample_weight_col:
632
670
  output_df_columns_set -= set(self.sample_weight_col)
671
+
633
672
  # if the dimension of inferred output column names is correct; use it
634
673
  if len(expected_output_cols_list) == len(output_df_columns_set):
635
- return expected_output_cols_list
674
+ return expected_output_cols_list, output_df_pd
636
675
  # otherwise, use the sklearn estimator's output
637
676
  else:
638
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
677
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
678
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
639
679
 
640
680
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
641
681
  @telemetry.send_api_usage_telemetry(
@@ -681,7 +721,7 @@ class TheilSenRegressor(BaseTransformer):
681
721
  drop_input_cols=self._drop_input_cols,
682
722
  expected_output_cols_type="float",
683
723
  )
684
- expected_output_cols = self._align_expected_output_names(
724
+ expected_output_cols, _ = self._align_expected_output(
685
725
  inference_method, dataset, expected_output_cols, output_cols_prefix
686
726
  )
687
727
 
@@ -747,7 +787,7 @@ class TheilSenRegressor(BaseTransformer):
747
787
  drop_input_cols=self._drop_input_cols,
748
788
  expected_output_cols_type="float",
749
789
  )
750
- expected_output_cols = self._align_expected_output_names(
790
+ expected_output_cols, _ = self._align_expected_output(
751
791
  inference_method, dataset, expected_output_cols, output_cols_prefix
752
792
  )
753
793
  elif isinstance(dataset, pd.DataFrame):
@@ -810,7 +850,7 @@ class TheilSenRegressor(BaseTransformer):
810
850
  drop_input_cols=self._drop_input_cols,
811
851
  expected_output_cols_type="float",
812
852
  )
813
- expected_output_cols = self._align_expected_output_names(
853
+ expected_output_cols, _ = self._align_expected_output(
814
854
  inference_method, dataset, expected_output_cols, output_cols_prefix
815
855
  )
816
856
 
@@ -875,7 +915,7 @@ class TheilSenRegressor(BaseTransformer):
875
915
  drop_input_cols = self._drop_input_cols,
876
916
  expected_output_cols_type="float",
877
917
  )
878
- expected_output_cols = self._align_expected_output_names(
918
+ expected_output_cols, _ = self._align_expected_output(
879
919
  inference_method, dataset, expected_output_cols, output_cols_prefix
880
920
  )
881
921
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -555,12 +552,23 @@ class TweedieRegressor(BaseTransformer):
555
552
  autogenerated=self._autogenerated,
556
553
  subproject=_SUBPROJECT,
557
554
  )
558
- output_result, fitted_estimator = model_trainer.train_fit_predict(
559
- drop_input_cols=self._drop_input_cols,
560
- expected_output_cols_list=(
561
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
562
- ),
555
+ expected_output_cols = (
556
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
563
557
  )
558
+ if isinstance(dataset, DataFrame):
559
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
560
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
561
+ )
562
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
563
+ drop_input_cols=self._drop_input_cols,
564
+ expected_output_cols_list=expected_output_cols,
565
+ example_output_pd_df=example_output_pd_df,
566
+ )
567
+ else:
568
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
569
+ drop_input_cols=self._drop_input_cols,
570
+ expected_output_cols_list=expected_output_cols,
571
+ )
564
572
  self._sklearn_object = fitted_estimator
565
573
  self._is_fitted = True
566
574
  return output_result
@@ -583,6 +591,7 @@ class TweedieRegressor(BaseTransformer):
583
591
  """
584
592
  self._infer_input_output_cols(dataset)
585
593
  super()._check_dataset_type(dataset)
594
+
586
595
  model_trainer = ModelTrainerBuilder.build_fit_transform(
587
596
  estimator=self._sklearn_object,
588
597
  dataset=dataset,
@@ -639,12 +648,41 @@ class TweedieRegressor(BaseTransformer):
639
648
 
640
649
  return rv
641
650
 
642
- def _align_expected_output_names(
643
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
644
- ) -> List[str]:
651
+ def _align_expected_output(
652
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
653
+ ) -> Tuple[List[str], pd.DataFrame]:
654
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
655
+ and output dataframe with 1 line.
656
+ If the method is fit_predict, run 2 lines of data.
657
+ """
645
658
  # in case the inferred output column names dimension is different
646
659
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
647
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
660
+
661
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
662
+ # so change the minimum of number of rows to 2
663
+ num_examples = 2
664
+ statement_params = telemetry.get_function_usage_statement_params(
665
+ project=_PROJECT,
666
+ subproject=_SUBPROJECT,
667
+ function_name=telemetry.get_statement_params_full_func_name(
668
+ inspect.currentframe(), TweedieRegressor.__class__.__name__
669
+ ),
670
+ api_calls=[Session.call],
671
+ custom_tags={"autogen": True} if self._autogenerated else None,
672
+ )
673
+ if output_cols_prefix == "fit_predict_":
674
+ if hasattr(self._sklearn_object, "n_clusters"):
675
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
676
+ num_examples = self._sklearn_object.n_clusters
677
+ elif hasattr(self._sklearn_object, "min_samples"):
678
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
679
+ num_examples = self._sklearn_object.min_samples
680
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
681
+ # LocalOutlierFactor expects n_neighbors <= n_samples
682
+ num_examples = self._sklearn_object.n_neighbors
683
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
684
+ else:
685
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
648
686
 
649
687
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
650
688
  # seen during the fit.
@@ -656,12 +694,14 @@ class TweedieRegressor(BaseTransformer):
656
694
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
657
695
  if self.sample_weight_col:
658
696
  output_df_columns_set -= set(self.sample_weight_col)
697
+
659
698
  # if the dimension of inferred output column names is correct; use it
660
699
  if len(expected_output_cols_list) == len(output_df_columns_set):
661
- return expected_output_cols_list
700
+ return expected_output_cols_list, output_df_pd
662
701
  # otherwise, use the sklearn estimator's output
663
702
  else:
664
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
703
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
704
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
665
705
 
666
706
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
667
707
  @telemetry.send_api_usage_telemetry(
@@ -707,7 +747,7 @@ class TweedieRegressor(BaseTransformer):
707
747
  drop_input_cols=self._drop_input_cols,
708
748
  expected_output_cols_type="float",
709
749
  )
710
- expected_output_cols = self._align_expected_output_names(
750
+ expected_output_cols, _ = self._align_expected_output(
711
751
  inference_method, dataset, expected_output_cols, output_cols_prefix
712
752
  )
713
753
 
@@ -773,7 +813,7 @@ class TweedieRegressor(BaseTransformer):
773
813
  drop_input_cols=self._drop_input_cols,
774
814
  expected_output_cols_type="float",
775
815
  )
776
- expected_output_cols = self._align_expected_output_names(
816
+ expected_output_cols, _ = self._align_expected_output(
777
817
  inference_method, dataset, expected_output_cols, output_cols_prefix
778
818
  )
779
819
  elif isinstance(dataset, pd.DataFrame):
@@ -836,7 +876,7 @@ class TweedieRegressor(BaseTransformer):
836
876
  drop_input_cols=self._drop_input_cols,
837
877
  expected_output_cols_type="float",
838
878
  )
839
- expected_output_cols = self._align_expected_output_names(
879
+ expected_output_cols, _ = self._align_expected_output(
840
880
  inference_method, dataset, expected_output_cols, output_cols_prefix
841
881
  )
842
882
 
@@ -901,7 +941,7 @@ class TweedieRegressor(BaseTransformer):
901
941
  drop_input_cols = self._drop_input_cols,
902
942
  expected_output_cols_type="float",
903
943
  )
904
- expected_output_cols = self._align_expected_output_names(
944
+ expected_output_cols, _ = self._align_expected_output(
905
945
  inference_method, dataset, expected_output_cols, output_cols_prefix
906
946
  )
907
947