dr-wandb 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dr-wandb might be problematic. Click here for more details.

dr_wandb/__init__.py ADDED
@@ -0,0 +1,2 @@
1
+ def hello() -> str:
2
+ return "Hello from dr-wandb!"
File without changes
@@ -0,0 +1,128 @@
1
+ import logging
2
+ from pathlib import Path
3
+
4
+ import click
5
+ from pydantic_settings import BaseSettings, SettingsConfigDict
6
+
7
+ from dr_wandb.downloader import Downloader
8
+ from dr_wandb.store import ProjectStore
9
+
10
+
11
+ class ProjDownloadSettings(BaseSettings):
12
+ model_config = SettingsConfigDict(env_file=".env", env_prefix="DR_WANDB_")
13
+
14
+ entity: str | None = None
15
+ project: str | None = None
16
+ database_url: str = "postgresql+psycopg2://localhost/wandb"
17
+ output_dir: Path = Path(__file__).parent.parent / "data"
18
+ runs_per_page: int = 500
19
+
20
+
21
+ def setup_logging(level: str = "INFO") -> None:
22
+ logging.basicConfig(
23
+ level=getattr(logging, level.upper()),
24
+ format="%(asctime)s - %(levelname)s - %(message)s",
25
+ datefmt="%Y-%m-%d %H:%M:%S",
26
+ )
27
+
28
+
29
+ def validate_settings(entity: str | None, project: str | None) -> None:
30
+ if not entity:
31
+ raise click.ClickException(
32
+ "--entity is required, or set DR_WANDB_ENTITY in .env"
33
+ )
34
+ if not project:
35
+ raise click.ClickException(
36
+ "--project is required, or set DR_WANDB_PROJECT in .env"
37
+ )
38
+
39
+
40
+ def resolve_config(
41
+ entity: str | None,
42
+ project: str | None,
43
+ db_url: str | None,
44
+ output_dir: str | None,
45
+ ) -> ProjDownloadSettings:
46
+ cfg = ProjDownloadSettings()
47
+ final_entity = entity if entity else cfg.entity
48
+ final_project = project if project else cfg.project
49
+ final_db_url = db_url if db_url else cfg.database_url
50
+ final_output_dir = output_dir if output_dir else cfg.output_dir
51
+ validate_settings(final_entity, final_project)
52
+ return ProjDownloadSettings(
53
+ entity=final_entity,
54
+ project=final_project,
55
+ database_url=final_db_url,
56
+ output_dir=final_output_dir,
57
+ runs_per_page=cfg.runs_per_page,
58
+ )
59
+
60
+
61
+ def execute_download(
62
+ cfg: ProjDownloadSettings, runs_only: bool, force_refresh: bool
63
+ ) -> None:
64
+ store = ProjectStore(
65
+ cfg.database_url,
66
+ output_dir=cfg.output_dir,
67
+ )
68
+ downloader = Downloader(store, runs_per_page=cfg.runs_per_page)
69
+ click.echo(">> Beginning download:")
70
+ stats = downloader.download_project(
71
+ entity=cfg.entity,
72
+ project=cfg.project,
73
+ runs_only=runs_only,
74
+ force_refresh=force_refresh,
75
+ )
76
+ click.echo(str(stats))
77
+ return downloader
78
+
79
+
80
+ @click.command()
81
+ @click.option(
82
+ "--entity",
83
+ envvar="DR_WANDB_ENTITY",
84
+ help="WandB entity (username or team name)",
85
+ )
86
+ @click.option("--project", envvar="DR_WANDB_PROJECT", help="WandB project name")
87
+ @click.option(
88
+ "--runs-only",
89
+ is_flag=True,
90
+ help="Only download runs, don't download history",
91
+ )
92
+ @click.option(
93
+ "--force-refresh",
94
+ is_flag=True,
95
+ help="Force refresh, download all data",
96
+ )
97
+ @click.option(
98
+ "--db-url",
99
+ envvar="DR_WANDB_DATABASE_URL",
100
+ help="PostgreSQL connection string",
101
+ )
102
+ @click.option(
103
+ "--output-dir",
104
+ envvar="DR_WANDB_OUTPUT_DIR",
105
+ help="Output directory",
106
+ )
107
+ def download_project(
108
+ entity: str | None,
109
+ project: str | None,
110
+ runs_only: bool,
111
+ force_refresh: bool,
112
+ db_url: str | None,
113
+ output_dir: str | None,
114
+ ) -> None:
115
+ setup_logging()
116
+ click.echo("\n:: Beginning Dr. Wandb Project Downloading Tool ::\n")
117
+ cfg = resolve_config(entity, project, db_url, output_dir)
118
+ click.echo(f">> Downloading project {cfg.entity}/{cfg.project}")
119
+ click.echo(f">> Database: {cfg.database_url}")
120
+ click.echo(f">> Output directory: {cfg.output_dir}")
121
+ click.echo(f">> Force refresh: {force_refresh} Runs only: {runs_only}")
122
+ click.echo()
123
+ downloader = execute_download(cfg, runs_only, force_refresh)
124
+ downloader.write_downloaded_to_parquet()
125
+
126
+
127
+ if __name__ == "__main__":
128
+ download_project()
dr_wandb/constants.py ADDED
@@ -0,0 +1,20 @@
1
+ from collections.abc import Callable
2
+ from typing import Literal
3
+
4
+ from sqlalchemy.orm import DeclarativeBase
5
+
6
+
7
+ class Base(DeclarativeBase):
8
+ pass
9
+
10
+
11
+ MAX_INT = 2**31 - 1
12
+
13
+ SUPPORTED_FILTER_FIELDS = ["project", "entity", "state", "run_ids"]
14
+ type FilterField = Literal["project", "entity", "state", "run_ids"]
15
+
16
+ WANDB_RUN_STATES = ["finished", "running", "crashed", "failed", "killed"]
17
+ type RunState = Literal["finished", "running", "crashed", "failed", "killed"]
18
+ type RunId = str
19
+
20
+ type ProgressCallback = Callable[[int, int, str], None]
dr_wandb/downloader.py ADDED
@@ -0,0 +1,118 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from dataclasses import dataclass
5
+
6
+ import wandb
7
+
8
+ from dr_wandb.constants import ProgressCallback
9
+ from dr_wandb.store import ProjectStore
10
+ from dr_wandb.utils import default_progress_callback, select_updated_runs
11
+
12
+
13
+ @dataclass
14
+ class DownloaderStats:
15
+ num_wandb_runs: int = 0
16
+ num_stored_runs: int = 0
17
+ num_new_runs: int = 0
18
+ num_updated_runs: int = 0
19
+
20
+ def __str__(self) -> str:
21
+ return "\n".join(
22
+ [
23
+ "",
24
+ ":: Downloader Stats ::",
25
+ f" - # WandB runs: {self.num_wandb_runs:,}",
26
+ f" - # Stored runs: {self.num_stored_runs:,}",
27
+ f" - # New runs: {self.num_new_runs:,}",
28
+ f" - # Updated runs: {self.num_updated_runs:,}",
29
+ "",
30
+ ]
31
+ )
32
+
33
+
34
+ class Downloader:
35
+ def __init__(
36
+ self,
37
+ store: ProjectStore,
38
+ runs_per_page: int = 500,
39
+ ) -> None:
40
+ self.store = store
41
+ self._api: wandb.Api | None = None
42
+ self.runs_per_page = runs_per_page
43
+ self.progress_callback: ProgressCallback = default_progress_callback
44
+
45
+ @property
46
+ def api(self) -> wandb.Api:
47
+ if self._api is None:
48
+ try:
49
+ self._api = wandb.Api()
50
+ except wandb.errors.UsageError as e:
51
+ if "api_key not configured" in str(e):
52
+ raise RuntimeError(
53
+ "WandB API key not configured. "
54
+ "Please run 'wandb login' or set WANDB_API_KEY env var"
55
+ ) from e
56
+ raise
57
+ return self._api
58
+
59
+ def set_progress_callback(self, progress_callback: ProgressCallback) -> None:
60
+ self.progress_callback = progress_callback
61
+
62
+ def get_all_runs(self, entity: str, project: str) -> list[wandb.apis.public.Run]:
63
+ return list(self.api.runs(f"{entity}/{project}", per_page=self.runs_per_page))
64
+
65
+ def download_runs(
66
+ self,
67
+ entity: str,
68
+ project: str,
69
+ force_refresh: bool = False,
70
+ with_history: bool = False,
71
+ ) -> DownloaderStats:
72
+ wandb_runs = self.get_all_runs(entity, project)
73
+ stored_states = self.store.get_existing_run_states(
74
+ {"entity": entity, "project": project}
75
+ )
76
+ runs_to_download = (
77
+ wandb_runs
78
+ if force_refresh
79
+ else select_updated_runs(wandb_runs, stored_states)
80
+ )
81
+ num_new_runs = len([r for r in runs_to_download if r.id not in stored_states])
82
+ stats = DownloaderStats(
83
+ num_wandb_runs=len(wandb_runs),
84
+ num_stored_runs=len(stored_states),
85
+ num_new_runs=num_new_runs,
86
+ num_updated_runs=len(runs_to_download) - num_new_runs,
87
+ )
88
+ if len(runs_to_download) == 0:
89
+ logging.info(">> No runs to download")
90
+ return stats
91
+
92
+ if not with_history:
93
+ logging.info(">> Runs only mode, bulk downloading runs")
94
+ self.store.store_runs(runs_to_download)
95
+ return stats
96
+
97
+ logging.info(">> Downloading runs and history data together")
98
+ for i, run in enumerate(runs_to_download):
99
+ self.store.store_run_and_history(run, list(run.scan_history()))
100
+ self.progress_callback(i + 1, len(runs_to_download), run.name)
101
+ return stats
102
+
103
+ def download_project(
104
+ self,
105
+ entity: str,
106
+ project: str,
107
+ runs_only: bool = False,
108
+ force_refresh: bool = False,
109
+ ) -> DownloaderStats:
110
+ stats = self.download_runs(
111
+ entity, project, force_refresh, with_history=not runs_only
112
+ )
113
+ logging.info(">> Download completed")
114
+ return stats
115
+
116
+ def write_downloaded_to_parquet(self) -> None:
117
+ logging.info(">> Beginning export to parquet")
118
+ self.store.export_to_parquet()
@@ -0,0 +1,62 @@
1
+ from __future__ import annotations
2
+
3
+ from datetime import datetime
4
+ from typing import Any
5
+
6
+ from sqlalchemy import Select, select
7
+ from sqlalchemy.dialects.postgresql import JSONB
8
+ from sqlalchemy.orm import Mapped, mapped_column
9
+
10
+ from dr_wandb.constants import Base, RunId
11
+ from dr_wandb.utils import extract_as_datetime
12
+
13
+ type HistoryEntry = dict[str, Any]
14
+
15
+
16
+ class HistoryEntryRecord(Base):
17
+ __tablename__ = "wandb_history"
18
+
19
+ id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
20
+ run_id: Mapped[str]
21
+ step: Mapped[int | None]
22
+ timestamp: Mapped[datetime | None]
23
+ runtime: Mapped[int | None]
24
+ wandb_metadata: Mapped[dict[str, Any]] = mapped_column(JSONB)
25
+ metrics: Mapped[dict[str, Any]] = mapped_column(JSONB)
26
+
27
+ @classmethod
28
+ def from_wandb_history(
29
+ cls, history_entry: HistoryEntry, run_id: str
30
+ ) -> HistoryEntryRecord:
31
+ return cls(
32
+ run_id=run_id,
33
+ step=history_entry.get("_step"),
34
+ timestamp=extract_as_datetime(history_entry, "_timestamp"),
35
+ runtime=history_entry.get("_runtime"),
36
+ wandb_metadata=history_entry.get("_wandb", {}),
37
+ metrics={k: v for k, v in history_entry.items() if not k.startswith("_")},
38
+ )
39
+
40
+ @classmethod
41
+ def standard_fields(cls) -> list[str]:
42
+ return [
43
+ col.name
44
+ for col in cls.__table__.columns
45
+ if col.name not in ["wandb_metadata", "metrics"]
46
+ ]
47
+
48
+ def to_dict(self, include_metadata: bool = False) -> dict[str, Any]:
49
+ return {
50
+ **{field: getattr(self, field) for field in self.standard_fields()},
51
+ **self.metrics,
52
+ **({"wandb_metadata": self.wandb_metadata} if include_metadata else {}),
53
+ }
54
+
55
+
56
+ def build_history_query(
57
+ run_ids: list[RunId] | None = None,
58
+ ) -> Select[HistoryEntryRecord]:
59
+ query = select(HistoryEntryRecord)
60
+ if run_ids is not None:
61
+ query = query.where(HistoryEntryRecord.run_id.in_(run_ids))
62
+ return query
dr_wandb/py.typed ADDED
File without changes
dr_wandb/run_record.py ADDED
@@ -0,0 +1,115 @@
1
+ from __future__ import annotations
2
+
3
+ from datetime import datetime
4
+ from typing import Any, Literal
5
+
6
+ import wandb
7
+ from sqlalchemy import Select, select
8
+ from sqlalchemy.dialects.postgresql import JSONB
9
+ from sqlalchemy.orm import Mapped, mapped_column
10
+
11
+ from dr_wandb.constants import (
12
+ SUPPORTED_FILTER_FIELDS,
13
+ Base,
14
+ FilterField,
15
+ RunId,
16
+ RunState,
17
+ )
18
+
19
+ RUN_DATA_COMPONENTS = [
20
+ "config",
21
+ "summary",
22
+ "wandb_metadata",
23
+ "system_metrics",
24
+ "system_attrs",
25
+ "sweep_info",
26
+ ]
27
+ type All = Literal["all"]
28
+ type RunDataComponent = Literal[
29
+ "config",
30
+ "summary",
31
+ "wandb_metadata",
32
+ "system_metrics",
33
+ "system_attrs",
34
+ "sweep_info",
35
+ ]
36
+
37
+
38
+ class RunRecord(Base):
39
+ __tablename__ = "wandb_runs"
40
+
41
+ run_id: Mapped[RunId] = mapped_column(primary_key=True)
42
+ run_name: Mapped[str]
43
+ state: Mapped[RunState]
44
+ project: Mapped[str]
45
+ entity: Mapped[str]
46
+ created_at: Mapped[datetime | None]
47
+
48
+ config: Mapped[dict[str, Any]] = mapped_column(JSONB)
49
+ summary: Mapped[dict[str, Any]] = mapped_column(JSONB)
50
+ wandb_metadata: Mapped[dict[str, Any]] = mapped_column(JSONB)
51
+ system_metrics: Mapped[dict[str, Any]] = mapped_column(JSONB)
52
+ system_attrs: Mapped[dict[str, Any]] = mapped_column(JSONB)
53
+ sweep_info: Mapped[dict[str, Any]] = mapped_column(JSONB)
54
+
55
+ @classmethod
56
+ def standard_fields(cls) -> list[str]:
57
+ return [
58
+ col.name
59
+ for col in cls.__table__.columns
60
+ if col.name not in RUN_DATA_COMPONENTS
61
+ ]
62
+
63
+ @classmethod
64
+ def from_wandb_run(cls, wandb_run: wandb.apis.public.Run) -> RunRecord:
65
+ return cls(
66
+ run_id=wandb_run.id,
67
+ run_name=wandb_run.name,
68
+ state=wandb_run.state,
69
+ project=wandb_run.project,
70
+ entity=wandb_run.entity,
71
+ created_at=wandb_run.created_at,
72
+ config=dict(wandb_run.config),
73
+ summary=dict(wandb_run.summary._json_dict) if wandb_run.summary else {}, # noqa: SLF001
74
+ wandb_metadata=wandb_run.metadata or {},
75
+ system_metrics=wandb_run.system_metrics or {},
76
+ system_attrs=dict(wandb_run._attrs), # noqa: SLF001
77
+ sweep_info={
78
+ "sweep_id": getattr(wandb_run, "sweep_id", None),
79
+ "sweep_url": getattr(wandb_run, "sweep_url", None),
80
+ },
81
+ )
82
+
83
+ def update_from_wandb_run(self, wandb_run: wandb.apis.public.Run) -> None:
84
+ updated = self.__class__.from_wandb_run(wandb_run)
85
+ for col in self.__table__.columns:
86
+ if col.name != "run_id":
87
+ setattr(self, col.name, getattr(updated, col.name))
88
+
89
+ def to_dict(
90
+ self, include: list[RunDataComponent] | All | None = None
91
+ ) -> dict[str, Any]:
92
+ include = include or []
93
+ if include == "all":
94
+ include = RUN_DATA_COMPONENTS
95
+ assert all(field in RUN_DATA_COMPONENTS for field in include)
96
+ data = {k: getattr(self, k) for k in self.standard_fields()}
97
+ for field in include:
98
+ data[field] = getattr(self, field)
99
+ return data
100
+
101
+
102
+ def build_run_query(kwargs: dict[FilterField, Any] | None = None) -> Select[RunRecord]:
103
+ query = select(RunRecord)
104
+ if kwargs is not None:
105
+ assert all(k in SUPPORTED_FILTER_FIELDS for k in kwargs)
106
+ assert all(v is not None for v in kwargs.values())
107
+ if "project" in kwargs:
108
+ query = query.where(RunRecord.project == kwargs["project"])
109
+ if "entity" in kwargs:
110
+ query = query.where(RunRecord.entity == kwargs["entity"])
111
+ if "state" in kwargs:
112
+ query = query.where(RunRecord.state == kwargs["state"])
113
+ if "run_ids" in kwargs:
114
+ query = query.where(RunRecord.run_id.in_(kwargs["run_ids"]))
115
+ return query
dr_wandb/store.py ADDED
@@ -0,0 +1,193 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from pathlib import Path
5
+ from typing import Any
6
+ from urllib.parse import urlparse
7
+
8
+ import pandas as pd
9
+ import wandb
10
+ from sqlalchemy import Engine, create_engine, text
11
+ from sqlalchemy.exc import OperationalError
12
+ from sqlalchemy.orm import Session
13
+
14
+ from dr_wandb.constants import (
15
+ Base,
16
+ FilterField,
17
+ RunId,
18
+ RunState,
19
+ )
20
+ from dr_wandb.history_entry_record import (
21
+ HistoryEntry,
22
+ HistoryEntryRecord,
23
+ build_history_query,
24
+ )
25
+ from dr_wandb.run_record import (
26
+ RUN_DATA_COMPONENTS,
27
+ All,
28
+ RunDataComponent,
29
+ RunRecord,
30
+ build_run_query,
31
+ )
32
+ from dr_wandb.utils import safe_convert_for_parquet
33
+
34
+ DEFAULT_OUTPUT_DIR = Path(__file__).parent.parent / "data"
35
+ DEFAULT_RUNS_FILENAME = "runs_metadata"
36
+ DEFAULT_HISTORY_FILENAME = "runs_history"
37
+ type History = list[HistoryEntry]
38
+
39
+
40
+ def delete_history_for_runs(session: Session, run_ids: list[RunId]) -> None:
41
+ if not run_ids:
42
+ return
43
+ session.execute(
44
+ text("DELETE FROM wandb_history WHERE run_id = ANY(:run_ids)"),
45
+ {"run_ids": run_ids},
46
+ )
47
+
48
+
49
+ def save_update_run(session: Session, run: wandb.apis.public.Run) -> None:
50
+ existing_run = session.get(RunRecord, run.id)
51
+ if existing_run:
52
+ existing_run.update_from_wandb_run(run)
53
+ else:
54
+ session.add(RunRecord.from_wandb_run(run))
55
+
56
+
57
+ def delete_add_history(session: Session, run_id: RunId, history: History) -> None:
58
+ delete_history_for_runs(session, [run_id])
59
+ for history_entry in history:
60
+ session.add(HistoryEntryRecord.from_wandb_history(history_entry, run_id))
61
+
62
+
63
+ def ensure_database_exists(database_url: str) -> str:
64
+ parsed = urlparse(database_url)
65
+ db_name = parsed.path.lstrip("/")
66
+ postgres_url = database_url.replace(f"/{db_name}", "/postgres")
67
+
68
+ try:
69
+ test_engine = create_engine(database_url)
70
+ with test_engine.connect():
71
+ pass
72
+ return database_url
73
+ except OperationalError as e:
74
+ if "does not exist" in str(e):
75
+ logging.info(f"Database '{db_name}' doesn't exist, creating it...")
76
+ postgres_engine = create_engine(postgres_url)
77
+ with postgres_engine.connect() as conn:
78
+ conn.execute(text("COMMIT"))
79
+ conn.execute(text(f'CREATE DATABASE "{db_name}"'))
80
+ logging.info(f"Created database '{db_name}'")
81
+ return database_url
82
+ else:
83
+ raise
84
+
85
+
86
+ class ProjectStore:
87
+ def __init__(self, connection_string: str, output_dir: str | None = None) -> None:
88
+ connection_string = ensure_database_exists(connection_string)
89
+ self.engine: Engine = create_engine(connection_string)
90
+ self.create_tables()
91
+ self.output_dir = output_dir if output_dir is not None else DEFAULT_OUTPUT_DIR
92
+
93
+ def create_tables(self) -> None:
94
+ Base.metadata.create_all(self.engine)
95
+
96
+ def store_run(self, run: wandb.apis.public.Run) -> None:
97
+ with Session(self.engine) as session:
98
+ save_update_run(session, run)
99
+ session.commit()
100
+
101
+ def store_runs(self, runs: list[wandb.apis.public.Run]) -> None:
102
+ with Session(self.engine) as session:
103
+ for run in runs:
104
+ save_update_run(session, run)
105
+ session.commit()
106
+
107
+ def store_history(self, run_id: RunId, history: History) -> None:
108
+ with Session(self.engine) as session:
109
+ delete_add_history(session, run_id, history)
110
+ session.commit()
111
+
112
+ def store_histories(
113
+ self,
114
+ runs: list[wandb.apis.public.Run],
115
+ histories: list[History],
116
+ ) -> None:
117
+ assert len(runs) == len(histories)
118
+ run_ids = [run.id for run in runs]
119
+ with Session(self.engine) as session:
120
+ delete_history_for_runs(session, run_ids)
121
+ for run_id, history in zip(run_ids, histories, strict=False):
122
+ for history_entry in history:
123
+ session.add(
124
+ HistoryEntryRecord.from_wandb_history(history_entry, run_id)
125
+ )
126
+ session.commit()
127
+
128
+ def store_run_and_history(
129
+ self, run: wandb.apis.public.Run, history: History
130
+ ) -> None:
131
+ with Session(self.engine) as session:
132
+ delete_add_history(session, run.id, history)
133
+ save_update_run(session, run)
134
+ session.commit()
135
+
136
+ def get_runs_df(
137
+ self,
138
+ include: list[RunDataComponent] | All | None = None,
139
+ kwargs: dict[FilterField, Any] | None = None,
140
+ ) -> pd.DataFrame:
141
+ with Session(self.engine) as session:
142
+ result = session.execute(build_run_query(kwargs=kwargs))
143
+ return pd.DataFrame(
144
+ [run.to_dict(include=include) for run in result.scalars().all()]
145
+ )
146
+
147
+ def get_history_df(
148
+ self,
149
+ include_metadata: bool = False,
150
+ run_ids: list[RunId] | None = None,
151
+ ) -> pd.DataFrame:
152
+ with Session(self.engine) as session:
153
+ result = session.execute(build_history_query(run_ids=run_ids))
154
+ return pd.DataFrame(
155
+ [
156
+ history.to_dict(include_metadata=include_metadata)
157
+ for history in result.scalars().all()
158
+ ]
159
+ )
160
+
161
+ def get_existing_run_states(
162
+ self, kwargs: dict[FilterField, Any] | None = None
163
+ ) -> dict[RunId, RunState]:
164
+ with Session(self.engine) as session:
165
+ result = session.execute(build_run_query(kwargs=kwargs))
166
+ return {run.run_id: run.state for run in result.scalars().all()}
167
+
168
+ def export_to_parquet(
169
+ self,
170
+ runs_filename: str = DEFAULT_RUNS_FILENAME,
171
+ history_filename: str = DEFAULT_HISTORY_FILENAME,
172
+ ) -> None:
173
+ self.output_dir.mkdir(exist_ok=True)
174
+ logging.info(f">> Using data output directory: {self.output_dir}")
175
+ history_df = self.get_history_df()
176
+ if not history_df.empty:
177
+ history_path = self.output_dir / f"{history_filename}.parquet"
178
+ history_df = safe_convert_for_parquet(history_df)
179
+ history_df.to_parquet(history_path, engine="pyarrow", index=False)
180
+ logging.info(f">> Wrote history_df to {history_path}")
181
+ for include_type in RUN_DATA_COMPONENTS:
182
+ runs_df = self.get_runs_df(include=[include_type])
183
+ if not runs_df.empty:
184
+ runs_path = self.output_dir / f"{runs_filename}_{include_type}.parquet"
185
+ runs_df = safe_convert_for_parquet(runs_df)
186
+ runs_df.to_parquet(runs_path, engine="pyarrow", index=False)
187
+ logging.info(f">> Wrote runs_df with {include_type} to {runs_path}")
188
+ runs_df_full = self.get_runs_df(include="all")
189
+ if not runs_df_full.empty:
190
+ runs_path = self.output_dir / f"{runs_filename}.parquet"
191
+ runs_df_full = safe_convert_for_parquet(runs_df_full)
192
+ runs_df_full.to_parquet(runs_path, engine="pyarrow", index=False)
193
+ logging.info(f">> Wrote runs_df with all parts to {runs_path}")
dr_wandb/utils.py ADDED
@@ -0,0 +1,57 @@
1
+ import json
2
+ import logging
3
+ from datetime import datetime
4
+ from typing import Any
5
+
6
+ import pandas as pd
7
+ import wandb
8
+
9
+ from dr_wandb.constants import MAX_INT, RunId, RunState
10
+
11
+
12
+ def extract_as_datetime(data: dict[str, Any], key: str) -> datetime | None:
13
+ timestamp = data.get(key)
14
+ return datetime.fromtimestamp(timestamp) if timestamp is not None else None
15
+
16
+
17
+ def select_updated_runs(
18
+ all_runs: list[wandb.apis.public.Run],
19
+ existing_run_states: dict[RunId, RunState],
20
+ ) -> list[wandb.apis.public.Run]:
21
+ return [
22
+ run
23
+ for run in all_runs
24
+ if run.id not in existing_run_states or existing_run_states[run.id] == "running"
25
+ ]
26
+
27
+
28
+ def default_progress_callback(run_index: int, total_runs: int, message: str) -> None:
29
+ logging.info(f">> {run_index}/{total_runs}: {message}")
30
+
31
+
32
+ def convert_large_ints_in_data(data: Any, max_int: int = MAX_INT) -> Any:
33
+ if isinstance(data, dict):
34
+ return {k: convert_large_ints_in_data(v, max_int) for k, v in data.items()}
35
+ elif isinstance(data, list):
36
+ return [convert_large_ints_in_data(item, max_int) for item in data]
37
+ elif isinstance(data, int) and abs(data) > max_int:
38
+ return float(data)
39
+ return data
40
+
41
+
42
+ def safe_convert_for_parquet(df: pd.DataFrame) -> pd.DataFrame:
43
+ df = df.copy()
44
+ for col in df.columns:
45
+ if df[col].dtype == "int64":
46
+ mask = df[col].abs() > MAX_INT
47
+ if mask.any():
48
+ df[col] = df[col].astype("float64")
49
+ elif df[col].dtype == "object":
50
+ df[col] = df[col].apply(
51
+ lambda x: json.dumps(convert_large_ints_in_data(x), default=str)
52
+ if isinstance(x, dict | list)
53
+ else str(x)
54
+ if x is not None
55
+ else None
56
+ )
57
+ return df
@@ -0,0 +1,123 @@
1
+ Metadata-Version: 2.4
2
+ Name: dr-wandb
3
+ Version: 0.1.0
4
+ Summary: Interact with wandb from python
5
+ Author-email: Danielle Rothermel <danielle.rothermel@gmail.com>
6
+ License-File: LICENSE
7
+ Requires-Python: >=3.12
8
+ Requires-Dist: pandas>=2.3.2
9
+ Requires-Dist: psycopg2>=2.9.10
10
+ Requires-Dist: pyarrow>=21.0.0
11
+ Requires-Dist: pydantic-settings>=2.10.1
12
+ Requires-Dist: sqlalchemy>=2.0.43
13
+ Requires-Dist: wandb>=0.21.4
14
+ Description-Content-Type: text/markdown
15
+
16
+ # dr_wandb
17
+
18
+ A command-line utility for downloading and archiving Weights & Biases experiment data to local storage formats optimized for offline analysis. Stores to PostgreSQL db + Parquet files, supports incremental updates and selective data retrieval.
19
+
20
+ ## Installation
21
+
22
+ ```bash
23
+ uv add dr_wandb
24
+ ```
25
+
26
+ ### Prerequisites
27
+
28
+ - Python 3.12 or higher
29
+ - PostgreSQL database server
30
+ - Weights & Biases account with API access
31
+ - PyArrow for Parquet file operations
32
+
33
+ ### Authentication
34
+
35
+ Configure Weights & Biases authentication using one of these methods:
36
+
37
+ ```bash
38
+ wandb login
39
+ ```
40
+
41
+ Or set the API key as an environment variable:
42
+
43
+ ```bash
44
+ export WANDB_API_KEY=your_api_key_here
45
+ ```
46
+
47
+ ## Basic Usage
48
+
49
+ Download all runs from a Weights & Biases project:
50
+
51
+ ```bash
52
+ wandb-download --entity your_entity --project your_project
53
+
54
+ Options:
55
+ --entity TEXT WandB entity (username or team name)
56
+ --project TEXT WandB project name
57
+ --runs-only Download only run metadata, skip training history
58
+ --force-refresh Download all data, ignoring existing records
59
+ --db-url TEXT PostgreSQL connection string
60
+ --output-dir TEXT Directory for exported Parquet files
61
+ --help Show help message and exit
62
+ ```
63
+
64
+ The tool creates a PostgreSQL database, downloads experiment data, and exports Parquet files to the configured output directory. It tool tracks existing data and downloads only new or updated runs by default. A run is considered for update if:
65
+
66
+ - It does not exist in the local database
67
+ - Its state is "running" (indicating potential new data)
68
+
69
+ Use `--force-refresh` to download all runs regardless of existing data.
70
+
71
+ ### Environment Variables
72
+
73
+ The tool reads configuration from environment variables with the `DR_WANDB_` prefix and supports `.env` files:
74
+
75
+ | Variable | Description | Default |
76
+ |----------|-------------|---------|
77
+ | `DR_WANDB_ENTITY` | Weights & Biases entity name | None |
78
+ | `DR_WANDB_PROJECT` | Weights & Biases project name | None |
79
+ | `DR_WANDB_DATABASE_URL` | PostgreSQL connection string | `postgresql+psycopg2://localhost/wandb` |
80
+ | `DR_WANDB_OUTPUT_DIR` | Directory for exported files | `./data` |
81
+
82
+ ### Database Configuration
83
+
84
+ The PostgreSQL connection string follows the standard format:
85
+
86
+ ```
87
+ postgresql+psycopg2://username:password@host:port/database_name
88
+ ```
89
+
90
+ If the specified database does not exist, the tool will attempt to create it automatically.
91
+
92
+ ## Data Schema
93
+
94
+
95
+ The tool generates the following files in the output directory:
96
+
97
+ - `runs_metadata.parquet` - Complete run metadata including configurations, summaries, and system information
98
+ - `runs_history.parquet` - Training metrics and logged values over time
99
+ - `runs_metadata_{component}.parquet` - Component-specific files for config, summary, wandb_metadata, system_metrics, system_attrs, and sweep_info
100
+
101
+
102
+ **Run Records**
103
+ - **run_id**: Unique identifier for the experiment run
104
+ - **run_name**: Human-readable name assigned to the run
105
+ - **state**: Current state (finished, running, crashed, failed, killed)
106
+ - **project**: Project name
107
+ - **entity**: Entity name
108
+ - **created_at**: Timestamp of run creation
109
+ - **config**: Experiment configuration parameters (JSONB)
110
+ - **summary**: Final metrics and outputs (JSONB)
111
+ - **wandb_metadata**: Platform-specific metadata (JSONB)
112
+ - **system_metrics**: Hardware and system information (JSONB)
113
+ - **system_attrs**: Additional system attributes (JSONB)
114
+ - **sweep_info**: Hyperparameter sweep information (JSONB)
115
+
116
+ **Training History Records**
117
+ - **run_id**: Reference to the parent run
118
+ - **step**: Training step number
119
+ - **timestamp**: Time of metric logging
120
+ - **runtime**: Elapsed time since run start
121
+ - **wandb_metadata**: Platform logging metadata (JSONB)
122
+ - **metrics**: All logged metrics and values (JSONB, flattened in Parquet export)
123
+
@@ -0,0 +1,15 @@
1
+ dr_wandb/__init__.py,sha256=aAqpPH5MBqIT_dPF5aEdqnjglgsas1MTaYqEJxyHc6s,54
2
+ dr_wandb/constants.py,sha256=aKbkVU08aRdfrcSYu_UYXPI48ZpSTZUSpVZeXK2L4L8,534
3
+ dr_wandb/downloader.py,sha256=X-NN1A1GilnUoxdEyHCsKJolqGBke_dIzS5wJWEAvvE,3888
4
+ dr_wandb/history_entry_record.py,sha256=ni9rXhYWxOg2kdidQ4norYAK37tGWj8xaB8R_lU4tw0,2010
5
+ dr_wandb/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ dr_wandb/run_record.py,sha256=X3fmNOhsBfJsaCZEWhOtnulI16mXAYiH8K194CPdjfk,3794
7
+ dr_wandb/store.py,sha256=gWvlC0NIjcKeRP1rZooBz6dDq2nS2wIudsgahLso3VM,7063
8
+ dr_wandb/utils.py,sha256=zzpHVOVo0QD82ik9ksQCP_vN7Zw0ov9dPGFfNMFgfmg,1796
9
+ dr_wandb/cli/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
10
+ dr_wandb/cli/download.py,sha256=XvUY8Jl2u9BGo1l8QXn0foEFi5a3SfCARbvzZ-HxPoA,3710
11
+ dr_wandb-0.1.0.dist-info/METADATA,sha256=2UBB8JfOPTCMWJbjXG8jOa0KC91GFwKlZguSzBdRwc8,4226
12
+ dr_wandb-0.1.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
13
+ dr_wandb-0.1.0.dist-info/entry_points.txt,sha256=l4X0h3JbfOr_-3pgqiq3iy4MqUTSiaFUMeVf0DTck88,74
14
+ dr_wandb-0.1.0.dist-info/licenses/LICENSE,sha256=6tUm1Q55M1UBMbbawzFlF0-DgCazM1BELo_5-RXA1K4,1075
15
+ dr_wandb-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.27.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ wandb-download = dr_wandb.cli.download:download_project
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Danielle Rothermel
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.