wandb 0.16.6__py3-none-any.whl → 0.17.0rc2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- package_readme.md +95 -0
- wandb/__init__.py +2 -2
- wandb/agents/pyagent.py +0 -1
- wandb/analytics/sentry.py +2 -1
- wandb/apis/importers/internals/protocols.py +30 -56
- wandb/apis/importers/mlflow.py +13 -26
- wandb/apis/importers/wandb.py +8 -14
- wandb/apis/public/api.py +1 -0
- wandb/apis/public/artifacts.py +1 -0
- wandb/apis/public/files.py +1 -0
- wandb/apis/public/history.py +1 -0
- wandb/apis/public/jobs.py +1 -0
- wandb/apis/public/projects.py +1 -0
- wandb/apis/public/reports.py +1 -0
- wandb/apis/public/runs.py +1 -0
- wandb/apis/public/sweeps.py +1 -0
- wandb/apis/public/teams.py +1 -0
- wandb/apis/public/users.py +1 -0
- wandb/apis/reports/v1/_blocks.py +3 -7
- wandb/apis/reports/v2/gql.py +1 -0
- wandb/apis/reports/v2/interface.py +3 -4
- wandb/apis/reports/v2/internal.py +5 -8
- wandb/cli/cli.py +2 -2
- wandb/data_types.py +9 -6
- wandb/docker/__init__.py +1 -1
- wandb/env.py +38 -8
- wandb/errors/__init__.py +5 -0
- wandb/integration/catboost/catboost.py +1 -1
- wandb/integration/fastai/__init__.py +1 -0
- wandb/integration/huggingface/resolver.py +2 -2
- wandb/integration/keras/__init__.py +1 -0
- wandb/integration/keras/callbacks/metrics_logger.py +1 -1
- wandb/integration/keras/keras.py +7 -7
- wandb/integration/langchain/wandb_tracer.py +1 -0
- wandb/integration/lightning/fabric/logger.py +1 -3
- wandb/integration/metaflow/metaflow.py +41 -6
- wandb/integration/openai/fine_tuning.py +3 -3
- wandb/keras/__init__.py +1 -0
- wandb/old/summary.py +1 -1
- wandb/plot/confusion_matrix.py +1 -1
- wandb/plots/precision_recall.py +1 -1
- wandb/plots/roc.py +1 -1
- wandb/proto/v3/wandb_internal_pb2.py +364 -332
- wandb/proto/v3/wandb_settings_pb2.py +1 -1
- wandb/proto/v4/wandb_internal_pb2.py +322 -316
- wandb/proto/v4/wandb_settings_pb2.py +1 -1
- wandb/proto/wandb_internal_codegen.py +0 -25
- wandb/sdk/artifacts/artifact.py +16 -4
- wandb/sdk/artifacts/artifact_download_logger.py +1 -0
- wandb/sdk/artifacts/artifact_file_cache.py +18 -4
- wandb/sdk/artifacts/artifact_instance_cache.py +1 -0
- wandb/sdk/artifacts/artifact_manifest.py +1 -0
- wandb/sdk/artifacts/artifact_manifest_entry.py +1 -0
- wandb/sdk/artifacts/artifact_manifests/artifact_manifest_v1.py +1 -0
- wandb/sdk/artifacts/artifact_saver.py +5 -2
- wandb/sdk/artifacts/artifact_state.py +1 -0
- wandb/sdk/artifacts/artifact_ttl.py +1 -0
- wandb/sdk/artifacts/exceptions.py +1 -0
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +13 -18
- wandb/sdk/artifacts/storage_handlers/http_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/local_file_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/multi_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +5 -3
- wandb/sdk/artifacts/storage_handlers/tracking_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +1 -0
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -0
- wandb/sdk/artifacts/storage_policies/wandb_storage_policy.py +1 -0
- wandb/sdk/artifacts/storage_policy.py +1 -0
- wandb/sdk/data_types/_dtypes.py +8 -8
- wandb/sdk/data_types/base_types/media.py +3 -6
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +3 -1
- wandb/sdk/data_types/image.py +1 -1
- wandb/sdk/data_types/video.py +1 -1
- wandb/sdk/integration_utils/auto_logging.py +5 -6
- wandb/sdk/integration_utils/data_logging.py +10 -6
- wandb/sdk/interface/interface.py +55 -32
- wandb/sdk/interface/interface_shared.py +7 -13
- wandb/sdk/internal/datastore.py +1 -1
- wandb/sdk/internal/handler.py +18 -2
- wandb/sdk/internal/internal.py +0 -1
- wandb/sdk/internal/internal_util.py +0 -1
- wandb/sdk/internal/job_builder.py +5 -4
- wandb/sdk/internal/profiler.py +1 -0
- wandb/sdk/internal/run.py +1 -0
- wandb/sdk/internal/sender.py +1 -1
- wandb/sdk/internal/system/assets/gpu_amd.py +44 -44
- wandb/sdk/internal/system/assets/gpu_apple.py +56 -11
- wandb/sdk/internal/system/assets/interfaces.py +6 -8
- wandb/sdk/internal/system/assets/open_metrics.py +2 -2
- wandb/sdk/internal/system/assets/trainium.py +1 -3
- wandb/sdk/launch/_project_spec.py +8 -4
- wandb/sdk/launch/agent/agent.py +2 -1
- wandb/sdk/launch/agent/config.py +72 -11
- wandb/sdk/launch/builder/abstract.py +2 -1
- wandb/sdk/launch/builder/build.py +29 -2
- wandb/sdk/launch/builder/docker_builder.py +1 -0
- wandb/sdk/launch/builder/kaniko_builder.py +2 -2
- wandb/sdk/launch/builder/noop.py +1 -0
- wandb/sdk/launch/create_job.py +18 -0
- wandb/sdk/launch/environment/abstract.py +1 -0
- wandb/sdk/launch/environment/gcp_environment.py +1 -0
- wandb/sdk/launch/environment/local_environment.py +1 -0
- wandb/sdk/launch/loader.py +1 -0
- wandb/sdk/launch/registry/abstract.py +1 -0
- wandb/sdk/launch/registry/azure_container_registry.py +1 -0
- wandb/sdk/launch/registry/elastic_container_registry.py +1 -0
- wandb/sdk/launch/registry/google_artifact_registry.py +2 -1
- wandb/sdk/launch/registry/local_registry.py +1 -0
- wandb/sdk/launch/runner/abstract.py +1 -0
- wandb/sdk/launch/runner/kubernetes_monitor.py +1 -0
- wandb/sdk/launch/runner/kubernetes_runner.py +4 -3
- wandb/sdk/launch/runner/sagemaker_runner.py +11 -10
- wandb/sdk/launch/sweeps/scheduler.py +4 -3
- wandb/sdk/launch/sweeps/scheduler_sweep.py +2 -1
- wandb/sdk/launch/sweeps/utils.py +3 -3
- wandb/sdk/launch/utils.py +3 -3
- wandb/sdk/lib/fsm.py +8 -12
- wandb/sdk/lib/gitlib.py +4 -4
- wandb/sdk/lib/import_hooks.py +1 -1
- wandb/sdk/lib/lazyloader.py +0 -1
- wandb/sdk/lib/proto_util.py +1 -1
- wandb/sdk/lib/redirect.py +19 -14
- wandb/sdk/lib/retry.py +3 -2
- wandb/sdk/lib/tracelog.py +1 -1
- wandb/sdk/service/service.py +17 -15
- wandb/sdk/verify/verify.py +2 -1
- wandb/sdk/wandb_manager.py +2 -2
- wandb/sdk/wandb_require.py +5 -0
- wandb/sdk/wandb_run.py +25 -20
- wandb/sdk/wandb_settings.py +0 -1
- wandb/sdk/wandb_setup.py +1 -1
- wandb/sklearn/__init__.py +1 -0
- wandb/sklearn/plot/__init__.py +1 -0
- wandb/sklearn/plot/classifier.py +7 -6
- wandb/sklearn/plot/clusterer.py +2 -1
- wandb/sklearn/plot/regressor.py +1 -0
- wandb/sklearn/plot/shared.py +1 -0
- wandb/sklearn/utils.py +1 -0
- wandb/testing/relay.py +4 -4
- wandb/trigger.py +1 -0
- wandb/util.py +40 -17
- wandb/wandb_controller.py +2 -3
- wandb/wandb_torch.py +1 -2
- {wandb-0.16.6.dist-info → wandb-0.17.0rc2.dist-info}/METADATA +68 -69
- {wandb-0.16.6.dist-info → wandb-0.17.0rc2.dist-info}/RECORD +149 -150
- {wandb-0.16.6.dist-info → wandb-0.17.0rc2.dist-info}/WHEEL +1 -2
- wandb/bin/apple_gpu_stats +0 -0
- wandb-0.16.6.dist-info/top_level.txt +0 -1
- {wandb-0.16.6.dist-info → wandb-0.17.0rc2.dist-info}/entry_points.txt +0 -0
- {wandb-0.16.6.dist-info → wandb-0.17.0rc2.dist-info/licenses}/LICENSE +0 -0
wandb/env.py
CHANGED
@@ -13,11 +13,10 @@ these values in many cases.
|
|
13
13
|
import json
|
14
14
|
import os
|
15
15
|
import sys
|
16
|
-
from distutils.util import strtobool
|
17
16
|
from pathlib import Path
|
18
17
|
from typing import List, MutableMapping, Optional, Union
|
19
18
|
|
20
|
-
import
|
19
|
+
import platformdirs # type: ignore
|
21
20
|
|
22
21
|
Env = Optional[MutableMapping]
|
23
22
|
|
@@ -61,6 +60,8 @@ SAVE_CODE = "WANDB_SAVE_CODE"
|
|
61
60
|
TAGS = "WANDB_TAGS"
|
62
61
|
IGNORE = "WANDB_IGNORE_GLOBS"
|
63
62
|
ERROR_REPORTING = "WANDB_ERROR_REPORTING"
|
63
|
+
CORE_ERROR_REPORTING = "WANDB_CORE_ERROR_REPORTING"
|
64
|
+
CORE_DEBUG = "WANDB_CORE_DEBUG"
|
64
65
|
DOCKER = "WANDB_DOCKER"
|
65
66
|
AGENT_REPORT_INTERVAL = "WANDB_AGENT_REPORT_INTERVAL"
|
66
67
|
AGENT_KILL_DELAY = "WANDB_AGENT_KILL_DELAY"
|
@@ -87,6 +88,7 @@ _EXECUTABLE = "WANDB_EXECUTABLE"
|
|
87
88
|
LAUNCH_QUEUE_NAME = "WANDB_LAUNCH_QUEUE_NAME"
|
88
89
|
LAUNCH_QUEUE_ENTITY = "WANDB_LAUNCH_QUEUE_ENTITY"
|
89
90
|
LAUNCH_TRACE_ID = "WANDB_LAUNCH_TRACE_ID"
|
91
|
+
_REQUIRE_CORE = "WANDB__REQUIRE_CORE"
|
90
92
|
|
91
93
|
# For testing, to be removed in future version
|
92
94
|
USE_V1_ARTIFACTS = "_WANDB_USE_V1_ARTIFACTS"
|
@@ -139,11 +141,16 @@ def _env_as_bool(
|
|
139
141
|
if env is None:
|
140
142
|
env = os.environ
|
141
143
|
val = env.get(var, default)
|
144
|
+
if not isinstance(val, str):
|
145
|
+
return False
|
142
146
|
try:
|
143
|
-
|
144
|
-
except
|
145
|
-
|
146
|
-
|
147
|
+
return strtobool(val)
|
148
|
+
except ValueError:
|
149
|
+
return False
|
150
|
+
|
151
|
+
|
152
|
+
def is_require_core(env: Optional[Env] = None) -> bool:
|
153
|
+
return _env_as_bool(_REQUIRE_CORE, default="False", env=env)
|
147
154
|
|
148
155
|
|
149
156
|
def is_debug(default: Optional[str] = None, env: Optional[Env] = None) -> bool:
|
@@ -154,6 +161,14 @@ def error_reporting_enabled() -> bool:
|
|
154
161
|
return _env_as_bool(ERROR_REPORTING, default="True")
|
155
162
|
|
156
163
|
|
164
|
+
def core_error_reporting_enabled(default: Optional[str] = None) -> bool:
|
165
|
+
return _env_as_bool(CORE_ERROR_REPORTING, default=default)
|
166
|
+
|
167
|
+
|
168
|
+
def core_debug(default: Optional[str] = None) -> bool:
|
169
|
+
return _env_as_bool(CORE_DEBUG, default=default)
|
170
|
+
|
171
|
+
|
157
172
|
def ssl_disabled() -> bool:
|
158
173
|
return _env_as_bool(DISABLE_SSL, default="False")
|
159
174
|
|
@@ -370,7 +385,7 @@ def get_magic(
|
|
370
385
|
|
371
386
|
|
372
387
|
def get_data_dir(env: Optional[Env] = None) -> str:
|
373
|
-
default_dir =
|
388
|
+
default_dir = platformdirs.user_data_dir("wandb")
|
374
389
|
if env is None:
|
375
390
|
env = os.environ
|
376
391
|
val = env.get(DATA_DIR, default_dir)
|
@@ -395,7 +410,7 @@ def get_artifact_fetch_file_url_batch_size(env: Optional[Env] = None) -> int:
|
|
395
410
|
|
396
411
|
def get_cache_dir(env: Optional[Env] = None) -> Path:
|
397
412
|
env = env or os.environ
|
398
|
-
return Path(env.get(CACHE_DIR,
|
413
|
+
return Path(env.get(CACHE_DIR, platformdirs.user_cache_dir("wandb")))
|
399
414
|
|
400
415
|
|
401
416
|
def get_use_v1_artifacts(env: Optional[Env] = None) -> bool:
|
@@ -464,3 +479,18 @@ def get_launch_trace_id(env: Optional[Env] = None) -> Optional[str]:
|
|
464
479
|
env = os.environ
|
465
480
|
val = env.get(LAUNCH_TRACE_ID, None)
|
466
481
|
return val
|
482
|
+
|
483
|
+
|
484
|
+
def strtobool(val: str) -> bool:
|
485
|
+
"""Convert a string representation of truth to true or false.
|
486
|
+
|
487
|
+
Copied from distutils. distutils was removed in Python 3.12.
|
488
|
+
"""
|
489
|
+
val = val.lower()
|
490
|
+
|
491
|
+
if val in ("y", "yes", "t", "true", "on", "1"):
|
492
|
+
return True
|
493
|
+
elif val in ("n", "no", "f", "false", "off", "0"):
|
494
|
+
return False
|
495
|
+
else:
|
496
|
+
raise ValueError(f"invalid truth value {val!r}")
|
wandb/errors/__init__.py
CHANGED
@@ -4,6 +4,7 @@ __all__ = [
|
|
4
4
|
"AuthenticationError",
|
5
5
|
"UsageError",
|
6
6
|
"UnsupportedError",
|
7
|
+
"WandbCoreNotAvailableError",
|
7
8
|
]
|
8
9
|
|
9
10
|
from typing import Optional
|
@@ -39,3 +40,7 @@ class UsageError(Error):
|
|
39
40
|
|
40
41
|
class UnsupportedError(UsageError):
|
41
42
|
"""Raised when trying to use a feature that is not supported."""
|
43
|
+
|
44
|
+
|
45
|
+
class WandbCoreNotAvailableError(Error):
|
46
|
+
"""Raised when wandb core is not available."""
|
@@ -81,7 +81,7 @@ def _checkpoint_artifact(
|
|
81
81
|
|
82
82
|
|
83
83
|
def _log_feature_importance(
|
84
|
-
model: Union[CatBoostClassifier, CatBoostRegressor]
|
84
|
+
model: Union[CatBoostClassifier, CatBoostRegressor],
|
85
85
|
) -> None:
|
86
86
|
"""Log feature importance with default settings."""
|
87
87
|
if wandb.run is None:
|
@@ -91,8 +91,8 @@ class HuggingFacePipelineRequestResponseResolver:
|
|
91
91
|
|
92
92
|
# TODO: This should have a dependency on PreTrainedModel. i.e. isinstance(PreTrainedModel)
|
93
93
|
# from transformers.modeling_utils import PreTrainedModel
|
94
|
-
# We do not want this dependency
|
95
|
-
# the structure of the pipeline which may have unintended consequences
|
94
|
+
# We do not want this dependency explicitly in our codebase so we make a very general
|
95
|
+
# assumption about the structure of the pipeline which may have unintended consequences
|
96
96
|
def _get_model(self, pipe) -> Optional[Any]:
|
97
97
|
"""Extracts model from the pipeline.
|
98
98
|
|
@@ -43,7 +43,7 @@ class WandbMetricsLogger(callbacks.Callback):
|
|
43
43
|
at the end of each epoch. If "batch", logs metrics at the end
|
44
44
|
of each batch. If an integer, logs metrics at the end of that
|
45
45
|
many batches. Defaults to "epoch".
|
46
|
-
initial_global_step: (int) Use this argument to
|
46
|
+
initial_global_step: (int) Use this argument to correctly log the
|
47
47
|
learning rate when you resume training from some `initial_epoch`,
|
48
48
|
and a learning rate scheduler is used. This can be computed as
|
49
49
|
`step_size * initial_step`. Defaults to 0.
|
wandb/integration/keras/keras.py
CHANGED
@@ -576,7 +576,7 @@ class WandbCallback(tf.keras.callbacks.Callback):
|
|
576
576
|
)
|
577
577
|
self._model_trained_since_last_eval = False
|
578
578
|
except Exception as e:
|
579
|
-
wandb.termwarn("Error
|
579
|
+
wandb.termwarn("Error during prediction logging for epoch: " + str(e))
|
580
580
|
|
581
581
|
def on_epoch_end(self, epoch, logs=None):
|
582
582
|
if logs is None:
|
@@ -616,9 +616,9 @@ class WandbCallback(tf.keras.callbacks.Callback):
|
|
616
616
|
self.current = logs.get(self.monitor)
|
617
617
|
if self.current and self.monitor_op(self.current, self.best):
|
618
618
|
if self.log_best_prefix:
|
619
|
-
wandb.run.summary[
|
620
|
-
|
621
|
-
|
619
|
+
wandb.run.summary[f"{self.log_best_prefix}{self.monitor}"] = (
|
620
|
+
self.current
|
621
|
+
)
|
622
622
|
wandb.run.summary["{}{}".format(self.log_best_prefix, "epoch")] = epoch
|
623
623
|
if self.verbose and not self.save_model:
|
624
624
|
print(
|
@@ -937,9 +937,9 @@ class WandbCallback(tf.keras.callbacks.Callback):
|
|
937
937
|
grads = self._grad_accumulator_callback.grads
|
938
938
|
metrics = {}
|
939
939
|
for weight, grad in zip(weights, grads):
|
940
|
-
metrics[
|
941
|
-
|
942
|
-
|
940
|
+
metrics["gradients/" + weight.name.split(":")[0] + ".gradient"] = (
|
941
|
+
wandb.Histogram(grad)
|
942
|
+
)
|
943
943
|
return metrics
|
944
944
|
|
945
945
|
def _log_dataframe(self):
|
@@ -14,6 +14,7 @@ integration will not break user code. The one exception to the rule is at import
|
|
14
14
|
LangChain is not installed, or the symbols are not in the same place, the appropriate error
|
15
15
|
will be raised when importing this module.
|
16
16
|
"""
|
17
|
+
|
17
18
|
from packaging import version
|
18
19
|
|
19
20
|
import wandb.util
|
@@ -401,9 +401,7 @@ class WandbLogger(Logger):
|
|
401
401
|
"*", step_metric="trainer/global_step", step_sync=True
|
402
402
|
)
|
403
403
|
|
404
|
-
self._experiment._label(
|
405
|
-
repo="lightning_fabric_logger"
|
406
|
-
) # pylint: disable=protected-access
|
404
|
+
self._experiment._label(repo="lightning_fabric_logger") # pylint: disable=protected-access
|
407
405
|
with telemetry.context(run=self._experiment) as tel:
|
408
406
|
tel.feature.lightning_fabric_logger = True
|
409
407
|
return self._experiment
|
@@ -36,7 +36,15 @@ try:
|
|
36
36
|
import pandas as pd
|
37
37
|
|
38
38
|
@typedispatch # noqa: F811
|
39
|
-
def _wandb_use(
|
39
|
+
def _wandb_use(
|
40
|
+
name: str,
|
41
|
+
data: pd.DataFrame,
|
42
|
+
datasets=False,
|
43
|
+
run=None,
|
44
|
+
testing=False,
|
45
|
+
*args,
|
46
|
+
**kwargs,
|
47
|
+
): # type: ignore
|
40
48
|
if testing:
|
41
49
|
return "datasets" if datasets else None
|
42
50
|
|
@@ -74,7 +82,15 @@ try:
|
|
74
82
|
import torch.nn as nn
|
75
83
|
|
76
84
|
@typedispatch # noqa: F811
|
77
|
-
def _wandb_use(
|
85
|
+
def _wandb_use(
|
86
|
+
name: str,
|
87
|
+
data: nn.Module,
|
88
|
+
models=False,
|
89
|
+
run=None,
|
90
|
+
testing=False,
|
91
|
+
*args,
|
92
|
+
**kwargs,
|
93
|
+
): # type: ignore
|
78
94
|
if testing:
|
79
95
|
return "models" if models else None
|
80
96
|
|
@@ -111,7 +127,15 @@ try:
|
|
111
127
|
from sklearn.base import BaseEstimator
|
112
128
|
|
113
129
|
@typedispatch # noqa: F811
|
114
|
-
def _wandb_use(
|
130
|
+
def _wandb_use(
|
131
|
+
name: str,
|
132
|
+
data: BaseEstimator,
|
133
|
+
models=False,
|
134
|
+
run=None,
|
135
|
+
testing=False,
|
136
|
+
*args,
|
137
|
+
**kwargs,
|
138
|
+
): # type: ignore
|
115
139
|
if testing:
|
116
140
|
return "models" if models else None
|
117
141
|
|
@@ -169,7 +193,14 @@ class ArtifactProxy:
|
|
169
193
|
|
170
194
|
|
171
195
|
@typedispatch # noqa: F811
|
172
|
-
def wandb_track(
|
196
|
+
def wandb_track(
|
197
|
+
name: str,
|
198
|
+
data: (dict, list, set, str, int, float, bool),
|
199
|
+
run=None,
|
200
|
+
testing=False,
|
201
|
+
*args,
|
202
|
+
**kwargs,
|
203
|
+
): # type: ignore
|
173
204
|
if testing:
|
174
205
|
return "scalar"
|
175
206
|
|
@@ -222,12 +253,16 @@ def wandb_use(name: str, data, *args, **kwargs):
|
|
222
253
|
|
223
254
|
|
224
255
|
@typedispatch # noqa: F811
|
225
|
-
def wandb_use(
|
256
|
+
def wandb_use(
|
257
|
+
name: str, data: (dict, list, set, str, int, float, bool), *args, **kwargs
|
258
|
+
): # type: ignore
|
226
259
|
pass # do nothing for these types
|
227
260
|
|
228
261
|
|
229
262
|
@typedispatch # noqa: F811
|
230
|
-
def _wandb_use(
|
263
|
+
def _wandb_use(
|
264
|
+
name: str, data: Path, datasets=False, run=None, testing=False, *args, **kwargs
|
265
|
+
): # type: ignore
|
231
266
|
if testing:
|
232
267
|
return "datasets" if datasets else None
|
233
268
|
|
@@ -310,9 +310,9 @@ class WandbLogger:
|
|
310
310
|
try:
|
311
311
|
hyperparams["n_epochs"] = hyperparameters.n_epochs
|
312
312
|
hyperparams["batch_size"] = hyperparameters.batch_size
|
313
|
-
hyperparams[
|
314
|
-
|
315
|
-
|
313
|
+
hyperparams["learning_rate_multiplier"] = (
|
314
|
+
hyperparameters.learning_rate_multiplier
|
315
|
+
)
|
316
316
|
except Exception:
|
317
317
|
# If unpacking fails, return the object to be logged as config
|
318
318
|
return None
|
wandb/keras/__init__.py
CHANGED
wandb/old/summary.py
CHANGED
@@ -61,7 +61,7 @@ class SummarySubDict:
|
|
61
61
|
This should only be implemented by the "_root" child class.
|
62
62
|
|
63
63
|
We pass the child_dict so the item can be set on it or not as
|
64
|
-
appropriate. Returning None for a
|
64
|
+
appropriate. Returning None for a nonexistent path wouldn't be
|
65
65
|
distinguishable from that path being set to the value None.
|
66
66
|
"""
|
67
67
|
raise NotImplementedError
|
wandb/plot/confusion_matrix.py
CHANGED
@@ -48,7 +48,7 @@ def confusion_matrix(
|
|
48
48
|
|
49
49
|
assert (probs is None or preds is None) and not (
|
50
50
|
probs is None and preds is None
|
51
|
-
), "Must provide
|
51
|
+
), "Must provide probabilities or predictions but not both to confusion matrix"
|
52
52
|
|
53
53
|
if probs is not None:
|
54
54
|
preds = np.argmax(probs, axis=1).tolist()
|
wandb/plots/precision_recall.py
CHANGED
@@ -25,7 +25,7 @@ def precision_recall(
|
|
25
25
|
Arguments:
|
26
26
|
y_true (arr): Test set labels.
|
27
27
|
y_probas (arr): Test set predicted probabilities.
|
28
|
-
labels (list): Named labels for target
|
28
|
+
labels (list): Named labels for target variable (y). Makes plots easier to
|
29
29
|
read by replacing target values with corresponding index.
|
30
30
|
For example labels= ['dog', 'cat', 'owl'] all 0s are
|
31
31
|
replaced by 'dog', 1s by 'cat'.
|
wandb/plots/roc.py
CHANGED
@@ -25,7 +25,7 @@ def roc(
|
|
25
25
|
Arguments:
|
26
26
|
y_true (arr): Test set labels.
|
27
27
|
y_probas (arr): Test set predicted probabilities.
|
28
|
-
labels (list): Named labels for target
|
28
|
+
labels (list): Named labels for target variable (y). Makes plots easier to
|
29
29
|
read by replacing target values with corresponding index.
|
30
30
|
For example labels= ['dog', 'cat', 'owl'] all 0s are
|
31
31
|
replaced by 'dog', 1s by 'cat'.
|