wandb 0.18.4__py3-none-any.whl → 0.18.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. wandb/__init__.py +2 -2
  2. wandb/__init__.pyi +21 -19
  3. wandb/agents/pyagent.py +1 -1
  4. wandb/apis/importers/wandb.py +1 -1
  5. wandb/apis/normalize.py +2 -18
  6. wandb/apis/public/api.py +122 -62
  7. wandb/apis/public/artifacts.py +8 -3
  8. wandb/apis/public/files.py +17 -2
  9. wandb/apis/public/jobs.py +2 -2
  10. wandb/apis/public/query_generator.py +1 -1
  11. wandb/apis/public/runs.py +8 -8
  12. wandb/apis/public/teams.py +3 -3
  13. wandb/apis/public/users.py +1 -1
  14. wandb/apis/public/utils.py +68 -0
  15. wandb/bin/gpu_stats +0 -0
  16. wandb/cli/cli.py +12 -3
  17. wandb/data_types.py +1 -1
  18. wandb/docker/__init__.py +2 -1
  19. wandb/docker/auth.py +2 -3
  20. wandb/errors/links.py +73 -0
  21. wandb/errors/term.py +7 -6
  22. wandb/filesync/step_prepare.py +1 -1
  23. wandb/filesync/upload_job.py +1 -1
  24. wandb/integration/catboost/catboost.py +2 -2
  25. wandb/integration/diffusers/pipeline_resolver.py +1 -1
  26. wandb/integration/diffusers/resolvers/multimodal.py +6 -6
  27. wandb/integration/diffusers/resolvers/utils.py +1 -1
  28. wandb/integration/fastai/__init__.py +3 -2
  29. wandb/integration/keras/callbacks/metrics_logger.py +1 -1
  30. wandb/integration/keras/callbacks/model_checkpoint.py +1 -1
  31. wandb/integration/keras/keras.py +1 -1
  32. wandb/integration/kfp/kfp_patch.py +1 -1
  33. wandb/integration/lightgbm/__init__.py +2 -2
  34. wandb/integration/magic.py +2 -2
  35. wandb/integration/metaflow/metaflow.py +1 -1
  36. wandb/integration/sacred/__init__.py +1 -1
  37. wandb/integration/sagemaker/auth.py +1 -1
  38. wandb/integration/sklearn/plot/classifier.py +7 -7
  39. wandb/integration/sklearn/plot/clusterer.py +3 -3
  40. wandb/integration/sklearn/plot/regressor.py +3 -3
  41. wandb/integration/sklearn/plot/shared.py +2 -2
  42. wandb/integration/tensorboard/log.py +2 -2
  43. wandb/integration/ultralytics/callback.py +2 -2
  44. wandb/integration/xgboost/xgboost.py +1 -1
  45. wandb/jupyter.py +0 -1
  46. wandb/plot/__init__.py +17 -8
  47. wandb/plot/bar.py +53 -27
  48. wandb/plot/confusion_matrix.py +151 -70
  49. wandb/plot/custom_chart.py +124 -0
  50. wandb/plot/histogram.py +46 -20
  51. wandb/plot/line.py +57 -26
  52. wandb/plot/line_series.py +148 -60
  53. wandb/plot/pr_curve.py +89 -44
  54. wandb/plot/roc_curve.py +82 -37
  55. wandb/plot/scatter.py +53 -20
  56. wandb/plot/viz.py +20 -102
  57. wandb/sdk/artifacts/artifact.py +280 -328
  58. wandb/sdk/artifacts/artifact_manifest.py +10 -9
  59. wandb/sdk/artifacts/artifact_manifest_entry.py +1 -1
  60. wandb/sdk/artifacts/storage_handlers/azure_handler.py +9 -4
  61. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +1 -3
  62. wandb/sdk/artifacts/storage_handlers/s3_handler.py +1 -1
  63. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +2 -2
  64. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -1
  65. wandb/sdk/backend/backend.py +0 -1
  66. wandb/sdk/data_types/audio.py +1 -1
  67. wandb/sdk/data_types/base_types/media.py +66 -5
  68. wandb/sdk/data_types/bokeh.py +1 -1
  69. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +1 -1
  70. wandb/sdk/data_types/helper_types/image_mask.py +2 -2
  71. wandb/sdk/data_types/histogram.py +1 -1
  72. wandb/sdk/data_types/html.py +1 -1
  73. wandb/sdk/data_types/image.py +1 -1
  74. wandb/sdk/data_types/molecule.py +3 -3
  75. wandb/sdk/data_types/object_3d.py +4 -4
  76. wandb/sdk/data_types/plotly.py +1 -1
  77. wandb/sdk/data_types/saved_model.py +0 -1
  78. wandb/sdk/data_types/table.py +7 -7
  79. wandb/sdk/data_types/trace_tree.py +1 -1
  80. wandb/sdk/data_types/video.py +4 -3
  81. wandb/sdk/interface/router.py +0 -2
  82. wandb/sdk/internal/datastore.py +1 -1
  83. wandb/sdk/internal/file_pusher.py +1 -1
  84. wandb/sdk/internal/file_stream.py +4 -4
  85. wandb/sdk/internal/handler.py +3 -2
  86. wandb/sdk/internal/internal.py +1 -1
  87. wandb/sdk/internal/internal_api.py +183 -64
  88. wandb/sdk/internal/job_builder.py +4 -3
  89. wandb/sdk/internal/system/assets/__init__.py +0 -2
  90. wandb/sdk/internal/tb_watcher.py +11 -10
  91. wandb/sdk/launch/_launch.py +4 -3
  92. wandb/sdk/launch/_launch_add.py +2 -2
  93. wandb/sdk/launch/builder/kaniko_builder.py +0 -1
  94. wandb/sdk/launch/create_job.py +1 -0
  95. wandb/sdk/launch/environment/local_environment.py +0 -1
  96. wandb/sdk/launch/errors.py +0 -6
  97. wandb/sdk/launch/registry/local_registry.py +0 -2
  98. wandb/sdk/launch/runner/abstract.py +0 -5
  99. wandb/sdk/launch/sweeps/__init__.py +0 -2
  100. wandb/sdk/launch/sweeps/scheduler.py +0 -2
  101. wandb/sdk/launch/sweeps/scheduler_sweep.py +0 -1
  102. wandb/sdk/lib/apikey.py +3 -3
  103. wandb/sdk/lib/file_stream_utils.py +1 -1
  104. wandb/sdk/lib/filesystem.py +1 -1
  105. wandb/sdk/lib/ipython.py +16 -9
  106. wandb/sdk/lib/mailbox.py +0 -4
  107. wandb/sdk/lib/printer.py +44 -8
  108. wandb/sdk/lib/retry.py +1 -1
  109. wandb/sdk/service/service.py +3 -3
  110. wandb/sdk/service/streams.py +2 -4
  111. wandb/sdk/wandb_init.py +20 -20
  112. wandb/sdk/wandb_login.py +1 -1
  113. wandb/sdk/wandb_require.py +1 -4
  114. wandb/sdk/wandb_run.py +57 -69
  115. wandb/sdk/wandb_settings.py +3 -4
  116. wandb/sdk/wandb_sync.py +2 -1
  117. wandb/util.py +46 -18
  118. wandb/wandb_agent.py +3 -3
  119. wandb/wandb_controller.py +2 -2
  120. {wandb-0.18.4.dist-info → wandb-0.18.6.dist-info}/METADATA +1 -1
  121. {wandb-0.18.4.dist-info → wandb-0.18.6.dist-info}/RECORD +124 -125
  122. wandb/sdk/internal/system/assets/gpu_apple.py +0 -177
  123. wandb/sdk/lib/_wburls_generate.py +0 -25
  124. wandb/sdk/lib/_wburls_generated.py +0 -22
  125. wandb/sdk/lib/wburls.py +0 -46
  126. {wandb-0.18.4.dist-info → wandb-0.18.6.dist-info}/WHEEL +0 -0
  127. {wandb-0.18.4.dist-info → wandb-0.18.6.dist-info}/entry_points.txt +0 -0
  128. {wandb-0.18.4.dist-info → wandb-0.18.6.dist-info}/licenses/LICENSE +0 -0
wandb/plot/histogram.py CHANGED
@@ -1,39 +1,65 @@
1
- from typing import TYPE_CHECKING, Optional
1
+ from __future__ import annotations
2
2
 
3
- from wandb.plot.viz import custom_chart
3
+ from typing import TYPE_CHECKING
4
+
5
+ from wandb.plot.custom_chart import plot_table
4
6
 
5
7
  if TYPE_CHECKING:
6
8
  import wandb
9
+ from wandb.plot.custom_chart import CustomChart
7
10
 
8
11
 
9
12
  def histogram(
10
- table: "wandb.Table",
13
+ table: wandb.Table,
11
14
  value: str,
12
- title: Optional[str] = None,
13
- split_table: Optional[bool] = False,
14
- ):
15
- """Construct a histogram plot.
15
+ title: str = "",
16
+ split_table: bool = False,
17
+ ) -> CustomChart:
18
+ """Constructs a histogram chart from a W&B Table.
16
19
 
17
- Arguments:
18
- table (wandb.Table): Table of data.
19
- value (string): Name of column to use as data for bucketing.
20
- title (string): Plot title.
21
- split_table (bool): If True, adds "Custom Chart Tables/" to the key of the table so that it's logged in a different section.
20
+ Args:
21
+ table (wandb.Table): The W&B Table containing the data for the histogram.
22
+ value (str): The label for the bin axis (x-axis).
23
+ title (str): The title of the histogram plot.
24
+ split_table (bool): Whether the table should be split into a separate section
25
+ in the W&B UI. If `True`, the table will be displayed in a section named
26
+ "Custom Chart Tables". Default is `False`.
22
27
 
23
28
  Returns:
24
- A plot object, to be passed to wandb.log()
29
+ CustomChart: A custom chart object that can be logged to W&B. To log the
30
+ chart, pass it to `wandb.log()`.
25
31
 
26
32
  Example:
27
33
  ```
34
+ import math
35
+ import random
36
+ import wandb
37
+
38
+ # Generate random data
28
39
  data = [[i, random.random() + math.sin(i / 10)] for i in range(100)]
29
- table = wandb.Table(data=data, columns=["step", "height"])
30
- wandb.log({'histogram-plot1': wandb.plot.histogram(table, "height")})
40
+
41
+ # Create a W&B Table
42
+ table = wandb.Table(
43
+ data=data,
44
+ columns=["step", "height"],
45
+ )
46
+
47
+ # Create a histogram plot
48
+ histogram = wandb.plot.histogram(
49
+ table,
50
+ value="height",
51
+ title="My Histogram",
52
+ )
53
+
54
+ # Log the histogram plot to W&B
55
+ with wandb.init(...) as run:
56
+ run.log({'histogram-plot1': histogram})
31
57
  ```
32
58
  """
33
- return custom_chart(
34
- "wandb/histogram/v0",
35
- table,
36
- {"value": value},
37
- {"title": title},
59
+ return plot_table(
60
+ data_table=table,
61
+ vega_spec_name="wandb/histogram/v0",
62
+ fields={"value": value},
63
+ string_fields={"title": title},
38
64
  split_table=split_table,
39
65
  )
wandb/plot/line.py CHANGED
@@ -1,43 +1,74 @@
1
- from typing import TYPE_CHECKING, Optional
1
+ from __future__ import annotations
2
2
 
3
- from wandb.plot.viz import custom_chart
3
+ from typing import TYPE_CHECKING
4
+
5
+ from wandb.plot.custom_chart import plot_table
4
6
 
5
7
  if TYPE_CHECKING:
6
8
  import wandb
9
+ from wandb.plot.custom_chart import CustomChart
7
10
 
8
11
 
9
12
  def line(
10
- table: "wandb.Table",
13
+ table: wandb.Table,
11
14
  x: str,
12
15
  y: str,
13
- stroke: Optional[str] = None,
14
- title: Optional[str] = None,
15
- split_table: Optional[bool] = False,
16
- ):
17
- """Construct a line plot.
18
-
19
- Arguments:
20
- table (wandb.Table): Table of data.
21
- x (string): Name of column to as for x-axis values.
22
- y (string): Name of column to as for y-axis values.
23
- stroke (string): Name of column to map to the line stroke scale.
24
- title (string): Plot title.
25
- split_table (bool): If True, adds "Custom Chart Tables/" to the key of the table so that it's logged in a different section.
16
+ stroke: str | None = None,
17
+ title: str = "",
18
+ split_table: bool = False,
19
+ ) -> CustomChart:
20
+ """Constructs a customizable line chart.
21
+
22
+ Args:
23
+ table (wandb.Table): The table containing data for the chart.
24
+ x (str): Column name for the x-axis values.
25
+ y (str): Column name for the y-axis values.
26
+ stroke (str):Column name to differentiate line strokes (e.g., for
27
+ grouping lines).
28
+ title (str):Title of the chart.
29
+ split_table (bool): Whether the table should be split into a separate section
30
+ in the W&B UI. If `True`, the table will be displayed in a section named
31
+ "Custom Chart Tables". Default is `False`.
26
32
 
27
33
  Returns:
28
- A plot object, to be passed to wandb.log()
34
+ CustomChart: A custom chart object that can be logged to W&B. To log the
35
+ chart, pass it to `wandb.log()`.
29
36
 
30
37
  Example:
31
- ```
32
- data = [[i, random.random() + math.sin(i / 10)] for i in range(100)]
33
- table = wandb.Table(data=data, columns=["step", "height"])
34
- wandb.log({'line-plot1': wandb.plot.line(table, "step", "height")})
38
+ ```python
39
+ import math
40
+ import random
41
+ import wandb
42
+
43
+ # Create multiple series of data with different patterns
44
+ data = []
45
+ for i in range(100):
46
+ # Series 1: Sinusoidal pattern with random noise
47
+ data.append([i, math.sin(i / 10) + random.uniform(-0.1, 0.1), "series_1"])
48
+ # Series 2: Cosine pattern with random noise
49
+ data.append([i, math.cos(i / 10) + random.uniform(-0.1, 0.1), "series_2"])
50
+ # Series 3: Linear increase with random noise
51
+ data.append([i, i / 10 + random.uniform(-0.5, 0.5), "series_3"])
52
+
53
+ # Define the columns for the table
54
+ table = wandb.Table(data=data, columns=["step", "value", "series"])
55
+
56
+ # Initialize wandb run and log the line chart
57
+ with wandb.init(project="line_chart_example") as run:
58
+ line_chart = wandb.plot.line(
59
+ table=table,
60
+ x="step",
61
+ y="value",
62
+ stroke="series", # Group by the "series" column
63
+ title="Multi-Series Line Plot",
64
+ )
65
+ run.log({"line-chart": line_chart})
35
66
  ```
36
67
  """
37
- return custom_chart(
38
- "wandb/line/v0",
39
- table,
40
- {"x": x, "y": y, "stroke": stroke},
41
- {"title": title},
68
+ return plot_table(
69
+ data_table=table,
70
+ vega_spec_name="wandb/line/v0",
71
+ fields={"x": x, "y": y, "stroke": stroke},
72
+ string_fields={"title": title},
42
73
  split_table=split_table,
43
74
  )
wandb/plot/line_series.py CHANGED
@@ -1,88 +1,176 @@
1
- import typing as t
2
- from collections.abc import Iterable
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any, Iterable
3
4
 
4
5
  import wandb
6
+ from wandb.plot.custom_chart import plot_table
7
+
8
+ if TYPE_CHECKING:
9
+ from wandb.plot.custom_chart import CustomChart
5
10
 
6
11
 
7
12
  def line_series(
8
- xs: t.Union[t.Iterable, t.Iterable[t.Iterable]],
9
- ys: t.Iterable[t.Iterable],
10
- keys: t.Optional[t.Iterable] = None,
11
- title: t.Optional[str] = None,
12
- xname: t.Optional[str] = None,
13
- split_table: t.Optional[bool] = False,
14
- ):
15
- """Construct a line series plot.
16
-
17
- Arguments:
18
- xs (array of arrays, or array): Array of arrays of x values
19
- ys (array of arrays): Array of y values
20
- keys (array): Array of labels for the line plots
21
- title (string): Plot title.
22
- xname: Title of x-axis
23
- split_table (bool): If True, adds "Custom Chart Tables/" to the key of the table so that it's logged in a different section.
13
+ xs: Iterable[Iterable[Any]] | Iterable[Any],
14
+ ys: Iterable[Iterable[Any]],
15
+ keys: Iterable[str] | None = None,
16
+ title: str = "",
17
+ xname: str = "x",
18
+ split_table: bool = False,
19
+ ) -> CustomChart:
20
+ """Constructs a line series chart.
21
+
22
+ Args:
23
+ xs (Iterable[Iterable] | Iterable): Sequence of x values. If a singular
24
+ array is provided, all y values are plotted against that x array. If
25
+ an array of arrays is provided, each y value is plotted against the
26
+ corresponding x array.
27
+ ys (Iterable[Iterable]): Sequence of y values, where each iterable represents
28
+ a separate line series.
29
+ keys (Iterable[str]): Sequence of keys for labeling each line series. If
30
+ not provided, keys will be automatically generated as "line_1",
31
+ "line_2", etc.
32
+ title (str): Title of the chart.
33
+ xname (str): Label for the x-axis.
34
+ split_table (bool): Whether the table should be split into a separate section
35
+ in the W&B UI. If `True`, the table will be displayed in a section named
36
+ "Custom Chart Tables". Default is `False`.
24
37
 
25
38
  Returns:
26
- A plot object, to be passed to wandb.log()
39
+ CustomChart: A custom chart object that can be logged to W&B. To log the
40
+ chart, pass it to `wandb.log()`.
27
41
 
28
- Example:
29
- When logging a singular array for xs, all ys are plotted against that xs
30
- <!--yeadoc-test:plot-line-series-single-->
31
- ```python
42
+ Examples:
43
+ 1. Logging a single x array where all y series are plotted against
44
+ the same x values:
45
+
46
+ ```
32
47
  import wandb
33
48
 
34
- run = wandb.init()
35
- xs = [i for i in range(10)]
36
- ys = [[i for i in range(10)], [i**2 for i in range(10)]]
37
- run.log(
38
- {"line-series-plot1": wandb.plot.line_series(xs, ys, title="title", xname="step")}
39
- )
40
- run.finish()
49
+ # Initialize W&B run
50
+ with wandb.init(project="line_series_example") as run:
51
+ # x values shared across all y series
52
+ xs = list(range(10))
53
+
54
+ # Multiple y series to plot
55
+ ys = [
56
+ [i for i in range(10)], # y = x
57
+ [i**2 for i in range(10)], # y = x^2
58
+ [i**3 for i in range(10)], # y = x^3
59
+ ]
60
+
61
+ # Generate and log the line series chart
62
+ line_series_chart = wandb.plot.line_series(
63
+ xs,
64
+ ys,
65
+ title="title",
66
+ xname="step",
67
+ )
68
+ run.log({"line-series-single-x": line_series_chart})
41
69
  ```
42
- xs can also contain an array of arrays for having different steps for each metric
43
- <!--yeadoc-test:plot-line-series-double-->
70
+
71
+ In this example, a single `xs` series (shared x-values) is used for all
72
+ `ys` series. This results in each y-series being plotted against the
73
+ same x-values (0-9).
74
+
75
+ 2. Logging multiple x arrays where each y series is plotted against
76
+ its corresponding x array:
77
+
44
78
  ```python
45
79
  import wandb
46
80
 
47
- run = wandb.init()
48
- xs = [[i for i in range(10)], [2 * i for i in range(10)]]
49
- ys = [[i for i in range(10)], [i**2 for i in range(10)]]
50
- run.log(
51
- {"line-series-plot2": wandb.plot.line_series(xs, ys, title="title", xname="step")}
52
- )
53
- run.finish()
81
+ # Initialize W&B run
82
+ with wandb.init(project="line_series_example") as run:
83
+ # Separate x values for each y series
84
+ xs = [
85
+ [i for i in range(10)], # x for first series
86
+ [2 * i for i in range(10)], # x for second series (stretched)
87
+ [3 * i for i in range(10)], # x for third series (stretched more)
88
+ ]
89
+
90
+ # Corresponding y series
91
+ ys = [
92
+ [i for i in range(10)], # y = x
93
+ [i**2 for i in range(10)], # y = x^2
94
+ [i**3 for i in range(10)], # y = x^3
95
+ ]
96
+
97
+ # Generate and log the line series chart
98
+ line_series_chart = wandb.plot.line_series(
99
+ xs, ys, title="Multiple X Arrays Example", xname="Step"
100
+ )
101
+ run.log({"line-series-multiple-x": line_series_chart})
54
102
  ```
55
- """
56
- if not isinstance(xs, Iterable):
57
- raise TypeError(f"Expected xs to be an array instead got {type(xs)}")
58
103
 
59
- if not isinstance(ys, Iterable):
60
- raise TypeError(f"Expected ys to be an array instead got {type(xs)}")
104
+ In this example, each y series is plotted against its own unique x series.
105
+ This allows for more flexibility when the x values are not uniform across
106
+ the data series.
107
+
108
+ 3. Customizing line labels using `keys`:
109
+
110
+ ```python
111
+ import wandb
112
+
113
+ # Initialize W&B run
114
+ with wandb.init(project="line_series_example") as run:
115
+ xs = list(range(10)) # Single x array
116
+ ys = [
117
+ [i for i in range(10)], # y = x
118
+ [i**2 for i in range(10)], # y = x^2
119
+ [i**3 for i in range(10)], # y = x^3
120
+ ]
61
121
 
62
- for y in ys:
63
- if not isinstance(y, Iterable):
64
- raise TypeError(
65
- f"Expected ys to be an array of arrays instead got {type(y)}"
122
+ # Custom labels for each line
123
+ keys = ["Linear", "Quadratic", "Cubic"]
124
+
125
+ # Generate and log the line series chart
126
+ line_series_chart = wandb.plot.line_series(
127
+ xs,
128
+ ys,
129
+ keys=keys, # Custom keys (line labels)
130
+ title="Custom Line Labels Example",
131
+ xname="Step",
66
132
  )
133
+ run.log({"line-series-custom-keys": line_series_chart})
134
+ ```
67
135
 
136
+ This example shows how to provide custom labels for the lines using
137
+ the `keys` argument. The keys will appear in the legend as "Linear",
138
+ "Quadratic", and "Cubic".
139
+
140
+ """
141
+ # If xs is a single array, repeat it for each y in ys
68
142
  if not isinstance(xs[0], Iterable) or isinstance(xs[0], (str, bytes)):
69
- xs = [xs for _ in range(len(ys))]
70
- assert len(xs) == len(ys), "Number of x-lines and y-lines must match"
143
+ xs = [xs] * len(ys)
144
+
145
+ if len(xs) != len(ys):
146
+ msg = f"Number of x-series ({len(xs)}) must match y-series ({len(ys)})."
147
+ raise ValueError(msg)
148
+
149
+ if keys is None:
150
+ keys = [f"line_{i}" for i in range(len(ys))]
151
+
152
+ if len(keys) != len(ys):
153
+ msg = f"Number of keys ({len(keys)}) must match y-series ({len(ys)})."
154
+ raise ValueError(msg)
71
155
 
72
- if keys is not None:
73
- assert len(keys) == len(ys), "Number of keys and y-lines must match"
74
156
  data = [
75
- [x, f"key_{i}" if keys is None else keys[i], y]
157
+ [x, keys[i], y]
76
158
  for i, (xx, yy) in enumerate(zip(xs, ys))
77
159
  for x, y in zip(xx, yy)
78
160
  ]
161
+ table = wandb.Table(
162
+ data=data,
163
+ columns=["step", "lineKey", "lineVal"],
164
+ )
79
165
 
80
- table = wandb.Table(data=data, columns=["step", "lineKey", "lineVal"])
81
-
82
- return wandb.plot_table(
83
- "wandb/lineseries/v0",
84
- table,
85
- {"step": "step", "lineKey": "lineKey", "lineVal": "lineVal"},
86
- {"title": title, "xname": xname or "x"},
166
+ return plot_table(
167
+ data_table=table,
168
+ vega_spec_name="wandb/lineseries/v0",
169
+ fields={
170
+ "step": "step",
171
+ "lineKey": "lineKey",
172
+ "lineVal": "lineVal",
173
+ },
174
+ string_fields={"title": title, "xname": xname},
87
175
  split_table=split_table,
88
176
  )
wandb/plot/pr_curve.py CHANGED
@@ -1,48 +1,92 @@
1
- from typing import Optional
1
+ from __future__ import annotations
2
+
3
+ import numbers
4
+ from typing import TYPE_CHECKING, Iterable, TypeVar
2
5
 
3
6
  import wandb
4
7
  from wandb import util
8
+ from wandb.plot.custom_chart import plot_table
9
+ from wandb.plot.utils import test_missing, test_types
10
+
11
+ if TYPE_CHECKING:
12
+ from wandb.plot.custom_chart import CustomChart
13
+
5
14
 
6
- from .utils import test_missing, test_types
15
+ T = TypeVar("T")
7
16
 
8
17
 
9
18
  def pr_curve(
10
- y_true=None,
11
- y_probas=None,
12
- labels=None,
13
- classes_to_plot=None,
14
- interp_size=21,
15
- title=None,
16
- split_table: Optional[bool] = False,
17
- ):
18
- """Compute the tradeoff between precision and recall for different thresholds.
19
-
20
- A high area under the curve represents both high recall and high precision, where
21
- high precision relates to a low false positive rate, and high recall relates to a
22
- low false negative rate. High scores for both show that the classifier is returning
23
- accurate results (high precision), and returning a majority of all positive results
24
- (high recall). PR curve is useful when the classes are very imbalanced.
25
-
26
- Arguments:
27
- y_true (arr): true sparse labels y_probas (arr): Target scores, can either be
28
- probability estimates, confidence values, or non-thresholded measure of
29
- decisions. shape: (*y_true.shape, num_classes)
30
- labels (list): Named labels for target variable (y). Makes plots easier to read
31
- by replacing target values with corresponding index. For example labels =
32
- ['dog', 'cat', 'owl'] all 0s are replaced by 'dog', 1s by 'cat'.
33
- classes_to_plot (list): unique values of y_true to include in the plot
34
- interp_size (int): the recall values will be fixed to `interp_size` points
35
- uniform on [0, 1] and the precision will be interpolated for these recall
36
- values.
37
- split_table (bool): If True, adds "Custom Chart Tables/" to the key of the table so that it's logged in a different section.
19
+ y_true: Iterable[T] | None = None,
20
+ y_probas: Iterable[numbers.Number] | None = None,
21
+ labels: list[str] | None = None,
22
+ classes_to_plot: list[T] | None = None,
23
+ interp_size: int = 21,
24
+ title: str = "Precision-Recall Curve",
25
+ split_table: bool = False,
26
+ ) -> CustomChart:
27
+ """Constructs a Precision-Recall (PR) curve.
28
+
29
+ The Precision-Recall curve is particularly useful for evaluating classifiers
30
+ on imbalanced datasets. A high area under the PR curve signifies both high
31
+ precision (a low false positive rate) and high recall (a low false negative
32
+ rate). The curve provides insights into the balance between false positives
33
+ and false negatives at various threshold levels, aiding in the assessment of
34
+ a model's performance.
35
+
36
+ Args:
37
+ y_true (Iterable): True binary labels. The shape should be (`num_samples`,).
38
+ y_probas (Iterable): Predicted scores or probabilities for each class.
39
+ These can be probability estimates, confidence scores, or non-thresholded
40
+ decision values. The shape should be (`num_samples`, `num_classes`).
41
+ labels (list[str] | None): Optional list of class names to replace
42
+ numeric values in `y_true` for easier plot interpretation.
43
+ For example, `labels = ['dog', 'cat', 'owl']` will replace 0 with
44
+ 'dog', 1 with 'cat', and 2 with 'owl' in the plot. If not provided,
45
+ numeric values from `y_true` will be used.
46
+ classes_to_plot (list | None): Optional list of unique class values from
47
+ y_true to be included in the plot. If not specified, all unique
48
+ classes in y_true will be plotted.
49
+ interp_size (int): Number of points to interpolate recall values. The
50
+ recall values will be fixed to `interp_size` uniformly distributed
51
+ points in the range [0, 1], and the precision will be interpolated
52
+ accordingly.
53
+ title (str): Title of the plot. Defaults to "Precision-Recall Curve".
54
+ split_table (bool): Whether the table should be split into a separate section
55
+ in the W&B UI. If `True`, the table will be displayed in a section named
56
+ "Custom Chart Tables". Default is `False`.
38
57
 
39
58
  Returns:
40
- Nothing. To see plots, go to your W&B run page then expand the 'media' tab under
41
- 'auto visualizations'.
59
+ CustomChart: A custom chart object that can be logged to W&B. To log the
60
+ chart, pass it to `wandb.log()`.
61
+
62
+ Raises:
63
+ wandb.Error: If numpy, pandas, or scikit-learn is not installed.
64
+
42
65
 
43
66
  Example:
44
67
  ```
45
- wandb.log({"pr-curve": wandb.plot.pr_curve(y_true, y_probas, labels)})
68
+ import wandb
69
+
70
+ # Example for spam detection (binary classification)
71
+ y_true = [0, 1, 1, 0, 1] # 0 = not spam, 1 = spam
72
+ y_probas = [
73
+ [0.9, 0.1], # Predicted probabilities for the first sample (not spam)
74
+ [0.2, 0.8], # Second sample (spam), and so on
75
+ [0.1, 0.9],
76
+ [0.8, 0.2],
77
+ [0.3, 0.7]
78
+ ]
79
+
80
+ labels = ['not spam', 'spam'] # Optional class names for readability
81
+
82
+ with wandb.init(project="spam-detection") as run:
83
+ pr_curve = wandb.plot.pr_curve(
84
+ y_true=y_true,
85
+ y_probas=y_probas,
86
+ labels=labels,
87
+ title="Precision-Recall Curve for Spam Detection",
88
+ )
89
+ run.log({"pr-curve": pr_curve})
46
90
  ```
47
91
  """
48
92
  np = util.get_module(
@@ -80,7 +124,7 @@ def pr_curve(
80
124
  if classes_to_plot is None:
81
125
  classes_to_plot = classes
82
126
 
83
- precision = dict()
127
+ precision = {}
84
128
  interp_recall = np.linspace(0, 1, interp_size)[::-1]
85
129
  indices_to_plot = np.where(np.isin(classes, classes_to_plot))[0]
86
130
  for i in indices_to_plot:
@@ -109,12 +153,11 @@ def pr_curve(
109
153
  "precision": np.hstack(list(precision.values())),
110
154
  "recall": np.tile(interp_recall, len(precision)),
111
155
  }
112
- )
113
- df = df.round(3)
156
+ ).round(3)
114
157
 
115
158
  if len(df) > wandb.Table.MAX_ROWS:
116
159
  wandb.termwarn(
117
- "wandb uses only %d data points to create the plots." % wandb.Table.MAX_ROWS
160
+ f"Table has a limit of {wandb.Table.MAX_ROWS} rows. Resampling to fit."
118
161
  )
119
162
  # different sampling could be applied, possibly to ensure endpoints are kept
120
163
  df = sklearn_utils.resample(
@@ -125,12 +168,14 @@ def pr_curve(
125
168
  stratify=df["class"],
126
169
  ).sort_values(["precision", "recall", "class"])
127
170
 
128
- table = wandb.Table(dataframe=df)
129
- title = title or "Precision v. Recall"
130
- return wandb.plot_table(
131
- "wandb/area-under-curve/v0",
132
- table,
133
- {"x": "recall", "y": "precision", "class": "class"},
134
- {"title": title},
171
+ return plot_table(
172
+ data_table=wandb.Table(dataframe=df),
173
+ vega_spec_name="wandb/area-under-curve/v0",
174
+ fields={
175
+ "x": "recall",
176
+ "y": "precision",
177
+ "class": "class",
178
+ },
179
+ string_fields={"title": title},
135
180
  split_table=split_table,
136
181
  )