wandb-util 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- wandb_util-0.1.0/PKG-INFO +11 -0
- wandb_util-0.1.0/README.md +0 -0
- wandb_util-0.1.0/pyproject.toml +20 -0
- wandb_util-0.1.0/src/wandb_util/__init__.py +2 -0
- wandb_util-0.1.0/src/wandb_util/__main__.py +23 -0
- wandb_util-0.1.0/src/wandb_util/commands/__init__.py +5 -0
- wandb_util-0.1.0/src/wandb_util/commands/artifact.py +88 -0
- wandb_util-0.1.0/src/wandb_util/commands/log.py +88 -0
- wandb_util-0.1.0/src/wandb_util/commands/run.py +47 -0
- wandb_util-0.1.0/src/wandb_util/py.typed +0 -0
- wandb_util-0.1.0/src/wandb_util/utils.py +30 -0
|
@@ -0,0 +1,11 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: wandb-util
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Utility tool providing extended functionality for W&B (Weights & Biases)
|
|
5
|
+
Author: Xiao Li
|
|
6
|
+
Author-email: Xiao Li <xiaoli3397@gmail.com>
|
|
7
|
+
Requires-Dist: click>=8.3.3
|
|
8
|
+
Requires-Dist: wandb>=0.26.1
|
|
9
|
+
Requires-Python: >=3.13
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
|
|
File without changes
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "wandb-util"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Utility tool providing extended functionality for W&B (Weights & Biases)"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
authors = [
|
|
7
|
+
{ name = "Xiao Li", email = "xiaoli3397@gmail.com" }
|
|
8
|
+
]
|
|
9
|
+
requires-python = ">=3.13"
|
|
10
|
+
dependencies = [
|
|
11
|
+
"click>=8.3.3",
|
|
12
|
+
"wandb>=0.26.1",
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
[project.scripts]
|
|
16
|
+
wandb-util = "wandb_util.__main__:cli"
|
|
17
|
+
|
|
18
|
+
[build-system]
|
|
19
|
+
requires = ["uv_build>=0.11.7,<0.12.0"]
|
|
20
|
+
build-backend = "uv_build"
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
import click
|
|
2
|
+
|
|
3
|
+
from .commands import artifact, log, run
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@click.group()
|
|
7
|
+
@click.option("--entity", "-e", required=True, help="W&B entity (username or team).")
|
|
8
|
+
@click.option("--project", "-p", required=True, help="W&B project name.")
|
|
9
|
+
@click.pass_context
|
|
10
|
+
def cli(ctx, entity, project):
|
|
11
|
+
"""WandB utility tool providing extended functionality."""
|
|
12
|
+
ctx.ensure_object(dict)
|
|
13
|
+
ctx.obj["entity"] = entity
|
|
14
|
+
ctx.obj["project"] = project
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
cli.add_command(artifact)
|
|
18
|
+
cli.add_command(log)
|
|
19
|
+
cli.add_command(run)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
if __name__ == "__main__":
|
|
23
|
+
cli()
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
|
|
3
|
+
import click
|
|
4
|
+
import wandb
|
|
5
|
+
|
|
6
|
+
from ..utils import handle_api_errors
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def fmt_size(n_bytes: int) -> str:
|
|
10
|
+
size: float = float(n_bytes)
|
|
11
|
+
for unit in ("B", "KB", "MB", "GB"):
|
|
12
|
+
if size < 1024:
|
|
13
|
+
return f"{size:.1f} {unit}"
|
|
14
|
+
size /= 1024
|
|
15
|
+
return f"{size:.1f} TB"
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _list_artifacts(
|
|
19
|
+
api: wandb.Api, entity: str, project: str, run_id: str, artifact_type: str
|
|
20
|
+
) -> None:
|
|
21
|
+
"""Core logic to list artifacts for a run."""
|
|
22
|
+
run = api.run(f"{entity}/{project}/{run_id}")
|
|
23
|
+
artifacts = sorted(
|
|
24
|
+
[a for a in run.logged_artifacts() if a.type == artifact_type],
|
|
25
|
+
key=lambda x: x.metadata.get("global_step", 0),
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
if not artifacts:
|
|
29
|
+
print(f"No artifacts found for run '{run_id}'")
|
|
30
|
+
return
|
|
31
|
+
|
|
32
|
+
print(f"Run: {run.name} ({run_id}) state={run.state}")
|
|
33
|
+
print()
|
|
34
|
+
print(f"{'Version':<28} {'Step':>8} {'Size':>10} {'Created'}")
|
|
35
|
+
print("-" * 72)
|
|
36
|
+
for a in artifacts:
|
|
37
|
+
step = a.metadata.get("global_step", "-")
|
|
38
|
+
created = a.created_at[:19].replace("T", " ") if a.created_at else "-"
|
|
39
|
+
print(f"{a.name:<28} {str(step):>8} {fmt_size(a.size):>10} {created}")
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _download_artifact(
|
|
43
|
+
api: wandb.Api, entity: str, project: str, artifact_path: str, output_dir: str
|
|
44
|
+
) -> None:
|
|
45
|
+
"""Core logic to download an artifact."""
|
|
46
|
+
artifact = api.artifact(f"{entity}/{project}/{artifact_path}", type="model")
|
|
47
|
+
out = Path(output_dir)
|
|
48
|
+
click.echo(f"Downloading {artifact.name} ({fmt_size(artifact.size)}) -> {out}/")
|
|
49
|
+
artifact.download(root=str(out))
|
|
50
|
+
click.echo("Done.")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
@click.group()
|
|
54
|
+
@click.pass_context
|
|
55
|
+
def artifact(ctx):
|
|
56
|
+
"""Manage W&B artifacts."""
|
|
57
|
+
pass
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
@artifact.command()
|
|
61
|
+
@click.argument("run_id")
|
|
62
|
+
@click.option("--type", "-t", default="model", help="Artifact type (default: model).")
|
|
63
|
+
@click.pass_context
|
|
64
|
+
@handle_api_errors
|
|
65
|
+
def list(ctx, run_id, type):
|
|
66
|
+
"""List artifacts for a run."""
|
|
67
|
+
entity = ctx.obj["entity"]
|
|
68
|
+
project = ctx.obj["project"]
|
|
69
|
+
api = wandb.Api()
|
|
70
|
+
_list_artifacts(api, entity, project, run_id, type)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
artifact.add_command(list, name="ls")
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@artifact.command()
|
|
77
|
+
@click.argument("artifact_path")
|
|
78
|
+
@click.option(
|
|
79
|
+
"--output-dir", "-o", default=".", help="Directory to download into (default: .)."
|
|
80
|
+
)
|
|
81
|
+
@click.pass_context
|
|
82
|
+
@handle_api_errors
|
|
83
|
+
def download(ctx, artifact_path, output_dir):
|
|
84
|
+
"""Download an artifact."""
|
|
85
|
+
entity = ctx.obj["entity"]
|
|
86
|
+
project = ctx.obj["project"]
|
|
87
|
+
api = wandb.Api()
|
|
88
|
+
_download_artifact(api, entity, project, artifact_path, output_dir)
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
import click
|
|
2
|
+
import wandb
|
|
3
|
+
|
|
4
|
+
from ..utils import handle_api_errors
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def _list_logs(
|
|
8
|
+
api: wandb.Api,
|
|
9
|
+
entity: str,
|
|
10
|
+
project: str,
|
|
11
|
+
run_id: str,
|
|
12
|
+
show_history: bool = False,
|
|
13
|
+
filter_keys: str | None = None,
|
|
14
|
+
prefix: str | None = None,
|
|
15
|
+
last: int = 10,
|
|
16
|
+
) -> None:
|
|
17
|
+
"""Core logic to list logs for a run."""
|
|
18
|
+
run = api.run(f"{entity}/{project}/{run_id}")
|
|
19
|
+
|
|
20
|
+
print(f"Run: {run.name} ({run.id})")
|
|
21
|
+
print(f"State: {run.state}")
|
|
22
|
+
print(f"Project: {entity}/{project}")
|
|
23
|
+
print()
|
|
24
|
+
|
|
25
|
+
filter_set = set(filter_keys.split(",")) if filter_keys else None
|
|
26
|
+
|
|
27
|
+
if show_history:
|
|
28
|
+
history = run.history()
|
|
29
|
+
if filter_set:
|
|
30
|
+
cols = [c for c in history.columns if c in filter_set or c == "_step"]
|
|
31
|
+
else:
|
|
32
|
+
cols = list(history.columns)
|
|
33
|
+
|
|
34
|
+
df = history[cols].tail(last)
|
|
35
|
+
print(df.to_string(index=False))
|
|
36
|
+
else:
|
|
37
|
+
summary = run.summary.items()
|
|
38
|
+
if prefix:
|
|
39
|
+
summary = {k: v for k, v in summary if k.startswith(prefix)}
|
|
40
|
+
else:
|
|
41
|
+
summary = dict(summary)
|
|
42
|
+
if filter_set:
|
|
43
|
+
summary = {k: v for k, v in summary.items() if k in filter_set}
|
|
44
|
+
|
|
45
|
+
if not summary:
|
|
46
|
+
print("No metrics found in summary.")
|
|
47
|
+
return
|
|
48
|
+
|
|
49
|
+
key_width = max(len(k) for k in summary)
|
|
50
|
+
print(f"{'Metric':<{key_width}} Value")
|
|
51
|
+
print("-" * (key_width + 20))
|
|
52
|
+
for key, value in sorted(summary.items()):
|
|
53
|
+
if isinstance(value, float):
|
|
54
|
+
print(f"{key:<{key_width}} {value:.6g}")
|
|
55
|
+
else:
|
|
56
|
+
print(f"{key:<{key_width}} {value}")
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@click.group()
|
|
60
|
+
@click.pass_context
|
|
61
|
+
def log(ctx):
|
|
62
|
+
"""Manage W&B logs."""
|
|
63
|
+
pass
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@log.command()
|
|
67
|
+
@click.argument("run_id")
|
|
68
|
+
@click.option("--history", is_flag=True, help="Show history instead of summary.")
|
|
69
|
+
@click.option("--prefix", help="Filter summary by key prefix (e.g., eval/, train/).")
|
|
70
|
+
@click.option("--keys", "-k", help="Filter by comma-separated keys.")
|
|
71
|
+
@click.option(
|
|
72
|
+
"--last",
|
|
73
|
+
"-l",
|
|
74
|
+
default=10,
|
|
75
|
+
type=int,
|
|
76
|
+
help="Show last N history entries (default: 10).",
|
|
77
|
+
)
|
|
78
|
+
@click.pass_context
|
|
79
|
+
@handle_api_errors
|
|
80
|
+
def list(ctx, run_id, history, prefix, keys, last):
|
|
81
|
+
"""List logs for a run."""
|
|
82
|
+
entity = ctx.obj["entity"]
|
|
83
|
+
project = ctx.obj["project"]
|
|
84
|
+
api = wandb.Api()
|
|
85
|
+
_list_logs(api, entity, project, run_id, history, keys, prefix, last)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
log.add_command(list, name="ls")
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import click
|
|
2
|
+
import wandb
|
|
3
|
+
|
|
4
|
+
from ..utils import handle_api_errors
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def _list_runs(api: wandb.Api, entity: str, project: str, limit: int) -> None:
|
|
8
|
+
"""Core logic to list runs."""
|
|
9
|
+
runs = api.runs(f"{entity}/{project}", order="-created_at")
|
|
10
|
+
print(f"{'ID':<12} {'State':<10} {'Name':<40} {'Created':<20} {'Steps':>8}")
|
|
11
|
+
print("-" * 96)
|
|
12
|
+
count = 0
|
|
13
|
+
for run in runs:
|
|
14
|
+
if count >= limit:
|
|
15
|
+
break
|
|
16
|
+
created = run.created_at[:19].replace("T", " ") if run.created_at else "-"
|
|
17
|
+
steps = run.summary.get("_step", "-")
|
|
18
|
+
print(
|
|
19
|
+
f"{run.id:<12} {run.state:<10} {run.name[:40]:<40} {created:<20} {str(steps):>8}"
|
|
20
|
+
)
|
|
21
|
+
count += 1
|
|
22
|
+
if count == 0:
|
|
23
|
+
print(f"No runs found in '{entity}/{project}'")
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@click.group()
|
|
27
|
+
@click.pass_context
|
|
28
|
+
def run(ctx):
|
|
29
|
+
"""Manage W&B runs."""
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@run.command()
|
|
34
|
+
@click.option(
|
|
35
|
+
"--limit", "-l", default=10, type=int, help="Maximum number of runs to display."
|
|
36
|
+
)
|
|
37
|
+
@click.pass_context
|
|
38
|
+
@handle_api_errors
|
|
39
|
+
def list(ctx, limit):
|
|
40
|
+
"""List runs."""
|
|
41
|
+
entity = ctx.obj["entity"]
|
|
42
|
+
project = ctx.obj["project"]
|
|
43
|
+
api = wandb.Api()
|
|
44
|
+
_list_runs(api, entity, project, limit)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
run.add_command(list, name="ls")
|
|
File without changes
|
|
@@ -0,0 +1,30 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
|
|
3
|
+
import click
|
|
4
|
+
import wandb
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def handle_api_errors(func):
|
|
8
|
+
"""Decorator to handle common wandb API errors."""
|
|
9
|
+
|
|
10
|
+
@functools.wraps(func)
|
|
11
|
+
def wrapper(*args, **kwargs):
|
|
12
|
+
try:
|
|
13
|
+
return func(*args, **kwargs)
|
|
14
|
+
except wandb.errors.CommError as e:
|
|
15
|
+
click.secho(
|
|
16
|
+
f"Error: Failed to connect to wandb API. {str(e)}", fg="red", err=True
|
|
17
|
+
)
|
|
18
|
+
raise SystemExit(1)
|
|
19
|
+
except wandb.errors.AuthenticationError:
|
|
20
|
+
click.secho(
|
|
21
|
+
"Error: Authentication failed. Check your API key in ~/.netrc",
|
|
22
|
+
fg="red",
|
|
23
|
+
err=True,
|
|
24
|
+
)
|
|
25
|
+
raise SystemExit(1)
|
|
26
|
+
except Exception as e:
|
|
27
|
+
click.secho(f"Error: {str(e)}", fg="red", err=True)
|
|
28
|
+
raise SystemExit(1)
|
|
29
|
+
|
|
30
|
+
return wrapper
|