xax 0.0.1__py3-none-any.whl → 0.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. xax/__init__.py +256 -1
  2. xax/core/conf.py +193 -0
  3. xax/core/state.py +81 -0
  4. xax/nn/__init__.py +0 -0
  5. xax/nn/embeddings.py +355 -0
  6. xax/nn/functions.py +77 -0
  7. xax/nn/parallel.py +211 -0
  8. xax/requirements-dev.txt +15 -0
  9. xax/requirements.txt +23 -0
  10. xax/task/__init__.py +0 -0
  11. xax/task/base.py +207 -0
  12. xax/task/launchers/__init__.py +0 -0
  13. xax/task/launchers/base.py +28 -0
  14. xax/task/launchers/cli.py +42 -0
  15. xax/task/launchers/single_process.py +30 -0
  16. xax/task/launchers/staged.py +29 -0
  17. xax/task/logger.py +783 -0
  18. xax/task/loggers/__init__.py +0 -0
  19. xax/task/loggers/callback.py +56 -0
  20. xax/task/loggers/json.py +121 -0
  21. xax/task/loggers/state.py +45 -0
  22. xax/task/loggers/stdout.py +170 -0
  23. xax/task/loggers/tensorboard.py +223 -0
  24. xax/task/mixins/__init__.py +12 -0
  25. xax/task/mixins/artifacts.py +114 -0
  26. xax/task/mixins/checkpointing.py +209 -0
  27. xax/task/mixins/cpu_stats.py +251 -0
  28. xax/task/mixins/data_loader.py +149 -0
  29. xax/task/mixins/gpu_stats.py +257 -0
  30. xax/task/mixins/logger.py +66 -0
  31. xax/task/mixins/process.py +51 -0
  32. xax/task/mixins/runnable.py +63 -0
  33. xax/task/mixins/step_wrapper.py +63 -0
  34. xax/task/mixins/train.py +541 -0
  35. xax/task/script.py +53 -0
  36. xax/task/task.py +65 -0
  37. xax/utils/__init__.py +0 -0
  38. xax/utils/data/__init__.py +0 -0
  39. xax/utils/data/collate.py +206 -0
  40. xax/utils/experiments.py +802 -0
  41. xax/utils/jax.py +14 -0
  42. xax/utils/logging.py +223 -0
  43. xax/utils/numpy.py +47 -0
  44. xax/utils/tensorboard.py +258 -0
  45. xax/utils/text.py +350 -0
  46. xax-0.0.5.dist-info/METADATA +40 -0
  47. xax-0.0.5.dist-info/RECORD +52 -0
  48. {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/WHEEL +1 -1
  49. xax-0.0.5.dist-info/top_level.txt +1 -0
  50. examples/mnist.py +0 -148
  51. xax-0.0.1.dist-info/METADATA +0 -21
  52. xax-0.0.1.dist-info/RECORD +0 -9
  53. xax-0.0.1.dist-info/top_level.txt +0 -2
  54. {examples → xax/core}/__init__.py +0 -0
  55. {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/LICENSE +0 -0
@@ -0,0 +1,114 @@
1
+ """Defines a mixin for storing any task artifacts."""
2
+
3
+ import functools
4
+ import inspect
5
+ import logging
6
+ import os
7
+ from dataclasses import dataclass
8
+ from pathlib import Path
9
+ from typing import Self, TypeVar
10
+
11
+ from xax.core.conf import field, get_run_dir
12
+ from xax.core.state import State
13
+ from xax.nn.parallel import is_master
14
+ from xax.task.base import BaseConfig, BaseTask
15
+ from xax.utils.experiments import stage_environment
16
+ from xax.utils.logging import LOG_STATUS
17
+ from xax.utils.text import show_info
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ @dataclass
23
+ class ArtifactsConfig(BaseConfig):
24
+ exp_dir: str | None = field(None, help="The fixed experiment directory")
25
+
26
+
27
+ Config = TypeVar("Config", bound=ArtifactsConfig)
28
+
29
+
30
+ class ArtifactsMixin(BaseTask[Config]):
31
+ _exp_dir: Path | None
32
+
33
+ def __init__(self, config: Config) -> None:
34
+ super().__init__(config)
35
+
36
+ self._exp_dir = None
37
+
38
+ @functools.cached_property
39
+ def run_dir(self) -> Path:
40
+ run_dir = get_run_dir()
41
+ if run_dir is None:
42
+ task_file = inspect.getfile(self.__class__)
43
+ run_dir = Path(task_file).resolve().parent
44
+ return run_dir / self.task_name
45
+
46
+ def set_exp_dir(self, exp_dir: Path) -> Self:
47
+ self._exp_dir = exp_dir
48
+ return self
49
+
50
+ @property
51
+ def exp_dir(self) -> Path:
52
+ return self.get_exp_dir()
53
+
54
+ def add_lock_file(self, lock_type: str, *, exists_ok: bool = False) -> None:
55
+ if (lock_file := self.exp_dir / f".lock_{lock_type}").exists():
56
+ if not exists_ok:
57
+ raise RuntimeError(f"Lock file already exists at {lock_file}")
58
+ else:
59
+ with open(lock_file, "w", encoding="utf-8") as f:
60
+ f.write(f"PID: {os.getpid()}")
61
+
62
+ def remove_lock_file(self, lock_type: str, *, missing_ok: bool = False) -> None:
63
+ if (lock_file := self.exp_dir / f".lock_{lock_type}").exists():
64
+ lock_file.unlink()
65
+ elif not missing_ok:
66
+ raise RuntimeError(f"Lock file not found at {lock_file}")
67
+
68
+ def get_exp_dir(self) -> Path:
69
+ if self._exp_dir is not None:
70
+ return self._exp_dir
71
+
72
+ if self.config.exp_dir is not None:
73
+ exp_dir = Path(self.config.exp_dir).expanduser().resolve()
74
+ exp_dir.mkdir(parents=True, exist_ok=True)
75
+ self._exp_dir = exp_dir
76
+ logger.log(LOG_STATUS, self._exp_dir)
77
+ return self._exp_dir
78
+
79
+ def get_exp_dir(run_id: int) -> Path:
80
+ return self.run_dir / f"run_{run_id}"
81
+
82
+ def has_lock_file(exp_dir: Path, lock_type: str | None = None) -> bool:
83
+ if lock_type is not None:
84
+ return (exp_dir / f".lock_{lock_type}").exists()
85
+ return any(exp_dir.glob(".lock_*"))
86
+
87
+ run_id = 0
88
+ while (exp_dir := get_exp_dir(run_id)).is_dir() and has_lock_file(exp_dir):
89
+ run_id += 1
90
+ exp_dir.mkdir(exist_ok=True, parents=True)
91
+ self._exp_dir = exp_dir.expanduser().resolve()
92
+ logger.log(LOG_STATUS, self._exp_dir)
93
+ return self._exp_dir
94
+
95
+ @functools.lru_cache(maxsize=None)
96
+ def stage_environment(self) -> Path | None:
97
+ stage_dir = (self.exp_dir / "code").resolve()
98
+ try:
99
+ stage_environment(self, stage_dir)
100
+ except Exception:
101
+ logger.exception("Failed to stage environment!")
102
+ return None
103
+ return stage_dir
104
+
105
+ def on_training_end(self, state: State) -> State:
106
+ state = super().on_training_end(state)
107
+
108
+ if is_master():
109
+ if self._exp_dir is None:
110
+ show_info("Exiting training job", important=True)
111
+ else:
112
+ show_info(f"Exiting training job for {self.exp_dir}", important=True)
113
+
114
+ return state
@@ -0,0 +1,209 @@
1
+ """Defines a mixin for handling model checkpointing."""
2
+
3
+ import io
4
+ import json
5
+ import logging
6
+ import tarfile
7
+ from dataclasses import asdict, dataclass
8
+ from pathlib import Path
9
+ from typing import Any, Callable, Generic, Literal, TypeVar, cast, overload
10
+
11
+ import cloudpickle
12
+ import optax
13
+ from jaxtyping import PyTree
14
+ from omegaconf import DictConfig, OmegaConf
15
+
16
+ from xax.core.conf import field
17
+ from xax.core.state import State
18
+ from xax.nn.parallel import is_master
19
+ from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ CheckpointPart = Literal["model", "opt", "opt_state", "state", "config"]
24
+
25
+
26
+ def get_ckpt_path(exp_dir: Path, state: State | None = None) -> Path:
27
+ """Defines the path to the checkpoint for a given state.
28
+
29
+ Args:
30
+ exp_dir: The experiment directory
31
+ state: The current trainer state
32
+
33
+ Returns:
34
+ The path to the checkpoint file.
35
+ """
36
+ if state is None:
37
+ return exp_dir / "checkpoints" / "ckpt.bin"
38
+ return exp_dir / "checkpoints" / f"ckpt.{state.num_steps}.bin"
39
+
40
+
41
+ @dataclass
42
+ class CheckpointingConfig(ArtifactsConfig):
43
+ save_every_n_steps: int | None = field(None, help="Save a checkpoint every N steps")
44
+ save_every_n_seconds: float | None = field(60.0 * 60.0, help="Save a checkpoint every N seconds")
45
+ only_save_most_recent: bool = field(True, help="Only keep the most recent checkpoint")
46
+ load_from_ckpt_path: str | None = field(None, help="If set, load initial model weights from this path")
47
+ load_ckpt_strict: bool = field(True, help="If set, only load weights for which have a matching key in the model")
48
+
49
+
50
+ Config = TypeVar("Config", bound=CheckpointingConfig)
51
+
52
+
53
+ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
54
+ def __init__(self, config: Config) -> None:
55
+ super().__init__(config)
56
+
57
+ self.__last_ckpt_time = 0.0
58
+
59
+ def get_ckpt_path(self, state: State | None = None) -> Path:
60
+ return get_ckpt_path(self.exp_dir, state)
61
+
62
+ def get_init_ckpt_path(self) -> Path | None:
63
+ if self._exp_dir is not None:
64
+ ckpt_path = self.get_ckpt_path()
65
+ if ckpt_path.exists():
66
+ return ckpt_path
67
+ if self.config.load_from_ckpt_path is not None:
68
+ ckpt_path = Path(self.config.load_from_ckpt_path)
69
+ assert ckpt_path.exists(), f"Checkpoint path {ckpt_path} does not exist."
70
+ return ckpt_path
71
+ return None
72
+
73
+ def should_checkpoint(self, state: State) -> bool:
74
+ if self.config.save_every_n_steps is not None:
75
+ if state.num_steps % self.config.save_every_n_steps == 0:
76
+ return True
77
+ if self.config.save_every_n_seconds is not None:
78
+ last_time, cur_time = self.__last_ckpt_time, state.elapsed_time_s
79
+ if cur_time - last_time >= self.config.save_every_n_seconds:
80
+ self.__last_ckpt_time = cur_time
81
+ return True
82
+ return False
83
+
84
+ @overload
85
+ def load_checkpoint(
86
+ self,
87
+ path: Path,
88
+ ) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]: ...
89
+
90
+ @overload
91
+ def load_checkpoint(self, path: Path, part: Literal["model"]) -> PyTree: ...
92
+
93
+ @overload
94
+ def load_checkpoint(self, path: Path, part: Literal["opt"]) -> optax.GradientTransformation: ...
95
+
96
+ @overload
97
+ def load_checkpoint(self, path: Path, part: Literal["opt_state"]) -> optax.OptState: ...
98
+
99
+ @overload
100
+ def load_checkpoint(self, path: Path, part: Literal["state"]) -> State: ...
101
+
102
+ @overload
103
+ def load_checkpoint(self, path: Path, part: Literal["config"]) -> DictConfig: ...
104
+
105
+ def load_checkpoint(
106
+ self,
107
+ path: Path,
108
+ part: CheckpointPart | None = None,
109
+ ) -> (
110
+ tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]
111
+ | PyTree
112
+ | optax.GradientTransformation
113
+ | optax.OptState
114
+ | State
115
+ | DictConfig
116
+ ):
117
+ with tarfile.open(path, "r:gz") as tar:
118
+
119
+ def get_model() -> PyTree:
120
+ if (model := tar.extractfile("model")) is None:
121
+ raise ValueError(f"Checkpoint does not contain a model file: {path}")
122
+ return cloudpickle.load(model)
123
+
124
+ def get_opt() -> optax.GradientTransformation:
125
+ if (opt := tar.extractfile("opt")) is None:
126
+ raise ValueError(f"Checkpoint does not contain an opt file: {path}")
127
+ return cloudpickle.load(opt)
128
+
129
+ def get_opt_state() -> optax.OptState:
130
+ if (opt_state := tar.extractfile("opt_state")) is None:
131
+ raise ValueError(f"Checkpoint does not contain an opt_state file: {path}")
132
+ return cloudpickle.load(opt_state)
133
+
134
+ def get_state() -> State:
135
+ if (state := tar.extractfile("state")) is None:
136
+ raise ValueError(f"Checkpoint does not contain a state file: {path}")
137
+ return State(**json.loads(state.read().decode()))
138
+
139
+ def get_config() -> DictConfig:
140
+ if (config := tar.extractfile("config")) is None:
141
+ raise ValueError(f"Checkpoint does not contain a config file: {path}")
142
+ return cast(DictConfig, OmegaConf.load(config))
143
+
144
+ match part:
145
+ case "model":
146
+ return get_model()
147
+ case "opt":
148
+ return get_opt()
149
+ case "opt_state":
150
+ return get_opt_state()
151
+ case "state":
152
+ return get_state()
153
+ case "config":
154
+ return get_config()
155
+ case None:
156
+ return get_model(), get_opt(), get_opt_state(), get_state(), get_config()
157
+ case _:
158
+ raise ValueError(f"Invalid checkpoint part: {part}")
159
+
160
+ def save_checkpoint(
161
+ self,
162
+ model: PyTree,
163
+ optimizer: optax.GradientTransformation,
164
+ opt_state: optax.OptState,
165
+ state: State,
166
+ ) -> Path:
167
+ ckpt_path = self.get_ckpt_path(state)
168
+
169
+ if not is_master():
170
+ return ckpt_path
171
+
172
+ # Gets the path to the last checkpoint.
173
+ logger.info("Saving checkpoint to %s", ckpt_path)
174
+ last_ckpt_path = self.get_ckpt_path()
175
+ ckpt_path.parent.mkdir(exist_ok=True, parents=True)
176
+
177
+ # Potentially removes the last checkpoint.
178
+ if last_ckpt_path.exists() and self.config.only_save_most_recent:
179
+ if (base_ckpt := last_ckpt_path.resolve()).is_file():
180
+ base_ckpt.unlink()
181
+
182
+ # Combines all temporary files into a single checkpoint TAR file.
183
+ with tarfile.open(ckpt_path, "w:gz") as tar:
184
+
185
+ def add_file(name: str, write_fn: Callable[[io.BytesIO], Any]) -> None:
186
+ with io.BytesIO() as buf:
187
+ write_fn(buf)
188
+ tarinfo = tarfile.TarInfo(name)
189
+ tarinfo.size = buf.tell()
190
+ buf.seek(0)
191
+ tar.addfile(tarinfo, buf)
192
+
193
+ add_file("model", lambda buf: cloudpickle.dump(model, buf))
194
+ add_file("opt", lambda buf: cloudpickle.dump(optimizer, buf))
195
+ add_file("opt_state", lambda buf: cloudpickle.dump(opt_state, buf))
196
+ add_file("state", lambda buf: buf.write(json.dumps(asdict(state), indent=2).encode()))
197
+ add_file("config", lambda buf: buf.write(OmegaConf.to_yaml(self.config).encode()))
198
+
199
+ # Updates the symlink to the new checkpoint.
200
+ last_ckpt_path.unlink(missing_ok=True)
201
+ try:
202
+ last_ckpt_path.symlink_to(ckpt_path.relative_to(last_ckpt_path.parent))
203
+ except FileExistsError:
204
+ logger.exception("Exception while trying to update %s", ckpt_path)
205
+
206
+ # Marks directory as having artifacts which shouldn't be overwritten.
207
+ self.add_lock_file("ckpt", exists_ok=True)
208
+
209
+ return ckpt_path
@@ -0,0 +1,251 @@
1
+ """A trainer mixin for logging CPU statistics.
2
+
3
+ This logs memory and CPU utilization in a background process, sending it to
4
+ the logging process every now and then. This is useful for detecting memory
5
+ leaks in your dataloader, among other issues.
6
+ """
7
+
8
+ import logging
9
+ import multiprocessing as mp
10
+ import os
11
+ import time
12
+ from ctypes import Structure, c_double, c_uint16, c_uint64
13
+ from dataclasses import dataclass
14
+ from multiprocessing.managers import SyncManager, ValueProxy
15
+ from multiprocessing.synchronize import Event
16
+ from typing import Generic, TypeVar
17
+
18
+ import psutil
19
+
20
+ from xax.core.conf import field
21
+ from xax.core.state import State
22
+ from xax.task.base import BaseConfig
23
+ from xax.task.mixins.logger import LoggerConfig, LoggerMixin
24
+ from xax.task.mixins.process import ProcessConfig, ProcessMixin
25
+
26
+ logger: logging.Logger = logging.getLogger(__name__)
27
+
28
+
29
+ @dataclass
30
+ class CPUStatsOptions:
31
+ ping_interval: int = field(1, help="How often to check stats (in seconds)")
32
+ only_log_once: bool = field(False, help="If set, only log read stats one time")
33
+
34
+
35
+ @dataclass
36
+ class CPUStatsConfig(ProcessConfig, LoggerConfig, BaseConfig):
37
+ cpu_stats: CPUStatsOptions = field(CPUStatsOptions(), help="CPU stats configuration")
38
+
39
+
40
+ Config = TypeVar("Config", bound=CPUStatsConfig)
41
+
42
+
43
+ class CPUStats(Structure):
44
+ _fields_ = [
45
+ ("cpu_percent", c_double),
46
+ ("mem_percent", c_double),
47
+ ("mem_rss", c_uint64),
48
+ ("mem_vms", c_uint64),
49
+ ("mem_shared", c_uint64),
50
+ ("mem_rss_total", c_uint64),
51
+ ("mem_vms_total", c_uint64),
52
+ ("child_cpu_percent", c_double),
53
+ ("child_mem_percent", c_double),
54
+ ("num_child_procs", c_uint16),
55
+ ]
56
+
57
+
58
+ @dataclass
59
+ class CPUStatsInfo:
60
+ cpu_percent: float
61
+ mem_percent: float
62
+ mem_rss: int
63
+ mem_vms: int
64
+ mem_shared: int
65
+ mem_rss_total: int
66
+ mem_vms_total: int
67
+ child_cpu_percent: float
68
+ child_mem_percent: float
69
+ num_child_procs: int
70
+
71
+ @classmethod
72
+ def from_stats(cls, stats: CPUStats) -> "CPUStatsInfo":
73
+ return cls(
74
+ cpu_percent=stats.cpu_percent,
75
+ mem_percent=stats.mem_percent,
76
+ mem_rss=stats.mem_rss,
77
+ mem_vms=stats.mem_vms,
78
+ mem_shared=stats.mem_shared,
79
+ mem_rss_total=stats.mem_rss_total,
80
+ mem_vms_total=stats.mem_vms_total,
81
+ child_cpu_percent=stats.child_cpu_percent,
82
+ child_mem_percent=stats.child_mem_percent,
83
+ num_child_procs=stats.num_child_procs,
84
+ )
85
+
86
+
87
+ def worker(
88
+ ping_interval: float,
89
+ stats: ValueProxy[CPUStats],
90
+ monitor_event: Event,
91
+ start_event: Event,
92
+ pid: int,
93
+ ) -> None:
94
+ start_event.set()
95
+
96
+ proc, cur_pid = psutil.Process(pid), os.getpid()
97
+ logger.debug("Starting CPU stats monitor for PID %d with PID %d", pid, cur_pid)
98
+
99
+ def get_children() -> dict[int, psutil.Process]:
100
+ return {p.pid: p for p in proc.children(recursive=True) if p.pid != cur_pid}
101
+
102
+ child_procs = get_children()
103
+
104
+ try:
105
+ while True:
106
+ # Updates child processes, preserving the previous child process
107
+ # object. Otherwise the CPU percentage will be zero.
108
+ new_procs = get_children()
109
+ child_procs = {**new_procs, **child_procs}
110
+ child_procs = {pid: child_procs[pid] for pid in new_procs.keys()}
111
+
112
+ # Gets process memory info.
113
+ mem_info = proc.memory_info()
114
+ mem_rss_total = sum(p.memory_info().rss for p in child_procs.values()) + mem_info.rss
115
+ mem_vms_total = sum(p.memory_info().vms for p in child_procs.values()) + mem_info.vms
116
+
117
+ # Gets child CPU and memory percentages.
118
+ child_cpu_percent_total = sum(p.cpu_percent() for p in child_procs.values()) if child_procs else 0.0
119
+ child_mem_percent_total = sum(p.memory_percent() for p in child_procs.values()) if child_procs else 0.0
120
+
121
+ # Sets the CPU stats.
122
+ stats.set(
123
+ CPUStats(
124
+ cpu_percent=proc.cpu_percent(),
125
+ mem_percent=proc.memory_percent(),
126
+ mem_rss=int(mem_info.rss),
127
+ mem_vms=int(mem_info.vms),
128
+ mem_shared=int(getattr(mem_info, "shared", 0)),
129
+ mem_rss_total=int(mem_rss_total),
130
+ mem_vms_total=int(mem_vms_total),
131
+ child_cpu_percent=child_cpu_percent_total / len(child_procs),
132
+ child_mem_percent=child_mem_percent_total / len(child_procs),
133
+ num_child_procs=len(child_procs),
134
+ ),
135
+ )
136
+
137
+ monitor_event.set()
138
+ time.sleep(ping_interval)
139
+
140
+ except BaseException:
141
+ logger.error("Closing CPU stats monitor")
142
+
143
+
144
+ class CPUStatsMonitor:
145
+ def __init__(self, ping_interval: float, manager: SyncManager) -> None:
146
+ self._ping_interval = ping_interval
147
+ self._manager = manager
148
+ self._monitor_event = self._manager.Event()
149
+ self._start_event = self._manager.Event()
150
+ self._cpu_stats_smem = self._manager.Value(
151
+ CPUStats,
152
+ CPUStats(
153
+ cpu_percent=0.0,
154
+ mem_percent=0.0,
155
+ mem_rss=0,
156
+ mem_vms=0,
157
+ mem_shared=0,
158
+ mem_rss_total=0,
159
+ mem_vms_total=0,
160
+ child_cpu_percent=0.0,
161
+ child_mem_percent=0.0,
162
+ num_child_procs=0,
163
+ ),
164
+ )
165
+ self._cpu_stats: CPUStatsInfo | None = None
166
+ self._proc: mp.Process | None = None
167
+
168
+ def get_if_set(self) -> CPUStatsInfo | None:
169
+ if self._monitor_event.is_set():
170
+ self._monitor_event.clear()
171
+ return CPUStatsInfo.from_stats(self._cpu_stats_smem.get())
172
+ return None
173
+
174
+ def get(self) -> CPUStatsInfo | None:
175
+ if (stats := self.get_if_set()) is not None:
176
+ self._cpu_stats = stats
177
+ return self._cpu_stats
178
+
179
+ def start(self, wait: bool = False) -> None:
180
+ if self._proc is not None:
181
+ raise RuntimeError("CPU stats monitor already started")
182
+ if self._monitor_event.is_set():
183
+ self._monitor_event.clear()
184
+ if self._start_event.is_set():
185
+ self._start_event.clear()
186
+ self._cpu_stats = None
187
+ self._proc = mp.Process(
188
+ target=worker,
189
+ args=(self._ping_interval, self._cpu_stats_smem, self._monitor_event, self._start_event, os.getpid()),
190
+ daemon=True,
191
+ name="xax-cpu-stats",
192
+ )
193
+ self._proc.start()
194
+ if wait:
195
+ self._start_event.wait()
196
+
197
+ def stop(self) -> None:
198
+ if self._proc is None:
199
+ raise RuntimeError("CPU stats monitor not started")
200
+ if self._proc.is_alive():
201
+ self._proc.terminate()
202
+ logger.debug("Terminated CPU stats monitor; joining...")
203
+ self._proc.join()
204
+ self._proc = None
205
+ self._cpu_stats = None
206
+
207
+
208
+ class CPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
209
+ """Defines a task mixin for getting CPU statistics."""
210
+
211
+ _cpu_stats_monitor: CPUStatsMonitor
212
+
213
+ def __init__(self, config: Config) -> None:
214
+ super().__init__(config)
215
+
216
+ self._cpu_stats_monitor = CPUStatsMonitor(
217
+ ping_interval=self.config.cpu_stats.ping_interval,
218
+ manager=self._mp_manager,
219
+ )
220
+
221
+ def on_training_start(self, state: State) -> State:
222
+ state = super().on_training_start(state)
223
+
224
+ self._cpu_stats_monitor.start()
225
+ return state
226
+
227
+ def on_training_end(self, state: State) -> State:
228
+ state = super().on_training_end(state)
229
+
230
+ self._cpu_stats_monitor.stop()
231
+ return state
232
+
233
+ def on_step_start(self, state: State) -> State:
234
+ state = super().on_step_start(state)
235
+
236
+ monitor = self._cpu_stats_monitor
237
+ stats = monitor.get_if_set() if self.config.cpu_stats.only_log_once else monitor.get()
238
+
239
+ if stats is not None:
240
+ self.logger.log_scalar("child_procs", stats.num_child_procs, namespace="🔧 cpu")
241
+ self.logger.log_scalar("percent", stats.cpu_percent, namespace="🔧 cpu")
242
+ self.logger.log_scalar("child_percent", stats.child_cpu_percent, namespace="🔧 cpu")
243
+ self.logger.log_scalar("percent", stats.mem_percent, namespace="🔧 mem")
244
+ self.logger.log_scalar("shared", stats.mem_shared, namespace="🔧 mem")
245
+ self.logger.log_scalar("child_percent", stats.child_mem_percent, namespace="🔧 mem")
246
+ self.logger.log_scalar("rss/cur", stats.mem_rss, namespace="🔧 mem")
247
+ self.logger.log_scalar("rss/total", stats.mem_rss_total, namespace="🔧 mem")
248
+ self.logger.log_scalar("vms/cur", stats.mem_vms, namespace="🔧 mem")
249
+ self.logger.log_scalar("vms/total", stats.mem_vms_total, namespace="🔧 mem")
250
+
251
+ return state