xax 0.0.1__tar.gz → 0.0.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. {xax-0.0.1/xax.egg-info → xax-0.0.3}/PKG-INFO +19 -1
  2. xax-0.0.3/README.md +9 -0
  3. {xax-0.0.1 → xax-0.0.3}/pyproject.toml +14 -7
  4. {xax-0.0.1 → xax-0.0.3}/setup.cfg +2 -0
  5. {xax-0.0.1 → xax-0.0.3}/setup.py +6 -0
  6. xax-0.0.3/tests/test_dummy.py +5 -0
  7. xax-0.0.3/xax/__init__.py +214 -0
  8. xax-0.0.3/xax/core/conf.py +192 -0
  9. xax-0.0.3/xax/core/state.py +81 -0
  10. xax-0.0.3/xax/nn/functions.py +73 -0
  11. xax-0.0.3/xax/nn/parallel.py +211 -0
  12. xax-0.0.3/xax/py.typed +0 -0
  13. xax-0.0.3/xax/requirements.txt +18 -0
  14. xax-0.0.3/xax/task/__init__.py +0 -0
  15. xax-0.0.3/xax/task/base.py +213 -0
  16. xax-0.0.3/xax/task/launchers/__init__.py +0 -0
  17. xax-0.0.3/xax/task/launchers/base.py +28 -0
  18. xax-0.0.3/xax/task/launchers/cli.py +42 -0
  19. xax-0.0.3/xax/task/launchers/single_process.py +30 -0
  20. xax-0.0.3/xax/task/launchers/staged.py +29 -0
  21. xax-0.0.3/xax/task/logger.py +848 -0
  22. xax-0.0.3/xax/task/loggers/__init__.py +0 -0
  23. xax-0.0.3/xax/task/loggers/json.py +121 -0
  24. xax-0.0.3/xax/task/loggers/state.py +45 -0
  25. xax-0.0.3/xax/task/loggers/stdout.py +170 -0
  26. xax-0.0.3/xax/task/loggers/tensorboard.py +226 -0
  27. xax-0.0.3/xax/task/mixins/__init__.py +11 -0
  28. xax-0.0.3/xax/task/mixins/artifacts.py +107 -0
  29. xax-0.0.3/xax/task/mixins/cpu_stats.py +251 -0
  30. xax-0.0.3/xax/task/mixins/data_loader.py +152 -0
  31. xax-0.0.3/xax/task/mixins/gpu_stats.py +257 -0
  32. xax-0.0.3/xax/task/mixins/logger.py +314 -0
  33. xax-0.0.3/xax/task/mixins/process.py +47 -0
  34. xax-0.0.3/xax/task/mixins/runnable.py +63 -0
  35. xax-0.0.3/xax/task/mixins/step_wrapper.py +63 -0
  36. xax-0.0.3/xax/task/mixins/train.py +510 -0
  37. xax-0.0.3/xax/task/script.py +53 -0
  38. xax-0.0.3/xax/task/task.py +64 -0
  39. xax-0.0.3/xax/utils/__init__.py +0 -0
  40. xax-0.0.3/xax/utils/data/__init__.py +0 -0
  41. xax-0.0.3/xax/utils/data/collate.py +206 -0
  42. xax-0.0.3/xax/utils/experiments.py +758 -0
  43. xax-0.0.3/xax/utils/jax.py +14 -0
  44. xax-0.0.3/xax/utils/logging.py +194 -0
  45. xax-0.0.3/xax/utils/numpy.py +47 -0
  46. xax-0.0.3/xax/utils/tensorboard.py +238 -0
  47. xax-0.0.3/xax/utils/text.py +350 -0
  48. {xax-0.0.1 → xax-0.0.3/xax.egg-info}/PKG-INFO +19 -1
  49. xax-0.0.3/xax.egg-info/SOURCES.txt +56 -0
  50. xax-0.0.3/xax.egg-info/requires.txt +21 -0
  51. xax-0.0.3/xax.egg-info/top_level.txt +1 -0
  52. xax-0.0.1/README.md +0 -3
  53. xax-0.0.1/examples/mnist.py +0 -148
  54. xax-0.0.1/tests/test_dummy.py +0 -22
  55. xax-0.0.1/xax/__init__.py +0 -1
  56. xax-0.0.1/xax/requirements.txt +0 -4
  57. xax-0.0.1/xax.egg-info/SOURCES.txt +0 -18
  58. xax-0.0.1/xax.egg-info/requires.txt +0 -9
  59. xax-0.0.1/xax.egg-info/top_level.txt +0 -2
  60. {xax-0.0.1 → xax-0.0.3}/LICENSE +0 -0
  61. {xax-0.0.1 → xax-0.0.3}/MANIFEST.in +0 -0
  62. {xax-0.0.1/examples → xax-0.0.3/xax/core}/__init__.py +0 -0
  63. /xax-0.0.1/xax/py.typed → /xax-0.0.3/xax/nn/__init__.py +0 -0
  64. {xax-0.0.1 → xax-0.0.3}/xax/requirements-dev.txt +0 -0
  65. {xax-0.0.1 → xax-0.0.3}/xax.egg-info/dependency_links.txt +0 -0
@@ -1,14 +1,26 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: xax
3
- Version: 0.0.1
3
+ Version: 0.0.3
4
4
  Summary: The xax project
5
5
  Home-page: https://github.com/dpshai/xax
6
6
  Author: Benjamin Bolte
7
7
  Requires-Python: >=3.11
8
8
  Description-Content-Type: text/markdown
9
9
  License-File: LICENSE
10
+ Requires-Dist: dpshdl
11
+ Requires-Dist: equinox
12
+ Requires-Dist: gitpython
10
13
  Requires-Dist: jax
11
14
  Requires-Dist: jaxtyping
15
+ Requires-Dist: omegaconf
16
+ Requires-Dist: optax
17
+ Requires-Dist: pillow
18
+ Requires-Dist: psutil
19
+ Requires-Dist: requests
20
+ Requires-Dist: tensorboard
21
+ Requires-Dist: types-pillow
22
+ Requires-Dist: types-psutil
23
+ Requires-Dist: types-requests
12
24
  Provides-Extra: dev
13
25
  Requires-Dist: black; extra == "dev"
14
26
  Requires-Dist: darglint; extra == "dev"
@@ -19,3 +31,9 @@ Requires-Dist: ruff; extra == "dev"
19
31
  # xax
20
32
 
21
33
  JAX library for fast experimentation.
34
+
35
+ ## Installation
36
+
37
+ ```bash
38
+ pip install xax
39
+ ```
xax-0.0.3/README.md ADDED
@@ -0,0 +1,9 @@
1
+ # xax
2
+
3
+ JAX library for fast experimentation.
4
+
5
+ ## Installation
6
+
7
+ ```bash
8
+ pip install xax
9
+ ```
@@ -30,11 +30,14 @@ warn_unused_ignores = true
30
30
  warn_redundant_casts = true
31
31
 
32
32
  incremental = true
33
- namespace_packages = false
33
+ explicit_package_bases = true
34
34
 
35
35
  [[tool.mypy.overrides]]
36
36
 
37
37
  module = [
38
+ "optax.*",
39
+ "setuptools.*",
40
+ "tensorboard.*",
38
41
  "transformers.*",
39
42
  ]
40
43
 
@@ -46,33 +49,37 @@ profile = "black"
46
49
 
47
50
  [tool.ruff]
48
51
 
52
+ line-length = 120
53
+ target-version = "py311"
54
+
55
+ [tool.ruff.lint]
56
+
49
57
  select = ["ANN", "D", "E", "F", "I", "N", "PGH", "PLC", "PLE", "PLR", "PLW", "W"]
50
58
 
51
59
  ignore = [
52
60
  "ANN101", "ANN102", "ANN401",
53
61
  "D101", "D102", "D103", "D104", "D105", "D106", "D107",
62
+ "F722",
54
63
  "N812", "N817",
55
64
  "PLR0911", "PLR0912", "PLR0913", "PLR0915", "PLR2004",
56
65
  "PLW0603", "PLW2901",
57
66
  ]
58
67
 
59
- line-length = 120
60
68
  dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
61
- target-version = "py311"
62
69
 
63
- [tool.ruff.per-file-ignores]
70
+ [tool.ruff.lint.per-file-ignores]
64
71
 
65
72
  "__init__.py" = ["E402", "F401", "F403", "F811"]
66
73
 
67
- [tool.ruff.mccabe]
74
+ [tool.ruff.lint.mccabe]
68
75
 
69
76
  max-complexity = 10
70
77
 
71
- [tool.ruff.isort]
78
+ [tool.ruff.lint.isort]
72
79
 
73
80
  known-first-party = ["xax", "tests"]
74
81
  combine-as-imports = true
75
82
 
76
- [tool.ruff.pydocstyle]
83
+ [tool.ruff.lint.pydocstyle]
77
84
 
78
85
  convention = "google"
@@ -3,6 +3,8 @@ packages = find:
3
3
 
4
4
  [options.packages.find]
5
5
  exclude =
6
+ .vscode
7
+ .github
6
8
  tests
7
9
 
8
10
  [egg_info]
@@ -35,4 +35,10 @@ setup(
35
35
  install_requires=requirements,
36
36
  tests_require=requirements_dev,
37
37
  extras_require={"dev": requirements_dev},
38
+ package_data={
39
+ "xax": [
40
+ "py.typed",
41
+ "requirements*.txt",
42
+ ],
43
+ },
38
44
  )
@@ -0,0 +1,5 @@
1
+ """Runs a dummy placeholder test."""
2
+
3
+
4
+ def test_dummy() -> None:
5
+ assert True
@@ -0,0 +1,214 @@
1
+ """Defines the top-level xax API.
2
+
3
+ This package is structured so that all the important stuff can be accessed
4
+ without having to dig around through the internals. This is done by lazily
5
+ importing the module by name.
6
+
7
+ This file can be maintained by running the update script:
8
+
9
+ .. code-block:: bash
10
+
11
+ python -m scripts.update_api --inplace
12
+ """
13
+
14
+ __version__ = "0.0.3"
15
+
16
+ # This list shouldn't be modified by hand; instead, run the update script.
17
+ __all__ = [
18
+ "UserConfig",
19
+ "field",
20
+ "get_data_dir",
21
+ "get_pretrained_models_dir",
22
+ "get_run_dir",
23
+ "load_user_config",
24
+ "State",
25
+ "cast_phase",
26
+ "BaseLauncher",
27
+ "CliLauncher",
28
+ "SingleProcessLauncher",
29
+ "LogAudio",
30
+ "LogImage",
31
+ "LogLine",
32
+ "LogVideo",
33
+ "Logger",
34
+ "LoggerImpl",
35
+ "JsonLogger",
36
+ "StateLogger",
37
+ "StdoutLogger",
38
+ "TensorboardLogger",
39
+ "CPUStatsOptions",
40
+ "DataloaderConfig",
41
+ "GPUStatsOptions",
42
+ "Script",
43
+ "ScriptConfig",
44
+ "Config",
45
+ "Task",
46
+ "collate",
47
+ "collate_non_null",
48
+ "BaseFileDownloader",
49
+ "DataDownloader",
50
+ "ModelDownloader",
51
+ "check_md5",
52
+ "check_sha256",
53
+ "get_git_state",
54
+ "get_state_dict_prefix",
55
+ "get_training_code",
56
+ "save_config",
57
+ "ColoredFormatter",
58
+ "configure_logging",
59
+ "one_hot",
60
+ "partial_flatten",
61
+ "worker_chunk",
62
+ "TextBlock",
63
+ "colored",
64
+ "format_datetime",
65
+ "format_timedelta",
66
+ "outlined",
67
+ "render_text_blocks",
68
+ "show_error",
69
+ "show_warning",
70
+ "uncolored",
71
+ "wrapped",
72
+ ]
73
+
74
+ __all__ += [
75
+ "CollateMode",
76
+ "Phase",
77
+ ]
78
+
79
+ import os
80
+ from typing import TYPE_CHECKING
81
+
82
+ # If this flag is set, eagerly imports the entire package (not recommended).
83
+ IMPORT_ALL = int(os.environ.get("XAX_IMPORT_ALL", "0")) != 0
84
+
85
+ del os
86
+
87
+ # This dictionary is auto-generated and shouldn't be modified by hand; instead,
88
+ # run the update script.
89
+ NAME_MAP: dict[str, str] = {
90
+ "UserConfig": "core.conf",
91
+ "field": "core.conf",
92
+ "get_data_dir": "core.conf",
93
+ "get_pretrained_models_dir": "core.conf",
94
+ "get_run_dir": "core.conf",
95
+ "load_user_config": "core.conf",
96
+ "State": "core.state",
97
+ "cast_phase": "core.state",
98
+ "BaseLauncher": "task.launchers.base",
99
+ "CliLauncher": "task.launchers.cli",
100
+ "SingleProcessLauncher": "task.launchers.single_process",
101
+ "LogAudio": "task.logger",
102
+ "LogImage": "task.logger",
103
+ "LogLine": "task.logger",
104
+ "LogVideo": "task.logger",
105
+ "Logger": "task.logger",
106
+ "LoggerImpl": "task.logger",
107
+ "JsonLogger": "task.loggers.json",
108
+ "StateLogger": "task.loggers.state",
109
+ "StdoutLogger": "task.loggers.stdout",
110
+ "TensorboardLogger": "task.loggers.tensorboard",
111
+ "CPUStatsOptions": "task.mixins.cpu_stats",
112
+ "DataLoaderConfig": "task.mixins.data_loader",
113
+ "GPUStatsOptions": "task.mixins.gpu_stats",
114
+ "Script": "task.script",
115
+ "ScriptConfig": "task.script",
116
+ "Config": "task.task",
117
+ "Task": "task.task",
118
+ "collate": "utils.data.collate",
119
+ "collate_non_null": "utils.data.collate",
120
+ "BaseFileDownloader": "utils.experiments",
121
+ "DataDownloader": "utils.experiments",
122
+ "ModelDownloader": "utils.experiments",
123
+ "check_md5": "utils.experiments",
124
+ "check_sha256": "utils.experiments",
125
+ "get_git_state": "utils.experiments",
126
+ "get_state_dict_prefix": "utils.experiments",
127
+ "get_training_code": "utils.experiments",
128
+ "save_config": "utils.experiments",
129
+ "ColoredFormatter": "utils.logging",
130
+ "configure_logging": "utils.logging",
131
+ "one_hot": "utils.numpy",
132
+ "partial_flatten": "utils.numpy",
133
+ "worker_chunk": "utils.numpy",
134
+ "TextBlock": "utils.text",
135
+ "colored": "utils.text",
136
+ "format_datetime": "utils.text",
137
+ "format_timedelta": "utils.text",
138
+ "outlined": "utils.text",
139
+ "render_text_blocks": "utils.text",
140
+ "show_error": "utils.text",
141
+ "show_warning": "utils.text",
142
+ "uncolored": "utils.text",
143
+ "wrapped": "utils.text",
144
+ }
145
+
146
+ # Need to manually set some values which can't be auto-generated.
147
+ NAME_MAP.update(
148
+ {
149
+ "CollateMode": "utils.data.collate",
150
+ "Phase": "core.state",
151
+ },
152
+ )
153
+
154
+
155
+ def __getattr__(name: str) -> object:
156
+ if name not in NAME_MAP:
157
+ raise AttributeError(f"{__name__} has no attribute {name}")
158
+
159
+ module_name = f"xax.{NAME_MAP[name]}"
160
+ module = __import__(module_name, fromlist=[name])
161
+ return getattr(module, name)
162
+
163
+
164
+ if IMPORT_ALL or TYPE_CHECKING:
165
+ from xax.core.conf import (
166
+ UserConfig,
167
+ field,
168
+ get_data_dir,
169
+ get_pretrained_models_dir,
170
+ get_run_dir,
171
+ load_user_config,
172
+ )
173
+ from xax.core.state import Phase, State, cast_phase
174
+ from xax.task.launchers.base import BaseLauncher
175
+ from xax.task.launchers.cli import CliLauncher
176
+ from xax.task.launchers.single_process import SingleProcessLauncher
177
+ from xax.task.logger import LogAudio, Logger, LoggerImpl, LogImage, LogLine, LogVideo
178
+ from xax.task.loggers.json import JsonLogger
179
+ from xax.task.loggers.state import StateLogger
180
+ from xax.task.loggers.stdout import StdoutLogger
181
+ from xax.task.loggers.tensorboard import TensorboardLogger
182
+ from xax.task.mixins.cpu_stats import CPUStatsOptions
183
+ from xax.task.mixins.data_loader import DataloaderConfig
184
+ from xax.task.mixins.gpu_stats import GPUStatsOptions
185
+ from xax.task.script import Script, ScriptConfig
186
+ from xax.task.task import Config, Task
187
+ from xax.utils.data.collate import CollateMode, collate, collate_non_null
188
+ from xax.utils.experiments import (
189
+ BaseFileDownloader,
190
+ DataDownloader,
191
+ ModelDownloader,
192
+ check_md5,
193
+ check_sha256,
194
+ get_git_state,
195
+ get_state_dict_prefix,
196
+ get_training_code,
197
+ save_config,
198
+ )
199
+ from xax.utils.logging import ColoredFormatter, configure_logging
200
+ from xax.utils.numpy import one_hot, partial_flatten, worker_chunk
201
+ from xax.utils.text import (
202
+ TextBlock,
203
+ colored,
204
+ format_datetime,
205
+ format_timedelta,
206
+ outlined,
207
+ render_text_blocks,
208
+ show_error,
209
+ show_warning,
210
+ uncolored,
211
+ wrapped,
212
+ )
213
+
214
+ del TYPE_CHECKING, IMPORT_ALL
@@ -0,0 +1,192 @@
1
+ """Defines base configuration functions and utilities."""
2
+
3
+ import functools
4
+ import os
5
+ from dataclasses import dataclass, field as field_base
6
+ from pathlib import Path
7
+ from typing import Any, cast
8
+
9
+ import jax.numpy as jnp
10
+ from omegaconf import II, MISSING, Container as OmegaConfContainer, OmegaConf
11
+
12
+ from xax.utils.text import show_error
13
+
14
+ FieldType = Any
15
+
16
+
17
+ def field(value: FieldType, **kwargs: str) -> FieldType:
18
+ """Short-hand function for getting a config field.
19
+
20
+ Args:
21
+ value: The current field's default value.
22
+ kwargs: Additional metadata fields to supply.
23
+
24
+ Returns:
25
+ The dataclass field.
26
+ """
27
+ metadata: dict[str, Any] = {}
28
+ metadata.update(kwargs)
29
+
30
+ if hasattr(value, "__call__"):
31
+ return field_base(default_factory=value, metadata=metadata)
32
+ if value.__class__.__hash__ is None:
33
+ return field_base(default_factory=lambda: value, metadata=metadata)
34
+ return field_base(default=value, metadata=metadata)
35
+
36
+
37
+ def is_missing(cfg: Any, key: str) -> bool: # noqa: ANN401
38
+ """Utility function for checking if a config key is missing.
39
+
40
+ This is for cases when you are using a raw dataclass rather than an
41
+ OmegaConf container but want to treat them the same way.
42
+
43
+ Args:
44
+ cfg: The config to check
45
+ key: The key to check
46
+
47
+ Returns:
48
+ Whether or not the key is missing a value in the config
49
+ """
50
+ if isinstance(cfg, OmegaConfContainer):
51
+ if OmegaConf.is_missing(cfg, key):
52
+ return True
53
+ if OmegaConf.is_interpolation(cfg, key):
54
+ try:
55
+ getattr(cfg, key)
56
+ return False
57
+ except Exception:
58
+ return True
59
+ if getattr(cfg, key) is MISSING:
60
+ return True
61
+ return False
62
+
63
+
64
+ @dataclass
65
+ class Logging:
66
+ hide_third_party_logs: bool = field(True, help="If set, hide third-party logs")
67
+ log_level: str = field("INFO", help="The logging level to use")
68
+
69
+
70
+ @dataclass
71
+ class Device:
72
+ cpu: bool = field(True, help="Whether to use the CPU")
73
+ gpu: bool = field(II("oc.env:USE_GPU,1"), help="Whether to use the GPU")
74
+ metal: bool = field(II("oc.env:USE_METAL,1"), help="Whether to use the Apple Silicon accelerator")
75
+ use_fp64: bool = field(False, help="Always use the 64-bit floating point type")
76
+ use_fp32: bool = field(False, help="Always use the 32-bit floating point type")
77
+ use_bf16: bool = field(False, help="Always use the 16-bit bfloat type")
78
+ use_fp16: bool = field(False, help="Always use the 16-bit floating point type")
79
+
80
+
81
+ def parse_dtype(cfg: Device) -> jnp.dtype | None:
82
+ if cfg.use_fp64:
83
+ return jnp.float64
84
+ if cfg.use_fp32:
85
+ return jnp.float32
86
+ if cfg.use_bf16:
87
+ return jnp.bfloat16
88
+ if cfg.use_fp16:
89
+ return jnp.float16
90
+ return None
91
+
92
+
93
+ @dataclass
94
+ class Triton:
95
+ use_triton_if_available: bool = field(True, help="Use Triton if available")
96
+
97
+
98
+ @dataclass
99
+ class Experiment:
100
+ default_random_seed: int = field(1337, help="The default random seed to use")
101
+
102
+
103
+ @dataclass
104
+ class Directories:
105
+ run: str = field(II("oc.env:RUN_DIR"), help="The run directory")
106
+ data: str = field(II("oc.env:DATA_DIR"), help="The data directory")
107
+ pretrained_models: str = field(II("oc.env:MODEL_DIR"), help="The models directory")
108
+
109
+
110
+ @dataclass
111
+ class SlurmPartition:
112
+ partition: str = field(MISSING, help="The partition name")
113
+ num_nodes: int = field(1, help="The number of nodes to use")
114
+
115
+
116
+ @dataclass
117
+ class Slurm:
118
+ launch: dict[str, SlurmPartition] = field({}, help="The available launch configurations")
119
+
120
+
121
+ @dataclass
122
+ class UserConfig:
123
+ logging: Logging = field(Logging)
124
+ device: Device = field(Device)
125
+ triton: Triton = field(Triton)
126
+ experiment: Experiment = field(Experiment)
127
+ directories: Directories = field(Directories)
128
+ slurm: Slurm = field(Slurm)
129
+
130
+
131
+ def user_config_path() -> Path:
132
+ xaxrc_path_raw = os.environ.get("XAXRC_PATH", "~/.xax.yml")
133
+ xaxrc_path = Path(xaxrc_path_raw).expanduser()
134
+ return xaxrc_path
135
+
136
+
137
+ @functools.lru_cache(maxsize=None)
138
+ def _load_user_config_cached() -> UserConfig:
139
+ xaxrc_path = user_config_path()
140
+ base_cfg = OmegaConf.structured(UserConfig)
141
+
142
+ # Writes the config file.
143
+ if xaxrc_path.exists():
144
+ cfg = OmegaConf.merge(base_cfg, OmegaConf.load(xaxrc_path))
145
+ else:
146
+ show_error(f"No config file was found in {xaxrc_path}; writing one...", important=True)
147
+ OmegaConf.save(base_cfg, xaxrc_path)
148
+ cfg = base_cfg
149
+
150
+ # Looks in the current directory for a config file.
151
+ local_cfg_path = Path("xax.yml")
152
+ if local_cfg_path.exists():
153
+ cfg = OmegaConf.merge(cfg, OmegaConf.load(local_cfg_path))
154
+
155
+ return cast(UserConfig, cfg)
156
+
157
+
158
+ def load_user_config() -> UserConfig:
159
+ """Loads the ``~/.xax.yml`` configuration file.
160
+
161
+ Returns:
162
+ The loaded configuration.
163
+ """
164
+ return _load_user_config_cached()
165
+
166
+
167
+ def get_run_dir() -> Path | None:
168
+ config = load_user_config().directories
169
+ if is_missing(config, "run"):
170
+ return None
171
+ (run_dir := Path(config.run)).mkdir(parents=True, exist_ok=True)
172
+ return run_dir
173
+
174
+
175
+ def get_data_dir() -> Path:
176
+ config = load_user_config().directories
177
+ if is_missing(config, "data"):
178
+ raise RuntimeError(
179
+ "The data directory has not been set! You should set it in your config file "
180
+ f"in {user_config_path()} or set the DATA_DIR environment variable."
181
+ )
182
+ return Path(config.data)
183
+
184
+
185
+ def get_pretrained_models_dir() -> Path:
186
+ config = load_user_config().directories
187
+ if is_missing(config, "pretrained_models"):
188
+ raise RuntimeError(
189
+ "The data directory has not been set! You should set it in your config file "
190
+ f"in {user_config_path()} or set the MODEL_DIR environment variable."
191
+ )
192
+ return Path(config.pretrained_models)
@@ -0,0 +1,81 @@
1
+ """Defines a dataclass for keeping track of the current training state."""
2
+
3
+ import time
4
+ from dataclasses import dataclass
5
+ from typing import Literal, TypedDict, cast, get_args
6
+
7
+ from omegaconf import MISSING
8
+
9
+ from xax.core.conf import field
10
+
11
+ Phase = Literal["train", "valid"]
12
+
13
+
14
+ def cast_phase(raw_phase: str) -> Phase:
15
+ args = get_args(Phase)
16
+ assert raw_phase in args, f"Invalid phase: '{raw_phase}' Valid options are {args}"
17
+ return cast(Phase, raw_phase)
18
+
19
+
20
+ class StateDict(TypedDict, total=False):
21
+ num_steps: int
22
+ num_samples: int
23
+ num_valid_steps: int
24
+ num_valid_samples: int
25
+ start_time_s: float
26
+ elapsed_time_s: float
27
+ raw_phase: str
28
+
29
+
30
+ @dataclass(frozen=True)
31
+ class State:
32
+ num_steps: int = field(MISSING, help="Number of steps so far")
33
+ num_samples: int = field(MISSING, help="Number of sample so far")
34
+ num_valid_steps: int = field(MISSING, help="Number of validation steps so far")
35
+ num_valid_samples: int = field(MISSING, help="Number of validation samples so far")
36
+ start_time_s: float = field(MISSING, help="Start time of training")
37
+ elapsed_time_s: float = field(MISSING, help="Total elapsed time so far")
38
+ raw_phase: str = field(MISSING, help="Current training phase")
39
+
40
+ @property
41
+ def phase(self) -> Phase:
42
+ return cast_phase(self.raw_phase)
43
+
44
+ @classmethod
45
+ def init_state(cls) -> "State":
46
+ return cls(
47
+ num_steps=0,
48
+ num_samples=0,
49
+ num_valid_steps=0,
50
+ num_valid_samples=0,
51
+ start_time_s=time.time(),
52
+ elapsed_time_s=0.0,
53
+ raw_phase="train",
54
+ )
55
+
56
+ @property
57
+ def training(self) -> bool:
58
+ return self.phase == "train"
59
+
60
+ def num_phase_steps(self, phase: Phase) -> int:
61
+ match phase:
62
+ case "train":
63
+ return self.num_steps
64
+ case "valid":
65
+ return self.num_valid_steps
66
+ case _:
67
+ raise ValueError(f"Invalid phase: {phase}")
68
+
69
+ def replace(self, values: StateDict) -> "State":
70
+ return State(
71
+ num_steps=values.get("num_steps", self.num_steps),
72
+ num_samples=values.get("num_samples", self.num_samples),
73
+ num_valid_steps=values.get("num_valid_steps", self.num_valid_steps),
74
+ num_valid_samples=values.get("num_valid_samples", self.num_valid_samples),
75
+ start_time_s=values.get("start_time_s", self.start_time_s),
76
+ elapsed_time_s=values.get("elapsed_time_s", self.elapsed_time_s),
77
+ raw_phase=values.get("raw_phase", self.raw_phase),
78
+ )
79
+
80
+ def with_phase(self, phase: Phase) -> "State":
81
+ return self.replace({"raw_phase": phase})
@@ -0,0 +1,73 @@
1
+ # mypy: disable-error-code="override"
2
+ """Defines helper Torch functions."""
3
+
4
+ import random
5
+ from dataclasses import is_dataclass
6
+ from typing import Any, Callable, Iterable, Mapping, ParamSpec, Sequence, TypeVar
7
+
8
+ import numpy as np
9
+ from jaxtyping import Array
10
+
11
+ from xax.core.conf import load_user_config
12
+
13
+ T = TypeVar("T")
14
+ P = ParamSpec("P")
15
+
16
+
17
+ def recursive_apply(item: Any, func: Callable[[Array], Array]) -> Any: # noqa: ANN401
18
+ """Applies a function recursively to tensors in an item.
19
+
20
+ Args:
21
+ item: The item to apply the function to
22
+ func: The function to apply (for the tensor)
23
+
24
+ Returns:
25
+ The same item, with the function applied
26
+ """
27
+ if isinstance(item, (str, int, float)):
28
+ return item
29
+ if isinstance(item, Array):
30
+ return func(item)
31
+ if is_dataclass(item):
32
+ return item.__class__(**{k: recursive_apply(v, func) for k, v in item.__dict__.items()})
33
+ if isinstance(item, Mapping):
34
+ return {k: recursive_apply(v, func) for k, v in item.items()}
35
+ if isinstance(item, Sequence):
36
+ return [recursive_apply(i, func) for i in item]
37
+ return item
38
+
39
+
40
+ def recursive_chunk(item: Any, num_chunks: int, dim: int = 0) -> Iterable[Any]: # noqa: ANN401
41
+ """Recursively chunk tensors N times.
42
+
43
+ Args:
44
+ item: The item to recursively chunk
45
+ num_chunks: The number of splits to make
46
+ dim: The split dimension
47
+
48
+ Yields:
49
+ ``num_chunks`` chunks of items
50
+ """
51
+ if isinstance(item, (str, int, float)):
52
+ yield from (item for _ in range(num_chunks))
53
+ elif isinstance(item, np.ndarray):
54
+ yield from np.array_split(item, num_chunks, axis=dim)
55
+ elif is_dataclass(item):
56
+ yield from (
57
+ item.__class__(**{k: i for k, i in zip(item.__dict__, ii)})
58
+ for ii in zip(*(recursive_chunk(v, num_chunks, dim) for v in item.__dict__.values()))
59
+ )
60
+ elif isinstance(item, Mapping):
61
+ yield from (dict(zip(item, ii)) for ii in zip(*(recursive_chunk(i, num_chunks, dim) for i in item.values())))
62
+ elif isinstance(item, Sequence):
63
+ yield from (list(ii) for ii in zip(*(recursive_chunk(i, num_chunks, dim) for i in item)))
64
+ else:
65
+ yield from (item for _ in range(num_chunks))
66
+
67
+
68
+ def set_random_seed(seed: int | None = None, offset: int = 0) -> None:
69
+ if seed is None:
70
+ seed = load_user_config().experiment.default_random_seed
71
+ seed += offset
72
+ random.seed(seed)
73
+ np.random.seed(seed)