wandb 0.18.4__py3-none-any.whl → 0.18.6__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (128) hide show
  1. wandb/__init__.py +2 -2
  2. wandb/__init__.pyi +21 -19
  3. wandb/agents/pyagent.py +1 -1
  4. wandb/apis/importers/wandb.py +1 -1
  5. wandb/apis/normalize.py +2 -18
  6. wandb/apis/public/api.py +122 -62
  7. wandb/apis/public/artifacts.py +8 -3
  8. wandb/apis/public/files.py +17 -2
  9. wandb/apis/public/jobs.py +2 -2
  10. wandb/apis/public/query_generator.py +1 -1
  11. wandb/apis/public/runs.py +8 -8
  12. wandb/apis/public/teams.py +3 -3
  13. wandb/apis/public/users.py +1 -1
  14. wandb/apis/public/utils.py +68 -0
  15. wandb/bin/gpu_stats +0 -0
  16. wandb/cli/cli.py +12 -3
  17. wandb/data_types.py +1 -1
  18. wandb/docker/__init__.py +2 -1
  19. wandb/docker/auth.py +2 -3
  20. wandb/errors/links.py +73 -0
  21. wandb/errors/term.py +7 -6
  22. wandb/filesync/step_prepare.py +1 -1
  23. wandb/filesync/upload_job.py +1 -1
  24. wandb/integration/catboost/catboost.py +2 -2
  25. wandb/integration/diffusers/pipeline_resolver.py +1 -1
  26. wandb/integration/diffusers/resolvers/multimodal.py +6 -6
  27. wandb/integration/diffusers/resolvers/utils.py +1 -1
  28. wandb/integration/fastai/__init__.py +3 -2
  29. wandb/integration/keras/callbacks/metrics_logger.py +1 -1
  30. wandb/integration/keras/callbacks/model_checkpoint.py +1 -1
  31. wandb/integration/keras/keras.py +1 -1
  32. wandb/integration/kfp/kfp_patch.py +1 -1
  33. wandb/integration/lightgbm/__init__.py +2 -2
  34. wandb/integration/magic.py +2 -2
  35. wandb/integration/metaflow/metaflow.py +1 -1
  36. wandb/integration/sacred/__init__.py +1 -1
  37. wandb/integration/sagemaker/auth.py +1 -1
  38. wandb/integration/sklearn/plot/classifier.py +7 -7
  39. wandb/integration/sklearn/plot/clusterer.py +3 -3
  40. wandb/integration/sklearn/plot/regressor.py +3 -3
  41. wandb/integration/sklearn/plot/shared.py +2 -2
  42. wandb/integration/tensorboard/log.py +2 -2
  43. wandb/integration/ultralytics/callback.py +2 -2
  44. wandb/integration/xgboost/xgboost.py +1 -1
  45. wandb/jupyter.py +0 -1
  46. wandb/plot/__init__.py +17 -8
  47. wandb/plot/bar.py +53 -27
  48. wandb/plot/confusion_matrix.py +151 -70
  49. wandb/plot/custom_chart.py +124 -0
  50. wandb/plot/histogram.py +46 -20
  51. wandb/plot/line.py +57 -26
  52. wandb/plot/line_series.py +148 -60
  53. wandb/plot/pr_curve.py +89 -44
  54. wandb/plot/roc_curve.py +82 -37
  55. wandb/plot/scatter.py +53 -20
  56. wandb/plot/viz.py +20 -102
  57. wandb/sdk/artifacts/artifact.py +280 -328
  58. wandb/sdk/artifacts/artifact_manifest.py +10 -9
  59. wandb/sdk/artifacts/artifact_manifest_entry.py +1 -1
  60. wandb/sdk/artifacts/storage_handlers/azure_handler.py +9 -4
  61. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +1 -3
  62. wandb/sdk/artifacts/storage_handlers/s3_handler.py +1 -1
  63. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +2 -2
  64. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -1
  65. wandb/sdk/backend/backend.py +0 -1
  66. wandb/sdk/data_types/audio.py +1 -1
  67. wandb/sdk/data_types/base_types/media.py +66 -5
  68. wandb/sdk/data_types/bokeh.py +1 -1
  69. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +1 -1
  70. wandb/sdk/data_types/helper_types/image_mask.py +2 -2
  71. wandb/sdk/data_types/histogram.py +1 -1
  72. wandb/sdk/data_types/html.py +1 -1
  73. wandb/sdk/data_types/image.py +1 -1
  74. wandb/sdk/data_types/molecule.py +3 -3
  75. wandb/sdk/data_types/object_3d.py +4 -4
  76. wandb/sdk/data_types/plotly.py +1 -1
  77. wandb/sdk/data_types/saved_model.py +0 -1
  78. wandb/sdk/data_types/table.py +7 -7
  79. wandb/sdk/data_types/trace_tree.py +1 -1
  80. wandb/sdk/data_types/video.py +4 -3
  81. wandb/sdk/interface/router.py +0 -2
  82. wandb/sdk/internal/datastore.py +1 -1
  83. wandb/sdk/internal/file_pusher.py +1 -1
  84. wandb/sdk/internal/file_stream.py +4 -4
  85. wandb/sdk/internal/handler.py +3 -2
  86. wandb/sdk/internal/internal.py +1 -1
  87. wandb/sdk/internal/internal_api.py +183 -64
  88. wandb/sdk/internal/job_builder.py +4 -3
  89. wandb/sdk/internal/system/assets/__init__.py +0 -2
  90. wandb/sdk/internal/tb_watcher.py +11 -10
  91. wandb/sdk/launch/_launch.py +4 -3
  92. wandb/sdk/launch/_launch_add.py +2 -2
  93. wandb/sdk/launch/builder/kaniko_builder.py +0 -1
  94. wandb/sdk/launch/create_job.py +1 -0
  95. wandb/sdk/launch/environment/local_environment.py +0 -1
  96. wandb/sdk/launch/errors.py +0 -6
  97. wandb/sdk/launch/registry/local_registry.py +0 -2
  98. wandb/sdk/launch/runner/abstract.py +0 -5
  99. wandb/sdk/launch/sweeps/__init__.py +0 -2
  100. wandb/sdk/launch/sweeps/scheduler.py +0 -2
  101. wandb/sdk/launch/sweeps/scheduler_sweep.py +0 -1
  102. wandb/sdk/lib/apikey.py +3 -3
  103. wandb/sdk/lib/file_stream_utils.py +1 -1
  104. wandb/sdk/lib/filesystem.py +1 -1
  105. wandb/sdk/lib/ipython.py +16 -9
  106. wandb/sdk/lib/mailbox.py +0 -4
  107. wandb/sdk/lib/printer.py +44 -8
  108. wandb/sdk/lib/retry.py +1 -1
  109. wandb/sdk/service/service.py +3 -3
  110. wandb/sdk/service/streams.py +2 -4
  111. wandb/sdk/wandb_init.py +20 -20
  112. wandb/sdk/wandb_login.py +1 -1
  113. wandb/sdk/wandb_require.py +1 -4
  114. wandb/sdk/wandb_run.py +57 -69
  115. wandb/sdk/wandb_settings.py +3 -4
  116. wandb/sdk/wandb_sync.py +2 -1
  117. wandb/util.py +46 -18
  118. wandb/wandb_agent.py +3 -3
  119. wandb/wandb_controller.py +2 -2
  120. {wandb-0.18.4.dist-info → wandb-0.18.6.dist-info}/METADATA +1 -1
  121. {wandb-0.18.4.dist-info → wandb-0.18.6.dist-info}/RECORD +124 -125
  122. wandb/sdk/internal/system/assets/gpu_apple.py +0 -177
  123. wandb/sdk/lib/_wburls_generate.py +0 -25
  124. wandb/sdk/lib/_wburls_generated.py +0 -22
  125. wandb/sdk/lib/wburls.py +0 -46
  126. {wandb-0.18.4.dist-info → wandb-0.18.6.dist-info}/WHEEL +0 -0
  127. {wandb-0.18.4.dist-info → wandb-0.18.6.dist-info}/entry_points.txt +0 -0
  128. {wandb-0.18.4.dist-info → wandb-0.18.6.dist-info}/licenses/LICENSE +0 -0
wandb/plot/histogram.py CHANGED
@@ -1,39 +1,65 @@
1
- from typing import TYPE_CHECKING, Optional
1
+ from __future__ import annotations
2
2
 
3
- from wandb.plot.viz import custom_chart
3
+ from typing import TYPE_CHECKING
4
+
5
+ from wandb.plot.custom_chart import plot_table
4
6
 
5
7
  if TYPE_CHECKING:
6
8
  import wandb
9
+ from wandb.plot.custom_chart import CustomChart
7
10
 
8
11
 
9
12
  def histogram(
10
- table: "wandb.Table",
13
+ table: wandb.Table,
11
14
  value: str,
12
- title: Optional[str] = None,
13
- split_table: Optional[bool] = False,
14
- ):
15
- """Construct a histogram plot.
15
+ title: str = "",
16
+ split_table: bool = False,
17
+ ) -> CustomChart:
18
+ """Constructs a histogram chart from a W&B Table.
16
19
 
17
- Arguments:
18
- table (wandb.Table): Table of data.
19
- value (string): Name of column to use as data for bucketing.
20
- title (string): Plot title.
21
- split_table (bool): If True, adds "Custom Chart Tables/" to the key of the table so that it's logged in a different section.
20
+ Args:
21
+ table (wandb.Table): The W&B Table containing the data for the histogram.
22
+ value (str): The label for the bin axis (x-axis).
23
+ title (str): The title of the histogram plot.
24
+ split_table (bool): Whether the table should be split into a separate section
25
+ in the W&B UI. If `True`, the table will be displayed in a section named
26
+ "Custom Chart Tables". Default is `False`.
22
27
 
23
28
  Returns:
24
- A plot object, to be passed to wandb.log()
29
+ CustomChart: A custom chart object that can be logged to W&B. To log the
30
+ chart, pass it to `wandb.log()`.
25
31
 
26
32
  Example:
27
33
  ```
34
+ import math
35
+ import random
36
+ import wandb
37
+
38
+ # Generate random data
28
39
  data = [[i, random.random() + math.sin(i / 10)] for i in range(100)]
29
- table = wandb.Table(data=data, columns=["step", "height"])
30
- wandb.log({'histogram-plot1': wandb.plot.histogram(table, "height")})
40
+
41
+ # Create a W&B Table
42
+ table = wandb.Table(
43
+ data=data,
44
+ columns=["step", "height"],
45
+ )
46
+
47
+ # Create a histogram plot
48
+ histogram = wandb.plot.histogram(
49
+ table,
50
+ value="height",
51
+ title="My Histogram",
52
+ )
53
+
54
+ # Log the histogram plot to W&B
55
+ with wandb.init(...) as run:
56
+ run.log({'histogram-plot1': histogram})
31
57
  ```
32
58
  """
33
- return custom_chart(
34
- "wandb/histogram/v0",
35
- table,
36
- {"value": value},
37
- {"title": title},
59
+ return plot_table(
60
+ data_table=table,
61
+ vega_spec_name="wandb/histogram/v0",
62
+ fields={"value": value},
63
+ string_fields={"title": title},
38
64
  split_table=split_table,
39
65
  )
wandb/plot/line.py CHANGED
@@ -1,43 +1,74 @@
1
- from typing import TYPE_CHECKING, Optional
1
+ from __future__ import annotations
2
2
 
3
- from wandb.plot.viz import custom_chart
3
+ from typing import TYPE_CHECKING
4
+
5
+ from wandb.plot.custom_chart import plot_table
4
6
 
5
7
  if TYPE_CHECKING:
6
8
  import wandb
9
+ from wandb.plot.custom_chart import CustomChart
7
10
 
8
11
 
9
12
  def line(
10
- table: "wandb.Table",
13
+ table: wandb.Table,
11
14
  x: str,
12
15
  y: str,
13
- stroke: Optional[str] = None,
14
- title: Optional[str] = None,
15
- split_table: Optional[bool] = False,
16
- ):
17
- """Construct a line plot.
18
-
19
- Arguments:
20
- table (wandb.Table): Table of data.
21
- x (string): Name of column to as for x-axis values.
22
- y (string): Name of column to as for y-axis values.
23
- stroke (string): Name of column to map to the line stroke scale.
24
- title (string): Plot title.
25
- split_table (bool): If True, adds "Custom Chart Tables/" to the key of the table so that it's logged in a different section.
16
+ stroke: str | None = None,
17
+ title: str = "",
18
+ split_table: bool = False,
19
+ ) -> CustomChart:
20
+ """Constructs a customizable line chart.
21
+
22
+ Args:
23
+ table (wandb.Table): The table containing data for the chart.
24
+ x (str): Column name for the x-axis values.
25
+ y (str): Column name for the y-axis values.
26
+ stroke (str):Column name to differentiate line strokes (e.g., for
27
+ grouping lines).
28
+ title (str):Title of the chart.
29
+ split_table (bool): Whether the table should be split into a separate section
30
+ in the W&B UI. If `True`, the table will be displayed in a section named
31
+ "Custom Chart Tables". Default is `False`.
26
32
 
27
33
  Returns:
28
- A plot object, to be passed to wandb.log()
34
+ CustomChart: A custom chart object that can be logged to W&B. To log the
35
+ chart, pass it to `wandb.log()`.
29
36
 
30
37
  Example:
31
- ```
32
- data = [[i, random.random() + math.sin(i / 10)] for i in range(100)]
33
- table = wandb.Table(data=data, columns=["step", "height"])
34
- wandb.log({'line-plot1': wandb.plot.line(table, "step", "height")})
38
+ ```python
39
+ import math
40
+ import random
41
+ import wandb
42
+
43
+ # Create multiple series of data with different patterns
44
+ data = []
45
+ for i in range(100):
46
+ # Series 1: Sinusoidal pattern with random noise
47
+ data.append([i, math.sin(i / 10) + random.uniform(-0.1, 0.1), "series_1"])
48
+ # Series 2: Cosine pattern with random noise
49
+ data.append([i, math.cos(i / 10) + random.uniform(-0.1, 0.1), "series_2"])
50
+ # Series 3: Linear increase with random noise
51
+ data.append([i, i / 10 + random.uniform(-0.5, 0.5), "series_3"])
52
+
53
+ # Define the columns for the table
54
+ table = wandb.Table(data=data, columns=["step", "value", "series"])
55
+
56
+ # Initialize wandb run and log the line chart
57
+ with wandb.init(project="line_chart_example") as run:
58
+ line_chart = wandb.plot.line(
59
+ table=table,
60
+ x="step",
61
+ y="value",
62
+ stroke="series", # Group by the "series" column
63
+ title="Multi-Series Line Plot",
64
+ )
65
+ run.log({"line-chart": line_chart})
35
66
  ```
36
67
  """
37
- return custom_chart(
38
- "wandb/line/v0",
39
- table,
40
- {"x": x, "y": y, "stroke": stroke},
41
- {"title": title},
68
+ return plot_table(
69
+ data_table=table,
70
+ vega_spec_name="wandb/line/v0",
71
+ fields={"x": x, "y": y, "stroke": stroke},
72
+ string_fields={"title": title},
42
73
  split_table=split_table,
43
74
  )
wandb/plot/line_series.py CHANGED
@@ -1,88 +1,176 @@
1
- import typing as t
2
- from collections.abc import Iterable
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING, Any, Iterable
3
4
 
4
5
  import wandb
6
+ from wandb.plot.custom_chart import plot_table
7
+
8
+ if TYPE_CHECKING:
9
+ from wandb.plot.custom_chart import CustomChart
5
10
 
6
11
 
7
12
  def line_series(
8
- xs: t.Union[t.Iterable, t.Iterable[t.Iterable]],
9
- ys: t.Iterable[t.Iterable],
10
- keys: t.Optional[t.Iterable] = None,
11
- title: t.Optional[str] = None,
12
- xname: t.Optional[str] = None,
13
- split_table: t.Optional[bool] = False,
14
- ):
15
- """Construct a line series plot.
16
-
17
- Arguments:
18
- xs (array of arrays, or array): Array of arrays of x values
19
- ys (array of arrays): Array of y values
20
- keys (array): Array of labels for the line plots
21
- title (string): Plot title.
22
- xname: Title of x-axis
23
- split_table (bool): If True, adds "Custom Chart Tables/" to the key of the table so that it's logged in a different section.
13
+ xs: Iterable[Iterable[Any]] | Iterable[Any],
14
+ ys: Iterable[Iterable[Any]],
15
+ keys: Iterable[str] | None = None,
16
+ title: str = "",
17
+ xname: str = "x",
18
+ split_table: bool = False,
19
+ ) -> CustomChart:
20
+ """Constructs a line series chart.
21
+
22
+ Args:
23
+ xs (Iterable[Iterable] | Iterable): Sequence of x values. If a singular
24
+ array is provided, all y values are plotted against that x array. If
25
+ an array of arrays is provided, each y value is plotted against the
26
+ corresponding x array.
27
+ ys (Iterable[Iterable]): Sequence of y values, where each iterable represents
28
+ a separate line series.
29
+ keys (Iterable[str]): Sequence of keys for labeling each line series. If
30
+ not provided, keys will be automatically generated as "line_1",
31
+ "line_2", etc.
32
+ title (str): Title of the chart.
33
+ xname (str): Label for the x-axis.
34
+ split_table (bool): Whether the table should be split into a separate section
35
+ in the W&B UI. If `True`, the table will be displayed in a section named
36
+ "Custom Chart Tables". Default is `False`.
24
37
 
25
38
  Returns:
26
- A plot object, to be passed to wandb.log()
39
+ CustomChart: A custom chart object that can be logged to W&B. To log the
40
+ chart, pass it to `wandb.log()`.
27
41
 
28
- Example:
29
- When logging a singular array for xs, all ys are plotted against that xs
30
- <!--yeadoc-test:plot-line-series-single-->
31
- ```python
42
+ Examples:
43
+ 1. Logging a single x array where all y series are plotted against
44
+ the same x values:
45
+
46
+ ```
32
47
  import wandb
33
48
 
34
- run = wandb.init()
35
- xs = [i for i in range(10)]
36
- ys = [[i for i in range(10)], [i**2 for i in range(10)]]
37
- run.log(
38
- {"line-series-plot1": wandb.plot.line_series(xs, ys, title="title", xname="step")}
39
- )
40
- run.finish()
49
+ # Initialize W&B run
50
+ with wandb.init(project="line_series_example") as run:
51
+ # x values shared across all y series
52
+ xs = list(range(10))
53
+
54
+ # Multiple y series to plot
55
+ ys = [
56
+ [i for i in range(10)], # y = x
57
+ [i**2 for i in range(10)], # y = x^2
58
+ [i**3 for i in range(10)], # y = x^3
59
+ ]
60
+
61
+ # Generate and log the line series chart
62
+ line_series_chart = wandb.plot.line_series(
63
+ xs,
64
+ ys,
65
+ title="title",
66
+ xname="step",
67
+ )
68
+ run.log({"line-series-single-x": line_series_chart})
41
69
  ```
42
- xs can also contain an array of arrays for having different steps for each metric
43
- <!--yeadoc-test:plot-line-series-double-->
70
+
71
+ In this example, a single `xs` series (shared x-values) is used for all
72
+ `ys` series. This results in each y-series being plotted against the
73
+ same x-values (0-9).
74
+
75
+ 2. Logging multiple x arrays where each y series is plotted against
76
+ its corresponding x array:
77
+
44
78
  ```python
45
79
  import wandb
46
80
 
47
- run = wandb.init()
48
- xs = [[i for i in range(10)], [2 * i for i in range(10)]]
49
- ys = [[i for i in range(10)], [i**2 for i in range(10)]]
50
- run.log(
51
- {"line-series-plot2": wandb.plot.line_series(xs, ys, title="title", xname="step")}
52
- )
53
- run.finish()
81
+ # Initialize W&B run
82
+ with wandb.init(project="line_series_example") as run:
83
+ # Separate x values for each y series
84
+ xs = [
85
+ [i for i in range(10)], # x for first series
86
+ [2 * i for i in range(10)], # x for second series (stretched)
87
+ [3 * i for i in range(10)], # x for third series (stretched more)
88
+ ]
89
+
90
+ # Corresponding y series
91
+ ys = [
92
+ [i for i in range(10)], # y = x
93
+ [i**2 for i in range(10)], # y = x^2
94
+ [i**3 for i in range(10)], # y = x^3
95
+ ]
96
+
97
+ # Generate and log the line series chart
98
+ line_series_chart = wandb.plot.line_series(
99
+ xs, ys, title="Multiple X Arrays Example", xname="Step"
100
+ )
101
+ run.log({"line-series-multiple-x": line_series_chart})
54
102
  ```
55
- """
56
- if not isinstance(xs, Iterable):
57
- raise TypeError(f"Expected xs to be an array instead got {type(xs)}")
58
103
 
59
- if not isinstance(ys, Iterable):
60
- raise TypeError(f"Expected ys to be an array instead got {type(xs)}")
104
+ In this example, each y series is plotted against its own unique x series.
105
+ This allows for more flexibility when the x values are not uniform across
106
+ the data series.
107
+
108
+ 3. Customizing line labels using `keys`:
109
+
110
+ ```python
111
+ import wandb
112
+
113
+ # Initialize W&B run
114
+ with wandb.init(project="line_series_example") as run:
115
+ xs = list(range(10)) # Single x array
116
+ ys = [
117
+ [i for i in range(10)], # y = x
118
+ [i**2 for i in range(10)], # y = x^2
119
+ [i**3 for i in range(10)], # y = x^3
120
+ ]
61
121
 
62
- for y in ys:
63
- if not isinstance(y, Iterable):
64
- raise TypeError(
65
- f"Expected ys to be an array of arrays instead got {type(y)}"
122
+ # Custom labels for each line
123
+ keys = ["Linear", "Quadratic", "Cubic"]
124
+
125
+ # Generate and log the line series chart
126
+ line_series_chart = wandb.plot.line_series(
127
+ xs,
128
+ ys,
129
+ keys=keys, # Custom keys (line labels)
130
+ title="Custom Line Labels Example",
131
+ xname="Step",
66
132
  )
133
+ run.log({"line-series-custom-keys": line_series_chart})
134
+ ```
67
135
 
136
+ This example shows how to provide custom labels for the lines using
137
+ the `keys` argument. The keys will appear in the legend as "Linear",
138
+ "Quadratic", and "Cubic".
139
+
140
+ """
141
+ # If xs is a single array, repeat it for each y in ys
68
142
  if not isinstance(xs[0], Iterable) or isinstance(xs[0], (str, bytes)):
69
- xs = [xs for _ in range(len(ys))]
70
- assert len(xs) == len(ys), "Number of x-lines and y-lines must match"
143
+ xs = [xs] * len(ys)
144
+
145
+ if len(xs) != len(ys):
146
+ msg = f"Number of x-series ({len(xs)}) must match y-series ({len(ys)})."
147
+ raise ValueError(msg)
148
+
149
+ if keys is None:
150
+ keys = [f"line_{i}" for i in range(len(ys))]
151
+
152
+ if len(keys) != len(ys):
153
+ msg = f"Number of keys ({len(keys)}) must match y-series ({len(ys)})."
154
+ raise ValueError(msg)
71
155
 
72
- if keys is not None:
73
- assert len(keys) == len(ys), "Number of keys and y-lines must match"
74
156
  data = [
75
- [x, f"key_{i}" if keys is None else keys[i], y]
157
+ [x, keys[i], y]
76
158
  for i, (xx, yy) in enumerate(zip(xs, ys))
77
159
  for x, y in zip(xx, yy)
78
160
  ]
161
+ table = wandb.Table(
162
+ data=data,
163
+ columns=["step", "lineKey", "lineVal"],
164
+ )
79
165
 
80
- table = wandb.Table(data=data, columns=["step", "lineKey", "lineVal"])
81
-
82
- return wandb.plot_table(
83
- "wandb/lineseries/v0",
84
- table,
85
- {"step": "step", "lineKey": "lineKey", "lineVal": "lineVal"},
86
- {"title": title, "xname": xname or "x"},
166
+ return plot_table(
167
+ data_table=table,
168
+ vega_spec_name="wandb/lineseries/v0",
169
+ fields={
170
+ "step": "step",
171
+ "lineKey": "lineKey",
172
+ "lineVal": "lineVal",
173
+ },
174
+ string_fields={"title": title, "xname": xname},
87
175
  split_table=split_table,
88
176
  )
wandb/plot/pr_curve.py CHANGED
@@ -1,48 +1,92 @@
1
- from typing import Optional
1
+ from __future__ import annotations
2
+
3
+ import numbers
4
+ from typing import TYPE_CHECKING, Iterable, TypeVar
2
5
 
3
6
  import wandb
4
7
  from wandb import util
8
+ from wandb.plot.custom_chart import plot_table
9
+ from wandb.plot.utils import test_missing, test_types
10
+
11
+ if TYPE_CHECKING:
12
+ from wandb.plot.custom_chart import CustomChart
13
+
5
14
 
6
- from .utils import test_missing, test_types
15
+ T = TypeVar("T")
7
16
 
8
17
 
9
18
  def pr_curve(
10
- y_true=None,
11
- y_probas=None,
12
- labels=None,
13
- classes_to_plot=None,
14
- interp_size=21,
15
- title=None,
16
- split_table: Optional[bool] = False,
17
- ):
18
- """Compute the tradeoff between precision and recall for different thresholds.
19
-
20
- A high area under the curve represents both high recall and high precision, where
21
- high precision relates to a low false positive rate, and high recall relates to a
22
- low false negative rate. High scores for both show that the classifier is returning
23
- accurate results (high precision), and returning a majority of all positive results
24
- (high recall). PR curve is useful when the classes are very imbalanced.
25
-
26
- Arguments:
27
- y_true (arr): true sparse labels y_probas (arr): Target scores, can either be
28
- probability estimates, confidence values, or non-thresholded measure of
29
- decisions. shape: (*y_true.shape, num_classes)
30
- labels (list): Named labels for target variable (y). Makes plots easier to read
31
- by replacing target values with corresponding index. For example labels =
32
- ['dog', 'cat', 'owl'] all 0s are replaced by 'dog', 1s by 'cat'.
33
- classes_to_plot (list): unique values of y_true to include in the plot
34
- interp_size (int): the recall values will be fixed to `interp_size` points
35
- uniform on [0, 1] and the precision will be interpolated for these recall
36
- values.
37
- split_table (bool): If True, adds "Custom Chart Tables/" to the key of the table so that it's logged in a different section.
19
+ y_true: Iterable[T] | None = None,
20
+ y_probas: Iterable[numbers.Number] | None = None,
21
+ labels: list[str] | None = None,
22
+ classes_to_plot: list[T] | None = None,
23
+ interp_size: int = 21,
24
+ title: str = "Precision-Recall Curve",
25
+ split_table: bool = False,
26
+ ) -> CustomChart:
27
+ """Constructs a Precision-Recall (PR) curve.
28
+
29
+ The Precision-Recall curve is particularly useful for evaluating classifiers
30
+ on imbalanced datasets. A high area under the PR curve signifies both high
31
+ precision (a low false positive rate) and high recall (a low false negative
32
+ rate). The curve provides insights into the balance between false positives
33
+ and false negatives at various threshold levels, aiding in the assessment of
34
+ a model's performance.
35
+
36
+ Args:
37
+ y_true (Iterable): True binary labels. The shape should be (`num_samples`,).
38
+ y_probas (Iterable): Predicted scores or probabilities for each class.
39
+ These can be probability estimates, confidence scores, or non-thresholded
40
+ decision values. The shape should be (`num_samples`, `num_classes`).
41
+ labels (list[str] | None): Optional list of class names to replace
42
+ numeric values in `y_true` for easier plot interpretation.
43
+ For example, `labels = ['dog', 'cat', 'owl']` will replace 0 with
44
+ 'dog', 1 with 'cat', and 2 with 'owl' in the plot. If not provided,
45
+ numeric values from `y_true` will be used.
46
+ classes_to_plot (list | None): Optional list of unique class values from
47
+ y_true to be included in the plot. If not specified, all unique
48
+ classes in y_true will be plotted.
49
+ interp_size (int): Number of points to interpolate recall values. The
50
+ recall values will be fixed to `interp_size` uniformly distributed
51
+ points in the range [0, 1], and the precision will be interpolated
52
+ accordingly.
53
+ title (str): Title of the plot. Defaults to "Precision-Recall Curve".
54
+ split_table (bool): Whether the table should be split into a separate section
55
+ in the W&B UI. If `True`, the table will be displayed in a section named
56
+ "Custom Chart Tables". Default is `False`.
38
57
 
39
58
  Returns:
40
- Nothing. To see plots, go to your W&B run page then expand the 'media' tab under
41
- 'auto visualizations'.
59
+ CustomChart: A custom chart object that can be logged to W&B. To log the
60
+ chart, pass it to `wandb.log()`.
61
+
62
+ Raises:
63
+ wandb.Error: If numpy, pandas, or scikit-learn is not installed.
64
+
42
65
 
43
66
  Example:
44
67
  ```
45
- wandb.log({"pr-curve": wandb.plot.pr_curve(y_true, y_probas, labels)})
68
+ import wandb
69
+
70
+ # Example for spam detection (binary classification)
71
+ y_true = [0, 1, 1, 0, 1] # 0 = not spam, 1 = spam
72
+ y_probas = [
73
+ [0.9, 0.1], # Predicted probabilities for the first sample (not spam)
74
+ [0.2, 0.8], # Second sample (spam), and so on
75
+ [0.1, 0.9],
76
+ [0.8, 0.2],
77
+ [0.3, 0.7]
78
+ ]
79
+
80
+ labels = ['not spam', 'spam'] # Optional class names for readability
81
+
82
+ with wandb.init(project="spam-detection") as run:
83
+ pr_curve = wandb.plot.pr_curve(
84
+ y_true=y_true,
85
+ y_probas=y_probas,
86
+ labels=labels,
87
+ title="Precision-Recall Curve for Spam Detection",
88
+ )
89
+ run.log({"pr-curve": pr_curve})
46
90
  ```
47
91
  """
48
92
  np = util.get_module(
@@ -80,7 +124,7 @@ def pr_curve(
80
124
  if classes_to_plot is None:
81
125
  classes_to_plot = classes
82
126
 
83
- precision = dict()
127
+ precision = {}
84
128
  interp_recall = np.linspace(0, 1, interp_size)[::-1]
85
129
  indices_to_plot = np.where(np.isin(classes, classes_to_plot))[0]
86
130
  for i in indices_to_plot:
@@ -109,12 +153,11 @@ def pr_curve(
109
153
  "precision": np.hstack(list(precision.values())),
110
154
  "recall": np.tile(interp_recall, len(precision)),
111
155
  }
112
- )
113
- df = df.round(3)
156
+ ).round(3)
114
157
 
115
158
  if len(df) > wandb.Table.MAX_ROWS:
116
159
  wandb.termwarn(
117
- "wandb uses only %d data points to create the plots." % wandb.Table.MAX_ROWS
160
+ f"Table has a limit of {wandb.Table.MAX_ROWS} rows. Resampling to fit."
118
161
  )
119
162
  # different sampling could be applied, possibly to ensure endpoints are kept
120
163
  df = sklearn_utils.resample(
@@ -125,12 +168,14 @@ def pr_curve(
125
168
  stratify=df["class"],
126
169
  ).sort_values(["precision", "recall", "class"])
127
170
 
128
- table = wandb.Table(dataframe=df)
129
- title = title or "Precision v. Recall"
130
- return wandb.plot_table(
131
- "wandb/area-under-curve/v0",
132
- table,
133
- {"x": "recall", "y": "precision", "class": "class"},
134
- {"title": title},
171
+ return plot_table(
172
+ data_table=wandb.Table(dataframe=df),
173
+ vega_spec_name="wandb/area-under-curve/v0",
174
+ fields={
175
+ "x": "recall",
176
+ "y": "precision",
177
+ "class": "class",
178
+ },
179
+ string_fields={"title": title},
135
180
  split_table=split_table,
136
181
  )