wandb 0.18.4__py3-none-any.whl → 0.18.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb/__init__.py +2 -2
- wandb/__init__.pyi +21 -19
- wandb/agents/pyagent.py +1 -1
- wandb/apis/importers/wandb.py +1 -1
- wandb/apis/normalize.py +2 -18
- wandb/apis/public/api.py +122 -62
- wandb/apis/public/artifacts.py +8 -3
- wandb/apis/public/files.py +17 -2
- wandb/apis/public/jobs.py +2 -2
- wandb/apis/public/query_generator.py +1 -1
- wandb/apis/public/runs.py +8 -8
- wandb/apis/public/teams.py +3 -3
- wandb/apis/public/users.py +1 -1
- wandb/apis/public/utils.py +68 -0
- wandb/bin/gpu_stats +0 -0
- wandb/cli/cli.py +12 -3
- wandb/data_types.py +1 -1
- wandb/docker/__init__.py +2 -1
- wandb/docker/auth.py +2 -3
- wandb/errors/links.py +73 -0
- wandb/errors/term.py +7 -6
- wandb/filesync/step_prepare.py +1 -1
- wandb/filesync/upload_job.py +1 -1
- wandb/integration/catboost/catboost.py +2 -2
- wandb/integration/diffusers/pipeline_resolver.py +1 -1
- wandb/integration/diffusers/resolvers/multimodal.py +6 -6
- wandb/integration/diffusers/resolvers/utils.py +1 -1
- wandb/integration/fastai/__init__.py +3 -2
- wandb/integration/keras/callbacks/metrics_logger.py +1 -1
- wandb/integration/keras/callbacks/model_checkpoint.py +1 -1
- wandb/integration/keras/keras.py +1 -1
- wandb/integration/kfp/kfp_patch.py +1 -1
- wandb/integration/lightgbm/__init__.py +2 -2
- wandb/integration/magic.py +2 -2
- wandb/integration/metaflow/metaflow.py +1 -1
- wandb/integration/sacred/__init__.py +1 -1
- wandb/integration/sagemaker/auth.py +1 -1
- wandb/integration/sklearn/plot/classifier.py +7 -7
- wandb/integration/sklearn/plot/clusterer.py +3 -3
- wandb/integration/sklearn/plot/regressor.py +3 -3
- wandb/integration/sklearn/plot/shared.py +2 -2
- wandb/integration/tensorboard/log.py +2 -2
- wandb/integration/ultralytics/callback.py +2 -2
- wandb/integration/xgboost/xgboost.py +1 -1
- wandb/jupyter.py +0 -1
- wandb/plot/__init__.py +17 -8
- wandb/plot/bar.py +53 -27
- wandb/plot/confusion_matrix.py +151 -70
- wandb/plot/custom_chart.py +124 -0
- wandb/plot/histogram.py +46 -20
- wandb/plot/line.py +57 -26
- wandb/plot/line_series.py +148 -60
- wandb/plot/pr_curve.py +89 -44
- wandb/plot/roc_curve.py +82 -37
- wandb/plot/scatter.py +53 -20
- wandb/plot/viz.py +20 -102
- wandb/sdk/artifacts/artifact.py +280 -328
- wandb/sdk/artifacts/artifact_manifest.py +10 -9
- wandb/sdk/artifacts/artifact_manifest_entry.py +1 -1
- wandb/sdk/artifacts/storage_handlers/azure_handler.py +9 -4
- wandb/sdk/artifacts/storage_handlers/gcs_handler.py +1 -3
- wandb/sdk/artifacts/storage_handlers/s3_handler.py +1 -1
- wandb/sdk/artifacts/storage_handlers/wb_artifact_handler.py +2 -2
- wandb/sdk/artifacts/storage_handlers/wb_local_artifact_handler.py +1 -1
- wandb/sdk/backend/backend.py +0 -1
- wandb/sdk/data_types/audio.py +1 -1
- wandb/sdk/data_types/base_types/media.py +66 -5
- wandb/sdk/data_types/bokeh.py +1 -1
- wandb/sdk/data_types/helper_types/bounding_boxes_2d.py +1 -1
- wandb/sdk/data_types/helper_types/image_mask.py +2 -2
- wandb/sdk/data_types/histogram.py +1 -1
- wandb/sdk/data_types/html.py +1 -1
- wandb/sdk/data_types/image.py +1 -1
- wandb/sdk/data_types/molecule.py +3 -3
- wandb/sdk/data_types/object_3d.py +4 -4
- wandb/sdk/data_types/plotly.py +1 -1
- wandb/sdk/data_types/saved_model.py +0 -1
- wandb/sdk/data_types/table.py +7 -7
- wandb/sdk/data_types/trace_tree.py +1 -1
- wandb/sdk/data_types/video.py +4 -3
- wandb/sdk/interface/router.py +0 -2
- wandb/sdk/internal/datastore.py +1 -1
- wandb/sdk/internal/file_pusher.py +1 -1
- wandb/sdk/internal/file_stream.py +4 -4
- wandb/sdk/internal/handler.py +3 -2
- wandb/sdk/internal/internal.py +1 -1
- wandb/sdk/internal/internal_api.py +183 -64
- wandb/sdk/internal/job_builder.py +4 -3
- wandb/sdk/internal/system/assets/__init__.py +0 -2
- wandb/sdk/internal/tb_watcher.py +11 -10
- wandb/sdk/launch/_launch.py +4 -3
- wandb/sdk/launch/_launch_add.py +2 -2
- wandb/sdk/launch/builder/kaniko_builder.py +0 -1
- wandb/sdk/launch/create_job.py +1 -0
- wandb/sdk/launch/environment/local_environment.py +0 -1
- wandb/sdk/launch/errors.py +0 -6
- wandb/sdk/launch/registry/local_registry.py +0 -2
- wandb/sdk/launch/runner/abstract.py +0 -5
- wandb/sdk/launch/sweeps/__init__.py +0 -2
- wandb/sdk/launch/sweeps/scheduler.py +0 -2
- wandb/sdk/launch/sweeps/scheduler_sweep.py +0 -1
- wandb/sdk/lib/apikey.py +3 -3
- wandb/sdk/lib/file_stream_utils.py +1 -1
- wandb/sdk/lib/filesystem.py +1 -1
- wandb/sdk/lib/ipython.py +16 -9
- wandb/sdk/lib/mailbox.py +0 -4
- wandb/sdk/lib/printer.py +44 -8
- wandb/sdk/lib/retry.py +1 -1
- wandb/sdk/service/service.py +3 -3
- wandb/sdk/service/streams.py +2 -4
- wandb/sdk/wandb_init.py +20 -20
- wandb/sdk/wandb_login.py +1 -1
- wandb/sdk/wandb_require.py +1 -4
- wandb/sdk/wandb_run.py +57 -69
- wandb/sdk/wandb_settings.py +3 -4
- wandb/sdk/wandb_sync.py +2 -1
- wandb/util.py +46 -18
- wandb/wandb_agent.py +3 -3
- wandb/wandb_controller.py +2 -2
- {wandb-0.18.4.dist-info → wandb-0.18.6.dist-info}/METADATA +1 -1
- {wandb-0.18.4.dist-info → wandb-0.18.6.dist-info}/RECORD +124 -125
- wandb/sdk/internal/system/assets/gpu_apple.py +0 -177
- wandb/sdk/lib/_wburls_generate.py +0 -25
- wandb/sdk/lib/_wburls_generated.py +0 -22
- wandb/sdk/lib/wburls.py +0 -46
- {wandb-0.18.4.dist-info → wandb-0.18.6.dist-info}/WHEEL +0 -0
- {wandb-0.18.4.dist-info → wandb-0.18.6.dist-info}/entry_points.txt +0 -0
- {wandb-0.18.4.dist-info → wandb-0.18.6.dist-info}/licenses/LICENSE +0 -0
wandb/apis/public/teams.py
CHANGED
@@ -127,7 +127,7 @@ class Team(Attrs):
|
|
127
127
|
def create(cls, api, team, admin_username=None):
|
128
128
|
"""Create a new team.
|
129
129
|
|
130
|
-
|
130
|
+
Args:
|
131
131
|
api: (`Api`) The api instance to use
|
132
132
|
team: (str) The name of the team
|
133
133
|
admin_username: (str) optional username of the admin user of the team, defaults to the current user.
|
@@ -147,7 +147,7 @@ class Team(Attrs):
|
|
147
147
|
def invite(self, username_or_email, admin=False):
|
148
148
|
"""Invite a user to a team.
|
149
149
|
|
150
|
-
|
150
|
+
Args:
|
151
151
|
username_or_email: (str) The username or email address of the user you want to invite
|
152
152
|
admin: (bool) Whether to make this user a team admin, defaults to False
|
153
153
|
|
@@ -168,7 +168,7 @@ class Team(Attrs):
|
|
168
168
|
def create_service_account(self, description):
|
169
169
|
"""Create a service account for the team.
|
170
170
|
|
171
|
-
|
171
|
+
Args:
|
172
172
|
description: (str) A description for this service account
|
173
173
|
|
174
174
|
Returns:
|
wandb/apis/public/users.py
CHANGED
@@ -62,7 +62,7 @@ class User(Attrs):
|
|
62
62
|
def create(cls, api, email, admin=False):
|
63
63
|
"""Create a new user.
|
64
64
|
|
65
|
-
|
65
|
+
Args:
|
66
66
|
api: (`Api`) The api instance to use
|
67
67
|
email: (str) The name of the team
|
68
68
|
admin: (bool) Whether this user should be a global instance admin
|
@@ -0,0 +1,68 @@
|
|
1
|
+
import re
|
2
|
+
from enum import Enum
|
3
|
+
from urllib.parse import urlparse
|
4
|
+
|
5
|
+
from wandb.sdk.artifacts._validators import is_artifact_registry_project
|
6
|
+
|
7
|
+
|
8
|
+
def parse_s3_url_to_s3_uri(url) -> str:
|
9
|
+
"""Convert an S3 HTTP(S) URL to an S3 URI.
|
10
|
+
|
11
|
+
Arguments:
|
12
|
+
url (str): The S3 URL to convert, in the format
|
13
|
+
'http(s)://<bucket>.s3.<region>.amazonaws.com/<key>'.
|
14
|
+
or 'http(s)://<bucket>.s3.amazonaws.com/<key>'
|
15
|
+
|
16
|
+
Returns:
|
17
|
+
str: The corresponding S3 URI in the format 's3://<bucket>/<key>'.
|
18
|
+
|
19
|
+
Raises:
|
20
|
+
ValueError: If the provided URL is not a valid S3 URL.
|
21
|
+
"""
|
22
|
+
# Regular expression to match S3 URL pattern
|
23
|
+
s3_pattern = r"^https?://.*s3.*amazonaws\.com.*"
|
24
|
+
parsed_url = urlparse(url)
|
25
|
+
|
26
|
+
# Check if it's an S3 URL
|
27
|
+
match = re.match(s3_pattern, parsed_url.geturl())
|
28
|
+
if not match:
|
29
|
+
raise ValueError("Invalid S3 URL")
|
30
|
+
|
31
|
+
# Extract bucket name and key
|
32
|
+
bucket_name, *_ = parsed_url.netloc.split(".")
|
33
|
+
key = parsed_url.path.lstrip("/")
|
34
|
+
|
35
|
+
# Construct the S3 URI
|
36
|
+
s3_uri = f"s3://{bucket_name}/{key}"
|
37
|
+
|
38
|
+
return s3_uri
|
39
|
+
|
40
|
+
|
41
|
+
class PathType(Enum):
|
42
|
+
"""We have lots of different paths users pass in to fetch artifacts, projects, etc.
|
43
|
+
|
44
|
+
This enum is used for specifying what format the path is in given a string path.
|
45
|
+
"""
|
46
|
+
|
47
|
+
PROJECT = "PROJECT"
|
48
|
+
ARTIFACT = "ARTIFACT"
|
49
|
+
|
50
|
+
|
51
|
+
def parse_org_from_registry_path(path: str, path_type: PathType) -> str:
|
52
|
+
"""Parse the org from a registry path.
|
53
|
+
|
54
|
+
Essentially fetching the "entity" from the path but for Registries the entity is actually the org.
|
55
|
+
|
56
|
+
Args:
|
57
|
+
path (str): The path to parse. Can be a project path <entity>/<project> or <project> or an
|
58
|
+
artifact path like <entity>/<project>/<artifact> or <project>/<artifact> or <artifact>
|
59
|
+
path_type (PathType): The type of path to parse.
|
60
|
+
"""
|
61
|
+
parts = path.split("/")
|
62
|
+
expected_parts = 3 if path_type == PathType.ARTIFACT else 2
|
63
|
+
|
64
|
+
if len(parts) >= expected_parts:
|
65
|
+
org, project = parts[:2]
|
66
|
+
if is_artifact_registry_project(project):
|
67
|
+
return org
|
68
|
+
return ""
|
wandb/bin/gpu_stats
CHANGED
Binary file
|
wandb/cli/cli.py
CHANGED
@@ -30,7 +30,9 @@ import wandb.sdk.verify.verify as wandb_verify
|
|
30
30
|
from wandb import Config, Error, env, util, wandb_agent, wandb_sdk
|
31
31
|
from wandb.apis import InternalApi, PublicApi
|
32
32
|
from wandb.apis.public import RunQueue
|
33
|
+
from wandb.errors.links import url_registry
|
33
34
|
from wandb.integration.magic import magic_install
|
35
|
+
from wandb.sdk.artifacts._validators import is_artifact_registry_project
|
34
36
|
from wandb.sdk.artifacts.artifact_file_cache import get_artifact_file_cache
|
35
37
|
from wandb.sdk.launch import utils as launch_utils
|
36
38
|
from wandb.sdk.launch._launch_add import _launch_add
|
@@ -38,7 +40,6 @@ from wandb.sdk.launch.errors import ExecutionError, LaunchError
|
|
38
40
|
from wandb.sdk.launch.sweeps import utils as sweep_utils
|
39
41
|
from wandb.sdk.launch.sweeps.scheduler import Scheduler
|
40
42
|
from wandb.sdk.lib import filesystem
|
41
|
-
from wandb.sdk.lib.wburls import wburls
|
42
43
|
from wandb.sync import SyncManager, get_run_from_path, get_runs
|
43
44
|
|
44
45
|
from .beta import beta
|
@@ -1172,7 +1173,7 @@ def launch_sweep(
|
|
1172
1173
|
wandb.termlog(f"Scheduler added to launch queue ({queue})")
|
1173
1174
|
|
1174
1175
|
|
1175
|
-
@cli.command(help=f"Launch or queue a W&B Job. See {
|
1176
|
+
@cli.command(help=f"Launch or queue a W&B Job. See {url_registry.url('wandb-launch')}")
|
1176
1177
|
@click.option(
|
1177
1178
|
"--uri",
|
1178
1179
|
"-u",
|
@@ -1566,7 +1567,6 @@ def launch(
|
|
1566
1567
|
"queues",
|
1567
1568
|
default=None,
|
1568
1569
|
multiple=True,
|
1569
|
-
metavar="<queue(s)>",
|
1570
1570
|
help="The name of a queue for the agent to watch. Multiple -q flags supported.",
|
1571
1571
|
)
|
1572
1572
|
@click.option(
|
@@ -2391,6 +2391,15 @@ def get(path, root, type):
|
|
2391
2391
|
artifact_name = artifact_parts[0]
|
2392
2392
|
else:
|
2393
2393
|
version = "latest"
|
2394
|
+
if is_artifact_registry_project(project):
|
2395
|
+
organization = path.split("/")[0] if path.count("/") == 2 else ""
|
2396
|
+
# set entity to match the settings since in above code it was potentially set to an org
|
2397
|
+
settings_entity = public_api.settings["entity"] or public_api.default_entity
|
2398
|
+
# Registry artifacts are under the org entity. Because we offer a shorthand and alias for this path,
|
2399
|
+
# we need to fetch the org entity to for the user behind the scenes.
|
2400
|
+
entity = InternalApi()._resolve_org_entity_name(
|
2401
|
+
entity=settings_entity, organization=organization
|
2402
|
+
)
|
2394
2403
|
full_path = f"{entity}/{project}/{artifact_name}:{version}"
|
2395
2404
|
wandb.termlog(
|
2396
2405
|
"Downloading {type} artifact {full_path}".format(
|
wandb/data_types.py
CHANGED
@@ -6,7 +6,7 @@ flexible containers for information, like tables and HTML, and more.
|
|
6
6
|
For more on logging media, see [our guide](https://docs.wandb.com/guides/track/log/media)
|
7
7
|
|
8
8
|
For more on logging structured data for interactive dataset and model analysis,
|
9
|
-
see [our guide to W&B Tables](https://docs.wandb.com/guides/
|
9
|
+
see [our guide to W&B Tables](https://docs.wandb.com/guides/tables/).
|
10
10
|
|
11
11
|
All of these special data types are subclasses of WBValue. All the data types
|
12
12
|
serialize to JSON, since that is what wandb uses to save the objects locally
|
wandb/docker/__init__.py
CHANGED
@@ -87,7 +87,8 @@ def is_docker_installed() -> bool:
|
|
87
87
|
try:
|
88
88
|
# Run the docker --version command
|
89
89
|
result = subprocess.run(
|
90
|
-
["docker", "--version"],
|
90
|
+
["docker", "--version"],
|
91
|
+
capture_output=True,
|
91
92
|
)
|
92
93
|
if result.returncode == 0:
|
93
94
|
return True
|
wandb/docker/auth.py
CHANGED
@@ -151,7 +151,7 @@ class AuthConfig(dict):
|
|
151
151
|
) -> Dict[str, Dict[str, Any]]:
|
152
152
|
"""Parse authentication entries.
|
153
153
|
|
154
|
-
|
154
|
+
Args:
|
155
155
|
entries: Dict of authentication entries.
|
156
156
|
raise_on_error: If set to true, an invalid format will raise
|
157
157
|
InvalidConfigFileError
|
@@ -386,7 +386,7 @@ def parse_auth(
|
|
386
386
|
) -> Dict[str, Dict[str, Any]]:
|
387
387
|
"""Parse authentication entries.
|
388
388
|
|
389
|
-
|
389
|
+
Args:
|
390
390
|
entries: Dict of authentication entries.
|
391
391
|
raise_on_error: If set to true, an invalid format will raise
|
392
392
|
InvalidConfigFileError
|
@@ -430,7 +430,6 @@ def _load_legacy_config(
|
|
430
430
|
}
|
431
431
|
except Exception as e:
|
432
432
|
log.debug(e)
|
433
|
-
pass
|
434
433
|
|
435
434
|
log.debug("All parsing attempts failed - returning empty config")
|
436
435
|
return {}
|
wandb/errors/links.py
ADDED
@@ -0,0 +1,73 @@
|
|
1
|
+
"""Module containing the WBURLs class and WBURL dataclass.
|
2
|
+
|
3
|
+
Used to store predefined URLs that can be associated with a name. The URLs are
|
4
|
+
shortened using with the `wandb.me` domain, using dub.co as the shortening service.
|
5
|
+
If the URLs need to be updates, use the dub.co service to point to the new URL.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from __future__ import annotations
|
9
|
+
|
10
|
+
from dataclasses import dataclass
|
11
|
+
|
12
|
+
|
13
|
+
@dataclass
|
14
|
+
class WBURL:
|
15
|
+
url: str
|
16
|
+
description: str
|
17
|
+
|
18
|
+
|
19
|
+
class Registry:
|
20
|
+
"""A collection of URLs that can be associated with a name."""
|
21
|
+
|
22
|
+
def __init__(self):
|
23
|
+
self.urls: dict[str, WBURL] = {
|
24
|
+
"wandb-launch": WBURL(
|
25
|
+
"https://wandb.me/launch",
|
26
|
+
"Link to the W&B launch marketing page",
|
27
|
+
),
|
28
|
+
"wandb-init": WBURL(
|
29
|
+
"https://wandb.me/wandb-init",
|
30
|
+
"Link to the wandb.init reference documentation page",
|
31
|
+
),
|
32
|
+
"define-metric": WBURL(
|
33
|
+
"https://wandb.me/define-metric",
|
34
|
+
"Link to the W&B developer guide documentation page on wandb.define_metric",
|
35
|
+
),
|
36
|
+
"developer-guide": WBURL(
|
37
|
+
"https://wandb.me/developer-guide",
|
38
|
+
"Link to the W&B developer guide top level page",
|
39
|
+
),
|
40
|
+
"wandb-core": WBURL(
|
41
|
+
"https://wandb.me/wandb-core",
|
42
|
+
"Link to the documentation for the wandb-core service",
|
43
|
+
),
|
44
|
+
"wandb-server": WBURL(
|
45
|
+
"https://wandb.me/wandb-server",
|
46
|
+
"Link to the documentation for the self-hosted W&B server",
|
47
|
+
),
|
48
|
+
"multiprocess": WBURL(
|
49
|
+
"https://wandb.me/multiprocess",
|
50
|
+
(
|
51
|
+
"Link to the W&B developer guide documentation page on how to "
|
52
|
+
"use wandb in a multiprocess environment"
|
53
|
+
),
|
54
|
+
),
|
55
|
+
}
|
56
|
+
|
57
|
+
def url(self, name: str) -> str:
|
58
|
+
"""Get the URL associated with the given name."""
|
59
|
+
wb_url = self.urls.get(name)
|
60
|
+
if wb_url:
|
61
|
+
return wb_url.url
|
62
|
+
raise ValueError(f"URL not found for {name}")
|
63
|
+
|
64
|
+
def description(self, name: str) -> str:
|
65
|
+
"""Get the description associated with the given name."""
|
66
|
+
wb_url = self.urls.get(name)
|
67
|
+
if wb_url:
|
68
|
+
return wb_url.description
|
69
|
+
raise ValueError(f"Description not found for {name}")
|
70
|
+
|
71
|
+
|
72
|
+
# This is an instance of the Links class that can be used to access the URLs
|
73
|
+
url_registry = Registry()
|
wandb/errors/term.py
CHANGED
@@ -136,12 +136,13 @@ def dynamic_text() -> Iterator[DynamicBlock | None]:
|
|
136
136
|
with _dynamic_text_lock:
|
137
137
|
_dynamic_blocks.append(block)
|
138
138
|
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
139
|
+
try:
|
140
|
+
yield block
|
141
|
+
finally:
|
142
|
+
with _dynamic_text_lock:
|
143
|
+
block._lines_to_print = []
|
144
|
+
_l_rerender_dynamic_blocks()
|
145
|
+
_dynamic_blocks.remove(block)
|
145
146
|
|
146
147
|
|
147
148
|
def _sys_stderr_isatty() -> bool:
|
wandb/filesync/step_prepare.py
CHANGED
@@ -149,7 +149,7 @@ class StepPrepare:
|
|
149
149
|
) -> Mapping[str, "CreateArtifactFilesResponseFile"]:
|
150
150
|
"""Execute the prepareFiles API call.
|
151
151
|
|
152
|
-
|
152
|
+
Args:
|
153
153
|
batch: List of RequestPrepare objects
|
154
154
|
Returns:
|
155
155
|
dict of (save_name: ResponseFile) pairs where ResponseFile is a dict with
|
wandb/filesync/upload_job.py
CHANGED
@@ -13,7 +13,7 @@ from wandb.sdk.lib import telemetry as wb_telemetry
|
|
13
13
|
class WandbCallback:
|
14
14
|
"""`WandbCallback` automatically integrates CatBoost with wandb.
|
15
15
|
|
16
|
-
|
16
|
+
Args:
|
17
17
|
- metric_period: (int) if you are passing `metric_period` to your CatBoost model please pass the same value here (default=1).
|
18
18
|
|
19
19
|
Passing `WandbCallback` to CatBoost will:
|
@@ -115,7 +115,7 @@ def log_summary(
|
|
115
115
|
) -> None:
|
116
116
|
"""`log_summary` logs useful metrics about catboost model after training is done.
|
117
117
|
|
118
|
-
|
118
|
+
Args:
|
119
119
|
model: it can be CatBoostClassifier or CatBoostRegressor.
|
120
120
|
log_all_params: (boolean) if True (default) log the model hyperparameters as W&B config.
|
121
121
|
save_model_checkpoint: (boolean) if True saves the model upload as W&B artifacts.
|
@@ -28,7 +28,7 @@ class DiffusersPipelineResolver:
|
|
28
28
|
) -> Any:
|
29
29
|
"""Main call method for the `DiffusersPipelineResolver` class.
|
30
30
|
|
31
|
-
|
31
|
+
Args:
|
32
32
|
args: (Sequence[Any]) List of arguments.
|
33
33
|
kwargs: (Dict[str, Any]) Dictionary of keyword arguments.
|
34
34
|
response: (wandb.sdk.integration_utils.auto_logging.Response) The response from
|
@@ -597,7 +597,7 @@ class DiffusersMultiModalPipelineResolver:
|
|
597
597
|
`__call__` for `wandb.integration.diffusers.pipeline_resolver.DiffusersPipelineResolver`.
|
598
598
|
This is based on `wandb.sdk.integration_utils.auto_logging.RequestResponseResolver`.
|
599
599
|
|
600
|
-
|
600
|
+
Args:
|
601
601
|
pipeline_name: (str) The name of the Diffusion Pipeline.
|
602
602
|
"""
|
603
603
|
|
@@ -621,7 +621,7 @@ class DiffusersMultiModalPipelineResolver:
|
|
621
621
|
) -> Any:
|
622
622
|
"""Main call method for the `DiffusersPipelineResolver` class.
|
623
623
|
|
624
|
-
|
624
|
+
Args:
|
625
625
|
args: (Sequence[Any]) List of arguments.
|
626
626
|
kwargs: (Dict[str, Any]) Dictionary of keyword arguments.
|
627
627
|
response: (wandb.sdk.integration_utils.auto_logging.Response) The response from
|
@@ -678,7 +678,7 @@ class DiffusersMultiModalPipelineResolver:
|
|
678
678
|
def get_output_images(self, response: Response) -> List:
|
679
679
|
"""Unpack the generated images, audio, video, etc. from the Diffusion Pipeline's response.
|
680
680
|
|
681
|
-
|
681
|
+
Args:
|
682
682
|
response: (wandb.sdk.integration_utils.auto_logging.Response) The response from
|
683
683
|
the request.
|
684
684
|
|
@@ -704,7 +704,7 @@ class DiffusersMultiModalPipelineResolver:
|
|
704
704
|
def log_media(self, image: Any, loggable_kwarg_chunks: List, idx: int) -> None:
|
705
705
|
"""Log the generated images, audio, video, etc. from the Diffusion Pipeline's response along with an optional caption to a media panel in the run.
|
706
706
|
|
707
|
-
|
707
|
+
Args:
|
708
708
|
image: (Any) The generated images, audio, video, etc. from the Diffusion
|
709
709
|
Pipeline's response.
|
710
710
|
loggable_kwarg_chunks: (List) Loggable chunks of kwargs.
|
@@ -780,7 +780,7 @@ class DiffusersMultiModalPipelineResolver:
|
|
780
780
|
) -> None:
|
781
781
|
"""Populate the row of the `wandb.Table`.
|
782
782
|
|
783
|
-
|
783
|
+
Args:
|
784
784
|
image: (Any) The generated images, audio, video, etc. from the Diffusion
|
785
785
|
Pipeline's response.
|
786
786
|
loggable_kwarg_chunks: (List) Loggable chunks of kwargs.
|
@@ -819,7 +819,7 @@ class DiffusersMultiModalPipelineResolver:
|
|
819
819
|
) -> Dict[str, Any]:
|
820
820
|
"""Prepare the loggable dictionary, which is the packed data as a dictionary for logging to wandb, None if an exception occurred.
|
821
821
|
|
822
|
-
|
822
|
+
Args:
|
823
823
|
pipeline: (Any) The Diffusion Pipeline.
|
824
824
|
response: (wandb.sdk.integration_utils.auto_logging.Response) The response from
|
825
825
|
the request.
|
@@ -69,7 +69,7 @@ def postprocess_np_arrays_for_video(
|
|
69
69
|
def decode_sdxl_t2i_latents(pipeline: Any, latents: "torch_float_tensor") -> List:
|
70
70
|
"""Decode latents generated by [`diffusers.StableDiffusionXLPipeline`](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion/stable_diffusion_xl#stable-diffusion-xl).
|
71
71
|
|
72
|
-
|
72
|
+
Args:
|
73
73
|
pipeline: (diffusers.DiffusionPipeline) The Diffusion Pipeline from
|
74
74
|
[`diffusers`](https://huggingface.co/docs/diffusers).
|
75
75
|
latents (torch.FloatTensor): The generated latents.
|
@@ -50,11 +50,12 @@ import fastai
|
|
50
50
|
from fastai.callbacks import TrackerCallback
|
51
51
|
|
52
52
|
import wandb
|
53
|
+
from wandb.sdk.lib import ipython
|
53
54
|
|
54
55
|
try:
|
55
56
|
import matplotlib
|
56
57
|
|
57
|
-
if
|
58
|
+
if not ipython.in_jupyter():
|
58
59
|
matplotlib.use("Agg") # non-interactive backend (avoid tkinter issues)
|
59
60
|
import matplotlib.pyplot as plt
|
60
61
|
except ImportError:
|
@@ -66,7 +67,7 @@ class WandbCallback(TrackerCallback):
|
|
66
67
|
|
67
68
|
Optionally logs weights, gradients, sample predictions and best trained model.
|
68
69
|
|
69
|
-
|
70
|
+
Args:
|
70
71
|
learn (fastai.basic_train.Learner): the fast.ai learner to hook.
|
71
72
|
log (str): "gradients", "parameters", "all", or None. Losses & metrics are always logged.
|
72
73
|
save_model (bool): save model at the end of each epoch. It will also load best model at the end of training.
|
@@ -38,7 +38,7 @@ class WandbMetricsLogger(callbacks.Callback):
|
|
38
38
|
`step_size` is number of training steps per epoch. `step_size` can be calculated as
|
39
39
|
the product of the cardinality of the training dataset and the batch size.
|
40
40
|
|
41
|
-
|
41
|
+
Args:
|
42
42
|
log_freq: ("epoch", "batch", or int) if "epoch", logs metrics
|
43
43
|
at the end of each epoch. If "batch", logs metrics at the end
|
44
44
|
of each batch. If an integer, logs metrics at the end of that
|
@@ -45,7 +45,7 @@ class WandbModelCheckpoint(callbacks.ModelCheckpoint):
|
|
45
45
|
- Save only model weights, or save the whole model.
|
46
46
|
- Save the model either in SavedModel format or in `.h5` format.
|
47
47
|
|
48
|
-
|
48
|
+
Args:
|
49
49
|
filepath: (Union[str, os.PathLike]) path to save the model file. `filepath`
|
50
50
|
can contain named formatting options, which will be filled by the value
|
51
51
|
of `epoch` and keys in `logs` (passed in `on_epoch_end`). For example:
|
wandb/integration/keras/keras.py
CHANGED
@@ -313,7 +313,7 @@ class WandbCallback(tf.keras.callbacks.Callback):
|
|
313
313
|
|
314
314
|
`WandbCallback` can optionally save training and validation data for wandb to visualize.
|
315
315
|
|
316
|
-
|
316
|
+
Args:
|
317
317
|
monitor: (str) name of metric to monitor. Defaults to `val_loss`.
|
318
318
|
mode: (str) one of {`auto`, `min`, `max`}.
|
319
319
|
`min` - save model when monitor is minimized
|
@@ -163,7 +163,7 @@ def _get_function_source_definition(func: Callable) -> str:
|
|
163
163
|
|
164
164
|
# For wandb, allow decorators (so we can use the @wandb_log decorator)
|
165
165
|
func_code_lines = itertools.dropwhile(
|
166
|
-
lambda x: not (x.startswith("def"
|
166
|
+
lambda x: not (x.startswith(("def", "@wandb_log"))),
|
167
167
|
func_code_lines,
|
168
168
|
)
|
169
169
|
|
@@ -155,7 +155,7 @@ class _WandbCallback:
|
|
155
155
|
def wandb_callback(log_params: bool = True, define_metric: bool = True) -> Callable:
|
156
156
|
"""Automatically integrates LightGBM with wandb.
|
157
157
|
|
158
|
-
|
158
|
+
Args:
|
159
159
|
log_params: (boolean) if True (default) logs params passed to lightgbm.train as W&B config
|
160
160
|
define_metric: (boolean) if True (default) capture model performance at the best step, instead of the last step, of training in your `wandb.summary`
|
161
161
|
|
@@ -190,7 +190,7 @@ def log_summary(
|
|
190
190
|
) -> None:
|
191
191
|
"""Log useful metrics about lightgbm model after training is done.
|
192
192
|
|
193
|
-
|
193
|
+
Args:
|
194
194
|
model: (Booster) is an instance of lightgbm.basic.Booster.
|
195
195
|
feature_importance: (boolean) if True (default), logs the feature importance plot.
|
196
196
|
save_model_checkpoint: (boolean) if True saves the best model and upload as W&B artifacts.
|
wandb/integration/magic.py
CHANGED
@@ -9,6 +9,7 @@ import yaml
|
|
9
9
|
|
10
10
|
import wandb
|
11
11
|
from wandb import trigger
|
12
|
+
from wandb.sdk.lib import ipython
|
12
13
|
from wandb.util import add_import_hook, get_optional_module
|
13
14
|
|
14
15
|
_import_hook = None
|
@@ -505,8 +506,7 @@ def magic_install(init_args=None):
|
|
505
506
|
# process system args
|
506
507
|
_process_system_args()
|
507
508
|
# install argparse wrapper
|
508
|
-
|
509
|
-
if not in_jupyter_or_ipython:
|
509
|
+
if not ipython.in_notebook():
|
510
510
|
_monkey_argparse()
|
511
511
|
|
512
512
|
# track init calls
|
@@ -300,7 +300,7 @@ def wandb_log(
|
|
300
300
|
- Decorating the flow is equivalent to decorating all steps with a default
|
301
301
|
- Decorating a step after decorating the flow will overwrite the flow decoration
|
302
302
|
|
303
|
-
|
303
|
+
Args:
|
304
304
|
func: (`Callable`). The method or class being decorated (if decorating a step or flow respectively).
|
305
305
|
datasets: (`bool`). If `True`, log datasets. Datasets can be a `pd.DataFrame` or `pathlib.Path`. The default value is `False`, so datasets are not logged.
|
306
306
|
models: (`bool`). If `True`, log models. Models can be a `nn.Module` or `sklearn.base.BaseEstimator`. The default value is `False`, so models are not logged.
|
@@ -10,7 +10,7 @@ import wandb
|
|
10
10
|
class WandbObserver(RunObserver):
|
11
11
|
"""Log sacred experiment data to W&B.
|
12
12
|
|
13
|
-
|
13
|
+
Args:
|
14
14
|
Accepts all the arguments accepted by wandb.init().
|
15
15
|
|
16
16
|
name — A display name for this run, which shows up in the UI and is editable, doesn't have to be unique
|
@@ -7,7 +7,7 @@ from wandb import env
|
|
7
7
|
def sagemaker_auth(overrides=None, path=".", api_key=None):
|
8
8
|
"""Write a secrets.env file with the W&B ApiKey and any additional secrets passed.
|
9
9
|
|
10
|
-
|
10
|
+
Args:
|
11
11
|
overrides (dict, optional): Additional environment variables to write
|
12
12
|
to secrets.env
|
13
13
|
path (str, optional): The path to write the secrets file.
|
@@ -37,7 +37,7 @@ def classifier(
|
|
37
37
|
|
38
38
|
Should only be called with a fitted classifier (otherwise an error is thrown).
|
39
39
|
|
40
|
-
|
40
|
+
Args:
|
41
41
|
model: (classifier) Takes in a fitted classifier.
|
42
42
|
X_train: (arr) Training set features.
|
43
43
|
y_train: (arr) Training set labels.
|
@@ -117,7 +117,7 @@ def roc(
|
|
117
117
|
):
|
118
118
|
"""Log the receiver-operating characteristic curve.
|
119
119
|
|
120
|
-
|
120
|
+
Args:
|
121
121
|
y_true: (arr) Test set labels.
|
122
122
|
y_probas: (arr) Test set predicted probabilities.
|
123
123
|
labels: (list) Named labels for target variable (y). Makes plots easier to
|
@@ -150,7 +150,7 @@ def confusion_matrix(
|
|
150
150
|
|
151
151
|
Confusion matrices depict the pattern of misclassifications by a model.
|
152
152
|
|
153
|
-
|
153
|
+
Args:
|
154
154
|
y_true: (arr) Test set labels.
|
155
155
|
y_probas: (arr) Test set predicted probabilities.
|
156
156
|
labels: (list) Named labels for target variable (y). Makes plots easier to
|
@@ -194,7 +194,7 @@ def precision_recall(
|
|
194
194
|
Precision-recall curves depict the tradeoff between positive predictive value (precision)
|
195
195
|
and true positive rate (recall) as the threshold of a classifier is shifted.
|
196
196
|
|
197
|
-
|
197
|
+
Args:
|
198
198
|
y_true: (arr) Test set labels.
|
199
199
|
y_probas: (arr) Test set predicted probabilities.
|
200
200
|
labels: (list) Named labels for target variable (y). Makes plots easier to
|
@@ -226,7 +226,7 @@ def feature_importances(
|
|
226
226
|
Should only be called with a fitted classifier (otherwise an error is thrown).
|
227
227
|
Only works with classifiers that have a feature_importances_ attribute, like trees.
|
228
228
|
|
229
|
-
|
229
|
+
Args:
|
230
230
|
model: (clf) Takes in a fitted classifier.
|
231
231
|
feature_names: (list) Names for features. Makes plots easier to read by
|
232
232
|
replacing feature indexes with corresponding names.
|
@@ -254,7 +254,7 @@ def class_proportions(y_train=None, y_test=None, labels=None):
|
|
254
254
|
|
255
255
|
Useful for detecting imbalanced classes.
|
256
256
|
|
257
|
-
|
257
|
+
Args:
|
258
258
|
y_train: (arr) Training set labels.
|
259
259
|
y_test: (arr) Test set labels.
|
260
260
|
labels: (list) Named labels for target variable (y). Makes plots easier to
|
@@ -298,7 +298,7 @@ def calibration_curve(clf=None, X=None, y=None, clf_name="Classifier"): # noqa:
|
|
298
298
|
|
299
299
|
Please note this function fits variations of the model on the training set when called.
|
300
300
|
|
301
|
-
|
301
|
+
Args:
|
302
302
|
clf: (clf) Takes in a fitted classifier.
|
303
303
|
X: (arr) Training set features.
|
304
304
|
y: (arr) Training set labels.
|
@@ -20,7 +20,7 @@ def clusterer(model, X_train, cluster_labels, labels=None, model_name="Clusterer
|
|
20
20
|
|
21
21
|
Should only be called with a fitted clusterer (otherwise an error is thrown).
|
22
22
|
|
23
|
-
|
23
|
+
Args:
|
24
24
|
model: (clusterer) Takes in a fitted clusterer.
|
25
25
|
X_train: (arr) Training set features.
|
26
26
|
cluster_labels: (list) Names for cluster labels. Makes plots easier to read
|
@@ -68,7 +68,7 @@ def elbow_curve(
|
|
68
68
|
|
69
69
|
Please note this function fits the model on the training set when called.
|
70
70
|
|
71
|
-
|
71
|
+
Args:
|
72
72
|
model: (clusterer) Takes in a fitted clusterer.
|
73
73
|
X: (arr) Training set features.
|
74
74
|
|
@@ -118,7 +118,7 @@ def silhouette(
|
|
118
118
|
|
119
119
|
Please note this function fits the model on the training set when called.
|
120
120
|
|
121
|
-
|
121
|
+
Args:
|
122
122
|
model: (clusterer) Takes in a fitted clusterer.
|
123
123
|
X: (arr) Training set features.
|
124
124
|
cluster_labels: (list) Names for cluster labels. Makes plots easier to read
|