wandb 0.18.4__py3-none-any.whl → 0.18.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +2 -2
- wandb/__init__.pyi +21 -19
- wandb/agents/pyagent.py +1 -1
- wandb/apis/importers/wandb.py +1 -1
- wandb/apis/normalize.py +2 -18
- wandb/apis/public/api.py +122 -62
- wandb/apis/public/artifacts.py +8 -3
- wandb/apis/public/files.py +17 -2
- wandb/apis/public/jobs.py +2 -2
- wandb/apis/public/query_generator.py +1 -1
- wandb/apis/public/runs.py +8 -8
- wandb/apis/public/teams.py +3 -3
- wandb/apis/public/users.py +1 -1
- wandb/apis/public/utils.py +68 -0
- wandb/bin/gpu_stats +0 -0
- wandb/cli/cli.py +12 -3
- wandb/data_types.py +1 -1
- wandb/docker/__init__.py +2 -1
- wandb/docker/auth.py +2 -3
- wandb/errors/links.py +73 -0
- wandb/errors/term.py +7 -6
- wandb/filesync/step_prepare.py +1 -1
- wandb/filesync/upload_job.py +1 -1
- wandb/integration/catboost/catboost.py +2 -2
- wandb/integration/diffusers/pipeline_resolver.py +1 -1
- wandb/integration/diffusers/resolvers/multimodal.py +6 -6
- wandb/integration/diffusers/resolvers/utils.py +1 -1
- wandb/integration/fastai/__init__.py +3 -2
- wandb/integration/keras/callbacks/metrics_logger.py +1 -1
- wandb/integration/keras/callbacks/model_checkpoint.py +1 -1
- wandb/integration/keras/keras.py +1 -1
- wandb/integration/kfp/kfp_patch.py +1 -1
- wandb/integration/lightgbm/__init__.py +2 -2
- wandb/integration/magic.py +2 -2
- wandb/integration/metaflow/metaflow.py +1 -1
- wandb/integration/sacred/__init__.py +1 -1
- wandb/integration/sagemaker/auth.py +1 -1
- wandb/integration/sklearn/plot/classifier.py +7 -7
- wandb/integration/sklearn/plot/clusterer.py +3 -3
- wandb/integration/sklearn/plot/regressor.py +3 -3
- wandb/integration/sklearn/plot/shared.py +2 -2
- wandb/integration/tensorboard/log.py +2 -2
- wandb/integration/ultralytics/callback.py +2 -2
- wandb/integration/xgboost/xgboost.py +1 -1
- wandb/jupyter.py +0 -1
- wandb/plot/__init__.py +17 -8
- wandb/plot/bar.py +53 -27
- wandb/plot/confusion_matrix.py +151 -70
- wandb/plot/custom_chart.py +124 -0
- wandb/plot/histogram.py +46 -20
- wandb/plot/line.py +57 -26
- wandb/plot/line_series.py +148 -60
- wandb/plot/pr_curve.py +89 -44
- wandb/plot/roc_curve.py +82 -37
- wandb/plot/scatter.py +53 -20
- wandb/plot/viz.py +20 -102
- wandb/sdk/artifacts/artifact.py +280 -328
- wandb/sdk/artifacts/artifact_manifest.py +10 -9
- wandb/sdk/artifacts/artifact_manifest_entry.py +1 -1
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +9 -4
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +1 -3
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +1 -1
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +2 -2
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -1
- wandb/sdk/backend/backend.py +0 -1
- wandb/sdk/data_types/audio.py +1 -1
- wandb/sdk/data_types/base_types/media.py +66 -5
- wandb/sdk/data_types/bokeh.py +1 -1
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +1 -1
- wandb/sdk/data_types/helper_types/image_mask.py +2 -2
- wandb/sdk/data_types/histogram.py +1 -1
- wandb/sdk/data_types/html.py +1 -1
- wandb/sdk/data_types/image.py +1 -1
- wandb/sdk/data_types/molecule.py +3 -3
- wandb/sdk/data_types/object_3d.py +4 -4
- wandb/sdk/data_types/plotly.py +1 -1
- wandb/sdk/data_types/saved_model.py +0 -1
- wandb/sdk/data_types/table.py +7 -7
- wandb/sdk/data_types/trace_tree.py +1 -1
- wandb/sdk/data_types/video.py +4 -3
- wandb/sdk/interface/router.py +0 -2
- wandb/sdk/internal/datastore.py +1 -1
- wandb/sdk/internal/file_pusher.py +1 -1
- wandb/sdk/internal/file_stream.py +4 -4
- wandb/sdk/internal/handler.py +3 -2
- wandb/sdk/internal/internal.py +1 -1
- wandb/sdk/internal/internal_api.py +183 -64
- wandb/sdk/internal/job_builder.py +4 -3
- wandb/sdk/internal/system/assets/__init__.py +0 -2
- wandb/sdk/internal/tb_watcher.py +11 -10
- wandb/sdk/launch/_launch.py +4 -3
- wandb/sdk/launch/_launch_add.py +2 -2
- wandb/sdk/launch/builder/kaniko_builder.py +0 -1
- wandb/sdk/launch/create_job.py +1 -0
- wandb/sdk/launch/environment/local_environment.py +0 -1
- wandb/sdk/launch/errors.py +0 -6
- wandb/sdk/launch/registry/local_registry.py +0 -2
- wandb/sdk/launch/runner/abstract.py +0 -5
- wandb/sdk/launch/sweeps/__init__.py +0 -2
- wandb/sdk/launch/sweeps/scheduler.py +0 -2
- wandb/sdk/launch/sweeps/scheduler_sweep.py +0 -1
- wandb/sdk/lib/apikey.py +3 -3
- wandb/sdk/lib/file_stream_utils.py +1 -1
- wandb/sdk/lib/filesystem.py +1 -1
- wandb/sdk/lib/ipython.py +16 -9
- wandb/sdk/lib/mailbox.py +0 -4
- wandb/sdk/lib/printer.py +44 -8
- wandb/sdk/lib/retry.py +1 -1
- wandb/sdk/service/service.py +3 -3
- wandb/sdk/service/streams.py +2 -4
- wandb/sdk/wandb_init.py +20 -20
- wandb/sdk/wandb_login.py +1 -1
- wandb/sdk/wandb_require.py +1 -4
- wandb/sdk/wandb_run.py +57 -69
- wandb/sdk/wandb_settings.py +3 -4
- wandb/sdk/wandb_sync.py +2 -1
- wandb/util.py +46 -18
- wandb/wandb_agent.py +3 -3
- wandb/wandb_controller.py +2 -2
- {wandb-0.18.4.dist-info → wandb-0.18.6.dist-info}/METADATA +1 -1
- {wandb-0.18.4.dist-info → wandb-0.18.6.dist-info}/RECORD +124 -125
- wandb/sdk/internal/system/assets/gpu_apple.py +0 -177
- wandb/sdk/lib/_wburls_generate.py +0 -25
- wandb/sdk/lib/_wburls_generated.py +0 -22
- wandb/sdk/lib/wburls.py +0 -46
- {wandb-0.18.4.dist-info → wandb-0.18.6.dist-info}/WHEEL +0 -0
- {wandb-0.18.4.dist-info → wandb-0.18.6.dist-info}/entry_points.txt +0 -0
- {wandb-0.18.4.dist-info → wandb-0.18.6.dist-info}/licenses/LICENSE +0 -0
wandb/plot/histogram.py
CHANGED
@@ -1,39 +1,65 @@
|
|
1
|
-
from
|
1
|
+
from __future__ import annotations
|
2
2
|
|
3
|
-
from
|
3
|
+
from typing import TYPE_CHECKING
|
4
|
+
|
5
|
+
from wandb.plot.custom_chart import plot_table
|
4
6
|
|
5
7
|
if TYPE_CHECKING:
|
6
8
|
import wandb
|
9
|
+
from wandb.plot.custom_chart import CustomChart
|
7
10
|
|
8
11
|
|
9
12
|
def histogram(
|
10
|
-
table:
|
13
|
+
table: wandb.Table,
|
11
14
|
value: str,
|
12
|
-
title:
|
13
|
-
split_table:
|
14
|
-
):
|
15
|
-
"""
|
15
|
+
title: str = "",
|
16
|
+
split_table: bool = False,
|
17
|
+
) -> CustomChart:
|
18
|
+
"""Constructs a histogram chart from a W&B Table.
|
16
19
|
|
17
|
-
|
18
|
-
table (wandb.Table): Table
|
19
|
-
value (
|
20
|
-
title (
|
21
|
-
split_table (bool):
|
20
|
+
Args:
|
21
|
+
table (wandb.Table): The W&B Table containing the data for the histogram.
|
22
|
+
value (str): The label for the bin axis (x-axis).
|
23
|
+
title (str): The title of the histogram plot.
|
24
|
+
split_table (bool): Whether the table should be split into a separate section
|
25
|
+
in the W&B UI. If `True`, the table will be displayed in a section named
|
26
|
+
"Custom Chart Tables". Default is `False`.
|
22
27
|
|
23
28
|
Returns:
|
24
|
-
A
|
29
|
+
CustomChart: A custom chart object that can be logged to W&B. To log the
|
30
|
+
chart, pass it to `wandb.log()`.
|
25
31
|
|
26
32
|
Example:
|
27
33
|
```
|
34
|
+
import math
|
35
|
+
import random
|
36
|
+
import wandb
|
37
|
+
|
38
|
+
# Generate random data
|
28
39
|
data = [[i, random.random() + math.sin(i / 10)] for i in range(100)]
|
29
|
-
|
30
|
-
|
40
|
+
|
41
|
+
# Create a W&B Table
|
42
|
+
table = wandb.Table(
|
43
|
+
data=data,
|
44
|
+
columns=["step", "height"],
|
45
|
+
)
|
46
|
+
|
47
|
+
# Create a histogram plot
|
48
|
+
histogram = wandb.plot.histogram(
|
49
|
+
table,
|
50
|
+
value="height",
|
51
|
+
title="My Histogram",
|
52
|
+
)
|
53
|
+
|
54
|
+
# Log the histogram plot to W&B
|
55
|
+
with wandb.init(...) as run:
|
56
|
+
run.log({'histogram-plot1': histogram})
|
31
57
|
```
|
32
58
|
"""
|
33
|
-
return
|
34
|
-
|
35
|
-
|
36
|
-
{"value": value},
|
37
|
-
{"title": title},
|
59
|
+
return plot_table(
|
60
|
+
data_table=table,
|
61
|
+
vega_spec_name="wandb/histogram/v0",
|
62
|
+
fields={"value": value},
|
63
|
+
string_fields={"title": title},
|
38
64
|
split_table=split_table,
|
39
65
|
)
|
wandb/plot/line.py
CHANGED
@@ -1,43 +1,74 @@
|
|
1
|
-
from
|
1
|
+
from __future__ import annotations
|
2
2
|
|
3
|
-
from
|
3
|
+
from typing import TYPE_CHECKING
|
4
|
+
|
5
|
+
from wandb.plot.custom_chart import plot_table
|
4
6
|
|
5
7
|
if TYPE_CHECKING:
|
6
8
|
import wandb
|
9
|
+
from wandb.plot.custom_chart import CustomChart
|
7
10
|
|
8
11
|
|
9
12
|
def line(
|
10
|
-
table:
|
13
|
+
table: wandb.Table,
|
11
14
|
x: str,
|
12
15
|
y: str,
|
13
|
-
stroke:
|
14
|
-
title:
|
15
|
-
split_table:
|
16
|
-
):
|
17
|
-
"""
|
18
|
-
|
19
|
-
|
20
|
-
table (wandb.Table):
|
21
|
-
x (
|
22
|
-
y (
|
23
|
-
stroke (
|
24
|
-
|
25
|
-
|
16
|
+
stroke: str | None = None,
|
17
|
+
title: str = "",
|
18
|
+
split_table: bool = False,
|
19
|
+
) -> CustomChart:
|
20
|
+
"""Constructs a customizable line chart.
|
21
|
+
|
22
|
+
Args:
|
23
|
+
table (wandb.Table): The table containing data for the chart.
|
24
|
+
x (str): Column name for the x-axis values.
|
25
|
+
y (str): Column name for the y-axis values.
|
26
|
+
stroke (str):Column name to differentiate line strokes (e.g., for
|
27
|
+
grouping lines).
|
28
|
+
title (str):Title of the chart.
|
29
|
+
split_table (bool): Whether the table should be split into a separate section
|
30
|
+
in the W&B UI. If `True`, the table will be displayed in a section named
|
31
|
+
"Custom Chart Tables". Default is `False`.
|
26
32
|
|
27
33
|
Returns:
|
28
|
-
|
34
|
+
CustomChart: A custom chart object that can be logged to W&B. To log the
|
35
|
+
chart, pass it to `wandb.log()`.
|
29
36
|
|
30
37
|
Example:
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
38
|
+
```python
|
39
|
+
import math
|
40
|
+
import random
|
41
|
+
import wandb
|
42
|
+
|
43
|
+
# Create multiple series of data with different patterns
|
44
|
+
data = []
|
45
|
+
for i in range(100):
|
46
|
+
# Series 1: Sinusoidal pattern with random noise
|
47
|
+
data.append([i, math.sin(i / 10) + random.uniform(-0.1, 0.1), "series_1"])
|
48
|
+
# Series 2: Cosine pattern with random noise
|
49
|
+
data.append([i, math.cos(i / 10) + random.uniform(-0.1, 0.1), "series_2"])
|
50
|
+
# Series 3: Linear increase with random noise
|
51
|
+
data.append([i, i / 10 + random.uniform(-0.5, 0.5), "series_3"])
|
52
|
+
|
53
|
+
# Define the columns for the table
|
54
|
+
table = wandb.Table(data=data, columns=["step", "value", "series"])
|
55
|
+
|
56
|
+
# Initialize wandb run and log the line chart
|
57
|
+
with wandb.init(project="line_chart_example") as run:
|
58
|
+
line_chart = wandb.plot.line(
|
59
|
+
table=table,
|
60
|
+
x="step",
|
61
|
+
y="value",
|
62
|
+
stroke="series", # Group by the "series" column
|
63
|
+
title="Multi-Series Line Plot",
|
64
|
+
)
|
65
|
+
run.log({"line-chart": line_chart})
|
35
66
|
```
|
36
67
|
"""
|
37
|
-
return
|
38
|
-
|
39
|
-
|
40
|
-
{"x": x, "y": y, "stroke": stroke},
|
41
|
-
{"title": title},
|
68
|
+
return plot_table(
|
69
|
+
data_table=table,
|
70
|
+
vega_spec_name="wandb/line/v0",
|
71
|
+
fields={"x": x, "y": y, "stroke": stroke},
|
72
|
+
string_fields={"title": title},
|
42
73
|
split_table=split_table,
|
43
74
|
)
|
wandb/plot/line_series.py
CHANGED
@@ -1,88 +1,176 @@
|
|
1
|
-
|
2
|
-
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING, Any, Iterable
|
3
4
|
|
4
5
|
import wandb
|
6
|
+
from wandb.plot.custom_chart import plot_table
|
7
|
+
|
8
|
+
if TYPE_CHECKING:
|
9
|
+
from wandb.plot.custom_chart import CustomChart
|
5
10
|
|
6
11
|
|
7
12
|
def line_series(
|
8
|
-
xs:
|
9
|
-
ys:
|
10
|
-
keys:
|
11
|
-
title:
|
12
|
-
xname:
|
13
|
-
split_table:
|
14
|
-
):
|
15
|
-
"""
|
16
|
-
|
17
|
-
|
18
|
-
xs (
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
13
|
+
xs: Iterable[Iterable[Any]] | Iterable[Any],
|
14
|
+
ys: Iterable[Iterable[Any]],
|
15
|
+
keys: Iterable[str] | None = None,
|
16
|
+
title: str = "",
|
17
|
+
xname: str = "x",
|
18
|
+
split_table: bool = False,
|
19
|
+
) -> CustomChart:
|
20
|
+
"""Constructs a line series chart.
|
21
|
+
|
22
|
+
Args:
|
23
|
+
xs (Iterable[Iterable] | Iterable): Sequence of x values. If a singular
|
24
|
+
array is provided, all y values are plotted against that x array. If
|
25
|
+
an array of arrays is provided, each y value is plotted against the
|
26
|
+
corresponding x array.
|
27
|
+
ys (Iterable[Iterable]): Sequence of y values, where each iterable represents
|
28
|
+
a separate line series.
|
29
|
+
keys (Iterable[str]): Sequence of keys for labeling each line series. If
|
30
|
+
not provided, keys will be automatically generated as "line_1",
|
31
|
+
"line_2", etc.
|
32
|
+
title (str): Title of the chart.
|
33
|
+
xname (str): Label for the x-axis.
|
34
|
+
split_table (bool): Whether the table should be split into a separate section
|
35
|
+
in the W&B UI. If `True`, the table will be displayed in a section named
|
36
|
+
"Custom Chart Tables". Default is `False`.
|
24
37
|
|
25
38
|
Returns:
|
26
|
-
A
|
39
|
+
CustomChart: A custom chart object that can be logged to W&B. To log the
|
40
|
+
chart, pass it to `wandb.log()`.
|
27
41
|
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
42
|
+
Examples:
|
43
|
+
1. Logging a single x array where all y series are plotted against
|
44
|
+
the same x values:
|
45
|
+
|
46
|
+
```
|
32
47
|
import wandb
|
33
48
|
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
49
|
+
# Initialize W&B run
|
50
|
+
with wandb.init(project="line_series_example") as run:
|
51
|
+
# x values shared across all y series
|
52
|
+
xs = list(range(10))
|
53
|
+
|
54
|
+
# Multiple y series to plot
|
55
|
+
ys = [
|
56
|
+
[i for i in range(10)], # y = x
|
57
|
+
[i**2 for i in range(10)], # y = x^2
|
58
|
+
[i**3 for i in range(10)], # y = x^3
|
59
|
+
]
|
60
|
+
|
61
|
+
# Generate and log the line series chart
|
62
|
+
line_series_chart = wandb.plot.line_series(
|
63
|
+
xs,
|
64
|
+
ys,
|
65
|
+
title="title",
|
66
|
+
xname="step",
|
67
|
+
)
|
68
|
+
run.log({"line-series-single-x": line_series_chart})
|
41
69
|
```
|
42
|
-
|
43
|
-
|
70
|
+
|
71
|
+
In this example, a single `xs` series (shared x-values) is used for all
|
72
|
+
`ys` series. This results in each y-series being plotted against the
|
73
|
+
same x-values (0-9).
|
74
|
+
|
75
|
+
2. Logging multiple x arrays where each y series is plotted against
|
76
|
+
its corresponding x array:
|
77
|
+
|
44
78
|
```python
|
45
79
|
import wandb
|
46
80
|
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
81
|
+
# Initialize W&B run
|
82
|
+
with wandb.init(project="line_series_example") as run:
|
83
|
+
# Separate x values for each y series
|
84
|
+
xs = [
|
85
|
+
[i for i in range(10)], # x for first series
|
86
|
+
[2 * i for i in range(10)], # x for second series (stretched)
|
87
|
+
[3 * i for i in range(10)], # x for third series (stretched more)
|
88
|
+
]
|
89
|
+
|
90
|
+
# Corresponding y series
|
91
|
+
ys = [
|
92
|
+
[i for i in range(10)], # y = x
|
93
|
+
[i**2 for i in range(10)], # y = x^2
|
94
|
+
[i**3 for i in range(10)], # y = x^3
|
95
|
+
]
|
96
|
+
|
97
|
+
# Generate and log the line series chart
|
98
|
+
line_series_chart = wandb.plot.line_series(
|
99
|
+
xs, ys, title="Multiple X Arrays Example", xname="Step"
|
100
|
+
)
|
101
|
+
run.log({"line-series-multiple-x": line_series_chart})
|
54
102
|
```
|
55
|
-
"""
|
56
|
-
if not isinstance(xs, Iterable):
|
57
|
-
raise TypeError(f"Expected xs to be an array instead got {type(xs)}")
|
58
103
|
|
59
|
-
|
60
|
-
|
104
|
+
In this example, each y series is plotted against its own unique x series.
|
105
|
+
This allows for more flexibility when the x values are not uniform across
|
106
|
+
the data series.
|
107
|
+
|
108
|
+
3. Customizing line labels using `keys`:
|
109
|
+
|
110
|
+
```python
|
111
|
+
import wandb
|
112
|
+
|
113
|
+
# Initialize W&B run
|
114
|
+
with wandb.init(project="line_series_example") as run:
|
115
|
+
xs = list(range(10)) # Single x array
|
116
|
+
ys = [
|
117
|
+
[i for i in range(10)], # y = x
|
118
|
+
[i**2 for i in range(10)], # y = x^2
|
119
|
+
[i**3 for i in range(10)], # y = x^3
|
120
|
+
]
|
61
121
|
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
122
|
+
# Custom labels for each line
|
123
|
+
keys = ["Linear", "Quadratic", "Cubic"]
|
124
|
+
|
125
|
+
# Generate and log the line series chart
|
126
|
+
line_series_chart = wandb.plot.line_series(
|
127
|
+
xs,
|
128
|
+
ys,
|
129
|
+
keys=keys, # Custom keys (line labels)
|
130
|
+
title="Custom Line Labels Example",
|
131
|
+
xname="Step",
|
66
132
|
)
|
133
|
+
run.log({"line-series-custom-keys": line_series_chart})
|
134
|
+
```
|
67
135
|
|
136
|
+
This example shows how to provide custom labels for the lines using
|
137
|
+
the `keys` argument. The keys will appear in the legend as "Linear",
|
138
|
+
"Quadratic", and "Cubic".
|
139
|
+
|
140
|
+
"""
|
141
|
+
# If xs is a single array, repeat it for each y in ys
|
68
142
|
if not isinstance(xs[0], Iterable) or isinstance(xs[0], (str, bytes)):
|
69
|
-
xs = [xs
|
70
|
-
|
143
|
+
xs = [xs] * len(ys)
|
144
|
+
|
145
|
+
if len(xs) != len(ys):
|
146
|
+
msg = f"Number of x-series ({len(xs)}) must match y-series ({len(ys)})."
|
147
|
+
raise ValueError(msg)
|
148
|
+
|
149
|
+
if keys is None:
|
150
|
+
keys = [f"line_{i}" for i in range(len(ys))]
|
151
|
+
|
152
|
+
if len(keys) != len(ys):
|
153
|
+
msg = f"Number of keys ({len(keys)}) must match y-series ({len(ys)})."
|
154
|
+
raise ValueError(msg)
|
71
155
|
|
72
|
-
if keys is not None:
|
73
|
-
assert len(keys) == len(ys), "Number of keys and y-lines must match"
|
74
156
|
data = [
|
75
|
-
[x,
|
157
|
+
[x, keys[i], y]
|
76
158
|
for i, (xx, yy) in enumerate(zip(xs, ys))
|
77
159
|
for x, y in zip(xx, yy)
|
78
160
|
]
|
161
|
+
table = wandb.Table(
|
162
|
+
data=data,
|
163
|
+
columns=["step", "lineKey", "lineVal"],
|
164
|
+
)
|
79
165
|
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
166
|
+
return plot_table(
|
167
|
+
data_table=table,
|
168
|
+
vega_spec_name="wandb/lineseries/v0",
|
169
|
+
fields={
|
170
|
+
"step": "step",
|
171
|
+
"lineKey": "lineKey",
|
172
|
+
"lineVal": "lineVal",
|
173
|
+
},
|
174
|
+
string_fields={"title": title, "xname": xname},
|
87
175
|
split_table=split_table,
|
88
176
|
)
|
wandb/plot/pr_curve.py
CHANGED
@@ -1,48 +1,92 @@
|
|
1
|
-
from
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import numbers
|
4
|
+
from typing import TYPE_CHECKING, Iterable, TypeVar
|
2
5
|
|
3
6
|
import wandb
|
4
7
|
from wandb import util
|
8
|
+
from wandb.plot.custom_chart import plot_table
|
9
|
+
from wandb.plot.utils import test_missing, test_types
|
10
|
+
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
from wandb.plot.custom_chart import CustomChart
|
13
|
+
|
5
14
|
|
6
|
-
|
15
|
+
T = TypeVar("T")
|
7
16
|
|
8
17
|
|
9
18
|
def pr_curve(
|
10
|
-
y_true=None,
|
11
|
-
y_probas=None,
|
12
|
-
labels=None,
|
13
|
-
classes_to_plot=None,
|
14
|
-
interp_size=21,
|
15
|
-
title=
|
16
|
-
split_table:
|
17
|
-
):
|
18
|
-
"""
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
low false
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
values.
|
37
|
-
|
19
|
+
y_true: Iterable[T] | None = None,
|
20
|
+
y_probas: Iterable[numbers.Number] | None = None,
|
21
|
+
labels: list[str] | None = None,
|
22
|
+
classes_to_plot: list[T] | None = None,
|
23
|
+
interp_size: int = 21,
|
24
|
+
title: str = "Precision-Recall Curve",
|
25
|
+
split_table: bool = False,
|
26
|
+
) -> CustomChart:
|
27
|
+
"""Constructs a Precision-Recall (PR) curve.
|
28
|
+
|
29
|
+
The Precision-Recall curve is particularly useful for evaluating classifiers
|
30
|
+
on imbalanced datasets. A high area under the PR curve signifies both high
|
31
|
+
precision (a low false positive rate) and high recall (a low false negative
|
32
|
+
rate). The curve provides insights into the balance between false positives
|
33
|
+
and false negatives at various threshold levels, aiding in the assessment of
|
34
|
+
a model's performance.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
y_true (Iterable): True binary labels. The shape should be (`num_samples`,).
|
38
|
+
y_probas (Iterable): Predicted scores or probabilities for each class.
|
39
|
+
These can be probability estimates, confidence scores, or non-thresholded
|
40
|
+
decision values. The shape should be (`num_samples`, `num_classes`).
|
41
|
+
labels (list[str] | None): Optional list of class names to replace
|
42
|
+
numeric values in `y_true` for easier plot interpretation.
|
43
|
+
For example, `labels = ['dog', 'cat', 'owl']` will replace 0 with
|
44
|
+
'dog', 1 with 'cat', and 2 with 'owl' in the plot. If not provided,
|
45
|
+
numeric values from `y_true` will be used.
|
46
|
+
classes_to_plot (list | None): Optional list of unique class values from
|
47
|
+
y_true to be included in the plot. If not specified, all unique
|
48
|
+
classes in y_true will be plotted.
|
49
|
+
interp_size (int): Number of points to interpolate recall values. The
|
50
|
+
recall values will be fixed to `interp_size` uniformly distributed
|
51
|
+
points in the range [0, 1], and the precision will be interpolated
|
52
|
+
accordingly.
|
53
|
+
title (str): Title of the plot. Defaults to "Precision-Recall Curve".
|
54
|
+
split_table (bool): Whether the table should be split into a separate section
|
55
|
+
in the W&B UI. If `True`, the table will be displayed in a section named
|
56
|
+
"Custom Chart Tables". Default is `False`.
|
38
57
|
|
39
58
|
Returns:
|
40
|
-
|
41
|
-
|
59
|
+
CustomChart: A custom chart object that can be logged to W&B. To log the
|
60
|
+
chart, pass it to `wandb.log()`.
|
61
|
+
|
62
|
+
Raises:
|
63
|
+
wandb.Error: If numpy, pandas, or scikit-learn is not installed.
|
64
|
+
|
42
65
|
|
43
66
|
Example:
|
44
67
|
```
|
45
|
-
|
68
|
+
import wandb
|
69
|
+
|
70
|
+
# Example for spam detection (binary classification)
|
71
|
+
y_true = [0, 1, 1, 0, 1] # 0 = not spam, 1 = spam
|
72
|
+
y_probas = [
|
73
|
+
[0.9, 0.1], # Predicted probabilities for the first sample (not spam)
|
74
|
+
[0.2, 0.8], # Second sample (spam), and so on
|
75
|
+
[0.1, 0.9],
|
76
|
+
[0.8, 0.2],
|
77
|
+
[0.3, 0.7]
|
78
|
+
]
|
79
|
+
|
80
|
+
labels = ['not spam', 'spam'] # Optional class names for readability
|
81
|
+
|
82
|
+
with wandb.init(project="spam-detection") as run:
|
83
|
+
pr_curve = wandb.plot.pr_curve(
|
84
|
+
y_true=y_true,
|
85
|
+
y_probas=y_probas,
|
86
|
+
labels=labels,
|
87
|
+
title="Precision-Recall Curve for Spam Detection",
|
88
|
+
)
|
89
|
+
run.log({"pr-curve": pr_curve})
|
46
90
|
```
|
47
91
|
"""
|
48
92
|
np = util.get_module(
|
@@ -80,7 +124,7 @@ def pr_curve(
|
|
80
124
|
if classes_to_plot is None:
|
81
125
|
classes_to_plot = classes
|
82
126
|
|
83
|
-
precision =
|
127
|
+
precision = {}
|
84
128
|
interp_recall = np.linspace(0, 1, interp_size)[::-1]
|
85
129
|
indices_to_plot = np.where(np.isin(classes, classes_to_plot))[0]
|
86
130
|
for i in indices_to_plot:
|
@@ -109,12 +153,11 @@ def pr_curve(
|
|
109
153
|
"precision": np.hstack(list(precision.values())),
|
110
154
|
"recall": np.tile(interp_recall, len(precision)),
|
111
155
|
}
|
112
|
-
)
|
113
|
-
df = df.round(3)
|
156
|
+
).round(3)
|
114
157
|
|
115
158
|
if len(df) > wandb.Table.MAX_ROWS:
|
116
159
|
wandb.termwarn(
|
117
|
-
"
|
160
|
+
f"Table has a limit of {wandb.Table.MAX_ROWS} rows. Resampling to fit."
|
118
161
|
)
|
119
162
|
# different sampling could be applied, possibly to ensure endpoints are kept
|
120
163
|
df = sklearn_utils.resample(
|
@@ -125,12 +168,14 @@ def pr_curve(
|
|
125
168
|
stratify=df["class"],
|
126
169
|
).sort_values(["precision", "recall", "class"])
|
127
170
|
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
171
|
+
return plot_table(
|
172
|
+
data_table=wandb.Table(dataframe=df),
|
173
|
+
vega_spec_name="wandb/area-under-curve/v0",
|
174
|
+
fields={
|
175
|
+
"x": "recall",
|
176
|
+
"y": "precision",
|
177
|
+
"class": "class",
|
178
|
+
},
|
179
|
+
string_fields={"title": title},
|
135
180
|
split_table=split_table,
|
136
181
|
)
|