mouse-experiment 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mouse/__init__.py +0 -0
- mouse/runner/__init__.py +0 -0
- mouse/runner/__main__.py +20 -0
- mouse/runner/auth.py +111 -0
- mouse/runner/config.py +1491 -0
- mouse/runner/eval.py +129 -0
- mouse/runner/experiment.py +955 -0
- mouse/runner/hub.py +203 -0
- mouse/runner/py.typed +0 -0
- mouse/runner/test.py +565 -0
- mouse/runner/train.py +216 -0
- mouse_experiment-0.1.0.dist-info/METADATA +169 -0
- mouse_experiment-0.1.0.dist-info/RECORD +17 -0
- mouse_experiment-0.1.0.dist-info/WHEEL +5 -0
- mouse_experiment-0.1.0.dist-info/entry_points.txt +2 -0
- mouse_experiment-0.1.0.dist-info/licenses/LICENSE +674 -0
- mouse_experiment-0.1.0.dist-info/top_level.txt +1 -0
mouse/__init__.py
ADDED
|
File without changes
|
mouse/runner/__init__.py
ADDED
|
File without changes
|
mouse/runner/__main__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import sys
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def main() -> None:
|
|
7
|
+
from mouse.runner.experiment import main as run_main
|
|
8
|
+
|
|
9
|
+
config_path = ""
|
|
10
|
+
if len(sys.argv) > 1:
|
|
11
|
+
arg = sys.argv[1]
|
|
12
|
+
if not os.path.isabs(arg) and "/" not in arg and "\\" not in arg:
|
|
13
|
+
arg = str(Path.cwd() / "configs" / arg)
|
|
14
|
+
config_path = arg
|
|
15
|
+
|
|
16
|
+
run_main(config_path)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
if __name__ == "__main__":
|
|
20
|
+
main()
|
mouse/runner/auth.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
"""Authentication helpers for third-party services (Hugging Face, Weights & Biases, Trackio)."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import Any
|
|
5
|
+
from huggingface_hub import login as hf_login
|
|
6
|
+
|
|
7
|
+
_hf_authenticated: bool = False
|
|
8
|
+
_wandb_authenticated: bool = False
|
|
9
|
+
_trackio_initialized: bool = False
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def mirror_hf_token_env() -> None:
|
|
13
|
+
"""Copy alternate HF token env vars into ``HF_TOKEN`` for hub clients.
|
|
14
|
+
|
|
15
|
+
``transformers`` / ``huggingface_hub`` read ``HF_TOKEN`` from the environment
|
|
16
|
+
during ``AutoConfig.from_pretrained`` and similar calls. Those run from
|
|
17
|
+
``load_config()`` *before* ``setup_huggingface()`` in ``loop.main``, so
|
|
18
|
+
``RUNPOD_HF_TOKEN`` must be mirrored up front.
|
|
19
|
+
"""
|
|
20
|
+
if os.environ.get("HF_TOKEN"):
|
|
21
|
+
return
|
|
22
|
+
for env_var in ("RUNPOD_HF_TOKEN",):
|
|
23
|
+
token = os.environ.get(env_var)
|
|
24
|
+
if token:
|
|
25
|
+
os.environ["HF_TOKEN"] = token
|
|
26
|
+
break
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def setup_huggingface() -> None:
|
|
30
|
+
"""Authenticate with Hugging Face using env vars.
|
|
31
|
+
|
|
32
|
+
Expects one of: HF_TOKEN, RUNPOD_HF_TOKEN.
|
|
33
|
+
Idempotent: the /whoami-v2 verification is performed at most once per process
|
|
34
|
+
to avoid hitting the strict HF rate limit.
|
|
35
|
+
"""
|
|
36
|
+
global _hf_authenticated
|
|
37
|
+
if _hf_authenticated:
|
|
38
|
+
return
|
|
39
|
+
|
|
40
|
+
hf_token_envs = ["HF_TOKEN", "RUNPOD_HF_TOKEN"]
|
|
41
|
+
hf_token = None
|
|
42
|
+
for env_var in hf_token_envs:
|
|
43
|
+
hf_token = os.getenv(env_var)
|
|
44
|
+
if hf_token:
|
|
45
|
+
break
|
|
46
|
+
if not hf_token:
|
|
47
|
+
raise ValueError(f"None of the following environment variables are set: {hf_token_envs}")
|
|
48
|
+
|
|
49
|
+
os.environ["HF_TOKEN"] = hf_token
|
|
50
|
+
hf_login(token=hf_token)
|
|
51
|
+
_hf_authenticated = True
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def setup_wandb() -> None:
|
|
55
|
+
"""Authenticate with Weights & Biases using env vars.
|
|
56
|
+
|
|
57
|
+
Expects one of: WANDB_TOKEN, RUNPOD_WANDB_TOKEN.
|
|
58
|
+
Idempotent: login is performed at most once per process.
|
|
59
|
+
"""
|
|
60
|
+
global _wandb_authenticated
|
|
61
|
+
if _wandb_authenticated:
|
|
62
|
+
return
|
|
63
|
+
|
|
64
|
+
import wandb # type: ignore[import-untyped]
|
|
65
|
+
|
|
66
|
+
wandb_token_envs = ["WANDB_TOKEN", "RUNPOD_WANDB_TOKEN"]
|
|
67
|
+
wandb_token = None
|
|
68
|
+
for env_var in wandb_token_envs:
|
|
69
|
+
wandb_token = os.getenv(env_var)
|
|
70
|
+
if wandb_token:
|
|
71
|
+
break
|
|
72
|
+
if not wandb_token:
|
|
73
|
+
raise ValueError(f"None of the following environment variables are set: {wandb_token_envs}")
|
|
74
|
+
|
|
75
|
+
wandb.login(key=wandb_token, verify=True)
|
|
76
|
+
_wandb_authenticated = True
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def setup_trackio(project: str, run_name: str, config: dict, space_id: str | None = None) -> Any:
|
|
80
|
+
"""Initialize a Trackio run for experiment tracking.
|
|
81
|
+
|
|
82
|
+
Trackio stores metrics locally in SQLite and optionally syncs to a Hugging
|
|
83
|
+
Face Space dashboard. It is a drop-in complement to W&B — both can log the
|
|
84
|
+
same metrics in parallel.
|
|
85
|
+
|
|
86
|
+
Args:
|
|
87
|
+
project: Trackio project name (groups related runs).
|
|
88
|
+
run_name: Unique name for this run.
|
|
89
|
+
config: Flat or nested dict of hyperparameters to attach to the run.
|
|
90
|
+
space_id: Optional HF Space to sync the dashboard to, e.g. ``"user/my-space"``.
|
|
91
|
+
Also read from the ``TRACKIO_SPACE_ID`` environment variable when not given.
|
|
92
|
+
|
|
93
|
+
Returns:
|
|
94
|
+
The active trackio run object (supports ``.log()`` and ``.finish()``).
|
|
95
|
+
"""
|
|
96
|
+
global _trackio_initialized
|
|
97
|
+
import trackio # type: ignore[import-untyped]
|
|
98
|
+
|
|
99
|
+
init_kwargs: dict = dict(
|
|
100
|
+
project=project,
|
|
101
|
+
name=run_name,
|
|
102
|
+
config=config,
|
|
103
|
+
)
|
|
104
|
+
if space_id:
|
|
105
|
+
init_kwargs["space_id"] = space_id
|
|
106
|
+
elif os.environ.get("TRACKIO_SPACE_ID"):
|
|
107
|
+
init_kwargs["space_id"] = os.environ["TRACKIO_SPACE_ID"]
|
|
108
|
+
|
|
109
|
+
run = trackio.init(**init_kwargs)
|
|
110
|
+
_trackio_initialized = True
|
|
111
|
+
return run
|