snowflake-ml-python 1.1.0__py3-none-any.whl → 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (225) hide show
  1. snowflake/cortex/_complete.py +1 -1
  2. snowflake/cortex/_extract_answer.py +1 -1
  3. snowflake/cortex/_sentiment.py +1 -1
  4. snowflake/cortex/_summarize.py +1 -1
  5. snowflake/cortex/_translate.py +1 -1
  6. snowflake/ml/_internal/env_utils.py +68 -6
  7. snowflake/ml/_internal/file_utils.py +34 -4
  8. snowflake/ml/_internal/telemetry.py +79 -91
  9. snowflake/ml/_internal/utils/identifier.py +78 -72
  10. snowflake/ml/_internal/utils/retryable_http.py +16 -4
  11. snowflake/ml/_internal/utils/spcs_attribution_utils.py +122 -0
  12. snowflake/ml/dataset/dataset.py +1 -1
  13. snowflake/ml/model/_api.py +21 -14
  14. snowflake/ml/model/_client/model/model_impl.py +176 -0
  15. snowflake/ml/model/_client/model/model_method_info.py +19 -0
  16. snowflake/ml/model/_client/model/model_version_impl.py +291 -0
  17. snowflake/ml/model/_client/ops/metadata_ops.py +107 -0
  18. snowflake/ml/model/_client/ops/model_ops.py +308 -0
  19. snowflake/ml/model/_client/sql/model.py +75 -0
  20. snowflake/ml/model/_client/sql/model_version.py +213 -0
  21. snowflake/ml/model/_client/sql/stage.py +40 -0
  22. snowflake/ml/model/_deploy_client/image_builds/server_image_builder.py +3 -4
  23. snowflake/ml/model/_deploy_client/image_builds/templates/image_build_job_spec_template +24 -8
  24. snowflake/ml/model/_deploy_client/image_builds/templates/kaniko_shell_script_template +23 -0
  25. snowflake/ml/model/_deploy_client/snowservice/deploy.py +14 -2
  26. snowflake/ml/model/_deploy_client/utils/constants.py +1 -0
  27. snowflake/ml/model/_deploy_client/warehouse/deploy.py +2 -2
  28. snowflake/ml/model/_model_composer/model_composer.py +31 -9
  29. snowflake/ml/model/_model_composer/model_manifest/model_manifest.py +25 -10
  30. snowflake/ml/model/_model_composer/model_manifest/model_manifest_schema.py +2 -2
  31. snowflake/ml/model/_model_composer/model_method/infer_function.py_template +2 -1
  32. snowflake/ml/model/_model_composer/model_method/model_method.py +34 -3
  33. snowflake/ml/model/_model_composer/model_runtime/model_runtime.py +1 -1
  34. snowflake/ml/model/_packager/model_handlers/huggingface_pipeline.py +3 -1
  35. snowflake/ml/model/_packager/model_handlers/snowmlmodel.py +10 -28
  36. snowflake/ml/model/_packager/model_meta/model_meta.py +18 -16
  37. snowflake/ml/model/_signatures/snowpark_handler.py +1 -1
  38. snowflake/ml/model/model_signature.py +108 -53
  39. snowflake/ml/model/type_hints.py +1 -0
  40. snowflake/ml/modeling/_internal/distributed_hpo_trainer.py +554 -0
  41. snowflake/ml/modeling/_internal/estimator_protocols.py +1 -60
  42. snowflake/ml/modeling/_internal/model_specifications.py +146 -0
  43. snowflake/ml/modeling/_internal/model_trainer.py +13 -0
  44. snowflake/ml/modeling/_internal/model_trainer_builder.py +78 -0
  45. snowflake/ml/modeling/_internal/pandas_trainer.py +54 -0
  46. snowflake/ml/modeling/_internal/snowpark_handlers.py +6 -760
  47. snowflake/ml/modeling/_internal/snowpark_trainer.py +331 -0
  48. snowflake/ml/modeling/calibration/calibrated_classifier_cv.py +108 -135
  49. snowflake/ml/modeling/cluster/affinity_propagation.py +106 -135
  50. snowflake/ml/modeling/cluster/agglomerative_clustering.py +106 -135
  51. snowflake/ml/modeling/cluster/birch.py +106 -135
  52. snowflake/ml/modeling/cluster/bisecting_k_means.py +106 -135
  53. snowflake/ml/modeling/cluster/dbscan.py +106 -135
  54. snowflake/ml/modeling/cluster/feature_agglomeration.py +106 -135
  55. snowflake/ml/modeling/cluster/k_means.py +105 -135
  56. snowflake/ml/modeling/cluster/mean_shift.py +106 -135
  57. snowflake/ml/modeling/cluster/mini_batch_k_means.py +105 -135
  58. snowflake/ml/modeling/cluster/optics.py +106 -135
  59. snowflake/ml/modeling/cluster/spectral_biclustering.py +106 -135
  60. snowflake/ml/modeling/cluster/spectral_clustering.py +106 -135
  61. snowflake/ml/modeling/cluster/spectral_coclustering.py +106 -135
  62. snowflake/ml/modeling/compose/column_transformer.py +106 -135
  63. snowflake/ml/modeling/compose/transformed_target_regressor.py +108 -135
  64. snowflake/ml/modeling/covariance/elliptic_envelope.py +106 -135
  65. snowflake/ml/modeling/covariance/empirical_covariance.py +99 -128
  66. snowflake/ml/modeling/covariance/graphical_lasso.py +106 -135
  67. snowflake/ml/modeling/covariance/graphical_lasso_cv.py +106 -135
  68. snowflake/ml/modeling/covariance/ledoit_wolf.py +104 -133
  69. snowflake/ml/modeling/covariance/min_cov_det.py +106 -135
  70. snowflake/ml/modeling/covariance/oas.py +99 -128
  71. snowflake/ml/modeling/covariance/shrunk_covariance.py +103 -132
  72. snowflake/ml/modeling/decomposition/dictionary_learning.py +106 -135
  73. snowflake/ml/modeling/decomposition/factor_analysis.py +106 -135
  74. snowflake/ml/modeling/decomposition/fast_ica.py +106 -135
  75. snowflake/ml/modeling/decomposition/incremental_pca.py +106 -135
  76. snowflake/ml/modeling/decomposition/kernel_pca.py +106 -135
  77. snowflake/ml/modeling/decomposition/mini_batch_dictionary_learning.py +106 -135
  78. snowflake/ml/modeling/decomposition/mini_batch_sparse_pca.py +106 -135
  79. snowflake/ml/modeling/decomposition/pca.py +106 -135
  80. snowflake/ml/modeling/decomposition/sparse_pca.py +106 -135
  81. snowflake/ml/modeling/decomposition/truncated_svd.py +106 -135
  82. snowflake/ml/modeling/discriminant_analysis/linear_discriminant_analysis.py +108 -135
  83. snowflake/ml/modeling/discriminant_analysis/quadratic_discriminant_analysis.py +108 -135
  84. snowflake/ml/modeling/ensemble/ada_boost_classifier.py +108 -135
  85. snowflake/ml/modeling/ensemble/ada_boost_regressor.py +108 -135
  86. snowflake/ml/modeling/ensemble/bagging_classifier.py +108 -135
  87. snowflake/ml/modeling/ensemble/bagging_regressor.py +108 -135
  88. snowflake/ml/modeling/ensemble/extra_trees_classifier.py +108 -135
  89. snowflake/ml/modeling/ensemble/extra_trees_regressor.py +108 -135
  90. snowflake/ml/modeling/ensemble/gradient_boosting_classifier.py +108 -135
  91. snowflake/ml/modeling/ensemble/gradient_boosting_regressor.py +108 -135
  92. snowflake/ml/modeling/ensemble/hist_gradient_boosting_classifier.py +108 -135
  93. snowflake/ml/modeling/ensemble/hist_gradient_boosting_regressor.py +108 -135
  94. snowflake/ml/modeling/ensemble/isolation_forest.py +106 -135
  95. snowflake/ml/modeling/ensemble/random_forest_classifier.py +108 -135
  96. snowflake/ml/modeling/ensemble/random_forest_regressor.py +108 -135
  97. snowflake/ml/modeling/ensemble/stacking_regressor.py +108 -135
  98. snowflake/ml/modeling/ensemble/voting_classifier.py +108 -135
  99. snowflake/ml/modeling/ensemble/voting_regressor.py +108 -135
  100. snowflake/ml/modeling/feature_selection/generic_univariate_select.py +101 -128
  101. snowflake/ml/modeling/feature_selection/select_fdr.py +99 -126
  102. snowflake/ml/modeling/feature_selection/select_fpr.py +99 -126
  103. snowflake/ml/modeling/feature_selection/select_fwe.py +99 -126
  104. snowflake/ml/modeling/feature_selection/select_k_best.py +100 -127
  105. snowflake/ml/modeling/feature_selection/select_percentile.py +99 -126
  106. snowflake/ml/modeling/feature_selection/sequential_feature_selector.py +106 -135
  107. snowflake/ml/modeling/feature_selection/variance_threshold.py +95 -124
  108. snowflake/ml/modeling/framework/base.py +83 -1
  109. snowflake/ml/modeling/gaussian_process/gaussian_process_classifier.py +108 -135
  110. snowflake/ml/modeling/gaussian_process/gaussian_process_regressor.py +108 -135
  111. snowflake/ml/modeling/impute/iterative_imputer.py +106 -135
  112. snowflake/ml/modeling/impute/knn_imputer.py +106 -135
  113. snowflake/ml/modeling/impute/missing_indicator.py +106 -135
  114. snowflake/ml/modeling/impute/simple_imputer.py +9 -1
  115. snowflake/ml/modeling/kernel_approximation/additive_chi2_sampler.py +96 -125
  116. snowflake/ml/modeling/kernel_approximation/nystroem.py +106 -135
  117. snowflake/ml/modeling/kernel_approximation/polynomial_count_sketch.py +106 -135
  118. snowflake/ml/modeling/kernel_approximation/rbf_sampler.py +105 -134
  119. snowflake/ml/modeling/kernel_approximation/skewed_chi2_sampler.py +103 -132
  120. snowflake/ml/modeling/kernel_ridge/kernel_ridge.py +108 -135
  121. snowflake/ml/modeling/lightgbm/lgbm_classifier.py +90 -118
  122. snowflake/ml/modeling/lightgbm/lgbm_regressor.py +90 -118
  123. snowflake/ml/modeling/linear_model/ard_regression.py +108 -135
  124. snowflake/ml/modeling/linear_model/bayesian_ridge.py +108 -135
  125. snowflake/ml/modeling/linear_model/elastic_net.py +108 -135
  126. snowflake/ml/modeling/linear_model/elastic_net_cv.py +108 -135
  127. snowflake/ml/modeling/linear_model/gamma_regressor.py +108 -135
  128. snowflake/ml/modeling/linear_model/huber_regressor.py +108 -135
  129. snowflake/ml/modeling/linear_model/lars.py +108 -135
  130. snowflake/ml/modeling/linear_model/lars_cv.py +108 -135
  131. snowflake/ml/modeling/linear_model/lasso.py +108 -135
  132. snowflake/ml/modeling/linear_model/lasso_cv.py +108 -135
  133. snowflake/ml/modeling/linear_model/lasso_lars.py +108 -135
  134. snowflake/ml/modeling/linear_model/lasso_lars_cv.py +108 -135
  135. snowflake/ml/modeling/linear_model/lasso_lars_ic.py +108 -135
  136. snowflake/ml/modeling/linear_model/linear_regression.py +108 -135
  137. snowflake/ml/modeling/linear_model/logistic_regression.py +108 -135
  138. snowflake/ml/modeling/linear_model/logistic_regression_cv.py +108 -135
  139. snowflake/ml/modeling/linear_model/multi_task_elastic_net.py +108 -135
  140. snowflake/ml/modeling/linear_model/multi_task_elastic_net_cv.py +108 -135
  141. snowflake/ml/modeling/linear_model/multi_task_lasso.py +108 -135
  142. snowflake/ml/modeling/linear_model/multi_task_lasso_cv.py +108 -135
  143. snowflake/ml/modeling/linear_model/orthogonal_matching_pursuit.py +108 -135
  144. snowflake/ml/modeling/linear_model/passive_aggressive_classifier.py +108 -135
  145. snowflake/ml/modeling/linear_model/passive_aggressive_regressor.py +107 -135
  146. snowflake/ml/modeling/linear_model/perceptron.py +107 -135
  147. snowflake/ml/modeling/linear_model/poisson_regressor.py +108 -135
  148. snowflake/ml/modeling/linear_model/ransac_regressor.py +108 -135
  149. snowflake/ml/modeling/linear_model/ridge.py +108 -135
  150. snowflake/ml/modeling/linear_model/ridge_classifier.py +108 -135
  151. snowflake/ml/modeling/linear_model/ridge_classifier_cv.py +108 -135
  152. snowflake/ml/modeling/linear_model/ridge_cv.py +108 -135
  153. snowflake/ml/modeling/linear_model/sgd_classifier.py +108 -135
  154. snowflake/ml/modeling/linear_model/sgd_one_class_svm.py +106 -135
  155. snowflake/ml/modeling/linear_model/sgd_regressor.py +108 -135
  156. snowflake/ml/modeling/linear_model/theil_sen_regressor.py +108 -135
  157. snowflake/ml/modeling/linear_model/tweedie_regressor.py +108 -135
  158. snowflake/ml/modeling/manifold/isomap.py +106 -135
  159. snowflake/ml/modeling/manifold/mds.py +106 -135
  160. snowflake/ml/modeling/manifold/spectral_embedding.py +106 -135
  161. snowflake/ml/modeling/manifold/tsne.py +106 -135
  162. snowflake/ml/modeling/metrics/classification.py +196 -55
  163. snowflake/ml/modeling/metrics/correlation.py +4 -2
  164. snowflake/ml/modeling/metrics/covariance.py +7 -4
  165. snowflake/ml/modeling/metrics/ranking.py +32 -16
  166. snowflake/ml/modeling/metrics/regression.py +60 -32
  167. snowflake/ml/modeling/mixture/bayesian_gaussian_mixture.py +106 -135
  168. snowflake/ml/modeling/mixture/gaussian_mixture.py +106 -135
  169. snowflake/ml/modeling/model_selection/grid_search_cv.py +91 -148
  170. snowflake/ml/modeling/model_selection/randomized_search_cv.py +93 -154
  171. snowflake/ml/modeling/multiclass/one_vs_one_classifier.py +105 -132
  172. snowflake/ml/modeling/multiclass/one_vs_rest_classifier.py +108 -135
  173. snowflake/ml/modeling/multiclass/output_code_classifier.py +108 -135
  174. snowflake/ml/modeling/naive_bayes/bernoulli_nb.py +108 -135
  175. snowflake/ml/modeling/naive_bayes/categorical_nb.py +108 -135
  176. snowflake/ml/modeling/naive_bayes/complement_nb.py +108 -135
  177. snowflake/ml/modeling/naive_bayes/gaussian_nb.py +98 -125
  178. snowflake/ml/modeling/naive_bayes/multinomial_nb.py +107 -134
  179. snowflake/ml/modeling/neighbors/k_neighbors_classifier.py +108 -135
  180. snowflake/ml/modeling/neighbors/k_neighbors_regressor.py +108 -135
  181. snowflake/ml/modeling/neighbors/kernel_density.py +106 -135
  182. snowflake/ml/modeling/neighbors/local_outlier_factor.py +106 -135
  183. snowflake/ml/modeling/neighbors/nearest_centroid.py +108 -135
  184. snowflake/ml/modeling/neighbors/nearest_neighbors.py +106 -135
  185. snowflake/ml/modeling/neighbors/neighborhood_components_analysis.py +108 -135
  186. snowflake/ml/modeling/neighbors/radius_neighbors_classifier.py +108 -135
  187. snowflake/ml/modeling/neighbors/radius_neighbors_regressor.py +108 -135
  188. snowflake/ml/modeling/neural_network/bernoulli_rbm.py +106 -135
  189. snowflake/ml/modeling/neural_network/mlp_classifier.py +108 -135
  190. snowflake/ml/modeling/neural_network/mlp_regressor.py +108 -135
  191. snowflake/ml/modeling/parameters/disable_distributed_hpo.py +2 -6
  192. snowflake/ml/modeling/preprocessing/binarizer.py +25 -8
  193. snowflake/ml/modeling/preprocessing/k_bins_discretizer.py +9 -4
  194. snowflake/ml/modeling/preprocessing/label_encoder.py +31 -11
  195. snowflake/ml/modeling/preprocessing/max_abs_scaler.py +27 -9
  196. snowflake/ml/modeling/preprocessing/min_max_scaler.py +42 -14
  197. snowflake/ml/modeling/preprocessing/normalizer.py +9 -4
  198. snowflake/ml/modeling/preprocessing/one_hot_encoder.py +26 -10
  199. snowflake/ml/modeling/preprocessing/ordinal_encoder.py +37 -13
  200. snowflake/ml/modeling/preprocessing/polynomial_features.py +106 -135
  201. snowflake/ml/modeling/preprocessing/robust_scaler.py +39 -13
  202. snowflake/ml/modeling/preprocessing/standard_scaler.py +36 -12
  203. snowflake/ml/modeling/semi_supervised/label_propagation.py +108 -135
  204. snowflake/ml/modeling/semi_supervised/label_spreading.py +108 -135
  205. snowflake/ml/modeling/svm/linear_svc.py +108 -135
  206. snowflake/ml/modeling/svm/linear_svr.py +108 -135
  207. snowflake/ml/modeling/svm/nu_svc.py +108 -135
  208. snowflake/ml/modeling/svm/nu_svr.py +108 -135
  209. snowflake/ml/modeling/svm/svc.py +108 -135
  210. snowflake/ml/modeling/svm/svr.py +108 -135
  211. snowflake/ml/modeling/tree/decision_tree_classifier.py +108 -135
  212. snowflake/ml/modeling/tree/decision_tree_regressor.py +108 -135
  213. snowflake/ml/modeling/tree/extra_tree_classifier.py +108 -135
  214. snowflake/ml/modeling/tree/extra_tree_regressor.py +108 -135
  215. snowflake/ml/modeling/xgboost/xgb_classifier.py +108 -136
  216. snowflake/ml/modeling/xgboost/xgb_regressor.py +108 -136
  217. snowflake/ml/modeling/xgboost/xgbrf_classifier.py +108 -136
  218. snowflake/ml/modeling/xgboost/xgbrf_regressor.py +108 -136
  219. snowflake/ml/registry/model_registry.py +2 -0
  220. snowflake/ml/registry/registry.py +215 -0
  221. snowflake/ml/version.py +1 -1
  222. {snowflake_ml_python-1.1.0.dist-info → snowflake_ml_python-1.1.2.dist-info}/METADATA +34 -1
  223. snowflake_ml_python-1.1.2.dist-info/RECORD +347 -0
  224. snowflake_ml_python-1.1.0.dist-info/RECORD +0 -331
  225. {snowflake_ml_python-1.1.0.dist-info → snowflake_ml_python-1.1.2.dist-info}/WHEEL +0 -0
@@ -5,7 +5,6 @@ import warnings
5
5
  from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
6
6
 
7
7
  import cloudpickle
8
- import numpy
9
8
  import numpy as np
10
9
  import numpy.typing as npt
11
10
  from sklearn import exceptions, metrics
@@ -43,12 +42,17 @@ def accuracy_score(
43
42
  corresponding set of labels in the y true columns.
44
43
 
45
44
  Args:
46
- df: Input dataframe.
47
- y_true_col_names: Column name(s) representing actual values.
48
- y_pred_col_names: Column name(s) representing predicted values.
49
- normalize: If ``False``, return the number of correctly classified samples.
45
+ df: snowpark.DataFrame
46
+ Input dataframe.
47
+ y_true_col_names: string or list of strings
48
+ Column name(s) representing actual values.
49
+ y_pred_col_names: string or list of strings
50
+ Column name(s) representing predicted values.
51
+ normalize: boolean, default=True
52
+ If ``False``, return the number of correctly classified samples.
50
53
  Otherwise, return the fraction of correctly classified samples.
51
- sample_weight_col_name: Column name representing sample weights.
54
+ sample_weight_col_name: string, default=None
55
+ Column name representing sample weights.
52
56
 
53
57
  Returns:
54
58
  If ``normalize == True``, return the fraction of correctly
@@ -102,14 +106,19 @@ def confusion_matrix(
102
106
  :math:`C_{1,1}` and false positives is :math:`C_{0,1}`.
103
107
 
104
108
  Args:
105
- df: Input dataframe.
106
- y_true_col_name: Column name representing actual values.
107
- y_pred_col_name: Column name representing predicted values.
108
- labels: List of labels to index the matrix. This may be used to
109
+ df: snowpark.DataFrame
110
+ Input dataframe.
111
+ y_true_col_name: string or list of strings
112
+ Column name representing actual values.
113
+ y_pred_col_name: string or list of strings
114
+ Column name representing predicted values.
115
+ labels: list of labels, default=None
116
+ List of labels to index the matrix. This may be used to
109
117
  reorder or select a subset of labels.
110
118
  If ``None`` is given, those that appear at least once in the
111
119
  y true or y pred column are used in sorted order.
112
- sample_weight_col_name: Column name representing sample weights.
120
+ sample_weight_col_name: string, default=None
121
+ Column name representing sample weights.
113
122
  normalize: {'true', 'pred', 'all'}, default=None
114
123
  Normalizes confusion matrix over the true (rows), predicted (columns)
115
124
  conditions or all the population. If None, confusion matrix will not be
@@ -124,7 +133,9 @@ def confusion_matrix(
124
133
 
125
134
  Raises:
126
135
  ValueError: The given ``labels`` is empty.
136
+
127
137
  ValueError: No label specified in the given ``labels`` is in the y true column.
138
+
128
139
  ValueError: ``normalize`` is not one of {'true', 'pred', 'all', None}.
129
140
  """
130
141
  assert df._session is not None
@@ -252,7 +263,7 @@ def _register_confusion_matrix_computer(*, session: snowpark.Session, statement_
252
263
  self._batched_rows[self._cur_count, :] = input_row
253
264
  self._cur_count += 1
254
265
 
255
- # 2. Compute incremental sum and dot_prod for the batch.
266
+ # 2. Compute incremental confusion matrix for the batch.
256
267
  if self._cur_count >= self.BATCH_SIZE:
257
268
  self.update_confusion_matrix()
258
269
  self._cur_count = 0
@@ -265,10 +276,16 @@ def _register_confusion_matrix_computer(*, session: snowpark.Session, statement_
265
276
  yield cloudpickle.dumps(self._confusion_matrix[i, :]), "row_" + str(i)
266
277
 
267
278
  def update_confusion_matrix(self) -> None:
279
+ # Update the confusion matrix by adding values from the 1st column of the batched rows to specific
280
+ # locations in the confusion matrix determined by row and column indices from the 2nd and 3rd columns of
281
+ # the batched rows.
268
282
  np.add.at(
269
283
  self._confusion_matrix,
270
- (self._batched_rows[:, 1].astype(int), self._batched_rows[:, 2].astype(int)),
271
- self._batched_rows[:, 0],
284
+ (
285
+ self._batched_rows[: self._cur_count][:, 1].astype(int),
286
+ self._batched_rows[: self._cur_count][:, 2].astype(int),
287
+ ),
288
+ self._batched_rows[: self._cur_count][:, 0],
272
289
  )
273
290
 
274
291
  confusion_matrix_computer = snowpark_utils.random_name_for_temp_object(snowpark_utils.TempObjectType.TABLE_FUNCTION)
@@ -317,17 +334,22 @@ def f1_score(
317
334
  parameter.
318
335
 
319
336
  Args:
320
- df: Input dataframe.
321
- y_true_col_names: Column name(s) representing actual values.
322
- y_pred_col_names: Column name(s) representing predicted values.
323
- labels: The set of labels to include when ``average != 'binary'``, and
337
+ df: snowpark.DataFrame
338
+ Input dataframe.
339
+ y_true_col_names: string or list of strings
340
+ Column name(s) representing actual values.
341
+ y_pred_col_names: string or list of strings
342
+ Column name(s) representing predicted values.
343
+ labels: list of labels, default=None
344
+ The set of labels to include when ``average != 'binary'``, and
324
345
  their order if ``average is None``. Labels present in the data can be
325
346
  excluded, for example to calculate a multiclass average ignoring a
326
347
  majority negative class, while labels not present in the data will
327
348
  result in 0 components in a macro average. For multilabel targets,
328
349
  labels are column indices. By default, all labels in the y true and
329
350
  y pred columns are used in sorted order.
330
- pos_label: The class to report if ``average='binary'`` and the data is
351
+ pos_label: string or integer, default=1
352
+ The class to report if ``average='binary'`` and the data is
331
353
  binary. If the data are multiclass or multilabel, this will be ignored;
332
354
  setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
333
355
  scores for that label only.
@@ -353,7 +375,8 @@ def f1_score(
353
375
  Calculate metrics for each instance, and find their average (only
354
376
  meaningful for multilabel classification where this differs from
355
377
  func`accuracy_score`).
356
- sample_weight_col_name: Column name representing sample weights.
378
+ sample_weight_col_name: string, default=None
379
+ Column name representing sample weights.
357
380
  zero_division: "warn", 0 or 1, default="warn"
358
381
  Sets the value to return when there is a zero division, i.e. when all
359
382
  predictions and labels are negative. If set to "warn", this acts as 0,
@@ -402,18 +425,24 @@ def fbeta_score(
402
425
  only recall).
403
426
 
404
427
  Args:
405
- df: Input dataframe.
406
- y_true_col_names: Column name(s) representing actual values.
407
- y_pred_col_names: Column name(s) representing predicted values.
408
- beta: Determines the weight of recall in the combined score.
409
- labels: The set of labels to include when ``average != 'binary'``, and
428
+ df: snowpark.DataFrame
429
+ Input dataframe.
430
+ y_true_col_names: string or list of strings
431
+ Column name(s) representing actual values.
432
+ y_pred_col_names: string or list of strings
433
+ Column name(s) representing predicted values.
434
+ beta: float
435
+ Determines the weight of recall in the combined score.
436
+ labels: list of labels, default=None
437
+ The set of labels to include when ``average != 'binary'``, and
410
438
  their order if ``average is None``. Labels present in the data can be
411
439
  excluded, for example to calculate a multiclass average ignoring a
412
440
  majority negative class, while labels not present in the data will
413
441
  result in 0 components in a macro average. For multilabel targets,
414
442
  labels are column indices. By default, all labels in the y true and
415
443
  y pred columns are used in sorted order.
416
- pos_label: The class to report if ``average='binary'`` and the data is
444
+ pos_label: string or integer, default=1
445
+ The class to report if ``average='binary'`` and the data is
417
446
  binary. If the data are multiclass or multilabel, this will be ignored;
418
447
  setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
419
448
  scores for that label only.
@@ -439,7 +468,8 @@ def fbeta_score(
439
468
  Calculate metrics for each instance, and find their average (only
440
469
  meaningful for multilabel classification where this differs from
441
470
  func`accuracy_score`).
442
- sample_weight_col_name: Column name representing sample weights.
471
+ sample_weight_col_name: string, default=None
472
+ Column name representing sample weights.
443
473
  zero_division: "warn", 0 or 1, default="warn"
444
474
  Sets the value to return when there is a zero division, i.e. when all
445
475
  predictions and labels are negative. If set to "warn", this acts as 0,
@@ -492,9 +522,12 @@ def log_loss(
492
522
  L_{\log}(y, p) = -(y \log (p) + (1 - y) \log (1 - p))
493
523
 
494
524
  Args:
495
- df: Input dataframe.
496
- y_true_col_names: Column name(s) representing actual values.
497
- y_pred_col_names: Column name(s) representing predicted probabilities,
525
+ df: snowpark.DataFrame
526
+ Input dataframe.
527
+ y_true_col_names: string or list of strings
528
+ Column name(s) representing actual values.
529
+ y_pred_col_names: string or list of strings
530
+ Column name(s) representing predicted probabilities,
498
531
  as returned by a classifier's predict_proba method.
499
532
  If ``y_pred.shape = (n_samples,)`` the probabilities provided are
500
533
  assumed to be that of the positive class. The labels in ``y_pred``
@@ -503,10 +536,13 @@ def log_loss(
503
536
  Log loss is undefined for p=0 or p=1, so probabilities are
504
537
  clipped to `max(eps, min(1 - eps, p))`. The default will depend on the
505
538
  data type of `y_pred` and is set to `np.finfo(y_pred.dtype).eps`.
506
- normalize: If true, return the mean loss per sample.
539
+ normalize: boolean, default=True
540
+ If true, return the mean loss per sample.
507
541
  Otherwise, return the sum of the per-sample losses.
508
- sample_weight_col_name: Column name representing sample weights.
509
- labels: If not provided, labels will be inferred from y_true. If ``labels``
542
+ sample_weight_col_name: string, default=None
543
+ Column name representing sample weights.
544
+ labels: list of labels, default=None
545
+ If not provided, labels will be inferred from y_true. If ``labels``
510
546
  is ``None`` and ``y_pred`` has shape (n_samples,) the labels are
511
547
  assumed to be binary and are inferred from ``y_true``.
512
548
 
@@ -691,18 +727,24 @@ def precision_recall_fscore_support(
691
727
  is one of ``'micro'``, ``'macro'``, ``'weighted'`` or ``'samples'``.
692
728
 
693
729
  Args:
694
- df: Input dataframe.
695
- y_true_col_names: Column name(s) representing actual values.
696
- y_pred_col_names: Column name(s) representing predicted values.
697
- beta: The strength of recall versus precision in the F-score.
698
- labels: The set of labels to include when ``average != 'binary'``, and
730
+ df: snowpark.DataFrame
731
+ Input dataframe.
732
+ y_true_col_names: string or list of strings
733
+ Column name(s) representing actual values.
734
+ y_pred_col_names: string or list of strings
735
+ Column name(s) representing predicted values.
736
+ beta: float, default=1.0
737
+ The strength of recall versus precision in the F-score.
738
+ labels: list of labels, default=None
739
+ The set of labels to include when ``average != 'binary'``, and
699
740
  their order if ``average is None``. Labels present in the data can be
700
741
  excluded, for example to calculate a multiclass average ignoring a
701
742
  majority negative class, while labels not present in the data will
702
743
  result in 0 components in a macro average. For multilabel targets,
703
744
  labels are column indices. By default, all labels in the y true and
704
745
  y pred columns are used in sorted order.
705
- pos_label: The class to report if ``average='binary'`` and the data is
746
+ pos_label: string or integer, default=1
747
+ The class to report if ``average='binary'`` and the data is
706
748
  binary. If the data are multiclass or multilabel, this will be ignored;
707
749
  setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
708
750
  scores for that label only.
@@ -727,9 +769,11 @@ def precision_recall_fscore_support(
727
769
  Calculate metrics for each instance, and find their average (only
728
770
  meaningful for multilabel classification where this differs from
729
771
  :func:`accuracy_score`).
730
- warn_for: This determines which warnings will be made in the case that this
772
+ warn_for: tuple or set containing "precision", "recall", or "f-score"
773
+ This determines which warnings will be made in the case that this
731
774
  function is being used to return only one of its metrics.
732
- sample_weight_col_name: Column name representing sample weights.
775
+ sample_weight_col_name: string, default=None
776
+ Column name representing sample weights.
733
777
  zero_division: "warn", 0 or 1, default="warn"
734
778
  Sets the value to return when there is a zero division:
735
779
  * recall - when there are no positive labels
@@ -974,6 +1018,78 @@ def _register_multilabel_confusion_matrix_computer(
974
1018
  return multilabel_confusion_matrix_computer
975
1019
 
976
1020
 
1021
+ def _binary_precision_score(
1022
+ *,
1023
+ df: snowpark.DataFrame,
1024
+ y_true_col_names: Union[str, List[str]],
1025
+ y_pred_col_names: Union[str, List[str]],
1026
+ pos_label: Union[str, int] = 1,
1027
+ sample_weight_col_name: Optional[str] = None,
1028
+ zero_division: Union[str, int] = "warn",
1029
+ ) -> Union[float, npt.NDArray[np.float_]]:
1030
+
1031
+ statement_params = telemetry.get_statement_params(_PROJECT, _SUBPROJECT)
1032
+
1033
+ if isinstance(y_true_col_names, str):
1034
+ y_true_col_names = [y_true_col_names]
1035
+ if isinstance(y_pred_col_names, str):
1036
+ y_pred_col_names = [y_pred_col_names]
1037
+
1038
+ if len(y_pred_col_names) != len(y_true_col_names):
1039
+ raise ValueError(
1040
+ "precision_score: `y_true_col_names` and `y_pred_column_names` must be lists of the same length "
1041
+ "or both strings."
1042
+ )
1043
+
1044
+ # Confirm that the data is binary.
1045
+ labels_set = set()
1046
+ columns = y_true_col_names + y_pred_col_names
1047
+ column_labels_lists = df.select(*[F.array_unique_agg(col) for col in columns]).collect(
1048
+ statement_params=statement_params
1049
+ )[0]
1050
+ for column_labels_list in column_labels_lists:
1051
+ for column_label in json.loads(column_labels_list):
1052
+ labels_set.add(column_label)
1053
+ labels = sorted(list(labels_set))
1054
+ _ = _check_binary_labels(labels, pos_label=pos_label)
1055
+
1056
+ sample_weight_column = df[sample_weight_col_name] if sample_weight_col_name else None
1057
+
1058
+ scores = []
1059
+ for y_true, y_pred in zip(y_true_col_names, y_pred_col_names):
1060
+ tp_col = F.iff((F.col(y_true) == pos_label) & (F.col(y_pred) == pos_label), 1, 0)
1061
+ fp_col = F.iff((F.col(y_true) != pos_label) & (F.col(y_pred) == pos_label), 1, 0)
1062
+ tp = metrics_utils.weighted_sum(
1063
+ df=df,
1064
+ sample_score_column=tp_col,
1065
+ sample_weight_column=sample_weight_column,
1066
+ statement_params=statement_params,
1067
+ )
1068
+ fp = metrics_utils.weighted_sum(
1069
+ df=df,
1070
+ sample_score_column=fp_col,
1071
+ sample_weight_column=sample_weight_column,
1072
+ statement_params=statement_params,
1073
+ )
1074
+
1075
+ try:
1076
+ score = tp / (tp + fp)
1077
+ except ZeroDivisionError:
1078
+ if zero_division == "warn":
1079
+ msg = "precision_score: division by zero: score value will be 0."
1080
+ warnings.warn(msg, exceptions.UndefinedMetricWarning, stacklevel=2)
1081
+ score = 0.0
1082
+ else:
1083
+ score = float(zero_division)
1084
+
1085
+ scores.append(score)
1086
+
1087
+ if len(scores) == 1:
1088
+ return scores[0]
1089
+
1090
+ return np.array(scores)
1091
+
1092
+
977
1093
  @telemetry.send_api_usage_telemetry(project=_PROJECT, subproject=_SUBPROJECT)
978
1094
  def precision_score(
979
1095
  *,
@@ -997,17 +1113,22 @@ def precision_score(
997
1113
  The best value is 1 and the worst value is 0.
998
1114
 
999
1115
  Args:
1000
- df: Input dataframe.
1001
- y_true_col_names: Column name(s) representing actual values.
1002
- y_pred_col_names: Column name(s) representing predicted values.
1003
- labels: The set of labels to include when ``average != 'binary'``, and
1116
+ df: snowpark.DataFrame
1117
+ Input dataframe.
1118
+ y_true_col_names: string or list of strings
1119
+ Column name(s) representing actual values.
1120
+ y_pred_col_names: string or list of strings
1121
+ Column name(s) representing predicted values.
1122
+ labels: list of labels, default=None
1123
+ The set of labels to include when ``average != 'binary'``, and
1004
1124
  their order if ``average is None``. Labels present in the data can be
1005
1125
  excluded, for example to calculate a multiclass average ignoring a
1006
1126
  majority negative class, while labels not present in the data will
1007
1127
  result in 0 components in a macro average. For multilabel targets,
1008
1128
  labels are column indices. By default, all labels in the y true and
1009
1129
  y pred columns are used in sorted order.
1010
- pos_label: The class to report if ``average='binary'`` and the data is
1130
+ pos_label: string or integer, default=1
1131
+ The class to report if ``average='binary'`` and the data is
1011
1132
  binary. If the data are multiclass or multilabel, this will be ignored;
1012
1133
  setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
1013
1134
  scores for that label only.
@@ -1032,7 +1153,8 @@ def precision_score(
1032
1153
  Calculate metrics for each instance, and find their average (only
1033
1154
  meaningful for multilabel classification where this differs from
1034
1155
  func`accuracy_score`).
1035
- sample_weight_col_name: Column name representing sample weights.
1156
+ sample_weight_col_name: string, default=None
1157
+ Column name representing sample weights.
1036
1158
  zero_division: "warn", 0 or 1, default="warn"
1037
1159
  Sets the value to return when there is a zero division. If set to
1038
1160
  "warn", this acts as 0, but warnings are also raised.
@@ -1042,6 +1164,16 @@ def precision_score(
1042
1164
  Precision of the positive class in binary classification or weighted
1043
1165
  average of the precision of each class for the multiclass task.
1044
1166
  """
1167
+ if average == "binary":
1168
+ return _binary_precision_score(
1169
+ df=df,
1170
+ y_true_col_names=y_true_col_names,
1171
+ y_pred_col_names=y_pred_col_names,
1172
+ pos_label=pos_label,
1173
+ sample_weight_col_name=sample_weight_col_name,
1174
+ zero_division=zero_division,
1175
+ )
1176
+
1045
1177
  p, _, _, _ = precision_recall_fscore_support(
1046
1178
  df=df,
1047
1179
  y_true_col_names=y_true_col_names,
@@ -1078,17 +1210,22 @@ def recall_score(
1078
1210
  The best value is 1 and the worst value is 0.
1079
1211
 
1080
1212
  Args:
1081
- df: Input dataframe.
1082
- y_true_col_names: Column name(s) representing actual values.
1083
- y_pred_col_names: Column name(s) representing predicted values.
1084
- labels: The set of labels to include when ``average != 'binary'``, and
1213
+ df: snowpark.DataFrame
1214
+ Input dataframe.
1215
+ y_true_col_names: string or list of strings
1216
+ Column name(s) representing actual values.
1217
+ y_pred_col_names: string or list of strings
1218
+ Column name(s) representing predicted values.
1219
+ labels: list of labels, default=None
1220
+ The set of labels to include when ``average != 'binary'``, and
1085
1221
  their order if ``average is None``. Labels present in the data can be
1086
1222
  excluded, for example to calculate a multiclass average ignoring a
1087
1223
  majority negative class, while labels not present in the data will
1088
1224
  result in 0 components in a macro average. For multilabel targets,
1089
1225
  labels are column indices. By default, all labels in the y true and
1090
1226
  y pred columns are used in sorted order.
1091
- pos_label: The class to report if ``average='binary'`` and the data is
1227
+ pos_label: string or integer, default=1
1228
+ The class to report if ``average='binary'`` and the data is
1092
1229
  binary. If the data are multiclass or multilabel, this will be ignored;
1093
1230
  setting ``labels=[pos_label]`` and ``average != 'binary'`` will report
1094
1231
  scores for that label only.
@@ -1115,7 +1252,8 @@ def recall_score(
1115
1252
  Calculate metrics for each instance, and find their average (only
1116
1253
  meaningful for multilabel classification where this differs from
1117
1254
  func`accuracy_score`).
1118
- sample_weight_col_name: Column name representing sample weights.
1255
+ sample_weight_col_name: string, default=None
1256
+ Column name representing sample weights.
1119
1257
  zero_division: "warn", 0 or 1, default="warn"
1120
1258
  Sets the value to return when there is a zero division. If set to
1121
1259
  "warn", this acts as 0, but warnings are also raised.
@@ -1184,10 +1322,13 @@ def _check_binary_labels(
1184
1322
  """
1185
1323
  if len(labels) <= 2:
1186
1324
  if len(labels) == 2 and pos_label not in labels:
1187
- raise ValueError(f"pos_label={pos_label} is not a valid label. It should be one of {labels}")
1325
+ raise ValueError(f"pos_label={pos_label} is not a valid label. It must be one of {labels}")
1188
1326
  labels = [pos_label]
1189
1327
  else:
1190
- raise ValueError("Please choose another average setting.")
1328
+ raise ValueError(
1329
+ "Cannot compute precision score with binary average: there are more than two labels present."
1330
+ "Please choose another average setting."
1331
+ )
1191
1332
 
1192
1333
  return labels
1193
1334
 
@@ -36,8 +36,10 @@ def correlation(*, df: snowpark.DataFrame, columns: Optional[Collection[str]] =
36
36
  as a post-processing step.
37
37
 
38
38
  Args:
39
- df (snowpark.DataFrame): Snowpark Dataframe for which correlation matrix has to be computed.
40
- columns (Optional[Collection[str]]): List of column names for which the correlation matrix has to be computed.
39
+ df: snowpark.DataFrame
40
+ Snowpark Dataframe for which correlation matrix has to be computed.
41
+ columns: List of strings
42
+ List of column names for which the correlation matrix has to be computed.
41
43
  If None, correlation matrix is computed for all numeric columns in the snowpark dataframe.
42
44
 
43
45
  Returns:
@@ -36,11 +36,14 @@ def covariance(*, df: DataFrame, columns: Optional[Collection[str]] = None, ddof
36
36
  as a post-processing step.
37
37
 
38
38
  Args:
39
- df (DataFrame): Snowpark Dataframe for which covariance matrix has to be computed.
40
- columns (Optional[Collection[str]]): List of column names for which the covariance matrix has to be computed.
39
+ df: snowpark.DataFrame
40
+ Snowpark Dataframe for which covariance matrix has to be computed.
41
+ columns: list of strings, default=None
42
+ List of column names for which the covariance matrix has to be computed.
41
43
  If None, covariance matrix is computed for all numeric columns in the snowpark dataframe.
42
- ddof (int): default 1. Delta degrees of freedom.
43
- The divisor used in calculations is N - ddof, where N represents the number of rows.
44
+ ddof: int, default=1
45
+ Delta degrees of freedom. The divisor used in calculations is N - ddof, where N represents the
46
+ number of rows.
44
47
 
45
48
  Returns:
46
49
  Covariance matrix in pandas.DataFrame format.
@@ -49,18 +49,23 @@ def precision_recall_curve(
49
49
  which corresponds to a classifier that always predicts the positive class.
50
50
 
51
51
  Args:
52
- df: Input dataframe.
53
- y_true_col_name: Column name representing true binary labels.
52
+ df: snowpark.DataFrame
53
+ Input dataframe.
54
+ y_true_col_name: string
55
+ Column name representing true binary labels.
54
56
  If labels are not either {-1, 1} or {0, 1}, then pos_label should be
55
57
  explicitly given.
56
- probas_pred_col_name: Column name representing target scores.
58
+ probas_pred_col_name: string
59
+ Column name representing target scores.
57
60
  Can either be probability estimates of the positive
58
61
  class, or non-thresholded measure of decisions (as returned by
59
62
  `decision_function` on some classifiers).
60
- pos_label: The label of the positive class.
63
+ pos_label: string or int, default=None
64
+ The label of the positive class.
61
65
  When ``pos_label=None``, if y_true is in {-1, 1} or {0, 1},
62
66
  ``pos_label`` is set to 1, otherwise an error will be raised.
63
- sample_weight_col_name: Column name representing sample weights.
67
+ sample_weight_col_name: string, default=None
68
+ Column name representing sample weights.
64
69
 
65
70
  Returns:
66
71
  Tuple containing following items
@@ -142,12 +147,15 @@ def roc_auc_score(
142
147
  multilabel classification, but some restrictions apply.
143
148
 
144
149
  Args:
145
- df: Input dataframe.
146
- y_true_col_names: Column name(s) representing true labels or binary label indicators.
150
+ df: snowpark.DataFrame
151
+ Input dataframe.
152
+ y_true_col_names: string or list of strings
153
+ Column name(s) representing true labels or binary label indicators.
147
154
  The binary and multiclass cases expect labels with shape (n_samples,)
148
155
  while the multilabel case expects binary label indicators with shape
149
156
  (n_samples, n_classes).
150
- y_score_col_names: Column name(s) representing target scores.
157
+ y_score_col_names: string or list of strings
158
+ Column name(s) representing target scores.
151
159
  * In the binary case, it corresponds to an array of shape
152
160
  `(n_samples,)`. Both probability estimates and non-thresholded
153
161
  decision values can be provided. The probability estimates correspond
@@ -186,7 +194,8 @@ def roc_auc_score(
186
194
  ``'samples'``
187
195
  Calculate metrics for each instance, and find their average.
188
196
  Will be ignored when ``y_true`` is binary.
189
- sample_weight_col_name: Column name representing sample weights.
197
+ sample_weight_col_name: string, default=None
198
+ Column name representing sample weights.
190
199
  max_fpr: float > 0 and <= 1, default=None
191
200
  If not ``None``, the standardized partial AUC [2]_ over the range
192
201
  [0, max_fpr] is returned. For the multiclass case, ``max_fpr``,
@@ -208,7 +217,8 @@ def roc_auc_score(
208
217
  possible pairwise combinations of classes [5]_.
209
218
  Insensitive to class imbalance when
210
219
  ``average == 'macro'``.
211
- labels: Only used for multiclass targets. List of labels that index the
220
+ labels: list of labels, default=None
221
+ Only used for multiclass targets. List of labels that index the
212
222
  classes in ``y_score``. If ``None``, the numerical or lexicographical
213
223
  order of the labels in ``y_true`` is used.
214
224
 
@@ -282,19 +292,25 @@ def roc_curve(
282
292
  Note: this implementation is restricted to the binary classification task.
283
293
 
284
294
  Args:
285
- df: Input dataframe.
286
- y_true_col_name: Column name representing true binary labels.
295
+ df: snowpark.DataFrame
296
+ Input dataframe.
297
+ y_true_col_name: string
298
+ Column name representing true binary labels.
287
299
  If labels are not either {-1, 1} or {0, 1}, then pos_label should be
288
300
  explicitly given.
289
- y_score_col_name: Column name representing target scores, can either
301
+ y_score_col_name: string
302
+ Column name representing target scores, can either
290
303
  be probability estimates of the positive class, confidence values,
291
304
  or non-thresholded measure of decisions (as returned by
292
305
  "decision_function" on some classifiers).
293
- pos_label: The label of the positive class.
306
+ pos_label: string, default=None
307
+ The label of the positive class.
294
308
  When ``pos_label=None``, if `y_true` is in {-1, 1} or {0, 1},
295
309
  ``pos_label`` is set to 1, otherwise an error will be raised.
296
- sample_weight_col_name: Column name representing sample weights.
297
- drop_intermediate: Whether to drop some suboptimal thresholds which would
310
+ sample_weight_col_name: string, default=None
311
+ Column name representing sample weights.
312
+ drop_intermediate: boolean, default=True
313
+ Whether to drop some suboptimal thresholds which would
298
314
  not appear on a plotted ROC curve. This is useful in order to create
299
315
  lighter ROC curves.
300
316