wandb 0.18.5__py3-none-win32.whl → 0.18.6__py3-none-win32.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (129) hide show
  1. wandb/__init__.py +2 -2
  2. wandb/__init__.pyi +21 -19
  3. wandb/agents/pyagent.py +1 -1
  4. wandb/apis/importers/wandb.py +1 -1
  5. wandb/apis/normalize.py +2 -18
  6. wandb/apis/public/api.py +122 -62
  7. wandb/apis/public/artifacts.py +8 -3
  8. wandb/apis/public/files.py +17 -2
  9. wandb/apis/public/jobs.py +2 -2
  10. wandb/apis/public/query_generator.py +1 -1
  11. wandb/apis/public/runs.py +8 -8
  12. wandb/apis/public/teams.py +3 -3
  13. wandb/apis/public/users.py +1 -1
  14. wandb/apis/public/utils.py +68 -0
  15. wandb/bin/gpu_stats.exe +0 -0
  16. wandb/bin/wandb-core +0 -0
  17. wandb/cli/cli.py +12 -3
  18. wandb/data_types.py +1 -1
  19. wandb/docker/__init__.py +2 -1
  20. wandb/docker/auth.py +2 -3
  21. wandb/errors/links.py +73 -0
  22. wandb/errors/term.py +7 -6
  23. wandb/filesync/step_prepare.py +1 -1
  24. wandb/filesync/upload_job.py +1 -1
  25. wandb/integration/catboost/catboost.py +2 -2
  26. wandb/integration/diffusers/pipeline_resolver.py +1 -1
  27. wandb/integration/diffusers/resolvers/multimodal.py +6 -6
  28. wandb/integration/diffusers/resolvers/utils.py +1 -1
  29. wandb/integration/fastai/__init__.py +3 -2
  30. wandb/integration/keras/callbacks/metrics_logger.py +1 -1
  31. wandb/integration/keras/callbacks/model_checkpoint.py +1 -1
  32. wandb/integration/keras/keras.py +1 -1
  33. wandb/integration/kfp/kfp_patch.py +1 -1
  34. wandb/integration/lightgbm/__init__.py +2 -2
  35. wandb/integration/magic.py +2 -2
  36. wandb/integration/metaflow/metaflow.py +1 -1
  37. wandb/integration/sacred/__init__.py +1 -1
  38. wandb/integration/sagemaker/auth.py +1 -1
  39. wandb/integration/sklearn/plot/classifier.py +7 -7
  40. wandb/integration/sklearn/plot/clusterer.py +3 -3
  41. wandb/integration/sklearn/plot/regressor.py +3 -3
  42. wandb/integration/sklearn/plot/shared.py +2 -2
  43. wandb/integration/tensorboard/log.py +2 -2
  44. wandb/integration/ultralytics/callback.py +2 -2
  45. wandb/integration/xgboost/xgboost.py +1 -1
  46. wandb/jupyter.py +0 -1
  47. wandb/plot/__init__.py +17 -8
  48. wandb/plot/bar.py +53 -27
  49. wandb/plot/confusion_matrix.py +151 -70
  50. wandb/plot/custom_chart.py +124 -0
  51. wandb/plot/histogram.py +46 -20
  52. wandb/plot/line.py +57 -26
  53. wandb/plot/line_series.py +148 -60
  54. wandb/plot/pr_curve.py +89 -44
  55. wandb/plot/roc_curve.py +82 -37
  56. wandb/plot/scatter.py +53 -20
  57. wandb/plot/viz.py +20 -102
  58. wandb/sdk/artifacts/artifact.py +280 -328
  59. wandb/sdk/artifacts/artifact_manifest.py +10 -9
  60. wandb/sdk/artifacts/artifact_manifest_entry.py +1 -1
  61. wandb/sdk/artifacts/storage_handlers/azure_handler.py +9 -4
  62. wandb/sdk/artifacts/storage_handlers/gcs_handler.py +1 -3
  63. wandb/sdk/artifacts/storage_handlers/s3_handler.py +1 -1
  64. wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +2 -2
  65. wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -1
  66. wandb/sdk/backend/backend.py +0 -1
  67. wandb/sdk/data_types/audio.py +1 -1
  68. wandb/sdk/data_types/base_types/media.py +66 -5
  69. wandb/sdk/data_types/bokeh.py +1 -1
  70. wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +1 -1
  71. wandb/sdk/data_types/helper_types/image_mask.py +2 -2
  72. wandb/sdk/data_types/histogram.py +1 -1
  73. wandb/sdk/data_types/html.py +1 -1
  74. wandb/sdk/data_types/image.py +1 -1
  75. wandb/sdk/data_types/molecule.py +3 -3
  76. wandb/sdk/data_types/object_3d.py +4 -4
  77. wandb/sdk/data_types/plotly.py +1 -1
  78. wandb/sdk/data_types/saved_model.py +0 -1
  79. wandb/sdk/data_types/table.py +7 -7
  80. wandb/sdk/data_types/trace_tree.py +1 -1
  81. wandb/sdk/data_types/video.py +4 -3
  82. wandb/sdk/interface/router.py +0 -2
  83. wandb/sdk/internal/datastore.py +1 -1
  84. wandb/sdk/internal/file_pusher.py +1 -1
  85. wandb/sdk/internal/file_stream.py +4 -4
  86. wandb/sdk/internal/handler.py +3 -2
  87. wandb/sdk/internal/internal.py +1 -1
  88. wandb/sdk/internal/internal_api.py +178 -63
  89. wandb/sdk/internal/job_builder.py +4 -3
  90. wandb/sdk/internal/system/assets/__init__.py +0 -2
  91. wandb/sdk/internal/tb_watcher.py +11 -10
  92. wandb/sdk/launch/_launch.py +4 -3
  93. wandb/sdk/launch/_launch_add.py +2 -2
  94. wandb/sdk/launch/builder/kaniko_builder.py +0 -1
  95. wandb/sdk/launch/create_job.py +1 -0
  96. wandb/sdk/launch/environment/local_environment.py +0 -1
  97. wandb/sdk/launch/errors.py +0 -6
  98. wandb/sdk/launch/registry/local_registry.py +0 -2
  99. wandb/sdk/launch/runner/abstract.py +0 -5
  100. wandb/sdk/launch/sweeps/__init__.py +0 -2
  101. wandb/sdk/launch/sweeps/scheduler.py +0 -2
  102. wandb/sdk/launch/sweeps/scheduler_sweep.py +0 -1
  103. wandb/sdk/lib/apikey.py +3 -3
  104. wandb/sdk/lib/file_stream_utils.py +1 -1
  105. wandb/sdk/lib/filesystem.py +1 -1
  106. wandb/sdk/lib/ipython.py +16 -9
  107. wandb/sdk/lib/mailbox.py +0 -4
  108. wandb/sdk/lib/printer.py +44 -8
  109. wandb/sdk/lib/retry.py +1 -1
  110. wandb/sdk/service/service.py +3 -3
  111. wandb/sdk/service/streams.py +2 -4
  112. wandb/sdk/wandb_init.py +20 -20
  113. wandb/sdk/wandb_login.py +1 -1
  114. wandb/sdk/wandb_require.py +1 -4
  115. wandb/sdk/wandb_run.py +57 -69
  116. wandb/sdk/wandb_settings.py +3 -4
  117. wandb/sdk/wandb_sync.py +2 -1
  118. wandb/util.py +46 -18
  119. wandb/wandb_agent.py +3 -3
  120. wandb/wandb_controller.py +2 -2
  121. {wandb-0.18.5.dist-info → wandb-0.18.6.dist-info}/METADATA +1 -1
  122. {wandb-0.18.5.dist-info → wandb-0.18.6.dist-info}/RECORD +125 -126
  123. wandb/sdk/internal/system/assets/gpu_apple.py +0 -177
  124. wandb/sdk/lib/_wburls_generate.py +0 -25
  125. wandb/sdk/lib/_wburls_generated.py +0 -22
  126. wandb/sdk/lib/wburls.py +0 -46
  127. {wandb-0.18.5.dist-info → wandb-0.18.6.dist-info}/WHEEL +0 -0
  128. {wandb-0.18.5.dist-info → wandb-0.18.6.dist-info}/entry_points.txt +0 -0
  129. {wandb-0.18.5.dist-info → wandb-0.18.6.dist-info}/licenses/LICENSE +0 -0
@@ -21,7 +21,7 @@ def regressor(model, X_train, X_test, y_train, y_test, model_name="Regressor"):
21
21
 
22
22
  Should only be called with a fitted regressor (otherwise an error is thrown).
23
23
 
24
- Arguments:
24
+ Args:
25
25
  model: (regressor) Takes in a fitted regressor.
26
26
  X_train: (arr) Training set features.
27
27
  y_train: (arr) Training set labels.
@@ -62,7 +62,7 @@ def outlier_candidates(regressor=None, X=None, y=None): # noqa: N803
62
62
 
63
63
  Please note this function fits the model on the training set when called.
64
64
 
65
- Arguments:
65
+ Args:
66
66
  model: (regressor) Takes in a fitted regressor.
67
67
  X: (arr) Training set features.
68
68
  y: (arr) Training set labels.
@@ -96,7 +96,7 @@ def residuals(regressor=None, X=None, y=None): # noqa: N803
96
96
 
97
97
  Please note this function fits variations of the model on the training set when called.
98
98
 
99
- Arguments:
99
+ Args:
100
100
  regressor: (regressor) Takes in a fitted regressor.
101
101
  X: (arr) Training set features.
102
102
  y: (arr) Training set labels.
@@ -16,7 +16,7 @@ def summary_metrics(model=None, X=None, y=None, X_test=None, y_test=None): # no
16
16
 
17
17
  Should only be called with a fitted model (otherwise an error is thrown).
18
18
 
19
- Arguments:
19
+ Args:
20
20
  model: (clf or reg) Takes in a fitted regressor or classifier.
21
21
  X: (arr) Training set features.
22
22
  y: (arr) Training set labels.
@@ -60,7 +60,7 @@ def learning_curve(
60
60
 
61
61
  Please note this function fits the model to datasets of varying sizes when called.
62
62
 
63
- Arguments:
63
+ Args:
64
64
  model: (clf or reg) Takes in a fitted regressor or classifier.
65
65
  X: (arr) Dataset features.
66
66
  y: (arr) Dataset labels.
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
5
5
 
6
6
  import wandb
7
7
  import wandb.util
8
- from wandb.plot.viz import custom_chart
8
+ from wandb.plot import plot_table
9
9
  from wandb.sdk.lib import telemetry
10
10
 
11
11
  if TYPE_CHECKING:
@@ -200,7 +200,7 @@ def tf_summary_to_dict( # noqa: C901
200
200
  data_table = wandb.Table(data=data, columns=["recall", "precision"])
201
201
  name = namespaced_tag(value.tag, namespace)
202
202
 
203
- values[name] = custom_chart(
203
+ values[name] = plot_table(
204
204
  "wandb/line/v0",
205
205
  data_table,
206
206
  {"x": "recall", "y": "precision"},
@@ -114,7 +114,7 @@ class WandBUltralyticsCallback:
114
114
  model(["img1.jpeg", "img2.jpeg"])
115
115
  ```
116
116
 
117
- Arguments:
117
+ Args:
118
118
  model: (ultralytics.yolo.engine.model.YOLO) YOLO Model of type
119
119
  `ultralytics.yolo.engine.model.YOLO`.
120
120
  epoch_logging_interval: (int) interval to log the prediction visualizations
@@ -466,7 +466,7 @@ def add_wandb_callback(
466
466
  model(["img1.jpeg", "img2.jpeg"])
467
467
  ```
468
468
 
469
- Arguments:
469
+ Args:
470
470
  model: (ultralytics.yolo.engine.model.YOLO) YOLO Model of type
471
471
  `ultralytics.yolo.engine.model.YOLO`.
472
472
  epoch_logging_interval: (int) interval to log the prediction visualizations
@@ -55,7 +55,7 @@ def wandb_callback() -> "Callable":
55
55
  class WandbCallback(xgb.callback.TrainingCallback):
56
56
  """`WandbCallback` automatically integrates XGBoost with wandb.
57
57
 
58
- Arguments:
58
+ Args:
59
59
  log_model: (boolean) if True save and upload the model to Weights & Biases Artifacts
60
60
  log_feature_importance: (boolean) if True log a feature importance bar plot
61
61
  importance_type: (str) one of {weight, gain, cover, total_gain, total_cover} for tree model. weight for linear model.
wandb/jupyter.py CHANGED
@@ -512,4 +512,3 @@ class Notebook:
512
512
  write(nb, f, version=4)
513
513
  except (OSError, validator.NotebookValidationError) as e:
514
514
  logger.error("Unable to save ipython session history:\n%s", e)
515
- pass
wandb/plot/__init__.py CHANGED
@@ -1,11 +1,9 @@
1
- from wandb.plot.bar import bar
2
- from wandb.plot.confusion_matrix import confusion_matrix
3
- from wandb.plot.histogram import histogram
4
- from wandb.plot.line import line
5
- from wandb.plot.line_series import line_series
6
- from wandb.plot.pr_curve import pr_curve
7
- from wandb.plot.roc_curve import roc_curve
8
- from wandb.plot.scatter import scatter
1
+ """Chart Visualization Utilities
2
+
3
+ This module offers a collection of predefined chart types, along with functionality
4
+ for creating custom charts, enabling flexible visualization of your data beyond the
5
+ built-in options.
6
+ """
9
7
 
10
8
  __all__ = [
11
9
  "line",
@@ -17,3 +15,14 @@ __all__ = [
17
15
  "confusion_matrix",
18
16
  "line_series",
19
17
  ]
18
+
19
+ from wandb.plot.bar import bar
20
+ from wandb.plot.confusion_matrix import confusion_matrix
21
+ from wandb.plot.custom_chart import CustomChart, plot_table
22
+ from wandb.plot.histogram import histogram
23
+ from wandb.plot.line import line
24
+ from wandb.plot.line_series import line_series
25
+ from wandb.plot.pr_curve import pr_curve
26
+ from wandb.plot.roc_curve import roc_curve
27
+ from wandb.plot.scatter import scatter
28
+ from wandb.plot.viz import Visualize, visualize
wandb/plot/bar.py CHANGED
@@ -1,45 +1,71 @@
1
- from typing import TYPE_CHECKING, Optional
1
+ from __future__ import annotations
2
2
 
3
- from .viz import custom_chart
3
+ from typing import TYPE_CHECKING
4
+
5
+ from wandb.plot.custom_chart import plot_table
4
6
 
5
7
  if TYPE_CHECKING:
6
8
  import wandb
9
+ from wandb.plot.custom_chart import CustomChart
7
10
 
8
11
 
9
12
  def bar(
10
- table: "wandb.Table",
13
+ table: wandb.Table,
11
14
  label: str,
12
15
  value: str,
13
- title: Optional[str] = None,
14
- split_table: Optional[bool] = False,
15
- ):
16
- """Construct a bar plot.
17
-
18
- Arguments:
19
- table (wandb.Table): Table of data.
20
- label (string): Name of column to use as each bar's label.
21
- value (string): Name of column to use as each bar's value.
22
- title (string): Plot title.
23
- split_table (bool): If True, adds "Custom Chart Tables/" to the key of the table so that it's logged in a different section.
16
+ title: str = "",
17
+ split_table: bool = False,
18
+ ) -> CustomChart:
19
+ """Constructs a bar chart from a wandb.Table of data.
20
+
21
+ Args:
22
+ table (wandb.Table): A table containing the data for the bar chart.
23
+ label (str): The name of the column to use for the labels of each bar.
24
+ value (str): The name of the column to use for the values of each bar.
25
+ title (str): The title of the bar chart.
26
+ split_table (bool): Whether the table should be split into a separate section
27
+ in the W&B UI. If `True`, the table will be displayed in a section named
28
+ "Custom Chart Tables". Default is `False`.
24
29
 
25
30
  Returns:
26
- A plot object, to be passed to wandb.log()
31
+ CustomChart: A custom chart object that can be logged to W&B. To log the
32
+ chart, pass it to `wandb.log()`.
27
33
 
28
34
  Example:
29
35
  ```
30
- table = wandb.Table(data=[
31
- ['car', random.random()],
32
- ['bus', random.random()],
33
- ['road', random.random()],
34
- ['person', random.random()],
35
- ], columns=["class", "acc"])
36
- wandb.log({'bar-plot1': wandb.plot.bar(table, "class", "acc")})
36
+ import random
37
+ import wandb
38
+
39
+ # Generate random data for the table
40
+ data = [
41
+ ['car', random.uniform(0, 1)],
42
+ ['bus', random.uniform(0, 1)],
43
+ ['road', random.uniform(0, 1)],
44
+ ['person', random.uniform(0, 1)],
45
+ ]
46
+
47
+ # Create a table with the data
48
+ table = wandb.Table(data=data, columns=["class", "accuracy"])
49
+
50
+ # Initialize a W&B run and log the bar plot
51
+ with wandb.init(project="bar_chart") as run:
52
+
53
+ # Create a bar plot from the table
54
+ bar_plot = wandb.plot.bar(
55
+ table=table,
56
+ label="class",
57
+ value="accuracy",
58
+ title="Object Classification Accuracy",
59
+ )
60
+
61
+ # Log the bar chart to W&B
62
+ run.log({'bar_plot': bar_plot})
37
63
  ```
38
64
  """
39
- return custom_chart(
40
- "wandb/bar/v0",
41
- table,
42
- {"label": label, "value": value},
43
- {"title": title},
65
+ return plot_table(
66
+ data_table=table,
67
+ vega_spec_name="wandb/bar/v0",
68
+ fields={"label": label, "value": value},
69
+ string_fields={"title": title},
44
70
  split_table=split_table,
45
71
  )
@@ -1,100 +1,181 @@
1
- from typing import Optional, Sequence
1
+ from __future__ import annotations
2
2
 
3
+ from typing import TYPE_CHECKING, Sequence, TypeVar
4
+
5
+ import wandb
3
6
  from wandb import util
4
- from wandb.data_types import Table
5
- from wandb.plot.viz import custom_chart
7
+ from wandb.plot.custom_chart import plot_table
8
+
9
+ if TYPE_CHECKING:
10
+ from wandb.plot.custom_chart import CustomChart
6
11
 
7
- chart_limit = Table.MAX_ROWS
12
+ T = TypeVar("T")
8
13
 
9
14
 
10
15
  def confusion_matrix(
11
- probs: Optional[Sequence[Sequence]] = None,
12
- y_true: Optional[Sequence] = None,
13
- preds: Optional[Sequence] = None,
14
- class_names: Optional[Sequence[str]] = None,
15
- title: Optional[str] = None,
16
- split_table: Optional[bool] = False,
17
- ):
18
- """Compute a multi-run confusion matrix.
19
-
20
- Arguments:
21
- probs (2-d arr): Shape [n_examples, n_classes]
22
- y_true (arr): Array of label indices.
23
- preds (arr): Array of predicted label indices.
24
- class_names (arr): Array of class names.
25
- split_table (bool): If True, adds "Custom Chart Tables/" to the key of the table so that it's logged in a different section.
16
+ probs: Sequence[Sequence[float]] | None = None,
17
+ y_true: Sequence[T] | None = None,
18
+ preds: Sequence[T] | None = None,
19
+ class_names: Sequence[str] | None = None,
20
+ title: str = "Confusion Matrix Curve",
21
+ split_table: bool = False,
22
+ ) -> CustomChart:
23
+ """Constructs a confusion matrix from a sequence of probabilities or predictions.
24
+
25
+ Args:
26
+ probs (Sequence[Sequence[float]] | None): A sequence of predicted probabilities for each
27
+ class. The sequence shape should be (N, K) where N is the number of samples
28
+ and K is the number of classes. If provided, `preds` should not be provided.
29
+ y_true (Sequence[T] | None): A sequence of true labels.
30
+ preds (Sequence[T] | None): A sequence of predicted class labels. If provided,
31
+ `probs` should not be provided.
32
+ class_names (Sequence[str] | None): Sequence of class names. If not
33
+ provided, class names will be defined as "Class_1", "Class_2", etc.
34
+ title (str): Title of the confusion matrix chart.
35
+ split_table (bool): Whether the table should be split into a separate section
36
+ in the W&B UI. If `True`, the table will be displayed in a section named
37
+ "Custom Chart Tables". Default is `False`.
26
38
 
27
39
  Returns:
28
- Nothing. To see plots, go to your W&B run page then expand the 'media' tab
29
- under 'auto visualizations'.
40
+ CustomChart: A custom chart object that can be logged to W&B. To log the
41
+ chart, pass it to `wandb.log()`.
42
+
43
+ Raises:
44
+ ValueError: If both `probs` and `preds` are provided or if the number of
45
+ predictions and true labels are not equal. If the number of unique
46
+ predicted classes exceeds the number of class names or if the number of
47
+ unique true labels exceeds the number of class names.
48
+ wandb.Error: If numpy is not installed.
49
+
50
+ Examples:
51
+ 1. Logging a confusion matrix with random probabilities for wildlife
52
+ classification:
53
+ ```
54
+ import numpy as np
55
+ import wandb
56
+
57
+ # Define class names for wildlife
58
+ wildlife_class_names = ["Lion", "Tiger", "Elephant", "Zebra"]
59
+
60
+ # Generate random true labels (0 to 3 for 10 samples)
61
+ wildlife_y_true = np.random.randint(0, 4, size=10)
62
+
63
+ # Generate random probabilities for each class (10 samples x 4 classes)
64
+ wildlife_probs = np.random.rand(10, 4)
65
+ wildlife_probs = np.exp(wildlife_probs) / np.sum(
66
+ np.exp(wildlife_probs),
67
+ axis=1,
68
+ keepdims=True,
69
+ )
70
+
71
+ # Initialize W&B run and log confusion matrix
72
+ with wandb.init(project="wildlife_classification") as run:
73
+ confusion_matrix = wandb.plot.confusion_matrix(
74
+ probs=wildlife_probs,
75
+ y_true=wildlife_y_true,
76
+ class_names=wildlife_class_names,
77
+ title="Wildlife Classification Confusion Matrix",
78
+ )
79
+ run.log({"wildlife_confusion_matrix": confusion_matrix})
80
+ ```
81
+ In this example, random probabilities are used to generate a confusion
82
+ matrix.
30
83
 
31
- Example:
84
+ 2. Logging a confusion matrix with simulated model predictions and 85%
85
+ accuracy:
32
86
  ```
33
- vals = np.random.uniform(size=(10, 5))
34
- probs = np.exp(vals)/np.sum(np.exp(vals), keepdims=True, axis=1)
35
- y_true = np.random.randint(0, 5, size=(10))
36
- labels = ["Cat", "Dog", "Bird", "Fish", "Horse"]
37
- wandb.log({'confusion_matrix': wandb.plot.confusion_matrix(probs, y_true=y_true, class_names=labels)})
87
+ import numpy as np
88
+ import wandb
89
+
90
+ # Define class names for wildlife
91
+ wildlife_class_names = ["Lion", "Tiger", "Elephant", "Zebra"]
92
+
93
+ # Simulate true labels for 200 animal images (imbalanced distribution)
94
+ wildlife_y_true = np.random.choice(
95
+ [0, 1, 2, 3],
96
+ size=200,
97
+ p=[0.2, 0.3, 0.25, 0.25],
98
+ )
99
+
100
+ # Simulate model predictions with 85% accuracy
101
+ wildlife_preds = [
102
+ y_t
103
+ if np.random.rand() < 0.85
104
+ else np.random.choice([x for x in range(4) if x != y_t])
105
+ for y_t in wildlife_y_true
106
+ ]
107
+
108
+ # Initialize W&B run and log confusion matrix
109
+ with wandb.init(project="wildlife_classification") as run:
110
+ confusion_matrix = wandb.plot.confusion_matrix(
111
+ preds=wildlife_preds,
112
+ y_true=wildlife_y_true,
113
+ class_names=wildlife_class_names,
114
+ title="Simulated Wildlife Classification Confusion Matrix"
115
+ )
116
+ run.log({"wildlife_confusion_matrix": confusion_matrix})
38
117
  ```
118
+ In this example, predictions are simulated with 85% accuracy to generate a
119
+ confusion matrix.
39
120
  """
40
121
  np = util.get_module(
41
122
  "numpy",
42
- required="confusion matrix requires the numpy library, install with `pip install numpy`",
43
- )
44
- # change warning
45
- assert probs is None or len(probs.shape) == 2, (
46
- "confusion_matrix has been updated to accept"
47
- " probabilities as the default first argument. Use preds=..."
123
+ required=(
124
+ "numpy is required to use wandb.plot.confusion_matrix, "
125
+ "install with `pip install numpy`",
126
+ ),
48
127
  )
49
128
 
50
- assert (probs is None or preds is None) and not (
51
- probs is None and preds is None
52
- ), "Must provide probabilities or predictions but not both to confusion matrix"
129
+ if probs is not None and preds is not None:
130
+ raise ValueError("Only one of `probs` or `preds` should be provided, not both.")
53
131
 
54
132
  if probs is not None:
55
133
  preds = np.argmax(probs, axis=1).tolist()
56
134
 
57
- assert len(preds) == len(
58
- y_true
59
- ), "Number of predictions and label indices must match"
135
+ if len(preds) != len(y_true):
136
+ raise ValueError("The number of predictions and true labels must be equal.")
60
137
 
61
138
  if class_names is not None:
62
139
  n_classes = len(class_names)
63
- class_inds = [i for i in range(n_classes)]
64
- assert max(preds) <= len(
65
- class_names
66
- ), "Higher predicted index than number of classes"
67
- assert max(y_true) <= len(
68
- class_names
69
- ), "Higher label class index than number of classes"
140
+ class_idx = list(range(n_classes))
141
+ if len(set(preds)) > len(class_names):
142
+ raise ValueError(
143
+ "The number of unique predicted classes exceeds the number of class names."
144
+ )
145
+
146
+ if len(set(y_true)) > len(class_names):
147
+ raise ValueError(
148
+ "The number of unique true labels exceeds the number of class names."
149
+ )
70
150
  else:
71
- class_inds = set(preds).union(set(y_true))
72
- n_classes = len(class_inds)
73
- class_names = [f"Class_{i}" for i in range(1, n_classes + 1)]
74
-
75
- # get mapping of inds to class index in case user has weird prediction indices
76
- class_mapping = {}
77
- for i, val in enumerate(sorted(list(class_inds))):
78
- class_mapping[val] = i
151
+ class_idx = set(preds).union(set(y_true))
152
+ n_classes = len(class_idx)
153
+ class_names = [f"Class_{i+1}" for i in range(n_classes)]
154
+
155
+ # Create a mapping from class name to index
156
+ class_mapping = {val: i for i, val in enumerate(sorted(list(class_idx)))}
157
+
79
158
  counts = np.zeros((n_classes, n_classes))
80
159
  for i in range(len(preds)):
81
160
  counts[class_mapping[y_true[i]], class_mapping[preds[i]]] += 1
82
161
 
83
- data = []
84
- for i in range(n_classes):
85
- for j in range(n_classes):
86
- data.append([class_names[i], class_names[j], counts[i, j]])
87
-
88
- fields = {
89
- "Actual": "Actual",
90
- "Predicted": "Predicted",
91
- "nPredictions": "nPredictions",
92
- }
93
- title = title or ""
94
- return custom_chart(
95
- "wandb/confusion_matrix/v1",
96
- Table(columns=["Actual", "Predicted", "nPredictions"], data=data),
97
- fields,
98
- {"title": title},
162
+ data = [
163
+ [class_names[i], class_names[j], counts[i, j]]
164
+ for i in range(n_classes)
165
+ for j in range(n_classes)
166
+ ]
167
+
168
+ return plot_table(
169
+ data_table=wandb.Table(
170
+ columns=["Actual", "Predicted", "nPredictions"],
171
+ data=data,
172
+ ),
173
+ vega_spec_name="wandb/confusion_matrix/v1",
174
+ fields={
175
+ "Actual": "Actual",
176
+ "Predicted": "Predicted",
177
+ "nPredictions": "nPredictions",
178
+ },
179
+ string_fields={"title": title},
99
180
  split_table=split_table,
100
181
  )
@@ -0,0 +1,124 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Any
5
+
6
+ import wandb
7
+
8
+
9
+ @dataclass
10
+ class CustomChartSpec:
11
+ spec_name: str
12
+ fields: dict[str, Any]
13
+ string_fields: dict[str, Any]
14
+ key: str = ""
15
+ panel_type: str = "Vega2"
16
+ split_table: bool = False
17
+
18
+ @property
19
+ def table_key(self) -> str:
20
+ if not self.key:
21
+ raise wandb.Error("Key for the custom chart spec is not set.")
22
+ if self.split_table:
23
+ return f"Custom Chart Tables/{self.key}_table"
24
+ return f"{self.key}_table"
25
+
26
+ @property
27
+ def config_value(self) -> dict[str, Any]:
28
+ return {
29
+ "panel_type": self.panel_type,
30
+ "panel_config": {
31
+ "panelDefId": self.spec_name,
32
+ "fieldSettings": self.fields,
33
+ "stringSettings": self.string_fields,
34
+ "transform": {"name": "tableWithLeafColNames"},
35
+ "userQuery": {
36
+ "queryFields": [
37
+ {
38
+ "name": "runSets",
39
+ "args": [{"name": "runSets", "value": "${runSets}"}],
40
+ "fields": [
41
+ {"name": "id", "fields": []},
42
+ {"name": "name", "fields": []},
43
+ {"name": "_defaultColorIndex", "fields": []},
44
+ {
45
+ "name": "summaryTable",
46
+ "args": [
47
+ {
48
+ "name": "tableKey",
49
+ "value": self.table_key,
50
+ }
51
+ ],
52
+ "fields": [],
53
+ },
54
+ ],
55
+ }
56
+ ],
57
+ },
58
+ },
59
+ }
60
+
61
+ @property
62
+ def config_key(self) -> tuple[str, str, str]:
63
+ return ("_wandb", "visualize", self.key)
64
+
65
+
66
+ @dataclass
67
+ class CustomChart:
68
+ table: wandb.Table
69
+ spec: CustomChartSpec
70
+
71
+ def set_key(self, key: str):
72
+ """Sets the key for the spec and updates dependent configurations."""
73
+ self.spec.key = key
74
+
75
+
76
+ def plot_table(
77
+ vega_spec_name: str,
78
+ data_table: wandb.Table,
79
+ fields: dict[str, Any],
80
+ string_fields: dict[str, Any] | None = None,
81
+ split_table: bool = False,
82
+ ) -> CustomChart:
83
+ """Creates a custom charts using a Vega-Lite specification and a `wandb.Table`.
84
+
85
+ This function creates a custom chart based on a Vega-Lite specification and
86
+ a data table represented by a `wandb.Table` object. The specification needs
87
+ to be predefined and stored in the W&B backend. The function returns a custom
88
+ chart object that can be logged to W&B using `wandb.log()`.
89
+
90
+ Args:
91
+ vega_spec_name (str): The name or identifier of the Vega-Lite spec
92
+ that defines the visualization structure.
93
+ data_table (wandb.Table): A `wandb.Table` object containing the data to be
94
+ visualized.
95
+ fields (dict[str, Any]): A mapping between the fields in the Vega-Lite spec and the
96
+ corresponding columns in the data table to be visualized.
97
+ string_fields (dict[str, Any] | None): A dictionary for providing values for any string constants
98
+ required by the custom visualization.
99
+ split_table (bool): Whether the table should be split into a separate section
100
+ in the W&B UI. If `True`, the table will be displayed in a section named
101
+ "Custom Chart Tables". Default is `False`.
102
+
103
+ Returns:
104
+ CustomChart: A custom chart object that can be logged to W&B. To log the
105
+ chart, pass it to `wandb.log()`.
106
+
107
+ Raises:
108
+ wandb.Error: If `data_table` is not a `wandb.Table` object.
109
+ """
110
+
111
+ if not isinstance(data_table, wandb.Table):
112
+ raise wandb.Error(
113
+ f"Expected `data_table` to be `wandb.Table` type, instead got {type(data_table).__name__}"
114
+ )
115
+
116
+ return CustomChart(
117
+ table=data_table,
118
+ spec=CustomChartSpec(
119
+ spec_name=vega_spec_name,
120
+ fields=fields,
121
+ string_fields=string_fields or {},
122
+ split_table=split_table,
123
+ ),
124
+ )