dr-wandb 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
dr_wandb/__init__.py ADDED
@@ -0,0 +1,9 @@
1
+ """dr_wandb public API."""
2
+
3
+ from .fetch import fetch_project_runs, serialize_history_entry, serialize_run
4
+
5
+ __all__ = [
6
+ "fetch_project_runs",
7
+ "serialize_history_entry",
8
+ "serialize_run",
9
+ ]
File without changes
@@ -0,0 +1,97 @@
1
+ from typing import Any
2
+ import logging
3
+ from pydantic import BaseModel, Field, computed_field
4
+ from pathlib import Path
5
+ import typer
6
+ import pickle
7
+
8
+ from dr_wandb.fetch import fetch_project_runs
9
+
10
+ app = typer.Typer()
11
+
12
+ class ProjDownloadConfig(BaseModel):
13
+ entity: str
14
+ project: str
15
+ output_dir: Path = Field(
16
+ default_factory=lambda: (
17
+ Path(__file__).parent.parent.parent.parent / "data"
18
+ )
19
+ )
20
+ runs_only: bool = False
21
+ runs_per_page: int = 500
22
+ log_every: int = 20
23
+
24
+ runs_output_filename: str = Field(
25
+ default_factory=lambda data: (
26
+ f"{data['entity']}_{data['project']}_runs.pkl"
27
+ )
28
+ )
29
+ histories_output_filename: str = Field(
30
+ default_factory=lambda data: (
31
+ f"{data['entity']}_{data['project']}_histories.pkl"
32
+ )
33
+ )
34
+
35
+ def progress_callback(self, run_index: int, total_runs: int, message: str)-> None:
36
+ if run_index % self.log_every == 0:
37
+ logging.info(f">> {run_index}/{total_runs}: {message}")
38
+
39
+
40
+ @computed_field
41
+ @property
42
+ def fetch_runs_cfg(self) -> dict[str, Any]:
43
+ return {
44
+ "entity": self.entity,
45
+ "project": self.project,
46
+ "runs_per_page": self.runs_per_page,
47
+ "progress_callback": self.progress_callback,
48
+ "include_history": not self.runs_only,
49
+ }
50
+
51
+ def setup_logging(level: str = "INFO") -> None:
52
+ logging.basicConfig(
53
+ level=getattr(logging, level.upper()),
54
+ format="%(asctime)s - %(levelname)s - %(message)s",
55
+ datefmt="%Y-%m-%d %H:%M:%S",
56
+ )
57
+
58
+
59
+ @app.command()
60
+ def download_project(
61
+ entity: str,
62
+ project: str,
63
+ output_dir: str,
64
+ runs_only: bool = False,
65
+ runs_per_page: int = 500,
66
+ log_every: int = 20,
67
+ ) -> None:
68
+ setup_logging()
69
+ logging.info("\n:: Beginning Dr. Wandb Project Downloading Tool ::\n")
70
+
71
+ cfg = ProjDownloadConfig(
72
+ entity=entity,
73
+ project=project,
74
+ output_dir=output_dir,
75
+ runs_only=runs_only,
76
+ runs_per_page=runs_per_page,
77
+ log_every=log_every,
78
+ )
79
+ logging.info(str(cfg.model_dump_json(indent=4, exclude="fetch_runs_cfg")))
80
+ logging.info("")
81
+
82
+ runs, histories = fetch_project_runs(**cfg.fetch_runs_cfg)
83
+ runs_filename = f"{output_dir}/{cfg.runs_output_filename}"
84
+ histories_filename = f"{output_dir}/{cfg.histories_output_filename}"
85
+ with open(runs_filename, 'wb') as run_file:
86
+ pickle.dump(runs, run_file)
87
+ logging.info(f">> Dumped runs data to: {runs_filename}")
88
+ if not cfg.runs_only:
89
+ with open(histories_filename, 'wb') as hist_file:
90
+ pickle.dump(histories, hist_file)
91
+ logging.info(f">> Dumped histories data to: {histories_filename}")
92
+ else:
93
+ logging.info(f">> Runs only, not dumping histories to: {histories_filename}")
94
+
95
+
96
+ if __name__ == "__main__":
97
+ app()
@@ -0,0 +1,128 @@
1
+ import logging
2
+ from pathlib import Path
3
+
4
+ import click
5
+ from pydantic_settings import BaseSettings, SettingsConfigDict
6
+
7
+ from dr_wandb.downloader import Downloader
8
+ from dr_wandb.store import ProjectStore
9
+
10
+
11
+ class ProjDownloadSettings(BaseSettings):
12
+ model_config = SettingsConfigDict(env_file=".env", env_prefix="DR_WANDB_")
13
+
14
+ entity: str | None = None
15
+ project: str | None = None
16
+ database_url: str = "postgresql+psycopg2://localhost/wandb"
17
+ output_dir: Path = Path(__file__).parent.parent / "data"
18
+ runs_per_page: int = 500
19
+
20
+
21
+ def setup_logging(level: str = "INFO") -> None:
22
+ logging.basicConfig(
23
+ level=getattr(logging, level.upper()),
24
+ format="%(asctime)s - %(levelname)s - %(message)s",
25
+ datefmt="%Y-%m-%d %H:%M:%S",
26
+ )
27
+
28
+
29
+ def validate_settings(entity: str | None, project: str | None) -> None:
30
+ if not entity:
31
+ raise click.ClickException(
32
+ "--entity is required, or set DR_WANDB_ENTITY in .env"
33
+ )
34
+ if not project:
35
+ raise click.ClickException(
36
+ "--project is required, or set DR_WANDB_PROJECT in .env"
37
+ )
38
+
39
+
40
+ def resolve_config(
41
+ entity: str | None,
42
+ project: str | None,
43
+ db_url: str | None,
44
+ output_dir: str | None,
45
+ ) -> ProjDownloadSettings:
46
+ cfg = ProjDownloadSettings()
47
+ final_entity = entity if entity else cfg.entity
48
+ final_project = project if project else cfg.project
49
+ final_db_url = db_url if db_url else cfg.database_url
50
+ final_output_dir = output_dir if output_dir else cfg.output_dir
51
+ validate_settings(final_entity, final_project)
52
+ return ProjDownloadSettings(
53
+ entity=final_entity,
54
+ project=final_project,
55
+ database_url=final_db_url,
56
+ output_dir=final_output_dir,
57
+ runs_per_page=cfg.runs_per_page,
58
+ )
59
+
60
+
61
+ def execute_download(
62
+ cfg: ProjDownloadSettings, runs_only: bool, force_refresh: bool
63
+ ) -> None:
64
+ store = ProjectStore(
65
+ cfg.database_url,
66
+ output_dir=cfg.output_dir,
67
+ )
68
+ downloader = Downloader(store, runs_per_page=cfg.runs_per_page)
69
+ click.echo(">> Beginning download:")
70
+ stats = downloader.download_project(
71
+ entity=cfg.entity,
72
+ project=cfg.project,
73
+ runs_only=runs_only,
74
+ force_refresh=force_refresh,
75
+ )
76
+ click.echo(str(stats))
77
+ return downloader
78
+
79
+
80
+ @click.command()
81
+ @click.option(
82
+ "--entity",
83
+ envvar="DR_WANDB_ENTITY",
84
+ help="WandB entity (username or team name)",
85
+ )
86
+ @click.option("--project", envvar="DR_WANDB_PROJECT", help="WandB project name")
87
+ @click.option(
88
+ "--runs-only",
89
+ is_flag=True,
90
+ help="Only download runs, don't download history",
91
+ )
92
+ @click.option(
93
+ "--force-refresh",
94
+ is_flag=True,
95
+ help="Force refresh, download all data",
96
+ )
97
+ @click.option(
98
+ "--db-url",
99
+ envvar="DR_WANDB_DATABASE_URL",
100
+ help="PostgreSQL connection string",
101
+ )
102
+ @click.option(
103
+ "--output-dir",
104
+ envvar="DR_WANDB_OUTPUT_DIR",
105
+ help="Output directory",
106
+ )
107
+ def download_project(
108
+ entity: str | None,
109
+ project: str | None,
110
+ runs_only: bool,
111
+ force_refresh: bool,
112
+ db_url: str | None,
113
+ output_dir: str | None,
114
+ ) -> None:
115
+ setup_logging()
116
+ click.echo("\n:: Beginning Dr. Wandb Project Downloading Tool ::\n")
117
+ cfg = resolve_config(entity, project, db_url, output_dir)
118
+ click.echo(f">> Downloading project {cfg.entity}/{cfg.project}")
119
+ click.echo(f">> Database: {cfg.database_url}")
120
+ click.echo(f">> Output directory: {cfg.output_dir}")
121
+ click.echo(f">> Force refresh: {force_refresh} Runs only: {runs_only}")
122
+ click.echo()
123
+ downloader = execute_download(cfg, runs_only, force_refresh)
124
+ downloader.write_downloaded_to_parquet()
125
+
126
+
127
+ if __name__ == "__main__":
128
+ download_project()
dr_wandb/constants.py ADDED
@@ -0,0 +1,23 @@
1
+ from collections.abc import Callable
2
+ from typing import Literal
3
+
4
+ from sqlalchemy import String
5
+ from sqlalchemy.orm import DeclarativeBase
6
+
7
+
8
+ class Base(DeclarativeBase):
9
+ pass
10
+
11
+
12
+ MAX_INT = 2**31 - 1
13
+
14
+ SUPPORTED_FILTER_FIELDS = ["project", "entity", "state", "run_ids"]
15
+ type FilterField = Literal["project", "entity", "state", "run_ids"]
16
+
17
+ WANDB_RUN_STATES = ["finished", "running", "crashed", "failed", "killed"]
18
+ type RunState = Literal["finished", "running", "crashed", "failed", "killed"]
19
+ type RunId = str
20
+
21
+ Base.type_annotation_map = {RunId: String}
22
+
23
+ type ProgressCallback = Callable[[int, int, str], None]
dr_wandb/downloader.py ADDED
@@ -0,0 +1,118 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from dataclasses import dataclass
5
+
6
+ import wandb
7
+
8
+ from dr_wandb.constants import ProgressCallback
9
+ from dr_wandb.store import ProjectStore
10
+ from dr_wandb.utils import default_progress_callback, select_updated_runs
11
+
12
+
13
+ @dataclass
14
+ class DownloaderStats:
15
+ num_wandb_runs: int = 0
16
+ num_stored_runs: int = 0
17
+ num_new_runs: int = 0
18
+ num_updated_runs: int = 0
19
+
20
+ def __str__(self) -> str:
21
+ return "\n".join(
22
+ [
23
+ "",
24
+ ":: Downloader Stats ::",
25
+ f" - # WandB runs: {self.num_wandb_runs:,}",
26
+ f" - # Stored runs: {self.num_stored_runs:,}",
27
+ f" - # New runs: {self.num_new_runs:,}",
28
+ f" - # Updated runs: {self.num_updated_runs:,}",
29
+ "",
30
+ ]
31
+ )
32
+
33
+
34
+ class Downloader:
35
+ def __init__(
36
+ self,
37
+ store: ProjectStore,
38
+ runs_per_page: int = 500,
39
+ ) -> None:
40
+ self.store = store
41
+ self._api: wandb.Api | None = None
42
+ self.runs_per_page = runs_per_page
43
+ self.progress_callback: ProgressCallback = default_progress_callback
44
+
45
+ @property
46
+ def api(self) -> wandb.Api:
47
+ if self._api is None:
48
+ try:
49
+ self._api = wandb.Api()
50
+ except wandb.errors.UsageError as e:
51
+ if "api_key not configured" in str(e):
52
+ raise RuntimeError(
53
+ "WandB API key not configured. "
54
+ "Please run 'wandb login' or set WANDB_API_KEY env var"
55
+ ) from e
56
+ raise
57
+ return self._api
58
+
59
+ def set_progress_callback(self, progress_callback: ProgressCallback) -> None:
60
+ self.progress_callback = progress_callback
61
+
62
+ def get_all_runs(self, entity: str, project: str) -> list[wandb.apis.public.Run]:
63
+ return list(self.api.runs(f"{entity}/{project}", per_page=self.runs_per_page))
64
+
65
+ def download_runs(
66
+ self,
67
+ entity: str,
68
+ project: str,
69
+ force_refresh: bool = False,
70
+ with_history: bool = False,
71
+ ) -> DownloaderStats:
72
+ wandb_runs = self.get_all_runs(entity, project)
73
+ stored_states = self.store.get_existing_run_states(
74
+ {"entity": entity, "project": project}
75
+ )
76
+ runs_to_download = (
77
+ wandb_runs
78
+ if force_refresh
79
+ else select_updated_runs(wandb_runs, stored_states)
80
+ )
81
+ num_new_runs = len([r for r in runs_to_download if r.id not in stored_states])
82
+ stats = DownloaderStats(
83
+ num_wandb_runs=len(wandb_runs),
84
+ num_stored_runs=len(stored_states),
85
+ num_new_runs=num_new_runs,
86
+ num_updated_runs=len(runs_to_download) - num_new_runs,
87
+ )
88
+ if len(runs_to_download) == 0:
89
+ logging.info(">> No runs to download")
90
+ return stats
91
+
92
+ if not with_history:
93
+ logging.info(">> Runs only mode, bulk downloading runs")
94
+ self.store.store_runs(runs_to_download)
95
+ return stats
96
+
97
+ logging.info(">> Downloading runs and history data together")
98
+ for i, run in enumerate(runs_to_download):
99
+ self.store.store_run_and_history(run, list(run.scan_history()))
100
+ self.progress_callback(i + 1, len(runs_to_download), run.name)
101
+ return stats
102
+
103
+ def download_project(
104
+ self,
105
+ entity: str,
106
+ project: str,
107
+ runs_only: bool = False,
108
+ force_refresh: bool = False,
109
+ ) -> DownloaderStats:
110
+ stats = self.download_runs(
111
+ entity, project, force_refresh, with_history=not runs_only
112
+ )
113
+ logging.info(">> Download completed")
114
+ return stats
115
+
116
+ def write_downloaded_to_parquet(self) -> None:
117
+ logging.info(">> Beginning export to parquet")
118
+ self.store.export_to_parquet()
dr_wandb/fetch.py ADDED
@@ -0,0 +1,84 @@
1
+ """Lightweight WandB fetch utilities that avoid database storage."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections.abc import Callable, Iterator
6
+ from typing import Any
7
+
8
+ import wandb
9
+ import logging
10
+
11
+ from dr_wandb.history_entry_record import HistoryEntryRecord
12
+ from dr_wandb.run_record import RunRecord
13
+ from dr_wandb.utils import default_progress_callback
14
+
15
+ ProgressFn = Callable[[int, int, str], None]
16
+
17
+
18
+ def _iterate_runs(
19
+ entity: str,
20
+ project: str,
21
+ *,
22
+ runs_per_page: int,
23
+ ) -> Iterator[wandb.apis.public.Run]:
24
+ api = wandb.Api()
25
+ yield from api.runs(f"{entity}/{project}", per_page=runs_per_page)
26
+
27
+
28
+ def serialize_run(run: wandb.apis.public.Run) -> dict[str, Any]:
29
+ """Convert a WandB run into a JSON-friendly dict."""
30
+
31
+ record = RunRecord.from_wandb_run(run)
32
+ return record.to_dict(include="all")
33
+
34
+
35
+ def serialize_history_entry(
36
+ run: wandb.apis.public.Run, history_entry: dict[str, Any]
37
+ ) -> dict[str, Any]:
38
+ """Convert a raw history payload into a structured dict."""
39
+
40
+ record = HistoryEntryRecord.from_wandb_history(history_entry, run.id)
41
+ return {
42
+ "run_id": record.run_id,
43
+ "step": record.step,
44
+ "timestamp": record.timestamp,
45
+ "runtime": record.runtime,
46
+ "wandb_metadata": record.wandb_metadata,
47
+ "metrics": record.metrics,
48
+ }
49
+
50
+
51
+ def fetch_project_runs(
52
+ entity: str,
53
+ project: str,
54
+ *,
55
+ runs_per_page: int = 500,
56
+ include_history: bool = True,
57
+ progress_callback: ProgressFn | None = None,
58
+ ) -> tuple[list[dict[str, Any]], list[list[dict[str, Any]]]]:
59
+ """Download runs (and optional history) without requiring Postgres."""
60
+
61
+ progress = progress_callback or default_progress_callback
62
+
63
+ runs: list[dict[str, Any]] = []
64
+ histories: list[list[dict[str, Any]]] = []
65
+
66
+ logging.info(">> Downloading runs, this will take a while (minutes)")
67
+ run_iter = list(_iterate_runs(entity, project, runs_per_page=runs_per_page))
68
+ total = len(run_iter)
69
+ logging.info(f" - total runs found: {total}")
70
+
71
+ logging.info(f">> Serializing runs and maybe getting histories: {include_history}")
72
+ for index, run in enumerate(run_iter, start=1):
73
+ runs.append(serialize_run(run))
74
+ if include_history:
75
+ history_payloads = [
76
+ serialize_history_entry(run, entry) for entry in run.scan_history()
77
+ ]
78
+ histories.append(history_payloads)
79
+ progress(index, total, run.name)
80
+
81
+ if not include_history:
82
+ histories = []
83
+
84
+ return runs, histories
@@ -0,0 +1,62 @@
1
+ from __future__ import annotations
2
+
3
+ from datetime import datetime
4
+ from typing import Any
5
+
6
+ from sqlalchemy import Select, select
7
+ from sqlalchemy.dialects.postgresql import JSONB
8
+ from sqlalchemy.orm import Mapped, mapped_column
9
+
10
+ from dr_wandb.constants import Base, RunId
11
+ from dr_wandb.utils import extract_as_datetime
12
+
13
+ type HistoryEntry = dict[str, Any]
14
+
15
+
16
+ class HistoryEntryRecord(Base):
17
+ __tablename__ = "wandb_history"
18
+
19
+ id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
20
+ run_id: Mapped[str]
21
+ step: Mapped[int | None]
22
+ timestamp: Mapped[datetime | None]
23
+ runtime: Mapped[int | None]
24
+ wandb_metadata: Mapped[dict[str, Any]] = mapped_column(JSONB)
25
+ metrics: Mapped[dict[str, Any]] = mapped_column(JSONB)
26
+
27
+ @classmethod
28
+ def from_wandb_history(
29
+ cls, history_entry: HistoryEntry, run_id: str
30
+ ) -> HistoryEntryRecord:
31
+ return cls(
32
+ run_id=run_id,
33
+ step=history_entry.get("_step"),
34
+ timestamp=extract_as_datetime(history_entry, "_timestamp"),
35
+ runtime=history_entry.get("_runtime"),
36
+ wandb_metadata=history_entry.get("_wandb", {}),
37
+ metrics={k: v for k, v in history_entry.items() if not k.startswith("_")},
38
+ )
39
+
40
+ @classmethod
41
+ def standard_fields(cls) -> list[str]:
42
+ return [
43
+ col.name
44
+ for col in cls.__table__.columns
45
+ if col.name not in ["wandb_metadata", "metrics"]
46
+ ]
47
+
48
+ def to_dict(self, include_metadata: bool = False) -> dict[str, Any]:
49
+ return {
50
+ **{field: getattr(self, field) for field in self.standard_fields()},
51
+ **self.metrics,
52
+ **({"wandb_metadata": self.wandb_metadata} if include_metadata else {}),
53
+ }
54
+
55
+
56
+ def build_history_query(
57
+ run_ids: list[RunId] | None = None,
58
+ ) -> Select[HistoryEntryRecord]:
59
+ query = select(HistoryEntryRecord)
60
+ if run_ids is not None:
61
+ query = query.where(HistoryEntryRecord.run_id.in_(run_ids))
62
+ return query
dr_wandb/py.typed ADDED
File without changes
dr_wandb/run_record.py ADDED
@@ -0,0 +1,115 @@
1
+ from __future__ import annotations
2
+
3
+ from datetime import datetime
4
+ from typing import Any, Literal
5
+
6
+ import wandb
7
+ from sqlalchemy import Select, select
8
+ from sqlalchemy.dialects.postgresql import JSONB
9
+ from sqlalchemy.orm import Mapped, mapped_column
10
+
11
+ from dr_wandb.constants import (
12
+ SUPPORTED_FILTER_FIELDS,
13
+ Base,
14
+ FilterField,
15
+ RunId,
16
+ RunState,
17
+ )
18
+
19
+ RUN_DATA_COMPONENTS = [
20
+ "config",
21
+ "summary",
22
+ "wandb_metadata",
23
+ "system_metrics",
24
+ "system_attrs",
25
+ "sweep_info",
26
+ ]
27
+ type All = Literal["all"]
28
+ type RunDataComponent = Literal[
29
+ "config",
30
+ "summary",
31
+ "wandb_metadata",
32
+ "system_metrics",
33
+ "system_attrs",
34
+ "sweep_info",
35
+ ]
36
+
37
+
38
+ class RunRecord(Base):
39
+ __tablename__ = "wandb_runs"
40
+
41
+ run_id: Mapped[RunId] = mapped_column(primary_key=True)
42
+ run_name: Mapped[str]
43
+ state: Mapped[RunState]
44
+ project: Mapped[str]
45
+ entity: Mapped[str]
46
+ created_at: Mapped[datetime | None]
47
+
48
+ config: Mapped[dict[str, Any]] = mapped_column(JSONB)
49
+ summary: Mapped[dict[str, Any]] = mapped_column(JSONB)
50
+ wandb_metadata: Mapped[dict[str, Any]] = mapped_column(JSONB)
51
+ system_metrics: Mapped[dict[str, Any]] = mapped_column(JSONB)
52
+ system_attrs: Mapped[dict[str, Any]] = mapped_column(JSONB)
53
+ sweep_info: Mapped[dict[str, Any]] = mapped_column(JSONB)
54
+
55
+ @classmethod
56
+ def standard_fields(cls) -> list[str]:
57
+ return [
58
+ col.name
59
+ for col in cls.__table__.columns
60
+ if col.name not in RUN_DATA_COMPONENTS
61
+ ]
62
+
63
+ @classmethod
64
+ def from_wandb_run(cls, wandb_run: wandb.apis.public.Run) -> RunRecord:
65
+ return cls(
66
+ run_id=wandb_run.id,
67
+ run_name=wandb_run.name,
68
+ state=wandb_run.state,
69
+ project=wandb_run.project,
70
+ entity=wandb_run.entity,
71
+ created_at=wandb_run.created_at,
72
+ config=dict(wandb_run.config),
73
+ summary=dict(wandb_run.summary._json_dict) if wandb_run.summary else {}, # noqa: SLF001
74
+ wandb_metadata=wandb_run.metadata or {},
75
+ system_metrics=wandb_run.system_metrics or {},
76
+ system_attrs=dict(wandb_run._attrs), # noqa: SLF001
77
+ sweep_info={
78
+ "sweep_id": getattr(wandb_run, "sweep_id", None),
79
+ "sweep_url": getattr(wandb_run, "sweep_url", None),
80
+ },
81
+ )
82
+
83
+ def update_from_wandb_run(self, wandb_run: wandb.apis.public.Run) -> None:
84
+ updated = self.__class__.from_wandb_run(wandb_run)
85
+ for col in self.__table__.columns:
86
+ if col.name != "run_id":
87
+ setattr(self, col.name, getattr(updated, col.name))
88
+
89
+ def to_dict(
90
+ self, include: list[RunDataComponent] | All | None = None
91
+ ) -> dict[str, Any]:
92
+ include = include or []
93
+ if include == "all":
94
+ include = RUN_DATA_COMPONENTS
95
+ assert all(field in RUN_DATA_COMPONENTS for field in include)
96
+ data = {k: getattr(self, k) for k in self.standard_fields()}
97
+ for field in include:
98
+ data[field] = getattr(self, field)
99
+ return data
100
+
101
+
102
+ def build_run_query(kwargs: dict[FilterField, Any] | None = None) -> Select[RunRecord]:
103
+ query = select(RunRecord)
104
+ if kwargs is not None:
105
+ assert all(k in SUPPORTED_FILTER_FIELDS for k in kwargs)
106
+ assert all(v is not None for v in kwargs.values())
107
+ if "project" in kwargs:
108
+ query = query.where(RunRecord.project == kwargs["project"])
109
+ if "entity" in kwargs:
110
+ query = query.where(RunRecord.entity == kwargs["entity"])
111
+ if "state" in kwargs:
112
+ query = query.where(RunRecord.state == kwargs["state"])
113
+ if "run_ids" in kwargs:
114
+ query = query.where(RunRecord.run_id.in_(kwargs["run_ids"]))
115
+ return query
dr_wandb/store.py ADDED
@@ -0,0 +1,193 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import Any
6
+ from urllib.parse import urlparse
7
+
8
+ import pandas as pd
9
+ import wandb
10
+ from sqlalchemy import Engine, create_engine, text
11
+ from sqlalchemy.exc import OperationalError
12
+ from sqlalchemy.orm import Session
13
+
14
+ from dr_wandb.constants import (
15
+ Base,
16
+ FilterField,
17
+ RunId,
18
+ RunState,
19
+ )
20
+ from dr_wandb.history_entry_record import (
21
+ HistoryEntry,
22
+ HistoryEntryRecord,
23
+ build_history_query,
24
+ )
25
+ from dr_wandb.run_record import (
26
+ RUN_DATA_COMPONENTS,
27
+ All,
28
+ RunDataComponent,
29
+ RunRecord,
30
+ build_run_query,
31
+ )
32
+ from dr_wandb.utils import safe_convert_for_parquet
33
+
34
+ DEFAULT_OUTPUT_DIR = Path(__file__).parent.parent / "data"
35
+ DEFAULT_RUNS_FILENAME = "runs_metadata"
36
+ DEFAULT_HISTORY_FILENAME = "runs_history"
37
+ type History = list[HistoryEntry]
38
+
39
+
40
+ def delete_history_for_runs(session: Session, run_ids: list[RunId]) -> None:
41
+ if not run_ids:
42
+ return
43
+ session.execute(
44
+ text("DELETE FROM wandb_history WHERE run_id = ANY(:run_ids)"),
45
+ {"run_ids": run_ids},
46
+ )
47
+
48
+
49
+ def save_update_run(session: Session, run: wandb.apis.public.Run) -> None:
50
+ existing_run = session.get(RunRecord, run.id)
51
+ if existing_run:
52
+ existing_run.update_from_wandb_run(run)
53
+ else:
54
+ session.add(RunRecord.from_wandb_run(run))
55
+
56
+
57
+ def delete_add_history(session: Session, run_id: RunId, history: History) -> None:
58
+ delete_history_for_runs(session, [run_id])
59
+ for history_entry in history:
60
+ session.add(HistoryEntryRecord.from_wandb_history(history_entry, run_id))
61
+
62
+
63
+ def ensure_database_exists(database_url: str) -> str:
64
+ parsed = urlparse(database_url)
65
+ db_name = parsed.path.lstrip("/")
66
+ postgres_url = database_url.replace(f"/{db_name}", "/postgres")
67
+
68
+ try:
69
+ test_engine = create_engine(database_url)
70
+ with test_engine.connect():
71
+ pass
72
+ return database_url
73
+ except OperationalError as e:
74
+ if "does not exist" in str(e):
75
+ logging.info(f"Database '{db_name}' doesn't exist, creating it...")
76
+ postgres_engine = create_engine(postgres_url)
77
+ with postgres_engine.connect() as conn:
78
+ conn.execute(text("COMMIT"))
79
+ conn.execute(text(f'CREATE DATABASE "{db_name}"'))
80
+ logging.info(f"Created database '{db_name}'")
81
+ return database_url
82
+ else:
83
+ raise
84
+
85
+
86
+ class ProjectStore:
87
+ def __init__(self, connection_string: str, output_dir: str | None = None) -> None:
88
+ connection_string = ensure_database_exists(connection_string)
89
+ self.engine: Engine = create_engine(connection_string)
90
+ self.create_tables()
91
+ self.output_dir = output_dir if output_dir is not None else DEFAULT_OUTPUT_DIR
92
+
93
+ def create_tables(self) -> None:
94
+ Base.metadata.create_all(self.engine)
95
+
96
+ def store_run(self, run: wandb.apis.public.Run) -> None:
97
+ with Session(self.engine) as session:
98
+ save_update_run(session, run)
99
+ session.commit()
100
+
101
+ def store_runs(self, runs: list[wandb.apis.public.Run]) -> None:
102
+ with Session(self.engine) as session:
103
+ for run in runs:
104
+ save_update_run(session, run)
105
+ session.commit()
106
+
107
+ def store_history(self, run_id: RunId, history: History) -> None:
108
+ with Session(self.engine) as session:
109
+ delete_add_history(session, run_id, history)
110
+ session.commit()
111
+
112
+ def store_histories(
113
+ self,
114
+ runs: list[wandb.apis.public.Run],
115
+ histories: list[History],
116
+ ) -> None:
117
+ assert len(runs) == len(histories)
118
+ run_ids = [run.id for run in runs]
119
+ with Session(self.engine) as session:
120
+ delete_history_for_runs(session, run_ids)
121
+ for run_id, history in zip(run_ids, histories, strict=False):
122
+ for history_entry in history:
123
+ session.add(
124
+ HistoryEntryRecord.from_wandb_history(history_entry, run_id)
125
+ )
126
+ session.commit()
127
+
128
+ def store_run_and_history(
129
+ self, run: wandb.apis.public.Run, history: History
130
+ ) -> None:
131
+ with Session(self.engine) as session:
132
+ delete_add_history(session, run.id, history)
133
+ save_update_run(session, run)
134
+ session.commit()
135
+
136
+ def get_runs_df(
137
+ self,
138
+ include: list[RunDataComponent] | All | None = None,
139
+ kwargs: dict[FilterField, Any] | None = None,
140
+ ) -> pd.DataFrame:
141
+ with Session(self.engine) as session:
142
+ result = session.execute(build_run_query(kwargs=kwargs))
143
+ return pd.DataFrame(
144
+ [run.to_dict(include=include) for run in result.scalars().all()]
145
+ )
146
+
147
+ def get_history_df(
148
+ self,
149
+ include_metadata: bool = False,
150
+ run_ids: list[RunId] | None = None,
151
+ ) -> pd.DataFrame:
152
+ with Session(self.engine) as session:
153
+ result = session.execute(build_history_query(run_ids=run_ids))
154
+ return pd.DataFrame(
155
+ [
156
+ history.to_dict(include_metadata=include_metadata)
157
+ for history in result.scalars().all()
158
+ ]
159
+ )
160
+
161
+ def get_existing_run_states(
162
+ self, kwargs: dict[FilterField, Any] | None = None
163
+ ) -> dict[RunId, RunState]:
164
+ with Session(self.engine) as session:
165
+ result = session.execute(build_run_query(kwargs=kwargs))
166
+ return {run.run_id: run.state for run in result.scalars().all()}
167
+
168
+ def export_to_parquet(
169
+ self,
170
+ runs_filename: str = DEFAULT_RUNS_FILENAME,
171
+ history_filename: str = DEFAULT_HISTORY_FILENAME,
172
+ ) -> None:
173
+ self.output_dir.mkdir(exist_ok=True)
174
+ logging.info(f">> Using data output directory: {self.output_dir}")
175
+ history_df = self.get_history_df()
176
+ if not history_df.empty:
177
+ history_path = self.output_dir / f"{history_filename}.parquet"
178
+ history_df = safe_convert_for_parquet(history_df)
179
+ history_df.to_parquet(history_path, engine="pyarrow", index=False)
180
+ logging.info(f">> Wrote history_df to {history_path}")
181
+ for include_type in RUN_DATA_COMPONENTS:
182
+ runs_df = self.get_runs_df(include=[include_type])
183
+ if not runs_df.empty:
184
+ runs_path = self.output_dir / f"{runs_filename}_{include_type}.parquet"
185
+ runs_df = safe_convert_for_parquet(runs_df)
186
+ runs_df.to_parquet(runs_path, engine="pyarrow", index=False)
187
+ logging.info(f">> Wrote runs_df with {include_type} to {runs_path}")
188
+ runs_df_full = self.get_runs_df(include="all")
189
+ if not runs_df_full.empty:
190
+ runs_path = self.output_dir / f"{runs_filename}.parquet"
191
+ runs_df_full = safe_convert_for_parquet(runs_df_full)
192
+ runs_df_full.to_parquet(runs_path, engine="pyarrow", index=False)
193
+ logging.info(f">> Wrote runs_df with all parts to {runs_path}")
dr_wandb/utils.py ADDED
@@ -0,0 +1,57 @@
1
+ import json
2
+ import logging
3
+ from datetime import datetime
4
+ from typing import Any
5
+
6
+ import pandas as pd
7
+ import wandb
8
+
9
+ from dr_wandb.constants import MAX_INT, RunId, RunState
10
+
11
+
12
+ def extract_as_datetime(data: dict[str, Any], key: str) -> datetime | None:
13
+ timestamp = data.get(key)
14
+ return datetime.fromtimestamp(timestamp) if timestamp is not None else None
15
+
16
+
17
+ def select_updated_runs(
18
+ all_runs: list[wandb.apis.public.Run],
19
+ existing_run_states: dict[RunId, RunState],
20
+ ) -> list[wandb.apis.public.Run]:
21
+ return [
22
+ run
23
+ for run in all_runs
24
+ if run.id not in existing_run_states or existing_run_states[run.id] == "running"
25
+ ]
26
+
27
+
28
+ def default_progress_callback(run_index: int, total_runs: int, message: str) -> None:
29
+ logging.info(f">> {run_index}/{total_runs}: {message}")
30
+
31
+
32
+ def convert_large_ints_in_data(data: Any, max_int: int = MAX_INT) -> Any:
33
+ if isinstance(data, dict):
34
+ return {k: convert_large_ints_in_data(v, max_int) for k, v in data.items()}
35
+ elif isinstance(data, list):
36
+ return [convert_large_ints_in_data(item, max_int) for item in data]
37
+ elif isinstance(data, int) and abs(data) > max_int:
38
+ return float(data)
39
+ return data
40
+
41
+
42
+ def safe_convert_for_parquet(df: pd.DataFrame) -> pd.DataFrame:
43
+ df = df.copy()
44
+ for col in df.columns:
45
+ if df[col].dtype == "int64":
46
+ mask = df[col].abs() > MAX_INT
47
+ if mask.any():
48
+ df[col] = df[col].astype("float64")
49
+ elif df[col].dtype == "object":
50
+ df[col] = df[col].apply(
51
+ lambda x: json.dumps(convert_large_ints_in_data(x), default=str)
52
+ if isinstance(x, dict | list)
53
+ else str(x)
54
+ if x is not None
55
+ else None
56
+ )
57
+ return df
@@ -0,0 +1,179 @@
1
+ Metadata-Version: 2.4
2
+ Name: dr-wandb
3
+ Version: 0.1.2
4
+ Summary: Interact with wandb from python
5
+ Author-email: Danielle Rothermel <danielle.rothermel@gmail.com>
6
+ License-File: LICENSE
7
+ Requires-Python: >=3.12
8
+ Requires-Dist: pandas>=2.3.2
9
+ Requires-Dist: pyarrow>=21.0.0
10
+ Requires-Dist: sqlalchemy>=2.0.43
11
+ Requires-Dist: typer>=0.20.0
12
+ Requires-Dist: wandb>=0.21.4
13
+ Description-Content-Type: text/markdown
14
+
15
+ # dr_wandb
16
+
17
+ A command-line utility for downloading and archiving Weights & Biases experiment data to local storage formats optimized for offline analysis.
18
+
19
+
20
+ ## Installation
21
+
22
+ CLI Tool Install: `wandb-downloader`
23
+ ```
24
+ uv tool install dr_wandb
25
+ ```
26
+
27
+ Or, to use the library functions
28
+ ```bash
29
+ # To use the library functions
30
+ uv add dr_wandb
31
+ # Optionally
32
+ uv add dr_wandb[postgres]
33
+ uv sync
34
+ ```
35
+
36
+ ### Authentication
37
+
38
+ Configure Weights & Biases authentication using one of these methods:
39
+
40
+ ```bash
41
+ wandb login
42
+ ```
43
+
44
+ Or set the API key as an environment variable:
45
+
46
+ ```bash
47
+ export WANDB_API_KEY=your_api_key_here
48
+ ```
49
+
50
+ ## Quickstart
51
+
52
+ The default approach doesn't involve postgres. It fetches the runs, and optionally histories, and dumps them to local pkl files.
53
+
54
+ ```bash
55
+ » wandb-download --help
56
+
57
+ Usage: wandb-download [OPTIONS] ENTITY PROJECT OUTPUT_DIR
58
+
59
+ ╭─ Arguments ──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
60
+ │ * entity TEXT [required] │
61
+ │ * project TEXT [required] │
62
+ │ * output_dir TEXT [required] │
63
+ ╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
64
+ ╭─ Options ────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╮
65
+ │ --runs-only --no-runs-only [default: no-runs-only] │
66
+ │ --runs-per-page INTEGER [default: 500] │
67
+ │ --log-every INTEGER [default: 20] │
68
+ │ --install-completion Install completion for the current shell. │
69
+ │ --show-completion Show completion for the current shell, to copy it or customize the installation. │
70
+ │ --help Show this message and exit. │
71
+ ╰──────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────╯
72
+ ```
73
+
74
+ An example:
75
+ ```bash
76
+ » wandb-download --runs-only "ml-moe" "ft-scaling" "./data" 1 ↵
77
+ 2025-11-10 21:47:54 - INFO -
78
+ :: Beginning Dr. Wandb Project Downloading Tool ::
79
+
80
+ 2025-11-10 21:47:54 - INFO - {
81
+ "entity": "ml-me",
82
+ "project": "scaling",
83
+ "output_dir": "data",
84
+ "runs_only": true,
85
+ "runs_per_page": 500,
86
+ "log_every": 20,
87
+ "runs_output_filename": "ml-me_scaling_runs.pkl",
88
+ "histories_output_filename": "ml-me_scaling_histories.pkl"
89
+ }
90
+ 2025-11-10 21:47:54 - INFO -
91
+ 2025-11-10 21:47:54 - INFO - >> Downloading runs, this will take a while (minutes)
92
+ wandb: Currently logged in as: danielle-rothermel (ml-moe) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
93
+ 2025-11-10 21:48:00 - INFO - - total runs found: 517
94
+ 2025-11-10 21:48:00 - INFO - >> Serializing runs and maybe getting histories: False
95
+ 2025-11-10 21:48:07 - INFO - >> 20/517: 2025_08_21-08_24_43_test_finetune_DD-dolma1_7-10M_main_1Mtx1_--learning_rate=5e-05
96
+ 2025-11-10 21:48:12 - INFO - >> 40/517: 2025_08_21-08_24_43_test_finetune_DD-dolma1_7-150M_main_10Mtx1_--learning_rate=5e-06
97
+ ...
98
+ 2025-11-10 21:50:46 - INFO - >> Dumped runs data to: ./data/ml-moe_ft-scaling_runs.pkl
99
+ 2025-11-10 21:50:46 - INFO - >> Runs only, not dumping histories to: ./data/ml-moe_ft-scaling_histories.pkl
100
+ ```
101
+
102
+
103
+
104
+ ## Very Alpha: Postgres Version
105
+
106
+ **Its very likely this won't currently work.** Download all runs from a Weights & Biases project:
107
+
108
+ ```bash
109
+ uv run python src/dr_wandb/cli/postres_download.py --entity your_entity --project your_project
110
+
111
+ Options:
112
+ --entity TEXT WandB entity (username or team name)
113
+ --project TEXT WandB project name
114
+ --runs-only Download only run metadata, skip training history
115
+ --force-refresh Download all data, ignoring existing records
116
+ --db-url TEXT PostgreSQL connection string
117
+ --output-dir TEXT Directory for exported Parquet files
118
+ --help Show help message and exit
119
+ ```
120
+
121
+ The tool creates a PostgreSQL database, downloads experiment data, and exports Parquet files to the configured output directory. It tool tracks existing data and downloads only new or updated runs by default. A run is considered for update if:
122
+
123
+ - It does not exist in the local database
124
+ - Its state is "running" (indicating potential new data)
125
+
126
+ Use `--force-refresh` to download all runs regardless of existing data.
127
+
128
+ ### Environment Variables
129
+
130
+ The tool reads configuration from environment variables with the `DR_WANDB_` prefix and supports `.env` files:
131
+
132
+ | Variable | Description | Default |
133
+ |----------|-------------|---------|
134
+ | `DR_WANDB_ENTITY` | Weights & Biases entity name | None |
135
+ | `DR_WANDB_PROJECT` | Weights & Biases project name | None |
136
+ | `DR_WANDB_DATABASE_URL` | PostgreSQL connection string | `postgresql+psycopg2://localhost/wandb` |
137
+ | `DR_WANDB_OUTPUT_DIR` | Directory for exported files | `./data` |
138
+
139
+ ### Database Configuration
140
+
141
+ The PostgreSQL connection string follows the standard format:
142
+
143
+ ```
144
+ postgresql+psycopg2://username:password@host:port/database_name
145
+ ```
146
+
147
+ If the specified database does not exist, the tool will attempt to create it automatically.
148
+
149
+ ### Data Schema
150
+
151
+
152
+ The tool generates the following files in the output directory:
153
+
154
+ - `runs_metadata.parquet` - Complete run metadata including configurations, summaries, and system information
155
+ - `runs_history.parquet` - Training metrics and logged values over time
156
+ - `runs_metadata_{component}.parquet` - Component-specific files for config, summary, wandb_metadata, system_metrics, system_attrs, and sweep_info
157
+
158
+
159
+ **Run Records**
160
+ - **run_id**: Unique identifier for the experiment run
161
+ - **run_name**: Human-readable name assigned to the run
162
+ - **state**: Current state (finished, running, crashed, failed, killed)
163
+ - **project**: Project name
164
+ - **entity**: Entity name
165
+ - **created_at**: Timestamp of run creation
166
+ - **config**: Experiment configuration parameters (JSONB)
167
+ - **summary**: Final metrics and outputs (JSONB)
168
+ - **wandb_metadata**: Platform-specific metadata (JSONB)
169
+ - **system_metrics**: Hardware and system information (JSONB)
170
+ - **system_attrs**: Additional system attributes (JSONB)
171
+ - **sweep_info**: Hyperparameter sweep information (JSONB)
172
+
173
+ **Training History Records**
174
+ - **run_id**: Reference to the parent run
175
+ - **step**: Training step number
176
+ - **timestamp**: Time of metric logging
177
+ - **runtime**: Elapsed time since run start
178
+ - **wandb_metadata**: Platform logging metadata (JSONB)
179
+ - **metrics**: All logged metrics and values (JSONB, flattened in Parquet export)
@@ -0,0 +1,17 @@
1
+ dr_wandb/__init__.py,sha256=C1FWh869zNWF5XU4XKyGfBPJ2hVpW_UsawIIusfXuQQ,199
2
+ dr_wandb/constants.py,sha256=HuIDOe_MRp2BTTuD1uyVzPJPUm3DbDQDIDk7HNltspc,608
3
+ dr_wandb/downloader.py,sha256=X-NN1A1GilnUoxdEyHCsKJolqGBke_dIzS5wJWEAvvE,3888
4
+ dr_wandb/fetch.py,sha256=wtpY78-VeNjCjro4Ata0N6-uV6neqBq8ooLPxJgXE7k,2533
5
+ dr_wandb/history_entry_record.py,sha256=ni9rXhYWxOg2kdidQ4norYAK37tGWj8xaB8R_lU4tw0,2010
6
+ dr_wandb/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
+ dr_wandb/run_record.py,sha256=X3fmNOhsBfJsaCZEWhOtnulI16mXAYiH8K194CPdjfk,3794
8
+ dr_wandb/store.py,sha256=gWvlC0NIjcKeRP1rZooBz6dDq2nS2wIudsgahLso3VM,7063
9
+ dr_wandb/utils.py,sha256=zzpHVOVo0QD82ik9ksQCP_vN7Zw0ov9dPGFfNMFgfmg,1796
10
+ dr_wandb/cli/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ dr_wandb/cli/download.py,sha256=V_V1q5HPXbnBorS0l1gMJlVRg4QlVl5LYL1K7-_j--s,2870
12
+ dr_wandb/cli/postgres_download.py,sha256=XvUY8Jl2u9BGo1l8QXn0foEFi5a3SfCARbvzZ-HxPoA,3710
13
+ dr_wandb-0.1.2.dist-info/METADATA,sha256=EI1oFoFETRG-3eYnzrdJBBUtuRulOEChqhEJE3HU4co,9250
14
+ dr_wandb-0.1.2.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
15
+ dr_wandb-0.1.2.dist-info/entry_points.txt,sha256=BATf5eJjnFMRULrNGiXfzL3ImYPdNK-MlatSzOFrtII,61
16
+ dr_wandb-0.1.2.dist-info/licenses/LICENSE,sha256=6tUm1Q55M1UBMbbawzFlF0-DgCazM1BELo_5-RXA1K4,1075
17
+ dr_wandb-0.1.2.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.27.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ wandb-download = dr_wandb.cli.download:app
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Danielle Rothermel
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.