xax 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -9,6 +9,7 @@ from pathlib import Path
9
9
  from typing import Any, Callable, Generic, Literal, TypeVar, cast, overload
10
10
 
11
11
  import cloudpickle
12
+ import jax
12
13
  import optax
13
14
  from jaxtyping import PyTree
14
15
  from omegaconf import DictConfig, OmegaConf
@@ -38,6 +39,7 @@ def get_ckpt_path(exp_dir: Path, state: State | None = None) -> Path:
38
39
  return exp_dir / "checkpoints" / f"ckpt.{state.num_steps}.bin"
39
40
 
40
41
 
42
+ @jax.tree_util.register_dataclass
41
43
  @dataclass
42
44
  class CheckpointingConfig(ArtifactsConfig):
43
45
  save_every_n_steps: int | None = field(None, help="Save a checkpoint every N steps")
@@ -45,6 +47,7 @@ class CheckpointingConfig(ArtifactsConfig):
45
47
  only_save_most_recent: bool = field(True, help="Only keep the most recent checkpoint")
46
48
  load_from_ckpt_path: str | None = field(None, help="If set, load initial model weights from this path")
47
49
  load_ckpt_strict: bool = field(True, help="If set, only load weights for which have a matching key in the model")
50
+ save_tf_model: bool = field(False, help="If set, saves a Tensorflow version of the model")
48
51
 
49
52
 
50
53
  Config = TypeVar("Config", bound=CheckpointingConfig)
@@ -196,6 +199,15 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
196
199
  add_file("state", lambda buf: buf.write(json.dumps(asdict(state), indent=2).encode()))
197
200
  add_file("config", lambda buf: buf.write(OmegaConf.to_yaml(self.config).encode()))
198
201
 
202
+ if self.config.save_tf_model:
203
+ try:
204
+ from jax.experimental import jax2tf
205
+ except ModuleNotFoundError:
206
+ raise ImportError("Tensorflow is not installed. Install it with `pip install tensorflow`")
207
+
208
+ tf_model = jax2tf.convert(model)
209
+ add_file("model.tf", lambda buf: cloudpickle.dump(tf_model, buf))
210
+
199
211
  # Updates the symlink to the new checkpoint.
200
212
  last_ckpt_path.unlink(missing_ok=True)
201
213
  try:
@@ -0,0 +1,104 @@
1
+ """Defines a mixin for handling JAX compilation behavior.
2
+
3
+ This mixin allows control over JAX compilation settings like jit, pmap, and vmap
4
+ behavior during initialization and training.
5
+ """
6
+
7
+ import logging
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import Generic, TypeVar
11
+
12
+ import jax
13
+
14
+ from xax.core.conf import field
15
+ from xax.task.base import BaseConfig, BaseTask
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ @jax.tree_util.register_dataclass
21
+ @dataclass
22
+ class CompileOptions:
23
+ # JAX compilation options
24
+ disable_jit: bool = field(
25
+ value=False,
26
+ help="If True, disables JIT compilation",
27
+ )
28
+ enable_x64: bool = field(
29
+ value=False,
30
+ help="If True, enables 64-bit precision",
31
+ )
32
+ default_device: str | None = field(
33
+ value=None,
34
+ help="Default device to use (e.g. 'cpu', 'gpu')",
35
+ )
36
+
37
+ # JAX logging options
38
+ logging_level: str = field(
39
+ value="INFO",
40
+ help="JAX logging verbosity level",
41
+ )
42
+
43
+ # JAX cache options
44
+ cache_dir: str | None = field(
45
+ value=lambda: str((Path.home() / ".cache" / "jax" / "jaxcache").resolve()),
46
+ help="Directory for JAX compilation cache. If None, caching is disabled",
47
+ )
48
+ cache_min_size_bytes: int = field(
49
+ value=-1,
50
+ help="Minimum size in bytes for cache entries. -1 means no minimum",
51
+ )
52
+ cache_min_compile_time_secs: float = field(
53
+ value=0.0,
54
+ help="Minimum compilation time in seconds for cache entries. 0 means no minimum",
55
+ )
56
+ cache_enable_xla: str = field(
57
+ value="all",
58
+ help="Which XLA caches to enable",
59
+ )
60
+
61
+
62
+ @jax.tree_util.register_dataclass
63
+ @dataclass
64
+ class CompileConfig(BaseConfig):
65
+ compile: CompileOptions = field(CompileOptions(), help="Compilation configuration")
66
+
67
+
68
+ Config = TypeVar("Config", bound=CompileConfig)
69
+
70
+
71
+ class CompileMixin(BaseTask[Config], Generic[Config]):
72
+ """Defines a task mixin for controlling JAX compilation behavior."""
73
+
74
+ def __init__(self, config: Config) -> None:
75
+ super().__init__(config)
76
+
77
+ cc = self.config.compile
78
+
79
+ # Set basic compilation flags
80
+ if cc.disable_jit:
81
+ logger.info("Disabling JIT compilation")
82
+ jax.config.update("jax_disable_jit", True)
83
+
84
+ if cc.enable_x64:
85
+ logger.info("Enabling 64-bit precision")
86
+ jax.config.update("jax_enable_x64", True)
87
+
88
+ if cc.default_device is not None:
89
+ logger.info("Setting default device to %s", cc.default_device)
90
+ jax.config.update("jax_default_device", cc.default_device)
91
+
92
+ # Set logging level
93
+ logger.info("Setting JAX logging level to %s", cc.logging_level)
94
+ jax.config.update("jax_logging_level", cc.logging_level)
95
+
96
+ # Configure compilation cache
97
+ if cc.cache_dir is not None:
98
+ logger.info("Setting JAX compilation cache directory to %s", cc.cache_dir)
99
+ jax.config.update("jax_compilation_cache_dir", cc.cache_dir)
100
+
101
+ logger.info("Configuring JAX compilation cache parameters")
102
+ jax.config.update("jax_persistent_cache_min_entry_size_bytes", cc.cache_min_size_bytes)
103
+ jax.config.update("jax_persistent_cache_min_compile_time_secs", cc.cache_min_compile_time_secs)
104
+ jax.config.update("jax_persistent_cache_enable_xla_caches", cc.cache_enable_xla)
@@ -6,15 +6,16 @@ leaks in your dataloader, among other issues.
6
6
  """
7
7
 
8
8
  import logging
9
- import multiprocessing as mp
10
9
  import os
11
10
  import time
12
11
  from ctypes import Structure, c_double, c_uint16, c_uint64
13
12
  from dataclasses import dataclass
13
+ from multiprocessing.context import BaseContext, Process
14
14
  from multiprocessing.managers import SyncManager, ValueProxy
15
15
  from multiprocessing.synchronize import Event
16
16
  from typing import Generic, TypeVar
17
17
 
18
+ import jax
18
19
  import psutil
19
20
 
20
21
  from xax.core.conf import field
@@ -26,12 +27,14 @@ from xax.task.mixins.process import ProcessConfig, ProcessMixin
26
27
  logger: logging.Logger = logging.getLogger(__name__)
27
28
 
28
29
 
30
+ @jax.tree_util.register_dataclass
29
31
  @dataclass
30
32
  class CPUStatsOptions:
31
33
  ping_interval: int = field(1, help="How often to check stats (in seconds)")
32
34
  only_log_once: bool = field(False, help="If set, only log read stats one time")
33
35
 
34
36
 
37
+ @jax.tree_util.register_dataclass
35
38
  @dataclass
36
39
  class CPUStatsConfig(ProcessConfig, LoggerConfig, BaseConfig):
37
40
  cpu_stats: CPUStatsOptions = field(CPUStatsOptions(), help="CPU stats configuration")
@@ -55,7 +58,7 @@ class CPUStats(Structure):
55
58
  ]
56
59
 
57
60
 
58
- @dataclass
61
+ @dataclass(kw_only=True)
59
62
  class CPUStatsInfo:
60
63
  cpu_percent: float
61
64
  mem_percent: float
@@ -142,9 +145,16 @@ def worker(
142
145
 
143
146
 
144
147
  class CPUStatsMonitor:
145
- def __init__(self, ping_interval: float, manager: SyncManager) -> None:
148
+ def __init__(
149
+ self,
150
+ ping_interval: float,
151
+ context: BaseContext,
152
+ manager: SyncManager,
153
+ ) -> None:
146
154
  self._ping_interval = ping_interval
147
155
  self._manager = manager
156
+ self._context = context
157
+
148
158
  self._monitor_event = self._manager.Event()
149
159
  self._start_event = self._manager.Event()
150
160
  self._cpu_stats_smem = self._manager.Value(
@@ -163,7 +173,7 @@ class CPUStatsMonitor:
163
173
  ),
164
174
  )
165
175
  self._cpu_stats: CPUStatsInfo | None = None
166
- self._proc: mp.Process | None = None
176
+ self._proc: Process | None = None
167
177
 
168
178
  def get_if_set(self) -> CPUStatsInfo | None:
169
179
  if self._monitor_event.is_set():
@@ -184,7 +194,7 @@ class CPUStatsMonitor:
184
194
  if self._start_event.is_set():
185
195
  self._start_event.clear()
186
196
  self._cpu_stats = None
187
- self._proc = mp.Process(
197
+ self._proc = self._context.Process( # type: ignore[attr-defined]
188
198
  target=worker,
189
199
  args=(self._ping_interval, self._cpu_stats_smem, self._monitor_event, self._start_event, os.getpid()),
190
200
  daemon=True,
@@ -215,6 +225,7 @@ class CPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
215
225
 
216
226
  self._cpu_stats_monitor = CPUStatsMonitor(
217
227
  ping_interval=self.config.cpu_stats.ping_interval,
228
+ context=self._mp_ctx,
218
229
  manager=self._mp_manager,
219
230
  )
220
231
 
@@ -1,9 +1,9 @@
1
1
  """Defines a mixin for instantiating dataloaders."""
2
2
 
3
3
  import logging
4
- from abc import ABC, abstractmethod
4
+ from abc import ABC
5
5
  from dataclasses import dataclass
6
- from typing import Generic, TypeVar
6
+ from typing import Generic, Iterator, TypeVar
7
7
 
8
8
  import jax
9
9
  from dpshdl.dataloader import CollatedDataloaderItem, Dataloader
@@ -13,7 +13,7 @@ from omegaconf import II, MISSING
13
13
 
14
14
  from xax.core.conf import field, is_missing
15
15
  from xax.core.state import Phase
16
- from xax.nn.functions import recursive_apply, set_random_seed
16
+ from xax.nn.functions import set_random_seed
17
17
  from xax.task.base import BaseConfig, BaseTask
18
18
  from xax.task.mixins.process import ProcessConfig, ProcessMixin
19
19
  from xax.utils.logging import LOG_ERROR_SUMMARY, configure_logging
@@ -24,6 +24,7 @@ T = TypeVar("T")
24
24
  Tc_co = TypeVar("Tc_co", covariant=True)
25
25
 
26
26
 
27
+ @jax.tree_util.register_dataclass
27
28
  @dataclass
28
29
  class DataloaderErrorConfig:
29
30
  sleep_backoff: float = field(0.1, help="The initial sleep time after an exception")
@@ -36,6 +37,7 @@ class DataloaderErrorConfig:
36
37
  log_exceptions_all_workers: bool = field(False, help="If set, log exceptions from all workers")
37
38
 
38
39
 
40
+ @jax.tree_util.register_dataclass
39
41
  @dataclass
40
42
  class DataloaderConfig:
41
43
  num_workers: int | None = field(MISSING, help="Number of workers for loading samples")
@@ -43,6 +45,7 @@ class DataloaderConfig:
43
45
  error: DataloaderErrorConfig = field(DataloaderErrorConfig(), help="Dataloader error configuration")
44
46
 
45
47
 
48
+ @jax.tree_util.register_dataclass
46
49
  @dataclass
47
50
  class DataloadersConfig(ProcessConfig, BaseConfig):
48
51
  batch_size: int = field(MISSING, help="Size of each batch")
@@ -63,9 +66,6 @@ Config = TypeVar("Config", bound=DataloadersConfig)
63
66
 
64
67
  class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config], ABC):
65
68
  def __init__(self, config: Config) -> None:
66
- if is_missing(config, "batch_size"):
67
- config.batch_size = self.get_batch_size()
68
-
69
69
  super().__init__(config)
70
70
 
71
71
  def get_batch_size(self) -> int:
@@ -74,6 +74,12 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
74
74
  "method to return the desired training batch size."
75
75
  )
76
76
 
77
+ @property
78
+ def batch_size(self) -> int:
79
+ if is_missing(self.config, "batch_size"):
80
+ self.config.batch_size = self.get_batch_size()
81
+ return self.config.batch_size
82
+
77
83
  def dataloader_config(self, phase: Phase) -> DataloaderConfig:
78
84
  match phase:
79
85
  case "train":
@@ -83,7 +89,6 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
83
89
  case _:
84
90
  raise KeyError(f"Unknown phase: {phase}")
85
91
 
86
- @abstractmethod
87
92
  def get_dataset(self, phase: Phase) -> Dataset:
88
93
  """Returns the dataset for the given phase.
89
94
 
@@ -93,6 +98,16 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
93
98
  Returns:
94
99
  The dataset for the given phase.
95
100
  """
101
+ raise NotImplementedError(
102
+ "You must implement either the `get_dataset` method to return the dataset for the given phase, "
103
+ "or `get_data_iterator` to return an iterator for the given dataset."
104
+ )
105
+
106
+ def get_data_iterator(self, phase: Phase) -> Iterator:
107
+ raise NotImplementedError(
108
+ "You must implement either the `get_dataset` method to return the dataset for the given phase, "
109
+ "or `get_data_iterator` to return an iterator for the given dataset."
110
+ )
96
111
 
97
112
  def get_dataloader(self, dataset: Dataset[T, Tc_co], phase: Phase) -> Dataloader[T, Tc_co]:
98
113
  debugging = self.config.debug_dataloader
@@ -128,11 +143,7 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
128
143
  )
129
144
 
130
145
  def get_prefetcher(self, dataloader: Dataloader[T, Tc_co]) -> Prefetcher[Tc_co, Tc_co]:
131
- return Prefetcher(to_device_func=self.to_device_fn, dataloader=dataloader)
132
-
133
- @classmethod
134
- def to_device_fn(cls, sample: T) -> T:
135
- return recursive_apply(sample, jax.device_put, include_numpy=True)
146
+ return Prefetcher(to_device_func=jax.device_put, dataloader=dataloader)
136
147
 
137
148
  @classmethod
138
149
  def dataloader_worker_init_fn(cls, worker_id: int, num_workers: int) -> None:
@@ -6,17 +6,19 @@ This logs GPU memory and utilization in a background process using
6
6
 
7
7
  import functools
8
8
  import logging
9
- import multiprocessing as mp
10
9
  import os
11
10
  import re
12
11
  import shutil
13
12
  import subprocess
14
13
  from ctypes import Structure, c_double, c_uint32
15
14
  from dataclasses import dataclass
15
+ from multiprocessing.context import BaseContext, Process
16
16
  from multiprocessing.managers import SyncManager, ValueProxy
17
17
  from multiprocessing.synchronize import Event
18
18
  from typing import Generic, Iterable, Pattern, TypeVar
19
19
 
20
+ import jax
21
+
20
22
  from xax.core.conf import field
21
23
  from xax.core.state import State
22
24
  from xax.task.mixins.logger import LoggerConfig, LoggerMixin
@@ -25,12 +27,14 @@ from xax.task.mixins.process import ProcessConfig, ProcessMixin
25
27
  logger: logging.Logger = logging.getLogger(__name__)
26
28
 
27
29
 
30
+ @jax.tree_util.register_dataclass
28
31
  @dataclass
29
32
  class GPUStatsOptions:
30
33
  ping_interval: int = field(10, help="How often to check stats (in seconds)")
31
34
  only_log_once: bool = field(False, help="If set, only log read stats one time")
32
35
 
33
36
 
37
+ @jax.tree_util.register_dataclass
34
38
  @dataclass
35
39
  class GPUStatsConfig(ProcessConfig, LoggerConfig):
36
40
  gpu_stats: GPUStatsOptions = field(GPUStatsOptions(), help="GPU stats configuration")
@@ -147,8 +151,14 @@ def worker(
147
151
 
148
152
 
149
153
  class GPUStatsMonitor:
150
- def __init__(self, ping_interval: float, manager: SyncManager) -> None:
154
+ def __init__(
155
+ self,
156
+ ping_interval: float,
157
+ context: BaseContext,
158
+ manager: SyncManager,
159
+ ) -> None:
151
160
  self._ping_interval = ping_interval
161
+ self._context = context
152
162
  self._manager = manager
153
163
 
154
164
  num_gpus = get_num_gpus()
@@ -169,7 +179,7 @@ class GPUStatsMonitor:
169
179
  for i in range(num_gpus)
170
180
  ]
171
181
  self._gpu_stats: dict[int, GPUStatsInfo] = {}
172
- self._proc: mp.Process | None = None
182
+ self._proc: Process | None = None
173
183
 
174
184
  def get_if_set(self) -> dict[int, GPUStatsInfo]:
175
185
  gpu_stats: dict[int, GPUStatsInfo] = {}
@@ -196,7 +206,7 @@ class GPUStatsMonitor:
196
206
  if self._start_event.is_set():
197
207
  self._start_event.clear()
198
208
  self._gpu_stats.clear()
199
- self._proc = mp.Process(
209
+ self._proc = self._context.Process( # type: ignore[attr-defined]
200
210
  target=worker,
201
211
  args=(self._ping_interval, self._smems, self._main_event, self._events, self._start_event),
202
212
  daemon=True,
@@ -226,7 +236,11 @@ class GPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
226
236
 
227
237
  self._gpu_stats_monitor = None
228
238
  if shutil.which("nvidia-smi") is not None:
229
- self._gpu_stats_monitor = GPUStatsMonitor(config.gpu_stats.ping_interval, self._mp_manager)
239
+ self._gpu_stats_monitor = GPUStatsMonitor(
240
+ config.gpu_stats.ping_interval,
241
+ self._mp_ctx,
242
+ self._mp_manager,
243
+ )
230
244
 
231
245
  def on_training_start(self, state: State) -> State:
232
246
  state = super().on_training_start(state)
xax/task/mixins/logger.py CHANGED
@@ -6,7 +6,8 @@ from pathlib import Path
6
6
  from types import TracebackType
7
7
  from typing import Generic, Self, TypeVar
8
8
 
9
- from xax.core.conf import Device as BaseDeviceConfig, field
9
+ import jax
10
+
10
11
  from xax.core.state import State
11
12
  from xax.task.base import BaseConfig, BaseTask
12
13
  from xax.task.logger import Logger, LoggerImpl
@@ -18,9 +19,10 @@ from xax.task.mixins.artifacts import ArtifactsMixin
18
19
  from xax.utils.text import is_interactive_session
19
20
 
20
21
 
22
+ @jax.tree_util.register_dataclass
21
23
  @dataclass
22
24
  class LoggerConfig(BaseConfig):
23
- device: BaseDeviceConfig = field(BaseDeviceConfig(), help="Device configuration")
25
+ pass
24
26
 
25
27
 
26
28
  Config = TypeVar("Config", bound=LoggerConfig)
@@ -7,6 +7,8 @@ from multiprocessing.context import BaseContext
7
7
  from multiprocessing.managers import SyncManager
8
8
  from typing import Generic, TypeVar
9
9
 
10
+ import jax
11
+
10
12
  from xax.core.conf import field
11
13
  from xax.core.state import State
12
14
  from xax.task.base import BaseConfig, BaseTask
@@ -14,9 +16,10 @@ from xax.task.base import BaseConfig, BaseTask
14
16
  logger: logging.Logger = logging.getLogger(__name__)
15
17
 
16
18
 
19
+ @jax.tree_util.register_dataclass
17
20
  @dataclass
18
21
  class ProcessConfig(BaseConfig):
19
- multiprocessing_context: str | None = field(None, help="The multiprocessing context to use")
22
+ multiprocessing_context: str | None = field("spawn", help="The multiprocessing context to use")
20
23
 
21
24
 
22
25
  Config = TypeVar("Config", bound=ProcessConfig)
@@ -6,10 +6,13 @@ from dataclasses import dataclass
6
6
  from types import FrameType
7
7
  from typing import Callable, TypeVar
8
8
 
9
+ import jax
10
+
9
11
  from xax.task.base import BaseConfig, BaseTask, RawConfigType
10
12
  from xax.task.launchers.base import BaseLauncher
11
13
 
12
14
 
15
+ @jax.tree_util.register_dataclass
13
16
  @dataclass
14
17
  class RunnableConfig(BaseConfig):
15
18
  pass
@@ -4,6 +4,9 @@ from dataclasses import dataclass
4
4
  from types import TracebackType
5
5
  from typing import ContextManager, Literal, TypeVar
6
6
 
7
+ import equinox as eqx
8
+ import jax
9
+
7
10
  from xax.task.base import BaseConfig, BaseTask
8
11
 
9
12
  StepType = Literal[
@@ -47,6 +50,7 @@ class StepContext(ContextManager):
47
50
  StepContext.CURRENT_STEP = None
48
51
 
49
52
 
53
+ @jax.tree_util.register_dataclass
50
54
  @dataclass
51
55
  class StepContextConfig(BaseConfig):
52
56
  pass
@@ -59,5 +63,6 @@ class StepContextMixin(BaseTask[Config]):
59
63
  def __init__(self, config: Config) -> None:
60
64
  super().__init__(config)
61
65
 
66
+ @eqx.filter_jit
62
67
  def step_context(self, step: StepType) -> ContextManager:
63
68
  return StepContext(step)