wandb-util 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,11 @@
1
+ Metadata-Version: 2.3
2
+ Name: wandb-util
3
+ Version: 0.1.0
4
+ Summary: Utility tool providing extended functionality for W&B (Weights & Biases)
5
+ Author: Xiao Li
6
+ Author-email: Xiao Li <xiaoli3397@gmail.com>
7
+ Requires-Dist: click>=8.3.3
8
+ Requires-Dist: wandb>=0.26.1
9
+ Requires-Python: >=3.13
10
+ Description-Content-Type: text/markdown
11
+
File without changes
@@ -0,0 +1,20 @@
1
+ [project]
2
+ name = "wandb-util"
3
+ version = "0.1.0"
4
+ description = "Utility tool providing extended functionality for W&B (Weights & Biases)"
5
+ readme = "README.md"
6
+ authors = [
7
+ { name = "Xiao Li", email = "xiaoli3397@gmail.com" }
8
+ ]
9
+ requires-python = ">=3.13"
10
+ dependencies = [
11
+ "click>=8.3.3",
12
+ "wandb>=0.26.1",
13
+ ]
14
+
15
+ [project.scripts]
16
+ wandb-util = "wandb_util.__main__:cli"
17
+
18
+ [build-system]
19
+ requires = ["uv_build>=0.11.7,<0.12.0"]
20
+ build-backend = "uv_build"
@@ -0,0 +1,2 @@
1
+ def hello() -> str:
2
+ return "Hello from wandb-util!"
@@ -0,0 +1,23 @@
1
+ import click
2
+
3
+ from .commands import artifact, log, run
4
+
5
+
6
+ @click.group()
7
+ @click.option("--entity", "-e", required=True, help="W&B entity (username or team).")
8
+ @click.option("--project", "-p", required=True, help="W&B project name.")
9
+ @click.pass_context
10
+ def cli(ctx, entity, project):
11
+ """WandB utility tool providing extended functionality."""
12
+ ctx.ensure_object(dict)
13
+ ctx.obj["entity"] = entity
14
+ ctx.obj["project"] = project
15
+
16
+
17
+ cli.add_command(artifact)
18
+ cli.add_command(log)
19
+ cli.add_command(run)
20
+
21
+
22
+ if __name__ == "__main__":
23
+ cli()
@@ -0,0 +1,5 @@
1
+ from .artifact import artifact
2
+ from .log import log
3
+ from .run import run
4
+
5
+ __all__ = ["artifact", "log", "run"]
@@ -0,0 +1,88 @@
1
+ from pathlib import Path
2
+
3
+ import click
4
+ import wandb
5
+
6
+ from ..utils import handle_api_errors
7
+
8
+
9
+ def fmt_size(n_bytes: int) -> str:
10
+ size: float = float(n_bytes)
11
+ for unit in ("B", "KB", "MB", "GB"):
12
+ if size < 1024:
13
+ return f"{size:.1f} {unit}"
14
+ size /= 1024
15
+ return f"{size:.1f} TB"
16
+
17
+
18
+ def _list_artifacts(
19
+ api: wandb.Api, entity: str, project: str, run_id: str, artifact_type: str
20
+ ) -> None:
21
+ """Core logic to list artifacts for a run."""
22
+ run = api.run(f"{entity}/{project}/{run_id}")
23
+ artifacts = sorted(
24
+ [a for a in run.logged_artifacts() if a.type == artifact_type],
25
+ key=lambda x: x.metadata.get("global_step", 0),
26
+ )
27
+
28
+ if not artifacts:
29
+ print(f"No artifacts found for run '{run_id}'")
30
+ return
31
+
32
+ print(f"Run: {run.name} ({run_id}) state={run.state}")
33
+ print()
34
+ print(f"{'Version':<28} {'Step':>8} {'Size':>10} {'Created'}")
35
+ print("-" * 72)
36
+ for a in artifacts:
37
+ step = a.metadata.get("global_step", "-")
38
+ created = a.created_at[:19].replace("T", " ") if a.created_at else "-"
39
+ print(f"{a.name:<28} {str(step):>8} {fmt_size(a.size):>10} {created}")
40
+
41
+
42
+ def _download_artifact(
43
+ api: wandb.Api, entity: str, project: str, artifact_path: str, output_dir: str
44
+ ) -> None:
45
+ """Core logic to download an artifact."""
46
+ artifact = api.artifact(f"{entity}/{project}/{artifact_path}", type="model")
47
+ out = Path(output_dir)
48
+ click.echo(f"Downloading {artifact.name} ({fmt_size(artifact.size)}) -> {out}/")
49
+ artifact.download(root=str(out))
50
+ click.echo("Done.")
51
+
52
+
53
+ @click.group()
54
+ @click.pass_context
55
+ def artifact(ctx):
56
+ """Manage W&B artifacts."""
57
+ pass
58
+
59
+
60
+ @artifact.command()
61
+ @click.argument("run_id")
62
+ @click.option("--type", "-t", default="model", help="Artifact type (default: model).")
63
+ @click.pass_context
64
+ @handle_api_errors
65
+ def list(ctx, run_id, type):
66
+ """List artifacts for a run."""
67
+ entity = ctx.obj["entity"]
68
+ project = ctx.obj["project"]
69
+ api = wandb.Api()
70
+ _list_artifacts(api, entity, project, run_id, type)
71
+
72
+
73
+ artifact.add_command(list, name="ls")
74
+
75
+
76
+ @artifact.command()
77
+ @click.argument("artifact_path")
78
+ @click.option(
79
+ "--output-dir", "-o", default=".", help="Directory to download into (default: .)."
80
+ )
81
+ @click.pass_context
82
+ @handle_api_errors
83
+ def download(ctx, artifact_path, output_dir):
84
+ """Download an artifact."""
85
+ entity = ctx.obj["entity"]
86
+ project = ctx.obj["project"]
87
+ api = wandb.Api()
88
+ _download_artifact(api, entity, project, artifact_path, output_dir)
@@ -0,0 +1,88 @@
1
+ import click
2
+ import wandb
3
+
4
+ from ..utils import handle_api_errors
5
+
6
+
7
+ def _list_logs(
8
+ api: wandb.Api,
9
+ entity: str,
10
+ project: str,
11
+ run_id: str,
12
+ show_history: bool = False,
13
+ filter_keys: str | None = None,
14
+ prefix: str | None = None,
15
+ last: int = 10,
16
+ ) -> None:
17
+ """Core logic to list logs for a run."""
18
+ run = api.run(f"{entity}/{project}/{run_id}")
19
+
20
+ print(f"Run: {run.name} ({run.id})")
21
+ print(f"State: {run.state}")
22
+ print(f"Project: {entity}/{project}")
23
+ print()
24
+
25
+ filter_set = set(filter_keys.split(",")) if filter_keys else None
26
+
27
+ if show_history:
28
+ history = run.history()
29
+ if filter_set:
30
+ cols = [c for c in history.columns if c in filter_set or c == "_step"]
31
+ else:
32
+ cols = list(history.columns)
33
+
34
+ df = history[cols].tail(last)
35
+ print(df.to_string(index=False))
36
+ else:
37
+ summary = run.summary.items()
38
+ if prefix:
39
+ summary = {k: v for k, v in summary if k.startswith(prefix)}
40
+ else:
41
+ summary = dict(summary)
42
+ if filter_set:
43
+ summary = {k: v for k, v in summary.items() if k in filter_set}
44
+
45
+ if not summary:
46
+ print("No metrics found in summary.")
47
+ return
48
+
49
+ key_width = max(len(k) for k in summary)
50
+ print(f"{'Metric':<{key_width}} Value")
51
+ print("-" * (key_width + 20))
52
+ for key, value in sorted(summary.items()):
53
+ if isinstance(value, float):
54
+ print(f"{key:<{key_width}} {value:.6g}")
55
+ else:
56
+ print(f"{key:<{key_width}} {value}")
57
+
58
+
59
+ @click.group()
60
+ @click.pass_context
61
+ def log(ctx):
62
+ """Manage W&B logs."""
63
+ pass
64
+
65
+
66
+ @log.command()
67
+ @click.argument("run_id")
68
+ @click.option("--history", is_flag=True, help="Show history instead of summary.")
69
+ @click.option("--prefix", help="Filter summary by key prefix (e.g., eval/, train/).")
70
+ @click.option("--keys", "-k", help="Filter by comma-separated keys.")
71
+ @click.option(
72
+ "--last",
73
+ "-l",
74
+ default=10,
75
+ type=int,
76
+ help="Show last N history entries (default: 10).",
77
+ )
78
+ @click.pass_context
79
+ @handle_api_errors
80
+ def list(ctx, run_id, history, prefix, keys, last):
81
+ """List logs for a run."""
82
+ entity = ctx.obj["entity"]
83
+ project = ctx.obj["project"]
84
+ api = wandb.Api()
85
+ _list_logs(api, entity, project, run_id, history, keys, prefix, last)
86
+
87
+
88
+ log.add_command(list, name="ls")
@@ -0,0 +1,47 @@
1
+ import click
2
+ import wandb
3
+
4
+ from ..utils import handle_api_errors
5
+
6
+
7
+ def _list_runs(api: wandb.Api, entity: str, project: str, limit: int) -> None:
8
+ """Core logic to list runs."""
9
+ runs = api.runs(f"{entity}/{project}", order="-created_at")
10
+ print(f"{'ID':<12} {'State':<10} {'Name':<40} {'Created':<20} {'Steps':>8}")
11
+ print("-" * 96)
12
+ count = 0
13
+ for run in runs:
14
+ if count >= limit:
15
+ break
16
+ created = run.created_at[:19].replace("T", " ") if run.created_at else "-"
17
+ steps = run.summary.get("_step", "-")
18
+ print(
19
+ f"{run.id:<12} {run.state:<10} {run.name[:40]:<40} {created:<20} {str(steps):>8}"
20
+ )
21
+ count += 1
22
+ if count == 0:
23
+ print(f"No runs found in '{entity}/{project}'")
24
+
25
+
26
+ @click.group()
27
+ @click.pass_context
28
+ def run(ctx):
29
+ """Manage W&B runs."""
30
+ pass
31
+
32
+
33
+ @run.command()
34
+ @click.option(
35
+ "--limit", "-l", default=10, type=int, help="Maximum number of runs to display."
36
+ )
37
+ @click.pass_context
38
+ @handle_api_errors
39
+ def list(ctx, limit):
40
+ """List runs."""
41
+ entity = ctx.obj["entity"]
42
+ project = ctx.obj["project"]
43
+ api = wandb.Api()
44
+ _list_runs(api, entity, project, limit)
45
+
46
+
47
+ run.add_command(list, name="ls")
File without changes
@@ -0,0 +1,30 @@
1
+ import functools
2
+
3
+ import click
4
+ import wandb
5
+
6
+
7
+ def handle_api_errors(func):
8
+ """Decorator to handle common wandb API errors."""
9
+
10
+ @functools.wraps(func)
11
+ def wrapper(*args, **kwargs):
12
+ try:
13
+ return func(*args, **kwargs)
14
+ except wandb.errors.CommError as e:
15
+ click.secho(
16
+ f"Error: Failed to connect to wandb API. {str(e)}", fg="red", err=True
17
+ )
18
+ raise SystemExit(1)
19
+ except wandb.errors.AuthenticationError:
20
+ click.secho(
21
+ "Error: Authentication failed. Check your API key in ~/.netrc",
22
+ fg="red",
23
+ err=True,
24
+ )
25
+ raise SystemExit(1)
26
+ except Exception as e:
27
+ click.secho(f"Error: {str(e)}", fg="red", err=True)
28
+ raise SystemExit(1)
29
+
30
+ return wrapper