xax 0.0.1__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +256 -1
- xax/core/conf.py +193 -0
- xax/core/state.py +81 -0
- xax/nn/__init__.py +0 -0
- xax/nn/embeddings.py +355 -0
- xax/nn/functions.py +77 -0
- xax/nn/parallel.py +211 -0
- xax/requirements-dev.txt +15 -0
- xax/requirements.txt +23 -0
- xax/task/__init__.py +0 -0
- xax/task/base.py +207 -0
- xax/task/launchers/__init__.py +0 -0
- xax/task/launchers/base.py +28 -0
- xax/task/launchers/cli.py +42 -0
- xax/task/launchers/single_process.py +30 -0
- xax/task/launchers/staged.py +29 -0
- xax/task/logger.py +783 -0
- xax/task/loggers/__init__.py +0 -0
- xax/task/loggers/callback.py +56 -0
- xax/task/loggers/json.py +121 -0
- xax/task/loggers/state.py +45 -0
- xax/task/loggers/stdout.py +170 -0
- xax/task/loggers/tensorboard.py +223 -0
- xax/task/mixins/__init__.py +12 -0
- xax/task/mixins/artifacts.py +114 -0
- xax/task/mixins/checkpointing.py +209 -0
- xax/task/mixins/cpu_stats.py +251 -0
- xax/task/mixins/data_loader.py +149 -0
- xax/task/mixins/gpu_stats.py +257 -0
- xax/task/mixins/logger.py +66 -0
- xax/task/mixins/process.py +51 -0
- xax/task/mixins/runnable.py +63 -0
- xax/task/mixins/step_wrapper.py +63 -0
- xax/task/mixins/train.py +541 -0
- xax/task/script.py +53 -0
- xax/task/task.py +65 -0
- xax/utils/__init__.py +0 -0
- xax/utils/data/__init__.py +0 -0
- xax/utils/data/collate.py +206 -0
- xax/utils/experiments.py +802 -0
- xax/utils/jax.py +14 -0
- xax/utils/logging.py +223 -0
- xax/utils/numpy.py +47 -0
- xax/utils/tensorboard.py +258 -0
- xax/utils/text.py +350 -0
- xax-0.0.5.dist-info/METADATA +40 -0
- xax-0.0.5.dist-info/RECORD +52 -0
- {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/WHEEL +1 -1
- xax-0.0.5.dist-info/top_level.txt +1 -0
- examples/mnist.py +0 -148
- xax-0.0.1.dist-info/METADATA +0 -21
- xax-0.0.1.dist-info/RECORD +0 -9
- xax-0.0.1.dist-info/top_level.txt +0 -2
- {examples → xax/core}/__init__.py +0 -0
- {xax-0.0.1.dist-info → xax-0.0.5.dist-info}/LICENSE +0 -0
@@ -0,0 +1,114 @@
|
|
1
|
+
"""Defines a mixin for storing any task artifacts."""
|
2
|
+
|
3
|
+
import functools
|
4
|
+
import inspect
|
5
|
+
import logging
|
6
|
+
import os
|
7
|
+
from dataclasses import dataclass
|
8
|
+
from pathlib import Path
|
9
|
+
from typing import Self, TypeVar
|
10
|
+
|
11
|
+
from xax.core.conf import field, get_run_dir
|
12
|
+
from xax.core.state import State
|
13
|
+
from xax.nn.parallel import is_master
|
14
|
+
from xax.task.base import BaseConfig, BaseTask
|
15
|
+
from xax.utils.experiments import stage_environment
|
16
|
+
from xax.utils.logging import LOG_STATUS
|
17
|
+
from xax.utils.text import show_info
|
18
|
+
|
19
|
+
logger = logging.getLogger(__name__)
|
20
|
+
|
21
|
+
|
22
|
+
@dataclass
|
23
|
+
class ArtifactsConfig(BaseConfig):
|
24
|
+
exp_dir: str | None = field(None, help="The fixed experiment directory")
|
25
|
+
|
26
|
+
|
27
|
+
Config = TypeVar("Config", bound=ArtifactsConfig)
|
28
|
+
|
29
|
+
|
30
|
+
class ArtifactsMixin(BaseTask[Config]):
|
31
|
+
_exp_dir: Path | None
|
32
|
+
|
33
|
+
def __init__(self, config: Config) -> None:
|
34
|
+
super().__init__(config)
|
35
|
+
|
36
|
+
self._exp_dir = None
|
37
|
+
|
38
|
+
@functools.cached_property
|
39
|
+
def run_dir(self) -> Path:
|
40
|
+
run_dir = get_run_dir()
|
41
|
+
if run_dir is None:
|
42
|
+
task_file = inspect.getfile(self.__class__)
|
43
|
+
run_dir = Path(task_file).resolve().parent
|
44
|
+
return run_dir / self.task_name
|
45
|
+
|
46
|
+
def set_exp_dir(self, exp_dir: Path) -> Self:
|
47
|
+
self._exp_dir = exp_dir
|
48
|
+
return self
|
49
|
+
|
50
|
+
@property
|
51
|
+
def exp_dir(self) -> Path:
|
52
|
+
return self.get_exp_dir()
|
53
|
+
|
54
|
+
def add_lock_file(self, lock_type: str, *, exists_ok: bool = False) -> None:
|
55
|
+
if (lock_file := self.exp_dir / f".lock_{lock_type}").exists():
|
56
|
+
if not exists_ok:
|
57
|
+
raise RuntimeError(f"Lock file already exists at {lock_file}")
|
58
|
+
else:
|
59
|
+
with open(lock_file, "w", encoding="utf-8") as f:
|
60
|
+
f.write(f"PID: {os.getpid()}")
|
61
|
+
|
62
|
+
def remove_lock_file(self, lock_type: str, *, missing_ok: bool = False) -> None:
|
63
|
+
if (lock_file := self.exp_dir / f".lock_{lock_type}").exists():
|
64
|
+
lock_file.unlink()
|
65
|
+
elif not missing_ok:
|
66
|
+
raise RuntimeError(f"Lock file not found at {lock_file}")
|
67
|
+
|
68
|
+
def get_exp_dir(self) -> Path:
|
69
|
+
if self._exp_dir is not None:
|
70
|
+
return self._exp_dir
|
71
|
+
|
72
|
+
if self.config.exp_dir is not None:
|
73
|
+
exp_dir = Path(self.config.exp_dir).expanduser().resolve()
|
74
|
+
exp_dir.mkdir(parents=True, exist_ok=True)
|
75
|
+
self._exp_dir = exp_dir
|
76
|
+
logger.log(LOG_STATUS, self._exp_dir)
|
77
|
+
return self._exp_dir
|
78
|
+
|
79
|
+
def get_exp_dir(run_id: int) -> Path:
|
80
|
+
return self.run_dir / f"run_{run_id}"
|
81
|
+
|
82
|
+
def has_lock_file(exp_dir: Path, lock_type: str | None = None) -> bool:
|
83
|
+
if lock_type is not None:
|
84
|
+
return (exp_dir / f".lock_{lock_type}").exists()
|
85
|
+
return any(exp_dir.glob(".lock_*"))
|
86
|
+
|
87
|
+
run_id = 0
|
88
|
+
while (exp_dir := get_exp_dir(run_id)).is_dir() and has_lock_file(exp_dir):
|
89
|
+
run_id += 1
|
90
|
+
exp_dir.mkdir(exist_ok=True, parents=True)
|
91
|
+
self._exp_dir = exp_dir.expanduser().resolve()
|
92
|
+
logger.log(LOG_STATUS, self._exp_dir)
|
93
|
+
return self._exp_dir
|
94
|
+
|
95
|
+
@functools.lru_cache(maxsize=None)
|
96
|
+
def stage_environment(self) -> Path | None:
|
97
|
+
stage_dir = (self.exp_dir / "code").resolve()
|
98
|
+
try:
|
99
|
+
stage_environment(self, stage_dir)
|
100
|
+
except Exception:
|
101
|
+
logger.exception("Failed to stage environment!")
|
102
|
+
return None
|
103
|
+
return stage_dir
|
104
|
+
|
105
|
+
def on_training_end(self, state: State) -> State:
|
106
|
+
state = super().on_training_end(state)
|
107
|
+
|
108
|
+
if is_master():
|
109
|
+
if self._exp_dir is None:
|
110
|
+
show_info("Exiting training job", important=True)
|
111
|
+
else:
|
112
|
+
show_info(f"Exiting training job for {self.exp_dir}", important=True)
|
113
|
+
|
114
|
+
return state
|
@@ -0,0 +1,209 @@
|
|
1
|
+
"""Defines a mixin for handling model checkpointing."""
|
2
|
+
|
3
|
+
import io
|
4
|
+
import json
|
5
|
+
import logging
|
6
|
+
import tarfile
|
7
|
+
from dataclasses import asdict, dataclass
|
8
|
+
from pathlib import Path
|
9
|
+
from typing import Any, Callable, Generic, Literal, TypeVar, cast, overload
|
10
|
+
|
11
|
+
import cloudpickle
|
12
|
+
import optax
|
13
|
+
from jaxtyping import PyTree
|
14
|
+
from omegaconf import DictConfig, OmegaConf
|
15
|
+
|
16
|
+
from xax.core.conf import field
|
17
|
+
from xax.core.state import State
|
18
|
+
from xax.nn.parallel import is_master
|
19
|
+
from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
|
20
|
+
|
21
|
+
logger = logging.getLogger(__name__)
|
22
|
+
|
23
|
+
CheckpointPart = Literal["model", "opt", "opt_state", "state", "config"]
|
24
|
+
|
25
|
+
|
26
|
+
def get_ckpt_path(exp_dir: Path, state: State | None = None) -> Path:
|
27
|
+
"""Defines the path to the checkpoint for a given state.
|
28
|
+
|
29
|
+
Args:
|
30
|
+
exp_dir: The experiment directory
|
31
|
+
state: The current trainer state
|
32
|
+
|
33
|
+
Returns:
|
34
|
+
The path to the checkpoint file.
|
35
|
+
"""
|
36
|
+
if state is None:
|
37
|
+
return exp_dir / "checkpoints" / "ckpt.bin"
|
38
|
+
return exp_dir / "checkpoints" / f"ckpt.{state.num_steps}.bin"
|
39
|
+
|
40
|
+
|
41
|
+
@dataclass
|
42
|
+
class CheckpointingConfig(ArtifactsConfig):
|
43
|
+
save_every_n_steps: int | None = field(None, help="Save a checkpoint every N steps")
|
44
|
+
save_every_n_seconds: float | None = field(60.0 * 60.0, help="Save a checkpoint every N seconds")
|
45
|
+
only_save_most_recent: bool = field(True, help="Only keep the most recent checkpoint")
|
46
|
+
load_from_ckpt_path: str | None = field(None, help="If set, load initial model weights from this path")
|
47
|
+
load_ckpt_strict: bool = field(True, help="If set, only load weights for which have a matching key in the model")
|
48
|
+
|
49
|
+
|
50
|
+
Config = TypeVar("Config", bound=CheckpointingConfig)
|
51
|
+
|
52
|
+
|
53
|
+
class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
54
|
+
def __init__(self, config: Config) -> None:
|
55
|
+
super().__init__(config)
|
56
|
+
|
57
|
+
self.__last_ckpt_time = 0.0
|
58
|
+
|
59
|
+
def get_ckpt_path(self, state: State | None = None) -> Path:
|
60
|
+
return get_ckpt_path(self.exp_dir, state)
|
61
|
+
|
62
|
+
def get_init_ckpt_path(self) -> Path | None:
|
63
|
+
if self._exp_dir is not None:
|
64
|
+
ckpt_path = self.get_ckpt_path()
|
65
|
+
if ckpt_path.exists():
|
66
|
+
return ckpt_path
|
67
|
+
if self.config.load_from_ckpt_path is not None:
|
68
|
+
ckpt_path = Path(self.config.load_from_ckpt_path)
|
69
|
+
assert ckpt_path.exists(), f"Checkpoint path {ckpt_path} does not exist."
|
70
|
+
return ckpt_path
|
71
|
+
return None
|
72
|
+
|
73
|
+
def should_checkpoint(self, state: State) -> bool:
|
74
|
+
if self.config.save_every_n_steps is not None:
|
75
|
+
if state.num_steps % self.config.save_every_n_steps == 0:
|
76
|
+
return True
|
77
|
+
if self.config.save_every_n_seconds is not None:
|
78
|
+
last_time, cur_time = self.__last_ckpt_time, state.elapsed_time_s
|
79
|
+
if cur_time - last_time >= self.config.save_every_n_seconds:
|
80
|
+
self.__last_ckpt_time = cur_time
|
81
|
+
return True
|
82
|
+
return False
|
83
|
+
|
84
|
+
@overload
|
85
|
+
def load_checkpoint(
|
86
|
+
self,
|
87
|
+
path: Path,
|
88
|
+
) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]: ...
|
89
|
+
|
90
|
+
@overload
|
91
|
+
def load_checkpoint(self, path: Path, part: Literal["model"]) -> PyTree: ...
|
92
|
+
|
93
|
+
@overload
|
94
|
+
def load_checkpoint(self, path: Path, part: Literal["opt"]) -> optax.GradientTransformation: ...
|
95
|
+
|
96
|
+
@overload
|
97
|
+
def load_checkpoint(self, path: Path, part: Literal["opt_state"]) -> optax.OptState: ...
|
98
|
+
|
99
|
+
@overload
|
100
|
+
def load_checkpoint(self, path: Path, part: Literal["state"]) -> State: ...
|
101
|
+
|
102
|
+
@overload
|
103
|
+
def load_checkpoint(self, path: Path, part: Literal["config"]) -> DictConfig: ...
|
104
|
+
|
105
|
+
def load_checkpoint(
|
106
|
+
self,
|
107
|
+
path: Path,
|
108
|
+
part: CheckpointPart | None = None,
|
109
|
+
) -> (
|
110
|
+
tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]
|
111
|
+
| PyTree
|
112
|
+
| optax.GradientTransformation
|
113
|
+
| optax.OptState
|
114
|
+
| State
|
115
|
+
| DictConfig
|
116
|
+
):
|
117
|
+
with tarfile.open(path, "r:gz") as tar:
|
118
|
+
|
119
|
+
def get_model() -> PyTree:
|
120
|
+
if (model := tar.extractfile("model")) is None:
|
121
|
+
raise ValueError(f"Checkpoint does not contain a model file: {path}")
|
122
|
+
return cloudpickle.load(model)
|
123
|
+
|
124
|
+
def get_opt() -> optax.GradientTransformation:
|
125
|
+
if (opt := tar.extractfile("opt")) is None:
|
126
|
+
raise ValueError(f"Checkpoint does not contain an opt file: {path}")
|
127
|
+
return cloudpickle.load(opt)
|
128
|
+
|
129
|
+
def get_opt_state() -> optax.OptState:
|
130
|
+
if (opt_state := tar.extractfile("opt_state")) is None:
|
131
|
+
raise ValueError(f"Checkpoint does not contain an opt_state file: {path}")
|
132
|
+
return cloudpickle.load(opt_state)
|
133
|
+
|
134
|
+
def get_state() -> State:
|
135
|
+
if (state := tar.extractfile("state")) is None:
|
136
|
+
raise ValueError(f"Checkpoint does not contain a state file: {path}")
|
137
|
+
return State(**json.loads(state.read().decode()))
|
138
|
+
|
139
|
+
def get_config() -> DictConfig:
|
140
|
+
if (config := tar.extractfile("config")) is None:
|
141
|
+
raise ValueError(f"Checkpoint does not contain a config file: {path}")
|
142
|
+
return cast(DictConfig, OmegaConf.load(config))
|
143
|
+
|
144
|
+
match part:
|
145
|
+
case "model":
|
146
|
+
return get_model()
|
147
|
+
case "opt":
|
148
|
+
return get_opt()
|
149
|
+
case "opt_state":
|
150
|
+
return get_opt_state()
|
151
|
+
case "state":
|
152
|
+
return get_state()
|
153
|
+
case "config":
|
154
|
+
return get_config()
|
155
|
+
case None:
|
156
|
+
return get_model(), get_opt(), get_opt_state(), get_state(), get_config()
|
157
|
+
case _:
|
158
|
+
raise ValueError(f"Invalid checkpoint part: {part}")
|
159
|
+
|
160
|
+
def save_checkpoint(
|
161
|
+
self,
|
162
|
+
model: PyTree,
|
163
|
+
optimizer: optax.GradientTransformation,
|
164
|
+
opt_state: optax.OptState,
|
165
|
+
state: State,
|
166
|
+
) -> Path:
|
167
|
+
ckpt_path = self.get_ckpt_path(state)
|
168
|
+
|
169
|
+
if not is_master():
|
170
|
+
return ckpt_path
|
171
|
+
|
172
|
+
# Gets the path to the last checkpoint.
|
173
|
+
logger.info("Saving checkpoint to %s", ckpt_path)
|
174
|
+
last_ckpt_path = self.get_ckpt_path()
|
175
|
+
ckpt_path.parent.mkdir(exist_ok=True, parents=True)
|
176
|
+
|
177
|
+
# Potentially removes the last checkpoint.
|
178
|
+
if last_ckpt_path.exists() and self.config.only_save_most_recent:
|
179
|
+
if (base_ckpt := last_ckpt_path.resolve()).is_file():
|
180
|
+
base_ckpt.unlink()
|
181
|
+
|
182
|
+
# Combines all temporary files into a single checkpoint TAR file.
|
183
|
+
with tarfile.open(ckpt_path, "w:gz") as tar:
|
184
|
+
|
185
|
+
def add_file(name: str, write_fn: Callable[[io.BytesIO], Any]) -> None:
|
186
|
+
with io.BytesIO() as buf:
|
187
|
+
write_fn(buf)
|
188
|
+
tarinfo = tarfile.TarInfo(name)
|
189
|
+
tarinfo.size = buf.tell()
|
190
|
+
buf.seek(0)
|
191
|
+
tar.addfile(tarinfo, buf)
|
192
|
+
|
193
|
+
add_file("model", lambda buf: cloudpickle.dump(model, buf))
|
194
|
+
add_file("opt", lambda buf: cloudpickle.dump(optimizer, buf))
|
195
|
+
add_file("opt_state", lambda buf: cloudpickle.dump(opt_state, buf))
|
196
|
+
add_file("state", lambda buf: buf.write(json.dumps(asdict(state), indent=2).encode()))
|
197
|
+
add_file("config", lambda buf: buf.write(OmegaConf.to_yaml(self.config).encode()))
|
198
|
+
|
199
|
+
# Updates the symlink to the new checkpoint.
|
200
|
+
last_ckpt_path.unlink(missing_ok=True)
|
201
|
+
try:
|
202
|
+
last_ckpt_path.symlink_to(ckpt_path.relative_to(last_ckpt_path.parent))
|
203
|
+
except FileExistsError:
|
204
|
+
logger.exception("Exception while trying to update %s", ckpt_path)
|
205
|
+
|
206
|
+
# Marks directory as having artifacts which shouldn't be overwritten.
|
207
|
+
self.add_lock_file("ckpt", exists_ok=True)
|
208
|
+
|
209
|
+
return ckpt_path
|
@@ -0,0 +1,251 @@
|
|
1
|
+
"""A trainer mixin for logging CPU statistics.
|
2
|
+
|
3
|
+
This logs memory and CPU utilization in a background process, sending it to
|
4
|
+
the logging process every now and then. This is useful for detecting memory
|
5
|
+
leaks in your dataloader, among other issues.
|
6
|
+
"""
|
7
|
+
|
8
|
+
import logging
|
9
|
+
import multiprocessing as mp
|
10
|
+
import os
|
11
|
+
import time
|
12
|
+
from ctypes import Structure, c_double, c_uint16, c_uint64
|
13
|
+
from dataclasses import dataclass
|
14
|
+
from multiprocessing.managers import SyncManager, ValueProxy
|
15
|
+
from multiprocessing.synchronize import Event
|
16
|
+
from typing import Generic, TypeVar
|
17
|
+
|
18
|
+
import psutil
|
19
|
+
|
20
|
+
from xax.core.conf import field
|
21
|
+
from xax.core.state import State
|
22
|
+
from xax.task.base import BaseConfig
|
23
|
+
from xax.task.mixins.logger import LoggerConfig, LoggerMixin
|
24
|
+
from xax.task.mixins.process import ProcessConfig, ProcessMixin
|
25
|
+
|
26
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
27
|
+
|
28
|
+
|
29
|
+
@dataclass
|
30
|
+
class CPUStatsOptions:
|
31
|
+
ping_interval: int = field(1, help="How often to check stats (in seconds)")
|
32
|
+
only_log_once: bool = field(False, help="If set, only log read stats one time")
|
33
|
+
|
34
|
+
|
35
|
+
@dataclass
|
36
|
+
class CPUStatsConfig(ProcessConfig, LoggerConfig, BaseConfig):
|
37
|
+
cpu_stats: CPUStatsOptions = field(CPUStatsOptions(), help="CPU stats configuration")
|
38
|
+
|
39
|
+
|
40
|
+
Config = TypeVar("Config", bound=CPUStatsConfig)
|
41
|
+
|
42
|
+
|
43
|
+
class CPUStats(Structure):
|
44
|
+
_fields_ = [
|
45
|
+
("cpu_percent", c_double),
|
46
|
+
("mem_percent", c_double),
|
47
|
+
("mem_rss", c_uint64),
|
48
|
+
("mem_vms", c_uint64),
|
49
|
+
("mem_shared", c_uint64),
|
50
|
+
("mem_rss_total", c_uint64),
|
51
|
+
("mem_vms_total", c_uint64),
|
52
|
+
("child_cpu_percent", c_double),
|
53
|
+
("child_mem_percent", c_double),
|
54
|
+
("num_child_procs", c_uint16),
|
55
|
+
]
|
56
|
+
|
57
|
+
|
58
|
+
@dataclass
|
59
|
+
class CPUStatsInfo:
|
60
|
+
cpu_percent: float
|
61
|
+
mem_percent: float
|
62
|
+
mem_rss: int
|
63
|
+
mem_vms: int
|
64
|
+
mem_shared: int
|
65
|
+
mem_rss_total: int
|
66
|
+
mem_vms_total: int
|
67
|
+
child_cpu_percent: float
|
68
|
+
child_mem_percent: float
|
69
|
+
num_child_procs: int
|
70
|
+
|
71
|
+
@classmethod
|
72
|
+
def from_stats(cls, stats: CPUStats) -> "CPUStatsInfo":
|
73
|
+
return cls(
|
74
|
+
cpu_percent=stats.cpu_percent,
|
75
|
+
mem_percent=stats.mem_percent,
|
76
|
+
mem_rss=stats.mem_rss,
|
77
|
+
mem_vms=stats.mem_vms,
|
78
|
+
mem_shared=stats.mem_shared,
|
79
|
+
mem_rss_total=stats.mem_rss_total,
|
80
|
+
mem_vms_total=stats.mem_vms_total,
|
81
|
+
child_cpu_percent=stats.child_cpu_percent,
|
82
|
+
child_mem_percent=stats.child_mem_percent,
|
83
|
+
num_child_procs=stats.num_child_procs,
|
84
|
+
)
|
85
|
+
|
86
|
+
|
87
|
+
def worker(
|
88
|
+
ping_interval: float,
|
89
|
+
stats: ValueProxy[CPUStats],
|
90
|
+
monitor_event: Event,
|
91
|
+
start_event: Event,
|
92
|
+
pid: int,
|
93
|
+
) -> None:
|
94
|
+
start_event.set()
|
95
|
+
|
96
|
+
proc, cur_pid = psutil.Process(pid), os.getpid()
|
97
|
+
logger.debug("Starting CPU stats monitor for PID %d with PID %d", pid, cur_pid)
|
98
|
+
|
99
|
+
def get_children() -> dict[int, psutil.Process]:
|
100
|
+
return {p.pid: p for p in proc.children(recursive=True) if p.pid != cur_pid}
|
101
|
+
|
102
|
+
child_procs = get_children()
|
103
|
+
|
104
|
+
try:
|
105
|
+
while True:
|
106
|
+
# Updates child processes, preserving the previous child process
|
107
|
+
# object. Otherwise the CPU percentage will be zero.
|
108
|
+
new_procs = get_children()
|
109
|
+
child_procs = {**new_procs, **child_procs}
|
110
|
+
child_procs = {pid: child_procs[pid] for pid in new_procs.keys()}
|
111
|
+
|
112
|
+
# Gets process memory info.
|
113
|
+
mem_info = proc.memory_info()
|
114
|
+
mem_rss_total = sum(p.memory_info().rss for p in child_procs.values()) + mem_info.rss
|
115
|
+
mem_vms_total = sum(p.memory_info().vms for p in child_procs.values()) + mem_info.vms
|
116
|
+
|
117
|
+
# Gets child CPU and memory percentages.
|
118
|
+
child_cpu_percent_total = sum(p.cpu_percent() for p in child_procs.values()) if child_procs else 0.0
|
119
|
+
child_mem_percent_total = sum(p.memory_percent() for p in child_procs.values()) if child_procs else 0.0
|
120
|
+
|
121
|
+
# Sets the CPU stats.
|
122
|
+
stats.set(
|
123
|
+
CPUStats(
|
124
|
+
cpu_percent=proc.cpu_percent(),
|
125
|
+
mem_percent=proc.memory_percent(),
|
126
|
+
mem_rss=int(mem_info.rss),
|
127
|
+
mem_vms=int(mem_info.vms),
|
128
|
+
mem_shared=int(getattr(mem_info, "shared", 0)),
|
129
|
+
mem_rss_total=int(mem_rss_total),
|
130
|
+
mem_vms_total=int(mem_vms_total),
|
131
|
+
child_cpu_percent=child_cpu_percent_total / len(child_procs),
|
132
|
+
child_mem_percent=child_mem_percent_total / len(child_procs),
|
133
|
+
num_child_procs=len(child_procs),
|
134
|
+
),
|
135
|
+
)
|
136
|
+
|
137
|
+
monitor_event.set()
|
138
|
+
time.sleep(ping_interval)
|
139
|
+
|
140
|
+
except BaseException:
|
141
|
+
logger.error("Closing CPU stats monitor")
|
142
|
+
|
143
|
+
|
144
|
+
class CPUStatsMonitor:
|
145
|
+
def __init__(self, ping_interval: float, manager: SyncManager) -> None:
|
146
|
+
self._ping_interval = ping_interval
|
147
|
+
self._manager = manager
|
148
|
+
self._monitor_event = self._manager.Event()
|
149
|
+
self._start_event = self._manager.Event()
|
150
|
+
self._cpu_stats_smem = self._manager.Value(
|
151
|
+
CPUStats,
|
152
|
+
CPUStats(
|
153
|
+
cpu_percent=0.0,
|
154
|
+
mem_percent=0.0,
|
155
|
+
mem_rss=0,
|
156
|
+
mem_vms=0,
|
157
|
+
mem_shared=0,
|
158
|
+
mem_rss_total=0,
|
159
|
+
mem_vms_total=0,
|
160
|
+
child_cpu_percent=0.0,
|
161
|
+
child_mem_percent=0.0,
|
162
|
+
num_child_procs=0,
|
163
|
+
),
|
164
|
+
)
|
165
|
+
self._cpu_stats: CPUStatsInfo | None = None
|
166
|
+
self._proc: mp.Process | None = None
|
167
|
+
|
168
|
+
def get_if_set(self) -> CPUStatsInfo | None:
|
169
|
+
if self._monitor_event.is_set():
|
170
|
+
self._monitor_event.clear()
|
171
|
+
return CPUStatsInfo.from_stats(self._cpu_stats_smem.get())
|
172
|
+
return None
|
173
|
+
|
174
|
+
def get(self) -> CPUStatsInfo | None:
|
175
|
+
if (stats := self.get_if_set()) is not None:
|
176
|
+
self._cpu_stats = stats
|
177
|
+
return self._cpu_stats
|
178
|
+
|
179
|
+
def start(self, wait: bool = False) -> None:
|
180
|
+
if self._proc is not None:
|
181
|
+
raise RuntimeError("CPU stats monitor already started")
|
182
|
+
if self._monitor_event.is_set():
|
183
|
+
self._monitor_event.clear()
|
184
|
+
if self._start_event.is_set():
|
185
|
+
self._start_event.clear()
|
186
|
+
self._cpu_stats = None
|
187
|
+
self._proc = mp.Process(
|
188
|
+
target=worker,
|
189
|
+
args=(self._ping_interval, self._cpu_stats_smem, self._monitor_event, self._start_event, os.getpid()),
|
190
|
+
daemon=True,
|
191
|
+
name="xax-cpu-stats",
|
192
|
+
)
|
193
|
+
self._proc.start()
|
194
|
+
if wait:
|
195
|
+
self._start_event.wait()
|
196
|
+
|
197
|
+
def stop(self) -> None:
|
198
|
+
if self._proc is None:
|
199
|
+
raise RuntimeError("CPU stats monitor not started")
|
200
|
+
if self._proc.is_alive():
|
201
|
+
self._proc.terminate()
|
202
|
+
logger.debug("Terminated CPU stats monitor; joining...")
|
203
|
+
self._proc.join()
|
204
|
+
self._proc = None
|
205
|
+
self._cpu_stats = None
|
206
|
+
|
207
|
+
|
208
|
+
class CPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
|
209
|
+
"""Defines a task mixin for getting CPU statistics."""
|
210
|
+
|
211
|
+
_cpu_stats_monitor: CPUStatsMonitor
|
212
|
+
|
213
|
+
def __init__(self, config: Config) -> None:
|
214
|
+
super().__init__(config)
|
215
|
+
|
216
|
+
self._cpu_stats_monitor = CPUStatsMonitor(
|
217
|
+
ping_interval=self.config.cpu_stats.ping_interval,
|
218
|
+
manager=self._mp_manager,
|
219
|
+
)
|
220
|
+
|
221
|
+
def on_training_start(self, state: State) -> State:
|
222
|
+
state = super().on_training_start(state)
|
223
|
+
|
224
|
+
self._cpu_stats_monitor.start()
|
225
|
+
return state
|
226
|
+
|
227
|
+
def on_training_end(self, state: State) -> State:
|
228
|
+
state = super().on_training_end(state)
|
229
|
+
|
230
|
+
self._cpu_stats_monitor.stop()
|
231
|
+
return state
|
232
|
+
|
233
|
+
def on_step_start(self, state: State) -> State:
|
234
|
+
state = super().on_step_start(state)
|
235
|
+
|
236
|
+
monitor = self._cpu_stats_monitor
|
237
|
+
stats = monitor.get_if_set() if self.config.cpu_stats.only_log_once else monitor.get()
|
238
|
+
|
239
|
+
if stats is not None:
|
240
|
+
self.logger.log_scalar("child_procs", stats.num_child_procs, namespace="🔧 cpu")
|
241
|
+
self.logger.log_scalar("percent", stats.cpu_percent, namespace="🔧 cpu")
|
242
|
+
self.logger.log_scalar("child_percent", stats.child_cpu_percent, namespace="🔧 cpu")
|
243
|
+
self.logger.log_scalar("percent", stats.mem_percent, namespace="🔧 mem")
|
244
|
+
self.logger.log_scalar("shared", stats.mem_shared, namespace="🔧 mem")
|
245
|
+
self.logger.log_scalar("child_percent", stats.child_mem_percent, namespace="🔧 mem")
|
246
|
+
self.logger.log_scalar("rss/cur", stats.mem_rss, namespace="🔧 mem")
|
247
|
+
self.logger.log_scalar("rss/total", stats.mem_rss_total, namespace="🔧 mem")
|
248
|
+
self.logger.log_scalar("vms/cur", stats.mem_vms, namespace="🔧 mem")
|
249
|
+
self.logger.log_scalar("vms/total", stats.mem_vms_total, namespace="🔧 mem")
|
250
|
+
|
251
|
+
return state
|