xax 0.0.3__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -6,15 +6,16 @@ leaks in your dataloader, among other issues.
6
6
  """
7
7
 
8
8
  import logging
9
- import multiprocessing as mp
10
9
  import os
11
10
  import time
12
11
  from ctypes import Structure, c_double, c_uint16, c_uint64
13
12
  from dataclasses import dataclass
13
+ from multiprocessing.context import BaseContext, Process
14
14
  from multiprocessing.managers import SyncManager, ValueProxy
15
15
  from multiprocessing.synchronize import Event
16
16
  from typing import Generic, TypeVar
17
17
 
18
+ import jax
18
19
  import psutil
19
20
 
20
21
  from xax.core.conf import field
@@ -26,12 +27,14 @@ from xax.task.mixins.process import ProcessConfig, ProcessMixin
26
27
  logger: logging.Logger = logging.getLogger(__name__)
27
28
 
28
29
 
30
+ @jax.tree_util.register_dataclass
29
31
  @dataclass
30
32
  class CPUStatsOptions:
31
33
  ping_interval: int = field(1, help="How often to check stats (in seconds)")
32
34
  only_log_once: bool = field(False, help="If set, only log read stats one time")
33
35
 
34
36
 
37
+ @jax.tree_util.register_dataclass
35
38
  @dataclass
36
39
  class CPUStatsConfig(ProcessConfig, LoggerConfig, BaseConfig):
37
40
  cpu_stats: CPUStatsOptions = field(CPUStatsOptions(), help="CPU stats configuration")
@@ -55,7 +58,7 @@ class CPUStats(Structure):
55
58
  ]
56
59
 
57
60
 
58
- @dataclass
61
+ @dataclass(kw_only=True)
59
62
  class CPUStatsInfo:
60
63
  cpu_percent: float
61
64
  mem_percent: float
@@ -142,9 +145,16 @@ def worker(
142
145
 
143
146
 
144
147
  class CPUStatsMonitor:
145
- def __init__(self, ping_interval: float, manager: SyncManager) -> None:
148
+ def __init__(
149
+ self,
150
+ ping_interval: float,
151
+ context: BaseContext,
152
+ manager: SyncManager,
153
+ ) -> None:
146
154
  self._ping_interval = ping_interval
147
155
  self._manager = manager
156
+ self._context = context
157
+
148
158
  self._monitor_event = self._manager.Event()
149
159
  self._start_event = self._manager.Event()
150
160
  self._cpu_stats_smem = self._manager.Value(
@@ -163,7 +173,7 @@ class CPUStatsMonitor:
163
173
  ),
164
174
  )
165
175
  self._cpu_stats: CPUStatsInfo | None = None
166
- self._proc: mp.Process | None = None
176
+ self._proc: Process | None = None
167
177
 
168
178
  def get_if_set(self) -> CPUStatsInfo | None:
169
179
  if self._monitor_event.is_set():
@@ -184,7 +194,7 @@ class CPUStatsMonitor:
184
194
  if self._start_event.is_set():
185
195
  self._start_event.clear()
186
196
  self._cpu_stats = None
187
- self._proc = mp.Process(
197
+ self._proc = self._context.Process( # type: ignore[attr-defined]
188
198
  target=worker,
189
199
  args=(self._ping_interval, self._cpu_stats_smem, self._monitor_event, self._start_event, os.getpid()),
190
200
  daemon=True,
@@ -215,6 +225,7 @@ class CPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
215
225
 
216
226
  self._cpu_stats_monitor = CPUStatsMonitor(
217
227
  ping_interval=self.config.cpu_stats.ping_interval,
228
+ context=self._mp_ctx,
218
229
  manager=self._mp_manager,
219
230
  )
220
231
 
@@ -237,15 +248,15 @@ class CPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
237
248
  stats = monitor.get_if_set() if self.config.cpu_stats.only_log_once else monitor.get()
238
249
 
239
250
  if stats is not None:
240
- self.log_scalar("child_procs", stats.num_child_procs, namespace="🔧 cpu")
241
- self.log_scalar("percent", stats.cpu_percent, namespace="🔧 cpu")
242
- self.log_scalar("child_percent", stats.child_cpu_percent, namespace="🔧 cpu")
243
- self.log_scalar("percent", stats.mem_percent, namespace="🔧 mem")
244
- self.log_scalar("shared", stats.mem_shared, namespace="🔧 mem")
245
- self.log_scalar("child_percent", stats.child_mem_percent, namespace="🔧 mem")
246
- self.log_scalar("rss/cur", stats.mem_rss, namespace="🔧 mem")
247
- self.log_scalar("rss/total", stats.mem_rss_total, namespace="🔧 mem")
248
- self.log_scalar("vms/cur", stats.mem_vms, namespace="🔧 mem")
249
- self.log_scalar("vms/total", stats.mem_vms_total, namespace="🔧 mem")
251
+ self.logger.log_scalar("child_procs", stats.num_child_procs, namespace="🔧 cpu")
252
+ self.logger.log_scalar("percent", stats.cpu_percent, namespace="🔧 cpu")
253
+ self.logger.log_scalar("child_percent", stats.child_cpu_percent, namespace="🔧 cpu")
254
+ self.logger.log_scalar("percent", stats.mem_percent, namespace="🔧 mem")
255
+ self.logger.log_scalar("shared", stats.mem_shared, namespace="🔧 mem")
256
+ self.logger.log_scalar("child_percent", stats.child_mem_percent, namespace="🔧 mem")
257
+ self.logger.log_scalar("rss/cur", stats.mem_rss, namespace="🔧 mem")
258
+ self.logger.log_scalar("rss/total", stats.mem_rss_total, namespace="🔧 mem")
259
+ self.logger.log_scalar("vms/cur", stats.mem_vms, namespace="🔧 mem")
260
+ self.logger.log_scalar("vms/total", stats.mem_vms_total, namespace="🔧 mem")
250
261
 
251
262
  return state
@@ -1,9 +1,9 @@
1
1
  """Defines a mixin for instantiating dataloaders."""
2
2
 
3
3
  import logging
4
- from abc import ABC, abstractmethod
4
+ from abc import ABC
5
5
  from dataclasses import dataclass
6
- from typing import Generic, TypeVar
6
+ from typing import Generic, Iterator, TypeVar
7
7
 
8
8
  import jax
9
9
  from dpshdl.dataloader import CollatedDataloaderItem, Dataloader
@@ -13,7 +13,7 @@ from omegaconf import II, MISSING
13
13
 
14
14
  from xax.core.conf import field, is_missing
15
15
  from xax.core.state import Phase
16
- from xax.nn.functions import recursive_apply, set_random_seed
16
+ from xax.nn.functions import set_random_seed
17
17
  from xax.task.base import BaseConfig, BaseTask
18
18
  from xax.task.mixins.process import ProcessConfig, ProcessMixin
19
19
  from xax.utils.logging import LOG_ERROR_SUMMARY, configure_logging
@@ -24,6 +24,7 @@ T = TypeVar("T")
24
24
  Tc_co = TypeVar("Tc_co", covariant=True)
25
25
 
26
26
 
27
+ @jax.tree_util.register_dataclass
27
28
  @dataclass
28
29
  class DataloaderErrorConfig:
29
30
  sleep_backoff: float = field(0.1, help="The initial sleep time after an exception")
@@ -36,24 +37,25 @@ class DataloaderErrorConfig:
36
37
  log_exceptions_all_workers: bool = field(False, help="If set, log exceptions from all workers")
37
38
 
38
39
 
40
+ @jax.tree_util.register_dataclass
39
41
  @dataclass
40
42
  class DataloaderConfig:
41
- batch_size: int = field(MISSING, help="Size of each batch")
42
43
  num_workers: int | None = field(MISSING, help="Number of workers for loading samples")
43
44
  prefetch_factor: int = field(2, help="Number of items to pre-fetch on each worker")
44
45
  error: DataloaderErrorConfig = field(DataloaderErrorConfig(), help="Dataloader error configuration")
45
46
 
46
47
 
48
+ @jax.tree_util.register_dataclass
47
49
  @dataclass
48
50
  class DataloadersConfig(ProcessConfig, BaseConfig):
49
51
  batch_size: int = field(MISSING, help="Size of each batch")
50
52
  raise_dataloader_errors: bool = field(False, help="If set, raise dataloader errors inside the worker processes")
51
53
  train_dl: DataloaderConfig = field(
52
- DataloaderConfig(batch_size=II("batch_size")),
54
+ DataloaderConfig(num_workers=II("mlfab.num_workers:-1")),
53
55
  help="Train dataloader config",
54
56
  )
55
57
  valid_dl: DataloaderConfig = field(
56
- DataloaderConfig(batch_size=II("batch_size"), num_workers=1),
58
+ DataloaderConfig(num_workers=1),
57
59
  help="Valid dataloader config",
58
60
  )
59
61
  debug_dataloader: bool = field(False, help="Debug dataloaders")
@@ -64,11 +66,6 @@ Config = TypeVar("Config", bound=DataloadersConfig)
64
66
 
65
67
  class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config], ABC):
66
68
  def __init__(self, config: Config) -> None:
67
- if is_missing(config, "batch_size") and (
68
- is_missing(config.train_dl, "batch_size") or is_missing(config.valid_dl, "batch_size")
69
- ):
70
- config.batch_size = self.get_batch_size()
71
-
72
69
  super().__init__(config)
73
70
 
74
71
  def get_batch_size(self) -> int:
@@ -77,6 +74,12 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
77
74
  "method to return the desired training batch size."
78
75
  )
79
76
 
77
+ @property
78
+ def batch_size(self) -> int:
79
+ if is_missing(self.config, "batch_size"):
80
+ self.config.batch_size = self.get_batch_size()
81
+ return self.config.batch_size
82
+
80
83
  def dataloader_config(self, phase: Phase) -> DataloaderConfig:
81
84
  match phase:
82
85
  case "train":
@@ -86,7 +89,6 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
86
89
  case _:
87
90
  raise KeyError(f"Unknown phase: {phase}")
88
91
 
89
- @abstractmethod
90
92
  def get_dataset(self, phase: Phase) -> Dataset:
91
93
  """Returns the dataset for the given phase.
92
94
 
@@ -96,6 +98,16 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
96
98
  Returns:
97
99
  The dataset for the given phase.
98
100
  """
101
+ raise NotImplementedError(
102
+ "You must implement either the `get_dataset` method to return the dataset for the given phase, "
103
+ "or `get_data_iterator` to return an iterator for the given dataset."
104
+ )
105
+
106
+ def get_data_iterator(self, phase: Phase) -> Iterator:
107
+ raise NotImplementedError(
108
+ "You must implement either the `get_dataset` method to return the dataset for the given phase, "
109
+ "or `get_data_iterator` to return an iterator for the given dataset."
110
+ )
99
111
 
100
112
  def get_dataloader(self, dataset: Dataset[T, Tc_co], phase: Phase) -> Dataloader[T, Tc_co]:
101
113
  debugging = self.config.debug_dataloader
@@ -120,10 +132,10 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
120
132
 
121
133
  return Dataloader(
122
134
  dataset=dataset,
123
- batch_size=cfg.batch_size,
135
+ batch_size=self.config.batch_size,
124
136
  num_workers=0 if debugging else cfg.num_workers,
125
137
  prefetch_factor=cfg.prefetch_factor,
126
- ctx=self.multiprocessing_context,
138
+ mp_manager=self.multiprocessing_manager,
127
139
  dataloader_worker_init_fn=self.dataloader_worker_init_fn,
128
140
  collate_worker_init_fn=self.collate_worker_init_fn,
129
141
  item_callback=self.dataloader_item_callback,
@@ -131,11 +143,7 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
131
143
  )
132
144
 
133
145
  def get_prefetcher(self, dataloader: Dataloader[T, Tc_co]) -> Prefetcher[Tc_co, Tc_co]:
134
- return Prefetcher(to_device_func=self.to_device_fn, dataloader=dataloader)
135
-
136
- @classmethod
137
- def to_device_fn(cls, sample: T) -> T:
138
- return recursive_apply(sample, jax.device_put)
146
+ return Prefetcher(to_device_func=jax.device_put, dataloader=dataloader)
139
147
 
140
148
  @classmethod
141
149
  def dataloader_worker_init_fn(cls, worker_id: int, num_workers: int) -> None:
@@ -6,17 +6,19 @@ This logs GPU memory and utilization in a background process using
6
6
 
7
7
  import functools
8
8
  import logging
9
- import multiprocessing as mp
10
9
  import os
11
10
  import re
12
11
  import shutil
13
12
  import subprocess
14
13
  from ctypes import Structure, c_double, c_uint32
15
14
  from dataclasses import dataclass
15
+ from multiprocessing.context import BaseContext, Process
16
16
  from multiprocessing.managers import SyncManager, ValueProxy
17
17
  from multiprocessing.synchronize import Event
18
18
  from typing import Generic, Iterable, Pattern, TypeVar
19
19
 
20
+ import jax
21
+
20
22
  from xax.core.conf import field
21
23
  from xax.core.state import State
22
24
  from xax.task.mixins.logger import LoggerConfig, LoggerMixin
@@ -25,12 +27,14 @@ from xax.task.mixins.process import ProcessConfig, ProcessMixin
25
27
  logger: logging.Logger = logging.getLogger(__name__)
26
28
 
27
29
 
30
+ @jax.tree_util.register_dataclass
28
31
  @dataclass
29
32
  class GPUStatsOptions:
30
33
  ping_interval: int = field(10, help="How often to check stats (in seconds)")
31
34
  only_log_once: bool = field(False, help="If set, only log read stats one time")
32
35
 
33
36
 
37
+ @jax.tree_util.register_dataclass
34
38
  @dataclass
35
39
  class GPUStatsConfig(ProcessConfig, LoggerConfig):
36
40
  gpu_stats: GPUStatsOptions = field(GPUStatsOptions(), help="GPU stats configuration")
@@ -147,8 +151,14 @@ def worker(
147
151
 
148
152
 
149
153
  class GPUStatsMonitor:
150
- def __init__(self, ping_interval: float, manager: SyncManager) -> None:
154
+ def __init__(
155
+ self,
156
+ ping_interval: float,
157
+ context: BaseContext,
158
+ manager: SyncManager,
159
+ ) -> None:
151
160
  self._ping_interval = ping_interval
161
+ self._context = context
152
162
  self._manager = manager
153
163
 
154
164
  num_gpus = get_num_gpus()
@@ -169,7 +179,7 @@ class GPUStatsMonitor:
169
179
  for i in range(num_gpus)
170
180
  ]
171
181
  self._gpu_stats: dict[int, GPUStatsInfo] = {}
172
- self._proc: mp.Process | None = None
182
+ self._proc: Process | None = None
173
183
 
174
184
  def get_if_set(self) -> dict[int, GPUStatsInfo]:
175
185
  gpu_stats: dict[int, GPUStatsInfo] = {}
@@ -196,7 +206,7 @@ class GPUStatsMonitor:
196
206
  if self._start_event.is_set():
197
207
  self._start_event.clear()
198
208
  self._gpu_stats.clear()
199
- self._proc = mp.Process(
209
+ self._proc = self._context.Process( # type: ignore[attr-defined]
200
210
  target=worker,
201
211
  args=(self._ping_interval, self._smems, self._main_event, self._events, self._start_event),
202
212
  daemon=True,
@@ -226,7 +236,11 @@ class GPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
226
236
 
227
237
  self._gpu_stats_monitor = None
228
238
  if shutil.which("nvidia-smi") is not None:
229
- self._gpu_stats_monitor = GPUStatsMonitor(config.gpu_stats.ping_interval, self._mp_manager)
239
+ self._gpu_stats_monitor = GPUStatsMonitor(
240
+ config.gpu_stats.ping_interval,
241
+ self._mp_ctx,
242
+ self._mp_manager,
243
+ )
230
244
 
231
245
  def on_training_start(self, state: State) -> State:
232
246
  state = super().on_training_start(state)
@@ -250,8 +264,8 @@ class GPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
250
264
  for gpu_stat in stats.values():
251
265
  if gpu_stat is None:
252
266
  continue
253
- self.log_scalar(f"mem/{gpu_stat.index}", gpu_stat.memory_used, namespace="🔧 gpu")
254
- self.log_scalar(f"temp/{gpu_stat.index}", gpu_stat.temperature, namespace="🔧 gpu")
255
- self.log_scalar(f"util/{gpu_stat.index}", gpu_stat.utilization, namespace="🔧 gpu")
267
+ self.logger.log_scalar(f"mem/{gpu_stat.index}", gpu_stat.memory_used, namespace="🔧 gpu")
268
+ self.logger.log_scalar(f"temp/{gpu_stat.index}", gpu_stat.temperature, namespace="🔧 gpu")
269
+ self.logger.log_scalar(f"util/{gpu_stat.index}", gpu_stat.utilization, namespace="🔧 gpu")
256
270
 
257
271
  return state
xax/task/mixins/logger.py CHANGED
@@ -4,14 +4,13 @@ import os
4
4
  from dataclasses import dataclass
5
5
  from pathlib import Path
6
6
  from types import TracebackType
7
- from typing import Callable, Generic, Self, Sequence, TypeVar
7
+ from typing import Generic, Self, TypeVar
8
8
 
9
- from jaxtyping import Array
9
+ import jax
10
10
 
11
- from xax.core.conf import Device as BaseDeviceConfig, field
12
11
  from xax.core.state import State
13
12
  from xax.task.base import BaseConfig, BaseTask
14
- from xax.task.logger import ChannelSelectMode, Logger, LoggerImpl, Number
13
+ from xax.task.logger import Logger, LoggerImpl
15
14
  from xax.task.loggers.json import JsonLogger
16
15
  from xax.task.loggers.state import StateLogger
17
16
  from xax.task.loggers.stdout import StdoutLogger
@@ -20,9 +19,10 @@ from xax.task.mixins.artifacts import ArtifactsMixin
20
19
  from xax.utils.text import is_interactive_session
21
20
 
22
21
 
22
+ @jax.tree_util.register_dataclass
23
23
  @dataclass
24
24
  class LoggerConfig(BaseConfig):
25
- device: BaseDeviceConfig = field(BaseDeviceConfig(), help="Device configuration")
25
+ pass
26
26
 
27
27
 
28
28
  Config = TypeVar("Config", bound=LoggerConfig)
@@ -59,252 +59,6 @@ class LoggerMixin(BaseTask[Config], Generic[Config]):
59
59
  def write_logs(self, state: State) -> None:
60
60
  self.logger.write(state)
61
61
 
62
- def log_scalar(self, key: str, value: Callable[[], Number] | Number, *, namespace: str | None = None) -> None:
63
- self.logger.log_scalar(key, value, namespace=namespace)
64
-
65
- def log_string(self, key: str, value: Callable[[], str] | str, *, namespace: str | None = None) -> None:
66
- self.logger.log_string(key, value, namespace=namespace)
67
-
68
- def log_image(
69
- self,
70
- key: str,
71
- value: Callable[[], Array] | Array,
72
- *,
73
- namespace: str | None = None,
74
- keep_resolution: bool = False,
75
- ) -> None:
76
- self.logger.log_image(
77
- key,
78
- value,
79
- namespace=namespace,
80
- keep_resolution=keep_resolution,
81
- )
82
-
83
- def log_labeled_image(
84
- self,
85
- key: str,
86
- value: Callable[[], tuple[Array, str]] | tuple[Array, str],
87
- *,
88
- namespace: str | None = None,
89
- max_line_length: int | None = None,
90
- keep_resolution: bool = False,
91
- centered: bool = True,
92
- ) -> None:
93
- self.logger.log_labeled_image(
94
- key,
95
- value,
96
- namespace=namespace,
97
- max_line_length=max_line_length,
98
- keep_resolution=keep_resolution,
99
- centered=centered,
100
- )
101
-
102
- def log_images(
103
- self,
104
- key: str,
105
- value: Callable[[], Array] | Array,
106
- *,
107
- namespace: str | None = None,
108
- keep_resolution: bool = False,
109
- max_images: int | None = None,
110
- sep: int = 0,
111
- ) -> None:
112
- self.logger.log_images(
113
- key,
114
- value,
115
- namespace=namespace,
116
- keep_resolution=keep_resolution,
117
- max_images=max_images,
118
- sep=sep,
119
- )
120
-
121
- def log_labeled_images(
122
- self,
123
- key: str,
124
- value: Callable[[], tuple[Array, Sequence[str]]] | tuple[Array, Sequence[str]],
125
- *,
126
- namespace: str | None = None,
127
- max_line_length: int | None = None,
128
- keep_resolution: bool = False,
129
- max_images: int | None = None,
130
- sep: int = 0,
131
- centered: bool = True,
132
- ) -> None:
133
- self.logger.log_labeled_images(
134
- key,
135
- value,
136
- namespace=namespace,
137
- max_line_length=max_line_length,
138
- keep_resolution=keep_resolution,
139
- max_images=max_images,
140
- sep=sep,
141
- centered=centered,
142
- )
143
-
144
- def log_audio(
145
- self,
146
- key: str,
147
- value: Callable[[], Array] | Array,
148
- *,
149
- namespace: str | None = None,
150
- sample_rate: int = 44100,
151
- log_spec: bool = True,
152
- n_fft_ms: float = 32.0,
153
- hop_length_ms: float | None = None,
154
- channel_select_mode: ChannelSelectMode = "first",
155
- keep_resolution: bool = False,
156
- ) -> None:
157
- self.logger.log_audio(
158
- key,
159
- value,
160
- namespace=namespace,
161
- sample_rate=sample_rate,
162
- log_spec=log_spec,
163
- n_fft_ms=n_fft_ms,
164
- hop_length_ms=hop_length_ms,
165
- channel_select_mode=channel_select_mode,
166
- keep_resolution=keep_resolution,
167
- )
168
-
169
- def log_audios(
170
- self,
171
- key: str,
172
- value: Callable[[], Array] | Array,
173
- *,
174
- namespace: str | None = None,
175
- sep_ms: float = 0.0,
176
- max_audios: int | None = None,
177
- sample_rate: int = 44100,
178
- log_spec: bool = True,
179
- n_fft_ms: float = 32.0,
180
- hop_length_ms: float | None = None,
181
- channel_select_mode: ChannelSelectMode = "first",
182
- spec_sep: int = 0,
183
- keep_resolution: bool = False,
184
- ) -> None:
185
- self.logger.log_audios(
186
- key,
187
- value,
188
- namespace=namespace,
189
- sep_ms=sep_ms,
190
- max_audios=max_audios,
191
- sample_rate=sample_rate,
192
- log_spec=log_spec,
193
- n_fft_ms=n_fft_ms,
194
- hop_length_ms=hop_length_ms,
195
- channel_select_mode=channel_select_mode,
196
- spec_sep=spec_sep,
197
- keep_resolution=keep_resolution,
198
- )
199
-
200
- def log_spectrogram(
201
- self,
202
- key: str,
203
- value: Callable[[], Array] | Array,
204
- *,
205
- namespace: str | None = None,
206
- sample_rate: int = 44100,
207
- n_fft_ms: float = 32.0,
208
- hop_length_ms: float | None = None,
209
- channel_select_mode: ChannelSelectMode = "first",
210
- keep_resolution: bool = False,
211
- ) -> None:
212
- self.logger.log_spectrogram(
213
- key,
214
- value,
215
- namespace=namespace,
216
- sample_rate=sample_rate,
217
- n_fft_ms=n_fft_ms,
218
- hop_length_ms=hop_length_ms,
219
- channel_select_mode=channel_select_mode,
220
- keep_resolution=keep_resolution,
221
- )
222
-
223
- def log_spectrograms(
224
- self,
225
- key: str,
226
- value: Callable[[], Array] | Array,
227
- *,
228
- namespace: str | None = None,
229
- max_audios: int | None = None,
230
- sample_rate: int = 44100,
231
- n_fft_ms: float = 32.0,
232
- hop_length_ms: float | None = None,
233
- channel_select_mode: ChannelSelectMode = "first",
234
- spec_sep: int = 0,
235
- keep_resolution: bool = False,
236
- ) -> None:
237
- self.logger.log_spectrograms(
238
- key,
239
- value,
240
- namespace=namespace,
241
- max_audios=max_audios,
242
- sample_rate=sample_rate,
243
- n_fft_ms=n_fft_ms,
244
- hop_length_ms=hop_length_ms,
245
- channel_select_mode=channel_select_mode,
246
- spec_sep=spec_sep,
247
- keep_resolution=keep_resolution,
248
- )
249
-
250
- def log_video(
251
- self,
252
- key: str,
253
- value: Callable[[], Array] | Array,
254
- *,
255
- namespace: str | None = None,
256
- fps: int | None = None,
257
- length: float | None = None,
258
- ) -> None:
259
- self.logger.log_video(
260
- key,
261
- value,
262
- namespace=namespace,
263
- fps=fps,
264
- length=length,
265
- )
266
-
267
- def log_videos(
268
- self,
269
- key: str,
270
- value: Callable[[], Array | list[Array]] | Array | list[Array],
271
- *,
272
- namespace: str | None = None,
273
- max_videos: int | None = None,
274
- sep: int = 0,
275
- fps: int | None = None,
276
- length: int | None = None,
277
- ) -> None:
278
- self.logger.log_videos(
279
- key,
280
- value,
281
- namespace=namespace,
282
- max_videos=max_videos,
283
- sep=sep,
284
- fps=fps,
285
- length=length,
286
- )
287
-
288
- def log_histogram(self, key: str, value: Callable[[], Array] | Array, *, namespace: str | None = None) -> None:
289
- self.logger.log_histogram(key, value, namespace=namespace)
290
-
291
- def log_point_cloud(
292
- self,
293
- key: str,
294
- value: Callable[[], Array] | Array,
295
- *,
296
- namespace: str | None = None,
297
- max_points: int = 1000,
298
- colors: Callable[[], Array] | Array | None = None,
299
- ) -> None:
300
- self.logger.log_point_cloud(
301
- key,
302
- value,
303
- namespace=namespace,
304
- max_points=max_points,
305
- colors=colors,
306
- )
307
-
308
62
  def __enter__(self) -> Self:
309
63
  self.logger.__enter__()
310
64
  return self
@@ -7,6 +7,8 @@ from multiprocessing.context import BaseContext
7
7
  from multiprocessing.managers import SyncManager
8
8
  from typing import Generic, TypeVar
9
9
 
10
+ import jax
11
+
10
12
  from xax.core.conf import field
11
13
  from xax.core.state import State
12
14
  from xax.task.base import BaseConfig, BaseTask
@@ -14,9 +16,10 @@ from xax.task.base import BaseConfig, BaseTask
14
16
  logger: logging.Logger = logging.getLogger(__name__)
15
17
 
16
18
 
19
+ @jax.tree_util.register_dataclass
17
20
  @dataclass
18
21
  class ProcessConfig(BaseConfig):
19
- multiprocessing_context: str | None = field(None, help="The multiprocessing context to use")
22
+ multiprocessing_context: str | None = field("spawn", help="The multiprocessing context to use")
20
23
 
21
24
 
22
25
  Config = TypeVar("Config", bound=ProcessConfig)
@@ -38,6 +41,10 @@ class ProcessMixin(BaseTask[Config], Generic[Config]):
38
41
  def multiprocessing_context(self) -> BaseContext:
39
42
  return self._mp_ctx
40
43
 
44
+ @property
45
+ def multiprocessing_manager(self) -> SyncManager:
46
+ return self._mp_manager
47
+
41
48
  def on_training_end(self, state: State) -> State:
42
49
  state = super().on_training_end(state)
43
50
 
@@ -6,10 +6,13 @@ from dataclasses import dataclass
6
6
  from types import FrameType
7
7
  from typing import Callable, TypeVar
8
8
 
9
+ import jax
10
+
9
11
  from xax.task.base import BaseConfig, BaseTask, RawConfigType
10
12
  from xax.task.launchers.base import BaseLauncher
11
13
 
12
14
 
15
+ @jax.tree_util.register_dataclass
13
16
  @dataclass
14
17
  class RunnableConfig(BaseConfig):
15
18
  pass
@@ -4,6 +4,9 @@ from dataclasses import dataclass
4
4
  from types import TracebackType
5
5
  from typing import ContextManager, Literal, TypeVar
6
6
 
7
+ import equinox as eqx
8
+ import jax
9
+
7
10
  from xax.task.base import BaseConfig, BaseTask
8
11
 
9
12
  StepType = Literal[
@@ -47,6 +50,7 @@ class StepContext(ContextManager):
47
50
  StepContext.CURRENT_STEP = None
48
51
 
49
52
 
53
+ @jax.tree_util.register_dataclass
50
54
  @dataclass
51
55
  class StepContextConfig(BaseConfig):
52
56
  pass
@@ -59,5 +63,6 @@ class StepContextMixin(BaseTask[Config]):
59
63
  def __init__(self, config: Config) -> None:
60
64
  super().__init__(config)
61
65
 
66
+ @eqx.filter_jit
62
67
  def step_context(self, step: StepType) -> ContextManager:
63
68
  return StepContext(step)