snowflake-ml-python 1.1.1__py3-none-any.whl → 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (224) hide show
  1. snowflake/cortex/_complete.py +1 -1
  2. snowflake/cortex/_extract_answer.py +1 -1
  3. snowflake/cortex/_sentiment.py +1 -1
  4. snowflake/cortex/_summarize.py +1 -1
  5. snowflake/cortex/_translate.py +1 -1
  6. snowflake/ml/_internal/env_utils.py +68 -6
  7. snowflake/ml/_internal/file_utils.py +34 -4
  8. snowflake/ml/_internal/telemetry.py +79 -91
  9. snowflake/ml/_internal/utils/retryable_http.py +16 -4
  10. snowflake/ml/_internal/utils/spcs_attribution_utils.py +122 -0
  11. snowflake/ml/dataset/dataset.py +1 -1
  12. snowflake/ml/model/_api.py +21 -14
  13. snowflake/ml/model/_client/model/model_impl.py +176 -0
  14. snowflake/ml/model/_client/model/model_method_info.py +19 -0
  15. snowflake/ml/model/_client/model/model_version_impl.py +291 -0
  16. snowflake/ml/model/_client/ops/metadata_ops.py +107 -0
  17. snowflake/ml/model/_client/ops/model_ops.py +308 -0
  18. snowflake/ml/model/_client/sql/model.py +75 -0
  19. snowflake/ml/model/_client/sql/model_version.py +213 -0
  20. snowflake/ml/model/_client/sql/stage.py +40 -0
  21. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -4
  22. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +24 -8
  23. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +23 -0
  24. snowflake/ml/model/_deploy_client/snowservice/deploy.py +14 -2
  25. snowflake/ml/model/_deploy_client/utils/constants.py +1 -0
  26. snowflake/ml/model/_deploy_client/warehouse/deploy.py +2 -2
  27. snowflake/ml/model/_model_composer/model_composer.py +31 -9
  28. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +25 -10
  29. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -2
  30. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
  31. snowflake/ml/model/_model_composer/model_method/model_method.py +34 -3
  32. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +1 -1
  33. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +3 -1
  34. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +10 -28
  35. snowflake/ml/model/_packager/model_meta/model_meta.py +18 -16
  36. snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
  37. snowflake/ml/model/model_signature.py +108 -53
  38. snowflake/ml/model/type_hints.py +1 -0
  39. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +554 -0
  40. snowflake/ml/modeling/_internal/estimator_protocols.py +1 -60
  41. snowflake/ml/modeling/_internal/model_specifications.py +146 -0
  42. snowflake/ml/modeling/_internal/model_trainer.py +13 -0
  43. snowflake/ml/modeling/_internal/model_trainer_builder.py +78 -0
  44. snowflake/ml/modeling/_internal/pandas_trainer.py +54 -0
  45. snowflake/ml/modeling/_internal/snowpark_handlers.py +6 -760
  46. snowflake/ml/modeling/_internal/snowpark_trainer.py +331 -0
  47. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +96 -124
  48. snowflake/ml/modeling/cluster/affinity_propagation.py +94 -124
  49. snowflake/ml/modeling/cluster/agglomerative_clustering.py +94 -124
  50. snowflake/ml/modeling/cluster/birch.py +94 -124
  51. snowflake/ml/modeling/cluster/bisecting_k_means.py +94 -124
  52. snowflake/ml/modeling/cluster/dbscan.py +94 -124
  53. snowflake/ml/modeling/cluster/feature_agglomeration.py +94 -124
  54. snowflake/ml/modeling/cluster/k_means.py +93 -124
  55. snowflake/ml/modeling/cluster/mean_shift.py +94 -124
  56. snowflake/ml/modeling/cluster/mini_batch_k_means.py +93 -124
  57. snowflake/ml/modeling/cluster/optics.py +94 -124
  58. snowflake/ml/modeling/cluster/spectral_biclustering.py +94 -124
  59. snowflake/ml/modeling/cluster/spectral_clustering.py +94 -124
  60. snowflake/ml/modeling/cluster/spectral_coclustering.py +94 -124
  61. snowflake/ml/modeling/compose/column_transformer.py +94 -124
  62. snowflake/ml/modeling/compose/transformed_target_regressor.py +96 -124
  63. snowflake/ml/modeling/covariance/elliptic_envelope.py +94 -124
  64. snowflake/ml/modeling/covariance/empirical_covariance.py +80 -110
  65. snowflake/ml/modeling/covariance/graphical_lasso.py +94 -124
  66. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +94 -124
  67. snowflake/ml/modeling/covariance/ledoit_wolf.py +85 -115
  68. snowflake/ml/modeling/covariance/min_cov_det.py +94 -124
  69. snowflake/ml/modeling/covariance/oas.py +80 -110
  70. snowflake/ml/modeling/covariance/shrunk_covariance.py +84 -114
  71. snowflake/ml/modeling/decomposition/dictionary_learning.py +94 -124
  72. snowflake/ml/modeling/decomposition/factor_analysis.py +94 -124
  73. snowflake/ml/modeling/decomposition/fast_ica.py +94 -124
  74. snowflake/ml/modeling/decomposition/incremental_pca.py +94 -124
  75. snowflake/ml/modeling/decomposition/kernel_pca.py +94 -124
  76. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +94 -124
  77. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +94 -124
  78. snowflake/ml/modeling/decomposition/pca.py +94 -124
  79. snowflake/ml/modeling/decomposition/sparse_pca.py +94 -124
  80. snowflake/ml/modeling/decomposition/truncated_svd.py +94 -124
  81. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +96 -124
  82. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +91 -119
  83. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +96 -124
  84. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +96 -124
  85. snowflake/ml/modeling/ensemble/bagging_classifier.py +96 -124
  86. snowflake/ml/modeling/ensemble/bagging_regressor.py +96 -124
  87. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +96 -124
  88. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +96 -124
  89. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +96 -124
  90. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +96 -124
  91. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +96 -124
  92. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +96 -124
  93. snowflake/ml/modeling/ensemble/isolation_forest.py +94 -124
  94. snowflake/ml/modeling/ensemble/random_forest_classifier.py +96 -124
  95. snowflake/ml/modeling/ensemble/random_forest_regressor.py +96 -124
  96. snowflake/ml/modeling/ensemble/stacking_regressor.py +96 -124
  97. snowflake/ml/modeling/ensemble/voting_classifier.py +96 -124
  98. snowflake/ml/modeling/ensemble/voting_regressor.py +91 -119
  99. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +82 -110
  100. snowflake/ml/modeling/feature_selection/select_fdr.py +80 -108
  101. snowflake/ml/modeling/feature_selection/select_fpr.py +80 -108
  102. snowflake/ml/modeling/feature_selection/select_fwe.py +80 -108
  103. snowflake/ml/modeling/feature_selection/select_k_best.py +81 -109
  104. snowflake/ml/modeling/feature_selection/select_percentile.py +80 -108
  105. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +94 -124
  106. snowflake/ml/modeling/feature_selection/variance_threshold.py +76 -106
  107. snowflake/ml/modeling/framework/base.py +2 -2
  108. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +96 -124
  109. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +96 -124
  110. snowflake/ml/modeling/impute/iterative_imputer.py +94 -124
  111. snowflake/ml/modeling/impute/knn_imputer.py +94 -124
  112. snowflake/ml/modeling/impute/missing_indicator.py +94 -124
  113. snowflake/ml/modeling/impute/simple_imputer.py +1 -1
  114. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +77 -107
  115. snowflake/ml/modeling/kernel_approximation/nystroem.py +94 -124
  116. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +94 -124
  117. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +86 -116
  118. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +84 -114
  119. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +96 -124
  120. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +71 -100
  121. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +71 -100
  122. snowflake/ml/modeling/linear_model/ard_regression.py +96 -124
  123. snowflake/ml/modeling/linear_model/bayesian_ridge.py +96 -124
  124. snowflake/ml/modeling/linear_model/elastic_net.py +96 -124
  125. snowflake/ml/modeling/linear_model/elastic_net_cv.py +96 -124
  126. snowflake/ml/modeling/linear_model/gamma_regressor.py +96 -124
  127. snowflake/ml/modeling/linear_model/huber_regressor.py +96 -124
  128. snowflake/ml/modeling/linear_model/lars.py +96 -124
  129. snowflake/ml/modeling/linear_model/lars_cv.py +96 -124
  130. snowflake/ml/modeling/linear_model/lasso.py +96 -124
  131. snowflake/ml/modeling/linear_model/lasso_cv.py +96 -124
  132. snowflake/ml/modeling/linear_model/lasso_lars.py +96 -124
  133. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +96 -124
  134. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +96 -124
  135. snowflake/ml/modeling/linear_model/linear_regression.py +91 -119
  136. snowflake/ml/modeling/linear_model/logistic_regression.py +96 -124
  137. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +96 -124
  138. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +96 -124
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +96 -124
  140. snowflake/ml/modeling/linear_model/multi_task_lasso.py +96 -124
  141. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +96 -124
  142. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +96 -124
  143. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +96 -124
  144. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +95 -124
  145. snowflake/ml/modeling/linear_model/perceptron.py +95 -124
  146. snowflake/ml/modeling/linear_model/poisson_regressor.py +96 -124
  147. snowflake/ml/modeling/linear_model/ransac_regressor.py +96 -124
  148. snowflake/ml/modeling/linear_model/ridge.py +96 -124
  149. snowflake/ml/modeling/linear_model/ridge_classifier.py +96 -124
  150. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +96 -124
  151. snowflake/ml/modeling/linear_model/ridge_cv.py +96 -124
  152. snowflake/ml/modeling/linear_model/sgd_classifier.py +96 -124
  153. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +94 -124
  154. snowflake/ml/modeling/linear_model/sgd_regressor.py +96 -124
  155. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +96 -124
  156. snowflake/ml/modeling/linear_model/tweedie_regressor.py +96 -124
  157. snowflake/ml/modeling/manifold/isomap.py +94 -124
  158. snowflake/ml/modeling/manifold/mds.py +94 -124
  159. snowflake/ml/modeling/manifold/spectral_embedding.py +94 -124
  160. snowflake/ml/modeling/manifold/tsne.py +94 -124
  161. snowflake/ml/modeling/metrics/classification.py +187 -52
  162. snowflake/ml/modeling/metrics/correlation.py +4 -2
  163. snowflake/ml/modeling/metrics/covariance.py +7 -4
  164. snowflake/ml/modeling/metrics/ranking.py +32 -16
  165. snowflake/ml/modeling/metrics/regression.py +60 -32
  166. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +94 -124
  167. snowflake/ml/modeling/mixture/gaussian_mixture.py +94 -124
  168. snowflake/ml/modeling/model_selection/grid_search_cv.py +88 -138
  169. snowflake/ml/modeling/model_selection/randomized_search_cv.py +90 -144
  170. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +86 -114
  171. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +93 -121
  172. snowflake/ml/modeling/multiclass/output_code_classifier.py +94 -122
  173. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +92 -120
  174. snowflake/ml/modeling/naive_bayes/categorical_nb.py +96 -124
  175. snowflake/ml/modeling/naive_bayes/complement_nb.py +92 -120
  176. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +79 -107
  177. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +88 -116
  178. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +96 -124
  179. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +96 -124
  180. snowflake/ml/modeling/neighbors/kernel_density.py +94 -124
  181. snowflake/ml/modeling/neighbors/local_outlier_factor.py +94 -124
  182. snowflake/ml/modeling/neighbors/nearest_centroid.py +89 -117
  183. snowflake/ml/modeling/neighbors/nearest_neighbors.py +94 -124
  184. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +96 -124
  185. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +96 -124
  186. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +96 -124
  187. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +94 -124
  188. snowflake/ml/modeling/neural_network/mlp_classifier.py +96 -124
  189. snowflake/ml/modeling/neural_network/mlp_regressor.py +96 -124
  190. snowflake/ml/modeling/parameters/disable_distributed_hpo.py +2 -6
  191. snowflake/ml/modeling/preprocessing/binarizer.py +14 -9
  192. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +0 -4
  193. snowflake/ml/modeling/preprocessing/label_encoder.py +21 -13
  194. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +20 -14
  195. snowflake/ml/modeling/preprocessing/min_max_scaler.py +35 -19
  196. snowflake/ml/modeling/preprocessing/normalizer.py +6 -9
  197. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +20 -13
  198. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +25 -13
  199. snowflake/ml/modeling/preprocessing/polynomial_features.py +94 -124
  200. snowflake/ml/modeling/preprocessing/robust_scaler.py +28 -14
  201. snowflake/ml/modeling/preprocessing/standard_scaler.py +25 -13
  202. snowflake/ml/modeling/semi_supervised/label_propagation.py +96 -124
  203. snowflake/ml/modeling/semi_supervised/label_spreading.py +96 -124
  204. snowflake/ml/modeling/svm/linear_svc.py +96 -124
  205. snowflake/ml/modeling/svm/linear_svr.py +96 -124
  206. snowflake/ml/modeling/svm/nu_svc.py +96 -124
  207. snowflake/ml/modeling/svm/nu_svr.py +96 -124
  208. snowflake/ml/modeling/svm/svc.py +96 -124
  209. snowflake/ml/modeling/svm/svr.py +96 -124
  210. snowflake/ml/modeling/tree/decision_tree_classifier.py +96 -124
  211. snowflake/ml/modeling/tree/decision_tree_regressor.py +96 -124
  212. snowflake/ml/modeling/tree/extra_tree_classifier.py +96 -124
  213. snowflake/ml/modeling/tree/extra_tree_regressor.py +96 -124
  214. snowflake/ml/modeling/xgboost/xgb_classifier.py +96 -125
  215. snowflake/ml/modeling/xgboost/xgb_regressor.py +96 -125
  216. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +96 -125
  217. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +96 -125
  218. snowflake/ml/registry/model_registry.py +2 -0
  219. snowflake/ml/registry/registry.py +215 -0
  220. snowflake/ml/version.py +1 -1
  221. {snowflake_ml_python-1.1.1.dist-info → snowflake_ml_python-1.1.2.dist-info}/METADATA +21 -3
  222. snowflake_ml_python-1.1.2.dist-info/RECORD +347 -0
  223. snowflake_ml_python-1.1.1.dist-info/RECORD +0 -331
  224. {snowflake_ml_python-1.1.1.dist-info → snowflake_ml_python-1.1.2.dist-info}/WHEEL +0 -0
@@ -5,7 +5,6 @@ import warnings
5
5
  from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
6
6
 
7
7
  import cloudpickle
8
- import numpy
9
8
  import numpy as np
10
9
  import numpy.typing as npt
11
10
  from sklearn import exceptions, metrics
@@ -43,12 +42,17 @@ def accuracy_score(
43
42
  corresponding set of labels in the y true columns.
44
43
 
45
44
  Args:
46
- df: Input dataframe.
47
- y_true_col_names: Column name(s) representing actual values.
48
- y_pred_col_names: Column name(s) representing predicted values.
49
- normalize: If ``False``, return the number of correctly classified samples.
45
+ df: snowpark.DataFrame
46
+ Input dataframe.
47
+ y_true_col_names: string or list of strings
48
+ Column name(s) representing actual values.
49
+ y_pred_col_names: string or list of strings
50
+ Column name(s) representing predicted values.
51
+ normalize: boolean, default=True
52
+ If ``False``, return the number of correctly classified samples.
50
53
  Otherwise, return the fraction of correctly classified samples.
51
- sample_weight_col_name: Column name representing sample weights.
54
+ sample_weight_col_name: string, default=None
55
+ Column name representing sample weights.
52
56
 
53
57
  Returns:
54
58
  If ``normalize == True``, return the fraction of correctly
@@ -102,14 +106,19 @@ def confusion_matrix(
102
106
  :math:`C_{1,1}` and false positives is :math:`C_{0,1}`.
103
107
 
104
108
  Args:
105
- df: Input dataframe.
106
- y_true_col_name: Column name representing actual values.
107
- y_pred_col_name: Column name representing predicted values.
108
- labels: List of labels to index the matrix. This may be used to
109
+ df: snowpark.DataFrame
110
+ Input dataframe.
111
+ y_true_col_name: string or list of strings
112
+ Column name representing actual values.
113
+ y_pred_col_name: string or list of strings
114
+ Column name representing predicted values.
115
+ labels: list of labels, default=None
116
+ List of labels to index the matrix. This may be used to
109
117
  reorder or select a subset of labels.
110
118
  If ``None`` is given, those that appear at least once in the
111
119
  y true or y pred column are used in sorted order.
112
- sample_weight_col_name: Column name representing sample weights.
120
+ sample_weight_col_name: string, default=None
121
+ Column name representing sample weights.
113
122
  normalize: {'true', 'pred', 'all'}, default=None
114
123
  Normalizes confusion matrix over the true (rows), predicted (columns)
115
124
  conditions or all the population. If None, confusion matrix will not be
@@ -124,7 +133,9 @@ def confusion_matrix(
124
133
 
125
134
  Raises:
126
135
  ValueError: The given ``labels`` is empty.
136
+
127
137
  ValueError: No label specified in the given ``labels`` is in the y true column.
138
+
128
139
  ValueError: ``normalize`` is not one of {'true', 'pred', 'all', None}.
129
140
  """
130
141
  assert df._session is not None
@@ -323,17 +334,22 @@ def f1_score(
323
334
  parameter.
324
335
 
325
336
  Args:
326
- df: Input dataframe.
327
- y_true_col_names: Column name(s) representing actual values.
328
- y_pred_col_names: Column name(s) representing predicted values.
329
- labels: The set of labels to include when ``average != 'binary'``, and
337
+ df: snowpark.DataFrame
338
+ Input dataframe.
339
+ y_true_col_names: string or list of strings
340
+ Column name(s) representing actual values.
341
+ y_pred_col_names: string or list of strings
342
+ Column name(s) representing predicted values.
343
+ labels: list of labels, default=None
344
+ The set of labels to include when ``average != 'binary'``, and
330
345
  their order if ``average is None``. Labels present in the data can be
331
346
  excluded, for example to calculate a multiclass average ignoring a
332
347
  majority negative class, while labels not present in the data will
333
348
  result in 0 components in a macro average. For multilabel targets,
334
349
  labels are column indices. By default, all labels in the y true and
335
350
  y pred columns are used in sorted order.
336
- pos_label: The class to report if ``average='binary'`` and the data is
351
+ pos_label: string or integer, default=1
352
+ The class to report if ``average='binary'`` and the data is
337
353
  binary. If the data are multiclass or multilabel, this will be ignored;
338
354
  setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
339
355
  scores for that label only.
@@ -359,7 +375,8 @@ def f1_score(
359
375
  Calculate metrics for each instance, and find their average (only
360
376
  meaningful for multilabel classification where this differs from
361
377
  func`accuracy_score`).
362
- sample_weight_col_name: Column name representing sample weights.
378
+ sample_weight_col_name: string, default=None
379
+ Column name representing sample weights.
363
380
  zero_division: "warn", 0 or 1, default="warn"
364
381
  Sets the value to return when there is a zero division, i.e. when all
365
382
  predictions and labels are negative. If set to "warn", this acts as 0,
@@ -408,18 +425,24 @@ def fbeta_score(
408
425
  only recall).
409
426
 
410
427
  Args:
411
- df: Input dataframe.
412
- y_true_col_names: Column name(s) representing actual values.
413
- y_pred_col_names: Column name(s) representing predicted values.
414
- beta: Determines the weight of recall in the combined score.
415
- labels: The set of labels to include when ``average != 'binary'``, and
428
+ df: snowpark.DataFrame
429
+ Input dataframe.
430
+ y_true_col_names: string or list of strings
431
+ Column name(s) representing actual values.
432
+ y_pred_col_names: string or list of strings
433
+ Column name(s) representing predicted values.
434
+ beta: float
435
+ Determines the weight of recall in the combined score.
436
+ labels: list of labels, default=None
437
+ The set of labels to include when ``average != 'binary'``, and
416
438
  their order if ``average is None``. Labels present in the data can be
417
439
  excluded, for example to calculate a multiclass average ignoring a
418
440
  majority negative class, while labels not present in the data will
419
441
  result in 0 components in a macro average. For multilabel targets,
420
442
  labels are column indices. By default, all labels in the y true and
421
443
  y pred columns are used in sorted order.
422
- pos_label: The class to report if ``average='binary'`` and the data is
444
+ pos_label: string or integer, default=1
445
+ The class to report if ``average='binary'`` and the data is
423
446
  binary. If the data are multiclass or multilabel, this will be ignored;
424
447
  setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
425
448
  scores for that label only.
@@ -445,7 +468,8 @@ def fbeta_score(
445
468
  Calculate metrics for each instance, and find their average (only
446
469
  meaningful for multilabel classification where this differs from
447
470
  func`accuracy_score`).
448
- sample_weight_col_name: Column name representing sample weights.
471
+ sample_weight_col_name: string, default=None
472
+ Column name representing sample weights.
449
473
  zero_division: "warn", 0 or 1, default="warn"
450
474
  Sets the value to return when there is a zero division, i.e. when all
451
475
  predictions and labels are negative. If set to "warn", this acts as 0,
@@ -498,9 +522,12 @@ def log_loss(
498
522
  L_{\log}(y, p) = -(y \log (p) + (1 - y) \log (1 - p))
499
523
 
500
524
  Args:
501
- df: Input dataframe.
502
- y_true_col_names: Column name(s) representing actual values.
503
- y_pred_col_names: Column name(s) representing predicted probabilities,
525
+ df: snowpark.DataFrame
526
+ Input dataframe.
527
+ y_true_col_names: string or list of strings
528
+ Column name(s) representing actual values.
529
+ y_pred_col_names: string or list of strings
530
+ Column name(s) representing predicted probabilities,
504
531
  as returned by a classifier's predict_proba method.
505
532
  If ``y_pred.shape = (n_samples,)`` the probabilities provided are
506
533
  assumed to be that of the positive class. The labels in ``y_pred``
@@ -509,10 +536,13 @@ def log_loss(
509
536
  Log loss is undefined for p=0 or p=1, so probabilities are
510
537
  clipped to `max(eps, min(1 - eps, p))`. The default will depend on the
511
538
  data type of `y_pred` and is set to `np.finfo(y_pred.dtype).eps`.
512
- normalize: If true, return the mean loss per sample.
539
+ normalize: boolean, default=True
540
+ If true, return the mean loss per sample.
513
541
  Otherwise, return the sum of the per-sample losses.
514
- sample_weight_col_name: Column name representing sample weights.
515
- labels: If not provided, labels will be inferred from y_true. If ``labels``
542
+ sample_weight_col_name: string, default=None
543
+ Column name representing sample weights.
544
+ labels: list of labels, default=None
545
+ If not provided, labels will be inferred from y_true. If ``labels``
516
546
  is ``None`` and ``y_pred`` has shape (n_samples,) the labels are
517
547
  assumed to be binary and are inferred from ``y_true``.
518
548
 
@@ -697,18 +727,24 @@ def precision_recall_fscore_support(
697
727
  is one of ``'micro'``, ``'macro'``, ``'weighted'`` or ``'samples'``.
698
728
 
699
729
  Args:
700
- df: Input dataframe.
701
- y_true_col_names: Column name(s) representing actual values.
702
- y_pred_col_names: Column name(s) representing predicted values.
703
- beta: The strength of recall versus precision in the F-score.
704
- labels: The set of labels to include when ``average != 'binary'``, and
730
+ df: snowpark.DataFrame
731
+ Input dataframe.
732
+ y_true_col_names: string or list of strings
733
+ Column name(s) representing actual values.
734
+ y_pred_col_names: string or list of strings
735
+ Column name(s) representing predicted values.
736
+ beta: float, default=1.0
737
+ The strength of recall versus precision in the F-score.
738
+ labels: list of labels, default=None
739
+ The set of labels to include when ``average != 'binary'``, and
705
740
  their order if ``average is None``. Labels present in the data can be
706
741
  excluded, for example to calculate a multiclass average ignoring a
707
742
  majority negative class, while labels not present in the data will
708
743
  result in 0 components in a macro average. For multilabel targets,
709
744
  labels are column indices. By default, all labels in the y true and
710
745
  y pred columns are used in sorted order.
711
- pos_label: The class to report if ``average='binary'`` and the data is
746
+ pos_label: string or integer, default=1
747
+ The class to report if ``average='binary'`` and the data is
712
748
  binary. If the data are multiclass or multilabel, this will be ignored;
713
749
  setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
714
750
  scores for that label only.
@@ -733,9 +769,11 @@ def precision_recall_fscore_support(
733
769
  Calculate metrics for each instance, and find their average (only
734
770
  meaningful for multilabel classification where this differs from
735
771
  :func:`accuracy_score`).
736
- warn_for: This determines which warnings will be made in the case that this
772
+ warn_for: tuple or set containing "precision", "recall", or "f-score"
773
+ This determines which warnings will be made in the case that this
737
774
  function is being used to return only one of its metrics.
738
- sample_weight_col_name: Column name representing sample weights.
775
+ sample_weight_col_name: string, default=None
776
+ Column name representing sample weights.
739
777
  zero_division: "warn", 0 or 1, default="warn"
740
778
  Sets the value to return when there is a zero division:
741
779
  * recall - when there are no positive labels
@@ -980,6 +1018,78 @@ def _register_multilabel_confusion_matrix_computer(
980
1018
  return multilabel_confusion_matrix_computer
981
1019
 
982
1020
 
1021
+ def _binary_precision_score(
1022
+ *,
1023
+ df: snowpark.DataFrame,
1024
+ y_true_col_names: Union[str, List[str]],
1025
+ y_pred_col_names: Union[str, List[str]],
1026
+ pos_label: Union[str, int] = 1,
1027
+ sample_weight_col_name: Optional[str] = None,
1028
+ zero_division: Union[str, int] = "warn",
1029
+ ) -> Union[float, npt.NDArray[np.float_]]:
1030
+
1031
+ statement_params = telemetry.get_statement_params(_PROJECT, _SUBPROJECT)
1032
+
1033
+ if isinstance(y_true_col_names, str):
1034
+ y_true_col_names = [y_true_col_names]
1035
+ if isinstance(y_pred_col_names, str):
1036
+ y_pred_col_names = [y_pred_col_names]
1037
+
1038
+ if len(y_pred_col_names) != len(y_true_col_names):
1039
+ raise ValueError(
1040
+ "precision_score: `y_true_col_names` and `y_pred_column_names` must be lists of the same length "
1041
+ "or both strings."
1042
+ )
1043
+
1044
+ # Confirm that the data is binary.
1045
+ labels_set = set()
1046
+ columns = y_true_col_names + y_pred_col_names
1047
+ column_labels_lists = df.select(*[F.array_unique_agg(col) for col in columns]).collect(
1048
+ statement_params=statement_params
1049
+ )[0]
1050
+ for column_labels_list in column_labels_lists:
1051
+ for column_label in json.loads(column_labels_list):
1052
+ labels_set.add(column_label)
1053
+ labels = sorted(list(labels_set))
1054
+ _ = _check_binary_labels(labels, pos_label=pos_label)
1055
+
1056
+ sample_weight_column = df[sample_weight_col_name] if sample_weight_col_name else None
1057
+
1058
+ scores = []
1059
+ for y_true, y_pred in zip(y_true_col_names, y_pred_col_names):
1060
+ tp_col = F.iff((F.col(y_true) == pos_label) & (F.col(y_pred) == pos_label), 1, 0)
1061
+ fp_col = F.iff((F.col(y_true) != pos_label) & (F.col(y_pred) == pos_label), 1, 0)
1062
+ tp = metrics_utils.weighted_sum(
1063
+ df=df,
1064
+ sample_score_column=tp_col,
1065
+ sample_weight_column=sample_weight_column,
1066
+ statement_params=statement_params,
1067
+ )
1068
+ fp = metrics_utils.weighted_sum(
1069
+ df=df,
1070
+ sample_score_column=fp_col,
1071
+ sample_weight_column=sample_weight_column,
1072
+ statement_params=statement_params,
1073
+ )
1074
+
1075
+ try:
1076
+ score = tp / (tp + fp)
1077
+ except ZeroDivisionError:
1078
+ if zero_division == "warn":
1079
+ msg = "precision_score: division by zero: score value will be 0."
1080
+ warnings.warn(msg, exceptions.UndefinedMetricWarning, stacklevel=2)
1081
+ score = 0.0
1082
+ else:
1083
+ score = float(zero_division)
1084
+
1085
+ scores.append(score)
1086
+
1087
+ if len(scores) == 1:
1088
+ return scores[0]
1089
+
1090
+ return np.array(scores)
1091
+
1092
+
983
1093
  @telemetry.send_api_usage_telemetry(project=_PROJECT, subproject=_SUBPROJECT)
984
1094
  def precision_score(
985
1095
  *,
@@ -1003,17 +1113,22 @@ def precision_score(
1003
1113
  The best value is 1 and the worst value is 0.
1004
1114
 
1005
1115
  Args:
1006
- df: Input dataframe.
1007
- y_true_col_names: Column name(s) representing actual values.
1008
- y_pred_col_names: Column name(s) representing predicted values.
1009
- labels: The set of labels to include when ``average != 'binary'``, and
1116
+ df: snowpark.DataFrame
1117
+ Input dataframe.
1118
+ y_true_col_names: string or list of strings
1119
+ Column name(s) representing actual values.
1120
+ y_pred_col_names: string or list of strings
1121
+ Column name(s) representing predicted values.
1122
+ labels: list of labels, default=None
1123
+ The set of labels to include when ``average != 'binary'``, and
1010
1124
  their order if ``average is None``. Labels present in the data can be
1011
1125
  excluded, for example to calculate a multiclass average ignoring a
1012
1126
  majority negative class, while labels not present in the data will
1013
1127
  result in 0 components in a macro average. For multilabel targets,
1014
1128
  labels are column indices. By default, all labels in the y true and
1015
1129
  y pred columns are used in sorted order.
1016
- pos_label: The class to report if ``average='binary'`` and the data is
1130
+ pos_label: string or integer, default=1
1131
+ The class to report if ``average='binary'`` and the data is
1017
1132
  binary. If the data are multiclass or multilabel, this will be ignored;
1018
1133
  setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
1019
1134
  scores for that label only.
@@ -1038,7 +1153,8 @@ def precision_score(
1038
1153
  Calculate metrics for each instance, and find their average (only
1039
1154
  meaningful for multilabel classification where this differs from
1040
1155
  func`accuracy_score`).
1041
- sample_weight_col_name: Column name representing sample weights.
1156
+ sample_weight_col_name: string, default=None
1157
+ Column name representing sample weights.
1042
1158
  zero_division: "warn", 0 or 1, default="warn"
1043
1159
  Sets the value to return when there is a zero division. If set to
1044
1160
  "warn", this acts as 0, but warnings are also raised.
@@ -1048,6 +1164,16 @@ def precision_score(
1048
1164
  Precision of the positive class in binary classification or weighted
1049
1165
  average of the precision of each class for the multiclass task.
1050
1166
  """
1167
+ if average == "binary":
1168
+ return _binary_precision_score(
1169
+ df=df,
1170
+ y_true_col_names=y_true_col_names,
1171
+ y_pred_col_names=y_pred_col_names,
1172
+ pos_label=pos_label,
1173
+ sample_weight_col_name=sample_weight_col_name,
1174
+ zero_division=zero_division,
1175
+ )
1176
+
1051
1177
  p, _, _, _ = precision_recall_fscore_support(
1052
1178
  df=df,
1053
1179
  y_true_col_names=y_true_col_names,
@@ -1084,17 +1210,22 @@ def recall_score(
1084
1210
  The best value is 1 and the worst value is 0.
1085
1211
 
1086
1212
  Args:
1087
- df: Input dataframe.
1088
- y_true_col_names: Column name(s) representing actual values.
1089
- y_pred_col_names: Column name(s) representing predicted values.
1090
- labels: The set of labels to include when ``average != 'binary'``, and
1213
+ df: snowpark.DataFrame
1214
+ Input dataframe.
1215
+ y_true_col_names: string or list of strings
1216
+ Column name(s) representing actual values.
1217
+ y_pred_col_names: string or list of strings
1218
+ Column name(s) representing predicted values.
1219
+ labels: list of labels, default=None
1220
+ The set of labels to include when ``average != 'binary'``, and
1091
1221
  their order if ``average is None``. Labels present in the data can be
1092
1222
  excluded, for example to calculate a multiclass average ignoring a
1093
1223
  majority negative class, while labels not present in the data will
1094
1224
  result in 0 components in a macro average. For multilabel targets,
1095
1225
  labels are column indices. By default, all labels in the y true and
1096
1226
  y pred columns are used in sorted order.
1097
- pos_label: The class to report if ``average='binary'`` and the data is
1227
+ pos_label: string or integer, default=1
1228
+ The class to report if ``average='binary'`` and the data is
1098
1229
  binary. If the data are multiclass or multilabel, this will be ignored;
1099
1230
  setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
1100
1231
  scores for that label only.
@@ -1121,7 +1252,8 @@ def recall_score(
1121
1252
  Calculate metrics for each instance, and find their average (only
1122
1253
  meaningful for multilabel classification where this differs from
1123
1254
  func`accuracy_score`).
1124
- sample_weight_col_name: Column name representing sample weights.
1255
+ sample_weight_col_name: string, default=None
1256
+ Column name representing sample weights.
1125
1257
  zero_division: "warn", 0 or 1, default="warn"
1126
1258
  Sets the value to return when there is a zero division. If set to
1127
1259
  "warn", this acts as 0, but warnings are also raised.
@@ -1190,10 +1322,13 @@ def _check_binary_labels(
1190
1322
  """
1191
1323
  if len(labels) <= 2:
1192
1324
  if len(labels) == 2 and pos_label not in labels:
1193
- raise ValueError(f"pos_label={pos_label} is not a valid label. It should be one of {labels}")
1325
+ raise ValueError(f"pos_label={pos_label} is not a valid label. It must be one of {labels}")
1194
1326
  labels = [pos_label]
1195
1327
  else:
1196
- raise ValueError("Please choose another average setting.")
1328
+ raise ValueError(
1329
+ "Cannot compute precision score with binary average: there are more than two labels present."
1330
+ "Please choose another average setting."
1331
+ )
1197
1332
 
1198
1333
  return labels
1199
1334
 
@@ -36,8 +36,10 @@ def correlation(*, df: snowpark.DataFrame, columns: Optional[Collection[str]] =
36
36
  as a post-processing step.
37
37
 
38
38
  Args:
39
- df (snowpark.DataFrame): Snowpark Dataframe for which correlation matrix has to be computed.
40
- columns (Optional[Collection[str]]): List of column names for which the correlation matrix has to be computed.
39
+ df: snowpark.DataFrame
40
+ Snowpark Dataframe for which correlation matrix has to be computed.
41
+ columns: List of strings
42
+ List of column names for which the correlation matrix has to be computed.
41
43
  If None, correlation matrix is computed for all numeric columns in the snowpark dataframe.
42
44
 
43
45
  Returns:
@@ -36,11 +36,14 @@ def covariance(*, df: DataFrame, columns: Optional[Collection[str]] = None, ddof
36
36
  as a post-processing step.
37
37
 
38
38
  Args:
39
- df (DataFrame): Snowpark Dataframe for which covariance matrix has to be computed.
40
- columns (Optional[Collection[str]]): List of column names for which the covariance matrix has to be computed.
39
+ df: snowpark.DataFrame
40
+ Snowpark Dataframe for which covariance matrix has to be computed.
41
+ columns: list of strings, default=None
42
+ List of column names for which the covariance matrix has to be computed.
41
43
  If None, covariance matrix is computed for all numeric columns in the snowpark dataframe.
42
- ddof (int): default 1. Delta degrees of freedom.
43
- The divisor used in calculations is N - ddof, where N represents the number of rows.
44
+ ddof: int, default=1
45
+ Delta degrees of freedom. The divisor used in calculations is N - ddof, where N represents the
46
+ number of rows.
44
47
 
45
48
  Returns:
46
49
  Covariance matrix in pandas.DataFrame format.
@@ -49,18 +49,23 @@ def precision_recall_curve(
49
49
  which corresponds to a classifier that always predicts the positive class.
50
50
 
51
51
  Args:
52
- df: Input dataframe.
53
- y_true_col_name: Column name representing true binary labels.
52
+ df: snowpark.DataFrame
53
+ Input dataframe.
54
+ y_true_col_name: string
55
+ Column name representing true binary labels.
54
56
  If labels are not either {-1, 1} or {0, 1}, then pos_label should be
55
57
  explicitly given.
56
- probas_pred_col_name: Column name representing target scores.
58
+ probas_pred_col_name: string
59
+ Column name representing target scores.
57
60
  Can either be probability estimates of the positive
58
61
  class, or non-thresholded measure of decisions (as returned by
59
62
  `decision_function` on some classifiers).
60
- pos_label: The label of the positive class.
63
+ pos_label: string or int, default=None
64
+ The label of the positive class.
61
65
  When ``pos_label=None``, if y_true is in {-1, 1} or {0, 1},
62
66
  ``pos_label`` is set to 1, otherwise an error will be raised.
63
- sample_weight_col_name: Column name representing sample weights.
67
+ sample_weight_col_name: string, default=None
68
+ Column name representing sample weights.
64
69
 
65
70
  Returns:
66
71
  Tuple containing following items
@@ -142,12 +147,15 @@ def roc_auc_score(
142
147
  multilabel classification, but some restrictions apply.
143
148
 
144
149
  Args:
145
- df: Input dataframe.
146
- y_true_col_names: Column name(s) representing true labels or binary label indicators.
150
+ df: snowpark.DataFrame
151
+ Input dataframe.
152
+ y_true_col_names: string or list of strings
153
+ Column name(s) representing true labels or binary label indicators.
147
154
  The binary and multiclass cases expect labels with shape (n_samples,)
148
155
  while the multilabel case expects binary label indicators with shape
149
156
  (n_samples, n_classes).
150
- y_score_col_names: Column name(s) representing target scores.
157
+ y_score_col_names: string or list of strings
158
+ Column name(s) representing target scores.
151
159
  * In the binary case, it corresponds to an array of shape
152
160
  `(n_samples,)`. Both probability estimates and non-thresholded
153
161
  decision values can be provided. The probability estimates correspond
@@ -186,7 +194,8 @@ def roc_auc_score(
186
194
  ``'samples'``
187
195
  Calculate metrics for each instance, and find their average.
188
196
  Will be ignored when ``y_true`` is binary.
189
- sample_weight_col_name: Column name representing sample weights.
197
+ sample_weight_col_name: string, default=None
198
+ Column name representing sample weights.
190
199
  max_fpr: float > 0 and <= 1, default=None
191
200
  If not ``None``, the standardized partial AUC [2]_ over the range
192
201
  [0, max_fpr] is returned. For the multiclass case, ``max_fpr``,
@@ -208,7 +217,8 @@ def roc_auc_score(
208
217
  possible pairwise combinations of classes [5]_.
209
218
  Insensitive to class imbalance when
210
219
  ``average == 'macro'``.
211
- labels: Only used for multiclass targets. List of labels that index the
220
+ labels: list of labels, default=None
221
+ Only used for multiclass targets. List of labels that index the
212
222
  classes in ``y_score``. If ``None``, the numerical or lexicographical
213
223
  order of the labels in ``y_true`` is used.
214
224
 
@@ -282,19 +292,25 @@ def roc_curve(
282
292
  Note: this implementation is restricted to the binary classification task.
283
293
 
284
294
  Args:
285
- df: Input dataframe.
286
- y_true_col_name: Column name representing true binary labels.
295
+ df: snowpark.DataFrame
296
+ Input dataframe.
297
+ y_true_col_name: string
298
+ Column name representing true binary labels.
287
299
  If labels are not either {-1, 1} or {0, 1}, then pos_label should be
288
300
  explicitly given.
289
- y_score_col_name: Column name representing target scores, can either
301
+ y_score_col_name: string
302
+ Column name representing target scores, can either
290
303
  be probability estimates of the positive class, confidence values,
291
304
  or non-thresholded measure of decisions (as returned by
292
305
  "decision_function" on some classifiers).
293
- pos_label: The label of the positive class.
306
+ pos_label: string, default=None
307
+ The label of the positive class.
294
308
  When ``pos_label=None``, if `y_true` is in {-1, 1} or {0, 1},
295
309
  ``pos_label`` is set to 1, otherwise an error will be raised.
296
- sample_weight_col_name: Column name representing sample weights.
297
- drop_intermediate: Whether to drop some suboptimal thresholds which would
310
+ sample_weight_col_name: string, default=None
311
+ Column name representing sample weights.
312
+ drop_intermediate: boolean, default=True
313
+ Whether to drop some suboptimal thresholds which would
298
314
  not appear on a plotted ROC curve. This is useful in order to create
299
315
  lighter ROC curves.
300
316