xax 0.0.1__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +256 -1
- xax/core/conf.py +193 -0
- xax/core/state.py +81 -0
- xax/nn/__init__.py +0 -0
- xax/nn/embeddings.py +355 -0
- xax/nn/functions.py +77 -0
- xax/nn/parallel.py +211 -0
- xax/requirements-dev.txt +15 -0
- xax/requirements.txt +23 -0
- xax/task/__init__.py +0 -0
- xax/task/base.py +207 -0
- xax/task/launchers/__init__.py +0 -0
- xax/task/launchers/base.py +28 -0
- xax/task/launchers/cli.py +42 -0
- xax/task/launchers/single_process.py +30 -0
- xax/task/launchers/staged.py +29 -0
- xax/task/logger.py +783 -0
- xax/task/loggers/__init__.py +0 -0
- xax/task/loggers/callback.py +56 -0
- xax/task/loggers/json.py +121 -0
- xax/task/loggers/state.py +45 -0
- xax/task/loggers/stdout.py +170 -0
- xax/task/loggers/tensorboard.py +223 -0
- xax/task/mixins/__init__.py +12 -0
- xax/task/mixins/artifacts.py +114 -0
- xax/task/mixins/checkpointing.py +209 -0
- xax/task/mixins/cpu_stats.py +251 -0
- xax/task/mixins/data_loader.py +149 -0
- xax/task/mixins/gpu_stats.py +257 -0
- xax/task/mixins/logger.py +66 -0
- xax/task/mixins/process.py +51 -0
- xax/task/mixins/runnable.py +63 -0
- xax/task/mixins/step_wrapper.py +63 -0
- xax/task/mixins/train.py +541 -0
- xax/task/script.py +53 -0
- xax/task/task.py +65 -0
- xax/utils/__init__.py +0 -0
- xax/utils/data/__init__.py +0 -0
- xax/utils/data/collate.py +206 -0
- xax/utils/experiments.py +802 -0
- xax/utils/jax.py +14 -0
- xax/utils/logging.py +223 -0
- xax/utils/numpy.py +47 -0
- xax/utils/tensorboard.py +258 -0
- xax/utils/text.py +350 -0
- xax-0.0.5.dist-info/METADATA +40 -0
- xax-0.0.5.dist-info/RECORD +52 -0
- {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/WHEEL +1 -1
- xax-0.0.5.dist-info/top_level.txt +1 -0
- examples/mnist.py +0 -148
- xax-0.0.1.dist-info/METADATA +0 -21
- xax-0.0.1.dist-info/RECORD +0 -9
- xax-0.0.1.dist-info/top_level.txt +0 -2
- {examples → xax/core}/__init__.py +0 -0
- {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/LICENSE +0 -0
@@ -0,0 +1,149 @@
|
|
1
|
+
"""Defines a mixin for instantiating dataloaders."""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from abc import ABC, abstractmethod
|
5
|
+
from dataclasses import dataclass
|
6
|
+
from typing import Generic, TypeVar
|
7
|
+
|
8
|
+
import jax
|
9
|
+
from dpshdl.dataloader import CollatedDataloaderItem, Dataloader
|
10
|
+
from dpshdl.dataset import Dataset, ErrorHandlingDataset
|
11
|
+
from dpshdl.prefetcher import Prefetcher
|
12
|
+
from omegaconf import II, MISSING
|
13
|
+
|
14
|
+
from xax.core.conf import field, is_missing
|
15
|
+
from xax.core.state import Phase
|
16
|
+
from xax.nn.functions import recursive_apply, set_random_seed
|
17
|
+
from xax.task.base import BaseConfig, BaseTask
|
18
|
+
from xax.task.mixins.process import ProcessConfig, ProcessMixin
|
19
|
+
from xax.utils.logging import LOG_ERROR_SUMMARY, configure_logging
|
20
|
+
|
21
|
+
logger = logging.getLogger(__name__)
|
22
|
+
|
23
|
+
T = TypeVar("T")
|
24
|
+
Tc_co = TypeVar("Tc_co", covariant=True)
|
25
|
+
|
26
|
+
|
27
|
+
@dataclass
|
28
|
+
class DataloaderErrorConfig:
|
29
|
+
sleep_backoff: float = field(0.1, help="The initial sleep time after an exception")
|
30
|
+
sleep_backoff_power: float = field(2.0, help="Power to raise the sleep time by after each consecutive exception")
|
31
|
+
maximum_exceptions: int = field(10, help="The maximum number of consecutive exceptions before raising an error")
|
32
|
+
backoff_after: int = field(5, help="The number of consecutive exceptions before starting to backoff")
|
33
|
+
traceback_depth: int = field(5, help="The depth of the traceback to print when an exception occurs")
|
34
|
+
flush_every_n_steps: int | None = field(None, help="Flush the error summary after this many steps")
|
35
|
+
flush_every_n_seconds: float | None = field(10.0, help="Flush the error summary after this many seconds")
|
36
|
+
log_exceptions_all_workers: bool = field(False, help="If set, log exceptions from all workers")
|
37
|
+
|
38
|
+
|
39
|
+
@dataclass
|
40
|
+
class DataloaderConfig:
|
41
|
+
num_workers: int | None = field(MISSING, help="Number of workers for loading samples")
|
42
|
+
prefetch_factor: int = field(2, help="Number of items to pre-fetch on each worker")
|
43
|
+
error: DataloaderErrorConfig = field(DataloaderErrorConfig(), help="Dataloader error configuration")
|
44
|
+
|
45
|
+
|
46
|
+
@dataclass
|
47
|
+
class DataloadersConfig(ProcessConfig, BaseConfig):
|
48
|
+
batch_size: int = field(MISSING, help="Size of each batch")
|
49
|
+
raise_dataloader_errors: bool = field(False, help="If set, raise dataloader errors inside the worker processes")
|
50
|
+
train_dl: DataloaderConfig = field(
|
51
|
+
DataloaderConfig(num_workers=II("mlfab.num_workers:-1")),
|
52
|
+
help="Train dataloader config",
|
53
|
+
)
|
54
|
+
valid_dl: DataloaderConfig = field(
|
55
|
+
DataloaderConfig(num_workers=1),
|
56
|
+
help="Valid dataloader config",
|
57
|
+
)
|
58
|
+
debug_dataloader: bool = field(False, help="Debug dataloaders")
|
59
|
+
|
60
|
+
|
61
|
+
Config = TypeVar("Config", bound=DataloadersConfig)
|
62
|
+
|
63
|
+
|
64
|
+
class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config], ABC):
|
65
|
+
def __init__(self, config: Config) -> None:
|
66
|
+
if is_missing(config, "batch_size"):
|
67
|
+
config.batch_size = self.get_batch_size()
|
68
|
+
|
69
|
+
super().__init__(config)
|
70
|
+
|
71
|
+
def get_batch_size(self) -> int:
|
72
|
+
raise NotImplementedError(
|
73
|
+
"When `batch_size` is not specified in your training config, you should override the `get_batch_size` "
|
74
|
+
"method to return the desired training batch size."
|
75
|
+
)
|
76
|
+
|
77
|
+
def dataloader_config(self, phase: Phase) -> DataloaderConfig:
|
78
|
+
match phase:
|
79
|
+
case "train":
|
80
|
+
return self.config.train_dl
|
81
|
+
case "valid":
|
82
|
+
return self.config.valid_dl
|
83
|
+
case _:
|
84
|
+
raise KeyError(f"Unknown phase: {phase}")
|
85
|
+
|
86
|
+
@abstractmethod
|
87
|
+
def get_dataset(self, phase: Phase) -> Dataset:
|
88
|
+
"""Returns the dataset for the given phase.
|
89
|
+
|
90
|
+
Args:
|
91
|
+
phase: The phase for the dataset to return.
|
92
|
+
|
93
|
+
Returns:
|
94
|
+
The dataset for the given phase.
|
95
|
+
"""
|
96
|
+
|
97
|
+
def get_dataloader(self, dataset: Dataset[T, Tc_co], phase: Phase) -> Dataloader[T, Tc_co]:
|
98
|
+
debugging = self.config.debug_dataloader
|
99
|
+
if debugging:
|
100
|
+
logger.warning("Parallel dataloaders disabled in debugging mode")
|
101
|
+
|
102
|
+
cfg = self.dataloader_config(phase)
|
103
|
+
|
104
|
+
# Wraps the dataset to handle errors.
|
105
|
+
dataset = ErrorHandlingDataset(
|
106
|
+
dataset=dataset,
|
107
|
+
sleep_backoff=cfg.error.sleep_backoff,
|
108
|
+
sleep_backoff_power=cfg.error.sleep_backoff_power,
|
109
|
+
maximum_exceptions=cfg.error.maximum_exceptions,
|
110
|
+
backoff_after=cfg.error.backoff_after,
|
111
|
+
traceback_depth=cfg.error.traceback_depth,
|
112
|
+
flush_every_n_steps=cfg.error.flush_every_n_steps,
|
113
|
+
flush_every_n_seconds=cfg.error.flush_every_n_seconds,
|
114
|
+
log_exceptions_all_workers=cfg.error.log_exceptions_all_workers,
|
115
|
+
log_level=LOG_ERROR_SUMMARY,
|
116
|
+
)
|
117
|
+
|
118
|
+
return Dataloader(
|
119
|
+
dataset=dataset,
|
120
|
+
batch_size=self.config.batch_size,
|
121
|
+
num_workers=0 if debugging else cfg.num_workers,
|
122
|
+
prefetch_factor=cfg.prefetch_factor,
|
123
|
+
mp_manager=self.multiprocessing_manager,
|
124
|
+
dataloader_worker_init_fn=self.dataloader_worker_init_fn,
|
125
|
+
collate_worker_init_fn=self.collate_worker_init_fn,
|
126
|
+
item_callback=self.dataloader_item_callback,
|
127
|
+
raise_errs=self.config.raise_dataloader_errors,
|
128
|
+
)
|
129
|
+
|
130
|
+
def get_prefetcher(self, dataloader: Dataloader[T, Tc_co]) -> Prefetcher[Tc_co, Tc_co]:
|
131
|
+
return Prefetcher(to_device_func=self.to_device_fn, dataloader=dataloader)
|
132
|
+
|
133
|
+
@classmethod
|
134
|
+
def to_device_fn(cls, sample: T) -> T:
|
135
|
+
return recursive_apply(sample, jax.device_put, include_numpy=True)
|
136
|
+
|
137
|
+
@classmethod
|
138
|
+
def dataloader_worker_init_fn(cls, worker_id: int, num_workers: int) -> None:
|
139
|
+
configure_logging(prefix=f"{worker_id}")
|
140
|
+
set_random_seed(offset=worker_id + 1)
|
141
|
+
|
142
|
+
@classmethod
|
143
|
+
def collate_worker_init_fn(cls) -> None:
|
144
|
+
configure_logging(prefix="collate")
|
145
|
+
set_random_seed(offset=-1)
|
146
|
+
|
147
|
+
@classmethod
|
148
|
+
def dataloader_item_callback(cls, item: CollatedDataloaderItem) -> None:
|
149
|
+
pass
|
@@ -0,0 +1,257 @@
|
|
1
|
+
"""A task mixin for logging GPU statistics.
|
2
|
+
|
3
|
+
This logs GPU memory and utilization in a background process using
|
4
|
+
``nvidia-smi``, if a GPU is available in the system.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import functools
|
8
|
+
import logging
|
9
|
+
import multiprocessing as mp
|
10
|
+
import os
|
11
|
+
import re
|
12
|
+
import shutil
|
13
|
+
import subprocess
|
14
|
+
from ctypes import Structure, c_double, c_uint32
|
15
|
+
from dataclasses import dataclass
|
16
|
+
from multiprocessing.managers import SyncManager, ValueProxy
|
17
|
+
from multiprocessing.synchronize import Event
|
18
|
+
from typing import Generic, Iterable, Pattern, TypeVar
|
19
|
+
|
20
|
+
from xax.core.conf import field
|
21
|
+
from xax.core.state import State
|
22
|
+
from xax.task.mixins.logger import LoggerConfig, LoggerMixin
|
23
|
+
from xax.task.mixins.process import ProcessConfig, ProcessMixin
|
24
|
+
|
25
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
26
|
+
|
27
|
+
|
28
|
+
@dataclass
|
29
|
+
class GPUStatsOptions:
|
30
|
+
ping_interval: int = field(10, help="How often to check stats (in seconds)")
|
31
|
+
only_log_once: bool = field(False, help="If set, only log read stats one time")
|
32
|
+
|
33
|
+
|
34
|
+
@dataclass
|
35
|
+
class GPUStatsConfig(ProcessConfig, LoggerConfig):
|
36
|
+
gpu_stats: GPUStatsOptions = field(GPUStatsOptions(), help="GPU stats configuration")
|
37
|
+
|
38
|
+
|
39
|
+
Config = TypeVar("Config", bound=GPUStatsConfig)
|
40
|
+
|
41
|
+
NUMBER_REGEX: Pattern[str] = re.compile(r"[\d\.]+")
|
42
|
+
|
43
|
+
|
44
|
+
class GPUStats(Structure):
|
45
|
+
_fields_ = [
|
46
|
+
("index", c_uint32),
|
47
|
+
("memory_used", c_double),
|
48
|
+
("temperature", c_double),
|
49
|
+
("utilization", c_double),
|
50
|
+
]
|
51
|
+
|
52
|
+
|
53
|
+
@dataclass(frozen=True)
|
54
|
+
class GPUStatsInfo:
|
55
|
+
index: int
|
56
|
+
memory_used: float
|
57
|
+
temperature: float
|
58
|
+
utilization: float
|
59
|
+
|
60
|
+
@classmethod
|
61
|
+
def from_stats(cls, stats: GPUStats) -> "GPUStatsInfo":
|
62
|
+
return cls(
|
63
|
+
index=stats.index,
|
64
|
+
memory_used=stats.memory_used,
|
65
|
+
temperature=stats.temperature,
|
66
|
+
utilization=stats.utilization,
|
67
|
+
)
|
68
|
+
|
69
|
+
|
70
|
+
@functools.lru_cache(maxsize=None)
|
71
|
+
def get_num_gpus() -> int:
|
72
|
+
command = "nvidia-smi --query-gpu=index --format=csv --format=csv,noheader"
|
73
|
+
|
74
|
+
try:
|
75
|
+
with subprocess.Popen(command.split(), stdout=subprocess.PIPE, universal_newlines=True) as proc:
|
76
|
+
stdout = proc.stdout
|
77
|
+
assert stdout is not None
|
78
|
+
rows = iter(stdout.readline, "")
|
79
|
+
return len(list(rows))
|
80
|
+
|
81
|
+
except Exception:
|
82
|
+
logger.exception("Caught exception while trying to query `nvidia-smi`")
|
83
|
+
return 0
|
84
|
+
|
85
|
+
|
86
|
+
def parse_number(s: str) -> float:
|
87
|
+
match = NUMBER_REGEX.search(s)
|
88
|
+
if match is None:
|
89
|
+
raise ValueError(s)
|
90
|
+
return float(match.group())
|
91
|
+
|
92
|
+
|
93
|
+
def parse_gpu_stats(row: str) -> GPUStats:
|
94
|
+
cols = row.split(",")
|
95
|
+
index = int(cols[0].strip())
|
96
|
+
memory_total, memory_used, temperature, utilization = (parse_number(col) for col in cols[1:])
|
97
|
+
|
98
|
+
return GPUStats(
|
99
|
+
index=index,
|
100
|
+
memory_used=100 * memory_used / memory_total,
|
101
|
+
temperature=temperature,
|
102
|
+
utilization=utilization,
|
103
|
+
)
|
104
|
+
|
105
|
+
|
106
|
+
def gen_gpu_stats(loop_secs: int = 5) -> Iterable[GPUStats]:
|
107
|
+
fields = ",".join(["index", "memory.total", "memory.used", "temperature.gpu", "utilization.gpu"])
|
108
|
+
command = f"nvidia-smi --query-gpu={fields} --format=csv,noheader --loop={loop_secs}"
|
109
|
+
visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
|
110
|
+
visible_device_ids = None if visible_devices is None else {int(i.strip()) for i in visible_devices.split(",")}
|
111
|
+
|
112
|
+
try:
|
113
|
+
with subprocess.Popen(command.split(), stdout=subprocess.PIPE, universal_newlines=True) as proc:
|
114
|
+
stdout = proc.stdout
|
115
|
+
assert stdout is not None
|
116
|
+
rows = iter(stdout.readline, "")
|
117
|
+
for row in rows:
|
118
|
+
try:
|
119
|
+
stats = parse_gpu_stats(row)
|
120
|
+
except ValueError:
|
121
|
+
continue
|
122
|
+
if visible_device_ids is None or stats.index in visible_device_ids:
|
123
|
+
yield stats
|
124
|
+
|
125
|
+
except BaseException:
|
126
|
+
logger.error("Closing GPU stats monitor")
|
127
|
+
|
128
|
+
|
129
|
+
def worker(
|
130
|
+
ping_interval: int,
|
131
|
+
smems: list[ValueProxy[GPUStats]],
|
132
|
+
main_event: Event,
|
133
|
+
events: list[Event],
|
134
|
+
start_event: Event,
|
135
|
+
) -> None:
|
136
|
+
start_event.set()
|
137
|
+
|
138
|
+
logger.debug("Starting GPU stats monitor with PID %d", os.getpid())
|
139
|
+
|
140
|
+
for gpu_stat in gen_gpu_stats(ping_interval):
|
141
|
+
if gpu_stat.index >= len(smems):
|
142
|
+
logger.warning("GPU index %d is out of range", gpu_stat.index)
|
143
|
+
continue
|
144
|
+
smems[gpu_stat.index].set(gpu_stat)
|
145
|
+
events[gpu_stat.index].set()
|
146
|
+
main_event.set()
|
147
|
+
|
148
|
+
|
149
|
+
class GPUStatsMonitor:
|
150
|
+
def __init__(self, ping_interval: float, manager: SyncManager) -> None:
|
151
|
+
self._ping_interval = ping_interval
|
152
|
+
self._manager = manager
|
153
|
+
|
154
|
+
num_gpus = get_num_gpus()
|
155
|
+
self._main_event = manager.Event()
|
156
|
+
self._events = [manager.Event() for _ in range(num_gpus)]
|
157
|
+
self._start_event = manager.Event()
|
158
|
+
|
159
|
+
self._smems = [
|
160
|
+
manager.Value(
|
161
|
+
GPUStats,
|
162
|
+
GPUStats(
|
163
|
+
index=i,
|
164
|
+
memory_used=0.0,
|
165
|
+
temperature=0.0,
|
166
|
+
utilization=0.0,
|
167
|
+
),
|
168
|
+
)
|
169
|
+
for i in range(num_gpus)
|
170
|
+
]
|
171
|
+
self._gpu_stats: dict[int, GPUStatsInfo] = {}
|
172
|
+
self._proc: mp.Process | None = None
|
173
|
+
|
174
|
+
def get_if_set(self) -> dict[int, GPUStatsInfo]:
|
175
|
+
gpu_stats: dict[int, GPUStatsInfo] = {}
|
176
|
+
if self._main_event.is_set():
|
177
|
+
self._main_event.clear()
|
178
|
+
for i, event in enumerate(self._events):
|
179
|
+
if event.is_set():
|
180
|
+
event.clear()
|
181
|
+
gpu_stats[i] = GPUStatsInfo.from_stats(self._smems[i].get())
|
182
|
+
return gpu_stats
|
183
|
+
|
184
|
+
def get(self) -> dict[int, GPUStatsInfo]:
|
185
|
+
self._gpu_stats.update(self.get_if_set())
|
186
|
+
return self._gpu_stats
|
187
|
+
|
188
|
+
def start(self, wait: bool = False) -> None:
|
189
|
+
if self._proc is not None:
|
190
|
+
raise RuntimeError("GPUStatsMonitor already started")
|
191
|
+
if self._main_event.is_set():
|
192
|
+
self._main_event.clear()
|
193
|
+
for event in self._events:
|
194
|
+
if event.is_set():
|
195
|
+
event.clear()
|
196
|
+
if self._start_event.is_set():
|
197
|
+
self._start_event.clear()
|
198
|
+
self._gpu_stats.clear()
|
199
|
+
self._proc = mp.Process(
|
200
|
+
target=worker,
|
201
|
+
args=(self._ping_interval, self._smems, self._main_event, self._events, self._start_event),
|
202
|
+
daemon=True,
|
203
|
+
name="xax-gpu-stats",
|
204
|
+
)
|
205
|
+
self._proc.start()
|
206
|
+
if wait:
|
207
|
+
self._start_event.wait()
|
208
|
+
|
209
|
+
def stop(self) -> None:
|
210
|
+
if self._proc is None:
|
211
|
+
raise RuntimeError("GPUStatsMonitor not started")
|
212
|
+
if self._proc.is_alive():
|
213
|
+
self._proc.terminate()
|
214
|
+
logger.debug("Terminated GPU stats monitor; joining...")
|
215
|
+
self._proc.join()
|
216
|
+
self._proc = None
|
217
|
+
|
218
|
+
|
219
|
+
class GPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
|
220
|
+
"""Defines a task mixin for getting GPU statistics."""
|
221
|
+
|
222
|
+
_gpu_stats_monitor: GPUStatsMonitor | None
|
223
|
+
|
224
|
+
def __init__(self, config: Config) -> None:
|
225
|
+
super().__init__(config)
|
226
|
+
|
227
|
+
self._gpu_stats_monitor = None
|
228
|
+
if shutil.which("nvidia-smi") is not None:
|
229
|
+
self._gpu_stats_monitor = GPUStatsMonitor(config.gpu_stats.ping_interval, self._mp_manager)
|
230
|
+
|
231
|
+
def on_training_start(self, state: State) -> State:
|
232
|
+
state = super().on_training_start(state)
|
233
|
+
if self._gpu_stats_monitor is not None:
|
234
|
+
self._gpu_stats_monitor.start()
|
235
|
+
return state
|
236
|
+
|
237
|
+
def on_training_end(self, state: State) -> State:
|
238
|
+
state = super().on_training_end(state)
|
239
|
+
if self._gpu_stats_monitor is not None:
|
240
|
+
self._gpu_stats_monitor.stop()
|
241
|
+
return state
|
242
|
+
|
243
|
+
def on_step_start(self, state: State) -> State:
|
244
|
+
state = super().on_step_start(state)
|
245
|
+
|
246
|
+
if (monitor := self._gpu_stats_monitor) is None:
|
247
|
+
return state
|
248
|
+
stats = monitor.get_if_set() if self.config.gpu_stats.only_log_once else monitor.get()
|
249
|
+
|
250
|
+
for gpu_stat in stats.values():
|
251
|
+
if gpu_stat is None:
|
252
|
+
continue
|
253
|
+
self.logger.log_scalar(f"mem/{gpu_stat.index}", gpu_stat.memory_used, namespace="🔧 gpu")
|
254
|
+
self.logger.log_scalar(f"temp/{gpu_stat.index}", gpu_stat.temperature, namespace="🔧 gpu")
|
255
|
+
self.logger.log_scalar(f"util/{gpu_stat.index}", gpu_stat.utilization, namespace="🔧 gpu")
|
256
|
+
|
257
|
+
return state
|
@@ -0,0 +1,66 @@
|
|
1
|
+
"""Defines a mixin for incorporating some logging functionality."""
|
2
|
+
|
3
|
+
import os
|
4
|
+
from dataclasses import dataclass
|
5
|
+
from pathlib import Path
|
6
|
+
from types import TracebackType
|
7
|
+
from typing import Generic, Self, TypeVar
|
8
|
+
|
9
|
+
from xax.core.conf import Device as BaseDeviceConfig, field
|
10
|
+
from xax.core.state import State
|
11
|
+
from xax.task.base import BaseConfig, BaseTask
|
12
|
+
from xax.task.logger import Logger, LoggerImpl
|
13
|
+
from xax.task.loggers.json import JsonLogger
|
14
|
+
from xax.task.loggers.state import StateLogger
|
15
|
+
from xax.task.loggers.stdout import StdoutLogger
|
16
|
+
from xax.task.loggers.tensorboard import TensorboardLogger
|
17
|
+
from xax.task.mixins.artifacts import ArtifactsMixin
|
18
|
+
from xax.utils.text import is_interactive_session
|
19
|
+
|
20
|
+
|
21
|
+
@dataclass
|
22
|
+
class LoggerConfig(BaseConfig):
|
23
|
+
device: BaseDeviceConfig = field(BaseDeviceConfig(), help="Device configuration")
|
24
|
+
|
25
|
+
|
26
|
+
Config = TypeVar("Config", bound=LoggerConfig)
|
27
|
+
|
28
|
+
|
29
|
+
def get_env_var(name: str, default: bool) -> bool:
|
30
|
+
if name not in os.environ:
|
31
|
+
return default
|
32
|
+
return os.environ[name].strip() == "1"
|
33
|
+
|
34
|
+
|
35
|
+
class LoggerMixin(BaseTask[Config], Generic[Config]):
|
36
|
+
logger: Logger
|
37
|
+
|
38
|
+
def __init__(self, config: Config) -> None:
|
39
|
+
super().__init__(config)
|
40
|
+
|
41
|
+
self.logger = Logger()
|
42
|
+
|
43
|
+
def log_directory(self) -> Path | None:
|
44
|
+
return None
|
45
|
+
|
46
|
+
def add_logger(self, *logger: LoggerImpl) -> None:
|
47
|
+
self.logger.add_logger(*logger)
|
48
|
+
|
49
|
+
def set_loggers(self) -> None:
|
50
|
+
self.add_logger(StdoutLogger() if is_interactive_session() else JsonLogger())
|
51
|
+
if isinstance(self, ArtifactsMixin):
|
52
|
+
self.add_logger(
|
53
|
+
StateLogger(self.exp_dir),
|
54
|
+
TensorboardLogger(self.exp_dir),
|
55
|
+
)
|
56
|
+
|
57
|
+
def write_logs(self, state: State) -> None:
|
58
|
+
self.logger.write(state)
|
59
|
+
|
60
|
+
def __enter__(self) -> Self:
|
61
|
+
self.logger.__enter__()
|
62
|
+
return self
|
63
|
+
|
64
|
+
def __exit__(self, t: type[BaseException] | None, e: BaseException | None, tr: TracebackType | None) -> None:
|
65
|
+
self.logger.__exit__(t, e, tr)
|
66
|
+
return super().__exit__(t, e, tr)
|
@@ -0,0 +1,51 @@
|
|
1
|
+
"""Defines a base trainer mixin for handling subprocess monitoring jobs."""
|
2
|
+
|
3
|
+
import logging
|
4
|
+
import multiprocessing as mp
|
5
|
+
from dataclasses import dataclass
|
6
|
+
from multiprocessing.context import BaseContext
|
7
|
+
from multiprocessing.managers import SyncManager
|
8
|
+
from typing import Generic, TypeVar
|
9
|
+
|
10
|
+
from xax.core.conf import field
|
11
|
+
from xax.core.state import State
|
12
|
+
from xax.task.base import BaseConfig, BaseTask
|
13
|
+
|
14
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
15
|
+
|
16
|
+
|
17
|
+
@dataclass
|
18
|
+
class ProcessConfig(BaseConfig):
|
19
|
+
multiprocessing_context: str | None = field(None, help="The multiprocessing context to use")
|
20
|
+
|
21
|
+
|
22
|
+
Config = TypeVar("Config", bound=ProcessConfig)
|
23
|
+
|
24
|
+
|
25
|
+
class ProcessMixin(BaseTask[Config], Generic[Config]):
|
26
|
+
"""Defines a base trainer mixin for handling monitoring processes."""
|
27
|
+
|
28
|
+
_mp_ctx: BaseContext
|
29
|
+
_mp_manager: SyncManager
|
30
|
+
|
31
|
+
def __init__(self, config: Config) -> None:
|
32
|
+
super().__init__(config)
|
33
|
+
|
34
|
+
self._mp_ctx = mp.get_context(config.multiprocessing_context)
|
35
|
+
self._mp_manager = self._mp_ctx.Manager()
|
36
|
+
|
37
|
+
@property
|
38
|
+
def multiprocessing_context(self) -> BaseContext:
|
39
|
+
return self._mp_ctx
|
40
|
+
|
41
|
+
@property
|
42
|
+
def multiprocessing_manager(self) -> SyncManager:
|
43
|
+
return self._mp_manager
|
44
|
+
|
45
|
+
def on_training_end(self, state: State) -> State:
|
46
|
+
state = super().on_training_end(state)
|
47
|
+
|
48
|
+
self._mp_manager.shutdown()
|
49
|
+
self._mp_manager.join()
|
50
|
+
|
51
|
+
return state
|
@@ -0,0 +1,63 @@
|
|
1
|
+
"""Defines a mixin which provides a "run" method."""
|
2
|
+
|
3
|
+
import signal
|
4
|
+
from abc import ABC, abstractmethod
|
5
|
+
from dataclasses import dataclass
|
6
|
+
from types import FrameType
|
7
|
+
from typing import Callable, TypeVar
|
8
|
+
|
9
|
+
from xax.task.base import BaseConfig, BaseTask, RawConfigType
|
10
|
+
from xax.task.launchers.base import BaseLauncher
|
11
|
+
|
12
|
+
|
13
|
+
@dataclass
|
14
|
+
class RunnableConfig(BaseConfig):
|
15
|
+
pass
|
16
|
+
|
17
|
+
|
18
|
+
Config = TypeVar("Config", bound=RunnableConfig)
|
19
|
+
|
20
|
+
|
21
|
+
class RunnableMixin(BaseTask[Config], ABC):
|
22
|
+
"""Mixin which provides a "run" method."""
|
23
|
+
|
24
|
+
_signal_handlers: dict[signal.Signals, list[Callable[[], None]]]
|
25
|
+
_set_signal_handlers: set[signal.Signals]
|
26
|
+
|
27
|
+
def __init__(self, config: Config) -> None:
|
28
|
+
super().__init__(config)
|
29
|
+
|
30
|
+
self._signal_handlers = {}
|
31
|
+
self._set_signal_handlers = set()
|
32
|
+
|
33
|
+
@abstractmethod
|
34
|
+
def run(self) -> None:
|
35
|
+
"""Runs the task."""
|
36
|
+
|
37
|
+
@classmethod
|
38
|
+
def launch(
|
39
|
+
cls,
|
40
|
+
*cfgs: RawConfigType,
|
41
|
+
launcher: BaseLauncher | None = None,
|
42
|
+
use_cli: bool | list[str] = True,
|
43
|
+
) -> None:
|
44
|
+
if launcher is None:
|
45
|
+
from xax.task.launchers.cli import CliLauncher
|
46
|
+
|
47
|
+
launcher = CliLauncher()
|
48
|
+
launcher.launch(cls, *cfgs, use_cli=use_cli)
|
49
|
+
|
50
|
+
def call_signal_handler(self, sig: int | signal.Signals, frame: FrameType | None = None) -> None:
|
51
|
+
if isinstance(sig, int):
|
52
|
+
sig = signal.Signals(sig)
|
53
|
+
for signal_handler in self._signal_handlers.get(sig, []):
|
54
|
+
signal_handler()
|
55
|
+
|
56
|
+
def add_signal_handler(self, handler: Callable[[], None], *sigs: signal.Signals) -> None:
|
57
|
+
for sig in sigs:
|
58
|
+
if sig not in self._signal_handlers:
|
59
|
+
self._signal_handlers[sig] = []
|
60
|
+
if sig not in self._set_signal_handlers:
|
61
|
+
self._set_signal_handlers.add(sig)
|
62
|
+
signal.signal(sig, self.call_signal_handler)
|
63
|
+
self._signal_handlers[sig].append(handler)
|
@@ -0,0 +1,63 @@
|
|
1
|
+
"""Defines a mixin to wrap some steps in a context manager."""
|
2
|
+
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from types import TracebackType
|
5
|
+
from typing import ContextManager, Literal, TypeVar
|
6
|
+
|
7
|
+
from xax.task.base import BaseConfig, BaseTask
|
8
|
+
|
9
|
+
StepType = Literal[
|
10
|
+
"backward",
|
11
|
+
"change_mode",
|
12
|
+
"clip_grads",
|
13
|
+
"create_optimizers",
|
14
|
+
"forward",
|
15
|
+
"get_dataloader",
|
16
|
+
"get_dataset",
|
17
|
+
"get_prefetcher",
|
18
|
+
"get_model",
|
19
|
+
"get_optimizer",
|
20
|
+
"get_initial_opt_state",
|
21
|
+
"get_update_fn",
|
22
|
+
"load_checkpoint",
|
23
|
+
"log_losses",
|
24
|
+
"model_to_device",
|
25
|
+
"on_step_end",
|
26
|
+
"on_step_start",
|
27
|
+
"save_checkpoint",
|
28
|
+
"step",
|
29
|
+
"update_state",
|
30
|
+
"write_logs",
|
31
|
+
"zero_grads",
|
32
|
+
]
|
33
|
+
|
34
|
+
|
35
|
+
class StepContext(ContextManager):
|
36
|
+
"""Context manager to get the current step type."""
|
37
|
+
|
38
|
+
CURRENT_STEP: StepType | None = None
|
39
|
+
|
40
|
+
def __init__(self, step: StepType) -> None:
|
41
|
+
self.step = step
|
42
|
+
|
43
|
+
def __enter__(self) -> None:
|
44
|
+
StepContext.CURRENT_STEP = self.step
|
45
|
+
|
46
|
+
def __exit__(self, _t: type[BaseException] | None, _e: BaseException | None, _tr: TracebackType | None) -> None:
|
47
|
+
StepContext.CURRENT_STEP = None
|
48
|
+
|
49
|
+
|
50
|
+
@dataclass
|
51
|
+
class StepContextConfig(BaseConfig):
|
52
|
+
pass
|
53
|
+
|
54
|
+
|
55
|
+
Config = TypeVar("Config", bound=StepContextConfig)
|
56
|
+
|
57
|
+
|
58
|
+
class StepContextMixin(BaseTask[Config]):
|
59
|
+
def __init__(self, config: Config) -> None:
|
60
|
+
super().__init__(config)
|
61
|
+
|
62
|
+
def step_context(self, step: StepType) -> ContextManager:
|
63
|
+
return StepContext(step)
|