xax 0.0.5__py3-none-any.whl → 0.0.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +102 -2
- xax/core/conf.py +8 -33
- xax/core/state.py +13 -23
- xax/nn/geom.py +75 -0
- xax/requirements.txt +2 -0
- xax/task/base.py +2 -0
- xax/task/logger.py +194 -122
- xax/task/loggers/callback.py +4 -16
- xax/task/loggers/state.py +5 -18
- xax/task/loggers/tensorboard.py +14 -28
- xax/task/mixins/__init__.py +1 -0
- xax/task/mixins/artifacts.py +7 -4
- xax/task/mixins/checkpointing.py +12 -0
- xax/task/mixins/compile.py +104 -0
- xax/task/mixins/cpu_stats.py +16 -5
- xax/task/mixins/data_loader.py +23 -12
- xax/task/mixins/gpu_stats.py +19 -5
- xax/task/mixins/logger.py +4 -2
- xax/task/mixins/process.py +4 -1
- xax/task/mixins/runnable.py +3 -0
- xax/task/mixins/step_wrapper.py +5 -0
- xax/task/mixins/train.py +189 -129
- xax/task/script.py +1 -1
- xax/task/task.py +7 -0
- xax/utils/jax.py +126 -0
- xax/utils/profile.py +61 -0
- xax/utils/pytree.py +50 -0
- xax/utils/tensorboard.py +48 -0
- {xax-0.0.5.dist-info → xax-0.0.7.dist-info}/METADATA +12 -2
- xax-0.0.7.dist-info/RECORD +55 -0
- {xax-0.0.5.dist-info → xax-0.0.7.dist-info}/WHEEL +1 -1
- xax/task/launchers/staged.py +0 -29
- xax-0.0.5.dist-info/RECORD +0 -52
- {xax-0.0.5.dist-info → xax-0.0.7.dist-info}/LICENSE +0 -0
- {xax-0.0.5.dist-info → xax-0.0.7.dist-info}/top_level.txt +0 -0
xax/task/mixins/checkpointing.py
CHANGED
@@ -9,6 +9,7 @@ from pathlib import Path
|
|
9
9
|
from typing import Any, Callable, Generic, Literal, TypeVar, cast, overload
|
10
10
|
|
11
11
|
import cloudpickle
|
12
|
+
import jax
|
12
13
|
import optax
|
13
14
|
from jaxtyping import PyTree
|
14
15
|
from omegaconf import DictConfig, OmegaConf
|
@@ -38,6 +39,7 @@ def get_ckpt_path(exp_dir: Path, state: State | None = None) -> Path:
|
|
38
39
|
return exp_dir / "checkpoints" / f"ckpt.{state.num_steps}.bin"
|
39
40
|
|
40
41
|
|
42
|
+
@jax.tree_util.register_dataclass
|
41
43
|
@dataclass
|
42
44
|
class CheckpointingConfig(ArtifactsConfig):
|
43
45
|
save_every_n_steps: int | None = field(None, help="Save a checkpoint every N steps")
|
@@ -45,6 +47,7 @@ class CheckpointingConfig(ArtifactsConfig):
|
|
45
47
|
only_save_most_recent: bool = field(True, help="Only keep the most recent checkpoint")
|
46
48
|
load_from_ckpt_path: str | None = field(None, help="If set, load initial model weights from this path")
|
47
49
|
load_ckpt_strict: bool = field(True, help="If set, only load weights for which have a matching key in the model")
|
50
|
+
save_tf_model: bool = field(False, help="If set, saves a Tensorflow version of the model")
|
48
51
|
|
49
52
|
|
50
53
|
Config = TypeVar("Config", bound=CheckpointingConfig)
|
@@ -196,6 +199,15 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
196
199
|
add_file("state", lambda buf: buf.write(json.dumps(asdict(state), indent=2).encode()))
|
197
200
|
add_file("config", lambda buf: buf.write(OmegaConf.to_yaml(self.config).encode()))
|
198
201
|
|
202
|
+
if self.config.save_tf_model:
|
203
|
+
try:
|
204
|
+
from jax.experimental import jax2tf
|
205
|
+
except ModuleNotFoundError:
|
206
|
+
raise ImportError("Tensorflow is not installed. Install it with `pip install tensorflow`")
|
207
|
+
|
208
|
+
tf_model = jax2tf.convert(model)
|
209
|
+
add_file("model.tf", lambda buf: cloudpickle.dump(tf_model, buf))
|
210
|
+
|
199
211
|
# Updates the symlink to the new checkpoint.
|
200
212
|
last_ckpt_path.unlink(missing_ok=True)
|
201
213
|
try:
|
@@ -0,0 +1,104 @@
|
|
1
|
+
"""Defines a mixin for handling JAX compilation behavior.
|
2
|
+
|
3
|
+
This mixin allows control over JAX compilation settings like jit, pmap, and vmap
|
4
|
+
behavior during initialization and training.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import logging
|
8
|
+
from dataclasses import dataclass
|
9
|
+
from pathlib import Path
|
10
|
+
from typing import Generic, TypeVar
|
11
|
+
|
12
|
+
import jax
|
13
|
+
|
14
|
+
from xax.core.conf import field
|
15
|
+
from xax.task.base import BaseConfig, BaseTask
|
16
|
+
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
|
19
|
+
|
20
|
+
@jax.tree_util.register_dataclass
|
21
|
+
@dataclass
|
22
|
+
class CompileOptions:
|
23
|
+
# JAX compilation options
|
24
|
+
disable_jit: bool = field(
|
25
|
+
value=False,
|
26
|
+
help="If True, disables JIT compilation",
|
27
|
+
)
|
28
|
+
enable_x64: bool = field(
|
29
|
+
value=False,
|
30
|
+
help="If True, enables 64-bit precision",
|
31
|
+
)
|
32
|
+
default_device: str | None = field(
|
33
|
+
value=None,
|
34
|
+
help="Default device to use (e.g. 'cpu', 'gpu')",
|
35
|
+
)
|
36
|
+
|
37
|
+
# JAX logging options
|
38
|
+
logging_level: str = field(
|
39
|
+
value="INFO",
|
40
|
+
help="JAX logging verbosity level",
|
41
|
+
)
|
42
|
+
|
43
|
+
# JAX cache options
|
44
|
+
cache_dir: str | None = field(
|
45
|
+
value=lambda: str((Path.home() / ".cache" / "jax" / "jaxcache").resolve()),
|
46
|
+
help="Directory for JAX compilation cache. If None, caching is disabled",
|
47
|
+
)
|
48
|
+
cache_min_size_bytes: int = field(
|
49
|
+
value=-1,
|
50
|
+
help="Minimum size in bytes for cache entries. -1 means no minimum",
|
51
|
+
)
|
52
|
+
cache_min_compile_time_secs: float = field(
|
53
|
+
value=0.0,
|
54
|
+
help="Minimum compilation time in seconds for cache entries. 0 means no minimum",
|
55
|
+
)
|
56
|
+
cache_enable_xla: str = field(
|
57
|
+
value="all",
|
58
|
+
help="Which XLA caches to enable",
|
59
|
+
)
|
60
|
+
|
61
|
+
|
62
|
+
@jax.tree_util.register_dataclass
|
63
|
+
@dataclass
|
64
|
+
class CompileConfig(BaseConfig):
|
65
|
+
compile: CompileOptions = field(CompileOptions(), help="Compilation configuration")
|
66
|
+
|
67
|
+
|
68
|
+
Config = TypeVar("Config", bound=CompileConfig)
|
69
|
+
|
70
|
+
|
71
|
+
class CompileMixin(BaseTask[Config], Generic[Config]):
|
72
|
+
"""Defines a task mixin for controlling JAX compilation behavior."""
|
73
|
+
|
74
|
+
def __init__(self, config: Config) -> None:
|
75
|
+
super().__init__(config)
|
76
|
+
|
77
|
+
cc = self.config.compile
|
78
|
+
|
79
|
+
# Set basic compilation flags
|
80
|
+
if cc.disable_jit:
|
81
|
+
logger.info("Disabling JIT compilation")
|
82
|
+
jax.config.update("jax_disable_jit", True)
|
83
|
+
|
84
|
+
if cc.enable_x64:
|
85
|
+
logger.info("Enabling 64-bit precision")
|
86
|
+
jax.config.update("jax_enable_x64", True)
|
87
|
+
|
88
|
+
if cc.default_device is not None:
|
89
|
+
logger.info("Setting default device to %s", cc.default_device)
|
90
|
+
jax.config.update("jax_default_device", cc.default_device)
|
91
|
+
|
92
|
+
# Set logging level
|
93
|
+
logger.info("Setting JAX logging level to %s", cc.logging_level)
|
94
|
+
jax.config.update("jax_logging_level", cc.logging_level)
|
95
|
+
|
96
|
+
# Configure compilation cache
|
97
|
+
if cc.cache_dir is not None:
|
98
|
+
logger.info("Setting JAX compilation cache directory to %s", cc.cache_dir)
|
99
|
+
jax.config.update("jax_compilation_cache_dir", cc.cache_dir)
|
100
|
+
|
101
|
+
logger.info("Configuring JAX compilation cache parameters")
|
102
|
+
jax.config.update("jax_persistent_cache_min_entry_size_bytes", cc.cache_min_size_bytes)
|
103
|
+
jax.config.update("jax_persistent_cache_min_compile_time_secs", cc.cache_min_compile_time_secs)
|
104
|
+
jax.config.update("jax_persistent_cache_enable_xla_caches", cc.cache_enable_xla)
|
xax/task/mixins/cpu_stats.py
CHANGED
@@ -6,15 +6,16 @@ leaks in your dataloader, among other issues.
|
|
6
6
|
"""
|
7
7
|
|
8
8
|
import logging
|
9
|
-
import multiprocessing as mp
|
10
9
|
import os
|
11
10
|
import time
|
12
11
|
from ctypes import Structure, c_double, c_uint16, c_uint64
|
13
12
|
from dataclasses import dataclass
|
13
|
+
from multiprocessing.context import BaseContext, Process
|
14
14
|
from multiprocessing.managers import SyncManager, ValueProxy
|
15
15
|
from multiprocessing.synchronize import Event
|
16
16
|
from typing import Generic, TypeVar
|
17
17
|
|
18
|
+
import jax
|
18
19
|
import psutil
|
19
20
|
|
20
21
|
from xax.core.conf import field
|
@@ -26,12 +27,14 @@ from xax.task.mixins.process import ProcessConfig, ProcessMixin
|
|
26
27
|
logger: logging.Logger = logging.getLogger(__name__)
|
27
28
|
|
28
29
|
|
30
|
+
@jax.tree_util.register_dataclass
|
29
31
|
@dataclass
|
30
32
|
class CPUStatsOptions:
|
31
33
|
ping_interval: int = field(1, help="How often to check stats (in seconds)")
|
32
34
|
only_log_once: bool = field(False, help="If set, only log read stats one time")
|
33
35
|
|
34
36
|
|
37
|
+
@jax.tree_util.register_dataclass
|
35
38
|
@dataclass
|
36
39
|
class CPUStatsConfig(ProcessConfig, LoggerConfig, BaseConfig):
|
37
40
|
cpu_stats: CPUStatsOptions = field(CPUStatsOptions(), help="CPU stats configuration")
|
@@ -55,7 +58,7 @@ class CPUStats(Structure):
|
|
55
58
|
]
|
56
59
|
|
57
60
|
|
58
|
-
@dataclass
|
61
|
+
@dataclass(kw_only=True)
|
59
62
|
class CPUStatsInfo:
|
60
63
|
cpu_percent: float
|
61
64
|
mem_percent: float
|
@@ -142,9 +145,16 @@ def worker(
|
|
142
145
|
|
143
146
|
|
144
147
|
class CPUStatsMonitor:
|
145
|
-
def __init__(
|
148
|
+
def __init__(
|
149
|
+
self,
|
150
|
+
ping_interval: float,
|
151
|
+
context: BaseContext,
|
152
|
+
manager: SyncManager,
|
153
|
+
) -> None:
|
146
154
|
self._ping_interval = ping_interval
|
147
155
|
self._manager = manager
|
156
|
+
self._context = context
|
157
|
+
|
148
158
|
self._monitor_event = self._manager.Event()
|
149
159
|
self._start_event = self._manager.Event()
|
150
160
|
self._cpu_stats_smem = self._manager.Value(
|
@@ -163,7 +173,7 @@ class CPUStatsMonitor:
|
|
163
173
|
),
|
164
174
|
)
|
165
175
|
self._cpu_stats: CPUStatsInfo | None = None
|
166
|
-
self._proc:
|
176
|
+
self._proc: Process | None = None
|
167
177
|
|
168
178
|
def get_if_set(self) -> CPUStatsInfo | None:
|
169
179
|
if self._monitor_event.is_set():
|
@@ -184,7 +194,7 @@ class CPUStatsMonitor:
|
|
184
194
|
if self._start_event.is_set():
|
185
195
|
self._start_event.clear()
|
186
196
|
self._cpu_stats = None
|
187
|
-
self._proc =
|
197
|
+
self._proc = self._context.Process( # type: ignore[attr-defined]
|
188
198
|
target=worker,
|
189
199
|
args=(self._ping_interval, self._cpu_stats_smem, self._monitor_event, self._start_event, os.getpid()),
|
190
200
|
daemon=True,
|
@@ -215,6 +225,7 @@ class CPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
|
|
215
225
|
|
216
226
|
self._cpu_stats_monitor = CPUStatsMonitor(
|
217
227
|
ping_interval=self.config.cpu_stats.ping_interval,
|
228
|
+
context=self._mp_ctx,
|
218
229
|
manager=self._mp_manager,
|
219
230
|
)
|
220
231
|
|
xax/task/mixins/data_loader.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1
1
|
"""Defines a mixin for instantiating dataloaders."""
|
2
2
|
|
3
3
|
import logging
|
4
|
-
from abc import ABC
|
4
|
+
from abc import ABC
|
5
5
|
from dataclasses import dataclass
|
6
|
-
from typing import Generic, TypeVar
|
6
|
+
from typing import Generic, Iterator, TypeVar
|
7
7
|
|
8
8
|
import jax
|
9
9
|
from dpshdl.dataloader import CollatedDataloaderItem, Dataloader
|
@@ -13,7 +13,7 @@ from omegaconf import II, MISSING
|
|
13
13
|
|
14
14
|
from xax.core.conf import field, is_missing
|
15
15
|
from xax.core.state import Phase
|
16
|
-
from xax.nn.functions import
|
16
|
+
from xax.nn.functions import set_random_seed
|
17
17
|
from xax.task.base import BaseConfig, BaseTask
|
18
18
|
from xax.task.mixins.process import ProcessConfig, ProcessMixin
|
19
19
|
from xax.utils.logging import LOG_ERROR_SUMMARY, configure_logging
|
@@ -24,6 +24,7 @@ T = TypeVar("T")
|
|
24
24
|
Tc_co = TypeVar("Tc_co", covariant=True)
|
25
25
|
|
26
26
|
|
27
|
+
@jax.tree_util.register_dataclass
|
27
28
|
@dataclass
|
28
29
|
class DataloaderErrorConfig:
|
29
30
|
sleep_backoff: float = field(0.1, help="The initial sleep time after an exception")
|
@@ -36,6 +37,7 @@ class DataloaderErrorConfig:
|
|
36
37
|
log_exceptions_all_workers: bool = field(False, help="If set, log exceptions from all workers")
|
37
38
|
|
38
39
|
|
40
|
+
@jax.tree_util.register_dataclass
|
39
41
|
@dataclass
|
40
42
|
class DataloaderConfig:
|
41
43
|
num_workers: int | None = field(MISSING, help="Number of workers for loading samples")
|
@@ -43,6 +45,7 @@ class DataloaderConfig:
|
|
43
45
|
error: DataloaderErrorConfig = field(DataloaderErrorConfig(), help="Dataloader error configuration")
|
44
46
|
|
45
47
|
|
48
|
+
@jax.tree_util.register_dataclass
|
46
49
|
@dataclass
|
47
50
|
class DataloadersConfig(ProcessConfig, BaseConfig):
|
48
51
|
batch_size: int = field(MISSING, help="Size of each batch")
|
@@ -63,9 +66,6 @@ Config = TypeVar("Config", bound=DataloadersConfig)
|
|
63
66
|
|
64
67
|
class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config], ABC):
|
65
68
|
def __init__(self, config: Config) -> None:
|
66
|
-
if is_missing(config, "batch_size"):
|
67
|
-
config.batch_size = self.get_batch_size()
|
68
|
-
|
69
69
|
super().__init__(config)
|
70
70
|
|
71
71
|
def get_batch_size(self) -> int:
|
@@ -74,6 +74,12 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
|
|
74
74
|
"method to return the desired training batch size."
|
75
75
|
)
|
76
76
|
|
77
|
+
@property
|
78
|
+
def batch_size(self) -> int:
|
79
|
+
if is_missing(self.config, "batch_size"):
|
80
|
+
self.config.batch_size = self.get_batch_size()
|
81
|
+
return self.config.batch_size
|
82
|
+
|
77
83
|
def dataloader_config(self, phase: Phase) -> DataloaderConfig:
|
78
84
|
match phase:
|
79
85
|
case "train":
|
@@ -83,7 +89,6 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
|
|
83
89
|
case _:
|
84
90
|
raise KeyError(f"Unknown phase: {phase}")
|
85
91
|
|
86
|
-
@abstractmethod
|
87
92
|
def get_dataset(self, phase: Phase) -> Dataset:
|
88
93
|
"""Returns the dataset for the given phase.
|
89
94
|
|
@@ -93,6 +98,16 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
|
|
93
98
|
Returns:
|
94
99
|
The dataset for the given phase.
|
95
100
|
"""
|
101
|
+
raise NotImplementedError(
|
102
|
+
"You must implement either the `get_dataset` method to return the dataset for the given phase, "
|
103
|
+
"or `get_data_iterator` to return an iterator for the given dataset."
|
104
|
+
)
|
105
|
+
|
106
|
+
def get_data_iterator(self, phase: Phase) -> Iterator:
|
107
|
+
raise NotImplementedError(
|
108
|
+
"You must implement either the `get_dataset` method to return the dataset for the given phase, "
|
109
|
+
"or `get_data_iterator` to return an iterator for the given dataset."
|
110
|
+
)
|
96
111
|
|
97
112
|
def get_dataloader(self, dataset: Dataset[T, Tc_co], phase: Phase) -> Dataloader[T, Tc_co]:
|
98
113
|
debugging = self.config.debug_dataloader
|
@@ -128,11 +143,7 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
|
|
128
143
|
)
|
129
144
|
|
130
145
|
def get_prefetcher(self, dataloader: Dataloader[T, Tc_co]) -> Prefetcher[Tc_co, Tc_co]:
|
131
|
-
return Prefetcher(to_device_func=
|
132
|
-
|
133
|
-
@classmethod
|
134
|
-
def to_device_fn(cls, sample: T) -> T:
|
135
|
-
return recursive_apply(sample, jax.device_put, include_numpy=True)
|
146
|
+
return Prefetcher(to_device_func=jax.device_put, dataloader=dataloader)
|
136
147
|
|
137
148
|
@classmethod
|
138
149
|
def dataloader_worker_init_fn(cls, worker_id: int, num_workers: int) -> None:
|
xax/task/mixins/gpu_stats.py
CHANGED
@@ -6,17 +6,19 @@ This logs GPU memory and utilization in a background process using
|
|
6
6
|
|
7
7
|
import functools
|
8
8
|
import logging
|
9
|
-
import multiprocessing as mp
|
10
9
|
import os
|
11
10
|
import re
|
12
11
|
import shutil
|
13
12
|
import subprocess
|
14
13
|
from ctypes import Structure, c_double, c_uint32
|
15
14
|
from dataclasses import dataclass
|
15
|
+
from multiprocessing.context import BaseContext, Process
|
16
16
|
from multiprocessing.managers import SyncManager, ValueProxy
|
17
17
|
from multiprocessing.synchronize import Event
|
18
18
|
from typing import Generic, Iterable, Pattern, TypeVar
|
19
19
|
|
20
|
+
import jax
|
21
|
+
|
20
22
|
from xax.core.conf import field
|
21
23
|
from xax.core.state import State
|
22
24
|
from xax.task.mixins.logger import LoggerConfig, LoggerMixin
|
@@ -25,12 +27,14 @@ from xax.task.mixins.process import ProcessConfig, ProcessMixin
|
|
25
27
|
logger: logging.Logger = logging.getLogger(__name__)
|
26
28
|
|
27
29
|
|
30
|
+
@jax.tree_util.register_dataclass
|
28
31
|
@dataclass
|
29
32
|
class GPUStatsOptions:
|
30
33
|
ping_interval: int = field(10, help="How often to check stats (in seconds)")
|
31
34
|
only_log_once: bool = field(False, help="If set, only log read stats one time")
|
32
35
|
|
33
36
|
|
37
|
+
@jax.tree_util.register_dataclass
|
34
38
|
@dataclass
|
35
39
|
class GPUStatsConfig(ProcessConfig, LoggerConfig):
|
36
40
|
gpu_stats: GPUStatsOptions = field(GPUStatsOptions(), help="GPU stats configuration")
|
@@ -147,8 +151,14 @@ def worker(
|
|
147
151
|
|
148
152
|
|
149
153
|
class GPUStatsMonitor:
|
150
|
-
def __init__(
|
154
|
+
def __init__(
|
155
|
+
self,
|
156
|
+
ping_interval: float,
|
157
|
+
context: BaseContext,
|
158
|
+
manager: SyncManager,
|
159
|
+
) -> None:
|
151
160
|
self._ping_interval = ping_interval
|
161
|
+
self._context = context
|
152
162
|
self._manager = manager
|
153
163
|
|
154
164
|
num_gpus = get_num_gpus()
|
@@ -169,7 +179,7 @@ class GPUStatsMonitor:
|
|
169
179
|
for i in range(num_gpus)
|
170
180
|
]
|
171
181
|
self._gpu_stats: dict[int, GPUStatsInfo] = {}
|
172
|
-
self._proc:
|
182
|
+
self._proc: Process | None = None
|
173
183
|
|
174
184
|
def get_if_set(self) -> dict[int, GPUStatsInfo]:
|
175
185
|
gpu_stats: dict[int, GPUStatsInfo] = {}
|
@@ -196,7 +206,7 @@ class GPUStatsMonitor:
|
|
196
206
|
if self._start_event.is_set():
|
197
207
|
self._start_event.clear()
|
198
208
|
self._gpu_stats.clear()
|
199
|
-
self._proc =
|
209
|
+
self._proc = self._context.Process( # type: ignore[attr-defined]
|
200
210
|
target=worker,
|
201
211
|
args=(self._ping_interval, self._smems, self._main_event, self._events, self._start_event),
|
202
212
|
daemon=True,
|
@@ -226,7 +236,11 @@ class GPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
|
|
226
236
|
|
227
237
|
self._gpu_stats_monitor = None
|
228
238
|
if shutil.which("nvidia-smi") is not None:
|
229
|
-
self._gpu_stats_monitor = GPUStatsMonitor(
|
239
|
+
self._gpu_stats_monitor = GPUStatsMonitor(
|
240
|
+
config.gpu_stats.ping_interval,
|
241
|
+
self._mp_ctx,
|
242
|
+
self._mp_manager,
|
243
|
+
)
|
230
244
|
|
231
245
|
def on_training_start(self, state: State) -> State:
|
232
246
|
state = super().on_training_start(state)
|
xax/task/mixins/logger.py
CHANGED
@@ -6,7 +6,8 @@ from pathlib import Path
|
|
6
6
|
from types import TracebackType
|
7
7
|
from typing import Generic, Self, TypeVar
|
8
8
|
|
9
|
-
|
9
|
+
import jax
|
10
|
+
|
10
11
|
from xax.core.state import State
|
11
12
|
from xax.task.base import BaseConfig, BaseTask
|
12
13
|
from xax.task.logger import Logger, LoggerImpl
|
@@ -18,9 +19,10 @@ from xax.task.mixins.artifacts import ArtifactsMixin
|
|
18
19
|
from xax.utils.text import is_interactive_session
|
19
20
|
|
20
21
|
|
22
|
+
@jax.tree_util.register_dataclass
|
21
23
|
@dataclass
|
22
24
|
class LoggerConfig(BaseConfig):
|
23
|
-
|
25
|
+
pass
|
24
26
|
|
25
27
|
|
26
28
|
Config = TypeVar("Config", bound=LoggerConfig)
|
xax/task/mixins/process.py
CHANGED
@@ -7,6 +7,8 @@ from multiprocessing.context import BaseContext
|
|
7
7
|
from multiprocessing.managers import SyncManager
|
8
8
|
from typing import Generic, TypeVar
|
9
9
|
|
10
|
+
import jax
|
11
|
+
|
10
12
|
from xax.core.conf import field
|
11
13
|
from xax.core.state import State
|
12
14
|
from xax.task.base import BaseConfig, BaseTask
|
@@ -14,9 +16,10 @@ from xax.task.base import BaseConfig, BaseTask
|
|
14
16
|
logger: logging.Logger = logging.getLogger(__name__)
|
15
17
|
|
16
18
|
|
19
|
+
@jax.tree_util.register_dataclass
|
17
20
|
@dataclass
|
18
21
|
class ProcessConfig(BaseConfig):
|
19
|
-
multiprocessing_context: str | None = field(
|
22
|
+
multiprocessing_context: str | None = field("spawn", help="The multiprocessing context to use")
|
20
23
|
|
21
24
|
|
22
25
|
Config = TypeVar("Config", bound=ProcessConfig)
|
xax/task/mixins/runnable.py
CHANGED
@@ -6,10 +6,13 @@ from dataclasses import dataclass
|
|
6
6
|
from types import FrameType
|
7
7
|
from typing import Callable, TypeVar
|
8
8
|
|
9
|
+
import jax
|
10
|
+
|
9
11
|
from xax.task.base import BaseConfig, BaseTask, RawConfigType
|
10
12
|
from xax.task.launchers.base import BaseLauncher
|
11
13
|
|
12
14
|
|
15
|
+
@jax.tree_util.register_dataclass
|
13
16
|
@dataclass
|
14
17
|
class RunnableConfig(BaseConfig):
|
15
18
|
pass
|
xax/task/mixins/step_wrapper.py
CHANGED
@@ -4,6 +4,9 @@ from dataclasses import dataclass
|
|
4
4
|
from types import TracebackType
|
5
5
|
from typing import ContextManager, Literal, TypeVar
|
6
6
|
|
7
|
+
import equinox as eqx
|
8
|
+
import jax
|
9
|
+
|
7
10
|
from xax.task.base import BaseConfig, BaseTask
|
8
11
|
|
9
12
|
StepType = Literal[
|
@@ -47,6 +50,7 @@ class StepContext(ContextManager):
|
|
47
50
|
StepContext.CURRENT_STEP = None
|
48
51
|
|
49
52
|
|
53
|
+
@jax.tree_util.register_dataclass
|
50
54
|
@dataclass
|
51
55
|
class StepContextConfig(BaseConfig):
|
52
56
|
pass
|
@@ -59,5 +63,6 @@ class StepContextMixin(BaseTask[Config]):
|
|
59
63
|
def __init__(self, config: Config) -> None:
|
60
64
|
super().__init__(config)
|
61
65
|
|
66
|
+
@eqx.filter_jit
|
62
67
|
def step_context(self, step: StepType) -> ContextManager:
|
63
68
|
return StepContext(step)
|