xax 0.0.1__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. xax/__init__.py +256 -1
  2. xax/core/conf.py +193 -0
  3. xax/core/state.py +81 -0
  4. xax/nn/__init__.py +0 -0
  5. xax/nn/embeddings.py +355 -0
  6. xax/nn/functions.py +77 -0
  7. xax/nn/parallel.py +211 -0
  8. xax/requirements-dev.txt +15 -0
  9. xax/requirements.txt +23 -0
  10. xax/task/__init__.py +0 -0
  11. xax/task/base.py +207 -0
  12. xax/task/launchers/__init__.py +0 -0
  13. xax/task/launchers/base.py +28 -0
  14. xax/task/launchers/cli.py +42 -0
  15. xax/task/launchers/single_process.py +30 -0
  16. xax/task/launchers/staged.py +29 -0
  17. xax/task/logger.py +783 -0
  18. xax/task/loggers/__init__.py +0 -0
  19. xax/task/loggers/callback.py +56 -0
  20. xax/task/loggers/json.py +121 -0
  21. xax/task/loggers/state.py +45 -0
  22. xax/task/loggers/stdout.py +170 -0
  23. xax/task/loggers/tensorboard.py +223 -0
  24. xax/task/mixins/__init__.py +12 -0
  25. xax/task/mixins/artifacts.py +114 -0
  26. xax/task/mixins/checkpointing.py +209 -0
  27. xax/task/mixins/cpu_stats.py +251 -0
  28. xax/task/mixins/data_loader.py +149 -0
  29. xax/task/mixins/gpu_stats.py +257 -0
  30. xax/task/mixins/logger.py +66 -0
  31. xax/task/mixins/process.py +51 -0
  32. xax/task/mixins/runnable.py +63 -0
  33. xax/task/mixins/step_wrapper.py +63 -0
  34. xax/task/mixins/train.py +541 -0
  35. xax/task/script.py +53 -0
  36. xax/task/task.py +65 -0
  37. xax/utils/__init__.py +0 -0
  38. xax/utils/data/__init__.py +0 -0
  39. xax/utils/data/collate.py +206 -0
  40. xax/utils/experiments.py +802 -0
  41. xax/utils/jax.py +14 -0
  42. xax/utils/logging.py +223 -0
  43. xax/utils/numpy.py +47 -0
  44. xax/utils/tensorboard.py +258 -0
  45. xax/utils/text.py +350 -0
  46. xax-0.0.5.dist-info/METADATA +40 -0
  47. xax-0.0.5.dist-info/RECORD +52 -0
  48. {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/WHEEL +1 -1
  49. xax-0.0.5.dist-info/top_level.txt +1 -0
  50. examples/mnist.py +0 -148
  51. xax-0.0.1.dist-info/METADATA +0 -21
  52. xax-0.0.1.dist-info/RECORD +0 -9
  53. xax-0.0.1.dist-info/top_level.txt +0 -2
  54. {examples → xax/core}/__init__.py +0 -0
  55. {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/LICENSE +0 -0
xax/__init__.py CHANGED
@@ -1 +1,256 @@
1
- __version__ = "0.0.1"
1
+ """Defines the top-level xax API.
2
+
3
+ This package is structured so that all the important stuff can be accessed
4
+ without having to dig around through the internals. This is done by lazily
5
+ importing the module by name.
6
+
7
+ This file can be maintained by running the update script:
8
+
9
+ .. code-block:: bash
10
+
11
+ python -m scripts.update_api --inplace
12
+ """
13
+
14
+ __version__ = "0.0.5"
15
+
16
+ # This list shouldn't be modified by hand; instead, run the update script.
17
+ __all__ = [
18
+ "UserConfig",
19
+ "field",
20
+ "get_data_dir",
21
+ "get_pretrained_models_dir",
22
+ "get_run_dir",
23
+ "load_user_config",
24
+ "State",
25
+ "cast_phase",
26
+ "FourierEmbeddings",
27
+ "IdentityPositionalEmbeddings",
28
+ "LearnedPositionalEmbeddings",
29
+ "RotaryEmbeddings",
30
+ "SinusoidalEmbeddings",
31
+ "apply_rotary_embeddings",
32
+ "cast_embedding_kind",
33
+ "fourier_embeddings",
34
+ "get_positional_embeddings",
35
+ "get_rotary_embeddings",
36
+ "rotary_embeddings",
37
+ "BaseLauncher",
38
+ "CliLauncher",
39
+ "SingleProcessLauncher",
40
+ "LogImage",
41
+ "LogLine",
42
+ "Logger",
43
+ "LoggerImpl",
44
+ "CallbackLogger",
45
+ "JsonLogger",
46
+ "StateLogger",
47
+ "StdoutLogger",
48
+ "TensorboardLogger",
49
+ "CPUStatsOptions",
50
+ "DataloaderConfig",
51
+ "GPUStatsOptions",
52
+ "Script",
53
+ "ScriptConfig",
54
+ "Config",
55
+ "Task",
56
+ "collate",
57
+ "collate_non_null",
58
+ "BaseFileDownloader",
59
+ "DataDownloader",
60
+ "ModelDownloader",
61
+ "check_md5",
62
+ "check_sha256",
63
+ "get_git_state",
64
+ "get_state_dict_prefix",
65
+ "get_training_code",
66
+ "save_config",
67
+ "ColoredFormatter",
68
+ "configure_logging",
69
+ "one_hot",
70
+ "partial_flatten",
71
+ "worker_chunk",
72
+ "TextBlock",
73
+ "colored",
74
+ "format_datetime",
75
+ "format_timedelta",
76
+ "outlined",
77
+ "render_text_blocks",
78
+ "show_error",
79
+ "show_warning",
80
+ "uncolored",
81
+ "wrapped",
82
+ ]
83
+
84
+ __all__ += [
85
+ "Batch",
86
+ "CollateMode",
87
+ "EmbeddingKind",
88
+ "Output",
89
+ "Phase",
90
+ ]
91
+
92
+ import os
93
+ from typing import TYPE_CHECKING
94
+
95
+ # If this flag is set, eagerly imports the entire package (not recommended).
96
+ IMPORT_ALL = int(os.environ.get("XAX_IMPORT_ALL", "0")) != 0
97
+
98
+ del os
99
+
100
+ # This dictionary is auto-generated and shouldn't be modified by hand; instead,
101
+ # run the update script.
102
+ NAME_MAP: dict[str, str] = {
103
+ "UserConfig": "core.conf",
104
+ "field": "core.conf",
105
+ "get_data_dir": "core.conf",
106
+ "get_pretrained_models_dir": "core.conf",
107
+ "get_run_dir": "core.conf",
108
+ "load_user_config": "core.conf",
109
+ "State": "core.state",
110
+ "cast_phase": "core.state",
111
+ "FourierEmbeddings": "nn.embeddings",
112
+ "IdentityPositionalEmbeddings": "nn.embeddings",
113
+ "LearnedPositionalEmbeddings": "nn.embeddings",
114
+ "RotaryEmbeddings": "nn.embeddings",
115
+ "SinusoidalEmbeddings": "nn.embeddings",
116
+ "apply_rotary_embeddings": "nn.embeddings",
117
+ "cast_embedding_kind": "nn.embeddings",
118
+ "fourier_embeddings": "nn.embeddings",
119
+ "get_positional_embeddings": "nn.embeddings",
120
+ "get_rotary_embeddings": "nn.embeddings",
121
+ "rotary_embeddings": "nn.embeddings",
122
+ "BaseLauncher": "task.launchers.base",
123
+ "CliLauncher": "task.launchers.cli",
124
+ "SingleProcessLauncher": "task.launchers.single_process",
125
+ "LogImage": "task.logger",
126
+ "LogLine": "task.logger",
127
+ "Logger": "task.logger",
128
+ "LoggerImpl": "task.logger",
129
+ "CallbackLogger": "task.loggers.callback",
130
+ "JsonLogger": "task.loggers.json",
131
+ "StateLogger": "task.loggers.state",
132
+ "StdoutLogger": "task.loggers.stdout",
133
+ "TensorboardLogger": "task.loggers.tensorboard",
134
+ "CPUStatsOptions": "task.mixins.cpu_stats",
135
+ "DataloaderConfig": "task.mixins.data_loader",
136
+ "GPUStatsOptions": "task.mixins.gpu_stats",
137
+ "Script": "task.script",
138
+ "ScriptConfig": "task.script",
139
+ "Config": "task.task",
140
+ "Task": "task.task",
141
+ "collate": "utils.data.collate",
142
+ "collate_non_null": "utils.data.collate",
143
+ "BaseFileDownloader": "utils.experiments",
144
+ "DataDownloader": "utils.experiments",
145
+ "ModelDownloader": "utils.experiments",
146
+ "check_md5": "utils.experiments",
147
+ "check_sha256": "utils.experiments",
148
+ "get_git_state": "utils.experiments",
149
+ "get_state_dict_prefix": "utils.experiments",
150
+ "get_training_code": "utils.experiments",
151
+ "save_config": "utils.experiments",
152
+ "ColoredFormatter": "utils.logging",
153
+ "configure_logging": "utils.logging",
154
+ "one_hot": "utils.numpy",
155
+ "partial_flatten": "utils.numpy",
156
+ "worker_chunk": "utils.numpy",
157
+ "TextBlock": "utils.text",
158
+ "colored": "utils.text",
159
+ "format_datetime": "utils.text",
160
+ "format_timedelta": "utils.text",
161
+ "outlined": "utils.text",
162
+ "render_text_blocks": "utils.text",
163
+ "show_error": "utils.text",
164
+ "show_warning": "utils.text",
165
+ "uncolored": "utils.text",
166
+ "wrapped": "utils.text",
167
+ }
168
+
169
+ # Need to manually set some values which can't be auto-generated.
170
+ NAME_MAP.update(
171
+ {
172
+ "Batch": "task.mixins.train",
173
+ "CollateMode": "utils.data.collate",
174
+ "EmbeddingKind": "nn.embeddings",
175
+ "Output": "task.mixins.output",
176
+ "Phase": "core.state",
177
+ },
178
+ )
179
+
180
+
181
+ def __getattr__(name: str) -> object:
182
+ if name not in NAME_MAP:
183
+ raise AttributeError(f"{__name__} has no attribute {name}")
184
+
185
+ module_name = f"xax.{NAME_MAP[name]}"
186
+ module = __import__(module_name, fromlist=[name])
187
+ return getattr(module, name)
188
+
189
+
190
+ if IMPORT_ALL or TYPE_CHECKING:
191
+ from xax.core.conf import (
192
+ UserConfig,
193
+ field,
194
+ get_data_dir,
195
+ get_pretrained_models_dir,
196
+ get_run_dir,
197
+ load_user_config,
198
+ )
199
+ from xax.core.state import Phase, State, cast_phase
200
+ from xax.nn.embeddings import (
201
+ EmbeddingKind,
202
+ FourierEmbeddings,
203
+ IdentityPositionalEmbeddings,
204
+ LearnedPositionalEmbeddings,
205
+ RotaryEmbeddings,
206
+ SinusoidalEmbeddings,
207
+ apply_rotary_embeddings,
208
+ cast_embedding_kind,
209
+ fourier_embeddings,
210
+ get_positional_embeddings,
211
+ get_rotary_embeddings,
212
+ rotary_embeddings,
213
+ )
214
+ from xax.task.launchers.base import BaseLauncher
215
+ from xax.task.launchers.cli import CliLauncher
216
+ from xax.task.launchers.single_process import SingleProcessLauncher
217
+ from xax.task.logger import Logger, LoggerImpl, LogImage, LogLine
218
+ from xax.task.loggers.callback import CallbackLogger
219
+ from xax.task.loggers.json import JsonLogger
220
+ from xax.task.loggers.state import StateLogger
221
+ from xax.task.loggers.stdout import StdoutLogger
222
+ from xax.task.loggers.tensorboard import TensorboardLogger
223
+ from xax.task.mixins.cpu_stats import CPUStatsOptions
224
+ from xax.task.mixins.data_loader import DataloaderConfig
225
+ from xax.task.mixins.gpu_stats import GPUStatsOptions
226
+ from xax.task.mixins.train import Batch, Output
227
+ from xax.task.script import Script, ScriptConfig
228
+ from xax.task.task import Config, Task
229
+ from xax.utils.data.collate import CollateMode, collate, collate_non_null
230
+ from xax.utils.experiments import (
231
+ BaseFileDownloader,
232
+ DataDownloader,
233
+ ModelDownloader,
234
+ check_md5,
235
+ check_sha256,
236
+ get_git_state,
237
+ get_state_dict_prefix,
238
+ get_training_code,
239
+ save_config,
240
+ )
241
+ from xax.utils.logging import ColoredFormatter, configure_logging
242
+ from xax.utils.numpy import one_hot, partial_flatten, worker_chunk
243
+ from xax.utils.text import (
244
+ TextBlock,
245
+ colored,
246
+ format_datetime,
247
+ format_timedelta,
248
+ outlined,
249
+ render_text_blocks,
250
+ show_error,
251
+ show_warning,
252
+ uncolored,
253
+ wrapped,
254
+ )
255
+
256
+ del TYPE_CHECKING, IMPORT_ALL
xax/core/conf.py ADDED
@@ -0,0 +1,193 @@
1
+ """Defines base configuration functions and utilities."""
2
+
3
+ import functools
4
+ import os
5
+ from dataclasses import dataclass, field as field_base
6
+ from pathlib import Path
7
+ from typing import Any, cast
8
+
9
+ import jax.numpy as jnp
10
+ from omegaconf import II, MISSING, Container as OmegaConfContainer, OmegaConf
11
+
12
+ from xax.utils.text import show_error
13
+
14
+ FieldType = Any
15
+
16
+
17
+ def field(value: FieldType, **kwargs: str) -> FieldType:
18
+ """Short-hand function for getting a config field.
19
+
20
+ Args:
21
+ value: The current field's default value.
22
+ kwargs: Additional metadata fields to supply.
23
+
24
+ Returns:
25
+ The dataclass field.
26
+ """
27
+ metadata: dict[str, Any] = {}
28
+ metadata.update(kwargs)
29
+
30
+ if hasattr(value, "__call__"):
31
+ return field_base(default_factory=value, metadata=metadata)
32
+ if value.__class__.__hash__ is None:
33
+ return field_base(default_factory=lambda: value, metadata=metadata)
34
+ return field_base(default=value, metadata=metadata)
35
+
36
+
37
+ def is_missing(cfg: Any, key: str) -> bool: # noqa: ANN401
38
+ """Utility function for checking if a config key is missing.
39
+
40
+ This is for cases when you are using a raw dataclass rather than an
41
+ OmegaConf container but want to treat them the same way.
42
+
43
+ Args:
44
+ cfg: The config to check
45
+ key: The key to check
46
+
47
+ Returns:
48
+ Whether or not the key is missing a value in the config
49
+ """
50
+ if isinstance(cfg, OmegaConfContainer):
51
+ if OmegaConf.is_missing(cfg, key):
52
+ return True
53
+ if OmegaConf.is_interpolation(cfg, key):
54
+ try:
55
+ getattr(cfg, key)
56
+ return False
57
+ except Exception:
58
+ return True
59
+ if getattr(cfg, key) is MISSING:
60
+ return True
61
+ return False
62
+
63
+
64
+ @dataclass
65
+ class Logging:
66
+ hide_third_party_logs: bool = field(True, help="If set, hide third-party logs")
67
+ log_level: str = field("INFO", help="The logging level to use")
68
+
69
+
70
+ @dataclass
71
+ class Device:
72
+ cpu: bool = field(True, help="Whether to use the CPU")
73
+ gpu: bool = field(II("oc.env:USE_GPU,1"), help="Whether to use the GPU")
74
+ metal: bool = field(II("oc.env:USE_METAL,1"), help="Whether to use the Apple Silicon accelerator")
75
+ use_fp64: bool = field(False, help="Always use the 64-bit floating point type")
76
+ use_fp32: bool = field(False, help="Always use the 32-bit floating point type")
77
+ use_bf16: bool = field(False, help="Always use the 16-bit bfloat type")
78
+ use_fp16: bool = field(False, help="Always use the 16-bit floating point type")
79
+
80
+
81
+ def parse_dtype(cfg: Device) -> jnp.dtype | None:
82
+ if cfg.use_fp64:
83
+ return jnp.float64
84
+ if cfg.use_fp32:
85
+ return jnp.float32
86
+ if cfg.use_bf16:
87
+ return jnp.bfloat16
88
+ if cfg.use_fp16:
89
+ return jnp.float16
90
+ return None
91
+
92
+
93
+ @dataclass
94
+ class Triton:
95
+ use_triton_if_available: bool = field(True, help="Use Triton if available")
96
+
97
+
98
+ @dataclass
99
+ class Experiment:
100
+ default_random_seed: int = field(1337, help="The default random seed to use")
101
+ max_workers: int = field(32, help="Maximum number of workers to use")
102
+
103
+
104
+ @dataclass
105
+ class Directories:
106
+ run: str = field(II("oc.env:RUN_DIR"), help="The run directory")
107
+ data: str = field(II("oc.env:DATA_DIR"), help="The data directory")
108
+ pretrained_models: str = field(II("oc.env:MODEL_DIR"), help="The models directory")
109
+
110
+
111
+ @dataclass
112
+ class SlurmPartition:
113
+ partition: str = field(MISSING, help="The partition name")
114
+ num_nodes: int = field(1, help="The number of nodes to use")
115
+
116
+
117
+ @dataclass
118
+ class Slurm:
119
+ launch: dict[str, SlurmPartition] = field({}, help="The available launch configurations")
120
+
121
+
122
+ @dataclass
123
+ class UserConfig:
124
+ logging: Logging = field(Logging)
125
+ device: Device = field(Device)
126
+ triton: Triton = field(Triton)
127
+ experiment: Experiment = field(Experiment)
128
+ directories: Directories = field(Directories)
129
+ slurm: Slurm = field(Slurm)
130
+
131
+
132
+ def user_config_path() -> Path:
133
+ xaxrc_path_raw = os.environ.get("XAXRC_PATH", "~/.xax.yml")
134
+ xaxrc_path = Path(xaxrc_path_raw).expanduser()
135
+ return xaxrc_path
136
+
137
+
138
+ @functools.lru_cache(maxsize=None)
139
+ def _load_user_config_cached() -> UserConfig:
140
+ xaxrc_path = user_config_path()
141
+ base_cfg = OmegaConf.structured(UserConfig)
142
+
143
+ # Writes the config file.
144
+ if xaxrc_path.exists():
145
+ cfg = OmegaConf.merge(base_cfg, OmegaConf.load(xaxrc_path))
146
+ else:
147
+ show_error(f"No config file was found in {xaxrc_path}; writing one...", important=True)
148
+ OmegaConf.save(base_cfg, xaxrc_path)
149
+ cfg = base_cfg
150
+
151
+ # Looks in the current directory for a config file.
152
+ local_cfg_path = Path("xax.yml")
153
+ if local_cfg_path.exists():
154
+ cfg = OmegaConf.merge(cfg, OmegaConf.load(local_cfg_path))
155
+
156
+ return cast(UserConfig, cfg)
157
+
158
+
159
+ def load_user_config() -> UserConfig:
160
+ """Loads the ``~/.xax.yml`` configuration file.
161
+
162
+ Returns:
163
+ The loaded configuration.
164
+ """
165
+ return _load_user_config_cached()
166
+
167
+
168
+ def get_run_dir() -> Path | None:
169
+ config = load_user_config().directories
170
+ if is_missing(config, "run"):
171
+ return None
172
+ (run_dir := Path(config.run)).mkdir(parents=True, exist_ok=True)
173
+ return run_dir
174
+
175
+
176
+ def get_data_dir() -> Path:
177
+ config = load_user_config().directories
178
+ if is_missing(config, "data"):
179
+ raise RuntimeError(
180
+ "The data directory has not been set! You should set it in your config file "
181
+ f"in {user_config_path()} or set the DATA_DIR environment variable."
182
+ )
183
+ return Path(config.data)
184
+
185
+
186
+ def get_pretrained_models_dir() -> Path:
187
+ config = load_user_config().directories
188
+ if is_missing(config, "pretrained_models"):
189
+ raise RuntimeError(
190
+ "The data directory has not been set! You should set it in your config file "
191
+ f"in {user_config_path()} or set the MODEL_DIR environment variable."
192
+ )
193
+ return Path(config.pretrained_models)
xax/core/state.py ADDED
@@ -0,0 +1,81 @@
1
+ """Defines a dataclass for keeping track of the current training state."""
2
+
3
+ import time
4
+ from dataclasses import dataclass
5
+ from typing import Literal, TypedDict, cast, get_args
6
+
7
+ from omegaconf import MISSING
8
+
9
+ from xax.core.conf import field
10
+
11
+ Phase = Literal["train", "valid"]
12
+
13
+
14
+ def cast_phase(raw_phase: str) -> Phase:
15
+ args = get_args(Phase)
16
+ assert raw_phase in args, f"Invalid phase: '{raw_phase}' Valid options are {args}"
17
+ return cast(Phase, raw_phase)
18
+
19
+
20
+ class StateDict(TypedDict, total=False):
21
+ num_steps: int
22
+ num_samples: int
23
+ num_valid_steps: int
24
+ num_valid_samples: int
25
+ start_time_s: float
26
+ elapsed_time_s: float
27
+ raw_phase: str
28
+
29
+
30
+ @dataclass(frozen=True)
31
+ class State:
32
+ num_steps: int = field(MISSING, help="Number of steps so far")
33
+ num_samples: int = field(MISSING, help="Number of sample so far")
34
+ num_valid_steps: int = field(MISSING, help="Number of validation steps so far")
35
+ num_valid_samples: int = field(MISSING, help="Number of validation samples so far")
36
+ start_time_s: float = field(MISSING, help="Start time of training")
37
+ elapsed_time_s: float = field(MISSING, help="Total elapsed time so far")
38
+ raw_phase: str = field(MISSING, help="Current training phase")
39
+
40
+ @property
41
+ def phase(self) -> Phase:
42
+ return cast_phase(self.raw_phase)
43
+
44
+ @classmethod
45
+ def init_state(cls) -> "State":
46
+ return cls(
47
+ num_steps=0,
48
+ num_samples=0,
49
+ num_valid_steps=0,
50
+ num_valid_samples=0,
51
+ start_time_s=time.time(),
52
+ elapsed_time_s=0.0,
53
+ raw_phase="train",
54
+ )
55
+
56
+ @property
57
+ def training(self) -> bool:
58
+ return self.phase == "train"
59
+
60
+ def num_phase_steps(self, phase: Phase) -> int:
61
+ match phase:
62
+ case "train":
63
+ return self.num_steps
64
+ case "valid":
65
+ return self.num_valid_steps
66
+ case _:
67
+ raise ValueError(f"Invalid phase: {phase}")
68
+
69
+ def replace(self, values: StateDict) -> "State":
70
+ return State(
71
+ num_steps=values.get("num_steps", self.num_steps),
72
+ num_samples=values.get("num_samples", self.num_samples),
73
+ num_valid_steps=values.get("num_valid_steps", self.num_valid_steps),
74
+ num_valid_samples=values.get("num_valid_samples", self.num_valid_samples),
75
+ start_time_s=values.get("start_time_s", self.start_time_s),
76
+ elapsed_time_s=values.get("elapsed_time_s", self.elapsed_time_s),
77
+ raw_phase=values.get("raw_phase", self.raw_phase),
78
+ )
79
+
80
+ def with_phase(self, phase: Phase) -> "State":
81
+ return self.replace({"raw_phase": phase})
xax/nn/__init__.py ADDED
File without changes