xax 0.0.3__tar.gz → 0.0.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. {xax-0.0.3/xax.egg-info → xax-0.0.6}/PKG-INFO +22 -11
  2. {xax-0.0.3 → xax-0.0.6}/pyproject.toml +16 -3
  3. {xax-0.0.3 → xax-0.0.6}/setup.py +3 -4
  4. {xax-0.0.3 → xax-0.0.6}/xax/__init__.py +122 -8
  5. {xax-0.0.3 → xax-0.0.6}/xax/core/conf.py +9 -33
  6. {xax-0.0.3 → xax-0.0.6}/xax/core/state.py +13 -23
  7. xax-0.0.6/xax/nn/embeddings.py +355 -0
  8. {xax-0.0.3 → xax-0.0.6}/xax/nn/functions.py +8 -4
  9. xax-0.0.6/xax/requirements-dev.txt +15 -0
  10. xax-0.0.6/xax/requirements.txt +25 -0
  11. {xax-0.0.3 → xax-0.0.6}/xax/task/base.py +2 -6
  12. xax-0.0.6/xax/task/logger.py +855 -0
  13. xax-0.0.6/xax/task/loggers/callback.py +44 -0
  14. xax-0.0.6/xax/task/loggers/state.py +32 -0
  15. {xax-0.0.3 → xax-0.0.6}/xax/task/loggers/tensorboard.py +16 -33
  16. {xax-0.0.3 → xax-0.0.6}/xax/task/mixins/__init__.py +3 -1
  17. {xax-0.0.3 → xax-0.0.6}/xax/task/mixins/artifacts.py +19 -9
  18. xax-0.0.6/xax/task/mixins/checkpointing.py +221 -0
  19. xax-0.0.6/xax/task/mixins/compile.py +104 -0
  20. {xax-0.0.3 → xax-0.0.6}/xax/task/mixins/cpu_stats.py +26 -15
  21. {xax-0.0.3 → xax-0.0.6}/xax/task/mixins/data_loader.py +27 -19
  22. {xax-0.0.3 → xax-0.0.6}/xax/task/mixins/gpu_stats.py +22 -8
  23. xax-0.0.6/xax/task/mixins/logger.py +68 -0
  24. {xax-0.0.3 → xax-0.0.6}/xax/task/mixins/process.py +8 -1
  25. {xax-0.0.3 → xax-0.0.6}/xax/task/mixins/runnable.py +3 -0
  26. {xax-0.0.3 → xax-0.0.6}/xax/task/mixins/step_wrapper.py +5 -0
  27. {xax-0.0.3 → xax-0.0.6}/xax/task/mixins/train.py +236 -145
  28. {xax-0.0.3 → xax-0.0.6}/xax/task/script.py +1 -1
  29. {xax-0.0.3 → xax-0.0.6}/xax/task/task.py +13 -5
  30. {xax-0.0.3 → xax-0.0.6}/xax/utils/data/collate.py +6 -6
  31. {xax-0.0.3 → xax-0.0.6}/xax/utils/experiments.py +45 -1
  32. {xax-0.0.3 → xax-0.0.6}/xax/utils/logging.py +29 -0
  33. {xax-0.0.3 → xax-0.0.6}/xax/utils/tensorboard.py +89 -21
  34. {xax-0.0.3 → xax-0.0.6/xax.egg-info}/PKG-INFO +22 -11
  35. {xax-0.0.3 → xax-0.0.6}/xax.egg-info/SOURCES.txt +4 -2
  36. {xax-0.0.3 → xax-0.0.6}/xax.egg-info/requires.txt +12 -9
  37. xax-0.0.3/tests/test_dummy.py +0 -5
  38. xax-0.0.3/xax/requirements-dev.txt +0 -7
  39. xax-0.0.3/xax/requirements.txt +0 -18
  40. xax-0.0.3/xax/task/launchers/staged.py +0 -29
  41. xax-0.0.3/xax/task/logger.py +0 -848
  42. xax-0.0.3/xax/task/loggers/state.py +0 -45
  43. xax-0.0.3/xax/task/mixins/logger.py +0 -314
  44. {xax-0.0.3 → xax-0.0.6}/LICENSE +0 -0
  45. {xax-0.0.3 → xax-0.0.6}/MANIFEST.in +0 -0
  46. {xax-0.0.3 → xax-0.0.6}/README.md +0 -0
  47. {xax-0.0.3 → xax-0.0.6}/setup.cfg +0 -0
  48. {xax-0.0.3 → xax-0.0.6}/xax/core/__init__.py +0 -0
  49. {xax-0.0.3 → xax-0.0.6}/xax/nn/__init__.py +0 -0
  50. {xax-0.0.3 → xax-0.0.6}/xax/nn/parallel.py +0 -0
  51. {xax-0.0.3 → xax-0.0.6}/xax/py.typed +0 -0
  52. {xax-0.0.3 → xax-0.0.6}/xax/task/__init__.py +0 -0
  53. {xax-0.0.3 → xax-0.0.6}/xax/task/launchers/__init__.py +0 -0
  54. {xax-0.0.3 → xax-0.0.6}/xax/task/launchers/base.py +0 -0
  55. {xax-0.0.3 → xax-0.0.6}/xax/task/launchers/cli.py +0 -0
  56. {xax-0.0.3 → xax-0.0.6}/xax/task/launchers/single_process.py +0 -0
  57. {xax-0.0.3 → xax-0.0.6}/xax/task/loggers/__init__.py +0 -0
  58. {xax-0.0.3 → xax-0.0.6}/xax/task/loggers/json.py +0 -0
  59. {xax-0.0.3 → xax-0.0.6}/xax/task/loggers/stdout.py +0 -0
  60. {xax-0.0.3 → xax-0.0.6}/xax/utils/__init__.py +0 -0
  61. {xax-0.0.3 → xax-0.0.6}/xax/utils/data/__init__.py +0 -0
  62. {xax-0.0.3 → xax-0.0.6}/xax/utils/jax.py +0 -0
  63. {xax-0.0.3 → xax-0.0.6}/xax/utils/numpy.py +0 -0
  64. {xax-0.0.3 → xax-0.0.6}/xax/utils/text.py +0 -0
  65. {xax-0.0.3 → xax-0.0.6}/xax.egg-info/dependency_links.txt +0 -0
  66. {xax-0.0.3 → xax-0.0.6}/xax.egg-info/top_level.txt +0 -0
@@ -1,32 +1,43 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.2
2
2
  Name: xax
3
- Version: 0.0.3
3
+ Version: 0.0.6
4
4
  Summary: The xax project
5
5
  Home-page: https://github.com/dpshai/xax
6
6
  Author: Benjamin Bolte
7
7
  Requires-Python: >=3.11
8
8
  Description-Content-Type: text/markdown
9
9
  License-File: LICENSE
10
- Requires-Dist: dpshdl
11
- Requires-Dist: equinox
12
- Requires-Dist: gitpython
13
10
  Requires-Dist: jax
14
11
  Requires-Dist: jaxtyping
15
- Requires-Dist: omegaconf
12
+ Requires-Dist: equinox
16
13
  Requires-Dist: optax
14
+ Requires-Dist: dpshdl
15
+ Requires-Dist: chex
16
+ Requires-Dist: importlib-resources
17
+ Requires-Dist: cloudpickle
17
18
  Requires-Dist: pillow
19
+ Requires-Dist: omegaconf
20
+ Requires-Dist: gitpython
21
+ Requires-Dist: tensorboard
18
22
  Requires-Dist: psutil
19
23
  Requires-Dist: requests
20
- Requires-Dist: tensorboard
21
- Requires-Dist: types-pillow
22
- Requires-Dist: types-psutil
23
- Requires-Dist: types-requests
24
24
  Provides-Extra: dev
25
25
  Requires-Dist: black; extra == "dev"
26
26
  Requires-Dist: darglint; extra == "dev"
27
27
  Requires-Dist: mypy; extra == "dev"
28
- Requires-Dist: pytest; extra == "dev"
29
28
  Requires-Dist: ruff; extra == "dev"
29
+ Requires-Dist: pytest; extra == "dev"
30
+ Requires-Dist: types-pillow; extra == "dev"
31
+ Requires-Dist: types-psutil; extra == "dev"
32
+ Requires-Dist: types-requests; extra == "dev"
33
+ Dynamic: author
34
+ Dynamic: description
35
+ Dynamic: description-content-type
36
+ Dynamic: home-page
37
+ Dynamic: provides-extra
38
+ Dynamic: requires-dist
39
+ Dynamic: requires-python
40
+ Dynamic: summary
30
41
 
31
42
  # xax
32
43
 
@@ -35,6 +35,7 @@ explicit_package_bases = true
35
35
  [[tool.mypy.overrides]]
36
36
 
37
37
  module = [
38
+ "cloudpickle.*",
38
39
  "optax.*",
39
40
  "setuptools.*",
40
41
  "tensorboard.*",
@@ -54,12 +55,24 @@ target-version = "py311"
54
55
 
55
56
  [tool.ruff.lint]
56
57
 
57
- select = ["ANN", "D", "E", "F", "I", "N", "PGH", "PLC", "PLE", "PLR", "PLW", "W"]
58
+ select = [
59
+ "ANN",
60
+ "D",
61
+ "E",
62
+ "F",
63
+ "G",
64
+ "I",
65
+ "N",
66
+ "PGH",
67
+ "PLC",
68
+ "PLE",
69
+ "PLR",
70
+ "PLW",
71
+ "W",
72
+ ]
58
73
 
59
74
  ignore = [
60
- "ANN101", "ANN102", "ANN401",
61
75
  "D101", "D102", "D103", "D104", "D105", "D106", "D107",
62
- "F722",
63
76
  "N812", "N817",
64
77
  "PLR0911", "PLR0912", "PLR0913", "PLR0915", "PLR2004",
65
78
  "PLW0603", "PLW2901",
@@ -8,15 +8,12 @@ from setuptools import setup
8
8
  with open("README.md", "r", encoding="utf-8") as f:
9
9
  long_description: str = f.read()
10
10
 
11
-
12
11
  with open("xax/requirements.txt", "r", encoding="utf-8") as f:
13
12
  requirements: list[str] = f.read().splitlines()
14
13
 
15
-
16
14
  with open("xax/requirements-dev.txt", "r", encoding="utf-8") as f:
17
15
  requirements_dev: list[str] = f.read().splitlines()
18
16
 
19
-
20
17
  with open("xax/__init__.py", "r", encoding="utf-8") as fh:
21
18
  version_re = re.search(r"^__version__ = \"([^\"]*)\"", fh.read(), re.MULTILINE)
22
19
  assert version_re is not None, "Could not find version in xax/__init__.py"
@@ -34,7 +31,9 @@ setup(
34
31
  python_requires=">=3.11",
35
32
  install_requires=requirements,
36
33
  tests_require=requirements_dev,
37
- extras_require={"dev": requirements_dev},
34
+ extras_require={
35
+ "dev": requirements_dev,
36
+ },
38
37
  package_data={
39
38
  "xax": [
40
39
  "py.typed",
@@ -11,7 +11,7 @@ This file can be maintained by running the update script:
11
11
  python -m scripts.update_api --inplace
12
12
  """
13
13
 
14
- __version__ = "0.0.3"
14
+ __version__ = "0.0.6"
15
15
 
16
16
  # This list shouldn't be modified by hand; instead, run the update script.
17
17
  __all__ = [
@@ -23,15 +23,26 @@ __all__ = [
23
23
  "load_user_config",
24
24
  "State",
25
25
  "cast_phase",
26
+ "FourierEmbeddings",
27
+ "IdentityPositionalEmbeddings",
28
+ "LearnedPositionalEmbeddings",
29
+ "RotaryEmbeddings",
30
+ "SinusoidalEmbeddings",
31
+ "apply_rotary_embeddings",
32
+ "cast_embedding_kind",
33
+ "fourier_embeddings",
34
+ "get_positional_embeddings",
35
+ "get_rotary_embeddings",
36
+ "rotary_embeddings",
37
+ "is_master",
26
38
  "BaseLauncher",
27
39
  "CliLauncher",
28
40
  "SingleProcessLauncher",
29
- "LogAudio",
30
41
  "LogImage",
31
42
  "LogLine",
32
- "LogVideo",
33
43
  "Logger",
34
44
  "LoggerImpl",
45
+ "CallbackLogger",
35
46
  "JsonLogger",
36
47
  "StateLogger",
37
48
  "StdoutLogger",
@@ -46,34 +57,59 @@ __all__ = [
46
57
  "collate",
47
58
  "collate_non_null",
48
59
  "BaseFileDownloader",
60
+ "CumulativeTimer",
49
61
  "DataDownloader",
62
+ "IntervalTicker",
63
+ "IterationTimer",
64
+ "MinGradScaleError",
50
65
  "ModelDownloader",
66
+ "NaNError",
67
+ "StateTimer",
68
+ "TrainingFinishedError",
51
69
  "check_md5",
52
70
  "check_sha256",
71
+ "cpu_count",
72
+ "date_str",
73
+ "diff_configs",
53
74
  "get_git_state",
75
+ "get_random_port",
54
76
  "get_state_dict_prefix",
55
77
  "get_training_code",
56
78
  "save_config",
79
+ "stage_environment",
80
+ "to_markdown_table",
57
81
  "ColoredFormatter",
58
82
  "configure_logging",
59
83
  "one_hot",
60
84
  "partial_flatten",
61
85
  "worker_chunk",
62
86
  "TextBlock",
87
+ "camelcase_to_snakecase",
63
88
  "colored",
64
89
  "format_datetime",
65
90
  "format_timedelta",
91
+ "highlight_exception_message",
92
+ "is_interactive_session",
66
93
  "outlined",
67
94
  "render_text_blocks",
68
95
  "show_error",
96
+ "show_info",
69
97
  "show_warning",
98
+ "snakecase_to_camelcase",
70
99
  "uncolored",
71
100
  "wrapped",
72
101
  ]
73
102
 
74
103
  __all__ += [
104
+ "Batch",
75
105
  "CollateMode",
106
+ "EmbeddingKind",
107
+ "LOG_ERROR_SUMMARY",
108
+ "LOG_PING",
109
+ "LOG_STATUS",
110
+ "Output",
76
111
  "Phase",
112
+ "RawConfigType",
77
113
  ]
78
114
 
79
115
  import os
@@ -95,21 +131,32 @@ NAME_MAP: dict[str, str] = {
95
131
  "load_user_config": "core.conf",
96
132
  "State": "core.state",
97
133
  "cast_phase": "core.state",
134
+ "FourierEmbeddings": "nn.embeddings",
135
+ "IdentityPositionalEmbeddings": "nn.embeddings",
136
+ "LearnedPositionalEmbeddings": "nn.embeddings",
137
+ "RotaryEmbeddings": "nn.embeddings",
138
+ "SinusoidalEmbeddings": "nn.embeddings",
139
+ "apply_rotary_embeddings": "nn.embeddings",
140
+ "cast_embedding_kind": "nn.embeddings",
141
+ "fourier_embeddings": "nn.embeddings",
142
+ "get_positional_embeddings": "nn.embeddings",
143
+ "get_rotary_embeddings": "nn.embeddings",
144
+ "rotary_embeddings": "nn.embeddings",
145
+ "is_master": "nn.parallel",
98
146
  "BaseLauncher": "task.launchers.base",
99
147
  "CliLauncher": "task.launchers.cli",
100
148
  "SingleProcessLauncher": "task.launchers.single_process",
101
- "LogAudio": "task.logger",
102
149
  "LogImage": "task.logger",
103
150
  "LogLine": "task.logger",
104
- "LogVideo": "task.logger",
105
151
  "Logger": "task.logger",
106
152
  "LoggerImpl": "task.logger",
153
+ "CallbackLogger": "task.loggers.callback",
107
154
  "JsonLogger": "task.loggers.json",
108
155
  "StateLogger": "task.loggers.state",
109
156
  "StdoutLogger": "task.loggers.stdout",
110
157
  "TensorboardLogger": "task.loggers.tensorboard",
111
158
  "CPUStatsOptions": "task.mixins.cpu_stats",
112
- "DataLoaderConfig": "task.mixins.data_loader",
159
+ "DataloaderConfig": "task.mixins.data_loader",
113
160
  "GPUStatsOptions": "task.mixins.gpu_stats",
114
161
  "Script": "task.script",
115
162
  "ScriptConfig": "task.script",
@@ -118,27 +165,45 @@ NAME_MAP: dict[str, str] = {
118
165
  "collate": "utils.data.collate",
119
166
  "collate_non_null": "utils.data.collate",
120
167
  "BaseFileDownloader": "utils.experiments",
168
+ "CumulativeTimer": "utils.experiments",
121
169
  "DataDownloader": "utils.experiments",
170
+ "IntervalTicker": "utils.experiments",
171
+ "IterationTimer": "utils.experiments",
172
+ "MinGradScaleError": "utils.experiments",
122
173
  "ModelDownloader": "utils.experiments",
174
+ "NaNError": "utils.experiments",
175
+ "StateTimer": "utils.experiments",
176
+ "TrainingFinishedError": "utils.experiments",
123
177
  "check_md5": "utils.experiments",
124
178
  "check_sha256": "utils.experiments",
179
+ "cpu_count": "utils.experiments",
180
+ "date_str": "utils.experiments",
181
+ "diff_configs": "utils.experiments",
125
182
  "get_git_state": "utils.experiments",
183
+ "get_random_port": "utils.experiments",
126
184
  "get_state_dict_prefix": "utils.experiments",
127
185
  "get_training_code": "utils.experiments",
128
186
  "save_config": "utils.experiments",
187
+ "stage_environment": "utils.experiments",
188
+ "to_markdown_table": "utils.experiments",
129
189
  "ColoredFormatter": "utils.logging",
130
190
  "configure_logging": "utils.logging",
131
191
  "one_hot": "utils.numpy",
132
192
  "partial_flatten": "utils.numpy",
133
193
  "worker_chunk": "utils.numpy",
134
194
  "TextBlock": "utils.text",
195
+ "camelcase_to_snakecase": "utils.text",
135
196
  "colored": "utils.text",
136
197
  "format_datetime": "utils.text",
137
198
  "format_timedelta": "utils.text",
199
+ "highlight_exception_message": "utils.text",
200
+ "is_interactive_session": "utils.text",
138
201
  "outlined": "utils.text",
139
202
  "render_text_blocks": "utils.text",
140
203
  "show_error": "utils.text",
204
+ "show_info": "utils.text",
141
205
  "show_warning": "utils.text",
206
+ "snakecase_to_camelcase": "utils.text",
142
207
  "uncolored": "utils.text",
143
208
  "wrapped": "utils.text",
144
209
  }
@@ -146,8 +211,15 @@ NAME_MAP: dict[str, str] = {
146
211
  # Need to manually set some values which can't be auto-generated.
147
212
  NAME_MAP.update(
148
213
  {
214
+ "Batch": "task.mixins.train",
149
215
  "CollateMode": "utils.data.collate",
216
+ "EmbeddingKind": "nn.embeddings",
217
+ "LOG_ERROR_SUMMARY": "utils.logging",
218
+ "LOG_PING": "utils.logging",
219
+ "LOG_STATUS": "utils.logging",
220
+ "Output": "task.mixins.output",
150
221
  "Phase": "core.state",
222
+ "RawConfigType": "task.base",
151
223
  },
152
224
  )
153
225
 
@@ -171,10 +243,27 @@ if IMPORT_ALL or TYPE_CHECKING:
171
243
  load_user_config,
172
244
  )
173
245
  from xax.core.state import Phase, State, cast_phase
246
+ from xax.nn.embeddings import (
247
+ EmbeddingKind,
248
+ FourierEmbeddings,
249
+ IdentityPositionalEmbeddings,
250
+ LearnedPositionalEmbeddings,
251
+ RotaryEmbeddings,
252
+ SinusoidalEmbeddings,
253
+ apply_rotary_embeddings,
254
+ cast_embedding_kind,
255
+ fourier_embeddings,
256
+ get_positional_embeddings,
257
+ get_rotary_embeddings,
258
+ rotary_embeddings,
259
+ )
260
+ from xax.nn.parallel import is_master
261
+ from xax.task.base import RawConfigType
174
262
  from xax.task.launchers.base import BaseLauncher
175
263
  from xax.task.launchers.cli import CliLauncher
176
264
  from xax.task.launchers.single_process import SingleProcessLauncher
177
- from xax.task.logger import LogAudio, Logger, LoggerImpl, LogImage, LogLine, LogVideo
265
+ from xax.task.logger import Logger, LoggerImpl, LogImage, LogLine
266
+ from xax.task.loggers.callback import CallbackLogger
178
267
  from xax.task.loggers.json import JsonLogger
179
268
  from xax.task.loggers.state import StateLogger
180
269
  from xax.task.loggers.stdout import StdoutLogger
@@ -182,31 +271,56 @@ if IMPORT_ALL or TYPE_CHECKING:
182
271
  from xax.task.mixins.cpu_stats import CPUStatsOptions
183
272
  from xax.task.mixins.data_loader import DataloaderConfig
184
273
  from xax.task.mixins.gpu_stats import GPUStatsOptions
274
+ from xax.task.mixins.train import Batch, Output
185
275
  from xax.task.script import Script, ScriptConfig
186
276
  from xax.task.task import Config, Task
187
277
  from xax.utils.data.collate import CollateMode, collate, collate_non_null
188
278
  from xax.utils.experiments import (
189
279
  BaseFileDownloader,
280
+ CumulativeTimer,
190
281
  DataDownloader,
282
+ IntervalTicker,
283
+ IterationTimer,
284
+ MinGradScaleError,
191
285
  ModelDownloader,
286
+ NaNError,
287
+ StateTimer,
288
+ TrainingFinishedError,
192
289
  check_md5,
193
290
  check_sha256,
291
+ cpu_count,
292
+ date_str,
293
+ diff_configs,
194
294
  get_git_state,
295
+ get_random_port,
195
296
  get_state_dict_prefix,
196
297
  get_training_code,
197
298
  save_config,
299
+ stage_environment,
300
+ to_markdown_table,
301
+ )
302
+ from xax.utils.logging import (
303
+ LOG_ERROR_SUMMARY,
304
+ LOG_PING,
305
+ LOG_STATUS,
306
+ ColoredFormatter,
307
+ configure_logging,
198
308
  )
199
- from xax.utils.logging import ColoredFormatter, configure_logging
200
309
  from xax.utils.numpy import one_hot, partial_flatten, worker_chunk
201
310
  from xax.utils.text import (
202
311
  TextBlock,
312
+ camelcase_to_snakecase,
203
313
  colored,
204
314
  format_datetime,
205
315
  format_timedelta,
316
+ highlight_exception_message,
317
+ is_interactive_session,
206
318
  outlined,
207
319
  render_text_blocks,
208
320
  show_error,
321
+ show_info,
209
322
  show_warning,
323
+ snakecase_to_camelcase,
210
324
  uncolored,
211
325
  wrapped,
212
326
  )
@@ -6,7 +6,6 @@ from dataclasses import dataclass, field as field_base
6
6
  from pathlib import Path
7
7
  from typing import Any, cast
8
8
 
9
- import jax.numpy as jnp
10
9
  from omegaconf import II, MISSING, Container as OmegaConfContainer, OmegaConf
11
10
 
12
11
  from xax.utils.text import show_error
@@ -61,67 +60,44 @@ def is_missing(cfg: Any, key: str) -> bool: # noqa: ANN401
61
60
  return False
62
61
 
63
62
 
64
- @dataclass
63
+ @dataclass(kw_only=True)
65
64
  class Logging:
66
65
  hide_third_party_logs: bool = field(True, help="If set, hide third-party logs")
67
- log_level: str = field("INFO", help="The logging level to use")
66
+ log_level: str = field(II("oc.env:XAX_LOG_LEVEL,INFO"), help="The logging level to use")
68
67
 
69
68
 
70
- @dataclass
71
- class Device:
72
- cpu: bool = field(True, help="Whether to use the CPU")
73
- gpu: bool = field(II("oc.env:USE_GPU,1"), help="Whether to use the GPU")
74
- metal: bool = field(II("oc.env:USE_METAL,1"), help="Whether to use the Apple Silicon accelerator")
75
- use_fp64: bool = field(False, help="Always use the 64-bit floating point type")
76
- use_fp32: bool = field(False, help="Always use the 32-bit floating point type")
77
- use_bf16: bool = field(False, help="Always use the 16-bit bfloat type")
78
- use_fp16: bool = field(False, help="Always use the 16-bit floating point type")
79
-
80
-
81
- def parse_dtype(cfg: Device) -> jnp.dtype | None:
82
- if cfg.use_fp64:
83
- return jnp.float64
84
- if cfg.use_fp32:
85
- return jnp.float32
86
- if cfg.use_bf16:
87
- return jnp.bfloat16
88
- if cfg.use_fp16:
89
- return jnp.float16
90
- return None
91
-
92
-
93
- @dataclass
69
+ @dataclass(kw_only=True)
94
70
  class Triton:
95
71
  use_triton_if_available: bool = field(True, help="Use Triton if available")
96
72
 
97
73
 
98
- @dataclass
74
+ @dataclass(kw_only=True)
99
75
  class Experiment:
100
76
  default_random_seed: int = field(1337, help="The default random seed to use")
77
+ max_workers: int = field(32, help="Maximum number of workers to use")
101
78
 
102
79
 
103
- @dataclass
80
+ @dataclass(kw_only=True)
104
81
  class Directories:
105
82
  run: str = field(II("oc.env:RUN_DIR"), help="The run directory")
106
83
  data: str = field(II("oc.env:DATA_DIR"), help="The data directory")
107
84
  pretrained_models: str = field(II("oc.env:MODEL_DIR"), help="The models directory")
108
85
 
109
86
 
110
- @dataclass
87
+ @dataclass(kw_only=True)
111
88
  class SlurmPartition:
112
89
  partition: str = field(MISSING, help="The partition name")
113
90
  num_nodes: int = field(1, help="The number of nodes to use")
114
91
 
115
92
 
116
- @dataclass
93
+ @dataclass(kw_only=True)
117
94
  class Slurm:
118
95
  launch: dict[str, SlurmPartition] = field({}, help="The available launch configurations")
119
96
 
120
97
 
121
- @dataclass
98
+ @dataclass(kw_only=True)
122
99
  class UserConfig:
123
100
  logging: Logging = field(Logging)
124
- device: Device = field(Device)
125
101
  triton: Triton = field(Triton)
126
102
  experiment: Experiment = field(Experiment)
127
103
  directories: Directories = field(Directories)
@@ -2,7 +2,7 @@
2
2
 
3
3
  import time
4
4
  from dataclasses import dataclass
5
- from typing import Literal, TypedDict, cast, get_args
5
+ from typing import Literal, NotRequired, TypedDict, cast, get_args
6
6
 
7
7
  from omegaconf import MISSING
8
8
 
@@ -18,16 +18,16 @@ def cast_phase(raw_phase: str) -> Phase:
18
18
 
19
19
 
20
20
  class StateDict(TypedDict, total=False):
21
- num_steps: int
22
- num_samples: int
23
- num_valid_steps: int
24
- num_valid_samples: int
25
- start_time_s: float
26
- elapsed_time_s: float
27
- raw_phase: str
21
+ num_steps: NotRequired[int]
22
+ num_samples: NotRequired[int]
23
+ num_valid_steps: NotRequired[int]
24
+ num_valid_samples: NotRequired[int]
25
+ start_time_s: NotRequired[float]
26
+ elapsed_time_s: NotRequired[float]
27
+ raw_phase: NotRequired[str]
28
28
 
29
29
 
30
- @dataclass(frozen=True)
30
+ @dataclass
31
31
  class State:
32
32
  num_steps: int = field(MISSING, help="Number of steps so far")
33
33
  num_samples: int = field(MISSING, help="Number of sample so far")
@@ -41,6 +41,10 @@ class State:
41
41
  def phase(self) -> Phase:
42
42
  return cast_phase(self.raw_phase)
43
43
 
44
+ @phase.setter
45
+ def phase(self, phase: Phase) -> None:
46
+ self.raw_phase = phase
47
+
44
48
  @classmethod
45
49
  def init_state(cls) -> "State":
46
50
  return cls(
@@ -65,17 +69,3 @@ class State:
65
69
  return self.num_valid_steps
66
70
  case _:
67
71
  raise ValueError(f"Invalid phase: {phase}")
68
-
69
- def replace(self, values: StateDict) -> "State":
70
- return State(
71
- num_steps=values.get("num_steps", self.num_steps),
72
- num_samples=values.get("num_samples", self.num_samples),
73
- num_valid_steps=values.get("num_valid_steps", self.num_valid_steps),
74
- num_valid_samples=values.get("num_valid_samples", self.num_valid_samples),
75
- start_time_s=values.get("start_time_s", self.start_time_s),
76
- elapsed_time_s=values.get("elapsed_time_s", self.elapsed_time_s),
77
- raw_phase=values.get("raw_phase", self.raw_phase),
78
- )
79
-
80
- def with_phase(self, phase: Phase) -> "State":
81
- return self.replace({"raw_phase": phase})