xax 0.0.1__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. xax/__init__.py +256 -1
  2. xax/core/conf.py +193 -0
  3. xax/core/state.py +81 -0
  4. xax/nn/__init__.py +0 -0
  5. xax/nn/embeddings.py +355 -0
  6. xax/nn/functions.py +77 -0
  7. xax/nn/parallel.py +211 -0
  8. xax/requirements-dev.txt +15 -0
  9. xax/requirements.txt +23 -0
  10. xax/task/__init__.py +0 -0
  11. xax/task/base.py +207 -0
  12. xax/task/launchers/__init__.py +0 -0
  13. xax/task/launchers/base.py +28 -0
  14. xax/task/launchers/cli.py +42 -0
  15. xax/task/launchers/single_process.py +30 -0
  16. xax/task/launchers/staged.py +29 -0
  17. xax/task/logger.py +783 -0
  18. xax/task/loggers/__init__.py +0 -0
  19. xax/task/loggers/callback.py +56 -0
  20. xax/task/loggers/json.py +121 -0
  21. xax/task/loggers/state.py +45 -0
  22. xax/task/loggers/stdout.py +170 -0
  23. xax/task/loggers/tensorboard.py +223 -0
  24. xax/task/mixins/__init__.py +12 -0
  25. xax/task/mixins/artifacts.py +114 -0
  26. xax/task/mixins/checkpointing.py +209 -0
  27. xax/task/mixins/cpu_stats.py +251 -0
  28. xax/task/mixins/data_loader.py +149 -0
  29. xax/task/mixins/gpu_stats.py +257 -0
  30. xax/task/mixins/logger.py +66 -0
  31. xax/task/mixins/process.py +51 -0
  32. xax/task/mixins/runnable.py +63 -0
  33. xax/task/mixins/step_wrapper.py +63 -0
  34. xax/task/mixins/train.py +541 -0
  35. xax/task/script.py +53 -0
  36. xax/task/task.py +65 -0
  37. xax/utils/__init__.py +0 -0
  38. xax/utils/data/__init__.py +0 -0
  39. xax/utils/data/collate.py +206 -0
  40. xax/utils/experiments.py +802 -0
  41. xax/utils/jax.py +14 -0
  42. xax/utils/logging.py +223 -0
  43. xax/utils/numpy.py +47 -0
  44. xax/utils/tensorboard.py +258 -0
  45. xax/utils/text.py +350 -0
  46. xax-0.0.5.dist-info/METADATA +40 -0
  47. xax-0.0.5.dist-info/RECORD +52 -0
  48. {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/WHEEL +1 -1
  49. xax-0.0.5.dist-info/top_level.txt +1 -0
  50. examples/mnist.py +0 -148
  51. xax-0.0.1.dist-info/METADATA +0 -21
  52. xax-0.0.1.dist-info/RECORD +0 -9
  53. xax-0.0.1.dist-info/top_level.txt +0 -2
  54. {examples → xax/core}/__init__.py +0 -0
  55. {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/LICENSE +0 -0
xax/task/base.py ADDED
@@ -0,0 +1,207 @@
1
+ """Defines the base task interface.
2
+
3
+ This interface is built upon by a large number of other interfaces which
4
+ compose various functionality into a single cohesive unit. The base task
5
+ just stores the configuration and provides hooks which are overridden by
6
+ upstream classes.
7
+ """
8
+
9
+ import functools
10
+ import inspect
11
+ import logging
12
+ import sys
13
+ from dataclasses import dataclass, is_dataclass
14
+ from pathlib import Path
15
+ from types import TracebackType
16
+ from typing import Generic, Self, TypeVar, cast
17
+
18
+ from omegaconf import Container, DictConfig, OmegaConf
19
+
20
+ from xax.core.state import State
21
+ from xax.utils.text import camelcase_to_snakecase
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ @dataclass
27
+ class BaseConfig:
28
+ pass
29
+
30
+
31
+ Config = TypeVar("Config", bound=BaseConfig)
32
+
33
+ RawConfigType = BaseConfig | dict | DictConfig | str | Path
34
+
35
+
36
+ def _load_as_dict(path: str | Path) -> DictConfig:
37
+ cfg = OmegaConf.load(path)
38
+ if not isinstance(cfg, DictConfig):
39
+ raise TypeError(f"Config file at {path} must be a dictionary, not {type(cfg)}!")
40
+ return cfg
41
+
42
+
43
+ def get_config(cfg: RawConfigType, task_path: Path) -> DictConfig:
44
+ if isinstance(cfg, (str, Path)):
45
+ cfg = Path(cfg)
46
+ if cfg.exists():
47
+ cfg = _load_as_dict(cfg)
48
+ elif task_path is not None and len(cfg.parts) == 1 and (other_cfg_path := task_path.parent / cfg).exists():
49
+ cfg = _load_as_dict(other_cfg_path)
50
+ else:
51
+ raise FileNotFoundError(f"Could not find config file at {cfg}!")
52
+ elif isinstance(cfg, dict):
53
+ cfg = OmegaConf.create(cfg)
54
+ elif is_dataclass(cfg):
55
+ cfg = OmegaConf.structured(cfg)
56
+ return cast(DictConfig, cfg)
57
+
58
+
59
+ class BaseTask(Generic[Config]):
60
+ config: Config
61
+
62
+ def __init__(self, config: Config) -> None:
63
+ super().__init__()
64
+
65
+ self.config = config
66
+
67
+ if isinstance(self.config, Container):
68
+ OmegaConf.resolve(self.config)
69
+
70
+ def on_step_start(self, state: State) -> State:
71
+ return state
72
+
73
+ def on_step_end(self, state: State) -> State:
74
+ return state
75
+
76
+ def on_training_start(self, state: State) -> State:
77
+ return state
78
+
79
+ def on_training_end(self, state: State) -> State:
80
+ return state
81
+
82
+ @functools.cached_property
83
+ def task_class_name(self) -> str:
84
+ return self.__class__.__name__
85
+
86
+ @functools.cached_property
87
+ def task_name(self) -> str:
88
+ return camelcase_to_snakecase(self.task_class_name)
89
+
90
+ @functools.cached_property
91
+ def task_path(self) -> Path:
92
+ return Path(inspect.getfile(self.__class__))
93
+
94
+ @functools.cached_property
95
+ def task_module(self) -> str:
96
+ if (mod := inspect.getmodule(self.__class__)) is None:
97
+ raise RuntimeError(f"Could not find module for task {self.__class__}!")
98
+ if (spec := mod.__spec__) is None:
99
+ raise RuntimeError(f"Could not find spec for module {mod}!")
100
+ return spec.name
101
+
102
+ @property
103
+ def task_key(self) -> str:
104
+ return f"{self.task_module}.{self.task_class_name}"
105
+
106
+ @classmethod
107
+ def from_task_key(cls, task_key: str) -> type[Self]:
108
+ task_module, task_class_name = task_key.rsplit(".", 1)
109
+ try:
110
+ mod = __import__(task_module, fromlist=[task_class_name])
111
+ except ImportError as e:
112
+ raise ImportError(f"Could not import module {task_module} for task {task_key}") from e
113
+ if not hasattr(mod, task_class_name):
114
+ raise RuntimeError(f"Could not find class {task_class_name} in module {task_module}")
115
+ task_class = getattr(mod, task_class_name)
116
+ if not issubclass(task_class, cls):
117
+ raise RuntimeError(f"Class {task_class_name} in module {task_module} is not a subclass of {cls}")
118
+ return task_class
119
+
120
+ def debug(self) -> bool:
121
+ return False
122
+
123
+ @property
124
+ def debugging(self) -> bool:
125
+ return self.debug()
126
+
127
+ def __enter__(self) -> Self:
128
+ return self
129
+
130
+ def __exit__(self, _t: type[BaseException] | None, _e: BaseException | None, _tr: TracebackType | None) -> None:
131
+ pass
132
+
133
+ @classmethod
134
+ def get_config_class(cls) -> type[Config]:
135
+ """Recursively retrieves the config class from the generic type.
136
+
137
+ Returns:
138
+ The parsed config class.
139
+
140
+ Raises:
141
+ ValueError: If the config class cannot be found, usually meaning
142
+ that the generic class has not been used correctly.
143
+ """
144
+ if hasattr(cls, "__orig_bases__"):
145
+ for base in cls.__orig_bases__:
146
+ if hasattr(base, "__args__"):
147
+ for arg in base.__args__:
148
+ if issubclass(arg, BaseConfig):
149
+ return arg
150
+
151
+ raise ValueError(
152
+ "The config class could not be parsed from the generic type, which usually means that the task is not "
153
+ "being instantiated correctly. Your class should be defined as follows:\n\n"
154
+ " class ExampleTask(mlfab.Task[Config]):\n ...\n\nThis lets the both the task and the type "
155
+ "checker know what config the task is using."
156
+ )
157
+
158
+ @classmethod
159
+ def get_config(cls, *cfgs: RawConfigType, use_cli: bool | list[str] = True) -> Config:
160
+ """Builds the structured config from the provided config classes.
161
+
162
+ Args:
163
+ cfgs: The config classes to merge. If a string or Path is provided,
164
+ it will be loaded as a YAML file.
165
+ use_cli: Whether to allow additional overrides from the CLI.
166
+
167
+ Returns:
168
+ The merged configs.
169
+ """
170
+ task_path = Path(inspect.getfile(cls))
171
+ cfg = OmegaConf.structured(cls.get_config_class())
172
+ cfg = OmegaConf.merge(cfg, *(get_config(other_cfg, task_path) for other_cfg in cfgs))
173
+ if use_cli:
174
+ args = use_cli if isinstance(use_cli, list) else sys.argv[1:]
175
+ if "-h" in args or "--help" in args:
176
+ sys.stderr.write(OmegaConf.to_yaml(cfg))
177
+ sys.stderr.flush()
178
+ sys.exit(0)
179
+
180
+ # Attempts to load any paths as configs.
181
+ is_path = [Path(arg).is_file() or (task_path / arg).is_file() for arg in args]
182
+ paths = [arg for arg, is_path in zip(args, is_path) if is_path]
183
+ non_paths = [arg for arg, is_path in zip(args, is_path) if not is_path]
184
+ if paths:
185
+ cfg = OmegaConf.merge(cfg, *(get_config(path, task_path) for path in paths))
186
+ cfg = OmegaConf.merge(cfg, OmegaConf.from_cli(non_paths))
187
+
188
+ return cast(Config, cfg)
189
+
190
+ @classmethod
191
+ def config_str(cls, *cfgs: RawConfigType, use_cli: bool | list[str] = True) -> str:
192
+ return OmegaConf.to_yaml(cls.get_config(*cfgs, use_cli=use_cli))
193
+
194
+ @classmethod
195
+ def get_task(cls, *cfgs: RawConfigType, use_cli: bool | list[str] = True) -> Self:
196
+ """Builds the task from the provided config classes.
197
+
198
+ Args:
199
+ cfgs: The config classes to merge. If a string or Path is provided,
200
+ it will be loaded as a YAML file.
201
+ use_cli: Whether to allow additional overrides from the CLI.
202
+
203
+ Returns:
204
+ The task.
205
+ """
206
+ cfg = cls.get_config(*cfgs, use_cli=use_cli)
207
+ return cls(cfg)
File without changes
@@ -0,0 +1,28 @@
1
+ """Defines the base launcher class."""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import TYPE_CHECKING
5
+
6
+ from xax.task.base import RawConfigType
7
+
8
+ if TYPE_CHECKING:
9
+ from xax.task.mixins.runnable import Config, RunnableMixin
10
+
11
+
12
+ class BaseLauncher(ABC):
13
+ """Defines the base launcher class."""
14
+
15
+ @abstractmethod
16
+ def launch(
17
+ self,
18
+ task: "type[RunnableMixin[Config]]",
19
+ *cfgs: RawConfigType,
20
+ use_cli: bool | list[str] = True,
21
+ ) -> None:
22
+ """Launches the training process.
23
+
24
+ Args:
25
+ task: The task class to train
26
+ cfgs: The raw configuration to use for training
27
+ use_cli: Whether to include CLI arguments in the configuration
28
+ """
@@ -0,0 +1,42 @@
1
+ """Defines a launcher that can be toggled from the command line."""
2
+
3
+ import argparse
4
+ import sys
5
+ from typing import TYPE_CHECKING, Literal, get_args
6
+
7
+ from xax.task.base import RawConfigType
8
+ from xax.task.launchers.base import BaseLauncher
9
+ from xax.task.launchers.single_process import SingleProcessLauncher
10
+
11
+ if TYPE_CHECKING:
12
+ from xax.task.mixins.runnable import Config, RunnableMixin
13
+
14
+
15
+ LauncherChoice = Literal["single"]
16
+
17
+
18
+ class CliLauncher(BaseLauncher):
19
+ def launch(
20
+ self,
21
+ task: "type[RunnableMixin[Config]]",
22
+ *cfgs: RawConfigType,
23
+ use_cli: bool | list[str] = True,
24
+ ) -> None:
25
+ args = use_cli if isinstance(use_cli, list) else sys.argv[1:]
26
+ parser = argparse.ArgumentParser(add_help=False)
27
+ parser.add_argument(
28
+ "-l",
29
+ "--launcher",
30
+ choices=get_args(LauncherChoice),
31
+ default="single",
32
+ help="The launcher to use",
33
+ )
34
+ args, cli_args_rest = parser.parse_known_intermixed_args(args=args)
35
+ launcher_choice: LauncherChoice = args.launcher
36
+ use_cli_next: bool | list[str] = False if not use_cli else cli_args_rest
37
+
38
+ match launcher_choice:
39
+ case "single":
40
+ SingleProcessLauncher().launch(task, *cfgs, use_cli=use_cli_next)
41
+ case _:
42
+ raise ValueError(f"Invalid launcher choice: {launcher_choice}")
@@ -0,0 +1,30 @@
1
+ """Defines a launcher to train a model locally, in a single process."""
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ from xax.task.base import RawConfigType
6
+ from xax.task.launchers.base import BaseLauncher
7
+ from xax.utils.logging import configure_logging
8
+
9
+ if TYPE_CHECKING:
10
+ from xax.task.mixins.runnable import Config, RunnableMixin
11
+
12
+
13
+ def run_single_process_training(
14
+ task: "type[RunnableMixin[Config]]",
15
+ *cfgs: RawConfigType,
16
+ use_cli: bool | list[str] = True,
17
+ ) -> None:
18
+ configure_logging()
19
+ task_obj = task.get_task(*cfgs, use_cli=use_cli)
20
+ task_obj.run()
21
+
22
+
23
+ class SingleProcessLauncher(BaseLauncher):
24
+ def launch(
25
+ self,
26
+ task: "type[RunnableMixin[Config]]",
27
+ *cfgs: RawConfigType,
28
+ use_cli: bool | list[str] = True,
29
+ ) -> None:
30
+ run_single_process_training(task, *cfgs, use_cli=use_cli)
@@ -0,0 +1,29 @@
1
+ """Defines a base class with utility functions for staged training runs."""
2
+
3
+ from abc import ABC
4
+ from pathlib import Path
5
+
6
+ from xax.task.launchers.base import BaseLauncher
7
+ from xax.task.mixins.artifacts import ArtifactsMixin, Config
8
+
9
+
10
+ class StagedLauncher(BaseLauncher, ABC):
11
+ def __init__(self, config_file_name: str = "config.yaml") -> None:
12
+ super().__init__()
13
+
14
+ self.config_file_name = config_file_name
15
+
16
+ def get_config_path(self, task: "ArtifactsMixin[Config]", use_cli: bool | list[str] = True) -> Path:
17
+ config_path = task.exp_dir / self.config_file_name
18
+ task.config.exp_dir = str(task.exp_dir)
19
+ with open(config_path, "w", encoding="utf-8") as f:
20
+ f.write(task.config_str(task.config, use_cli=use_cli))
21
+ return config_path
22
+
23
+ @classmethod
24
+ def from_components(cls, task_key: str, config_path: Path, use_cli: bool | list[str] = True) -> "ArtifactsMixin":
25
+ return (
26
+ ArtifactsMixin.from_task_key(task_key)
27
+ .get_task(config_path, use_cli=use_cli)
28
+ .set_exp_dir(config_path.parent)
29
+ )