xax 0.0.1__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +256 -1
- xax/core/conf.py +193 -0
- xax/core/state.py +81 -0
- xax/nn/__init__.py +0 -0
- xax/nn/embeddings.py +355 -0
- xax/nn/functions.py +77 -0
- xax/nn/parallel.py +211 -0
- xax/requirements-dev.txt +15 -0
- xax/requirements.txt +23 -0
- xax/task/__init__.py +0 -0
- xax/task/base.py +207 -0
- xax/task/launchers/__init__.py +0 -0
- xax/task/launchers/base.py +28 -0
- xax/task/launchers/cli.py +42 -0
- xax/task/launchers/single_process.py +30 -0
- xax/task/launchers/staged.py +29 -0
- xax/task/logger.py +783 -0
- xax/task/loggers/__init__.py +0 -0
- xax/task/loggers/callback.py +56 -0
- xax/task/loggers/json.py +121 -0
- xax/task/loggers/state.py +45 -0
- xax/task/loggers/stdout.py +170 -0
- xax/task/loggers/tensorboard.py +223 -0
- xax/task/mixins/__init__.py +12 -0
- xax/task/mixins/artifacts.py +114 -0
- xax/task/mixins/checkpointing.py +209 -0
- xax/task/mixins/cpu_stats.py +251 -0
- xax/task/mixins/data_loader.py +149 -0
- xax/task/mixins/gpu_stats.py +257 -0
- xax/task/mixins/logger.py +66 -0
- xax/task/mixins/process.py +51 -0
- xax/task/mixins/runnable.py +63 -0
- xax/task/mixins/step_wrapper.py +63 -0
- xax/task/mixins/train.py +541 -0
- xax/task/script.py +53 -0
- xax/task/task.py +65 -0
- xax/utils/__init__.py +0 -0
- xax/utils/data/__init__.py +0 -0
- xax/utils/data/collate.py +206 -0
- xax/utils/experiments.py +802 -0
- xax/utils/jax.py +14 -0
- xax/utils/logging.py +223 -0
- xax/utils/numpy.py +47 -0
- xax/utils/tensorboard.py +258 -0
- xax/utils/text.py +350 -0
- xax-0.0.5.dist-info/METADATA +40 -0
- xax-0.0.5.dist-info/RECORD +52 -0
- {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/WHEEL +1 -1
- xax-0.0.5.dist-info/top_level.txt +1 -0
- examples/mnist.py +0 -148
- xax-0.0.1.dist-info/METADATA +0 -21
- xax-0.0.1.dist-info/RECORD +0 -9
- xax-0.0.1.dist-info/top_level.txt +0 -2
- {examples → xax/core}/__init__.py +0 -0
- {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/LICENSE +0 -0
xax/task/base.py
ADDED
@@ -0,0 +1,207 @@
|
|
1
|
+
"""Defines the base task interface.
|
2
|
+
|
3
|
+
This interface is built upon by a large number of other interfaces which
|
4
|
+
compose various functionality into a single cohesive unit. The base task
|
5
|
+
just stores the configuration and provides hooks which are overridden by
|
6
|
+
upstream classes.
|
7
|
+
"""
|
8
|
+
|
9
|
+
import functools
|
10
|
+
import inspect
|
11
|
+
import logging
|
12
|
+
import sys
|
13
|
+
from dataclasses import dataclass, is_dataclass
|
14
|
+
from pathlib import Path
|
15
|
+
from types import TracebackType
|
16
|
+
from typing import Generic, Self, TypeVar, cast
|
17
|
+
|
18
|
+
from omegaconf import Container, DictConfig, OmegaConf
|
19
|
+
|
20
|
+
from xax.core.state import State
|
21
|
+
from xax.utils.text import camelcase_to_snakecase
|
22
|
+
|
23
|
+
logger = logging.getLogger(__name__)
|
24
|
+
|
25
|
+
|
26
|
+
@dataclass
|
27
|
+
class BaseConfig:
|
28
|
+
pass
|
29
|
+
|
30
|
+
|
31
|
+
Config = TypeVar("Config", bound=BaseConfig)
|
32
|
+
|
33
|
+
RawConfigType = BaseConfig | dict | DictConfig | str | Path
|
34
|
+
|
35
|
+
|
36
|
+
def _load_as_dict(path: str | Path) -> DictConfig:
|
37
|
+
cfg = OmegaConf.load(path)
|
38
|
+
if not isinstance(cfg, DictConfig):
|
39
|
+
raise TypeError(f"Config file at {path} must be a dictionary, not {type(cfg)}!")
|
40
|
+
return cfg
|
41
|
+
|
42
|
+
|
43
|
+
def get_config(cfg: RawConfigType, task_path: Path) -> DictConfig:
|
44
|
+
if isinstance(cfg, (str, Path)):
|
45
|
+
cfg = Path(cfg)
|
46
|
+
if cfg.exists():
|
47
|
+
cfg = _load_as_dict(cfg)
|
48
|
+
elif task_path is not None and len(cfg.parts) == 1 and (other_cfg_path := task_path.parent / cfg).exists():
|
49
|
+
cfg = _load_as_dict(other_cfg_path)
|
50
|
+
else:
|
51
|
+
raise FileNotFoundError(f"Could not find config file at {cfg}!")
|
52
|
+
elif isinstance(cfg, dict):
|
53
|
+
cfg = OmegaConf.create(cfg)
|
54
|
+
elif is_dataclass(cfg):
|
55
|
+
cfg = OmegaConf.structured(cfg)
|
56
|
+
return cast(DictConfig, cfg)
|
57
|
+
|
58
|
+
|
59
|
+
class BaseTask(Generic[Config]):
|
60
|
+
config: Config
|
61
|
+
|
62
|
+
def __init__(self, config: Config) -> None:
|
63
|
+
super().__init__()
|
64
|
+
|
65
|
+
self.config = config
|
66
|
+
|
67
|
+
if isinstance(self.config, Container):
|
68
|
+
OmegaConf.resolve(self.config)
|
69
|
+
|
70
|
+
def on_step_start(self, state: State) -> State:
|
71
|
+
return state
|
72
|
+
|
73
|
+
def on_step_end(self, state: State) -> State:
|
74
|
+
return state
|
75
|
+
|
76
|
+
def on_training_start(self, state: State) -> State:
|
77
|
+
return state
|
78
|
+
|
79
|
+
def on_training_end(self, state: State) -> State:
|
80
|
+
return state
|
81
|
+
|
82
|
+
@functools.cached_property
|
83
|
+
def task_class_name(self) -> str:
|
84
|
+
return self.__class__.__name__
|
85
|
+
|
86
|
+
@functools.cached_property
|
87
|
+
def task_name(self) -> str:
|
88
|
+
return camelcase_to_snakecase(self.task_class_name)
|
89
|
+
|
90
|
+
@functools.cached_property
|
91
|
+
def task_path(self) -> Path:
|
92
|
+
return Path(inspect.getfile(self.__class__))
|
93
|
+
|
94
|
+
@functools.cached_property
|
95
|
+
def task_module(self) -> str:
|
96
|
+
if (mod := inspect.getmodule(self.__class__)) is None:
|
97
|
+
raise RuntimeError(f"Could not find module for task {self.__class__}!")
|
98
|
+
if (spec := mod.__spec__) is None:
|
99
|
+
raise RuntimeError(f"Could not find spec for module {mod}!")
|
100
|
+
return spec.name
|
101
|
+
|
102
|
+
@property
|
103
|
+
def task_key(self) -> str:
|
104
|
+
return f"{self.task_module}.{self.task_class_name}"
|
105
|
+
|
106
|
+
@classmethod
|
107
|
+
def from_task_key(cls, task_key: str) -> type[Self]:
|
108
|
+
task_module, task_class_name = task_key.rsplit(".", 1)
|
109
|
+
try:
|
110
|
+
mod = __import__(task_module, fromlist=[task_class_name])
|
111
|
+
except ImportError as e:
|
112
|
+
raise ImportError(f"Could not import module {task_module} for task {task_key}") from e
|
113
|
+
if not hasattr(mod, task_class_name):
|
114
|
+
raise RuntimeError(f"Could not find class {task_class_name} in module {task_module}")
|
115
|
+
task_class = getattr(mod, task_class_name)
|
116
|
+
if not issubclass(task_class, cls):
|
117
|
+
raise RuntimeError(f"Class {task_class_name} in module {task_module} is not a subclass of {cls}")
|
118
|
+
return task_class
|
119
|
+
|
120
|
+
def debug(self) -> bool:
|
121
|
+
return False
|
122
|
+
|
123
|
+
@property
|
124
|
+
def debugging(self) -> bool:
|
125
|
+
return self.debug()
|
126
|
+
|
127
|
+
def __enter__(self) -> Self:
|
128
|
+
return self
|
129
|
+
|
130
|
+
def __exit__(self, _t: type[BaseException] | None, _e: BaseException | None, _tr: TracebackType | None) -> None:
|
131
|
+
pass
|
132
|
+
|
133
|
+
@classmethod
|
134
|
+
def get_config_class(cls) -> type[Config]:
|
135
|
+
"""Recursively retrieves the config class from the generic type.
|
136
|
+
|
137
|
+
Returns:
|
138
|
+
The parsed config class.
|
139
|
+
|
140
|
+
Raises:
|
141
|
+
ValueError: If the config class cannot be found, usually meaning
|
142
|
+
that the generic class has not been used correctly.
|
143
|
+
"""
|
144
|
+
if hasattr(cls, "__orig_bases__"):
|
145
|
+
for base in cls.__orig_bases__:
|
146
|
+
if hasattr(base, "__args__"):
|
147
|
+
for arg in base.__args__:
|
148
|
+
if issubclass(arg, BaseConfig):
|
149
|
+
return arg
|
150
|
+
|
151
|
+
raise ValueError(
|
152
|
+
"The config class could not be parsed from the generic type, which usually means that the task is not "
|
153
|
+
"being instantiated correctly. Your class should be defined as follows:\n\n"
|
154
|
+
" class ExampleTask(mlfab.Task[Config]):\n ...\n\nThis lets the both the task and the type "
|
155
|
+
"checker know what config the task is using."
|
156
|
+
)
|
157
|
+
|
158
|
+
@classmethod
|
159
|
+
def get_config(cls, *cfgs: RawConfigType, use_cli: bool | list[str] = True) -> Config:
|
160
|
+
"""Builds the structured config from the provided config classes.
|
161
|
+
|
162
|
+
Args:
|
163
|
+
cfgs: The config classes to merge. If a string or Path is provided,
|
164
|
+
it will be loaded as a YAML file.
|
165
|
+
use_cli: Whether to allow additional overrides from the CLI.
|
166
|
+
|
167
|
+
Returns:
|
168
|
+
The merged configs.
|
169
|
+
"""
|
170
|
+
task_path = Path(inspect.getfile(cls))
|
171
|
+
cfg = OmegaConf.structured(cls.get_config_class())
|
172
|
+
cfg = OmegaConf.merge(cfg, *(get_config(other_cfg, task_path) for other_cfg in cfgs))
|
173
|
+
if use_cli:
|
174
|
+
args = use_cli if isinstance(use_cli, list) else sys.argv[1:]
|
175
|
+
if "-h" in args or "--help" in args:
|
176
|
+
sys.stderr.write(OmegaConf.to_yaml(cfg))
|
177
|
+
sys.stderr.flush()
|
178
|
+
sys.exit(0)
|
179
|
+
|
180
|
+
# Attempts to load any paths as configs.
|
181
|
+
is_path = [Path(arg).is_file() or (task_path / arg).is_file() for arg in args]
|
182
|
+
paths = [arg for arg, is_path in zip(args, is_path) if is_path]
|
183
|
+
non_paths = [arg for arg, is_path in zip(args, is_path) if not is_path]
|
184
|
+
if paths:
|
185
|
+
cfg = OmegaConf.merge(cfg, *(get_config(path, task_path) for path in paths))
|
186
|
+
cfg = OmegaConf.merge(cfg, OmegaConf.from_cli(non_paths))
|
187
|
+
|
188
|
+
return cast(Config, cfg)
|
189
|
+
|
190
|
+
@classmethod
|
191
|
+
def config_str(cls, *cfgs: RawConfigType, use_cli: bool | list[str] = True) -> str:
|
192
|
+
return OmegaConf.to_yaml(cls.get_config(*cfgs, use_cli=use_cli))
|
193
|
+
|
194
|
+
@classmethod
|
195
|
+
def get_task(cls, *cfgs: RawConfigType, use_cli: bool | list[str] = True) -> Self:
|
196
|
+
"""Builds the task from the provided config classes.
|
197
|
+
|
198
|
+
Args:
|
199
|
+
cfgs: The config classes to merge. If a string or Path is provided,
|
200
|
+
it will be loaded as a YAML file.
|
201
|
+
use_cli: Whether to allow additional overrides from the CLI.
|
202
|
+
|
203
|
+
Returns:
|
204
|
+
The task.
|
205
|
+
"""
|
206
|
+
cfg = cls.get_config(*cfgs, use_cli=use_cli)
|
207
|
+
return cls(cfg)
|
File without changes
|
@@ -0,0 +1,28 @@
|
|
1
|
+
"""Defines the base launcher class."""
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod
|
4
|
+
from typing import TYPE_CHECKING
|
5
|
+
|
6
|
+
from xax.task.base import RawConfigType
|
7
|
+
|
8
|
+
if TYPE_CHECKING:
|
9
|
+
from xax.task.mixins.runnable import Config, RunnableMixin
|
10
|
+
|
11
|
+
|
12
|
+
class BaseLauncher(ABC):
|
13
|
+
"""Defines the base launcher class."""
|
14
|
+
|
15
|
+
@abstractmethod
|
16
|
+
def launch(
|
17
|
+
self,
|
18
|
+
task: "type[RunnableMixin[Config]]",
|
19
|
+
*cfgs: RawConfigType,
|
20
|
+
use_cli: bool | list[str] = True,
|
21
|
+
) -> None:
|
22
|
+
"""Launches the training process.
|
23
|
+
|
24
|
+
Args:
|
25
|
+
task: The task class to train
|
26
|
+
cfgs: The raw configuration to use for training
|
27
|
+
use_cli: Whether to include CLI arguments in the configuration
|
28
|
+
"""
|
@@ -0,0 +1,42 @@
|
|
1
|
+
"""Defines a launcher that can be toggled from the command line."""
|
2
|
+
|
3
|
+
import argparse
|
4
|
+
import sys
|
5
|
+
from typing import TYPE_CHECKING, Literal, get_args
|
6
|
+
|
7
|
+
from xax.task.base import RawConfigType
|
8
|
+
from xax.task.launchers.base import BaseLauncher
|
9
|
+
from xax.task.launchers.single_process import SingleProcessLauncher
|
10
|
+
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
from xax.task.mixins.runnable import Config, RunnableMixin
|
13
|
+
|
14
|
+
|
15
|
+
LauncherChoice = Literal["single"]
|
16
|
+
|
17
|
+
|
18
|
+
class CliLauncher(BaseLauncher):
|
19
|
+
def launch(
|
20
|
+
self,
|
21
|
+
task: "type[RunnableMixin[Config]]",
|
22
|
+
*cfgs: RawConfigType,
|
23
|
+
use_cli: bool | list[str] = True,
|
24
|
+
) -> None:
|
25
|
+
args = use_cli if isinstance(use_cli, list) else sys.argv[1:]
|
26
|
+
parser = argparse.ArgumentParser(add_help=False)
|
27
|
+
parser.add_argument(
|
28
|
+
"-l",
|
29
|
+
"--launcher",
|
30
|
+
choices=get_args(LauncherChoice),
|
31
|
+
default="single",
|
32
|
+
help="The launcher to use",
|
33
|
+
)
|
34
|
+
args, cli_args_rest = parser.parse_known_intermixed_args(args=args)
|
35
|
+
launcher_choice: LauncherChoice = args.launcher
|
36
|
+
use_cli_next: bool | list[str] = False if not use_cli else cli_args_rest
|
37
|
+
|
38
|
+
match launcher_choice:
|
39
|
+
case "single":
|
40
|
+
SingleProcessLauncher().launch(task, *cfgs, use_cli=use_cli_next)
|
41
|
+
case _:
|
42
|
+
raise ValueError(f"Invalid launcher choice: {launcher_choice}")
|
@@ -0,0 +1,30 @@
|
|
1
|
+
"""Defines a launcher to train a model locally, in a single process."""
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING
|
4
|
+
|
5
|
+
from xax.task.base import RawConfigType
|
6
|
+
from xax.task.launchers.base import BaseLauncher
|
7
|
+
from xax.utils.logging import configure_logging
|
8
|
+
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from xax.task.mixins.runnable import Config, RunnableMixin
|
11
|
+
|
12
|
+
|
13
|
+
def run_single_process_training(
|
14
|
+
task: "type[RunnableMixin[Config]]",
|
15
|
+
*cfgs: RawConfigType,
|
16
|
+
use_cli: bool | list[str] = True,
|
17
|
+
) -> None:
|
18
|
+
configure_logging()
|
19
|
+
task_obj = task.get_task(*cfgs, use_cli=use_cli)
|
20
|
+
task_obj.run()
|
21
|
+
|
22
|
+
|
23
|
+
class SingleProcessLauncher(BaseLauncher):
|
24
|
+
def launch(
|
25
|
+
self,
|
26
|
+
task: "type[RunnableMixin[Config]]",
|
27
|
+
*cfgs: RawConfigType,
|
28
|
+
use_cli: bool | list[str] = True,
|
29
|
+
) -> None:
|
30
|
+
run_single_process_training(task, *cfgs, use_cli=use_cli)
|
@@ -0,0 +1,29 @@
|
|
1
|
+
"""Defines a base class with utility functions for staged training runs."""
|
2
|
+
|
3
|
+
from abc import ABC
|
4
|
+
from pathlib import Path
|
5
|
+
|
6
|
+
from xax.task.launchers.base import BaseLauncher
|
7
|
+
from xax.task.mixins.artifacts import ArtifactsMixin, Config
|
8
|
+
|
9
|
+
|
10
|
+
class StagedLauncher(BaseLauncher, ABC):
|
11
|
+
def __init__(self, config_file_name: str = "config.yaml") -> None:
|
12
|
+
super().__init__()
|
13
|
+
|
14
|
+
self.config_file_name = config_file_name
|
15
|
+
|
16
|
+
def get_config_path(self, task: "ArtifactsMixin[Config]", use_cli: bool | list[str] = True) -> Path:
|
17
|
+
config_path = task.exp_dir / self.config_file_name
|
18
|
+
task.config.exp_dir = str(task.exp_dir)
|
19
|
+
with open(config_path, "w", encoding="utf-8") as f:
|
20
|
+
f.write(task.config_str(task.config, use_cli=use_cli))
|
21
|
+
return config_path
|
22
|
+
|
23
|
+
@classmethod
|
24
|
+
def from_components(cls, task_key: str, config_path: Path, use_cli: bool | list[str] = True) -> "ArtifactsMixin":
|
25
|
+
return (
|
26
|
+
ArtifactsMixin.from_task_key(task_key)
|
27
|
+
.get_task(config_path, use_cli=use_cli)
|
28
|
+
.set_exp_dir(config_path.parent)
|
29
|
+
)
|