xax 0.0.1__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. xax/__init__.py +256 -1
  2. xax/core/conf.py +193 -0
  3. xax/core/state.py +81 -0
  4. xax/nn/__init__.py +0 -0
  5. xax/nn/embeddings.py +355 -0
  6. xax/nn/functions.py +77 -0
  7. xax/nn/parallel.py +211 -0
  8. xax/requirements-dev.txt +15 -0
  9. xax/requirements.txt +23 -0
  10. xax/task/__init__.py +0 -0
  11. xax/task/base.py +207 -0
  12. xax/task/launchers/__init__.py +0 -0
  13. xax/task/launchers/base.py +28 -0
  14. xax/task/launchers/cli.py +42 -0
  15. xax/task/launchers/single_process.py +30 -0
  16. xax/task/launchers/staged.py +29 -0
  17. xax/task/logger.py +783 -0
  18. xax/task/loggers/__init__.py +0 -0
  19. xax/task/loggers/callback.py +56 -0
  20. xax/task/loggers/json.py +121 -0
  21. xax/task/loggers/state.py +45 -0
  22. xax/task/loggers/stdout.py +170 -0
  23. xax/task/loggers/tensorboard.py +223 -0
  24. xax/task/mixins/__init__.py +12 -0
  25. xax/task/mixins/artifacts.py +114 -0
  26. xax/task/mixins/checkpointing.py +209 -0
  27. xax/task/mixins/cpu_stats.py +251 -0
  28. xax/task/mixins/data_loader.py +149 -0
  29. xax/task/mixins/gpu_stats.py +257 -0
  30. xax/task/mixins/logger.py +66 -0
  31. xax/task/mixins/process.py +51 -0
  32. xax/task/mixins/runnable.py +63 -0
  33. xax/task/mixins/step_wrapper.py +63 -0
  34. xax/task/mixins/train.py +541 -0
  35. xax/task/script.py +53 -0
  36. xax/task/task.py +65 -0
  37. xax/utils/__init__.py +0 -0
  38. xax/utils/data/__init__.py +0 -0
  39. xax/utils/data/collate.py +206 -0
  40. xax/utils/experiments.py +802 -0
  41. xax/utils/jax.py +14 -0
  42. xax/utils/logging.py +223 -0
  43. xax/utils/numpy.py +47 -0
  44. xax/utils/tensorboard.py +258 -0
  45. xax/utils/text.py +350 -0
  46. xax-0.0.5.dist-info/METADATA +40 -0
  47. xax-0.0.5.dist-info/RECORD +52 -0
  48. {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/WHEEL +1 -1
  49. xax-0.0.5.dist-info/top_level.txt +1 -0
  50. examples/mnist.py +0 -148
  51. xax-0.0.1.dist-info/METADATA +0 -21
  52. xax-0.0.1.dist-info/RECORD +0 -9
  53. xax-0.0.1.dist-info/top_level.txt +0 -2
  54. {examples → xax/core}/__init__.py +0 -0
  55. {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/LICENSE +0 -0
@@ -0,0 +1,149 @@
1
+ """Defines a mixin for instantiating dataloaders."""
2
+
3
+ import logging
4
+ from abc import ABC, abstractmethod
5
+ from dataclasses import dataclass
6
+ from typing import Generic, TypeVar
7
+
8
+ import jax
9
+ from dpshdl.dataloader import CollatedDataloaderItem, Dataloader
10
+ from dpshdl.dataset import Dataset, ErrorHandlingDataset
11
+ from dpshdl.prefetcher import Prefetcher
12
+ from omegaconf import II, MISSING
13
+
14
+ from xax.core.conf import field, is_missing
15
+ from xax.core.state import Phase
16
+ from xax.nn.functions import recursive_apply, set_random_seed
17
+ from xax.task.base import BaseConfig, BaseTask
18
+ from xax.task.mixins.process import ProcessConfig, ProcessMixin
19
+ from xax.utils.logging import LOG_ERROR_SUMMARY, configure_logging
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ T = TypeVar("T")
24
+ Tc_co = TypeVar("Tc_co", covariant=True)
25
+
26
+
27
+ @dataclass
28
+ class DataloaderErrorConfig:
29
+ sleep_backoff: float = field(0.1, help="The initial sleep time after an exception")
30
+ sleep_backoff_power: float = field(2.0, help="Power to raise the sleep time by after each consecutive exception")
31
+ maximum_exceptions: int = field(10, help="The maximum number of consecutive exceptions before raising an error")
32
+ backoff_after: int = field(5, help="The number of consecutive exceptions before starting to backoff")
33
+ traceback_depth: int = field(5, help="The depth of the traceback to print when an exception occurs")
34
+ flush_every_n_steps: int | None = field(None, help="Flush the error summary after this many steps")
35
+ flush_every_n_seconds: float | None = field(10.0, help="Flush the error summary after this many seconds")
36
+ log_exceptions_all_workers: bool = field(False, help="If set, log exceptions from all workers")
37
+
38
+
39
+ @dataclass
40
+ class DataloaderConfig:
41
+ num_workers: int | None = field(MISSING, help="Number of workers for loading samples")
42
+ prefetch_factor: int = field(2, help="Number of items to pre-fetch on each worker")
43
+ error: DataloaderErrorConfig = field(DataloaderErrorConfig(), help="Dataloader error configuration")
44
+
45
+
46
+ @dataclass
47
+ class DataloadersConfig(ProcessConfig, BaseConfig):
48
+ batch_size: int = field(MISSING, help="Size of each batch")
49
+ raise_dataloader_errors: bool = field(False, help="If set, raise dataloader errors inside the worker processes")
50
+ train_dl: DataloaderConfig = field(
51
+ DataloaderConfig(num_workers=II("mlfab.num_workers:-1")),
52
+ help="Train dataloader config",
53
+ )
54
+ valid_dl: DataloaderConfig = field(
55
+ DataloaderConfig(num_workers=1),
56
+ help="Valid dataloader config",
57
+ )
58
+ debug_dataloader: bool = field(False, help="Debug dataloaders")
59
+
60
+
61
+ Config = TypeVar("Config", bound=DataloadersConfig)
62
+
63
+
64
+ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config], ABC):
65
+ def __init__(self, config: Config) -> None:
66
+ if is_missing(config, "batch_size"):
67
+ config.batch_size = self.get_batch_size()
68
+
69
+ super().__init__(config)
70
+
71
+ def get_batch_size(self) -> int:
72
+ raise NotImplementedError(
73
+ "When `batch_size` is not specified in your training config, you should override the `get_batch_size` "
74
+ "method to return the desired training batch size."
75
+ )
76
+
77
+ def dataloader_config(self, phase: Phase) -> DataloaderConfig:
78
+ match phase:
79
+ case "train":
80
+ return self.config.train_dl
81
+ case "valid":
82
+ return self.config.valid_dl
83
+ case _:
84
+ raise KeyError(f"Unknown phase: {phase}")
85
+
86
+ @abstractmethod
87
+ def get_dataset(self, phase: Phase) -> Dataset:
88
+ """Returns the dataset for the given phase.
89
+
90
+ Args:
91
+ phase: The phase for the dataset to return.
92
+
93
+ Returns:
94
+ The dataset for the given phase.
95
+ """
96
+
97
+ def get_dataloader(self, dataset: Dataset[T, Tc_co], phase: Phase) -> Dataloader[T, Tc_co]:
98
+ debugging = self.config.debug_dataloader
99
+ if debugging:
100
+ logger.warning("Parallel dataloaders disabled in debugging mode")
101
+
102
+ cfg = self.dataloader_config(phase)
103
+
104
+ # Wraps the dataset to handle errors.
105
+ dataset = ErrorHandlingDataset(
106
+ dataset=dataset,
107
+ sleep_backoff=cfg.error.sleep_backoff,
108
+ sleep_backoff_power=cfg.error.sleep_backoff_power,
109
+ maximum_exceptions=cfg.error.maximum_exceptions,
110
+ backoff_after=cfg.error.backoff_after,
111
+ traceback_depth=cfg.error.traceback_depth,
112
+ flush_every_n_steps=cfg.error.flush_every_n_steps,
113
+ flush_every_n_seconds=cfg.error.flush_every_n_seconds,
114
+ log_exceptions_all_workers=cfg.error.log_exceptions_all_workers,
115
+ log_level=LOG_ERROR_SUMMARY,
116
+ )
117
+
118
+ return Dataloader(
119
+ dataset=dataset,
120
+ batch_size=self.config.batch_size,
121
+ num_workers=0 if debugging else cfg.num_workers,
122
+ prefetch_factor=cfg.prefetch_factor,
123
+ mp_manager=self.multiprocessing_manager,
124
+ dataloader_worker_init_fn=self.dataloader_worker_init_fn,
125
+ collate_worker_init_fn=self.collate_worker_init_fn,
126
+ item_callback=self.dataloader_item_callback,
127
+ raise_errs=self.config.raise_dataloader_errors,
128
+ )
129
+
130
+ def get_prefetcher(self, dataloader: Dataloader[T, Tc_co]) -> Prefetcher[Tc_co, Tc_co]:
131
+ return Prefetcher(to_device_func=self.to_device_fn, dataloader=dataloader)
132
+
133
+ @classmethod
134
+ def to_device_fn(cls, sample: T) -> T:
135
+ return recursive_apply(sample, jax.device_put, include_numpy=True)
136
+
137
+ @classmethod
138
+ def dataloader_worker_init_fn(cls, worker_id: int, num_workers: int) -> None:
139
+ configure_logging(prefix=f"{worker_id}")
140
+ set_random_seed(offset=worker_id + 1)
141
+
142
+ @classmethod
143
+ def collate_worker_init_fn(cls) -> None:
144
+ configure_logging(prefix="collate")
145
+ set_random_seed(offset=-1)
146
+
147
+ @classmethod
148
+ def dataloader_item_callback(cls, item: CollatedDataloaderItem) -> None:
149
+ pass
@@ -0,0 +1,257 @@
1
+ """A task mixin for logging GPU statistics.
2
+
3
+ This logs GPU memory and utilization in a background process using
4
+ ``nvidia-smi``, if a GPU is available in the system.
5
+ """
6
+
7
+ import functools
8
+ import logging
9
+ import multiprocessing as mp
10
+ import os
11
+ import re
12
+ import shutil
13
+ import subprocess
14
+ from ctypes import Structure, c_double, c_uint32
15
+ from dataclasses import dataclass
16
+ from multiprocessing.managers import SyncManager, ValueProxy
17
+ from multiprocessing.synchronize import Event
18
+ from typing import Generic, Iterable, Pattern, TypeVar
19
+
20
+ from xax.core.conf import field
21
+ from xax.core.state import State
22
+ from xax.task.mixins.logger import LoggerConfig, LoggerMixin
23
+ from xax.task.mixins.process import ProcessConfig, ProcessMixin
24
+
25
+ logger: logging.Logger = logging.getLogger(__name__)
26
+
27
+
28
+ @dataclass
29
+ class GPUStatsOptions:
30
+ ping_interval: int = field(10, help="How often to check stats (in seconds)")
31
+ only_log_once: bool = field(False, help="If set, only log read stats one time")
32
+
33
+
34
+ @dataclass
35
+ class GPUStatsConfig(ProcessConfig, LoggerConfig):
36
+ gpu_stats: GPUStatsOptions = field(GPUStatsOptions(), help="GPU stats configuration")
37
+
38
+
39
+ Config = TypeVar("Config", bound=GPUStatsConfig)
40
+
41
+ NUMBER_REGEX: Pattern[str] = re.compile(r"[\d\.]+")
42
+
43
+
44
+ class GPUStats(Structure):
45
+ _fields_ = [
46
+ ("index", c_uint32),
47
+ ("memory_used", c_double),
48
+ ("temperature", c_double),
49
+ ("utilization", c_double),
50
+ ]
51
+
52
+
53
+ @dataclass(frozen=True)
54
+ class GPUStatsInfo:
55
+ index: int
56
+ memory_used: float
57
+ temperature: float
58
+ utilization: float
59
+
60
+ @classmethod
61
+ def from_stats(cls, stats: GPUStats) -> "GPUStatsInfo":
62
+ return cls(
63
+ index=stats.index,
64
+ memory_used=stats.memory_used,
65
+ temperature=stats.temperature,
66
+ utilization=stats.utilization,
67
+ )
68
+
69
+
70
+ @functools.lru_cache(maxsize=None)
71
+ def get_num_gpus() -> int:
72
+ command = "nvidia-smi --query-gpu=index --format=csv --format=csv,noheader"
73
+
74
+ try:
75
+ with subprocess.Popen(command.split(), stdout=subprocess.PIPE, universal_newlines=True) as proc:
76
+ stdout = proc.stdout
77
+ assert stdout is not None
78
+ rows = iter(stdout.readline, "")
79
+ return len(list(rows))
80
+
81
+ except Exception:
82
+ logger.exception("Caught exception while trying to query `nvidia-smi`")
83
+ return 0
84
+
85
+
86
+ def parse_number(s: str) -> float:
87
+ match = NUMBER_REGEX.search(s)
88
+ if match is None:
89
+ raise ValueError(s)
90
+ return float(match.group())
91
+
92
+
93
+ def parse_gpu_stats(row: str) -> GPUStats:
94
+ cols = row.split(",")
95
+ index = int(cols[0].strip())
96
+ memory_total, memory_used, temperature, utilization = (parse_number(col) for col in cols[1:])
97
+
98
+ return GPUStats(
99
+ index=index,
100
+ memory_used=100 * memory_used / memory_total,
101
+ temperature=temperature,
102
+ utilization=utilization,
103
+ )
104
+
105
+
106
+ def gen_gpu_stats(loop_secs: int = 5) -> Iterable[GPUStats]:
107
+ fields = ",".join(["index", "memory.total", "memory.used", "temperature.gpu", "utilization.gpu"])
108
+ command = f"nvidia-smi --query-gpu={fields} --format=csv,noheader --loop={loop_secs}"
109
+ visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES")
110
+ visible_device_ids = None if visible_devices is None else {int(i.strip()) for i in visible_devices.split(",")}
111
+
112
+ try:
113
+ with subprocess.Popen(command.split(), stdout=subprocess.PIPE, universal_newlines=True) as proc:
114
+ stdout = proc.stdout
115
+ assert stdout is not None
116
+ rows = iter(stdout.readline, "")
117
+ for row in rows:
118
+ try:
119
+ stats = parse_gpu_stats(row)
120
+ except ValueError:
121
+ continue
122
+ if visible_device_ids is None or stats.index in visible_device_ids:
123
+ yield stats
124
+
125
+ except BaseException:
126
+ logger.error("Closing GPU stats monitor")
127
+
128
+
129
+ def worker(
130
+ ping_interval: int,
131
+ smems: list[ValueProxy[GPUStats]],
132
+ main_event: Event,
133
+ events: list[Event],
134
+ start_event: Event,
135
+ ) -> None:
136
+ start_event.set()
137
+
138
+ logger.debug("Starting GPU stats monitor with PID %d", os.getpid())
139
+
140
+ for gpu_stat in gen_gpu_stats(ping_interval):
141
+ if gpu_stat.index >= len(smems):
142
+ logger.warning("GPU index %d is out of range", gpu_stat.index)
143
+ continue
144
+ smems[gpu_stat.index].set(gpu_stat)
145
+ events[gpu_stat.index].set()
146
+ main_event.set()
147
+
148
+
149
+ class GPUStatsMonitor:
150
+ def __init__(self, ping_interval: float, manager: SyncManager) -> None:
151
+ self._ping_interval = ping_interval
152
+ self._manager = manager
153
+
154
+ num_gpus = get_num_gpus()
155
+ self._main_event = manager.Event()
156
+ self._events = [manager.Event() for _ in range(num_gpus)]
157
+ self._start_event = manager.Event()
158
+
159
+ self._smems = [
160
+ manager.Value(
161
+ GPUStats,
162
+ GPUStats(
163
+ index=i,
164
+ memory_used=0.0,
165
+ temperature=0.0,
166
+ utilization=0.0,
167
+ ),
168
+ )
169
+ for i in range(num_gpus)
170
+ ]
171
+ self._gpu_stats: dict[int, GPUStatsInfo] = {}
172
+ self._proc: mp.Process | None = None
173
+
174
+ def get_if_set(self) -> dict[int, GPUStatsInfo]:
175
+ gpu_stats: dict[int, GPUStatsInfo] = {}
176
+ if self._main_event.is_set():
177
+ self._main_event.clear()
178
+ for i, event in enumerate(self._events):
179
+ if event.is_set():
180
+ event.clear()
181
+ gpu_stats[i] = GPUStatsInfo.from_stats(self._smems[i].get())
182
+ return gpu_stats
183
+
184
+ def get(self) -> dict[int, GPUStatsInfo]:
185
+ self._gpu_stats.update(self.get_if_set())
186
+ return self._gpu_stats
187
+
188
+ def start(self, wait: bool = False) -> None:
189
+ if self._proc is not None:
190
+ raise RuntimeError("GPUStatsMonitor already started")
191
+ if self._main_event.is_set():
192
+ self._main_event.clear()
193
+ for event in self._events:
194
+ if event.is_set():
195
+ event.clear()
196
+ if self._start_event.is_set():
197
+ self._start_event.clear()
198
+ self._gpu_stats.clear()
199
+ self._proc = mp.Process(
200
+ target=worker,
201
+ args=(self._ping_interval, self._smems, self._main_event, self._events, self._start_event),
202
+ daemon=True,
203
+ name="xax-gpu-stats",
204
+ )
205
+ self._proc.start()
206
+ if wait:
207
+ self._start_event.wait()
208
+
209
+ def stop(self) -> None:
210
+ if self._proc is None:
211
+ raise RuntimeError("GPUStatsMonitor not started")
212
+ if self._proc.is_alive():
213
+ self._proc.terminate()
214
+ logger.debug("Terminated GPU stats monitor; joining...")
215
+ self._proc.join()
216
+ self._proc = None
217
+
218
+
219
+ class GPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
220
+ """Defines a task mixin for getting GPU statistics."""
221
+
222
+ _gpu_stats_monitor: GPUStatsMonitor | None
223
+
224
+ def __init__(self, config: Config) -> None:
225
+ super().__init__(config)
226
+
227
+ self._gpu_stats_monitor = None
228
+ if shutil.which("nvidia-smi") is not None:
229
+ self._gpu_stats_monitor = GPUStatsMonitor(config.gpu_stats.ping_interval, self._mp_manager)
230
+
231
+ def on_training_start(self, state: State) -> State:
232
+ state = super().on_training_start(state)
233
+ if self._gpu_stats_monitor is not None:
234
+ self._gpu_stats_monitor.start()
235
+ return state
236
+
237
+ def on_training_end(self, state: State) -> State:
238
+ state = super().on_training_end(state)
239
+ if self._gpu_stats_monitor is not None:
240
+ self._gpu_stats_monitor.stop()
241
+ return state
242
+
243
+ def on_step_start(self, state: State) -> State:
244
+ state = super().on_step_start(state)
245
+
246
+ if (monitor := self._gpu_stats_monitor) is None:
247
+ return state
248
+ stats = monitor.get_if_set() if self.config.gpu_stats.only_log_once else monitor.get()
249
+
250
+ for gpu_stat in stats.values():
251
+ if gpu_stat is None:
252
+ continue
253
+ self.logger.log_scalar(f"mem/{gpu_stat.index}", gpu_stat.memory_used, namespace="🔧 gpu")
254
+ self.logger.log_scalar(f"temp/{gpu_stat.index}", gpu_stat.temperature, namespace="🔧 gpu")
255
+ self.logger.log_scalar(f"util/{gpu_stat.index}", gpu_stat.utilization, namespace="🔧 gpu")
256
+
257
+ return state
@@ -0,0 +1,66 @@
1
+ """Defines a mixin for incorporating some logging functionality."""
2
+
3
+ import os
4
+ from dataclasses import dataclass
5
+ from pathlib import Path
6
+ from types import TracebackType
7
+ from typing import Generic, Self, TypeVar
8
+
9
+ from xax.core.conf import Device as BaseDeviceConfig, field
10
+ from xax.core.state import State
11
+ from xax.task.base import BaseConfig, BaseTask
12
+ from xax.task.logger import Logger, LoggerImpl
13
+ from xax.task.loggers.json import JsonLogger
14
+ from xax.task.loggers.state import StateLogger
15
+ from xax.task.loggers.stdout import StdoutLogger
16
+ from xax.task.loggers.tensorboard import TensorboardLogger
17
+ from xax.task.mixins.artifacts import ArtifactsMixin
18
+ from xax.utils.text import is_interactive_session
19
+
20
+
21
+ @dataclass
22
+ class LoggerConfig(BaseConfig):
23
+ device: BaseDeviceConfig = field(BaseDeviceConfig(), help="Device configuration")
24
+
25
+
26
+ Config = TypeVar("Config", bound=LoggerConfig)
27
+
28
+
29
+ def get_env_var(name: str, default: bool) -> bool:
30
+ if name not in os.environ:
31
+ return default
32
+ return os.environ[name].strip() == "1"
33
+
34
+
35
+ class LoggerMixin(BaseTask[Config], Generic[Config]):
36
+ logger: Logger
37
+
38
+ def __init__(self, config: Config) -> None:
39
+ super().__init__(config)
40
+
41
+ self.logger = Logger()
42
+
43
+ def log_directory(self) -> Path | None:
44
+ return None
45
+
46
+ def add_logger(self, *logger: LoggerImpl) -> None:
47
+ self.logger.add_logger(*logger)
48
+
49
+ def set_loggers(self) -> None:
50
+ self.add_logger(StdoutLogger() if is_interactive_session() else JsonLogger())
51
+ if isinstance(self, ArtifactsMixin):
52
+ self.add_logger(
53
+ StateLogger(self.exp_dir),
54
+ TensorboardLogger(self.exp_dir),
55
+ )
56
+
57
+ def write_logs(self, state: State) -> None:
58
+ self.logger.write(state)
59
+
60
+ def __enter__(self) -> Self:
61
+ self.logger.__enter__()
62
+ return self
63
+
64
+ def __exit__(self, t: type[BaseException] | None, e: BaseException | None, tr: TracebackType | None) -> None:
65
+ self.logger.__exit__(t, e, tr)
66
+ return super().__exit__(t, e, tr)
@@ -0,0 +1,51 @@
1
+ """Defines a base trainer mixin for handling subprocess monitoring jobs."""
2
+
3
+ import logging
4
+ import multiprocessing as mp
5
+ from dataclasses import dataclass
6
+ from multiprocessing.context import BaseContext
7
+ from multiprocessing.managers import SyncManager
8
+ from typing import Generic, TypeVar
9
+
10
+ from xax.core.conf import field
11
+ from xax.core.state import State
12
+ from xax.task.base import BaseConfig, BaseTask
13
+
14
+ logger: logging.Logger = logging.getLogger(__name__)
15
+
16
+
17
+ @dataclass
18
+ class ProcessConfig(BaseConfig):
19
+ multiprocessing_context: str | None = field(None, help="The multiprocessing context to use")
20
+
21
+
22
+ Config = TypeVar("Config", bound=ProcessConfig)
23
+
24
+
25
+ class ProcessMixin(BaseTask[Config], Generic[Config]):
26
+ """Defines a base trainer mixin for handling monitoring processes."""
27
+
28
+ _mp_ctx: BaseContext
29
+ _mp_manager: SyncManager
30
+
31
+ def __init__(self, config: Config) -> None:
32
+ super().__init__(config)
33
+
34
+ self._mp_ctx = mp.get_context(config.multiprocessing_context)
35
+ self._mp_manager = self._mp_ctx.Manager()
36
+
37
+ @property
38
+ def multiprocessing_context(self) -> BaseContext:
39
+ return self._mp_ctx
40
+
41
+ @property
42
+ def multiprocessing_manager(self) -> SyncManager:
43
+ return self._mp_manager
44
+
45
+ def on_training_end(self, state: State) -> State:
46
+ state = super().on_training_end(state)
47
+
48
+ self._mp_manager.shutdown()
49
+ self._mp_manager.join()
50
+
51
+ return state
@@ -0,0 +1,63 @@
1
+ """Defines a mixin which provides a "run" method."""
2
+
3
+ import signal
4
+ from abc import ABC, abstractmethod
5
+ from dataclasses import dataclass
6
+ from types import FrameType
7
+ from typing import Callable, TypeVar
8
+
9
+ from xax.task.base import BaseConfig, BaseTask, RawConfigType
10
+ from xax.task.launchers.base import BaseLauncher
11
+
12
+
13
+ @dataclass
14
+ class RunnableConfig(BaseConfig):
15
+ pass
16
+
17
+
18
+ Config = TypeVar("Config", bound=RunnableConfig)
19
+
20
+
21
+ class RunnableMixin(BaseTask[Config], ABC):
22
+ """Mixin which provides a "run" method."""
23
+
24
+ _signal_handlers: dict[signal.Signals, list[Callable[[], None]]]
25
+ _set_signal_handlers: set[signal.Signals]
26
+
27
+ def __init__(self, config: Config) -> None:
28
+ super().__init__(config)
29
+
30
+ self._signal_handlers = {}
31
+ self._set_signal_handlers = set()
32
+
33
+ @abstractmethod
34
+ def run(self) -> None:
35
+ """Runs the task."""
36
+
37
+ @classmethod
38
+ def launch(
39
+ cls,
40
+ *cfgs: RawConfigType,
41
+ launcher: BaseLauncher | None = None,
42
+ use_cli: bool | list[str] = True,
43
+ ) -> None:
44
+ if launcher is None:
45
+ from xax.task.launchers.cli import CliLauncher
46
+
47
+ launcher = CliLauncher()
48
+ launcher.launch(cls, *cfgs, use_cli=use_cli)
49
+
50
+ def call_signal_handler(self, sig: int | signal.Signals, frame: FrameType | None = None) -> None:
51
+ if isinstance(sig, int):
52
+ sig = signal.Signals(sig)
53
+ for signal_handler in self._signal_handlers.get(sig, []):
54
+ signal_handler()
55
+
56
+ def add_signal_handler(self, handler: Callable[[], None], *sigs: signal.Signals) -> None:
57
+ for sig in sigs:
58
+ if sig not in self._signal_handlers:
59
+ self._signal_handlers[sig] = []
60
+ if sig not in self._set_signal_handlers:
61
+ self._set_signal_handlers.add(sig)
62
+ signal.signal(sig, self.call_signal_handler)
63
+ self._signal_handlers[sig].append(handler)
@@ -0,0 +1,63 @@
1
+ """Defines a mixin to wrap some steps in a context manager."""
2
+
3
+ from dataclasses import dataclass
4
+ from types import TracebackType
5
+ from typing import ContextManager, Literal, TypeVar
6
+
7
+ from xax.task.base import BaseConfig, BaseTask
8
+
9
+ StepType = Literal[
10
+ "backward",
11
+ "change_mode",
12
+ "clip_grads",
13
+ "create_optimizers",
14
+ "forward",
15
+ "get_dataloader",
16
+ "get_dataset",
17
+ "get_prefetcher",
18
+ "get_model",
19
+ "get_optimizer",
20
+ "get_initial_opt_state",
21
+ "get_update_fn",
22
+ "load_checkpoint",
23
+ "log_losses",
24
+ "model_to_device",
25
+ "on_step_end",
26
+ "on_step_start",
27
+ "save_checkpoint",
28
+ "step",
29
+ "update_state",
30
+ "write_logs",
31
+ "zero_grads",
32
+ ]
33
+
34
+
35
+ class StepContext(ContextManager):
36
+ """Context manager to get the current step type."""
37
+
38
+ CURRENT_STEP: StepType | None = None
39
+
40
+ def __init__(self, step: StepType) -> None:
41
+ self.step = step
42
+
43
+ def __enter__(self) -> None:
44
+ StepContext.CURRENT_STEP = self.step
45
+
46
+ def __exit__(self, _t: type[BaseException] | None, _e: BaseException | None, _tr: TracebackType | None) -> None:
47
+ StepContext.CURRENT_STEP = None
48
+
49
+
50
+ @dataclass
51
+ class StepContextConfig(BaseConfig):
52
+ pass
53
+
54
+
55
+ Config = TypeVar("Config", bound=StepContextConfig)
56
+
57
+
58
+ class StepContextMixin(BaseTask[Config]):
59
+ def __init__(self, config: Config) -> None:
60
+ super().__init__(config)
61
+
62
+ def step_context(self, step: StepType) -> ContextManager:
63
+ return StepContext(step)