snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -503,12 +500,23 @@ class EllipticEnvelope(BaseTransformer):
503
500
  autogenerated=self._autogenerated,
504
501
  subproject=_SUBPROJECT,
505
502
  )
506
- output_result, fitted_estimator = model_trainer.train_fit_predict(
507
- drop_input_cols=self._drop_input_cols,
508
- expected_output_cols_list=(
509
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
510
- ),
503
+ expected_output_cols = (
504
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
511
505
  )
506
+ if isinstance(dataset, DataFrame):
507
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
508
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
509
+ )
510
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
511
+ drop_input_cols=self._drop_input_cols,
512
+ expected_output_cols_list=expected_output_cols,
513
+ example_output_pd_df=example_output_pd_df,
514
+ )
515
+ else:
516
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
517
+ drop_input_cols=self._drop_input_cols,
518
+ expected_output_cols_list=expected_output_cols,
519
+ )
512
520
  self._sklearn_object = fitted_estimator
513
521
  self._is_fitted = True
514
522
  return output_result
@@ -531,6 +539,7 @@ class EllipticEnvelope(BaseTransformer):
531
539
  """
532
540
  self._infer_input_output_cols(dataset)
533
541
  super()._check_dataset_type(dataset)
542
+
534
543
  model_trainer = ModelTrainerBuilder.build_fit_transform(
535
544
  estimator=self._sklearn_object,
536
545
  dataset=dataset,
@@ -587,12 +596,41 @@ class EllipticEnvelope(BaseTransformer):
587
596
 
588
597
  return rv
589
598
 
590
- def _align_expected_output_names(
591
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
592
- ) -> List[str]:
599
+ def _align_expected_output(
600
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
601
+ ) -> Tuple[List[str], pd.DataFrame]:
602
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
603
+ and output dataframe with 1 line.
604
+ If the method is fit_predict, run 2 lines of data.
605
+ """
593
606
  # in case the inferred output column names dimension is different
594
607
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
595
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
608
+
609
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
610
+ # so change the minimum of number of rows to 2
611
+ num_examples = 2
612
+ statement_params = telemetry.get_function_usage_statement_params(
613
+ project=_PROJECT,
614
+ subproject=_SUBPROJECT,
615
+ function_name=telemetry.get_statement_params_full_func_name(
616
+ inspect.currentframe(), EllipticEnvelope.__class__.__name__
617
+ ),
618
+ api_calls=[Session.call],
619
+ custom_tags={"autogen": True} if self._autogenerated else None,
620
+ )
621
+ if output_cols_prefix == "fit_predict_":
622
+ if hasattr(self._sklearn_object, "n_clusters"):
623
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
624
+ num_examples = self._sklearn_object.n_clusters
625
+ elif hasattr(self._sklearn_object, "min_samples"):
626
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
627
+ num_examples = self._sklearn_object.min_samples
628
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
629
+ # LocalOutlierFactor expects n_neighbors <= n_samples
630
+ num_examples = self._sklearn_object.n_neighbors
631
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
632
+ else:
633
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
596
634
 
597
635
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
598
636
  # seen during the fit.
@@ -604,12 +642,14 @@ class EllipticEnvelope(BaseTransformer):
604
642
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
605
643
  if self.sample_weight_col:
606
644
  output_df_columns_set -= set(self.sample_weight_col)
645
+
607
646
  # if the dimension of inferred output column names is correct; use it
608
647
  if len(expected_output_cols_list) == len(output_df_columns_set):
609
- return expected_output_cols_list
648
+ return expected_output_cols_list, output_df_pd
610
649
  # otherwise, use the sklearn estimator's output
611
650
  else:
612
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
651
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
652
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
613
653
 
614
654
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
615
655
  @telemetry.send_api_usage_telemetry(
@@ -655,7 +695,7 @@ class EllipticEnvelope(BaseTransformer):
655
695
  drop_input_cols=self._drop_input_cols,
656
696
  expected_output_cols_type="float",
657
697
  )
658
- expected_output_cols = self._align_expected_output_names(
698
+ expected_output_cols, _ = self._align_expected_output(
659
699
  inference_method, dataset, expected_output_cols, output_cols_prefix
660
700
  )
661
701
 
@@ -721,7 +761,7 @@ class EllipticEnvelope(BaseTransformer):
721
761
  drop_input_cols=self._drop_input_cols,
722
762
  expected_output_cols_type="float",
723
763
  )
724
- expected_output_cols = self._align_expected_output_names(
764
+ expected_output_cols, _ = self._align_expected_output(
725
765
  inference_method, dataset, expected_output_cols, output_cols_prefix
726
766
  )
727
767
  elif isinstance(dataset, pd.DataFrame):
@@ -786,7 +826,7 @@ class EllipticEnvelope(BaseTransformer):
786
826
  drop_input_cols=self._drop_input_cols,
787
827
  expected_output_cols_type="float",
788
828
  )
789
- expected_output_cols = self._align_expected_output_names(
829
+ expected_output_cols, _ = self._align_expected_output(
790
830
  inference_method, dataset, expected_output_cols, output_cols_prefix
791
831
  )
792
832
 
@@ -853,7 +893,7 @@ class EllipticEnvelope(BaseTransformer):
853
893
  drop_input_cols = self._drop_input_cols,
854
894
  expected_output_cols_type="float",
855
895
  )
856
- expected_output_cols = self._align_expected_output_names(
896
+ expected_output_cols, _ = self._align_expected_output(
857
897
  inference_method, dataset, expected_output_cols, output_cols_prefix
858
898
  )
859
899
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -475,12 +472,23 @@ class EmpiricalCovariance(BaseTransformer):
475
472
  autogenerated=self._autogenerated,
476
473
  subproject=_SUBPROJECT,
477
474
  )
478
- output_result, fitted_estimator = model_trainer.train_fit_predict(
479
- drop_input_cols=self._drop_input_cols,
480
- expected_output_cols_list=(
481
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
482
- ),
475
+ expected_output_cols = (
476
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
483
477
  )
478
+ if isinstance(dataset, DataFrame):
479
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
480
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
481
+ )
482
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
483
+ drop_input_cols=self._drop_input_cols,
484
+ expected_output_cols_list=expected_output_cols,
485
+ example_output_pd_df=example_output_pd_df,
486
+ )
487
+ else:
488
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
489
+ drop_input_cols=self._drop_input_cols,
490
+ expected_output_cols_list=expected_output_cols,
491
+ )
484
492
  self._sklearn_object = fitted_estimator
485
493
  self._is_fitted = True
486
494
  return output_result
@@ -503,6 +511,7 @@ class EmpiricalCovariance(BaseTransformer):
503
511
  """
504
512
  self._infer_input_output_cols(dataset)
505
513
  super()._check_dataset_type(dataset)
514
+
506
515
  model_trainer = ModelTrainerBuilder.build_fit_transform(
507
516
  estimator=self._sklearn_object,
508
517
  dataset=dataset,
@@ -559,12 +568,41 @@ class EmpiricalCovariance(BaseTransformer):
559
568
 
560
569
  return rv
561
570
 
562
- def _align_expected_output_names(
563
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
564
- ) -> List[str]:
571
+ def _align_expected_output(
572
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
573
+ ) -> Tuple[List[str], pd.DataFrame]:
574
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
575
+ and output dataframe with 1 line.
576
+ If the method is fit_predict, run 2 lines of data.
577
+ """
565
578
  # in case the inferred output column names dimension is different
566
579
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
567
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
580
+
581
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
582
+ # so change the minimum of number of rows to 2
583
+ num_examples = 2
584
+ statement_params = telemetry.get_function_usage_statement_params(
585
+ project=_PROJECT,
586
+ subproject=_SUBPROJECT,
587
+ function_name=telemetry.get_statement_params_full_func_name(
588
+ inspect.currentframe(), EmpiricalCovariance.__class__.__name__
589
+ ),
590
+ api_calls=[Session.call],
591
+ custom_tags={"autogen": True} if self._autogenerated else None,
592
+ )
593
+ if output_cols_prefix == "fit_predict_":
594
+ if hasattr(self._sklearn_object, "n_clusters"):
595
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
596
+ num_examples = self._sklearn_object.n_clusters
597
+ elif hasattr(self._sklearn_object, "min_samples"):
598
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
599
+ num_examples = self._sklearn_object.min_samples
600
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
601
+ # LocalOutlierFactor expects n_neighbors <= n_samples
602
+ num_examples = self._sklearn_object.n_neighbors
603
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
604
+ else:
605
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
568
606
 
569
607
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
570
608
  # seen during the fit.
@@ -576,12 +614,14 @@ class EmpiricalCovariance(BaseTransformer):
576
614
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
577
615
  if self.sample_weight_col:
578
616
  output_df_columns_set -= set(self.sample_weight_col)
617
+
579
618
  # if the dimension of inferred output column names is correct; use it
580
619
  if len(expected_output_cols_list) == len(output_df_columns_set):
581
- return expected_output_cols_list
620
+ return expected_output_cols_list, output_df_pd
582
621
  # otherwise, use the sklearn estimator's output
583
622
  else:
584
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
623
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
624
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
585
625
 
586
626
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
587
627
  @telemetry.send_api_usage_telemetry(
@@ -627,7 +667,7 @@ class EmpiricalCovariance(BaseTransformer):
627
667
  drop_input_cols=self._drop_input_cols,
628
668
  expected_output_cols_type="float",
629
669
  )
630
- expected_output_cols = self._align_expected_output_names(
670
+ expected_output_cols, _ = self._align_expected_output(
631
671
  inference_method, dataset, expected_output_cols, output_cols_prefix
632
672
  )
633
673
 
@@ -693,7 +733,7 @@ class EmpiricalCovariance(BaseTransformer):
693
733
  drop_input_cols=self._drop_input_cols,
694
734
  expected_output_cols_type="float",
695
735
  )
696
- expected_output_cols = self._align_expected_output_names(
736
+ expected_output_cols, _ = self._align_expected_output(
697
737
  inference_method, dataset, expected_output_cols, output_cols_prefix
698
738
  )
699
739
  elif isinstance(dataset, pd.DataFrame):
@@ -756,7 +796,7 @@ class EmpiricalCovariance(BaseTransformer):
756
796
  drop_input_cols=self._drop_input_cols,
757
797
  expected_output_cols_type="float",
758
798
  )
759
- expected_output_cols = self._align_expected_output_names(
799
+ expected_output_cols, _ = self._align_expected_output(
760
800
  inference_method, dataset, expected_output_cols, output_cols_prefix
761
801
  )
762
802
 
@@ -821,7 +861,7 @@ class EmpiricalCovariance(BaseTransformer):
821
861
  drop_input_cols = self._drop_input_cols,
822
862
  expected_output_cols_type="float",
823
863
  )
824
- expected_output_cols = self._align_expected_output_names(
864
+ expected_output_cols, _ = self._align_expected_output(
825
865
  inference_method, dataset, expected_output_cols, output_cols_prefix
826
866
  )
827
867
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -523,12 +520,23 @@ class GraphicalLasso(BaseTransformer):
523
520
  autogenerated=self._autogenerated,
524
521
  subproject=_SUBPROJECT,
525
522
  )
526
- output_result, fitted_estimator = model_trainer.train_fit_predict(
527
- drop_input_cols=self._drop_input_cols,
528
- expected_output_cols_list=(
529
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
530
- ),
523
+ expected_output_cols = (
524
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
531
525
  )
526
+ if isinstance(dataset, DataFrame):
527
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
528
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
529
+ )
530
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
531
+ drop_input_cols=self._drop_input_cols,
532
+ expected_output_cols_list=expected_output_cols,
533
+ example_output_pd_df=example_output_pd_df,
534
+ )
535
+ else:
536
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
537
+ drop_input_cols=self._drop_input_cols,
538
+ expected_output_cols_list=expected_output_cols,
539
+ )
532
540
  self._sklearn_object = fitted_estimator
533
541
  self._is_fitted = True
534
542
  return output_result
@@ -551,6 +559,7 @@ class GraphicalLasso(BaseTransformer):
551
559
  """
552
560
  self._infer_input_output_cols(dataset)
553
561
  super()._check_dataset_type(dataset)
562
+
554
563
  model_trainer = ModelTrainerBuilder.build_fit_transform(
555
564
  estimator=self._sklearn_object,
556
565
  dataset=dataset,
@@ -607,12 +616,41 @@ class GraphicalLasso(BaseTransformer):
607
616
 
608
617
  return rv
609
618
 
610
- def _align_expected_output_names(
611
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
612
- ) -> List[str]:
619
+ def _align_expected_output(
620
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
621
+ ) -> Tuple[List[str], pd.DataFrame]:
622
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
623
+ and output dataframe with 1 line.
624
+ If the method is fit_predict, run 2 lines of data.
625
+ """
613
626
  # in case the inferred output column names dimension is different
614
627
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
615
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
628
+
629
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
630
+ # so change the minimum of number of rows to 2
631
+ num_examples = 2
632
+ statement_params = telemetry.get_function_usage_statement_params(
633
+ project=_PROJECT,
634
+ subproject=_SUBPROJECT,
635
+ function_name=telemetry.get_statement_params_full_func_name(
636
+ inspect.currentframe(), GraphicalLasso.__class__.__name__
637
+ ),
638
+ api_calls=[Session.call],
639
+ custom_tags={"autogen": True} if self._autogenerated else None,
640
+ )
641
+ if output_cols_prefix == "fit_predict_":
642
+ if hasattr(self._sklearn_object, "n_clusters"):
643
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
644
+ num_examples = self._sklearn_object.n_clusters
645
+ elif hasattr(self._sklearn_object, "min_samples"):
646
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
647
+ num_examples = self._sklearn_object.min_samples
648
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
649
+ # LocalOutlierFactor expects n_neighbors <= n_samples
650
+ num_examples = self._sklearn_object.n_neighbors
651
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
652
+ else:
653
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
616
654
 
617
655
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
618
656
  # seen during the fit.
@@ -624,12 +662,14 @@ class GraphicalLasso(BaseTransformer):
624
662
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
625
663
  if self.sample_weight_col:
626
664
  output_df_columns_set -= set(self.sample_weight_col)
665
+
627
666
  # if the dimension of inferred output column names is correct; use it
628
667
  if len(expected_output_cols_list) == len(output_df_columns_set):
629
- return expected_output_cols_list
668
+ return expected_output_cols_list, output_df_pd
630
669
  # otherwise, use the sklearn estimator's output
631
670
  else:
632
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
671
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
672
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
633
673
 
634
674
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
635
675
  @telemetry.send_api_usage_telemetry(
@@ -675,7 +715,7 @@ class GraphicalLasso(BaseTransformer):
675
715
  drop_input_cols=self._drop_input_cols,
676
716
  expected_output_cols_type="float",
677
717
  )
678
- expected_output_cols = self._align_expected_output_names(
718
+ expected_output_cols, _ = self._align_expected_output(
679
719
  inference_method, dataset, expected_output_cols, output_cols_prefix
680
720
  )
681
721
 
@@ -741,7 +781,7 @@ class GraphicalLasso(BaseTransformer):
741
781
  drop_input_cols=self._drop_input_cols,
742
782
  expected_output_cols_type="float",
743
783
  )
744
- expected_output_cols = self._align_expected_output_names(
784
+ expected_output_cols, _ = self._align_expected_output(
745
785
  inference_method, dataset, expected_output_cols, output_cols_prefix
746
786
  )
747
787
  elif isinstance(dataset, pd.DataFrame):
@@ -804,7 +844,7 @@ class GraphicalLasso(BaseTransformer):
804
844
  drop_input_cols=self._drop_input_cols,
805
845
  expected_output_cols_type="float",
806
846
  )
807
- expected_output_cols = self._align_expected_output_names(
847
+ expected_output_cols, _ = self._align_expected_output(
808
848
  inference_method, dataset, expected_output_cols, output_cols_prefix
809
849
  )
810
850
 
@@ -869,7 +909,7 @@ class GraphicalLasso(BaseTransformer):
869
909
  drop_input_cols = self._drop_input_cols,
870
910
  expected_output_cols_type="float",
871
911
  )
872
- expected_output_cols = self._align_expected_output_names(
912
+ expected_output_cols, _ = self._align_expected_output(
873
913
  inference_method, dataset, expected_output_cols, output_cols_prefix
874
914
  )
875
915
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -549,12 +546,23 @@ class GraphicalLassoCV(BaseTransformer):
549
546
  autogenerated=self._autogenerated,
550
547
  subproject=_SUBPROJECT,
551
548
  )
552
- output_result, fitted_estimator = model_trainer.train_fit_predict(
553
- drop_input_cols=self._drop_input_cols,
554
- expected_output_cols_list=(
555
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
556
- ),
549
+ expected_output_cols = (
550
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
557
551
  )
552
+ if isinstance(dataset, DataFrame):
553
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
554
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
555
+ )
556
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
557
+ drop_input_cols=self._drop_input_cols,
558
+ expected_output_cols_list=expected_output_cols,
559
+ example_output_pd_df=example_output_pd_df,
560
+ )
561
+ else:
562
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
563
+ drop_input_cols=self._drop_input_cols,
564
+ expected_output_cols_list=expected_output_cols,
565
+ )
558
566
  self._sklearn_object = fitted_estimator
559
567
  self._is_fitted = True
560
568
  return output_result
@@ -577,6 +585,7 @@ class GraphicalLassoCV(BaseTransformer):
577
585
  """
578
586
  self._infer_input_output_cols(dataset)
579
587
  super()._check_dataset_type(dataset)
588
+
580
589
  model_trainer = ModelTrainerBuilder.build_fit_transform(
581
590
  estimator=self._sklearn_object,
582
591
  dataset=dataset,
@@ -633,12 +642,41 @@ class GraphicalLassoCV(BaseTransformer):
633
642
 
634
643
  return rv
635
644
 
636
- def _align_expected_output_names(
637
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
638
- ) -> List[str]:
645
+ def _align_expected_output(
646
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
647
+ ) -> Tuple[List[str], pd.DataFrame]:
648
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
649
+ and output dataframe with 1 line.
650
+ If the method is fit_predict, run 2 lines of data.
651
+ """
639
652
  # in case the inferred output column names dimension is different
640
653
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
641
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
654
+
655
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
656
+ # so change the minimum of number of rows to 2
657
+ num_examples = 2
658
+ statement_params = telemetry.get_function_usage_statement_params(
659
+ project=_PROJECT,
660
+ subproject=_SUBPROJECT,
661
+ function_name=telemetry.get_statement_params_full_func_name(
662
+ inspect.currentframe(), GraphicalLassoCV.__class__.__name__
663
+ ),
664
+ api_calls=[Session.call],
665
+ custom_tags={"autogen": True} if self._autogenerated else None,
666
+ )
667
+ if output_cols_prefix == "fit_predict_":
668
+ if hasattr(self._sklearn_object, "n_clusters"):
669
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
670
+ num_examples = self._sklearn_object.n_clusters
671
+ elif hasattr(self._sklearn_object, "min_samples"):
672
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
673
+ num_examples = self._sklearn_object.min_samples
674
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
675
+ # LocalOutlierFactor expects n_neighbors <= n_samples
676
+ num_examples = self._sklearn_object.n_neighbors
677
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
678
+ else:
679
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
642
680
 
643
681
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
644
682
  # seen during the fit.
@@ -650,12 +688,14 @@ class GraphicalLassoCV(BaseTransformer):
650
688
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
651
689
  if self.sample_weight_col:
652
690
  output_df_columns_set -= set(self.sample_weight_col)
691
+
653
692
  # if the dimension of inferred output column names is correct; use it
654
693
  if len(expected_output_cols_list) == len(output_df_columns_set):
655
- return expected_output_cols_list
694
+ return expected_output_cols_list, output_df_pd
656
695
  # otherwise, use the sklearn estimator's output
657
696
  else:
658
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
697
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
698
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
659
699
 
660
700
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
661
701
  @telemetry.send_api_usage_telemetry(
@@ -701,7 +741,7 @@ class GraphicalLassoCV(BaseTransformer):
701
741
  drop_input_cols=self._drop_input_cols,
702
742
  expected_output_cols_type="float",
703
743
  )
704
- expected_output_cols = self._align_expected_output_names(
744
+ expected_output_cols, _ = self._align_expected_output(
705
745
  inference_method, dataset, expected_output_cols, output_cols_prefix
706
746
  )
707
747
 
@@ -767,7 +807,7 @@ class GraphicalLassoCV(BaseTransformer):
767
807
  drop_input_cols=self._drop_input_cols,
768
808
  expected_output_cols_type="float",
769
809
  )
770
- expected_output_cols = self._align_expected_output_names(
810
+ expected_output_cols, _ = self._align_expected_output(
771
811
  inference_method, dataset, expected_output_cols, output_cols_prefix
772
812
  )
773
813
  elif isinstance(dataset, pd.DataFrame):
@@ -830,7 +870,7 @@ class GraphicalLassoCV(BaseTransformer):
830
870
  drop_input_cols=self._drop_input_cols,
831
871
  expected_output_cols_type="float",
832
872
  )
833
- expected_output_cols = self._align_expected_output_names(
873
+ expected_output_cols, _ = self._align_expected_output(
834
874
  inference_method, dataset, expected_output_cols, output_cols_prefix
835
875
  )
836
876
 
@@ -895,7 +935,7 @@ class GraphicalLassoCV(BaseTransformer):
895
935
  drop_input_cols = self._drop_input_cols,
896
936
  expected_output_cols_type="float",
897
937
  )
898
- expected_output_cols = self._align_expected_output_names(
938
+ expected_output_cols, _ = self._align_expected_output(
899
939
  inference_method, dataset, expected_output_cols, output_cols_prefix
900
940
  )
901
941