dr-wandb 0.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dr_wandb/__init__.py +9 -0
- dr_wandb/cli/__init__.py +0 -0
- dr_wandb/cli/download.py +97 -0
- dr_wandb/cli/postgres_download.py +128 -0
- dr_wandb/constants.py +23 -0
- dr_wandb/downloader.py +118 -0
- dr_wandb/fetch.py +84 -0
- dr_wandb/history_entry_record.py +62 -0
- dr_wandb/py.typed +0 -0
- dr_wandb/run_record.py +115 -0
- dr_wandb/store.py +193 -0
- dr_wandb/utils.py +57 -0
- dr_wandb-0.1.2.dist-info/METADATA +179 -0
- dr_wandb-0.1.2.dist-info/RECORD +17 -0
- dr_wandb-0.1.2.dist-info/WHEEL +4 -0
- dr_wandb-0.1.2.dist-info/entry_points.txt +2 -0
- dr_wandb-0.1.2.dist-info/licenses/LICENSE +21 -0
dr_wandb/__init__.py
ADDED
dr_wandb/cli/__init__.py
ADDED
|
File without changes
|
dr_wandb/cli/download.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
import logging
|
|
3
|
+
from pydantic import BaseModel, Field, computed_field
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
import typer
|
|
6
|
+
import pickle
|
|
7
|
+
|
|
8
|
+
from dr_wandb.fetch import fetch_project_runs
|
|
9
|
+
|
|
10
|
+
app = typer.Typer()
|
|
11
|
+
|
|
12
|
+
class ProjDownloadConfig(BaseModel):
|
|
13
|
+
entity: str
|
|
14
|
+
project: str
|
|
15
|
+
output_dir: Path = Field(
|
|
16
|
+
default_factory=lambda: (
|
|
17
|
+
Path(__file__).parent.parent.parent.parent / "data"
|
|
18
|
+
)
|
|
19
|
+
)
|
|
20
|
+
runs_only: bool = False
|
|
21
|
+
runs_per_page: int = 500
|
|
22
|
+
log_every: int = 20
|
|
23
|
+
|
|
24
|
+
runs_output_filename: str = Field(
|
|
25
|
+
default_factory=lambda data: (
|
|
26
|
+
f"{data['entity']}_{data['project']}_runs.pkl"
|
|
27
|
+
)
|
|
28
|
+
)
|
|
29
|
+
histories_output_filename: str = Field(
|
|
30
|
+
default_factory=lambda data: (
|
|
31
|
+
f"{data['entity']}_{data['project']}_histories.pkl"
|
|
32
|
+
)
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
def progress_callback(self, run_index: int, total_runs: int, message: str)-> None:
|
|
36
|
+
if run_index % self.log_every == 0:
|
|
37
|
+
logging.info(f">> {run_index}/{total_runs}: {message}")
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@computed_field
|
|
41
|
+
@property
|
|
42
|
+
def fetch_runs_cfg(self) -> dict[str, Any]:
|
|
43
|
+
return {
|
|
44
|
+
"entity": self.entity,
|
|
45
|
+
"project": self.project,
|
|
46
|
+
"runs_per_page": self.runs_per_page,
|
|
47
|
+
"progress_callback": self.progress_callback,
|
|
48
|
+
"include_history": not self.runs_only,
|
|
49
|
+
}
|
|
50
|
+
|
|
51
|
+
def setup_logging(level: str = "INFO") -> None:
|
|
52
|
+
logging.basicConfig(
|
|
53
|
+
level=getattr(logging, level.upper()),
|
|
54
|
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
|
55
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@app.command()
|
|
60
|
+
def download_project(
|
|
61
|
+
entity: str,
|
|
62
|
+
project: str,
|
|
63
|
+
output_dir: str,
|
|
64
|
+
runs_only: bool = False,
|
|
65
|
+
runs_per_page: int = 500,
|
|
66
|
+
log_every: int = 20,
|
|
67
|
+
) -> None:
|
|
68
|
+
setup_logging()
|
|
69
|
+
logging.info("\n:: Beginning Dr. Wandb Project Downloading Tool ::\n")
|
|
70
|
+
|
|
71
|
+
cfg = ProjDownloadConfig(
|
|
72
|
+
entity=entity,
|
|
73
|
+
project=project,
|
|
74
|
+
output_dir=output_dir,
|
|
75
|
+
runs_only=runs_only,
|
|
76
|
+
runs_per_page=runs_per_page,
|
|
77
|
+
log_every=log_every,
|
|
78
|
+
)
|
|
79
|
+
logging.info(str(cfg.model_dump_json(indent=4, exclude="fetch_runs_cfg")))
|
|
80
|
+
logging.info("")
|
|
81
|
+
|
|
82
|
+
runs, histories = fetch_project_runs(**cfg.fetch_runs_cfg)
|
|
83
|
+
runs_filename = f"{output_dir}/{cfg.runs_output_filename}"
|
|
84
|
+
histories_filename = f"{output_dir}/{cfg.histories_output_filename}"
|
|
85
|
+
with open(runs_filename, 'wb') as run_file:
|
|
86
|
+
pickle.dump(runs, run_file)
|
|
87
|
+
logging.info(f">> Dumped runs data to: {runs_filename}")
|
|
88
|
+
if not cfg.runs_only:
|
|
89
|
+
with open(histories_filename, 'wb') as hist_file:
|
|
90
|
+
pickle.dump(histories, hist_file)
|
|
91
|
+
logging.info(f">> Dumped histories data to: {histories_filename}")
|
|
92
|
+
else:
|
|
93
|
+
logging.info(f">> Runs only, not dumping histories to: {histories_filename}")
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
if __name__ == "__main__":
|
|
97
|
+
app()
|
|
@@ -0,0 +1,128 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
import click
|
|
5
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
6
|
+
|
|
7
|
+
from dr_wandb.downloader import Downloader
|
|
8
|
+
from dr_wandb.store import ProjectStore
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ProjDownloadSettings(BaseSettings):
|
|
12
|
+
model_config = SettingsConfigDict(env_file=".env", env_prefix="DR_WANDB_")
|
|
13
|
+
|
|
14
|
+
entity: str | None = None
|
|
15
|
+
project: str | None = None
|
|
16
|
+
database_url: str = "postgresql+psycopg2://localhost/wandb"
|
|
17
|
+
output_dir: Path = Path(__file__).parent.parent / "data"
|
|
18
|
+
runs_per_page: int = 500
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def setup_logging(level: str = "INFO") -> None:
|
|
22
|
+
logging.basicConfig(
|
|
23
|
+
level=getattr(logging, level.upper()),
|
|
24
|
+
format="%(asctime)s - %(levelname)s - %(message)s",
|
|
25
|
+
datefmt="%Y-%m-%d %H:%M:%S",
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def validate_settings(entity: str | None, project: str | None) -> None:
|
|
30
|
+
if not entity:
|
|
31
|
+
raise click.ClickException(
|
|
32
|
+
"--entity is required, or set DR_WANDB_ENTITY in .env"
|
|
33
|
+
)
|
|
34
|
+
if not project:
|
|
35
|
+
raise click.ClickException(
|
|
36
|
+
"--project is required, or set DR_WANDB_PROJECT in .env"
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def resolve_config(
|
|
41
|
+
entity: str | None,
|
|
42
|
+
project: str | None,
|
|
43
|
+
db_url: str | None,
|
|
44
|
+
output_dir: str | None,
|
|
45
|
+
) -> ProjDownloadSettings:
|
|
46
|
+
cfg = ProjDownloadSettings()
|
|
47
|
+
final_entity = entity if entity else cfg.entity
|
|
48
|
+
final_project = project if project else cfg.project
|
|
49
|
+
final_db_url = db_url if db_url else cfg.database_url
|
|
50
|
+
final_output_dir = output_dir if output_dir else cfg.output_dir
|
|
51
|
+
validate_settings(final_entity, final_project)
|
|
52
|
+
return ProjDownloadSettings(
|
|
53
|
+
entity=final_entity,
|
|
54
|
+
project=final_project,
|
|
55
|
+
database_url=final_db_url,
|
|
56
|
+
output_dir=final_output_dir,
|
|
57
|
+
runs_per_page=cfg.runs_per_page,
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def execute_download(
|
|
62
|
+
cfg: ProjDownloadSettings, runs_only: bool, force_refresh: bool
|
|
63
|
+
) -> None:
|
|
64
|
+
store = ProjectStore(
|
|
65
|
+
cfg.database_url,
|
|
66
|
+
output_dir=cfg.output_dir,
|
|
67
|
+
)
|
|
68
|
+
downloader = Downloader(store, runs_per_page=cfg.runs_per_page)
|
|
69
|
+
click.echo(">> Beginning download:")
|
|
70
|
+
stats = downloader.download_project(
|
|
71
|
+
entity=cfg.entity,
|
|
72
|
+
project=cfg.project,
|
|
73
|
+
runs_only=runs_only,
|
|
74
|
+
force_refresh=force_refresh,
|
|
75
|
+
)
|
|
76
|
+
click.echo(str(stats))
|
|
77
|
+
return downloader
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@click.command()
|
|
81
|
+
@click.option(
|
|
82
|
+
"--entity",
|
|
83
|
+
envvar="DR_WANDB_ENTITY",
|
|
84
|
+
help="WandB entity (username or team name)",
|
|
85
|
+
)
|
|
86
|
+
@click.option("--project", envvar="DR_WANDB_PROJECT", help="WandB project name")
|
|
87
|
+
@click.option(
|
|
88
|
+
"--runs-only",
|
|
89
|
+
is_flag=True,
|
|
90
|
+
help="Only download runs, don't download history",
|
|
91
|
+
)
|
|
92
|
+
@click.option(
|
|
93
|
+
"--force-refresh",
|
|
94
|
+
is_flag=True,
|
|
95
|
+
help="Force refresh, download all data",
|
|
96
|
+
)
|
|
97
|
+
@click.option(
|
|
98
|
+
"--db-url",
|
|
99
|
+
envvar="DR_WANDB_DATABASE_URL",
|
|
100
|
+
help="PostgreSQL connection string",
|
|
101
|
+
)
|
|
102
|
+
@click.option(
|
|
103
|
+
"--output-dir",
|
|
104
|
+
envvar="DR_WANDB_OUTPUT_DIR",
|
|
105
|
+
help="Output directory",
|
|
106
|
+
)
|
|
107
|
+
def download_project(
|
|
108
|
+
entity: str | None,
|
|
109
|
+
project: str | None,
|
|
110
|
+
runs_only: bool,
|
|
111
|
+
force_refresh: bool,
|
|
112
|
+
db_url: str | None,
|
|
113
|
+
output_dir: str | None,
|
|
114
|
+
) -> None:
|
|
115
|
+
setup_logging()
|
|
116
|
+
click.echo("\n:: Beginning Dr. Wandb Project Downloading Tool ::\n")
|
|
117
|
+
cfg = resolve_config(entity, project, db_url, output_dir)
|
|
118
|
+
click.echo(f">> Downloading project {cfg.entity}/{cfg.project}")
|
|
119
|
+
click.echo(f">> Database: {cfg.database_url}")
|
|
120
|
+
click.echo(f">> Output directory: {cfg.output_dir}")
|
|
121
|
+
click.echo(f">> Force refresh: {force_refresh} Runs only: {runs_only}")
|
|
122
|
+
click.echo()
|
|
123
|
+
downloader = execute_download(cfg, runs_only, force_refresh)
|
|
124
|
+
downloader.write_downloaded_to_parquet()
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
if __name__ == "__main__":
|
|
128
|
+
download_project()
|
dr_wandb/constants.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from collections.abc import Callable
|
|
2
|
+
from typing import Literal
|
|
3
|
+
|
|
4
|
+
from sqlalchemy import String
|
|
5
|
+
from sqlalchemy.orm import DeclarativeBase
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Base(DeclarativeBase):
|
|
9
|
+
pass
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
MAX_INT = 2**31 - 1
|
|
13
|
+
|
|
14
|
+
SUPPORTED_FILTER_FIELDS = ["project", "entity", "state", "run_ids"]
|
|
15
|
+
type FilterField = Literal["project", "entity", "state", "run_ids"]
|
|
16
|
+
|
|
17
|
+
WANDB_RUN_STATES = ["finished", "running", "crashed", "failed", "killed"]
|
|
18
|
+
type RunState = Literal["finished", "running", "crashed", "failed", "killed"]
|
|
19
|
+
type RunId = str
|
|
20
|
+
|
|
21
|
+
Base.type_annotation_map = {RunId: String}
|
|
22
|
+
|
|
23
|
+
type ProgressCallback = Callable[[int, int, str], None]
|
dr_wandb/downloader.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
|
|
6
|
+
import wandb
|
|
7
|
+
|
|
8
|
+
from dr_wandb.constants import ProgressCallback
|
|
9
|
+
from dr_wandb.store import ProjectStore
|
|
10
|
+
from dr_wandb.utils import default_progress_callback, select_updated_runs
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclass
|
|
14
|
+
class DownloaderStats:
|
|
15
|
+
num_wandb_runs: int = 0
|
|
16
|
+
num_stored_runs: int = 0
|
|
17
|
+
num_new_runs: int = 0
|
|
18
|
+
num_updated_runs: int = 0
|
|
19
|
+
|
|
20
|
+
def __str__(self) -> str:
|
|
21
|
+
return "\n".join(
|
|
22
|
+
[
|
|
23
|
+
"",
|
|
24
|
+
":: Downloader Stats ::",
|
|
25
|
+
f" - # WandB runs: {self.num_wandb_runs:,}",
|
|
26
|
+
f" - # Stored runs: {self.num_stored_runs:,}",
|
|
27
|
+
f" - # New runs: {self.num_new_runs:,}",
|
|
28
|
+
f" - # Updated runs: {self.num_updated_runs:,}",
|
|
29
|
+
"",
|
|
30
|
+
]
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
class Downloader:
|
|
35
|
+
def __init__(
|
|
36
|
+
self,
|
|
37
|
+
store: ProjectStore,
|
|
38
|
+
runs_per_page: int = 500,
|
|
39
|
+
) -> None:
|
|
40
|
+
self.store = store
|
|
41
|
+
self._api: wandb.Api | None = None
|
|
42
|
+
self.runs_per_page = runs_per_page
|
|
43
|
+
self.progress_callback: ProgressCallback = default_progress_callback
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def api(self) -> wandb.Api:
|
|
47
|
+
if self._api is None:
|
|
48
|
+
try:
|
|
49
|
+
self._api = wandb.Api()
|
|
50
|
+
except wandb.errors.UsageError as e:
|
|
51
|
+
if "api_key not configured" in str(e):
|
|
52
|
+
raise RuntimeError(
|
|
53
|
+
"WandB API key not configured. "
|
|
54
|
+
"Please run 'wandb login' or set WANDB_API_KEY env var"
|
|
55
|
+
) from e
|
|
56
|
+
raise
|
|
57
|
+
return self._api
|
|
58
|
+
|
|
59
|
+
def set_progress_callback(self, progress_callback: ProgressCallback) -> None:
|
|
60
|
+
self.progress_callback = progress_callback
|
|
61
|
+
|
|
62
|
+
def get_all_runs(self, entity: str, project: str) -> list[wandb.apis.public.Run]:
|
|
63
|
+
return list(self.api.runs(f"{entity}/{project}", per_page=self.runs_per_page))
|
|
64
|
+
|
|
65
|
+
def download_runs(
|
|
66
|
+
self,
|
|
67
|
+
entity: str,
|
|
68
|
+
project: str,
|
|
69
|
+
force_refresh: bool = False,
|
|
70
|
+
with_history: bool = False,
|
|
71
|
+
) -> DownloaderStats:
|
|
72
|
+
wandb_runs = self.get_all_runs(entity, project)
|
|
73
|
+
stored_states = self.store.get_existing_run_states(
|
|
74
|
+
{"entity": entity, "project": project}
|
|
75
|
+
)
|
|
76
|
+
runs_to_download = (
|
|
77
|
+
wandb_runs
|
|
78
|
+
if force_refresh
|
|
79
|
+
else select_updated_runs(wandb_runs, stored_states)
|
|
80
|
+
)
|
|
81
|
+
num_new_runs = len([r for r in runs_to_download if r.id not in stored_states])
|
|
82
|
+
stats = DownloaderStats(
|
|
83
|
+
num_wandb_runs=len(wandb_runs),
|
|
84
|
+
num_stored_runs=len(stored_states),
|
|
85
|
+
num_new_runs=num_new_runs,
|
|
86
|
+
num_updated_runs=len(runs_to_download) - num_new_runs,
|
|
87
|
+
)
|
|
88
|
+
if len(runs_to_download) == 0:
|
|
89
|
+
logging.info(">> No runs to download")
|
|
90
|
+
return stats
|
|
91
|
+
|
|
92
|
+
if not with_history:
|
|
93
|
+
logging.info(">> Runs only mode, bulk downloading runs")
|
|
94
|
+
self.store.store_runs(runs_to_download)
|
|
95
|
+
return stats
|
|
96
|
+
|
|
97
|
+
logging.info(">> Downloading runs and history data together")
|
|
98
|
+
for i, run in enumerate(runs_to_download):
|
|
99
|
+
self.store.store_run_and_history(run, list(run.scan_history()))
|
|
100
|
+
self.progress_callback(i + 1, len(runs_to_download), run.name)
|
|
101
|
+
return stats
|
|
102
|
+
|
|
103
|
+
def download_project(
|
|
104
|
+
self,
|
|
105
|
+
entity: str,
|
|
106
|
+
project: str,
|
|
107
|
+
runs_only: bool = False,
|
|
108
|
+
force_refresh: bool = False,
|
|
109
|
+
) -> DownloaderStats:
|
|
110
|
+
stats = self.download_runs(
|
|
111
|
+
entity, project, force_refresh, with_history=not runs_only
|
|
112
|
+
)
|
|
113
|
+
logging.info(">> Download completed")
|
|
114
|
+
return stats
|
|
115
|
+
|
|
116
|
+
def write_downloaded_to_parquet(self) -> None:
|
|
117
|
+
logging.info(">> Beginning export to parquet")
|
|
118
|
+
self.store.export_to_parquet()
|
dr_wandb/fetch.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
"""Lightweight WandB fetch utilities that avoid database storage."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from collections.abc import Callable, Iterator
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
import wandb
|
|
9
|
+
import logging
|
|
10
|
+
|
|
11
|
+
from dr_wandb.history_entry_record import HistoryEntryRecord
|
|
12
|
+
from dr_wandb.run_record import RunRecord
|
|
13
|
+
from dr_wandb.utils import default_progress_callback
|
|
14
|
+
|
|
15
|
+
ProgressFn = Callable[[int, int, str], None]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def _iterate_runs(
|
|
19
|
+
entity: str,
|
|
20
|
+
project: str,
|
|
21
|
+
*,
|
|
22
|
+
runs_per_page: int,
|
|
23
|
+
) -> Iterator[wandb.apis.public.Run]:
|
|
24
|
+
api = wandb.Api()
|
|
25
|
+
yield from api.runs(f"{entity}/{project}", per_page=runs_per_page)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def serialize_run(run: wandb.apis.public.Run) -> dict[str, Any]:
|
|
29
|
+
"""Convert a WandB run into a JSON-friendly dict."""
|
|
30
|
+
|
|
31
|
+
record = RunRecord.from_wandb_run(run)
|
|
32
|
+
return record.to_dict(include="all")
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def serialize_history_entry(
|
|
36
|
+
run: wandb.apis.public.Run, history_entry: dict[str, Any]
|
|
37
|
+
) -> dict[str, Any]:
|
|
38
|
+
"""Convert a raw history payload into a structured dict."""
|
|
39
|
+
|
|
40
|
+
record = HistoryEntryRecord.from_wandb_history(history_entry, run.id)
|
|
41
|
+
return {
|
|
42
|
+
"run_id": record.run_id,
|
|
43
|
+
"step": record.step,
|
|
44
|
+
"timestamp": record.timestamp,
|
|
45
|
+
"runtime": record.runtime,
|
|
46
|
+
"wandb_metadata": record.wandb_metadata,
|
|
47
|
+
"metrics": record.metrics,
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def fetch_project_runs(
|
|
52
|
+
entity: str,
|
|
53
|
+
project: str,
|
|
54
|
+
*,
|
|
55
|
+
runs_per_page: int = 500,
|
|
56
|
+
include_history: bool = True,
|
|
57
|
+
progress_callback: ProgressFn | None = None,
|
|
58
|
+
) -> tuple[list[dict[str, Any]], list[list[dict[str, Any]]]]:
|
|
59
|
+
"""Download runs (and optional history) without requiring Postgres."""
|
|
60
|
+
|
|
61
|
+
progress = progress_callback or default_progress_callback
|
|
62
|
+
|
|
63
|
+
runs: list[dict[str, Any]] = []
|
|
64
|
+
histories: list[list[dict[str, Any]]] = []
|
|
65
|
+
|
|
66
|
+
logging.info(">> Downloading runs, this will take a while (minutes)")
|
|
67
|
+
run_iter = list(_iterate_runs(entity, project, runs_per_page=runs_per_page))
|
|
68
|
+
total = len(run_iter)
|
|
69
|
+
logging.info(f" - total runs found: {total}")
|
|
70
|
+
|
|
71
|
+
logging.info(f">> Serializing runs and maybe getting histories: {include_history}")
|
|
72
|
+
for index, run in enumerate(run_iter, start=1):
|
|
73
|
+
runs.append(serialize_run(run))
|
|
74
|
+
if include_history:
|
|
75
|
+
history_payloads = [
|
|
76
|
+
serialize_history_entry(run, entry) for entry in run.scan_history()
|
|
77
|
+
]
|
|
78
|
+
histories.append(history_payloads)
|
|
79
|
+
progress(index, total, run.name)
|
|
80
|
+
|
|
81
|
+
if not include_history:
|
|
82
|
+
histories = []
|
|
83
|
+
|
|
84
|
+
return runs, histories
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from sqlalchemy import Select, select
|
|
7
|
+
from sqlalchemy.dialects.postgresql import JSONB
|
|
8
|
+
from sqlalchemy.orm import Mapped, mapped_column
|
|
9
|
+
|
|
10
|
+
from dr_wandb.constants import Base, RunId
|
|
11
|
+
from dr_wandb.utils import extract_as_datetime
|
|
12
|
+
|
|
13
|
+
type HistoryEntry = dict[str, Any]
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class HistoryEntryRecord(Base):
|
|
17
|
+
__tablename__ = "wandb_history"
|
|
18
|
+
|
|
19
|
+
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
|
20
|
+
run_id: Mapped[str]
|
|
21
|
+
step: Mapped[int | None]
|
|
22
|
+
timestamp: Mapped[datetime | None]
|
|
23
|
+
runtime: Mapped[int | None]
|
|
24
|
+
wandb_metadata: Mapped[dict[str, Any]] = mapped_column(JSONB)
|
|
25
|
+
metrics: Mapped[dict[str, Any]] = mapped_column(JSONB)
|
|
26
|
+
|
|
27
|
+
@classmethod
|
|
28
|
+
def from_wandb_history(
|
|
29
|
+
cls, history_entry: HistoryEntry, run_id: str
|
|
30
|
+
) -> HistoryEntryRecord:
|
|
31
|
+
return cls(
|
|
32
|
+
run_id=run_id,
|
|
33
|
+
step=history_entry.get("_step"),
|
|
34
|
+
timestamp=extract_as_datetime(history_entry, "_timestamp"),
|
|
35
|
+
runtime=history_entry.get("_runtime"),
|
|
36
|
+
wandb_metadata=history_entry.get("_wandb", {}),
|
|
37
|
+
metrics={k: v for k, v in history_entry.items() if not k.startswith("_")},
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
@classmethod
|
|
41
|
+
def standard_fields(cls) -> list[str]:
|
|
42
|
+
return [
|
|
43
|
+
col.name
|
|
44
|
+
for col in cls.__table__.columns
|
|
45
|
+
if col.name not in ["wandb_metadata", "metrics"]
|
|
46
|
+
]
|
|
47
|
+
|
|
48
|
+
def to_dict(self, include_metadata: bool = False) -> dict[str, Any]:
|
|
49
|
+
return {
|
|
50
|
+
**{field: getattr(self, field) for field in self.standard_fields()},
|
|
51
|
+
**self.metrics,
|
|
52
|
+
**({"wandb_metadata": self.wandb_metadata} if include_metadata else {}),
|
|
53
|
+
}
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def build_history_query(
|
|
57
|
+
run_ids: list[RunId] | None = None,
|
|
58
|
+
) -> Select[HistoryEntryRecord]:
|
|
59
|
+
query = select(HistoryEntryRecord)
|
|
60
|
+
if run_ids is not None:
|
|
61
|
+
query = query.where(HistoryEntryRecord.run_id.in_(run_ids))
|
|
62
|
+
return query
|
dr_wandb/py.typed
ADDED
|
File without changes
|
dr_wandb/run_record.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from typing import Any, Literal
|
|
5
|
+
|
|
6
|
+
import wandb
|
|
7
|
+
from sqlalchemy import Select, select
|
|
8
|
+
from sqlalchemy.dialects.postgresql import JSONB
|
|
9
|
+
from sqlalchemy.orm import Mapped, mapped_column
|
|
10
|
+
|
|
11
|
+
from dr_wandb.constants import (
|
|
12
|
+
SUPPORTED_FILTER_FIELDS,
|
|
13
|
+
Base,
|
|
14
|
+
FilterField,
|
|
15
|
+
RunId,
|
|
16
|
+
RunState,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
RUN_DATA_COMPONENTS = [
|
|
20
|
+
"config",
|
|
21
|
+
"summary",
|
|
22
|
+
"wandb_metadata",
|
|
23
|
+
"system_metrics",
|
|
24
|
+
"system_attrs",
|
|
25
|
+
"sweep_info",
|
|
26
|
+
]
|
|
27
|
+
type All = Literal["all"]
|
|
28
|
+
type RunDataComponent = Literal[
|
|
29
|
+
"config",
|
|
30
|
+
"summary",
|
|
31
|
+
"wandb_metadata",
|
|
32
|
+
"system_metrics",
|
|
33
|
+
"system_attrs",
|
|
34
|
+
"sweep_info",
|
|
35
|
+
]
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class RunRecord(Base):
|
|
39
|
+
__tablename__ = "wandb_runs"
|
|
40
|
+
|
|
41
|
+
run_id: Mapped[RunId] = mapped_column(primary_key=True)
|
|
42
|
+
run_name: Mapped[str]
|
|
43
|
+
state: Mapped[RunState]
|
|
44
|
+
project: Mapped[str]
|
|
45
|
+
entity: Mapped[str]
|
|
46
|
+
created_at: Mapped[datetime | None]
|
|
47
|
+
|
|
48
|
+
config: Mapped[dict[str, Any]] = mapped_column(JSONB)
|
|
49
|
+
summary: Mapped[dict[str, Any]] = mapped_column(JSONB)
|
|
50
|
+
wandb_metadata: Mapped[dict[str, Any]] = mapped_column(JSONB)
|
|
51
|
+
system_metrics: Mapped[dict[str, Any]] = mapped_column(JSONB)
|
|
52
|
+
system_attrs: Mapped[dict[str, Any]] = mapped_column(JSONB)
|
|
53
|
+
sweep_info: Mapped[dict[str, Any]] = mapped_column(JSONB)
|
|
54
|
+
|
|
55
|
+
@classmethod
|
|
56
|
+
def standard_fields(cls) -> list[str]:
|
|
57
|
+
return [
|
|
58
|
+
col.name
|
|
59
|
+
for col in cls.__table__.columns
|
|
60
|
+
if col.name not in RUN_DATA_COMPONENTS
|
|
61
|
+
]
|
|
62
|
+
|
|
63
|
+
@classmethod
|
|
64
|
+
def from_wandb_run(cls, wandb_run: wandb.apis.public.Run) -> RunRecord:
|
|
65
|
+
return cls(
|
|
66
|
+
run_id=wandb_run.id,
|
|
67
|
+
run_name=wandb_run.name,
|
|
68
|
+
state=wandb_run.state,
|
|
69
|
+
project=wandb_run.project,
|
|
70
|
+
entity=wandb_run.entity,
|
|
71
|
+
created_at=wandb_run.created_at,
|
|
72
|
+
config=dict(wandb_run.config),
|
|
73
|
+
summary=dict(wandb_run.summary._json_dict) if wandb_run.summary else {}, # noqa: SLF001
|
|
74
|
+
wandb_metadata=wandb_run.metadata or {},
|
|
75
|
+
system_metrics=wandb_run.system_metrics or {},
|
|
76
|
+
system_attrs=dict(wandb_run._attrs), # noqa: SLF001
|
|
77
|
+
sweep_info={
|
|
78
|
+
"sweep_id": getattr(wandb_run, "sweep_id", None),
|
|
79
|
+
"sweep_url": getattr(wandb_run, "sweep_url", None),
|
|
80
|
+
},
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
def update_from_wandb_run(self, wandb_run: wandb.apis.public.Run) -> None:
|
|
84
|
+
updated = self.__class__.from_wandb_run(wandb_run)
|
|
85
|
+
for col in self.__table__.columns:
|
|
86
|
+
if col.name != "run_id":
|
|
87
|
+
setattr(self, col.name, getattr(updated, col.name))
|
|
88
|
+
|
|
89
|
+
def to_dict(
|
|
90
|
+
self, include: list[RunDataComponent] | All | None = None
|
|
91
|
+
) -> dict[str, Any]:
|
|
92
|
+
include = include or []
|
|
93
|
+
if include == "all":
|
|
94
|
+
include = RUN_DATA_COMPONENTS
|
|
95
|
+
assert all(field in RUN_DATA_COMPONENTS for field in include)
|
|
96
|
+
data = {k: getattr(self, k) for k in self.standard_fields()}
|
|
97
|
+
for field in include:
|
|
98
|
+
data[field] = getattr(self, field)
|
|
99
|
+
return data
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def build_run_query(kwargs: dict[FilterField, Any] | None = None) -> Select[RunRecord]:
|
|
103
|
+
query = select(RunRecord)
|
|
104
|
+
if kwargs is not None:
|
|
105
|
+
assert all(k in SUPPORTED_FILTER_FIELDS for k in kwargs)
|
|
106
|
+
assert all(v is not None for v in kwargs.values())
|
|
107
|
+
if "project" in kwargs:
|
|
108
|
+
query = query.where(RunRecord.project == kwargs["project"])
|
|
109
|
+
if "entity" in kwargs:
|
|
110
|
+
query = query.where(RunRecord.entity == kwargs["entity"])
|
|
111
|
+
if "state" in kwargs:
|
|
112
|
+
query = query.where(RunRecord.state == kwargs["state"])
|
|
113
|
+
if "run_ids" in kwargs:
|
|
114
|
+
query = query.where(RunRecord.run_id.in_(kwargs["run_ids"]))
|
|
115
|
+
return query
|
dr_wandb/store.py
ADDED
|
@@ -0,0 +1,193 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Any
|
|
6
|
+
from urllib.parse import urlparse
|
|
7
|
+
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import wandb
|
|
10
|
+
from sqlalchemy import Engine, create_engine, text
|
|
11
|
+
from sqlalchemy.exc import OperationalError
|
|
12
|
+
from sqlalchemy.orm import Session
|
|
13
|
+
|
|
14
|
+
from dr_wandb.constants import (
|
|
15
|
+
Base,
|
|
16
|
+
FilterField,
|
|
17
|
+
RunId,
|
|
18
|
+
RunState,
|
|
19
|
+
)
|
|
20
|
+
from dr_wandb.history_entry_record import (
|
|
21
|
+
HistoryEntry,
|
|
22
|
+
HistoryEntryRecord,
|
|
23
|
+
build_history_query,
|
|
24
|
+
)
|
|
25
|
+
from dr_wandb.run_record import (
|
|
26
|
+
RUN_DATA_COMPONENTS,
|
|
27
|
+
All,
|
|
28
|
+
RunDataComponent,
|
|
29
|
+
RunRecord,
|
|
30
|
+
build_run_query,
|
|
31
|
+
)
|
|
32
|
+
from dr_wandb.utils import safe_convert_for_parquet
|
|
33
|
+
|
|
34
|
+
DEFAULT_OUTPUT_DIR = Path(__file__).parent.parent / "data"
|
|
35
|
+
DEFAULT_RUNS_FILENAME = "runs_metadata"
|
|
36
|
+
DEFAULT_HISTORY_FILENAME = "runs_history"
|
|
37
|
+
type History = list[HistoryEntry]
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def delete_history_for_runs(session: Session, run_ids: list[RunId]) -> None:
|
|
41
|
+
if not run_ids:
|
|
42
|
+
return
|
|
43
|
+
session.execute(
|
|
44
|
+
text("DELETE FROM wandb_history WHERE run_id = ANY(:run_ids)"),
|
|
45
|
+
{"run_ids": run_ids},
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def save_update_run(session: Session, run: wandb.apis.public.Run) -> None:
|
|
50
|
+
existing_run = session.get(RunRecord, run.id)
|
|
51
|
+
if existing_run:
|
|
52
|
+
existing_run.update_from_wandb_run(run)
|
|
53
|
+
else:
|
|
54
|
+
session.add(RunRecord.from_wandb_run(run))
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def delete_add_history(session: Session, run_id: RunId, history: History) -> None:
|
|
58
|
+
delete_history_for_runs(session, [run_id])
|
|
59
|
+
for history_entry in history:
|
|
60
|
+
session.add(HistoryEntryRecord.from_wandb_history(history_entry, run_id))
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def ensure_database_exists(database_url: str) -> str:
|
|
64
|
+
parsed = urlparse(database_url)
|
|
65
|
+
db_name = parsed.path.lstrip("/")
|
|
66
|
+
postgres_url = database_url.replace(f"/{db_name}", "/postgres")
|
|
67
|
+
|
|
68
|
+
try:
|
|
69
|
+
test_engine = create_engine(database_url)
|
|
70
|
+
with test_engine.connect():
|
|
71
|
+
pass
|
|
72
|
+
return database_url
|
|
73
|
+
except OperationalError as e:
|
|
74
|
+
if "does not exist" in str(e):
|
|
75
|
+
logging.info(f"Database '{db_name}' doesn't exist, creating it...")
|
|
76
|
+
postgres_engine = create_engine(postgres_url)
|
|
77
|
+
with postgres_engine.connect() as conn:
|
|
78
|
+
conn.execute(text("COMMIT"))
|
|
79
|
+
conn.execute(text(f'CREATE DATABASE "{db_name}"'))
|
|
80
|
+
logging.info(f"Created database '{db_name}'")
|
|
81
|
+
return database_url
|
|
82
|
+
else:
|
|
83
|
+
raise
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class ProjectStore:
|
|
87
|
+
def __init__(self, connection_string: str, output_dir: str | None = None) -> None:
|
|
88
|
+
connection_string = ensure_database_exists(connection_string)
|
|
89
|
+
self.engine: Engine = create_engine(connection_string)
|
|
90
|
+
self.create_tables()
|
|
91
|
+
self.output_dir = output_dir if output_dir is not None else DEFAULT_OUTPUT_DIR
|
|
92
|
+
|
|
93
|
+
def create_tables(self) -> None:
|
|
94
|
+
Base.metadata.create_all(self.engine)
|
|
95
|
+
|
|
96
|
+
def store_run(self, run: wandb.apis.public.Run) -> None:
|
|
97
|
+
with Session(self.engine) as session:
|
|
98
|
+
save_update_run(session, run)
|
|
99
|
+
session.commit()
|
|
100
|
+
|
|
101
|
+
def store_runs(self, runs: list[wandb.apis.public.Run]) -> None:
|
|
102
|
+
with Session(self.engine) as session:
|
|
103
|
+
for run in runs:
|
|
104
|
+
save_update_run(session, run)
|
|
105
|
+
session.commit()
|
|
106
|
+
|
|
107
|
+
def store_history(self, run_id: RunId, history: History) -> None:
|
|
108
|
+
with Session(self.engine) as session:
|
|
109
|
+
delete_add_history(session, run_id, history)
|
|
110
|
+
session.commit()
|
|
111
|
+
|
|
112
|
+
def store_histories(
|
|
113
|
+
self,
|
|
114
|
+
runs: list[wandb.apis.public.Run],
|
|
115
|
+
histories: list[History],
|
|
116
|
+
) -> None:
|
|
117
|
+
assert len(runs) == len(histories)
|
|
118
|
+
run_ids = [run.id for run in runs]
|
|
119
|
+
with Session(self.engine) as session:
|
|
120
|
+
delete_history_for_runs(session, run_ids)
|
|
121
|
+
for run_id, history in zip(run_ids, histories, strict=False):
|
|
122
|
+
for history_entry in history:
|
|
123
|
+
session.add(
|
|
124
|
+
HistoryEntryRecord.from_wandb_history(history_entry, run_id)
|
|
125
|
+
)
|
|
126
|
+
session.commit()
|
|
127
|
+
|
|
128
|
+
def store_run_and_history(
|
|
129
|
+
self, run: wandb.apis.public.Run, history: History
|
|
130
|
+
) -> None:
|
|
131
|
+
with Session(self.engine) as session:
|
|
132
|
+
delete_add_history(session, run.id, history)
|
|
133
|
+
save_update_run(session, run)
|
|
134
|
+
session.commit()
|
|
135
|
+
|
|
136
|
+
def get_runs_df(
|
|
137
|
+
self,
|
|
138
|
+
include: list[RunDataComponent] | All | None = None,
|
|
139
|
+
kwargs: dict[FilterField, Any] | None = None,
|
|
140
|
+
) -> pd.DataFrame:
|
|
141
|
+
with Session(self.engine) as session:
|
|
142
|
+
result = session.execute(build_run_query(kwargs=kwargs))
|
|
143
|
+
return pd.DataFrame(
|
|
144
|
+
[run.to_dict(include=include) for run in result.scalars().all()]
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
def get_history_df(
|
|
148
|
+
self,
|
|
149
|
+
include_metadata: bool = False,
|
|
150
|
+
run_ids: list[RunId] | None = None,
|
|
151
|
+
) -> pd.DataFrame:
|
|
152
|
+
with Session(self.engine) as session:
|
|
153
|
+
result = session.execute(build_history_query(run_ids=run_ids))
|
|
154
|
+
return pd.DataFrame(
|
|
155
|
+
[
|
|
156
|
+
history.to_dict(include_metadata=include_metadata)
|
|
157
|
+
for history in result.scalars().all()
|
|
158
|
+
]
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
def get_existing_run_states(
|
|
162
|
+
self, kwargs: dict[FilterField, Any] | None = None
|
|
163
|
+
) -> dict[RunId, RunState]:
|
|
164
|
+
with Session(self.engine) as session:
|
|
165
|
+
result = session.execute(build_run_query(kwargs=kwargs))
|
|
166
|
+
return {run.run_id: run.state for run in result.scalars().all()}
|
|
167
|
+
|
|
168
|
+
def export_to_parquet(
|
|
169
|
+
self,
|
|
170
|
+
runs_filename: str = DEFAULT_RUNS_FILENAME,
|
|
171
|
+
history_filename: str = DEFAULT_HISTORY_FILENAME,
|
|
172
|
+
) -> None:
|
|
173
|
+
self.output_dir.mkdir(exist_ok=True)
|
|
174
|
+
logging.info(f">> Using data output directory: {self.output_dir}")
|
|
175
|
+
history_df = self.get_history_df()
|
|
176
|
+
if not history_df.empty:
|
|
177
|
+
history_path = self.output_dir / f"{history_filename}.parquet"
|
|
178
|
+
history_df = safe_convert_for_parquet(history_df)
|
|
179
|
+
history_df.to_parquet(history_path, engine="pyarrow", index=False)
|
|
180
|
+
logging.info(f">> Wrote history_df to {history_path}")
|
|
181
|
+
for include_type in RUN_DATA_COMPONENTS:
|
|
182
|
+
runs_df = self.get_runs_df(include=[include_type])
|
|
183
|
+
if not runs_df.empty:
|
|
184
|
+
runs_path = self.output_dir / f"{runs_filename}_{include_type}.parquet"
|
|
185
|
+
runs_df = safe_convert_for_parquet(runs_df)
|
|
186
|
+
runs_df.to_parquet(runs_path, engine="pyarrow", index=False)
|
|
187
|
+
logging.info(f">> Wrote runs_df with {include_type} to {runs_path}")
|
|
188
|
+
runs_df_full = self.get_runs_df(include="all")
|
|
189
|
+
if not runs_df_full.empty:
|
|
190
|
+
runs_path = self.output_dir / f"{runs_filename}.parquet"
|
|
191
|
+
runs_df_full = safe_convert_for_parquet(runs_df_full)
|
|
192
|
+
runs_df_full.to_parquet(runs_path, engine="pyarrow", index=False)
|
|
193
|
+
logging.info(f">> Wrote runs_df with all parts to {runs_path}")
|
dr_wandb/utils.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import pandas as pd
|
|
7
|
+
import wandb
|
|
8
|
+
|
|
9
|
+
from dr_wandb.constants import MAX_INT, RunId, RunState
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def extract_as_datetime(data: dict[str, Any], key: str) -> datetime | None:
|
|
13
|
+
timestamp = data.get(key)
|
|
14
|
+
return datetime.fromtimestamp(timestamp) if timestamp is not None else None
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def select_updated_runs(
|
|
18
|
+
all_runs: list[wandb.apis.public.Run],
|
|
19
|
+
existing_run_states: dict[RunId, RunState],
|
|
20
|
+
) -> list[wandb.apis.public.Run]:
|
|
21
|
+
return [
|
|
22
|
+
run
|
|
23
|
+
for run in all_runs
|
|
24
|
+
if run.id not in existing_run_states or existing_run_states[run.id] == "running"
|
|
25
|
+
]
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def default_progress_callback(run_index: int, total_runs: int, message: str) -> None:
|
|
29
|
+
logging.info(f">> {run_index}/{total_runs}: {message}")
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def convert_large_ints_in_data(data: Any, max_int: int = MAX_INT) -> Any:
|
|
33
|
+
if isinstance(data, dict):
|
|
34
|
+
return {k: convert_large_ints_in_data(v, max_int) for k, v in data.items()}
|
|
35
|
+
elif isinstance(data, list):
|
|
36
|
+
return [convert_large_ints_in_data(item, max_int) for item in data]
|
|
37
|
+
elif isinstance(data, int) and abs(data) > max_int:
|
|
38
|
+
return float(data)
|
|
39
|
+
return data
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def safe_convert_for_parquet(df: pd.DataFrame) -> pd.DataFrame:
|
|
43
|
+
df = df.copy()
|
|
44
|
+
for col in df.columns:
|
|
45
|
+
if df[col].dtype == "int64":
|
|
46
|
+
mask = df[col].abs() > MAX_INT
|
|
47
|
+
if mask.any():
|
|
48
|
+
df[col] = df[col].astype("float64")
|
|
49
|
+
elif df[col].dtype == "object":
|
|
50
|
+
df[col] = df[col].apply(
|
|
51
|
+
lambda x: json.dumps(convert_large_ints_in_data(x), default=str)
|
|
52
|
+
if isinstance(x, dict | list)
|
|
53
|
+
else str(x)
|
|
54
|
+
if x is not None
|
|
55
|
+
else None
|
|
56
|
+
)
|
|
57
|
+
return df
|
|
@@ -0,0 +1,179 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: dr-wandb
|
|
3
|
+
Version: 0.1.2
|
|
4
|
+
Summary: Interact with wandb from python
|
|
5
|
+
Author-email: Danielle Rothermel <danielle.rothermel@gmail.com>
|
|
6
|
+
License-File: LICENSE
|
|
7
|
+
Requires-Python: >=3.12
|
|
8
|
+
Requires-Dist: pandas>=2.3.2
|
|
9
|
+
Requires-Dist: pyarrow>=21.0.0
|
|
10
|
+
Requires-Dist: sqlalchemy>=2.0.43
|
|
11
|
+
Requires-Dist: typer>=0.20.0
|
|
12
|
+
Requires-Dist: wandb>=0.21.4
|
|
13
|
+
Description-Content-Type: text/markdown
|
|
14
|
+
|
|
15
|
+
# dr_wandb
|
|
16
|
+
|
|
17
|
+
A command-line utility for downloading and archiving Weights & Biases experiment data to local storage formats optimized for offline analysis.
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
## Installation
|
|
21
|
+
|
|
22
|
+
CLI Tool Install: `wandb-downloader`
|
|
23
|
+
```
|
|
24
|
+
uv tool install dr_wandb
|
|
25
|
+
```
|
|
26
|
+
|
|
27
|
+
Or, to use the library functions
|
|
28
|
+
```bash
|
|
29
|
+
# To use the library functions
|
|
30
|
+
uv add dr_wandb
|
|
31
|
+
# Optionally
|
|
32
|
+
uv add dr_wandb[postgres]
|
|
33
|
+
uv sync
|
|
34
|
+
```
|
|
35
|
+
|
|
36
|
+
### Authentication
|
|
37
|
+
|
|
38
|
+
Configure Weights & Biases authentication using one of these methods:
|
|
39
|
+
|
|
40
|
+
```bash
|
|
41
|
+
wandb login
|
|
42
|
+
```
|
|
43
|
+
|
|
44
|
+
Or set the API key as an environment variable:
|
|
45
|
+
|
|
46
|
+
```bash
|
|
47
|
+
export WANDB_API_KEY=your_api_key_here
|
|
48
|
+
```
|
|
49
|
+
|
|
50
|
+
## Quickstart
|
|
51
|
+
|
|
52
|
+
The default approach doesn't involve postgres. It fetches the runs, and optionally histories, and dumps them to local pkl files.
|
|
53
|
+
|
|
54
|
+
```bash
|
|
55
|
+
» wandb-download --help
|
|
56
|
+
|
|
57
|
+
Usage: wandb-download [OPTIONS] ENTITY PROJECT OUTPUT_DIR
|
|
58
|
+
|
|
59
|
+
╭─ Arguments ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
|
|
60
|
+
│ * entity TEXT [required] │
|
|
61
|
+
│ * project TEXT [required] │
|
|
62
|
+
│ * output_dir TEXT [required] │
|
|
63
|
+
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
|
|
64
|
+
╭─ Options ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
|
|
65
|
+
│ --runs-only --no-runs-only [default: no-runs-only] │
|
|
66
|
+
│ --runs-per-page INTEGER [default: 500] │
|
|
67
|
+
│ --log-every INTEGER [default: 20] │
|
|
68
|
+
│ --install-completion Install completion for the current shell. │
|
|
69
|
+
│ --show-completion Show completion for the current shell, to copy it or customize the installation. │
|
|
70
|
+
│ --help Show this message and exit. │
|
|
71
|
+
╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
|
|
72
|
+
```
|
|
73
|
+
|
|
74
|
+
An example:
|
|
75
|
+
```bash
|
|
76
|
+
» wandb-download --runs-only "ml-moe" "ft-scaling" "./data" 1 ↵
|
|
77
|
+
2025-11-10 21:47:54 - INFO -
|
|
78
|
+
:: Beginning Dr. Wandb Project Downloading Tool ::
|
|
79
|
+
|
|
80
|
+
2025-11-10 21:47:54 - INFO - {
|
|
81
|
+
"entity": "ml-me",
|
|
82
|
+
"project": "scaling",
|
|
83
|
+
"output_dir": "data",
|
|
84
|
+
"runs_only": true,
|
|
85
|
+
"runs_per_page": 500,
|
|
86
|
+
"log_every": 20,
|
|
87
|
+
"runs_output_filename": "ml-me_scaling_runs.pkl",
|
|
88
|
+
"histories_output_filename": "ml-me_scaling_histories.pkl"
|
|
89
|
+
}
|
|
90
|
+
2025-11-10 21:47:54 - INFO -
|
|
91
|
+
2025-11-10 21:47:54 - INFO - >> Downloading runs, this will take a while (minutes)
|
|
92
|
+
wandb: Currently logged in as: danielle-rothermel (ml-moe) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
|
|
93
|
+
2025-11-10 21:48:00 - INFO - - total runs found: 517
|
|
94
|
+
2025-11-10 21:48:00 - INFO - >> Serializing runs and maybe getting histories: False
|
|
95
|
+
2025-11-10 21:48:07 - INFO - >> 20/517: 2025_08_21-08_24_43_test_finetune_DD-dolma1_7-10M_main_1Mtx1_--learning_rate=5e-05
|
|
96
|
+
2025-11-10 21:48:12 - INFO - >> 40/517: 2025_08_21-08_24_43_test_finetune_DD-dolma1_7-150M_main_10Mtx1_--learning_rate=5e-06
|
|
97
|
+
...
|
|
98
|
+
2025-11-10 21:50:46 - INFO - >> Dumped runs data to: ./data/ml-moe_ft-scaling_runs.pkl
|
|
99
|
+
2025-11-10 21:50:46 - INFO - >> Runs only, not dumping histories to: ./data/ml-moe_ft-scaling_histories.pkl
|
|
100
|
+
```
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
## Very Alpha: Postgres Version
|
|
105
|
+
|
|
106
|
+
**Its very likely this won't currently work.** Download all runs from a Weights & Biases project:
|
|
107
|
+
|
|
108
|
+
```bash
|
|
109
|
+
uv run python src/dr_wandb/cli/postres_download.py --entity your_entity --project your_project
|
|
110
|
+
|
|
111
|
+
Options:
|
|
112
|
+
--entity TEXT WandB entity (username or team name)
|
|
113
|
+
--project TEXT WandB project name
|
|
114
|
+
--runs-only Download only run metadata, skip training history
|
|
115
|
+
--force-refresh Download all data, ignoring existing records
|
|
116
|
+
--db-url TEXT PostgreSQL connection string
|
|
117
|
+
--output-dir TEXT Directory for exported Parquet files
|
|
118
|
+
--help Show help message and exit
|
|
119
|
+
```
|
|
120
|
+
|
|
121
|
+
The tool creates a PostgreSQL database, downloads experiment data, and exports Parquet files to the configured output directory. It tool tracks existing data and downloads only new or updated runs by default. A run is considered for update if:
|
|
122
|
+
|
|
123
|
+
- It does not exist in the local database
|
|
124
|
+
- Its state is "running" (indicating potential new data)
|
|
125
|
+
|
|
126
|
+
Use `--force-refresh` to download all runs regardless of existing data.
|
|
127
|
+
|
|
128
|
+
### Environment Variables
|
|
129
|
+
|
|
130
|
+
The tool reads configuration from environment variables with the `DR_WANDB_` prefix and supports `.env` files:
|
|
131
|
+
|
|
132
|
+
| Variable | Description | Default |
|
|
133
|
+
|----------|-------------|---------|
|
|
134
|
+
| `DR_WANDB_ENTITY` | Weights & Biases entity name | None |
|
|
135
|
+
| `DR_WANDB_PROJECT` | Weights & Biases project name | None |
|
|
136
|
+
| `DR_WANDB_DATABASE_URL` | PostgreSQL connection string | `postgresql+psycopg2://localhost/wandb` |
|
|
137
|
+
| `DR_WANDB_OUTPUT_DIR` | Directory for exported files | `./data` |
|
|
138
|
+
|
|
139
|
+
### Database Configuration
|
|
140
|
+
|
|
141
|
+
The PostgreSQL connection string follows the standard format:
|
|
142
|
+
|
|
143
|
+
```
|
|
144
|
+
postgresql+psycopg2://username:password@host:port/database_name
|
|
145
|
+
```
|
|
146
|
+
|
|
147
|
+
If the specified database does not exist, the tool will attempt to create it automatically.
|
|
148
|
+
|
|
149
|
+
### Data Schema
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
The tool generates the following files in the output directory:
|
|
153
|
+
|
|
154
|
+
- `runs_metadata.parquet` - Complete run metadata including configurations, summaries, and system information
|
|
155
|
+
- `runs_history.parquet` - Training metrics and logged values over time
|
|
156
|
+
- `runs_metadata_{component}.parquet` - Component-specific files for config, summary, wandb_metadata, system_metrics, system_attrs, and sweep_info
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
**Run Records**
|
|
160
|
+
- **run_id**: Unique identifier for the experiment run
|
|
161
|
+
- **run_name**: Human-readable name assigned to the run
|
|
162
|
+
- **state**: Current state (finished, running, crashed, failed, killed)
|
|
163
|
+
- **project**: Project name
|
|
164
|
+
- **entity**: Entity name
|
|
165
|
+
- **created_at**: Timestamp of run creation
|
|
166
|
+
- **config**: Experiment configuration parameters (JSONB)
|
|
167
|
+
- **summary**: Final metrics and outputs (JSONB)
|
|
168
|
+
- **wandb_metadata**: Platform-specific metadata (JSONB)
|
|
169
|
+
- **system_metrics**: Hardware and system information (JSONB)
|
|
170
|
+
- **system_attrs**: Additional system attributes (JSONB)
|
|
171
|
+
- **sweep_info**: Hyperparameter sweep information (JSONB)
|
|
172
|
+
|
|
173
|
+
**Training History Records**
|
|
174
|
+
- **run_id**: Reference to the parent run
|
|
175
|
+
- **step**: Training step number
|
|
176
|
+
- **timestamp**: Time of metric logging
|
|
177
|
+
- **runtime**: Elapsed time since run start
|
|
178
|
+
- **wandb_metadata**: Platform logging metadata (JSONB)
|
|
179
|
+
- **metrics**: All logged metrics and values (JSONB, flattened in Parquet export)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
dr_wandb/__init__.py,sha256=C1FWh869zNWF5XU4XKyGfBPJ2hVpW_UsawIIusfXuQQ,199
|
|
2
|
+
dr_wandb/constants.py,sha256=HuIDOe_MRp2BTTuD1uyVzPJPUm3DbDQDIDk7HNltspc,608
|
|
3
|
+
dr_wandb/downloader.py,sha256=X-NN1A1GilnUoxdEyHCsKJolqGBke_dIzS5wJWEAvvE,3888
|
|
4
|
+
dr_wandb/fetch.py,sha256=wtpY78-VeNjCjro4Ata0N6-uV6neqBq8ooLPxJgXE7k,2533
|
|
5
|
+
dr_wandb/history_entry_record.py,sha256=ni9rXhYWxOg2kdidQ4norYAK37tGWj8xaB8R_lU4tw0,2010
|
|
6
|
+
dr_wandb/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
7
|
+
dr_wandb/run_record.py,sha256=X3fmNOhsBfJsaCZEWhOtnulI16mXAYiH8K194CPdjfk,3794
|
|
8
|
+
dr_wandb/store.py,sha256=gWvlC0NIjcKeRP1rZooBz6dDq2nS2wIudsgahLso3VM,7063
|
|
9
|
+
dr_wandb/utils.py,sha256=zzpHVOVo0QD82ik9ksQCP_vN7Zw0ov9dPGFfNMFgfmg,1796
|
|
10
|
+
dr_wandb/cli/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
|
+
dr_wandb/cli/download.py,sha256=V_V1q5HPXbnBorS0l1gMJlVRg4QlVl5LYL1K7-_j--s,2870
|
|
12
|
+
dr_wandb/cli/postgres_download.py,sha256=XvUY8Jl2u9BGo1l8QXn0foEFi5a3SfCARbvzZ-HxPoA,3710
|
|
13
|
+
dr_wandb-0.1.2.dist-info/METADATA,sha256=EI1oFoFETRG-3eYnzrdJBBUtuRulOEChqhEJE3HU4co,9250
|
|
14
|
+
dr_wandb-0.1.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
15
|
+
dr_wandb-0.1.2.dist-info/entry_points.txt,sha256=BATf5eJjnFMRULrNGiXfzL3ImYPdNK-MlatSzOFrtII,61
|
|
16
|
+
dr_wandb-0.1.2.dist-info/licenses/LICENSE,sha256=6tUm1Q55M1UBMbbawzFlF0-DgCazM1BELo_5-RXA1K4,1075
|
|
17
|
+
dr_wandb-0.1.2.dist-info/RECORD,,
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Danielle Rothermel
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|