snowflake-ml-python 1.6.1__py3-none-any.whl → 1.6.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (284) hide show
  1. snowflake/cortex/__init__.py +4 -0
  2. snowflake/cortex/_classify_text.py +2 -2
  3. snowflake/cortex/_embed_text_1024.py +37 -0
  4. snowflake/cortex/_embed_text_768.py +37 -0
  5. snowflake/cortex/_extract_answer.py +2 -2
  6. snowflake/cortex/_sentiment.py +2 -2
  7. snowflake/cortex/_summarize.py +2 -2
  8. snowflake/cortex/_translate.py +2 -2
  9. snowflake/cortex/_util.py +4 -4
  10. snowflake/ml/_internal/env_utils.py +5 -5
  11. snowflake/ml/_internal/exceptions/error_codes.py +2 -0
  12. snowflake/ml/_internal/telemetry.py +142 -20
  13. snowflake/ml/_internal/utils/db_utils.py +50 -0
  14. snowflake/ml/_internal/utils/identifier.py +48 -11
  15. snowflake/ml/_internal/utils/service_logger.py +63 -0
  16. snowflake/ml/_internal/utils/snowflake_env.py +23 -13
  17. snowflake/ml/_internal/utils/sql_identifier.py +26 -2
  18. snowflake/ml/_internal/utils/table_manager.py +19 -1
  19. snowflake/ml/data/_internal/arrow_ingestor.py +1 -11
  20. snowflake/ml/data/data_connector.py +33 -7
  21. snowflake/ml/data/ingestor_utils.py +20 -10
  22. snowflake/ml/data/torch_utils.py +68 -0
  23. snowflake/ml/dataset/dataset.py +1 -3
  24. snowflake/ml/feature_store/access_manager.py +3 -3
  25. snowflake/ml/feature_store/feature_store.py +60 -19
  26. snowflake/ml/feature_store/feature_view.py +84 -30
  27. snowflake/ml/fileset/embedded_stage_fs.py +1 -1
  28. snowflake/ml/fileset/fileset.py +1 -1
  29. snowflake/ml/fileset/sfcfs.py +9 -3
  30. snowflake/ml/fileset/stage_fs.py +2 -1
  31. snowflake/ml/lineage/lineage_node.py +7 -2
  32. snowflake/ml/model/__init__.py +1 -2
  33. snowflake/ml/model/_client/model/model_version_impl.py +96 -12
  34. snowflake/ml/model/_client/ops/model_ops.py +124 -6
  35. snowflake/ml/model/_client/ops/service_ops.py +309 -9
  36. snowflake/ml/model/_client/service/model_deployment_spec.py +8 -5
  37. snowflake/ml/model/_client/service/model_deployment_spec_schema.py +2 -2
  38. snowflake/ml/model/_client/sql/_base.py +5 -0
  39. snowflake/ml/model/_client/sql/model.py +1 -0
  40. snowflake/ml/model/_client/sql/model_version.py +9 -5
  41. snowflake/ml/model/_client/sql/service.py +121 -20
  42. snowflake/ml/model/_model_composer/model_composer.py +11 -39
  43. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +31 -11
  44. snowflake/ml/model/_packager/model_env/model_env.py +4 -38
  45. snowflake/ml/model/_packager/model_handlers/_utils.py +134 -28
  46. snowflake/ml/model/_packager/model_handlers/catboost.py +31 -30
  47. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +26 -18
  48. snowflake/ml/model/_packager/model_handlers/lightgbm.py +31 -58
  49. snowflake/ml/model/_packager/model_handlers/mlflow.py +3 -5
  50. snowflake/ml/model/_packager/model_handlers/model_objective_utils.py +169 -0
  51. snowflake/ml/model/_packager/model_handlers/sentence_transformers.py +15 -8
  52. snowflake/ml/model/_packager/model_handlers/sklearn.py +56 -60
  53. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +141 -9
  54. snowflake/ml/model/_packager/model_handlers/torchscript.py +2 -2
  55. snowflake/ml/model/_packager/model_handlers/xgboost.py +63 -48
  56. snowflake/ml/model/_packager/model_meta/model_meta.py +16 -42
  57. snowflake/ml/model/_packager/model_meta/model_meta_schema.py +1 -14
  58. snowflake/ml/model/_packager/model_packager.py +14 -8
  59. snowflake/ml/model/_packager/model_runtime/model_runtime.py +11 -0
  60. snowflake/ml/model/_signatures/pytorch_handler.py +1 -1
  61. snowflake/ml/model/_signatures/snowpark_handler.py +3 -2
  62. snowflake/ml/model/_signatures/utils.py +9 -0
  63. snowflake/ml/model/type_hints.py +12 -145
  64. snowflake/ml/modeling/_internal/constants.py +1 -0
  65. snowflake/ml/modeling/_internal/local_implementations/pandas_handlers.py +5 -5
  66. snowflake/ml/modeling/_internal/local_implementations/pandas_trainer.py +9 -6
  67. snowflake/ml/modeling/_internal/model_specifications.py +2 -0
  68. snowflake/ml/modeling/_internal/model_trainer.py +1 -0
  69. snowflake/ml/modeling/_internal/snowpark_implementations/distributed_hpo_trainer.py +2 -4
  70. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_handlers.py +5 -5
  71. snowflake/ml/modeling/_internal/snowpark_implementations/snowpark_trainer.py +130 -166
  72. snowflake/ml/modeling/_internal/snowpark_implementations/xgboost_external_memory_trainer.py +0 -1
  73. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +61 -21
  74. snowflake/ml/modeling/cluster/affinity_propagation.py +61 -21
  75. snowflake/ml/modeling/cluster/agglomerative_clustering.py +61 -21
  76. snowflake/ml/modeling/cluster/birch.py +61 -21
  77. snowflake/ml/modeling/cluster/bisecting_k_means.py +61 -21
  78. snowflake/ml/modeling/cluster/dbscan.py +61 -21
  79. snowflake/ml/modeling/cluster/feature_agglomeration.py +61 -21
  80. snowflake/ml/modeling/cluster/k_means.py +61 -21
  81. snowflake/ml/modeling/cluster/mean_shift.py +61 -21
  82. snowflake/ml/modeling/cluster/mini_batch_k_means.py +61 -21
  83. snowflake/ml/modeling/cluster/optics.py +61 -21
  84. snowflake/ml/modeling/cluster/spectral_biclustering.py +61 -21
  85. snowflake/ml/modeling/cluster/spectral_clustering.py +61 -21
  86. snowflake/ml/modeling/cluster/spectral_coclustering.py +61 -21
  87. snowflake/ml/modeling/compose/column_transformer.py +61 -21
  88. snowflake/ml/modeling/compose/transformed_target_regressor.py +61 -21
  89. snowflake/ml/modeling/covariance/elliptic_envelope.py +61 -21
  90. snowflake/ml/modeling/covariance/empirical_covariance.py +61 -21
  91. snowflake/ml/modeling/covariance/graphical_lasso.py +61 -21
  92. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +61 -21
  93. snowflake/ml/modeling/covariance/ledoit_wolf.py +61 -21
  94. snowflake/ml/modeling/covariance/min_cov_det.py +61 -21
  95. snowflake/ml/modeling/covariance/oas.py +61 -21
  96. snowflake/ml/modeling/covariance/shrunk_covariance.py +61 -21
  97. snowflake/ml/modeling/decomposition/dictionary_learning.py +61 -21
  98. snowflake/ml/modeling/decomposition/factor_analysis.py +61 -21
  99. snowflake/ml/modeling/decomposition/fast_ica.py +61 -21
  100. snowflake/ml/modeling/decomposition/incremental_pca.py +61 -21
  101. snowflake/ml/modeling/decomposition/kernel_pca.py +61 -21
  102. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +61 -21
  103. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +61 -21
  104. snowflake/ml/modeling/decomposition/pca.py +61 -21
  105. snowflake/ml/modeling/decomposition/sparse_pca.py +61 -21
  106. snowflake/ml/modeling/decomposition/truncated_svd.py +61 -21
  107. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +61 -21
  108. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +61 -21
  109. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +61 -21
  110. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +61 -21
  111. snowflake/ml/modeling/ensemble/bagging_classifier.py +61 -21
  112. snowflake/ml/modeling/ensemble/bagging_regressor.py +61 -21
  113. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +61 -21
  114. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +61 -21
  115. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +61 -21
  116. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +61 -21
  117. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +61 -21
  118. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +61 -21
  119. snowflake/ml/modeling/ensemble/isolation_forest.py +61 -21
  120. snowflake/ml/modeling/ensemble/random_forest_classifier.py +61 -21
  121. snowflake/ml/modeling/ensemble/random_forest_regressor.py +61 -21
  122. snowflake/ml/modeling/ensemble/stacking_regressor.py +61 -21
  123. snowflake/ml/modeling/ensemble/voting_classifier.py +61 -21
  124. snowflake/ml/modeling/ensemble/voting_regressor.py +61 -21
  125. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +61 -21
  126. snowflake/ml/modeling/feature_selection/select_fdr.py +61 -21
  127. snowflake/ml/modeling/feature_selection/select_fpr.py +61 -21
  128. snowflake/ml/modeling/feature_selection/select_fwe.py +61 -21
  129. snowflake/ml/modeling/feature_selection/select_k_best.py +61 -21
  130. snowflake/ml/modeling/feature_selection/select_percentile.py +61 -21
  131. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +61 -21
  132. snowflake/ml/modeling/feature_selection/variance_threshold.py +61 -21
  133. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +61 -21
  134. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +61 -21
  135. snowflake/ml/modeling/impute/iterative_imputer.py +61 -21
  136. snowflake/ml/modeling/impute/knn_imputer.py +61 -21
  137. snowflake/ml/modeling/impute/missing_indicator.py +61 -21
  138. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +61 -21
  139. snowflake/ml/modeling/kernel_approximation/nystroem.py +61 -21
  140. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +61 -21
  141. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +61 -21
  142. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +61 -21
  143. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +61 -21
  144. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +61 -21
  145. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +61 -21
  146. snowflake/ml/modeling/linear_model/ard_regression.py +61 -21
  147. snowflake/ml/modeling/linear_model/bayesian_ridge.py +61 -21
  148. snowflake/ml/modeling/linear_model/elastic_net.py +61 -21
  149. snowflake/ml/modeling/linear_model/elastic_net_cv.py +61 -21
  150. snowflake/ml/modeling/linear_model/gamma_regressor.py +61 -21
  151. snowflake/ml/modeling/linear_model/huber_regressor.py +61 -21
  152. snowflake/ml/modeling/linear_model/lars.py +61 -21
  153. snowflake/ml/modeling/linear_model/lars_cv.py +61 -21
  154. snowflake/ml/modeling/linear_model/lasso.py +61 -21
  155. snowflake/ml/modeling/linear_model/lasso_cv.py +61 -21
  156. snowflake/ml/modeling/linear_model/lasso_lars.py +61 -21
  157. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +61 -21
  158. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +61 -21
  159. snowflake/ml/modeling/linear_model/linear_regression.py +61 -21
  160. snowflake/ml/modeling/linear_model/logistic_regression.py +61 -21
  161. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +61 -21
  162. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +61 -21
  163. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +61 -21
  164. snowflake/ml/modeling/linear_model/multi_task_lasso.py +61 -21
  165. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +61 -21
  166. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +61 -21
  167. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +61 -21
  168. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +61 -21
  169. snowflake/ml/modeling/linear_model/perceptron.py +61 -21
  170. snowflake/ml/modeling/linear_model/poisson_regressor.py +61 -21
  171. snowflake/ml/modeling/linear_model/ransac_regressor.py +61 -21
  172. snowflake/ml/modeling/linear_model/ridge.py +61 -21
  173. snowflake/ml/modeling/linear_model/ridge_classifier.py +61 -21
  174. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +61 -21
  175. snowflake/ml/modeling/linear_model/ridge_cv.py +61 -21
  176. snowflake/ml/modeling/linear_model/sgd_classifier.py +61 -21
  177. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +61 -21
  178. snowflake/ml/modeling/linear_model/sgd_regressor.py +61 -21
  179. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +61 -21
  180. snowflake/ml/modeling/linear_model/tweedie_regressor.py +61 -21
  181. snowflake/ml/modeling/manifold/isomap.py +61 -21
  182. snowflake/ml/modeling/manifold/mds.py +61 -21
  183. snowflake/ml/modeling/manifold/spectral_embedding.py +61 -21
  184. snowflake/ml/modeling/manifold/tsne.py +61 -21
  185. snowflake/ml/modeling/metrics/metrics_utils.py +2 -2
  186. snowflake/ml/modeling/metrics/ranking.py +0 -3
  187. snowflake/ml/modeling/metrics/regression.py +0 -3
  188. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +61 -21
  189. snowflake/ml/modeling/mixture/gaussian_mixture.py +61 -21
  190. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +61 -21
  191. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +61 -21
  192. snowflake/ml/modeling/multiclass/output_code_classifier.py +61 -21
  193. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +61 -21
  194. snowflake/ml/modeling/naive_bayes/categorical_nb.py +61 -21
  195. snowflake/ml/modeling/naive_bayes/complement_nb.py +61 -21
  196. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +61 -21
  197. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +61 -21
  198. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +61 -21
  199. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +61 -21
  200. snowflake/ml/modeling/neighbors/kernel_density.py +61 -21
  201. snowflake/ml/modeling/neighbors/local_outlier_factor.py +61 -21
  202. snowflake/ml/modeling/neighbors/nearest_centroid.py +61 -21
  203. snowflake/ml/modeling/neighbors/nearest_neighbors.py +61 -21
  204. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +61 -21
  205. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +61 -21
  206. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +61 -21
  207. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +61 -21
  208. snowflake/ml/modeling/neural_network/mlp_classifier.py +61 -21
  209. snowflake/ml/modeling/neural_network/mlp_regressor.py +61 -21
  210. snowflake/ml/modeling/parameters/disable_model_tracer.py +5 -0
  211. snowflake/ml/modeling/pipeline/pipeline.py +1 -13
  212. snowflake/ml/modeling/preprocessing/polynomial_features.py +61 -21
  213. snowflake/ml/modeling/semi_supervised/label_propagation.py +61 -21
  214. snowflake/ml/modeling/semi_supervised/label_spreading.py +61 -21
  215. snowflake/ml/modeling/svm/linear_svc.py +61 -21
  216. snowflake/ml/modeling/svm/linear_svr.py +61 -21
  217. snowflake/ml/modeling/svm/nu_svc.py +61 -21
  218. snowflake/ml/modeling/svm/nu_svr.py +61 -21
  219. snowflake/ml/modeling/svm/svc.py +61 -21
  220. snowflake/ml/modeling/svm/svr.py +61 -21
  221. snowflake/ml/modeling/tree/decision_tree_classifier.py +61 -21
  222. snowflake/ml/modeling/tree/decision_tree_regressor.py +61 -21
  223. snowflake/ml/modeling/tree/extra_tree_classifier.py +61 -21
  224. snowflake/ml/modeling/tree/extra_tree_regressor.py +61 -21
  225. snowflake/ml/modeling/xgboost/xgb_classifier.py +64 -23
  226. snowflake/ml/modeling/xgboost/xgb_regressor.py +64 -23
  227. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +64 -23
  228. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +64 -23
  229. snowflake/ml/monitoring/_client/model_monitor.py +126 -0
  230. snowflake/ml/monitoring/_client/model_monitor_manager.py +361 -0
  231. snowflake/ml/monitoring/_client/model_monitor_version.py +1 -0
  232. snowflake/ml/monitoring/_client/monitor_sql_client.py +1335 -0
  233. snowflake/ml/monitoring/_client/queries/record_count.ssql +14 -0
  234. snowflake/ml/monitoring/_client/queries/rmse.ssql +28 -0
  235. snowflake/ml/monitoring/entities/model_monitor_config.py +28 -0
  236. snowflake/ml/monitoring/entities/model_monitor_interval.py +46 -0
  237. snowflake/ml/monitoring/entities/output_score_type.py +90 -0
  238. snowflake/ml/registry/_manager/model_manager.py +4 -0
  239. snowflake/ml/registry/registry.py +166 -8
  240. snowflake/ml/version.py +1 -1
  241. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/METADATA +43 -9
  242. snowflake_ml_python-1.6.3.dist-info/RECORD +400 -0
  243. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/WHEEL +1 -1
  244. snowflake/ml/_internal/container_services/image_registry/credential.py +0 -84
  245. snowflake/ml/_internal/container_services/image_registry/http_client.py +0 -127
  246. snowflake/ml/_internal/container_services/image_registry/imagelib.py +0 -400
  247. snowflake/ml/_internal/container_services/image_registry/registry_client.py +0 -212
  248. snowflake/ml/_internal/utils/log_stream_processor.py +0 -30
  249. snowflake/ml/_internal/utils/session_token_manager.py +0 -46
  250. snowflake/ml/_internal/utils/spcs_attribution_utils.py +0 -122
  251. snowflake/ml/_internal/utils/uri.py +0 -77
  252. snowflake/ml/data/torch_dataset.py +0 -33
  253. snowflake/ml/model/_api.py +0 -568
  254. snowflake/ml/model/_deploy_client/image_builds/base_image_builder.py +0 -12
  255. snowflake/ml/model/_deploy_client/image_builds/client_image_builder.py +0 -249
  256. snowflake/ml/model/_deploy_client/image_builds/docker_context.py +0 -130
  257. snowflake/ml/model/_deploy_client/image_builds/gunicorn_run.sh +0 -36
  258. snowflake/ml/model/_deploy_client/image_builds/inference_server/main.py +0 -268
  259. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +0 -215
  260. snowflake/ml/model/_deploy_client/image_builds/templates/dockerfile_template +0 -53
  261. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +0 -38
  262. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +0 -105
  263. snowflake/ml/model/_deploy_client/snowservice/deploy.py +0 -611
  264. snowflake/ml/model/_deploy_client/snowservice/deploy_options.py +0 -116
  265. snowflake/ml/model/_deploy_client/snowservice/instance_types.py +0 -10
  266. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template +0 -28
  267. snowflake/ml/model/_deploy_client/snowservice/templates/service_spec_template_with_model +0 -21
  268. snowflake/ml/model/_deploy_client/utils/constants.py +0 -48
  269. snowflake/ml/model/_deploy_client/utils/snowservice_client.py +0 -280
  270. snowflake/ml/model/_deploy_client/warehouse/deploy.py +0 -202
  271. snowflake/ml/model/_deploy_client/warehouse/infer_template.py +0 -99
  272. snowflake/ml/model/_packager/model_handlers/llm.py +0 -267
  273. snowflake/ml/model/_packager/model_meta/_core_requirements.py +0 -11
  274. snowflake/ml/model/deploy_platforms.py +0 -6
  275. snowflake/ml/model/models/llm.py +0 -104
  276. snowflake/ml/monitoring/monitor.py +0 -203
  277. snowflake/ml/registry/_initial_schema.py +0 -142
  278. snowflake/ml/registry/_schema.py +0 -82
  279. snowflake/ml/registry/_schema_upgrade_plans.py +0 -116
  280. snowflake/ml/registry/_schema_version_manager.py +0 -163
  281. snowflake/ml/registry/model_registry.py +0 -2048
  282. snowflake_ml_python-1.6.1.dist-info/RECORD +0 -422
  283. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/LICENSE.txt +0 -0
  284. {snowflake_ml_python-1.6.1.dist-info → snowflake_ml_python-1.6.3.dist-info}/top_level.txt +0 -0
@@ -59,7 +59,7 @@ def register_accumulator_udtf(*, session: Session, statement_params: Dict[str, A
59
59
  ]
60
60
  ),
61
61
  input_types=[T.BinaryType()],
62
- packages=["numpy", "cloudpickle"],
62
+ packages=[f"numpy=={np.__version__}", f"cloudpickle=={cloudpickle.__version__}"],
63
63
  name=accumulator,
64
64
  is_permanent=False,
65
65
  replace=True,
@@ -174,7 +174,7 @@ def register_sharded_dot_sum_computer(*, session: Session, statement_params: Dic
174
174
  ]
175
175
  ),
176
176
  input_types=[T.ArrayType(), T.IntegerType(), T.IntegerType()],
177
- packages=["numpy", "cloudpickle"],
177
+ packages=[f"numpy=={np.__version__}", f"cloudpickle=={cloudpickle.__version__}"],
178
178
  name=sharded_dot_and_sum_computer,
179
179
  is_permanent=False,
180
180
  replace=True,
@@ -102,7 +102,6 @@ def precision_recall_curve(
102
102
  ],
103
103
  statement_params=statement_params,
104
104
  anonymous=True,
105
- execute_as="caller",
106
105
  )
107
106
  def precision_recall_curve_anon_sproc(session: snowpark.Session) -> bytes:
108
107
  for query in queries[:-1]:
@@ -250,7 +249,6 @@ def roc_auc_score(
250
249
  ],
251
250
  statement_params=statement_params,
252
251
  anonymous=True,
253
- execute_as="caller",
254
252
  )
255
253
  def roc_auc_score_anon_sproc(session: snowpark.Session) -> bytes:
256
254
  for query in queries[:-1]:
@@ -354,7 +352,6 @@ def roc_curve(
354
352
  ],
355
353
  statement_params=statement_params,
356
354
  anonymous=True,
357
- execute_as="caller",
358
355
  )
359
356
  def roc_curve_anon_sproc(session: snowpark.Session) -> bytes:
360
357
  for query in queries[:-1]:
@@ -87,7 +87,6 @@ def d2_absolute_error_score(
87
87
  ],
88
88
  statement_params=statement_params,
89
89
  anonymous=True,
90
- execute_as="caller",
91
90
  )
92
91
  def d2_absolute_error_score_anon_sproc(session: snowpark.Session) -> bytes:
93
92
  for query in queries[:-1]:
@@ -185,7 +184,6 @@ def d2_pinball_score(
185
184
  ],
186
185
  statement_params=statement_params,
187
186
  anonymous=True,
188
- execute_as="caller",
189
187
  )
190
188
  def d2_pinball_score_anon_sproc(session: snowpark.Session) -> bytes:
191
189
  for query in queries[:-1]:
@@ -301,7 +299,6 @@ def explained_variance_score(
301
299
  ],
302
300
  statement_params=statement_params,
303
301
  anonymous=True,
304
- execute_as="caller",
305
302
  )
306
303
  def explained_variance_score_anon_sproc(session: snowpark.Session) -> bytes:
307
304
  for query in queries[:-1]:
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -602,12 +599,23 @@ class BayesianGaussianMixture(BaseTransformer):
602
599
  autogenerated=self._autogenerated,
603
600
  subproject=_SUBPROJECT,
604
601
  )
605
- output_result, fitted_estimator = model_trainer.train_fit_predict(
606
- drop_input_cols=self._drop_input_cols,
607
- expected_output_cols_list=(
608
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
609
- ),
602
+ expected_output_cols = (
603
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
610
604
  )
605
+ if isinstance(dataset, DataFrame):
606
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
607
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
608
+ )
609
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
610
+ drop_input_cols=self._drop_input_cols,
611
+ expected_output_cols_list=expected_output_cols,
612
+ example_output_pd_df=example_output_pd_df,
613
+ )
614
+ else:
615
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
616
+ drop_input_cols=self._drop_input_cols,
617
+ expected_output_cols_list=expected_output_cols,
618
+ )
611
619
  self._sklearn_object = fitted_estimator
612
620
  self._is_fitted = True
613
621
  return output_result
@@ -630,6 +638,7 @@ class BayesianGaussianMixture(BaseTransformer):
630
638
  """
631
639
  self._infer_input_output_cols(dataset)
632
640
  super()._check_dataset_type(dataset)
641
+
633
642
  model_trainer = ModelTrainerBuilder.build_fit_transform(
634
643
  estimator=self._sklearn_object,
635
644
  dataset=dataset,
@@ -686,12 +695,41 @@ class BayesianGaussianMixture(BaseTransformer):
686
695
 
687
696
  return rv
688
697
 
689
- def _align_expected_output_names(
690
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
691
- ) -> List[str]:
698
+ def _align_expected_output(
699
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
700
+ ) -> Tuple[List[str], pd.DataFrame]:
701
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
702
+ and output dataframe with 1 line.
703
+ If the method is fit_predict, run 2 lines of data.
704
+ """
692
705
  # in case the inferred output column names dimension is different
693
706
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
694
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
707
+
708
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
709
+ # so change the minimum of number of rows to 2
710
+ num_examples = 2
711
+ statement_params = telemetry.get_function_usage_statement_params(
712
+ project=_PROJECT,
713
+ subproject=_SUBPROJECT,
714
+ function_name=telemetry.get_statement_params_full_func_name(
715
+ inspect.currentframe(), BayesianGaussianMixture.__class__.__name__
716
+ ),
717
+ api_calls=[Session.call],
718
+ custom_tags={"autogen": True} if self._autogenerated else None,
719
+ )
720
+ if output_cols_prefix == "fit_predict_":
721
+ if hasattr(self._sklearn_object, "n_clusters"):
722
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
723
+ num_examples = self._sklearn_object.n_clusters
724
+ elif hasattr(self._sklearn_object, "min_samples"):
725
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
726
+ num_examples = self._sklearn_object.min_samples
727
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
728
+ # LocalOutlierFactor expects n_neighbors <= n_samples
729
+ num_examples = self._sklearn_object.n_neighbors
730
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
731
+ else:
732
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
695
733
 
696
734
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
697
735
  # seen during the fit.
@@ -703,12 +741,14 @@ class BayesianGaussianMixture(BaseTransformer):
703
741
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
704
742
  if self.sample_weight_col:
705
743
  output_df_columns_set -= set(self.sample_weight_col)
744
+
706
745
  # if the dimension of inferred output column names is correct; use it
707
746
  if len(expected_output_cols_list) == len(output_df_columns_set):
708
- return expected_output_cols_list
747
+ return expected_output_cols_list, output_df_pd
709
748
  # otherwise, use the sklearn estimator's output
710
749
  else:
711
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
750
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
751
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
712
752
 
713
753
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
714
754
  @telemetry.send_api_usage_telemetry(
@@ -756,7 +796,7 @@ class BayesianGaussianMixture(BaseTransformer):
756
796
  drop_input_cols=self._drop_input_cols,
757
797
  expected_output_cols_type="float",
758
798
  )
759
- expected_output_cols = self._align_expected_output_names(
799
+ expected_output_cols, _ = self._align_expected_output(
760
800
  inference_method, dataset, expected_output_cols, output_cols_prefix
761
801
  )
762
802
 
@@ -824,7 +864,7 @@ class BayesianGaussianMixture(BaseTransformer):
824
864
  drop_input_cols=self._drop_input_cols,
825
865
  expected_output_cols_type="float",
826
866
  )
827
- expected_output_cols = self._align_expected_output_names(
867
+ expected_output_cols, _ = self._align_expected_output(
828
868
  inference_method, dataset, expected_output_cols, output_cols_prefix
829
869
  )
830
870
  elif isinstance(dataset, pd.DataFrame):
@@ -887,7 +927,7 @@ class BayesianGaussianMixture(BaseTransformer):
887
927
  drop_input_cols=self._drop_input_cols,
888
928
  expected_output_cols_type="float",
889
929
  )
890
- expected_output_cols = self._align_expected_output_names(
930
+ expected_output_cols, _ = self._align_expected_output(
891
931
  inference_method, dataset, expected_output_cols, output_cols_prefix
892
932
  )
893
933
 
@@ -954,7 +994,7 @@ class BayesianGaussianMixture(BaseTransformer):
954
994
  drop_input_cols = self._drop_input_cols,
955
995
  expected_output_cols_type="float",
956
996
  )
957
- expected_output_cols = self._align_expected_output_names(
997
+ expected_output_cols, _ = self._align_expected_output(
958
998
  inference_method, dataset, expected_output_cols, output_cols_prefix
959
999
  )
960
1000
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -575,12 +572,23 @@ class GaussianMixture(BaseTransformer):
575
572
  autogenerated=self._autogenerated,
576
573
  subproject=_SUBPROJECT,
577
574
  )
578
- output_result, fitted_estimator = model_trainer.train_fit_predict(
579
- drop_input_cols=self._drop_input_cols,
580
- expected_output_cols_list=(
581
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
582
- ),
575
+ expected_output_cols = (
576
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
583
577
  )
578
+ if isinstance(dataset, DataFrame):
579
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
580
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
581
+ )
582
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
583
+ drop_input_cols=self._drop_input_cols,
584
+ expected_output_cols_list=expected_output_cols,
585
+ example_output_pd_df=example_output_pd_df,
586
+ )
587
+ else:
588
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
589
+ drop_input_cols=self._drop_input_cols,
590
+ expected_output_cols_list=expected_output_cols,
591
+ )
584
592
  self._sklearn_object = fitted_estimator
585
593
  self._is_fitted = True
586
594
  return output_result
@@ -603,6 +611,7 @@ class GaussianMixture(BaseTransformer):
603
611
  """
604
612
  self._infer_input_output_cols(dataset)
605
613
  super()._check_dataset_type(dataset)
614
+
606
615
  model_trainer = ModelTrainerBuilder.build_fit_transform(
607
616
  estimator=self._sklearn_object,
608
617
  dataset=dataset,
@@ -659,12 +668,41 @@ class GaussianMixture(BaseTransformer):
659
668
 
660
669
  return rv
661
670
 
662
- def _align_expected_output_names(
663
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
664
- ) -> List[str]:
671
+ def _align_expected_output(
672
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
673
+ ) -> Tuple[List[str], pd.DataFrame]:
674
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
675
+ and output dataframe with 1 line.
676
+ If the method is fit_predict, run 2 lines of data.
677
+ """
665
678
  # in case the inferred output column names dimension is different
666
679
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
667
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
680
+
681
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
682
+ # so change the minimum of number of rows to 2
683
+ num_examples = 2
684
+ statement_params = telemetry.get_function_usage_statement_params(
685
+ project=_PROJECT,
686
+ subproject=_SUBPROJECT,
687
+ function_name=telemetry.get_statement_params_full_func_name(
688
+ inspect.currentframe(), GaussianMixture.__class__.__name__
689
+ ),
690
+ api_calls=[Session.call],
691
+ custom_tags={"autogen": True} if self._autogenerated else None,
692
+ )
693
+ if output_cols_prefix == "fit_predict_":
694
+ if hasattr(self._sklearn_object, "n_clusters"):
695
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
696
+ num_examples = self._sklearn_object.n_clusters
697
+ elif hasattr(self._sklearn_object, "min_samples"):
698
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
699
+ num_examples = self._sklearn_object.min_samples
700
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
701
+ # LocalOutlierFactor expects n_neighbors <= n_samples
702
+ num_examples = self._sklearn_object.n_neighbors
703
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
704
+ else:
705
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
668
706
 
669
707
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
670
708
  # seen during the fit.
@@ -676,12 +714,14 @@ class GaussianMixture(BaseTransformer):
676
714
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
677
715
  if self.sample_weight_col:
678
716
  output_df_columns_set -= set(self.sample_weight_col)
717
+
679
718
  # if the dimension of inferred output column names is correct; use it
680
719
  if len(expected_output_cols_list) == len(output_df_columns_set):
681
- return expected_output_cols_list
720
+ return expected_output_cols_list, output_df_pd
682
721
  # otherwise, use the sklearn estimator's output
683
722
  else:
684
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
723
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
724
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
685
725
 
686
726
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
687
727
  @telemetry.send_api_usage_telemetry(
@@ -729,7 +769,7 @@ class GaussianMixture(BaseTransformer):
729
769
  drop_input_cols=self._drop_input_cols,
730
770
  expected_output_cols_type="float",
731
771
  )
732
- expected_output_cols = self._align_expected_output_names(
772
+ expected_output_cols, _ = self._align_expected_output(
733
773
  inference_method, dataset, expected_output_cols, output_cols_prefix
734
774
  )
735
775
 
@@ -797,7 +837,7 @@ class GaussianMixture(BaseTransformer):
797
837
  drop_input_cols=self._drop_input_cols,
798
838
  expected_output_cols_type="float",
799
839
  )
800
- expected_output_cols = self._align_expected_output_names(
840
+ expected_output_cols, _ = self._align_expected_output(
801
841
  inference_method, dataset, expected_output_cols, output_cols_prefix
802
842
  )
803
843
  elif isinstance(dataset, pd.DataFrame):
@@ -860,7 +900,7 @@ class GaussianMixture(BaseTransformer):
860
900
  drop_input_cols=self._drop_input_cols,
861
901
  expected_output_cols_type="float",
862
902
  )
863
- expected_output_cols = self._align_expected_output_names(
903
+ expected_output_cols, _ = self._align_expected_output(
864
904
  inference_method, dataset, expected_output_cols, output_cols_prefix
865
905
  )
866
906
 
@@ -927,7 +967,7 @@ class GaussianMixture(BaseTransformer):
927
967
  drop_input_cols = self._drop_input_cols,
928
968
  expected_output_cols_type="float",
929
969
  )
930
- expected_output_cols = self._align_expected_output_names(
970
+ expected_output_cols, _ = self._align_expected_output(
931
971
  inference_method, dataset, expected_output_cols, output_cols_prefix
932
972
  )
933
973
 
@@ -4,14 +4,12 @@
4
4
  #
5
5
  import inspect
6
6
  import os
7
- import posixpath
8
- from typing import Iterable, Optional, Union, List, Any, Dict, Callable, Set
9
- from typing_extensions import TypeGuard
7
+ from typing import Iterable, Optional, Union, List, Any, Dict, Set, Tuple
10
8
  from uuid import uuid4
11
9
 
12
10
  import cloudpickle as cp
13
- import pandas as pd
14
11
  import numpy as np
12
+ import pandas as pd
15
13
  from numpy import typing as npt
16
14
 
17
15
 
@@ -24,12 +22,11 @@ from snowflake.ml.modeling.framework.base import BaseTransformer, _process_cols
24
22
  from snowflake.ml._internal import telemetry
25
23
  from snowflake.ml._internal.exceptions import error_codes, exceptions, modeling_error_messages
26
24
  from snowflake.ml._internal.env_utils import SNOWML_SPROC_ENV
27
- from snowflake.ml._internal.utils import pkg_version_utils, identifier
25
+ from snowflake.ml._internal.utils import identifier
28
26
  from snowflake.snowpark import DataFrame, Session
29
27
  from snowflake.snowpark._internal.type_utils import convert_sp_to_sf_type
30
28
  from snowflake.ml.modeling._internal.model_trainer_builder import ModelTrainerBuilder
31
29
  from snowflake.ml.modeling._internal.transformer_protocols import (
32
- ModelTransformHandlers,
33
30
  BatchInferenceKwargsTypedDict,
34
31
  ScoreKwargsTypedDict
35
32
  )
@@ -483,12 +480,23 @@ class OneVsOneClassifier(BaseTransformer):
483
480
  autogenerated=self._autogenerated,
484
481
  subproject=_SUBPROJECT,
485
482
  )
486
- output_result, fitted_estimator = model_trainer.train_fit_predict(
487
- drop_input_cols=self._drop_input_cols,
488
- expected_output_cols_list=(
489
- self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
490
- ),
483
+ expected_output_cols = (
484
+ self.output_cols if self.output_cols else self._get_output_column_names(output_cols_prefix)
491
485
  )
486
+ if isinstance(dataset, DataFrame):
487
+ expected_output_cols, example_output_pd_df = self._align_expected_output(
488
+ "fit_predict", dataset, expected_output_cols, output_cols_prefix
489
+ )
490
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
491
+ drop_input_cols=self._drop_input_cols,
492
+ expected_output_cols_list=expected_output_cols,
493
+ example_output_pd_df=example_output_pd_df,
494
+ )
495
+ else:
496
+ output_result, fitted_estimator = model_trainer.train_fit_predict(
497
+ drop_input_cols=self._drop_input_cols,
498
+ expected_output_cols_list=expected_output_cols,
499
+ )
492
500
  self._sklearn_object = fitted_estimator
493
501
  self._is_fitted = True
494
502
  return output_result
@@ -511,6 +519,7 @@ class OneVsOneClassifier(BaseTransformer):
511
519
  """
512
520
  self._infer_input_output_cols(dataset)
513
521
  super()._check_dataset_type(dataset)
522
+
514
523
  model_trainer = ModelTrainerBuilder.build_fit_transform(
515
524
  estimator=self._sklearn_object,
516
525
  dataset=dataset,
@@ -567,12 +576,41 @@ class OneVsOneClassifier(BaseTransformer):
567
576
 
568
577
  return rv
569
578
 
570
- def _align_expected_output_names(
571
- self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str
572
- ) -> List[str]:
579
+ def _align_expected_output(
580
+ self, method: str, dataset: DataFrame, expected_output_cols_list: List[str], output_cols_prefix: str,
581
+ ) -> Tuple[List[str], pd.DataFrame]:
582
+ """ Run 1 line of data with the desired method, and return one tuple that consists of the output column names
583
+ and output dataframe with 1 line.
584
+ If the method is fit_predict, run 2 lines of data.
585
+ """
573
586
  # in case the inferred output column names dimension is different
574
587
  # we use one line of snowpark dataframe and put it into sklearn estimator using pandas
575
- sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas()
588
+
589
+ # For fit_predict method, a minimum of 2 is required by MinCovDet, BayesianGaussianMixture
590
+ # so change the minimum of number of rows to 2
591
+ num_examples = 2
592
+ statement_params = telemetry.get_function_usage_statement_params(
593
+ project=_PROJECT,
594
+ subproject=_SUBPROJECT,
595
+ function_name=telemetry.get_statement_params_full_func_name(
596
+ inspect.currentframe(), OneVsOneClassifier.__class__.__name__
597
+ ),
598
+ api_calls=[Session.call],
599
+ custom_tags={"autogen": True} if self._autogenerated else None,
600
+ )
601
+ if output_cols_prefix == "fit_predict_":
602
+ if hasattr(self._sklearn_object, "n_clusters"):
603
+ # cluster classes such as BisectingKMeansTest requires # of examples >= n_clusters
604
+ num_examples = self._sklearn_object.n_clusters
605
+ elif hasattr(self._sklearn_object, "min_samples"):
606
+ # OPTICS default min_samples 5, which requires at least 5 lines of data
607
+ num_examples = self._sklearn_object.min_samples
608
+ elif hasattr(self._sklearn_object, "n_neighbors") and hasattr(self._sklearn_object, "n_samples"):
609
+ # LocalOutlierFactor expects n_neighbors <= n_samples
610
+ num_examples = self._sklearn_object.n_neighbors
611
+ sample_pd_df = dataset.select(self.input_cols).limit(num_examples).to_pandas(statement_params=statement_params)
612
+ else:
613
+ sample_pd_df = dataset.select(self.input_cols).limit(1).to_pandas(statement_params=statement_params)
576
614
 
577
615
  # Rename the pandas df column names to snowflake identifiers and reorder columns to match the order
578
616
  # seen during the fit.
@@ -584,12 +622,14 @@ class OneVsOneClassifier(BaseTransformer):
584
622
  output_df_columns_set: Set[str] = set(output_df_columns) - set(dataset.columns)
585
623
  if self.sample_weight_col:
586
624
  output_df_columns_set -= set(self.sample_weight_col)
625
+
587
626
  # if the dimension of inferred output column names is correct; use it
588
627
  if len(expected_output_cols_list) == len(output_df_columns_set):
589
- return expected_output_cols_list
628
+ return expected_output_cols_list, output_df_pd
590
629
  # otherwise, use the sklearn estimator's output
591
630
  else:
592
- return sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
631
+ expected_output_cols_list = sorted(list(output_df_columns_set), key=lambda x: output_df_columns.index(x))
632
+ return expected_output_cols_list, output_df_pd[expected_output_cols_list]
593
633
 
594
634
  @available_if(original_estimator_has_callable("predict_proba")) # type: ignore[misc]
595
635
  @telemetry.send_api_usage_telemetry(
@@ -635,7 +675,7 @@ class OneVsOneClassifier(BaseTransformer):
635
675
  drop_input_cols=self._drop_input_cols,
636
676
  expected_output_cols_type="float",
637
677
  )
638
- expected_output_cols = self._align_expected_output_names(
678
+ expected_output_cols, _ = self._align_expected_output(
639
679
  inference_method, dataset, expected_output_cols, output_cols_prefix
640
680
  )
641
681
 
@@ -701,7 +741,7 @@ class OneVsOneClassifier(BaseTransformer):
701
741
  drop_input_cols=self._drop_input_cols,
702
742
  expected_output_cols_type="float",
703
743
  )
704
- expected_output_cols = self._align_expected_output_names(
744
+ expected_output_cols, _ = self._align_expected_output(
705
745
  inference_method, dataset, expected_output_cols, output_cols_prefix
706
746
  )
707
747
  elif isinstance(dataset, pd.DataFrame):
@@ -766,7 +806,7 @@ class OneVsOneClassifier(BaseTransformer):
766
806
  drop_input_cols=self._drop_input_cols,
767
807
  expected_output_cols_type="float",
768
808
  )
769
- expected_output_cols = self._align_expected_output_names(
809
+ expected_output_cols, _ = self._align_expected_output(
770
810
  inference_method, dataset, expected_output_cols, output_cols_prefix
771
811
  )
772
812
 
@@ -831,7 +871,7 @@ class OneVsOneClassifier(BaseTransformer):
831
871
  drop_input_cols = self._drop_input_cols,
832
872
  expected_output_cols_type="float",
833
873
  )
834
- expected_output_cols = self._align_expected_output_names(
874
+ expected_output_cols, _ = self._align_expected_output(
835
875
  inference_method, dataset, expected_output_cols, output_cols_prefix
836
876
  )
837
877