wandb 0.18.5__py3-none-any.whl → 0.18.7__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- wandb/__init__.py +2 -2
- wandb/__init__.pyi +22 -20
- wandb/agents/pyagent.py +1 -1
- wandb/apis/importers/wandb.py +1 -1
- wandb/apis/normalize.py +2 -18
- wandb/apis/public/api.py +126 -62
- wandb/apis/public/artifacts.py +8 -3
- wandb/apis/public/files.py +17 -2
- wandb/apis/public/jobs.py +2 -2
- wandb/apis/public/query_generator.py +1 -1
- wandb/apis/public/runs.py +8 -8
- wandb/apis/public/teams.py +3 -3
- wandb/apis/public/users.py +1 -1
- wandb/apis/public/utils.py +68 -0
- wandb/bin/gpu_stats +0 -0
- wandb/cli/cli.py +12 -3
- wandb/data_types.py +1 -1
- wandb/docker/__init__.py +2 -1
- wandb/docker/auth.py +2 -3
- wandb/errors/links.py +73 -0
- wandb/errors/term.py +7 -6
- wandb/filesync/step_prepare.py +1 -1
- wandb/filesync/upload_job.py +1 -1
- wandb/integration/catboost/catboost.py +2 -2
- wandb/integration/diffusers/pipeline_resolver.py +1 -1
- wandb/integration/diffusers/resolvers/multimodal.py +6 -6
- wandb/integration/diffusers/resolvers/utils.py +1 -1
- wandb/integration/fastai/__init__.py +3 -2
- wandb/integration/keras/callbacks/metrics_logger.py +1 -1
- wandb/integration/keras/callbacks/model_checkpoint.py +1 -1
- wandb/integration/keras/keras.py +1 -1
- wandb/integration/kfp/kfp_patch.py +1 -1
- wandb/integration/lightgbm/__init__.py +2 -2
- wandb/integration/magic.py +2 -2
- wandb/integration/metaflow/metaflow.py +1 -1
- wandb/integration/sacred/__init__.py +1 -1
- wandb/integration/sagemaker/auth.py +1 -1
- wandb/integration/sklearn/plot/classifier.py +7 -7
- wandb/integration/sklearn/plot/clusterer.py +3 -3
- wandb/integration/sklearn/plot/regressor.py +3 -3
- wandb/integration/sklearn/plot/shared.py +2 -2
- wandb/integration/tensorboard/log.py +2 -2
- wandb/integration/ultralytics/callback.py +2 -2
- wandb/integration/xgboost/xgboost.py +1 -1
- wandb/jupyter.py +0 -1
- wandb/plot/__init__.py +17 -8
- wandb/plot/bar.py +53 -27
- wandb/plot/confusion_matrix.py +151 -70
- wandb/plot/custom_chart.py +124 -0
- wandb/plot/histogram.py +46 -20
- wandb/plot/line.py +57 -26
- wandb/plot/line_series.py +148 -60
- wandb/plot/pr_curve.py +89 -44
- wandb/plot/roc_curve.py +82 -37
- wandb/plot/scatter.py +53 -20
- wandb/plot/viz.py +20 -102
- wandb/proto/v3/wandb_settings_pb2.py +2 -2
- wandb/proto/v3/wandb_telemetry_pb2.py +4 -4
- wandb/proto/v4/wandb_settings_pb2.py +2 -2
- wandb/proto/v4/wandb_telemetry_pb2.py +4 -4
- wandb/proto/v5/wandb_settings_pb2.py +2 -2
- wandb/proto/v5/wandb_telemetry_pb2.py +4 -4
- wandb/proto/wandb_deprecated.py +2 -0
- wandb/sdk/artifacts/artifact.py +281 -329
- wandb/sdk/artifacts/artifact_manifest.py +10 -9
- wandb/sdk/artifacts/artifact_manifest_entry.py +1 -1
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +9 -4
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +1 -3
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +1 -1
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +2 -2
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -1
- wandb/sdk/backend/backend.py +0 -1
- wandb/sdk/data_types/audio.py +1 -1
- wandb/sdk/data_types/base_types/media.py +66 -5
- wandb/sdk/data_types/base_types/wb_value.py +20 -10
- wandb/sdk/data_types/bokeh.py +1 -1
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +1 -1
- wandb/sdk/data_types/helper_types/image_mask.py +2 -2
- wandb/sdk/data_types/histogram.py +1 -1
- wandb/sdk/data_types/html.py +1 -1
- wandb/sdk/data_types/image.py +1 -1
- wandb/sdk/data_types/molecule.py +3 -3
- wandb/sdk/data_types/object_3d.py +4 -4
- wandb/sdk/data_types/plotly.py +1 -1
- wandb/sdk/data_types/saved_model.py +0 -1
- wandb/sdk/data_types/table.py +7 -7
- wandb/sdk/data_types/trace_tree.py +1 -1
- wandb/sdk/data_types/video.py +4 -3
- wandb/sdk/interface/interface_queue.py +0 -6
- wandb/sdk/interface/router.py +1 -4
- wandb/sdk/interface/router_queue.py +0 -3
- wandb/sdk/interface/router_relay.py +0 -2
- wandb/sdk/internal/datastore.py +1 -1
- wandb/sdk/internal/file_pusher.py +1 -1
- wandb/sdk/internal/file_stream.py +4 -4
- wandb/sdk/internal/handler.py +3 -4
- wandb/sdk/internal/internal.py +1 -15
- wandb/sdk/internal/internal_api.py +178 -63
- wandb/sdk/internal/internal_util.py +0 -3
- wandb/sdk/internal/job_builder.py +4 -3
- wandb/sdk/internal/sender.py +0 -2
- wandb/sdk/internal/system/assets/__init__.py +0 -2
- wandb/sdk/internal/tb_watcher.py +11 -10
- wandb/sdk/internal/writer.py +1 -3
- wandb/sdk/launch/__init__.py +2 -1
- wandb/sdk/launch/_launch.py +4 -3
- wandb/sdk/launch/_launch_add.py +2 -2
- wandb/sdk/launch/builder/kaniko_builder.py +0 -1
- wandb/sdk/launch/create_job.py +1 -0
- wandb/sdk/launch/environment/local_environment.py +0 -1
- wandb/sdk/launch/errors.py +0 -6
- wandb/sdk/launch/registry/local_registry.py +0 -2
- wandb/sdk/launch/runner/abstract.py +0 -5
- wandb/sdk/launch/sweeps/__init__.py +0 -2
- wandb/sdk/launch/sweeps/scheduler.py +0 -2
- wandb/sdk/launch/sweeps/scheduler_sweep.py +0 -1
- wandb/sdk/lib/_settings_toposort_generated.py +1 -0
- wandb/sdk/lib/apikey.py +3 -3
- wandb/sdk/lib/file_stream_utils.py +1 -1
- wandb/sdk/lib/filesystem.py +1 -1
- wandb/sdk/lib/ipython.py +16 -9
- wandb/sdk/lib/mailbox.py +0 -4
- wandb/sdk/lib/printer.py +44 -8
- wandb/sdk/lib/retry.py +1 -1
- wandb/sdk/lib/sock_client.py +0 -5
- wandb/sdk/service/server.py +2 -11
- wandb/sdk/service/server_sock.py +0 -2
- wandb/sdk/service/service.py +3 -3
- wandb/sdk/service/streams.py +2 -4
- wandb/sdk/wandb_init.py +20 -20
- wandb/sdk/wandb_login.py +1 -1
- wandb/sdk/wandb_require.py +1 -4
- wandb/sdk/wandb_run.py +97 -115
- wandb/sdk/wandb_settings.py +23 -6
- wandb/sdk/wandb_setup.py +1 -5
- wandb/sdk/wandb_sync.py +2 -1
- wandb/util.py +49 -21
- wandb/wandb_agent.py +3 -3
- wandb/wandb_controller.py +2 -2
- {wandb-0.18.5.dist-info → wandb-0.18.7.dist-info}/METADATA +1 -2
- {wandb-0.18.5.dist-info → wandb-0.18.7.dist-info}/RECORD +144 -146
- {wandb-0.18.5.dist-info → wandb-0.18.7.dist-info}/WHEEL +1 -1
- wandb/sdk/internal/system/assets/gpu_apple.py +0 -177
- wandb/sdk/lib/_wburls_generate.py +0 -25
- wandb/sdk/lib/_wburls_generated.py +0 -22
- wandb/sdk/lib/tracelog.py +0 -255
- wandb/sdk/lib/wburls.py +0 -46
- {wandb-0.18.5.dist-info → wandb-0.18.7.dist-info}/entry_points.txt +0 -0
- {wandb-0.18.5.dist-info → wandb-0.18.7.dist-info}/licenses/LICENSE +0 -0
@@ -21,7 +21,7 @@ def regressor(model, X_train, X_test, y_train, y_test, model_name="Regressor"):
|
|
21
21
|
|
22
22
|
Should only be called with a fitted regressor (otherwise an error is thrown).
|
23
23
|
|
24
|
-
|
24
|
+
Args:
|
25
25
|
model: (regressor) Takes in a fitted regressor.
|
26
26
|
X_train: (arr) Training set features.
|
27
27
|
y_train: (arr) Training set labels.
|
@@ -62,7 +62,7 @@ def outlier_candidates(regressor=None, X=None, y=None): # noqa: N803
|
|
62
62
|
|
63
63
|
Please note this function fits the model on the training set when called.
|
64
64
|
|
65
|
-
|
65
|
+
Args:
|
66
66
|
model: (regressor) Takes in a fitted regressor.
|
67
67
|
X: (arr) Training set features.
|
68
68
|
y: (arr) Training set labels.
|
@@ -96,7 +96,7 @@ def residuals(regressor=None, X=None, y=None): # noqa: N803
|
|
96
96
|
|
97
97
|
Please note this function fits variations of the model on the training set when called.
|
98
98
|
|
99
|
-
|
99
|
+
Args:
|
100
100
|
regressor: (regressor) Takes in a fitted regressor.
|
101
101
|
X: (arr) Training set features.
|
102
102
|
y: (arr) Training set labels.
|
@@ -16,7 +16,7 @@ def summary_metrics(model=None, X=None, y=None, X_test=None, y_test=None): # no
|
|
16
16
|
|
17
17
|
Should only be called with a fitted model (otherwise an error is thrown).
|
18
18
|
|
19
|
-
|
19
|
+
Args:
|
20
20
|
model: (clf or reg) Takes in a fitted regressor or classifier.
|
21
21
|
X: (arr) Training set features.
|
22
22
|
y: (arr) Training set labels.
|
@@ -60,7 +60,7 @@ def learning_curve(
|
|
60
60
|
|
61
61
|
Please note this function fits the model to datasets of varying sizes when called.
|
62
62
|
|
63
|
-
|
63
|
+
Args:
|
64
64
|
model: (clf or reg) Takes in a fitted regressor or classifier.
|
65
65
|
X: (arr) Dataset features.
|
66
66
|
y: (arr) Dataset labels.
|
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
|
5
5
|
|
6
6
|
import wandb
|
7
7
|
import wandb.util
|
8
|
-
from wandb.plot
|
8
|
+
from wandb.plot import plot_table
|
9
9
|
from wandb.sdk.lib import telemetry
|
10
10
|
|
11
11
|
if TYPE_CHECKING:
|
@@ -200,7 +200,7 @@ def tf_summary_to_dict( # noqa: C901
|
|
200
200
|
data_table = wandb.Table(data=data, columns=["recall", "precision"])
|
201
201
|
name = namespaced_tag(value.tag, namespace)
|
202
202
|
|
203
|
-
values[name] =
|
203
|
+
values[name] = plot_table(
|
204
204
|
"wandb/line/v0",
|
205
205
|
data_table,
|
206
206
|
{"x": "recall", "y": "precision"},
|
@@ -114,7 +114,7 @@ class WandBUltralyticsCallback:
|
|
114
114
|
model(["img1.jpeg", "img2.jpeg"])
|
115
115
|
```
|
116
116
|
|
117
|
-
|
117
|
+
Args:
|
118
118
|
model: (ultralytics.yolo.engine.model.YOLO) YOLO Model of type
|
119
119
|
`ultralytics.yolo.engine.model.YOLO`.
|
120
120
|
epoch_logging_interval: (int) interval to log the prediction visualizations
|
@@ -466,7 +466,7 @@ def add_wandb_callback(
|
|
466
466
|
model(["img1.jpeg", "img2.jpeg"])
|
467
467
|
```
|
468
468
|
|
469
|
-
|
469
|
+
Args:
|
470
470
|
model: (ultralytics.yolo.engine.model.YOLO) YOLO Model of type
|
471
471
|
`ultralytics.yolo.engine.model.YOLO`.
|
472
472
|
epoch_logging_interval: (int) interval to log the prediction visualizations
|
@@ -55,7 +55,7 @@ def wandb_callback() -> "Callable":
|
|
55
55
|
class WandbCallback(xgb.callback.TrainingCallback):
|
56
56
|
"""`WandbCallback` automatically integrates XGBoost with wandb.
|
57
57
|
|
58
|
-
|
58
|
+
Args:
|
59
59
|
log_model: (boolean) if True save and upload the model to Weights & Biases Artifacts
|
60
60
|
log_feature_importance: (boolean) if True log a feature importance bar plot
|
61
61
|
importance_type: (str) one of {weight, gain, cover, total_gain, total_cover} for tree model. weight for linear model.
|
wandb/jupyter.py
CHANGED
wandb/plot/__init__.py
CHANGED
@@ -1,11 +1,9 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
from wandb.plot.roc_curve import roc_curve
|
8
|
-
from wandb.plot.scatter import scatter
|
1
|
+
"""Chart Visualization Utilities
|
2
|
+
|
3
|
+
This module offers a collection of predefined chart types, along with functionality
|
4
|
+
for creating custom charts, enabling flexible visualization of your data beyond the
|
5
|
+
built-in options.
|
6
|
+
"""
|
9
7
|
|
10
8
|
__all__ = [
|
11
9
|
"line",
|
@@ -17,3 +15,14 @@ __all__ = [
|
|
17
15
|
"confusion_matrix",
|
18
16
|
"line_series",
|
19
17
|
]
|
18
|
+
|
19
|
+
from wandb.plot.bar import bar
|
20
|
+
from wandb.plot.confusion_matrix import confusion_matrix
|
21
|
+
from wandb.plot.custom_chart import CustomChart, plot_table
|
22
|
+
from wandb.plot.histogram import histogram
|
23
|
+
from wandb.plot.line import line
|
24
|
+
from wandb.plot.line_series import line_series
|
25
|
+
from wandb.plot.pr_curve import pr_curve
|
26
|
+
from wandb.plot.roc_curve import roc_curve
|
27
|
+
from wandb.plot.scatter import scatter
|
28
|
+
from wandb.plot.viz import Visualize, visualize
|
wandb/plot/bar.py
CHANGED
@@ -1,45 +1,71 @@
|
|
1
|
-
from
|
1
|
+
from __future__ import annotations
|
2
2
|
|
3
|
-
from
|
3
|
+
from typing import TYPE_CHECKING
|
4
|
+
|
5
|
+
from wandb.plot.custom_chart import plot_table
|
4
6
|
|
5
7
|
if TYPE_CHECKING:
|
6
8
|
import wandb
|
9
|
+
from wandb.plot.custom_chart import CustomChart
|
7
10
|
|
8
11
|
|
9
12
|
def bar(
|
10
|
-
table:
|
13
|
+
table: wandb.Table,
|
11
14
|
label: str,
|
12
15
|
value: str,
|
13
|
-
title:
|
14
|
-
split_table:
|
15
|
-
):
|
16
|
-
"""
|
17
|
-
|
18
|
-
|
19
|
-
table (wandb.Table):
|
20
|
-
label (
|
21
|
-
value (
|
22
|
-
title (
|
23
|
-
split_table (bool):
|
16
|
+
title: str = "",
|
17
|
+
split_table: bool = False,
|
18
|
+
) -> CustomChart:
|
19
|
+
"""Constructs a bar chart from a wandb.Table of data.
|
20
|
+
|
21
|
+
Args:
|
22
|
+
table (wandb.Table): A table containing the data for the bar chart.
|
23
|
+
label (str): The name of the column to use for the labels of each bar.
|
24
|
+
value (str): The name of the column to use for the values of each bar.
|
25
|
+
title (str): The title of the bar chart.
|
26
|
+
split_table (bool): Whether the table should be split into a separate section
|
27
|
+
in the W&B UI. If `True`, the table will be displayed in a section named
|
28
|
+
"Custom Chart Tables". Default is `False`.
|
24
29
|
|
25
30
|
Returns:
|
26
|
-
A
|
31
|
+
CustomChart: A custom chart object that can be logged to W&B. To log the
|
32
|
+
chart, pass it to `wandb.log()`.
|
27
33
|
|
28
34
|
Example:
|
29
35
|
```
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
36
|
+
import random
|
37
|
+
import wandb
|
38
|
+
|
39
|
+
# Generate random data for the table
|
40
|
+
data = [
|
41
|
+
['car', random.uniform(0, 1)],
|
42
|
+
['bus', random.uniform(0, 1)],
|
43
|
+
['road', random.uniform(0, 1)],
|
44
|
+
['person', random.uniform(0, 1)],
|
45
|
+
]
|
46
|
+
|
47
|
+
# Create a table with the data
|
48
|
+
table = wandb.Table(data=data, columns=["class", "accuracy"])
|
49
|
+
|
50
|
+
# Initialize a W&B run and log the bar plot
|
51
|
+
with wandb.init(project="bar_chart") as run:
|
52
|
+
|
53
|
+
# Create a bar plot from the table
|
54
|
+
bar_plot = wandb.plot.bar(
|
55
|
+
table=table,
|
56
|
+
label="class",
|
57
|
+
value="accuracy",
|
58
|
+
title="Object Classification Accuracy",
|
59
|
+
)
|
60
|
+
|
61
|
+
# Log the bar chart to W&B
|
62
|
+
run.log({'bar_plot': bar_plot})
|
37
63
|
```
|
38
64
|
"""
|
39
|
-
return
|
40
|
-
|
41
|
-
|
42
|
-
{"label": label, "value": value},
|
43
|
-
{"title": title},
|
65
|
+
return plot_table(
|
66
|
+
data_table=table,
|
67
|
+
vega_spec_name="wandb/bar/v0",
|
68
|
+
fields={"label": label, "value": value},
|
69
|
+
string_fields={"title": title},
|
44
70
|
split_table=split_table,
|
45
71
|
)
|
wandb/plot/confusion_matrix.py
CHANGED
@@ -1,100 +1,181 @@
|
|
1
|
-
from
|
1
|
+
from __future__ import annotations
|
2
2
|
|
3
|
+
from typing import TYPE_CHECKING, Sequence, TypeVar
|
4
|
+
|
5
|
+
import wandb
|
3
6
|
from wandb import util
|
4
|
-
from wandb.
|
5
|
-
|
7
|
+
from wandb.plot.custom_chart import plot_table
|
8
|
+
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from wandb.plot.custom_chart import CustomChart
|
6
11
|
|
7
|
-
|
12
|
+
T = TypeVar("T")
|
8
13
|
|
9
14
|
|
10
15
|
def confusion_matrix(
|
11
|
-
probs:
|
12
|
-
y_true:
|
13
|
-
preds:
|
14
|
-
class_names:
|
15
|
-
title:
|
16
|
-
split_table:
|
17
|
-
):
|
18
|
-
"""
|
19
|
-
|
20
|
-
|
21
|
-
probs (
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
16
|
+
probs: Sequence[Sequence[float]] | None = None,
|
17
|
+
y_true: Sequence[T] | None = None,
|
18
|
+
preds: Sequence[T] | None = None,
|
19
|
+
class_names: Sequence[str] | None = None,
|
20
|
+
title: str = "Confusion Matrix Curve",
|
21
|
+
split_table: bool = False,
|
22
|
+
) -> CustomChart:
|
23
|
+
"""Constructs a confusion matrix from a sequence of probabilities or predictions.
|
24
|
+
|
25
|
+
Args:
|
26
|
+
probs (Sequence[Sequence[float]] | None): A sequence of predicted probabilities for each
|
27
|
+
class. The sequence shape should be (N, K) where N is the number of samples
|
28
|
+
and K is the number of classes. If provided, `preds` should not be provided.
|
29
|
+
y_true (Sequence[T] | None): A sequence of true labels.
|
30
|
+
preds (Sequence[T] | None): A sequence of predicted class labels. If provided,
|
31
|
+
`probs` should not be provided.
|
32
|
+
class_names (Sequence[str] | None): Sequence of class names. If not
|
33
|
+
provided, class names will be defined as "Class_1", "Class_2", etc.
|
34
|
+
title (str): Title of the confusion matrix chart.
|
35
|
+
split_table (bool): Whether the table should be split into a separate section
|
36
|
+
in the W&B UI. If `True`, the table will be displayed in a section named
|
37
|
+
"Custom Chart Tables". Default is `False`.
|
26
38
|
|
27
39
|
Returns:
|
28
|
-
|
29
|
-
|
40
|
+
CustomChart: A custom chart object that can be logged to W&B. To log the
|
41
|
+
chart, pass it to `wandb.log()`.
|
42
|
+
|
43
|
+
Raises:
|
44
|
+
ValueError: If both `probs` and `preds` are provided or if the number of
|
45
|
+
predictions and true labels are not equal. If the number of unique
|
46
|
+
predicted classes exceeds the number of class names or if the number of
|
47
|
+
unique true labels exceeds the number of class names.
|
48
|
+
wandb.Error: If numpy is not installed.
|
49
|
+
|
50
|
+
Examples:
|
51
|
+
1. Logging a confusion matrix with random probabilities for wildlife
|
52
|
+
classification:
|
53
|
+
```
|
54
|
+
import numpy as np
|
55
|
+
import wandb
|
56
|
+
|
57
|
+
# Define class names for wildlife
|
58
|
+
wildlife_class_names = ["Lion", "Tiger", "Elephant", "Zebra"]
|
59
|
+
|
60
|
+
# Generate random true labels (0 to 3 for 10 samples)
|
61
|
+
wildlife_y_true = np.random.randint(0, 4, size=10)
|
62
|
+
|
63
|
+
# Generate random probabilities for each class (10 samples x 4 classes)
|
64
|
+
wildlife_probs = np.random.rand(10, 4)
|
65
|
+
wildlife_probs = np.exp(wildlife_probs) / np.sum(
|
66
|
+
np.exp(wildlife_probs),
|
67
|
+
axis=1,
|
68
|
+
keepdims=True,
|
69
|
+
)
|
70
|
+
|
71
|
+
# Initialize W&B run and log confusion matrix
|
72
|
+
with wandb.init(project="wildlife_classification") as run:
|
73
|
+
confusion_matrix = wandb.plot.confusion_matrix(
|
74
|
+
probs=wildlife_probs,
|
75
|
+
y_true=wildlife_y_true,
|
76
|
+
class_names=wildlife_class_names,
|
77
|
+
title="Wildlife Classification Confusion Matrix",
|
78
|
+
)
|
79
|
+
run.log({"wildlife_confusion_matrix": confusion_matrix})
|
80
|
+
```
|
81
|
+
In this example, random probabilities are used to generate a confusion
|
82
|
+
matrix.
|
30
83
|
|
31
|
-
|
84
|
+
2. Logging a confusion matrix with simulated model predictions and 85%
|
85
|
+
accuracy:
|
32
86
|
```
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
87
|
+
import numpy as np
|
88
|
+
import wandb
|
89
|
+
|
90
|
+
# Define class names for wildlife
|
91
|
+
wildlife_class_names = ["Lion", "Tiger", "Elephant", "Zebra"]
|
92
|
+
|
93
|
+
# Simulate true labels for 200 animal images (imbalanced distribution)
|
94
|
+
wildlife_y_true = np.random.choice(
|
95
|
+
[0, 1, 2, 3],
|
96
|
+
size=200,
|
97
|
+
p=[0.2, 0.3, 0.25, 0.25],
|
98
|
+
)
|
99
|
+
|
100
|
+
# Simulate model predictions with 85% accuracy
|
101
|
+
wildlife_preds = [
|
102
|
+
y_t
|
103
|
+
if np.random.rand() < 0.85
|
104
|
+
else np.random.choice([x for x in range(4) if x != y_t])
|
105
|
+
for y_t in wildlife_y_true
|
106
|
+
]
|
107
|
+
|
108
|
+
# Initialize W&B run and log confusion matrix
|
109
|
+
with wandb.init(project="wildlife_classification") as run:
|
110
|
+
confusion_matrix = wandb.plot.confusion_matrix(
|
111
|
+
preds=wildlife_preds,
|
112
|
+
y_true=wildlife_y_true,
|
113
|
+
class_names=wildlife_class_names,
|
114
|
+
title="Simulated Wildlife Classification Confusion Matrix"
|
115
|
+
)
|
116
|
+
run.log({"wildlife_confusion_matrix": confusion_matrix})
|
38
117
|
```
|
118
|
+
In this example, predictions are simulated with 85% accuracy to generate a
|
119
|
+
confusion matrix.
|
39
120
|
"""
|
40
121
|
np = util.get_module(
|
41
122
|
"numpy",
|
42
|
-
required=
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
"confusion_matrix has been updated to accept"
|
47
|
-
" probabilities as the default first argument. Use preds=..."
|
123
|
+
required=(
|
124
|
+
"numpy is required to use wandb.plot.confusion_matrix, "
|
125
|
+
"install with `pip install numpy`",
|
126
|
+
),
|
48
127
|
)
|
49
128
|
|
50
|
-
|
51
|
-
|
52
|
-
), "Must provide probabilities or predictions but not both to confusion matrix"
|
129
|
+
if probs is not None and preds is not None:
|
130
|
+
raise ValueError("Only one of `probs` or `preds` should be provided, not both.")
|
53
131
|
|
54
132
|
if probs is not None:
|
55
133
|
preds = np.argmax(probs, axis=1).tolist()
|
56
134
|
|
57
|
-
|
58
|
-
|
59
|
-
), "Number of predictions and label indices must match"
|
135
|
+
if len(preds) != len(y_true):
|
136
|
+
raise ValueError("The number of predictions and true labels must be equal.")
|
60
137
|
|
61
138
|
if class_names is not None:
|
62
139
|
n_classes = len(class_names)
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
)
|
140
|
+
class_idx = list(range(n_classes))
|
141
|
+
if len(set(preds)) > len(class_names):
|
142
|
+
raise ValueError(
|
143
|
+
"The number of unique predicted classes exceeds the number of class names."
|
144
|
+
)
|
145
|
+
|
146
|
+
if len(set(y_true)) > len(class_names):
|
147
|
+
raise ValueError(
|
148
|
+
"The number of unique true labels exceeds the number of class names."
|
149
|
+
)
|
70
150
|
else:
|
71
|
-
|
72
|
-
n_classes = len(
|
73
|
-
class_names = [f"Class_{i}" for i in range(
|
74
|
-
|
75
|
-
#
|
76
|
-
class_mapping = {}
|
77
|
-
|
78
|
-
class_mapping[val] = i
|
151
|
+
class_idx = set(preds).union(set(y_true))
|
152
|
+
n_classes = len(class_idx)
|
153
|
+
class_names = [f"Class_{i+1}" for i in range(n_classes)]
|
154
|
+
|
155
|
+
# Create a mapping from class name to index
|
156
|
+
class_mapping = {val: i for i, val in enumerate(sorted(list(class_idx)))}
|
157
|
+
|
79
158
|
counts = np.zeros((n_classes, n_classes))
|
80
159
|
for i in range(len(preds)):
|
81
160
|
counts[class_mapping[y_true[i]], class_mapping[preds[i]]] += 1
|
82
161
|
|
83
|
-
data = [
|
84
|
-
|
85
|
-
for
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
162
|
+
data = [
|
163
|
+
[class_names[i], class_names[j], counts[i, j]]
|
164
|
+
for i in range(n_classes)
|
165
|
+
for j in range(n_classes)
|
166
|
+
]
|
167
|
+
|
168
|
+
return plot_table(
|
169
|
+
data_table=wandb.Table(
|
170
|
+
columns=["Actual", "Predicted", "nPredictions"],
|
171
|
+
data=data,
|
172
|
+
),
|
173
|
+
vega_spec_name="wandb/confusion_matrix/v1",
|
174
|
+
fields={
|
175
|
+
"Actual": "Actual",
|
176
|
+
"Predicted": "Predicted",
|
177
|
+
"nPredictions": "nPredictions",
|
178
|
+
},
|
179
|
+
string_fields={"title": title},
|
99
180
|
split_table=split_table,
|
100
181
|
)
|
@@ -0,0 +1,124 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from typing import Any
|
5
|
+
|
6
|
+
import wandb
|
7
|
+
|
8
|
+
|
9
|
+
@dataclass
|
10
|
+
class CustomChartSpec:
|
11
|
+
spec_name: str
|
12
|
+
fields: dict[str, Any]
|
13
|
+
string_fields: dict[str, Any]
|
14
|
+
key: str = ""
|
15
|
+
panel_type: str = "Vega2"
|
16
|
+
split_table: bool = False
|
17
|
+
|
18
|
+
@property
|
19
|
+
def table_key(self) -> str:
|
20
|
+
if not self.key:
|
21
|
+
raise wandb.Error("Key for the custom chart spec is not set.")
|
22
|
+
if self.split_table:
|
23
|
+
return f"Custom Chart Tables/{self.key}_table"
|
24
|
+
return f"{self.key}_table"
|
25
|
+
|
26
|
+
@property
|
27
|
+
def config_value(self) -> dict[str, Any]:
|
28
|
+
return {
|
29
|
+
"panel_type": self.panel_type,
|
30
|
+
"panel_config": {
|
31
|
+
"panelDefId": self.spec_name,
|
32
|
+
"fieldSettings": self.fields,
|
33
|
+
"stringSettings": self.string_fields,
|
34
|
+
"transform": {"name": "tableWithLeafColNames"},
|
35
|
+
"userQuery": {
|
36
|
+
"queryFields": [
|
37
|
+
{
|
38
|
+
"name": "runSets",
|
39
|
+
"args": [{"name": "runSets", "value": "${runSets}"}],
|
40
|
+
"fields": [
|
41
|
+
{"name": "id", "fields": []},
|
42
|
+
{"name": "name", "fields": []},
|
43
|
+
{"name": "_defaultColorIndex", "fields": []},
|
44
|
+
{
|
45
|
+
"name": "summaryTable",
|
46
|
+
"args": [
|
47
|
+
{
|
48
|
+
"name": "tableKey",
|
49
|
+
"value": self.table_key,
|
50
|
+
}
|
51
|
+
],
|
52
|
+
"fields": [],
|
53
|
+
},
|
54
|
+
],
|
55
|
+
}
|
56
|
+
],
|
57
|
+
},
|
58
|
+
},
|
59
|
+
}
|
60
|
+
|
61
|
+
@property
|
62
|
+
def config_key(self) -> tuple[str, str, str]:
|
63
|
+
return ("_wandb", "visualize", self.key)
|
64
|
+
|
65
|
+
|
66
|
+
@dataclass
|
67
|
+
class CustomChart:
|
68
|
+
table: wandb.Table
|
69
|
+
spec: CustomChartSpec
|
70
|
+
|
71
|
+
def set_key(self, key: str):
|
72
|
+
"""Sets the key for the spec and updates dependent configurations."""
|
73
|
+
self.spec.key = key
|
74
|
+
|
75
|
+
|
76
|
+
def plot_table(
|
77
|
+
vega_spec_name: str,
|
78
|
+
data_table: wandb.Table,
|
79
|
+
fields: dict[str, Any],
|
80
|
+
string_fields: dict[str, Any] | None = None,
|
81
|
+
split_table: bool = False,
|
82
|
+
) -> CustomChart:
|
83
|
+
"""Creates a custom charts using a Vega-Lite specification and a `wandb.Table`.
|
84
|
+
|
85
|
+
This function creates a custom chart based on a Vega-Lite specification and
|
86
|
+
a data table represented by a `wandb.Table` object. The specification needs
|
87
|
+
to be predefined and stored in the W&B backend. The function returns a custom
|
88
|
+
chart object that can be logged to W&B using `wandb.log()`.
|
89
|
+
|
90
|
+
Args:
|
91
|
+
vega_spec_name (str): The name or identifier of the Vega-Lite spec
|
92
|
+
that defines the visualization structure.
|
93
|
+
data_table (wandb.Table): A `wandb.Table` object containing the data to be
|
94
|
+
visualized.
|
95
|
+
fields (dict[str, Any]): A mapping between the fields in the Vega-Lite spec and the
|
96
|
+
corresponding columns in the data table to be visualized.
|
97
|
+
string_fields (dict[str, Any] | None): A dictionary for providing values for any string constants
|
98
|
+
required by the custom visualization.
|
99
|
+
split_table (bool): Whether the table should be split into a separate section
|
100
|
+
in the W&B UI. If `True`, the table will be displayed in a section named
|
101
|
+
"Custom Chart Tables". Default is `False`.
|
102
|
+
|
103
|
+
Returns:
|
104
|
+
CustomChart: A custom chart object that can be logged to W&B. To log the
|
105
|
+
chart, pass it to `wandb.log()`.
|
106
|
+
|
107
|
+
Raises:
|
108
|
+
wandb.Error: If `data_table` is not a `wandb.Table` object.
|
109
|
+
"""
|
110
|
+
|
111
|
+
if not isinstance(data_table, wandb.Table):
|
112
|
+
raise wandb.Error(
|
113
|
+
f"Expected `data_table` to be `wandb.Table` type, instead got {type(data_table).__name__}"
|
114
|
+
)
|
115
|
+
|
116
|
+
return CustomChart(
|
117
|
+
table=data_table,
|
118
|
+
spec=CustomChartSpec(
|
119
|
+
spec_name=vega_spec_name,
|
120
|
+
fields=fields,
|
121
|
+
string_fields=string_fields or {},
|
122
|
+
split_table=split_table,
|
123
|
+
),
|
124
|
+
)
|