snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -552,12 +549,23 @@ class MultiTaskLassoCV(BaseTransformer):
552
549
  autogenerated=self._autogenerated,
553
550
  subproject=_SUBPROJECT,
554
551
  )
555
- output_result, fitted_estimator = model_trainer.train_fit_predict(
556
- drop_input_cols=self._drop_input_cols,
557
- expected_output_cols_list=(
558
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
559
- ),
552
+ expected_output_cols = (
553
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
560
554
  )
555
+ if isinstance(dataset, DataFrame):
556
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
557
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
558
+ )
559
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
560
+ drop_input_cols=self._drop_input_cols,
561
+ expected_output_cols_list=expected_output_cols,
562
+ example_output_pd_df=example_output_pd_df,
563
+ )
564
+ else:
565
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
566
+ drop_input_cols=self._drop_input_cols,
567
+ expected_output_cols_list=expected_output_cols,
568
+ )
561
569
  self._sklearn_object = fitted_estimator
562
570
  self._is_fitted = True
563
571
  return output_result
@@ -580,6 +588,7 @@ class MultiTaskLassoCV(BaseTransformer):
580
588
  """
581
589
  self._infer_input_output_cols(dataset)
582
590
  super()._check_dataset_type(dataset)
591
+
583
592
  model_trainer = ModelTrainerBuilder.build_fit_transform(
584
593
  estimator=self._sklearn_object,
585
594
  dataset=dataset,
@@ -636,12 +645,41 @@ class MultiTaskLassoCV(BaseTransformer):
636
645
 
637
646
  return rv
638
647
 
639
- def _align_expected_output_names(
640
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
641
- ) -> List[str]:
648
+ def _align_expected_output(
649
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
650
+ ) -> Tuple[List[str], pd.DataFrame]:
651
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
652
+ and output dataframe with 1 line.
653
+ If the method is fit_predict, run 2 lines of data.
654
+ """
642
655
  # in case the inferred output column names dimension is different
643
656
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
644
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
657
+
658
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
659
+ # so change the minimum of number of rows to 2
660
+ num_examples = 2
661
+ statement_params = telemetry.get_function_usage_statement_params(
662
+ project=_PROJECT,
663
+ subproject=_SUBPROJECT,
664
+ function_name=telemetry.get_statement_params_full_func_name(
665
+ inspect.currentframe(), MultiTaskLassoCV.__class__.__name__
666
+ ),
667
+ api_calls=[Session.call],
668
+ custom_tags={"autogen": True} if self._autogenerated else None,
669
+ )
670
+ if output_cols_prefix == "fit_predict_":
671
+ if hasattr(self._sklearn_object, "n_clusters"):
672
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
673
+ num_examples = self._sklearn_object.n_clusters
674
+ elif hasattr(self._sklearn_object, "min_samples"):
675
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
676
+ num_examples = self._sklearn_object.min_samples
677
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
678
+ # LocalOutlierFactor expects n_neighbors <= n_samples
679
+ num_examples = self._sklearn_object.n_neighbors
680
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
681
+ else:
682
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
645
683
 
646
684
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
647
685
  # seen during the fit.
@@ -653,12 +691,14 @@ class MultiTaskLassoCV(BaseTransformer):
653
691
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
654
692
  if self.sample_weight_col:
655
693
  output_df_columns_set -= set(self.sample_weight_col)
694
+
656
695
  # if the dimension of inferred output column names is correct; use it
657
696
  if len(expected_output_cols_list) == len(output_df_columns_set):
658
- return expected_output_cols_list
697
+ return expected_output_cols_list, output_df_pd
659
698
  # otherwise, use the sklearn estimator's output
660
699
  else:
661
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
700
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
701
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
662
702
 
663
703
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
664
704
  @telemetry.send_api_usage_telemetry(
@@ -704,7 +744,7 @@ class MultiTaskLassoCV(BaseTransformer):
704
744
  drop_input_cols=self._drop_input_cols,
705
745
  expected_output_cols_type="float",
706
746
  )
707
- expected_output_cols = self._align_expected_output_names(
747
+ expected_output_cols, _ = self._align_expected_output(
708
748
  inference_method, dataset, expected_output_cols, output_cols_prefix
709
749
  )
710
750
 
@@ -770,7 +810,7 @@ class MultiTaskLassoCV(BaseTransformer):
770
810
  drop_input_cols=self._drop_input_cols,
771
811
  expected_output_cols_type="float",
772
812
  )
773
- expected_output_cols = self._align_expected_output_names(
813
+ expected_output_cols, _ = self._align_expected_output(
774
814
  inference_method, dataset, expected_output_cols, output_cols_prefix
775
815
  )
776
816
  elif isinstance(dataset, pd.DataFrame):
@@ -833,7 +873,7 @@ class MultiTaskLassoCV(BaseTransformer):
833
873
  drop_input_cols=self._drop_input_cols,
834
874
  expected_output_cols_type="float",
835
875
  )
836
- expected_output_cols = self._align_expected_output_names(
876
+ expected_output_cols, _ = self._align_expected_output(
837
877
  inference_method, dataset, expected_output_cols, output_cols_prefix
838
878
  )
839
879
 
@@ -898,7 +938,7 @@ class MultiTaskLassoCV(BaseTransformer):
898
938
  drop_input_cols = self._drop_input_cols,
899
939
  expected_output_cols_type="float",
900
940
  )
901
- expected_output_cols = self._align_expected_output_names(
941
+ expected_output_cols, _ = self._align_expected_output(
902
942
  inference_method, dataset, expected_output_cols, output_cols_prefix
903
943
  )
904
944
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -500,12 +497,23 @@ class OrthogonalMatchingPursuit(BaseTransformer):
500
497
  autogenerated=self._autogenerated,
501
498
  subproject=_SUBPROJECT,
502
499
  )
503
- output_result, fitted_estimator = model_trainer.train_fit_predict(
504
- drop_input_cols=self._drop_input_cols,
505
- expected_output_cols_list=(
506
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
507
- ),
500
+ expected_output_cols = (
501
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
508
502
  )
503
+ if isinstance(dataset, DataFrame):
504
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
505
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
506
+ )
507
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
508
+ drop_input_cols=self._drop_input_cols,
509
+ expected_output_cols_list=expected_output_cols,
510
+ example_output_pd_df=example_output_pd_df,
511
+ )
512
+ else:
513
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
514
+ drop_input_cols=self._drop_input_cols,
515
+ expected_output_cols_list=expected_output_cols,
516
+ )
509
517
  self._sklearn_object = fitted_estimator
510
518
  self._is_fitted = True
511
519
  return output_result
@@ -528,6 +536,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
528
536
  """
529
537
  self._infer_input_output_cols(dataset)
530
538
  super()._check_dataset_type(dataset)
539
+
531
540
  model_trainer = ModelTrainerBuilder.build_fit_transform(
532
541
  estimator=self._sklearn_object,
533
542
  dataset=dataset,
@@ -584,12 +593,41 @@ class OrthogonalMatchingPursuit(BaseTransformer):
584
593
 
585
594
  return rv
586
595
 
587
- def _align_expected_output_names(
588
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
589
- ) -> List[str]:
596
+ def _align_expected_output(
597
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
598
+ ) -> Tuple[List[str], pd.DataFrame]:
599
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
600
+ and output dataframe with 1 line.
601
+ If the method is fit_predict, run 2 lines of data.
602
+ """
590
603
  # in case the inferred output column names dimension is different
591
604
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
592
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
605
+
606
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
607
+ # so change the minimum of number of rows to 2
608
+ num_examples = 2
609
+ statement_params = telemetry.get_function_usage_statement_params(
610
+ project=_PROJECT,
611
+ subproject=_SUBPROJECT,
612
+ function_name=telemetry.get_statement_params_full_func_name(
613
+ inspect.currentframe(), OrthogonalMatchingPursuit.__class__.__name__
614
+ ),
615
+ api_calls=[Session.call],
616
+ custom_tags={"autogen": True} if self._autogenerated else None,
617
+ )
618
+ if output_cols_prefix == "fit_predict_":
619
+ if hasattr(self._sklearn_object, "n_clusters"):
620
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
621
+ num_examples = self._sklearn_object.n_clusters
622
+ elif hasattr(self._sklearn_object, "min_samples"):
623
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
624
+ num_examples = self._sklearn_object.min_samples
625
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
626
+ # LocalOutlierFactor expects n_neighbors <= n_samples
627
+ num_examples = self._sklearn_object.n_neighbors
628
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
629
+ else:
630
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
593
631
 
594
632
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
595
633
  # seen during the fit.
@@ -601,12 +639,14 @@ class OrthogonalMatchingPursuit(BaseTransformer):
601
639
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
602
640
  if self.sample_weight_col:
603
641
  output_df_columns_set -= set(self.sample_weight_col)
642
+
604
643
  # if the dimension of inferred output column names is correct; use it
605
644
  if len(expected_output_cols_list) == len(output_df_columns_set):
606
- return expected_output_cols_list
645
+ return expected_output_cols_list, output_df_pd
607
646
  # otherwise, use the sklearn estimator's output
608
647
  else:
609
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
648
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
649
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
610
650
 
611
651
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
612
652
  @telemetry.send_api_usage_telemetry(
@@ -652,7 +692,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
652
692
  drop_input_cols=self._drop_input_cols,
653
693
  expected_output_cols_type="float",
654
694
  )
655
- expected_output_cols = self._align_expected_output_names(
695
+ expected_output_cols, _ = self._align_expected_output(
656
696
  inference_method, dataset, expected_output_cols, output_cols_prefix
657
697
  )
658
698
 
@@ -718,7 +758,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
718
758
  drop_input_cols=self._drop_input_cols,
719
759
  expected_output_cols_type="float",
720
760
  )
721
- expected_output_cols = self._align_expected_output_names(
761
+ expected_output_cols, _ = self._align_expected_output(
722
762
  inference_method, dataset, expected_output_cols, output_cols_prefix
723
763
  )
724
764
  elif isinstance(dataset, pd.DataFrame):
@@ -781,7 +821,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
781
821
  drop_input_cols=self._drop_input_cols,
782
822
  expected_output_cols_type="float",
783
823
  )
784
- expected_output_cols = self._align_expected_output_names(
824
+ expected_output_cols, _ = self._align_expected_output(
785
825
  inference_method, dataset, expected_output_cols, output_cols_prefix
786
826
  )
787
827
 
@@ -846,7 +886,7 @@ class OrthogonalMatchingPursuit(BaseTransformer):
846
886
  drop_input_cols = self._drop_input_cols,
847
887
  expected_output_cols_type="float",
848
888
  )
849
- expected_output_cols = self._align_expected_output_names(
889
+ expected_output_cols, _ = self._align_expected_output(
850
890
  inference_method, dataset, expected_output_cols, output_cols_prefix
851
891
  )
852
892
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -574,12 +571,23 @@ class PassiveAggressiveClassifier(BaseTransformer):
574
571
  autogenerated=self._autogenerated,
575
572
  subproject=_SUBPROJECT,
576
573
  )
577
- output_result, fitted_estimator = model_trainer.train_fit_predict(
578
- drop_input_cols=self._drop_input_cols,
579
- expected_output_cols_list=(
580
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
581
- ),
574
+ expected_output_cols = (
575
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
582
576
  )
577
+ if isinstance(dataset, DataFrame):
578
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
579
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
580
+ )
581
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
582
+ drop_input_cols=self._drop_input_cols,
583
+ expected_output_cols_list=expected_output_cols,
584
+ example_output_pd_df=example_output_pd_df,
585
+ )
586
+ else:
587
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
588
+ drop_input_cols=self._drop_input_cols,
589
+ expected_output_cols_list=expected_output_cols,
590
+ )
583
591
  self._sklearn_object = fitted_estimator
584
592
  self._is_fitted = True
585
593
  return output_result
@@ -602,6 +610,7 @@ class PassiveAggressiveClassifier(BaseTransformer):
602
610
  """
603
611
  self._infer_input_output_cols(dataset)
604
612
  super()._check_dataset_type(dataset)
613
+
605
614
  model_trainer = ModelTrainerBuilder.build_fit_transform(
606
615
  estimator=self._sklearn_object,
607
616
  dataset=dataset,
@@ -658,12 +667,41 @@ class PassiveAggressiveClassifier(BaseTransformer):
658
667
 
659
668
  return rv
660
669
 
661
- def _align_expected_output_names(
662
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
663
- ) -> List[str]:
670
+ def _align_expected_output(
671
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
672
+ ) -> Tuple[List[str], pd.DataFrame]:
673
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
674
+ and output dataframe with 1 line.
675
+ If the method is fit_predict, run 2 lines of data.
676
+ """
664
677
  # in case the inferred output column names dimension is different
665
678
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
666
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
679
+
680
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
681
+ # so change the minimum of number of rows to 2
682
+ num_examples = 2
683
+ statement_params = telemetry.get_function_usage_statement_params(
684
+ project=_PROJECT,
685
+ subproject=_SUBPROJECT,
686
+ function_name=telemetry.get_statement_params_full_func_name(
687
+ inspect.currentframe(), PassiveAggressiveClassifier.__class__.__name__
688
+ ),
689
+ api_calls=[Session.call],
690
+ custom_tags={"autogen": True} if self._autogenerated else None,
691
+ )
692
+ if output_cols_prefix == "fit_predict_":
693
+ if hasattr(self._sklearn_object, "n_clusters"):
694
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
695
+ num_examples = self._sklearn_object.n_clusters
696
+ elif hasattr(self._sklearn_object, "min_samples"):
697
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
698
+ num_examples = self._sklearn_object.min_samples
699
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
700
+ # LocalOutlierFactor expects n_neighbors <= n_samples
701
+ num_examples = self._sklearn_object.n_neighbors
702
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
703
+ else:
704
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
667
705
 
668
706
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
669
707
  # seen during the fit.
@@ -675,12 +713,14 @@ class PassiveAggressiveClassifier(BaseTransformer):
675
713
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
676
714
  if self.sample_weight_col:
677
715
  output_df_columns_set -= set(self.sample_weight_col)
716
+
678
717
  # if the dimension of inferred output column names is correct; use it
679
718
  if len(expected_output_cols_list) == len(output_df_columns_set):
680
- return expected_output_cols_list
719
+ return expected_output_cols_list, output_df_pd
681
720
  # otherwise, use the sklearn estimator's output
682
721
  else:
683
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
722
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
723
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
684
724
 
685
725
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
686
726
  @telemetry.send_api_usage_telemetry(
@@ -726,7 +766,7 @@ class PassiveAggressiveClassifier(BaseTransformer):
726
766
  drop_input_cols=self._drop_input_cols,
727
767
  expected_output_cols_type="float",
728
768
  )
729
- expected_output_cols = self._align_expected_output_names(
769
+ expected_output_cols, _ = self._align_expected_output(
730
770
  inference_method, dataset, expected_output_cols, output_cols_prefix
731
771
  )
732
772
 
@@ -792,7 +832,7 @@ class PassiveAggressiveClassifier(BaseTransformer):
792
832
  drop_input_cols=self._drop_input_cols,
793
833
  expected_output_cols_type="float",
794
834
  )
795
- expected_output_cols = self._align_expected_output_names(
835
+ expected_output_cols, _ = self._align_expected_output(
796
836
  inference_method, dataset, expected_output_cols, output_cols_prefix
797
837
  )
798
838
  elif isinstance(dataset, pd.DataFrame):
@@ -857,7 +897,7 @@ class PassiveAggressiveClassifier(BaseTransformer):
857
897
  drop_input_cols=self._drop_input_cols,
858
898
  expected_output_cols_type="float",
859
899
  )
860
- expected_output_cols = self._align_expected_output_names(
900
+ expected_output_cols, _ = self._align_expected_output(
861
901
  inference_method, dataset, expected_output_cols, output_cols_prefix
862
902
  )
863
903
 
@@ -922,7 +962,7 @@ class PassiveAggressiveClassifier(BaseTransformer):
922
962
  drop_input_cols = self._drop_input_cols,
923
963
  expected_output_cols_type="float",
924
964
  )
925
- expected_output_cols = self._align_expected_output_names(
965
+ expected_output_cols, _ = self._align_expected_output(
926
966
  inference_method, dataset, expected_output_cols, output_cols_prefix
927
967
  )
928
968
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -560,12 +557,23 @@ class PassiveAggressiveRegressor(BaseTransformer):
560
557
  autogenerated=self._autogenerated,
561
558
  subproject=_SUBPROJECT,
562
559
  )
563
- output_result, fitted_estimator = model_trainer.train_fit_predict(
564
- drop_input_cols=self._drop_input_cols,
565
- expected_output_cols_list=(
566
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
567
- ),
560
+ expected_output_cols = (
561
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
568
562
  )
563
+ if isinstance(dataset, DataFrame):
564
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
565
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
566
+ )
567
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
568
+ drop_input_cols=self._drop_input_cols,
569
+ expected_output_cols_list=expected_output_cols,
570
+ example_output_pd_df=example_output_pd_df,
571
+ )
572
+ else:
573
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
574
+ drop_input_cols=self._drop_input_cols,
575
+ expected_output_cols_list=expected_output_cols,
576
+ )
569
577
  self._sklearn_object = fitted_estimator
570
578
  self._is_fitted = True
571
579
  return output_result
@@ -588,6 +596,7 @@ class PassiveAggressiveRegressor(BaseTransformer):
588
596
  """
589
597
  self._infer_input_output_cols(dataset)
590
598
  super()._check_dataset_type(dataset)
599
+
591
600
  model_trainer = ModelTrainerBuilder.build_fit_transform(
592
601
  estimator=self._sklearn_object,
593
602
  dataset=dataset,
@@ -644,12 +653,41 @@ class PassiveAggressiveRegressor(BaseTransformer):
644
653
 
645
654
  return rv
646
655
 
647
- def _align_expected_output_names(
648
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
649
- ) -> List[str]:
656
+ def _align_expected_output(
657
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
658
+ ) -> Tuple[List[str], pd.DataFrame]:
659
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
660
+ and output dataframe with 1 line.
661
+ If the method is fit_predict, run 2 lines of data.
662
+ """
650
663
  # in case the inferred output column names dimension is different
651
664
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
652
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
665
+
666
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
667
+ # so change the minimum of number of rows to 2
668
+ num_examples = 2
669
+ statement_params = telemetry.get_function_usage_statement_params(
670
+ project=_PROJECT,
671
+ subproject=_SUBPROJECT,
672
+ function_name=telemetry.get_statement_params_full_func_name(
673
+ inspect.currentframe(), PassiveAggressiveRegressor.__class__.__name__
674
+ ),
675
+ api_calls=[Session.call],
676
+ custom_tags={"autogen": True} if self._autogenerated else None,
677
+ )
678
+ if output_cols_prefix == "fit_predict_":
679
+ if hasattr(self._sklearn_object, "n_clusters"):
680
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
681
+ num_examples = self._sklearn_object.n_clusters
682
+ elif hasattr(self._sklearn_object, "min_samples"):
683
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
684
+ num_examples = self._sklearn_object.min_samples
685
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
686
+ # LocalOutlierFactor expects n_neighbors <= n_samples
687
+ num_examples = self._sklearn_object.n_neighbors
688
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
689
+ else:
690
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
653
691
 
654
692
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
655
693
  # seen during the fit.
@@ -661,12 +699,14 @@ class PassiveAggressiveRegressor(BaseTransformer):
661
699
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
662
700
  if self.sample_weight_col:
663
701
  output_df_columns_set -= set(self.sample_weight_col)
702
+
664
703
  # if the dimension of inferred output column names is correct; use it
665
704
  if len(expected_output_cols_list) == len(output_df_columns_set):
666
- return expected_output_cols_list
705
+ return expected_output_cols_list, output_df_pd
667
706
  # otherwise, use the sklearn estimator's output
668
707
  else:
669
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
708
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
709
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
670
710
 
671
711
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
672
712
  @telemetry.send_api_usage_telemetry(
@@ -712,7 +752,7 @@ class PassiveAggressiveRegressor(BaseTransformer):
712
752
  drop_input_cols=self._drop_input_cols,
713
753
  expected_output_cols_type="float",
714
754
  )
715
- expected_output_cols = self._align_expected_output_names(
755
+ expected_output_cols, _ = self._align_expected_output(
716
756
  inference_method, dataset, expected_output_cols, output_cols_prefix
717
757
  )
718
758
 
@@ -778,7 +818,7 @@ class PassiveAggressiveRegressor(BaseTransformer):
778
818
  drop_input_cols=self._drop_input_cols,
779
819
  expected_output_cols_type="float",
780
820
  )
781
- expected_output_cols = self._align_expected_output_names(
821
+ expected_output_cols, _ = self._align_expected_output(
782
822
  inference_method, dataset, expected_output_cols, output_cols_prefix
783
823
  )
784
824
  elif isinstance(dataset, pd.DataFrame):
@@ -841,7 +881,7 @@ class PassiveAggressiveRegressor(BaseTransformer):
841
881
  drop_input_cols=self._drop_input_cols,
842
882
  expected_output_cols_type="float",
843
883
  )
844
- expected_output_cols = self._align_expected_output_names(
884
+ expected_output_cols, _ = self._align_expected_output(
845
885
  inference_method, dataset, expected_output_cols, output_cols_prefix
846
886
  )
847
887
 
@@ -906,7 +946,7 @@ class PassiveAggressiveRegressor(BaseTransformer):
906
946
  drop_input_cols = self._drop_input_cols,
907
947
  expected_output_cols_type="float",
908
948
  )
909
- expected_output_cols = self._align_expected_output_names(
949
+ expected_output_cols, _ = self._align_expected_output(
910
950
  inference_method, dataset, expected_output_cols, output_cols_prefix
911
951
  )
912
952