wandb 0.18.5__py3-none-win32.whl → 0.18.6__py3-none-win32.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +2 -2
- wandb/__init__.pyi +21 -19
- wandb/agents/pyagent.py +1 -1
- wandb/apis/importers/wandb.py +1 -1
- wandb/apis/normalize.py +2 -18
- wandb/apis/public/api.py +122 -62
- wandb/apis/public/artifacts.py +8 -3
- wandb/apis/public/files.py +17 -2
- wandb/apis/public/jobs.py +2 -2
- wandb/apis/public/query_generator.py +1 -1
- wandb/apis/public/runs.py +8 -8
- wandb/apis/public/teams.py +3 -3
- wandb/apis/public/users.py +1 -1
- wandb/apis/public/utils.py +68 -0
- wandb/bin/gpu_stats.exe +0 -0
- wandb/bin/wandb-core +0 -0
- wandb/cli/cli.py +12 -3
- wandb/data_types.py +1 -1
- wandb/docker/__init__.py +2 -1
- wandb/docker/auth.py +2 -3
- wandb/errors/links.py +73 -0
- wandb/errors/term.py +7 -6
- wandb/filesync/step_prepare.py +1 -1
- wandb/filesync/upload_job.py +1 -1
- wandb/integration/catboost/catboost.py +2 -2
- wandb/integration/diffusers/pipeline_resolver.py +1 -1
- wandb/integration/diffusers/resolvers/multimodal.py +6 -6
- wandb/integration/diffusers/resolvers/utils.py +1 -1
- wandb/integration/fastai/__init__.py +3 -2
- wandb/integration/keras/callbacks/metrics_logger.py +1 -1
- wandb/integration/keras/callbacks/model_checkpoint.py +1 -1
- wandb/integration/keras/keras.py +1 -1
- wandb/integration/kfp/kfp_patch.py +1 -1
- wandb/integration/lightgbm/__init__.py +2 -2
- wandb/integration/magic.py +2 -2
- wandb/integration/metaflow/metaflow.py +1 -1
- wandb/integration/sacred/__init__.py +1 -1
- wandb/integration/sagemaker/auth.py +1 -1
- wandb/integration/sklearn/plot/classifier.py +7 -7
- wandb/integration/sklearn/plot/clusterer.py +3 -3
- wandb/integration/sklearn/plot/regressor.py +3 -3
- wandb/integration/sklearn/plot/shared.py +2 -2
- wandb/integration/tensorboard/log.py +2 -2
- wandb/integration/ultralytics/callback.py +2 -2
- wandb/integration/xgboost/xgboost.py +1 -1
- wandb/jupyter.py +0 -1
- wandb/plot/__init__.py +17 -8
- wandb/plot/bar.py +53 -27
- wandb/plot/confusion_matrix.py +151 -70
- wandb/plot/custom_chart.py +124 -0
- wandb/plot/histogram.py +46 -20
- wandb/plot/line.py +57 -26
- wandb/plot/line_series.py +148 -60
- wandb/plot/pr_curve.py +89 -44
- wandb/plot/roc_curve.py +82 -37
- wandb/plot/scatter.py +53 -20
- wandb/plot/viz.py +20 -102
- wandb/sdk/artifacts/artifact.py +280 -328
- wandb/sdk/artifacts/artifact_manifest.py +10 -9
- wandb/sdk/artifacts/artifact_manifest_entry.py +1 -1
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +9 -4
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +1 -3
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +1 -1
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +2 -2
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -1
- wandb/sdk/backend/backend.py +0 -1
- wandb/sdk/data_types/audio.py +1 -1
- wandb/sdk/data_types/base_types/media.py +66 -5
- wandb/sdk/data_types/bokeh.py +1 -1
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +1 -1
- wandb/sdk/data_types/helper_types/image_mask.py +2 -2
- wandb/sdk/data_types/histogram.py +1 -1
- wandb/sdk/data_types/html.py +1 -1
- wandb/sdk/data_types/image.py +1 -1
- wandb/sdk/data_types/molecule.py +3 -3
- wandb/sdk/data_types/object_3d.py +4 -4
- wandb/sdk/data_types/plotly.py +1 -1
- wandb/sdk/data_types/saved_model.py +0 -1
- wandb/sdk/data_types/table.py +7 -7
- wandb/sdk/data_types/trace_tree.py +1 -1
- wandb/sdk/data_types/video.py +4 -3
- wandb/sdk/interface/router.py +0 -2
- wandb/sdk/internal/datastore.py +1 -1
- wandb/sdk/internal/file_pusher.py +1 -1
- wandb/sdk/internal/file_stream.py +4 -4
- wandb/sdk/internal/handler.py +3 -2
- wandb/sdk/internal/internal.py +1 -1
- wandb/sdk/internal/internal_api.py +178 -63
- wandb/sdk/internal/job_builder.py +4 -3
- wandb/sdk/internal/system/assets/__init__.py +0 -2
- wandb/sdk/internal/tb_watcher.py +11 -10
- wandb/sdk/launch/_launch.py +4 -3
- wandb/sdk/launch/_launch_add.py +2 -2
- wandb/sdk/launch/builder/kaniko_builder.py +0 -1
- wandb/sdk/launch/create_job.py +1 -0
- wandb/sdk/launch/environment/local_environment.py +0 -1
- wandb/sdk/launch/errors.py +0 -6
- wandb/sdk/launch/registry/local_registry.py +0 -2
- wandb/sdk/launch/runner/abstract.py +0 -5
- wandb/sdk/launch/sweeps/__init__.py +0 -2
- wandb/sdk/launch/sweeps/scheduler.py +0 -2
- wandb/sdk/launch/sweeps/scheduler_sweep.py +0 -1
- wandb/sdk/lib/apikey.py +3 -3
- wandb/sdk/lib/file_stream_utils.py +1 -1
- wandb/sdk/lib/filesystem.py +1 -1
- wandb/sdk/lib/ipython.py +16 -9
- wandb/sdk/lib/mailbox.py +0 -4
- wandb/sdk/lib/printer.py +44 -8
- wandb/sdk/lib/retry.py +1 -1
- wandb/sdk/service/service.py +3 -3
- wandb/sdk/service/streams.py +2 -4
- wandb/sdk/wandb_init.py +20 -20
- wandb/sdk/wandb_login.py +1 -1
- wandb/sdk/wandb_require.py +1 -4
- wandb/sdk/wandb_run.py +57 -69
- wandb/sdk/wandb_settings.py +3 -4
- wandb/sdk/wandb_sync.py +2 -1
- wandb/util.py +46 -18
- wandb/wandb_agent.py +3 -3
- wandb/wandb_controller.py +2 -2
- {wandb-0.18.5.dist-info → wandb-0.18.6.dist-info}/METADATA +1 -1
- {wandb-0.18.5.dist-info → wandb-0.18.6.dist-info}/RECORD +125 -126
- wandb/sdk/internal/system/assets/gpu_apple.py +0 -177
- wandb/sdk/lib/_wburls_generate.py +0 -25
- wandb/sdk/lib/_wburls_generated.py +0 -22
- wandb/sdk/lib/wburls.py +0 -46
- {wandb-0.18.5.dist-info → wandb-0.18.6.dist-info}/WHEEL +0 -0
- {wandb-0.18.5.dist-info → wandb-0.18.6.dist-info}/entry_points.txt +0 -0
- {wandb-0.18.5.dist-info → wandb-0.18.6.dist-info}/licenses/LICENSE +0 -0
wandb/plot/histogram.py
CHANGED
@@ -1,39 +1,65 @@
|
|
1
|
-
from
|
1
|
+
from __future__ import annotations
|
2
2
|
|
3
|
-
from
|
3
|
+
from typing import TYPE_CHECKING
|
4
|
+
|
5
|
+
from wandb.plot.custom_chart import plot_table
|
4
6
|
|
5
7
|
if TYPE_CHECKING:
|
6
8
|
import wandb
|
9
|
+
from wandb.plot.custom_chart import CustomChart
|
7
10
|
|
8
11
|
|
9
12
|
def histogram(
|
10
|
-
table:
|
13
|
+
table: wandb.Table,
|
11
14
|
value: str,
|
12
|
-
title:
|
13
|
-
split_table:
|
14
|
-
):
|
15
|
-
"""
|
15
|
+
title: str = "",
|
16
|
+
split_table: bool = False,
|
17
|
+
) -> CustomChart:
|
18
|
+
"""Constructs a histogram chart from a W&B Table.
|
16
19
|
|
17
|
-
|
18
|
-
table (wandb.Table): Table
|
19
|
-
value (
|
20
|
-
title (
|
21
|
-
split_table (bool):
|
20
|
+
Args:
|
21
|
+
table (wandb.Table): The W&B Table containing the data for the histogram.
|
22
|
+
value (str): The label for the bin axis (x-axis).
|
23
|
+
title (str): The title of the histogram plot.
|
24
|
+
split_table (bool): Whether the table should be split into a separate section
|
25
|
+
in the W&B UI. If `True`, the table will be displayed in a section named
|
26
|
+
"Custom Chart Tables". Default is `False`.
|
22
27
|
|
23
28
|
Returns:
|
24
|
-
A
|
29
|
+
CustomChart: A custom chart object that can be logged to W&B. To log the
|
30
|
+
chart, pass it to `wandb.log()`.
|
25
31
|
|
26
32
|
Example:
|
27
33
|
```
|
34
|
+
import math
|
35
|
+
import random
|
36
|
+
import wandb
|
37
|
+
|
38
|
+
# Generate random data
|
28
39
|
data = [[i, random.random() + math.sin(i / 10)] for i in range(100)]
|
29
|
-
|
30
|
-
|
40
|
+
|
41
|
+
# Create a W&B Table
|
42
|
+
table = wandb.Table(
|
43
|
+
data=data,
|
44
|
+
columns=["step", "height"],
|
45
|
+
)
|
46
|
+
|
47
|
+
# Create a histogram plot
|
48
|
+
histogram = wandb.plot.histogram(
|
49
|
+
table,
|
50
|
+
value="height",
|
51
|
+
title="My Histogram",
|
52
|
+
)
|
53
|
+
|
54
|
+
# Log the histogram plot to W&B
|
55
|
+
with wandb.init(...) as run:
|
56
|
+
run.log({'histogram-plot1': histogram})
|
31
57
|
```
|
32
58
|
"""
|
33
|
-
return
|
34
|
-
|
35
|
-
|
36
|
-
{"value": value},
|
37
|
-
{"title": title},
|
59
|
+
return plot_table(
|
60
|
+
data_table=table,
|
61
|
+
vega_spec_name="wandb/histogram/v0",
|
62
|
+
fields={"value": value},
|
63
|
+
string_fields={"title": title},
|
38
64
|
split_table=split_table,
|
39
65
|
)
|
wandb/plot/line.py
CHANGED
@@ -1,43 +1,74 @@
|
|
1
|
-
from
|
1
|
+
from __future__ import annotations
|
2
2
|
|
3
|
-
from
|
3
|
+
from typing import TYPE_CHECKING
|
4
|
+
|
5
|
+
from wandb.plot.custom_chart import plot_table
|
4
6
|
|
5
7
|
if TYPE_CHECKING:
|
6
8
|
import wandb
|
9
|
+
from wandb.plot.custom_chart import CustomChart
|
7
10
|
|
8
11
|
|
9
12
|
def line(
|
10
|
-
table:
|
13
|
+
table: wandb.Table,
|
11
14
|
x: str,
|
12
15
|
y: str,
|
13
|
-
stroke:
|
14
|
-
title:
|
15
|
-
split_table:
|
16
|
-
):
|
17
|
-
"""
|
18
|
-
|
19
|
-
|
20
|
-
table (wandb.Table):
|
21
|
-
x (
|
22
|
-
y (
|
23
|
-
stroke (
|
24
|
-
|
25
|
-
|
16
|
+
stroke: str | None = None,
|
17
|
+
title: str = "",
|
18
|
+
split_table: bool = False,
|
19
|
+
) -> CustomChart:
|
20
|
+
"""Constructs a customizable line chart.
|
21
|
+
|
22
|
+
Args:
|
23
|
+
table (wandb.Table): The table containing data for the chart.
|
24
|
+
x (str): Column name for the x-axis values.
|
25
|
+
y (str): Column name for the y-axis values.
|
26
|
+
stroke (str):Column name to differentiate line strokes (e.g., for
|
27
|
+
grouping lines).
|
28
|
+
title (str):Title of the chart.
|
29
|
+
split_table (bool): Whether the table should be split into a separate section
|
30
|
+
in the W&B UI. If `True`, the table will be displayed in a section named
|
31
|
+
"Custom Chart Tables". Default is `False`.
|
26
32
|
|
27
33
|
Returns:
|
28
|
-
|
34
|
+
CustomChart: A custom chart object that can be logged to W&B. To log the
|
35
|
+
chart, pass it to `wandb.log()`.
|
29
36
|
|
30
37
|
Example:
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
38
|
+
```python
|
39
|
+
import math
|
40
|
+
import random
|
41
|
+
import wandb
|
42
|
+
|
43
|
+
# Create multiple series of data with different patterns
|
44
|
+
data = []
|
45
|
+
for i in range(100):
|
46
|
+
# Series 1: Sinusoidal pattern with random noise
|
47
|
+
data.append([i, math.sin(i / 10) + random.uniform(-0.1, 0.1), "series_1"])
|
48
|
+
# Series 2: Cosine pattern with random noise
|
49
|
+
data.append([i, math.cos(i / 10) + random.uniform(-0.1, 0.1), "series_2"])
|
50
|
+
# Series 3: Linear increase with random noise
|
51
|
+
data.append([i, i / 10 + random.uniform(-0.5, 0.5), "series_3"])
|
52
|
+
|
53
|
+
# Define the columns for the table
|
54
|
+
table = wandb.Table(data=data, columns=["step", "value", "series"])
|
55
|
+
|
56
|
+
# Initialize wandb run and log the line chart
|
57
|
+
with wandb.init(project="line_chart_example") as run:
|
58
|
+
line_chart = wandb.plot.line(
|
59
|
+
table=table,
|
60
|
+
x="step",
|
61
|
+
y="value",
|
62
|
+
stroke="series", # Group by the "series" column
|
63
|
+
title="Multi-Series Line Plot",
|
64
|
+
)
|
65
|
+
run.log({"line-chart": line_chart})
|
35
66
|
```
|
36
67
|
"""
|
37
|
-
return
|
38
|
-
|
39
|
-
|
40
|
-
{"x": x, "y": y, "stroke": stroke},
|
41
|
-
{"title": title},
|
68
|
+
return plot_table(
|
69
|
+
data_table=table,
|
70
|
+
vega_spec_name="wandb/line/v0",
|
71
|
+
fields={"x": x, "y": y, "stroke": stroke},
|
72
|
+
string_fields={"title": title},
|
42
73
|
split_table=split_table,
|
43
74
|
)
|
wandb/plot/line_series.py
CHANGED
@@ -1,88 +1,176 @@
|
|
1
|
-
|
2
|
-
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING, Any, Iterable
|
3
4
|
|
4
5
|
import wandb
|
6
|
+
from wandb.plot.custom_chart import plot_table
|
7
|
+
|
8
|
+
if TYPE_CHECKING:
|
9
|
+
from wandb.plot.custom_chart import CustomChart
|
5
10
|
|
6
11
|
|
7
12
|
def line_series(
|
8
|
-
xs:
|
9
|
-
ys:
|
10
|
-
keys:
|
11
|
-
title:
|
12
|
-
xname:
|
13
|
-
split_table:
|
14
|
-
):
|
15
|
-
"""
|
16
|
-
|
17
|
-
|
18
|
-
xs (
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
13
|
+
xs: Iterable[Iterable[Any]] | Iterable[Any],
|
14
|
+
ys: Iterable[Iterable[Any]],
|
15
|
+
keys: Iterable[str] | None = None,
|
16
|
+
title: str = "",
|
17
|
+
xname: str = "x",
|
18
|
+
split_table: bool = False,
|
19
|
+
) -> CustomChart:
|
20
|
+
"""Constructs a line series chart.
|
21
|
+
|
22
|
+
Args:
|
23
|
+
xs (Iterable[Iterable] | Iterable): Sequence of x values. If a singular
|
24
|
+
array is provided, all y values are plotted against that x array. If
|
25
|
+
an array of arrays is provided, each y value is plotted against the
|
26
|
+
corresponding x array.
|
27
|
+
ys (Iterable[Iterable]): Sequence of y values, where each iterable represents
|
28
|
+
a separate line series.
|
29
|
+
keys (Iterable[str]): Sequence of keys for labeling each line series. If
|
30
|
+
not provided, keys will be automatically generated as "line_1",
|
31
|
+
"line_2", etc.
|
32
|
+
title (str): Title of the chart.
|
33
|
+
xname (str): Label for the x-axis.
|
34
|
+
split_table (bool): Whether the table should be split into a separate section
|
35
|
+
in the W&B UI. If `True`, the table will be displayed in a section named
|
36
|
+
"Custom Chart Tables". Default is `False`.
|
24
37
|
|
25
38
|
Returns:
|
26
|
-
A
|
39
|
+
CustomChart: A custom chart object that can be logged to W&B. To log the
|
40
|
+
chart, pass it to `wandb.log()`.
|
27
41
|
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
42
|
+
Examples:
|
43
|
+
1. Logging a single x array where all y series are plotted against
|
44
|
+
the same x values:
|
45
|
+
|
46
|
+
```
|
32
47
|
import wandb
|
33
48
|
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
49
|
+
# Initialize W&B run
|
50
|
+
with wandb.init(project="line_series_example") as run:
|
51
|
+
# x values shared across all y series
|
52
|
+
xs = list(range(10))
|
53
|
+
|
54
|
+
# Multiple y series to plot
|
55
|
+
ys = [
|
56
|
+
[i for i in range(10)], # y = x
|
57
|
+
[i**2 for i in range(10)], # y = x^2
|
58
|
+
[i**3 for i in range(10)], # y = x^3
|
59
|
+
]
|
60
|
+
|
61
|
+
# Generate and log the line series chart
|
62
|
+
line_series_chart = wandb.plot.line_series(
|
63
|
+
xs,
|
64
|
+
ys,
|
65
|
+
title="title",
|
66
|
+
xname="step",
|
67
|
+
)
|
68
|
+
run.log({"line-series-single-x": line_series_chart})
|
41
69
|
```
|
42
|
-
|
43
|
-
|
70
|
+
|
71
|
+
In this example, a single `xs` series (shared x-values) is used for all
|
72
|
+
`ys` series. This results in each y-series being plotted against the
|
73
|
+
same x-values (0-9).
|
74
|
+
|
75
|
+
2. Logging multiple x arrays where each y series is plotted against
|
76
|
+
its corresponding x array:
|
77
|
+
|
44
78
|
```python
|
45
79
|
import wandb
|
46
80
|
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
81
|
+
# Initialize W&B run
|
82
|
+
with wandb.init(project="line_series_example") as run:
|
83
|
+
# Separate x values for each y series
|
84
|
+
xs = [
|
85
|
+
[i for i in range(10)], # x for first series
|
86
|
+
[2 * i for i in range(10)], # x for second series (stretched)
|
87
|
+
[3 * i for i in range(10)], # x for third series (stretched more)
|
88
|
+
]
|
89
|
+
|
90
|
+
# Corresponding y series
|
91
|
+
ys = [
|
92
|
+
[i for i in range(10)], # y = x
|
93
|
+
[i**2 for i in range(10)], # y = x^2
|
94
|
+
[i**3 for i in range(10)], # y = x^3
|
95
|
+
]
|
96
|
+
|
97
|
+
# Generate and log the line series chart
|
98
|
+
line_series_chart = wandb.plot.line_series(
|
99
|
+
xs, ys, title="Multiple X Arrays Example", xname="Step"
|
100
|
+
)
|
101
|
+
run.log({"line-series-multiple-x": line_series_chart})
|
54
102
|
```
|
55
|
-
"""
|
56
|
-
if not isinstance(xs, Iterable):
|
57
|
-
raise TypeError(f"Expected xs to be an array instead got {type(xs)}")
|
58
103
|
|
59
|
-
|
60
|
-
|
104
|
+
In this example, each y series is plotted against its own unique x series.
|
105
|
+
This allows for more flexibility when the x values are not uniform across
|
106
|
+
the data series.
|
107
|
+
|
108
|
+
3. Customizing line labels using `keys`:
|
109
|
+
|
110
|
+
```python
|
111
|
+
import wandb
|
112
|
+
|
113
|
+
# Initialize W&B run
|
114
|
+
with wandb.init(project="line_series_example") as run:
|
115
|
+
xs = list(range(10)) # Single x array
|
116
|
+
ys = [
|
117
|
+
[i for i in range(10)], # y = x
|
118
|
+
[i**2 for i in range(10)], # y = x^2
|
119
|
+
[i**3 for i in range(10)], # y = x^3
|
120
|
+
]
|
61
121
|
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
122
|
+
# Custom labels for each line
|
123
|
+
keys = ["Linear", "Quadratic", "Cubic"]
|
124
|
+
|
125
|
+
# Generate and log the line series chart
|
126
|
+
line_series_chart = wandb.plot.line_series(
|
127
|
+
xs,
|
128
|
+
ys,
|
129
|
+
keys=keys, # Custom keys (line labels)
|
130
|
+
title="Custom Line Labels Example",
|
131
|
+
xname="Step",
|
66
132
|
)
|
133
|
+
run.log({"line-series-custom-keys": line_series_chart})
|
134
|
+
```
|
67
135
|
|
136
|
+
This example shows how to provide custom labels for the lines using
|
137
|
+
the `keys` argument. The keys will appear in the legend as "Linear",
|
138
|
+
"Quadratic", and "Cubic".
|
139
|
+
|
140
|
+
"""
|
141
|
+
# If xs is a single array, repeat it for each y in ys
|
68
142
|
if not isinstance(xs[0], Iterable) or isinstance(xs[0], (str, bytes)):
|
69
|
-
xs = [xs
|
70
|
-
|
143
|
+
xs = [xs] * len(ys)
|
144
|
+
|
145
|
+
if len(xs) != len(ys):
|
146
|
+
msg = f"Number of x-series ({len(xs)}) must match y-series ({len(ys)})."
|
147
|
+
raise ValueError(msg)
|
148
|
+
|
149
|
+
if keys is None:
|
150
|
+
keys = [f"line_{i}" for i in range(len(ys))]
|
151
|
+
|
152
|
+
if len(keys) != len(ys):
|
153
|
+
msg = f"Number of keys ({len(keys)}) must match y-series ({len(ys)})."
|
154
|
+
raise ValueError(msg)
|
71
155
|
|
72
|
-
if keys is not None:
|
73
|
-
assert len(keys) == len(ys), "Number of keys and y-lines must match"
|
74
156
|
data = [
|
75
|
-
[x,
|
157
|
+
[x, keys[i], y]
|
76
158
|
for i, (xx, yy) in enumerate(zip(xs, ys))
|
77
159
|
for x, y in zip(xx, yy)
|
78
160
|
]
|
161
|
+
table = wandb.Table(
|
162
|
+
data=data,
|
163
|
+
columns=["step", "lineKey", "lineVal"],
|
164
|
+
)
|
79
165
|
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
166
|
+
return plot_table(
|
167
|
+
data_table=table,
|
168
|
+
vega_spec_name="wandb/lineseries/v0",
|
169
|
+
fields={
|
170
|
+
"step": "step",
|
171
|
+
"lineKey": "lineKey",
|
172
|
+
"lineVal": "lineVal",
|
173
|
+
},
|
174
|
+
string_fields={"title": title, "xname": xname},
|
87
175
|
split_table=split_table,
|
88
176
|
)
|
wandb/plot/pr_curve.py
CHANGED
@@ -1,48 +1,92 @@
|
|
1
|
-
from
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import numbers
|
4
|
+
from typing import TYPE_CHECKING, Iterable, TypeVar
|
2
5
|
|
3
6
|
import wandb
|
4
7
|
from wandb import util
|
8
|
+
from wandb.plot.custom_chart import plot_table
|
9
|
+
from wandb.plot.utils import test_missing, test_types
|
10
|
+
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
from wandb.plot.custom_chart import CustomChart
|
13
|
+
|
5
14
|
|
6
|
-
|
15
|
+
T = TypeVar("T")
|
7
16
|
|
8
17
|
|
9
18
|
def pr_curve(
|
10
|
-
y_true=None,
|
11
|
-
y_probas=None,
|
12
|
-
labels=None,
|
13
|
-
classes_to_plot=None,
|
14
|
-
interp_size=21,
|
15
|
-
title=
|
16
|
-
split_table:
|
17
|
-
):
|
18
|
-
"""
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
low false
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
values.
|
37
|
-
|
19
|
+
y_true: Iterable[T] | None = None,
|
20
|
+
y_probas: Iterable[numbers.Number] | None = None,
|
21
|
+
labels: list[str] | None = None,
|
22
|
+
classes_to_plot: list[T] | None = None,
|
23
|
+
interp_size: int = 21,
|
24
|
+
title: str = "Precision-Recall Curve",
|
25
|
+
split_table: bool = False,
|
26
|
+
) -> CustomChart:
|
27
|
+
"""Constructs a Precision-Recall (PR) curve.
|
28
|
+
|
29
|
+
The Precision-Recall curve is particularly useful for evaluating classifiers
|
30
|
+
on imbalanced datasets. A high area under the PR curve signifies both high
|
31
|
+
precision (a low false positive rate) and high recall (a low false negative
|
32
|
+
rate). The curve provides insights into the balance between false positives
|
33
|
+
and false negatives at various threshold levels, aiding in the assessment of
|
34
|
+
a model's performance.
|
35
|
+
|
36
|
+
Args:
|
37
|
+
y_true (Iterable): True binary labels. The shape should be (`num_samples`,).
|
38
|
+
y_probas (Iterable): Predicted scores or probabilities for each class.
|
39
|
+
These can be probability estimates, confidence scores, or non-thresholded
|
40
|
+
decision values. The shape should be (`num_samples`, `num_classes`).
|
41
|
+
labels (list[str] | None): Optional list of class names to replace
|
42
|
+
numeric values in `y_true` for easier plot interpretation.
|
43
|
+
For example, `labels = ['dog', 'cat', 'owl']` will replace 0 with
|
44
|
+
'dog', 1 with 'cat', and 2 with 'owl' in the plot. If not provided,
|
45
|
+
numeric values from `y_true` will be used.
|
46
|
+
classes_to_plot (list | None): Optional list of unique class values from
|
47
|
+
y_true to be included in the plot. If not specified, all unique
|
48
|
+
classes in y_true will be plotted.
|
49
|
+
interp_size (int): Number of points to interpolate recall values. The
|
50
|
+
recall values will be fixed to `interp_size` uniformly distributed
|
51
|
+
points in the range [0, 1], and the precision will be interpolated
|
52
|
+
accordingly.
|
53
|
+
title (str): Title of the plot. Defaults to "Precision-Recall Curve".
|
54
|
+
split_table (bool): Whether the table should be split into a separate section
|
55
|
+
in the W&B UI. If `True`, the table will be displayed in a section named
|
56
|
+
"Custom Chart Tables". Default is `False`.
|
38
57
|
|
39
58
|
Returns:
|
40
|
-
|
41
|
-
|
59
|
+
CustomChart: A custom chart object that can be logged to W&B. To log the
|
60
|
+
chart, pass it to `wandb.log()`.
|
61
|
+
|
62
|
+
Raises:
|
63
|
+
wandb.Error: If numpy, pandas, or scikit-learn is not installed.
|
64
|
+
|
42
65
|
|
43
66
|
Example:
|
44
67
|
```
|
45
|
-
|
68
|
+
import wandb
|
69
|
+
|
70
|
+
# Example for spam detection (binary classification)
|
71
|
+
y_true = [0, 1, 1, 0, 1] # 0 = not spam, 1 = spam
|
72
|
+
y_probas = [
|
73
|
+
[0.9, 0.1], # Predicted probabilities for the first sample (not spam)
|
74
|
+
[0.2, 0.8], # Second sample (spam), and so on
|
75
|
+
[0.1, 0.9],
|
76
|
+
[0.8, 0.2],
|
77
|
+
[0.3, 0.7]
|
78
|
+
]
|
79
|
+
|
80
|
+
labels = ['not spam', 'spam'] # Optional class names for readability
|
81
|
+
|
82
|
+
with wandb.init(project="spam-detection") as run:
|
83
|
+
pr_curve = wandb.plot.pr_curve(
|
84
|
+
y_true=y_true,
|
85
|
+
y_probas=y_probas,
|
86
|
+
labels=labels,
|
87
|
+
title="Precision-Recall Curve for Spam Detection",
|
88
|
+
)
|
89
|
+
run.log({"pr-curve": pr_curve})
|
46
90
|
```
|
47
91
|
"""
|
48
92
|
np = util.get_module(
|
@@ -80,7 +124,7 @@ def pr_curve(
|
|
80
124
|
if classes_to_plot is None:
|
81
125
|
classes_to_plot = classes
|
82
126
|
|
83
|
-
precision =
|
127
|
+
precision = {}
|
84
128
|
interp_recall = np.linspace(0, 1, interp_size)[::-1]
|
85
129
|
indices_to_plot = np.where(np.isin(classes, classes_to_plot))[0]
|
86
130
|
for i in indices_to_plot:
|
@@ -109,12 +153,11 @@ def pr_curve(
|
|
109
153
|
"precision": np.hstack(list(precision.values())),
|
110
154
|
"recall": np.tile(interp_recall, len(precision)),
|
111
155
|
}
|
112
|
-
)
|
113
|
-
df = df.round(3)
|
156
|
+
).round(3)
|
114
157
|
|
115
158
|
if len(df) > wandb.Table.MAX_ROWS:
|
116
159
|
wandb.termwarn(
|
117
|
-
"
|
160
|
+
f"Table has a limit of {wandb.Table.MAX_ROWS} rows. Resampling to fit."
|
118
161
|
)
|
119
162
|
# different sampling could be applied, possibly to ensure endpoints are kept
|
120
163
|
df = sklearn_utils.resample(
|
@@ -125,12 +168,14 @@ def pr_curve(
|
|
125
168
|
stratify=df["class"],
|
126
169
|
).sort_values(["precision", "recall", "class"])
|
127
170
|
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
171
|
+
return plot_table(
|
172
|
+
data_table=wandb.Table(dataframe=df),
|
173
|
+
vega_spec_name="wandb/area-under-curve/v0",
|
174
|
+
fields={
|
175
|
+
"x": "recall",
|
176
|
+
"y": "precision",
|
177
|
+
"class": "class",
|
178
|
+
},
|
179
|
+
string_fields={"title": title},
|
135
180
|
split_table=split_table,
|
136
181
|
)
|