xax 0.0.3__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +122 -8
- xax/core/conf.py +9 -33
- xax/core/state.py +13 -23
- xax/nn/embeddings.py +355 -0
- xax/nn/functions.py +8 -4
- xax/requirements-dev.txt +9 -1
- xax/requirements.txt +17 -10
- xax/task/base.py +2 -6
- xax/task/logger.py +419 -412
- xax/task/loggers/callback.py +44 -0
- xax/task/loggers/state.py +5 -18
- xax/task/loggers/tensorboard.py +16 -33
- xax/task/mixins/__init__.py +3 -1
- xax/task/mixins/artifacts.py +19 -9
- xax/task/mixins/checkpointing.py +221 -0
- xax/task/mixins/compile.py +104 -0
- xax/task/mixins/cpu_stats.py +26 -15
- xax/task/mixins/data_loader.py +27 -19
- xax/task/mixins/gpu_stats.py +22 -8
- xax/task/mixins/logger.py +5 -251
- xax/task/mixins/process.py +8 -1
- xax/task/mixins/runnable.py +3 -0
- xax/task/mixins/step_wrapper.py +5 -0
- xax/task/mixins/train.py +236 -145
- xax/task/script.py +1 -1
- xax/task/task.py +13 -5
- xax/utils/data/collate.py +6 -6
- xax/utils/experiments.py +45 -1
- xax/utils/logging.py +29 -0
- xax/utils/tensorboard.py +89 -21
- xax-0.0.6.dist-info/METADATA +50 -0
- xax-0.0.6.dist-info/RECORD +52 -0
- {xax-0.0.3.dist-info → xax-0.0.6.dist-info}/WHEEL +1 -1
- xax/task/launchers/staged.py +0 -29
- xax-0.0.3.dist-info/METADATA +0 -39
- xax-0.0.3.dist-info/RECORD +0 -49
- {xax-0.0.3.dist-info → xax-0.0.6.dist-info}/LICENSE +0 -0
- {xax-0.0.3.dist-info → xax-0.0.6.dist-info}/top_level.txt +0 -0
xax/task/mixins/cpu_stats.py
CHANGED
@@ -6,15 +6,16 @@ leaks in your dataloader, among other issues.
|
|
6
6
|
"""
|
7
7
|
|
8
8
|
import logging
|
9
|
-
import multiprocessing as mp
|
10
9
|
import os
|
11
10
|
import time
|
12
11
|
from ctypes import Structure, c_double, c_uint16, c_uint64
|
13
12
|
from dataclasses import dataclass
|
13
|
+
from multiprocessing.context import BaseContext, Process
|
14
14
|
from multiprocessing.managers import SyncManager, ValueProxy
|
15
15
|
from multiprocessing.synchronize import Event
|
16
16
|
from typing import Generic, TypeVar
|
17
17
|
|
18
|
+
import jax
|
18
19
|
import psutil
|
19
20
|
|
20
21
|
from xax.core.conf import field
|
@@ -26,12 +27,14 @@ from xax.task.mixins.process import ProcessConfig, ProcessMixin
|
|
26
27
|
logger: logging.Logger = logging.getLogger(__name__)
|
27
28
|
|
28
29
|
|
30
|
+
@jax.tree_util.register_dataclass
|
29
31
|
@dataclass
|
30
32
|
class CPUStatsOptions:
|
31
33
|
ping_interval: int = field(1, help="How often to check stats (in seconds)")
|
32
34
|
only_log_once: bool = field(False, help="If set, only log read stats one time")
|
33
35
|
|
34
36
|
|
37
|
+
@jax.tree_util.register_dataclass
|
35
38
|
@dataclass
|
36
39
|
class CPUStatsConfig(ProcessConfig, LoggerConfig, BaseConfig):
|
37
40
|
cpu_stats: CPUStatsOptions = field(CPUStatsOptions(), help="CPU stats configuration")
|
@@ -55,7 +58,7 @@ class CPUStats(Structure):
|
|
55
58
|
]
|
56
59
|
|
57
60
|
|
58
|
-
@dataclass
|
61
|
+
@dataclass(kw_only=True)
|
59
62
|
class CPUStatsInfo:
|
60
63
|
cpu_percent: float
|
61
64
|
mem_percent: float
|
@@ -142,9 +145,16 @@ def worker(
|
|
142
145
|
|
143
146
|
|
144
147
|
class CPUStatsMonitor:
|
145
|
-
def __init__(
|
148
|
+
def __init__(
|
149
|
+
self,
|
150
|
+
ping_interval: float,
|
151
|
+
context: BaseContext,
|
152
|
+
manager: SyncManager,
|
153
|
+
) -> None:
|
146
154
|
self._ping_interval = ping_interval
|
147
155
|
self._manager = manager
|
156
|
+
self._context = context
|
157
|
+
|
148
158
|
self._monitor_event = self._manager.Event()
|
149
159
|
self._start_event = self._manager.Event()
|
150
160
|
self._cpu_stats_smem = self._manager.Value(
|
@@ -163,7 +173,7 @@ class CPUStatsMonitor:
|
|
163
173
|
),
|
164
174
|
)
|
165
175
|
self._cpu_stats: CPUStatsInfo | None = None
|
166
|
-
self._proc:
|
176
|
+
self._proc: Process | None = None
|
167
177
|
|
168
178
|
def get_if_set(self) -> CPUStatsInfo | None:
|
169
179
|
if self._monitor_event.is_set():
|
@@ -184,7 +194,7 @@ class CPUStatsMonitor:
|
|
184
194
|
if self._start_event.is_set():
|
185
195
|
self._start_event.clear()
|
186
196
|
self._cpu_stats = None
|
187
|
-
self._proc =
|
197
|
+
self._proc = self._context.Process( # type: ignore[attr-defined]
|
188
198
|
target=worker,
|
189
199
|
args=(self._ping_interval, self._cpu_stats_smem, self._monitor_event, self._start_event, os.getpid()),
|
190
200
|
daemon=True,
|
@@ -215,6 +225,7 @@ class CPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
|
|
215
225
|
|
216
226
|
self._cpu_stats_monitor = CPUStatsMonitor(
|
217
227
|
ping_interval=self.config.cpu_stats.ping_interval,
|
228
|
+
context=self._mp_ctx,
|
218
229
|
manager=self._mp_manager,
|
219
230
|
)
|
220
231
|
|
@@ -237,15 +248,15 @@ class CPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
|
|
237
248
|
stats = monitor.get_if_set() if self.config.cpu_stats.only_log_once else monitor.get()
|
238
249
|
|
239
250
|
if stats is not None:
|
240
|
-
self.log_scalar("child_procs", stats.num_child_procs, namespace="🔧 cpu")
|
241
|
-
self.log_scalar("percent", stats.cpu_percent, namespace="🔧 cpu")
|
242
|
-
self.log_scalar("child_percent", stats.child_cpu_percent, namespace="🔧 cpu")
|
243
|
-
self.log_scalar("percent", stats.mem_percent, namespace="🔧 mem")
|
244
|
-
self.log_scalar("shared", stats.mem_shared, namespace="🔧 mem")
|
245
|
-
self.log_scalar("child_percent", stats.child_mem_percent, namespace="🔧 mem")
|
246
|
-
self.log_scalar("rss/cur", stats.mem_rss, namespace="🔧 mem")
|
247
|
-
self.log_scalar("rss/total", stats.mem_rss_total, namespace="🔧 mem")
|
248
|
-
self.log_scalar("vms/cur", stats.mem_vms, namespace="🔧 mem")
|
249
|
-
self.log_scalar("vms/total", stats.mem_vms_total, namespace="🔧 mem")
|
251
|
+
self.logger.log_scalar("child_procs", stats.num_child_procs, namespace="🔧 cpu")
|
252
|
+
self.logger.log_scalar("percent", stats.cpu_percent, namespace="🔧 cpu")
|
253
|
+
self.logger.log_scalar("child_percent", stats.child_cpu_percent, namespace="🔧 cpu")
|
254
|
+
self.logger.log_scalar("percent", stats.mem_percent, namespace="🔧 mem")
|
255
|
+
self.logger.log_scalar("shared", stats.mem_shared, namespace="🔧 mem")
|
256
|
+
self.logger.log_scalar("child_percent", stats.child_mem_percent, namespace="🔧 mem")
|
257
|
+
self.logger.log_scalar("rss/cur", stats.mem_rss, namespace="🔧 mem")
|
258
|
+
self.logger.log_scalar("rss/total", stats.mem_rss_total, namespace="🔧 mem")
|
259
|
+
self.logger.log_scalar("vms/cur", stats.mem_vms, namespace="🔧 mem")
|
260
|
+
self.logger.log_scalar("vms/total", stats.mem_vms_total, namespace="🔧 mem")
|
250
261
|
|
251
262
|
return state
|
xax/task/mixins/data_loader.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1
1
|
"""Defines a mixin for instantiating dataloaders."""
|
2
2
|
|
3
3
|
import logging
|
4
|
-
from abc import ABC
|
4
|
+
from abc import ABC
|
5
5
|
from dataclasses import dataclass
|
6
|
-
from typing import Generic, TypeVar
|
6
|
+
from typing import Generic, Iterator, TypeVar
|
7
7
|
|
8
8
|
import jax
|
9
9
|
from dpshdl.dataloader import CollatedDataloaderItem, Dataloader
|
@@ -13,7 +13,7 @@ from omegaconf import II, MISSING
|
|
13
13
|
|
14
14
|
from xax.core.conf import field, is_missing
|
15
15
|
from xax.core.state import Phase
|
16
|
-
from xax.nn.functions import
|
16
|
+
from xax.nn.functions import set_random_seed
|
17
17
|
from xax.task.base import BaseConfig, BaseTask
|
18
18
|
from xax.task.mixins.process import ProcessConfig, ProcessMixin
|
19
19
|
from xax.utils.logging import LOG_ERROR_SUMMARY, configure_logging
|
@@ -24,6 +24,7 @@ T = TypeVar("T")
|
|
24
24
|
Tc_co = TypeVar("Tc_co", covariant=True)
|
25
25
|
|
26
26
|
|
27
|
+
@jax.tree_util.register_dataclass
|
27
28
|
@dataclass
|
28
29
|
class DataloaderErrorConfig:
|
29
30
|
sleep_backoff: float = field(0.1, help="The initial sleep time after an exception")
|
@@ -36,24 +37,25 @@ class DataloaderErrorConfig:
|
|
36
37
|
log_exceptions_all_workers: bool = field(False, help="If set, log exceptions from all workers")
|
37
38
|
|
38
39
|
|
40
|
+
@jax.tree_util.register_dataclass
|
39
41
|
@dataclass
|
40
42
|
class DataloaderConfig:
|
41
|
-
batch_size: int = field(MISSING, help="Size of each batch")
|
42
43
|
num_workers: int | None = field(MISSING, help="Number of workers for loading samples")
|
43
44
|
prefetch_factor: int = field(2, help="Number of items to pre-fetch on each worker")
|
44
45
|
error: DataloaderErrorConfig = field(DataloaderErrorConfig(), help="Dataloader error configuration")
|
45
46
|
|
46
47
|
|
48
|
+
@jax.tree_util.register_dataclass
|
47
49
|
@dataclass
|
48
50
|
class DataloadersConfig(ProcessConfig, BaseConfig):
|
49
51
|
batch_size: int = field(MISSING, help="Size of each batch")
|
50
52
|
raise_dataloader_errors: bool = field(False, help="If set, raise dataloader errors inside the worker processes")
|
51
53
|
train_dl: DataloaderConfig = field(
|
52
|
-
DataloaderConfig(
|
54
|
+
DataloaderConfig(num_workers=II("mlfab.num_workers:-1")),
|
53
55
|
help="Train dataloader config",
|
54
56
|
)
|
55
57
|
valid_dl: DataloaderConfig = field(
|
56
|
-
DataloaderConfig(
|
58
|
+
DataloaderConfig(num_workers=1),
|
57
59
|
help="Valid dataloader config",
|
58
60
|
)
|
59
61
|
debug_dataloader: bool = field(False, help="Debug dataloaders")
|
@@ -64,11 +66,6 @@ Config = TypeVar("Config", bound=DataloadersConfig)
|
|
64
66
|
|
65
67
|
class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config], ABC):
|
66
68
|
def __init__(self, config: Config) -> None:
|
67
|
-
if is_missing(config, "batch_size") and (
|
68
|
-
is_missing(config.train_dl, "batch_size") or is_missing(config.valid_dl, "batch_size")
|
69
|
-
):
|
70
|
-
config.batch_size = self.get_batch_size()
|
71
|
-
|
72
69
|
super().__init__(config)
|
73
70
|
|
74
71
|
def get_batch_size(self) -> int:
|
@@ -77,6 +74,12 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
|
|
77
74
|
"method to return the desired training batch size."
|
78
75
|
)
|
79
76
|
|
77
|
+
@property
|
78
|
+
def batch_size(self) -> int:
|
79
|
+
if is_missing(self.config, "batch_size"):
|
80
|
+
self.config.batch_size = self.get_batch_size()
|
81
|
+
return self.config.batch_size
|
82
|
+
|
80
83
|
def dataloader_config(self, phase: Phase) -> DataloaderConfig:
|
81
84
|
match phase:
|
82
85
|
case "train":
|
@@ -86,7 +89,6 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
|
|
86
89
|
case _:
|
87
90
|
raise KeyError(f"Unknown phase: {phase}")
|
88
91
|
|
89
|
-
@abstractmethod
|
90
92
|
def get_dataset(self, phase: Phase) -> Dataset:
|
91
93
|
"""Returns the dataset for the given phase.
|
92
94
|
|
@@ -96,6 +98,16 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
|
|
96
98
|
Returns:
|
97
99
|
The dataset for the given phase.
|
98
100
|
"""
|
101
|
+
raise NotImplementedError(
|
102
|
+
"You must implement either the `get_dataset` method to return the dataset for the given phase, "
|
103
|
+
"or `get_data_iterator` to return an iterator for the given dataset."
|
104
|
+
)
|
105
|
+
|
106
|
+
def get_data_iterator(self, phase: Phase) -> Iterator:
|
107
|
+
raise NotImplementedError(
|
108
|
+
"You must implement either the `get_dataset` method to return the dataset for the given phase, "
|
109
|
+
"or `get_data_iterator` to return an iterator for the given dataset."
|
110
|
+
)
|
99
111
|
|
100
112
|
def get_dataloader(self, dataset: Dataset[T, Tc_co], phase: Phase) -> Dataloader[T, Tc_co]:
|
101
113
|
debugging = self.config.debug_dataloader
|
@@ -120,10 +132,10 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
|
|
120
132
|
|
121
133
|
return Dataloader(
|
122
134
|
dataset=dataset,
|
123
|
-
batch_size=
|
135
|
+
batch_size=self.config.batch_size,
|
124
136
|
num_workers=0 if debugging else cfg.num_workers,
|
125
137
|
prefetch_factor=cfg.prefetch_factor,
|
126
|
-
|
138
|
+
mp_manager=self.multiprocessing_manager,
|
127
139
|
dataloader_worker_init_fn=self.dataloader_worker_init_fn,
|
128
140
|
collate_worker_init_fn=self.collate_worker_init_fn,
|
129
141
|
item_callback=self.dataloader_item_callback,
|
@@ -131,11 +143,7 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
|
|
131
143
|
)
|
132
144
|
|
133
145
|
def get_prefetcher(self, dataloader: Dataloader[T, Tc_co]) -> Prefetcher[Tc_co, Tc_co]:
|
134
|
-
return Prefetcher(to_device_func=
|
135
|
-
|
136
|
-
@classmethod
|
137
|
-
def to_device_fn(cls, sample: T) -> T:
|
138
|
-
return recursive_apply(sample, jax.device_put)
|
146
|
+
return Prefetcher(to_device_func=jax.device_put, dataloader=dataloader)
|
139
147
|
|
140
148
|
@classmethod
|
141
149
|
def dataloader_worker_init_fn(cls, worker_id: int, num_workers: int) -> None:
|
xax/task/mixins/gpu_stats.py
CHANGED
@@ -6,17 +6,19 @@ This logs GPU memory and utilization in a background process using
|
|
6
6
|
|
7
7
|
import functools
|
8
8
|
import logging
|
9
|
-
import multiprocessing as mp
|
10
9
|
import os
|
11
10
|
import re
|
12
11
|
import shutil
|
13
12
|
import subprocess
|
14
13
|
from ctypes import Structure, c_double, c_uint32
|
15
14
|
from dataclasses import dataclass
|
15
|
+
from multiprocessing.context import BaseContext, Process
|
16
16
|
from multiprocessing.managers import SyncManager, ValueProxy
|
17
17
|
from multiprocessing.synchronize import Event
|
18
18
|
from typing import Generic, Iterable, Pattern, TypeVar
|
19
19
|
|
20
|
+
import jax
|
21
|
+
|
20
22
|
from xax.core.conf import field
|
21
23
|
from xax.core.state import State
|
22
24
|
from xax.task.mixins.logger import LoggerConfig, LoggerMixin
|
@@ -25,12 +27,14 @@ from xax.task.mixins.process import ProcessConfig, ProcessMixin
|
|
25
27
|
logger: logging.Logger = logging.getLogger(__name__)
|
26
28
|
|
27
29
|
|
30
|
+
@jax.tree_util.register_dataclass
|
28
31
|
@dataclass
|
29
32
|
class GPUStatsOptions:
|
30
33
|
ping_interval: int = field(10, help="How often to check stats (in seconds)")
|
31
34
|
only_log_once: bool = field(False, help="If set, only log read stats one time")
|
32
35
|
|
33
36
|
|
37
|
+
@jax.tree_util.register_dataclass
|
34
38
|
@dataclass
|
35
39
|
class GPUStatsConfig(ProcessConfig, LoggerConfig):
|
36
40
|
gpu_stats: GPUStatsOptions = field(GPUStatsOptions(), help="GPU stats configuration")
|
@@ -147,8 +151,14 @@ def worker(
|
|
147
151
|
|
148
152
|
|
149
153
|
class GPUStatsMonitor:
|
150
|
-
def __init__(
|
154
|
+
def __init__(
|
155
|
+
self,
|
156
|
+
ping_interval: float,
|
157
|
+
context: BaseContext,
|
158
|
+
manager: SyncManager,
|
159
|
+
) -> None:
|
151
160
|
self._ping_interval = ping_interval
|
161
|
+
self._context = context
|
152
162
|
self._manager = manager
|
153
163
|
|
154
164
|
num_gpus = get_num_gpus()
|
@@ -169,7 +179,7 @@ class GPUStatsMonitor:
|
|
169
179
|
for i in range(num_gpus)
|
170
180
|
]
|
171
181
|
self._gpu_stats: dict[int, GPUStatsInfo] = {}
|
172
|
-
self._proc:
|
182
|
+
self._proc: Process | None = None
|
173
183
|
|
174
184
|
def get_if_set(self) -> dict[int, GPUStatsInfo]:
|
175
185
|
gpu_stats: dict[int, GPUStatsInfo] = {}
|
@@ -196,7 +206,7 @@ class GPUStatsMonitor:
|
|
196
206
|
if self._start_event.is_set():
|
197
207
|
self._start_event.clear()
|
198
208
|
self._gpu_stats.clear()
|
199
|
-
self._proc =
|
209
|
+
self._proc = self._context.Process( # type: ignore[attr-defined]
|
200
210
|
target=worker,
|
201
211
|
args=(self._ping_interval, self._smems, self._main_event, self._events, self._start_event),
|
202
212
|
daemon=True,
|
@@ -226,7 +236,11 @@ class GPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
|
|
226
236
|
|
227
237
|
self._gpu_stats_monitor = None
|
228
238
|
if shutil.which("nvidia-smi") is not None:
|
229
|
-
self._gpu_stats_monitor = GPUStatsMonitor(
|
239
|
+
self._gpu_stats_monitor = GPUStatsMonitor(
|
240
|
+
config.gpu_stats.ping_interval,
|
241
|
+
self._mp_ctx,
|
242
|
+
self._mp_manager,
|
243
|
+
)
|
230
244
|
|
231
245
|
def on_training_start(self, state: State) -> State:
|
232
246
|
state = super().on_training_start(state)
|
@@ -250,8 +264,8 @@ class GPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
|
|
250
264
|
for gpu_stat in stats.values():
|
251
265
|
if gpu_stat is None:
|
252
266
|
continue
|
253
|
-
self.log_scalar(f"mem/{gpu_stat.index}", gpu_stat.memory_used, namespace="🔧 gpu")
|
254
|
-
self.log_scalar(f"temp/{gpu_stat.index}", gpu_stat.temperature, namespace="🔧 gpu")
|
255
|
-
self.log_scalar(f"util/{gpu_stat.index}", gpu_stat.utilization, namespace="🔧 gpu")
|
267
|
+
self.logger.log_scalar(f"mem/{gpu_stat.index}", gpu_stat.memory_used, namespace="🔧 gpu")
|
268
|
+
self.logger.log_scalar(f"temp/{gpu_stat.index}", gpu_stat.temperature, namespace="🔧 gpu")
|
269
|
+
self.logger.log_scalar(f"util/{gpu_stat.index}", gpu_stat.utilization, namespace="🔧 gpu")
|
256
270
|
|
257
271
|
return state
|
xax/task/mixins/logger.py
CHANGED
@@ -4,14 +4,13 @@ import os
|
|
4
4
|
from dataclasses import dataclass
|
5
5
|
from pathlib import Path
|
6
6
|
from types import TracebackType
|
7
|
-
from typing import
|
7
|
+
from typing import Generic, Self, TypeVar
|
8
8
|
|
9
|
-
|
9
|
+
import jax
|
10
10
|
|
11
|
-
from xax.core.conf import Device as BaseDeviceConfig, field
|
12
11
|
from xax.core.state import State
|
13
12
|
from xax.task.base import BaseConfig, BaseTask
|
14
|
-
from xax.task.logger import
|
13
|
+
from xax.task.logger import Logger, LoggerImpl
|
15
14
|
from xax.task.loggers.json import JsonLogger
|
16
15
|
from xax.task.loggers.state import StateLogger
|
17
16
|
from xax.task.loggers.stdout import StdoutLogger
|
@@ -20,9 +19,10 @@ from xax.task.mixins.artifacts import ArtifactsMixin
|
|
20
19
|
from xax.utils.text import is_interactive_session
|
21
20
|
|
22
21
|
|
22
|
+
@jax.tree_util.register_dataclass
|
23
23
|
@dataclass
|
24
24
|
class LoggerConfig(BaseConfig):
|
25
|
-
|
25
|
+
pass
|
26
26
|
|
27
27
|
|
28
28
|
Config = TypeVar("Config", bound=LoggerConfig)
|
@@ -59,252 +59,6 @@ class LoggerMixin(BaseTask[Config], Generic[Config]):
|
|
59
59
|
def write_logs(self, state: State) -> None:
|
60
60
|
self.logger.write(state)
|
61
61
|
|
62
|
-
def log_scalar(self, key: str, value: Callable[[], Number] | Number, *, namespace: str | None = None) -> None:
|
63
|
-
self.logger.log_scalar(key, value, namespace=namespace)
|
64
|
-
|
65
|
-
def log_string(self, key: str, value: Callable[[], str] | str, *, namespace: str | None = None) -> None:
|
66
|
-
self.logger.log_string(key, value, namespace=namespace)
|
67
|
-
|
68
|
-
def log_image(
|
69
|
-
self,
|
70
|
-
key: str,
|
71
|
-
value: Callable[[], Array] | Array,
|
72
|
-
*,
|
73
|
-
namespace: str | None = None,
|
74
|
-
keep_resolution: bool = False,
|
75
|
-
) -> None:
|
76
|
-
self.logger.log_image(
|
77
|
-
key,
|
78
|
-
value,
|
79
|
-
namespace=namespace,
|
80
|
-
keep_resolution=keep_resolution,
|
81
|
-
)
|
82
|
-
|
83
|
-
def log_labeled_image(
|
84
|
-
self,
|
85
|
-
key: str,
|
86
|
-
value: Callable[[], tuple[Array, str]] | tuple[Array, str],
|
87
|
-
*,
|
88
|
-
namespace: str | None = None,
|
89
|
-
max_line_length: int | None = None,
|
90
|
-
keep_resolution: bool = False,
|
91
|
-
centered: bool = True,
|
92
|
-
) -> None:
|
93
|
-
self.logger.log_labeled_image(
|
94
|
-
key,
|
95
|
-
value,
|
96
|
-
namespace=namespace,
|
97
|
-
max_line_length=max_line_length,
|
98
|
-
keep_resolution=keep_resolution,
|
99
|
-
centered=centered,
|
100
|
-
)
|
101
|
-
|
102
|
-
def log_images(
|
103
|
-
self,
|
104
|
-
key: str,
|
105
|
-
value: Callable[[], Array] | Array,
|
106
|
-
*,
|
107
|
-
namespace: str | None = None,
|
108
|
-
keep_resolution: bool = False,
|
109
|
-
max_images: int | None = None,
|
110
|
-
sep: int = 0,
|
111
|
-
) -> None:
|
112
|
-
self.logger.log_images(
|
113
|
-
key,
|
114
|
-
value,
|
115
|
-
namespace=namespace,
|
116
|
-
keep_resolution=keep_resolution,
|
117
|
-
max_images=max_images,
|
118
|
-
sep=sep,
|
119
|
-
)
|
120
|
-
|
121
|
-
def log_labeled_images(
|
122
|
-
self,
|
123
|
-
key: str,
|
124
|
-
value: Callable[[], tuple[Array, Sequence[str]]] | tuple[Array, Sequence[str]],
|
125
|
-
*,
|
126
|
-
namespace: str | None = None,
|
127
|
-
max_line_length: int | None = None,
|
128
|
-
keep_resolution: bool = False,
|
129
|
-
max_images: int | None = None,
|
130
|
-
sep: int = 0,
|
131
|
-
centered: bool = True,
|
132
|
-
) -> None:
|
133
|
-
self.logger.log_labeled_images(
|
134
|
-
key,
|
135
|
-
value,
|
136
|
-
namespace=namespace,
|
137
|
-
max_line_length=max_line_length,
|
138
|
-
keep_resolution=keep_resolution,
|
139
|
-
max_images=max_images,
|
140
|
-
sep=sep,
|
141
|
-
centered=centered,
|
142
|
-
)
|
143
|
-
|
144
|
-
def log_audio(
|
145
|
-
self,
|
146
|
-
key: str,
|
147
|
-
value: Callable[[], Array] | Array,
|
148
|
-
*,
|
149
|
-
namespace: str | None = None,
|
150
|
-
sample_rate: int = 44100,
|
151
|
-
log_spec: bool = True,
|
152
|
-
n_fft_ms: float = 32.0,
|
153
|
-
hop_length_ms: float | None = None,
|
154
|
-
channel_select_mode: ChannelSelectMode = "first",
|
155
|
-
keep_resolution: bool = False,
|
156
|
-
) -> None:
|
157
|
-
self.logger.log_audio(
|
158
|
-
key,
|
159
|
-
value,
|
160
|
-
namespace=namespace,
|
161
|
-
sample_rate=sample_rate,
|
162
|
-
log_spec=log_spec,
|
163
|
-
n_fft_ms=n_fft_ms,
|
164
|
-
hop_length_ms=hop_length_ms,
|
165
|
-
channel_select_mode=channel_select_mode,
|
166
|
-
keep_resolution=keep_resolution,
|
167
|
-
)
|
168
|
-
|
169
|
-
def log_audios(
|
170
|
-
self,
|
171
|
-
key: str,
|
172
|
-
value: Callable[[], Array] | Array,
|
173
|
-
*,
|
174
|
-
namespace: str | None = None,
|
175
|
-
sep_ms: float = 0.0,
|
176
|
-
max_audios: int | None = None,
|
177
|
-
sample_rate: int = 44100,
|
178
|
-
log_spec: bool = True,
|
179
|
-
n_fft_ms: float = 32.0,
|
180
|
-
hop_length_ms: float | None = None,
|
181
|
-
channel_select_mode: ChannelSelectMode = "first",
|
182
|
-
spec_sep: int = 0,
|
183
|
-
keep_resolution: bool = False,
|
184
|
-
) -> None:
|
185
|
-
self.logger.log_audios(
|
186
|
-
key,
|
187
|
-
value,
|
188
|
-
namespace=namespace,
|
189
|
-
sep_ms=sep_ms,
|
190
|
-
max_audios=max_audios,
|
191
|
-
sample_rate=sample_rate,
|
192
|
-
log_spec=log_spec,
|
193
|
-
n_fft_ms=n_fft_ms,
|
194
|
-
hop_length_ms=hop_length_ms,
|
195
|
-
channel_select_mode=channel_select_mode,
|
196
|
-
spec_sep=spec_sep,
|
197
|
-
keep_resolution=keep_resolution,
|
198
|
-
)
|
199
|
-
|
200
|
-
def log_spectrogram(
|
201
|
-
self,
|
202
|
-
key: str,
|
203
|
-
value: Callable[[], Array] | Array,
|
204
|
-
*,
|
205
|
-
namespace: str | None = None,
|
206
|
-
sample_rate: int = 44100,
|
207
|
-
n_fft_ms: float = 32.0,
|
208
|
-
hop_length_ms: float | None = None,
|
209
|
-
channel_select_mode: ChannelSelectMode = "first",
|
210
|
-
keep_resolution: bool = False,
|
211
|
-
) -> None:
|
212
|
-
self.logger.log_spectrogram(
|
213
|
-
key,
|
214
|
-
value,
|
215
|
-
namespace=namespace,
|
216
|
-
sample_rate=sample_rate,
|
217
|
-
n_fft_ms=n_fft_ms,
|
218
|
-
hop_length_ms=hop_length_ms,
|
219
|
-
channel_select_mode=channel_select_mode,
|
220
|
-
keep_resolution=keep_resolution,
|
221
|
-
)
|
222
|
-
|
223
|
-
def log_spectrograms(
|
224
|
-
self,
|
225
|
-
key: str,
|
226
|
-
value: Callable[[], Array] | Array,
|
227
|
-
*,
|
228
|
-
namespace: str | None = None,
|
229
|
-
max_audios: int | None = None,
|
230
|
-
sample_rate: int = 44100,
|
231
|
-
n_fft_ms: float = 32.0,
|
232
|
-
hop_length_ms: float | None = None,
|
233
|
-
channel_select_mode: ChannelSelectMode = "first",
|
234
|
-
spec_sep: int = 0,
|
235
|
-
keep_resolution: bool = False,
|
236
|
-
) -> None:
|
237
|
-
self.logger.log_spectrograms(
|
238
|
-
key,
|
239
|
-
value,
|
240
|
-
namespace=namespace,
|
241
|
-
max_audios=max_audios,
|
242
|
-
sample_rate=sample_rate,
|
243
|
-
n_fft_ms=n_fft_ms,
|
244
|
-
hop_length_ms=hop_length_ms,
|
245
|
-
channel_select_mode=channel_select_mode,
|
246
|
-
spec_sep=spec_sep,
|
247
|
-
keep_resolution=keep_resolution,
|
248
|
-
)
|
249
|
-
|
250
|
-
def log_video(
|
251
|
-
self,
|
252
|
-
key: str,
|
253
|
-
value: Callable[[], Array] | Array,
|
254
|
-
*,
|
255
|
-
namespace: str | None = None,
|
256
|
-
fps: int | None = None,
|
257
|
-
length: float | None = None,
|
258
|
-
) -> None:
|
259
|
-
self.logger.log_video(
|
260
|
-
key,
|
261
|
-
value,
|
262
|
-
namespace=namespace,
|
263
|
-
fps=fps,
|
264
|
-
length=length,
|
265
|
-
)
|
266
|
-
|
267
|
-
def log_videos(
|
268
|
-
self,
|
269
|
-
key: str,
|
270
|
-
value: Callable[[], Array | list[Array]] | Array | list[Array],
|
271
|
-
*,
|
272
|
-
namespace: str | None = None,
|
273
|
-
max_videos: int | None = None,
|
274
|
-
sep: int = 0,
|
275
|
-
fps: int | None = None,
|
276
|
-
length: int | None = None,
|
277
|
-
) -> None:
|
278
|
-
self.logger.log_videos(
|
279
|
-
key,
|
280
|
-
value,
|
281
|
-
namespace=namespace,
|
282
|
-
max_videos=max_videos,
|
283
|
-
sep=sep,
|
284
|
-
fps=fps,
|
285
|
-
length=length,
|
286
|
-
)
|
287
|
-
|
288
|
-
def log_histogram(self, key: str, value: Callable[[], Array] | Array, *, namespace: str | None = None) -> None:
|
289
|
-
self.logger.log_histogram(key, value, namespace=namespace)
|
290
|
-
|
291
|
-
def log_point_cloud(
|
292
|
-
self,
|
293
|
-
key: str,
|
294
|
-
value: Callable[[], Array] | Array,
|
295
|
-
*,
|
296
|
-
namespace: str | None = None,
|
297
|
-
max_points: int = 1000,
|
298
|
-
colors: Callable[[], Array] | Array | None = None,
|
299
|
-
) -> None:
|
300
|
-
self.logger.log_point_cloud(
|
301
|
-
key,
|
302
|
-
value,
|
303
|
-
namespace=namespace,
|
304
|
-
max_points=max_points,
|
305
|
-
colors=colors,
|
306
|
-
)
|
307
|
-
|
308
62
|
def __enter__(self) -> Self:
|
309
63
|
self.logger.__enter__()
|
310
64
|
return self
|
xax/task/mixins/process.py
CHANGED
@@ -7,6 +7,8 @@ from multiprocessing.context import BaseContext
|
|
7
7
|
from multiprocessing.managers import SyncManager
|
8
8
|
from typing import Generic, TypeVar
|
9
9
|
|
10
|
+
import jax
|
11
|
+
|
10
12
|
from xax.core.conf import field
|
11
13
|
from xax.core.state import State
|
12
14
|
from xax.task.base import BaseConfig, BaseTask
|
@@ -14,9 +16,10 @@ from xax.task.base import BaseConfig, BaseTask
|
|
14
16
|
logger: logging.Logger = logging.getLogger(__name__)
|
15
17
|
|
16
18
|
|
19
|
+
@jax.tree_util.register_dataclass
|
17
20
|
@dataclass
|
18
21
|
class ProcessConfig(BaseConfig):
|
19
|
-
multiprocessing_context: str | None = field(
|
22
|
+
multiprocessing_context: str | None = field("spawn", help="The multiprocessing context to use")
|
20
23
|
|
21
24
|
|
22
25
|
Config = TypeVar("Config", bound=ProcessConfig)
|
@@ -38,6 +41,10 @@ class ProcessMixin(BaseTask[Config], Generic[Config]):
|
|
38
41
|
def multiprocessing_context(self) -> BaseContext:
|
39
42
|
return self._mp_ctx
|
40
43
|
|
44
|
+
@property
|
45
|
+
def multiprocessing_manager(self) -> SyncManager:
|
46
|
+
return self._mp_manager
|
47
|
+
|
41
48
|
def on_training_end(self, state: State) -> State:
|
42
49
|
state = super().on_training_end(state)
|
43
50
|
|
xax/task/mixins/runnable.py
CHANGED
@@ -6,10 +6,13 @@ from dataclasses import dataclass
|
|
6
6
|
from types import FrameType
|
7
7
|
from typing import Callable, TypeVar
|
8
8
|
|
9
|
+
import jax
|
10
|
+
|
9
11
|
from xax.task.base import BaseConfig, BaseTask, RawConfigType
|
10
12
|
from xax.task.launchers.base import BaseLauncher
|
11
13
|
|
12
14
|
|
15
|
+
@jax.tree_util.register_dataclass
|
13
16
|
@dataclass
|
14
17
|
class RunnableConfig(BaseConfig):
|
15
18
|
pass
|
xax/task/mixins/step_wrapper.py
CHANGED
@@ -4,6 +4,9 @@ from dataclasses import dataclass
|
|
4
4
|
from types import TracebackType
|
5
5
|
from typing import ContextManager, Literal, TypeVar
|
6
6
|
|
7
|
+
import equinox as eqx
|
8
|
+
import jax
|
9
|
+
|
7
10
|
from xax.task.base import BaseConfig, BaseTask
|
8
11
|
|
9
12
|
StepType = Literal[
|
@@ -47,6 +50,7 @@ class StepContext(ContextManager):
|
|
47
50
|
StepContext.CURRENT_STEP = None
|
48
51
|
|
49
52
|
|
53
|
+
@jax.tree_util.register_dataclass
|
50
54
|
@dataclass
|
51
55
|
class StepContextConfig(BaseConfig):
|
52
56
|
pass
|
@@ -59,5 +63,6 @@ class StepContextMixin(BaseTask[Config]):
|
|
59
63
|
def __init__(self, config: Config) -> None:
|
60
64
|
super().__init__(config)
|
61
65
|
|
66
|
+
@eqx.filter_jit
|
62
67
|
def step_context(self, step: StepType) -> ContextManager:
|
63
68
|
return StepContext(step)
|