xax 0.0.3__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +122 -8
- xax/core/conf.py +9 -33
- xax/core/state.py +13 -23
- xax/nn/embeddings.py +355 -0
- xax/nn/functions.py +8 -4
- xax/requirements-dev.txt +9 -1
- xax/requirements.txt +17 -10
- xax/task/base.py +2 -6
- xax/task/logger.py +419 -412
- xax/task/loggers/callback.py +44 -0
- xax/task/loggers/state.py +5 -18
- xax/task/loggers/tensorboard.py +16 -33
- xax/task/mixins/__init__.py +3 -1
- xax/task/mixins/artifacts.py +19 -9
- xax/task/mixins/checkpointing.py +221 -0
- xax/task/mixins/compile.py +104 -0
- xax/task/mixins/cpu_stats.py +26 -15
- xax/task/mixins/data_loader.py +27 -19
- xax/task/mixins/gpu_stats.py +22 -8
- xax/task/mixins/logger.py +5 -251
- xax/task/mixins/process.py +8 -1
- xax/task/mixins/runnable.py +3 -0
- xax/task/mixins/step_wrapper.py +5 -0
- xax/task/mixins/train.py +236 -145
- xax/task/script.py +1 -1
- xax/task/task.py +13 -5
- xax/utils/data/collate.py +6 -6
- xax/utils/experiments.py +45 -1
- xax/utils/logging.py +29 -0
- xax/utils/tensorboard.py +89 -21
- xax-0.0.6.dist-info/METADATA +50 -0
- xax-0.0.6.dist-info/RECORD +52 -0
- {xax-0.0.3.dist-info → xax-0.0.6.dist-info}/WHEEL +1 -1
- xax/task/launchers/staged.py +0 -29
- xax-0.0.3.dist-info/METADATA +0 -39
- xax-0.0.3.dist-info/RECORD +0 -49
- {xax-0.0.3.dist-info → xax-0.0.6.dist-info}/LICENSE +0 -0
- {xax-0.0.3.dist-info → xax-0.0.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,44 @@
|
|
1
|
+
"""Defines a logger that calls a callback function with the log line."""
|
2
|
+
|
3
|
+
from typing import Callable
|
4
|
+
|
5
|
+
from xax.task.logger import LogError, LogErrorSummary, LoggerImpl, LogLine, LogPing, LogStatus
|
6
|
+
|
7
|
+
|
8
|
+
class CallbackLogger(LoggerImpl):
|
9
|
+
def __init__(
|
10
|
+
self,
|
11
|
+
*,
|
12
|
+
callback: Callable[[LogLine], None] = lambda x: None,
|
13
|
+
error_summary_callback: Callable[[LogErrorSummary], None] = lambda x: None,
|
14
|
+
error_callback: Callable[[LogError], None] = lambda x: None,
|
15
|
+
status_callback: Callable[[LogStatus], None] = lambda x: None,
|
16
|
+
ping_callback: Callable[[LogPing], None] = lambda x: None,
|
17
|
+
file_callback: Callable[[str, str], None] = lambda x, y: None,
|
18
|
+
) -> None:
|
19
|
+
super().__init__()
|
20
|
+
|
21
|
+
self.callback = callback
|
22
|
+
self.error_summary_callback = error_summary_callback
|
23
|
+
self.error_callback = error_callback
|
24
|
+
self.status_callback = status_callback
|
25
|
+
self.ping_callback = ping_callback
|
26
|
+
self.file_callback = file_callback
|
27
|
+
|
28
|
+
def write(self, line: LogLine) -> None:
|
29
|
+
self.callback(line)
|
30
|
+
|
31
|
+
def write_error_summary(self, error_summary: LogErrorSummary) -> None:
|
32
|
+
self.error_summary_callback(error_summary)
|
33
|
+
|
34
|
+
def write_error(self, error: LogError) -> None:
|
35
|
+
self.error_callback(error)
|
36
|
+
|
37
|
+
def write_status(self, status: LogStatus) -> None:
|
38
|
+
self.status_callback(status)
|
39
|
+
|
40
|
+
def write_ping(self, ping: LogPing) -> None:
|
41
|
+
self.ping_callback(ping)
|
42
|
+
|
43
|
+
def log_file(self, name: str, contents: str) -> None:
|
44
|
+
self.file_callback(name, contents)
|
xax/task/loggers/state.py
CHANGED
@@ -3,8 +3,6 @@
|
|
3
3
|
from pathlib import Path
|
4
4
|
from typing import Literal
|
5
5
|
|
6
|
-
from omegaconf import DictConfig, OmegaConf
|
7
|
-
|
8
6
|
from xax.task.logger import LoggerImpl, LogLine
|
9
7
|
|
10
8
|
|
@@ -12,9 +10,6 @@ class StateLogger(LoggerImpl):
|
|
12
10
|
def __init__(
|
13
11
|
self,
|
14
12
|
run_directory: str | Path,
|
15
|
-
git_state_name: str = "git_state.txt",
|
16
|
-
train_code_name: str = "train_code.py",
|
17
|
-
config_name: str = "config.yaml",
|
18
13
|
flush_immediately: bool = False,
|
19
14
|
open_mode: Literal["w", "a"] = "w",
|
20
15
|
line_sep: str = "\n",
|
@@ -22,24 +17,16 @@ class StateLogger(LoggerImpl):
|
|
22
17
|
) -> None:
|
23
18
|
super().__init__(float("inf"))
|
24
19
|
|
25
|
-
self.
|
26
|
-
|
27
|
-
self.config_file = Path(run_directory).expanduser().resolve() / config_name
|
20
|
+
self.run_directory = Path(run_directory).expanduser().resolve()
|
21
|
+
|
28
22
|
self.flush_immediately = flush_immediately
|
29
23
|
self.open_mode = open_mode
|
30
24
|
self.line_sep = line_sep
|
31
25
|
self.remove_unicode_from_namespaces = remove_unicode_from_namespaces
|
32
26
|
|
33
|
-
def
|
34
|
-
with open(self.
|
35
|
-
f.write(
|
36
|
-
|
37
|
-
def log_training_code(self, training_code: str) -> None:
|
38
|
-
with open(self.train_code_file, "w") as f:
|
39
|
-
f.write(training_code)
|
40
|
-
|
41
|
-
def log_config(self, config: DictConfig) -> None:
|
42
|
-
OmegaConf.save(config, self.config_file)
|
27
|
+
def log_file(self, name: str, contents: str) -> None:
|
28
|
+
with open(self.run_directory / name, "w") as f:
|
29
|
+
f.write(contents)
|
43
30
|
|
44
31
|
def write(self, line: LogLine) -> None:
|
45
32
|
pass
|
xax/task/loggers/tensorboard.py
CHANGED
@@ -12,10 +12,6 @@ import time
|
|
12
12
|
from pathlib import Path
|
13
13
|
from typing import TypeVar
|
14
14
|
|
15
|
-
import jax
|
16
|
-
import PIL.Image
|
17
|
-
from omegaconf import DictConfig, OmegaConf
|
18
|
-
|
19
15
|
from xax.core.state import Phase
|
20
16
|
from xax.nn.parallel import is_master
|
21
17
|
from xax.task.logger import LoggerImpl, LogLine
|
@@ -62,10 +58,7 @@ class TensorboardLogger(LoggerImpl):
|
|
62
58
|
|
63
59
|
self.proc: subprocess.Popen | None = None
|
64
60
|
|
65
|
-
self.
|
66
|
-
self.training_code: str | None = None
|
67
|
-
self.config: DictConfig | None = None
|
68
|
-
|
61
|
+
self.files: dict[str, str] = {}
|
69
62
|
self.writers = TensorboardWriters(log_directory=self.log_directory, flush_seconds=flush_seconds)
|
70
63
|
self._started = False
|
71
64
|
|
@@ -84,7 +77,7 @@ class TensorboardLogger(LoggerImpl):
|
|
84
77
|
port = int(os.environ.get("TENSORBOARD_PORT", DEFAULT_TENSORBOARD_PORT))
|
85
78
|
|
86
79
|
while port_is_busy(port):
|
87
|
-
logger.warning(
|
80
|
+
logger.warning("Port %s is busy, waiting...", port)
|
88
81
|
time.sleep(10)
|
89
82
|
|
90
83
|
def make_localhost(s: str) -> str:
|
@@ -160,20 +153,10 @@ class TensorboardLogger(LoggerImpl):
|
|
160
153
|
self._start()
|
161
154
|
return self.writers.writer(phase)
|
162
155
|
|
163
|
-
def
|
164
|
-
if not is_master():
|
165
|
-
return
|
166
|
-
self.git_state = f"```\n{git_state}\n```"
|
167
|
-
|
168
|
-
def log_training_code(self, training_code: str) -> None:
|
156
|
+
def log_file(self, name: str, contents: str) -> None:
|
169
157
|
if not is_master():
|
170
158
|
return
|
171
|
-
self.
|
172
|
-
|
173
|
-
def log_config(self, config: DictConfig) -> None:
|
174
|
-
if not is_master():
|
175
|
-
return
|
176
|
-
self.config = config
|
159
|
+
self.files[name] = f"```\n{contents}\n```"
|
177
160
|
|
178
161
|
def write(self, line: LogLine) -> None:
|
179
162
|
if not is_master():
|
@@ -205,22 +188,22 @@ class TensorboardLogger(LoggerImpl):
|
|
205
188
|
|
206
189
|
for namespace, images in line.images.items():
|
207
190
|
for image_key, image_value in images.items():
|
208
|
-
image = PIL.Image.fromarray(jax.device_get(image_value.pixels))
|
209
191
|
writer.add_image(
|
210
192
|
f"{namespace}/{image_key}",
|
211
|
-
image,
|
193
|
+
image_value.image,
|
212
194
|
global_step=line.state.num_steps,
|
213
195
|
walltime=walltime,
|
214
196
|
)
|
215
197
|
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
198
|
+
for namespace, videos in line.videos.items():
|
199
|
+
for video_key, video_value in videos.items():
|
200
|
+
writer.add_video(
|
201
|
+
f"{namespace}/{video_key}",
|
202
|
+
video_value.frames,
|
203
|
+
fps=video_value.fps,
|
204
|
+
global_step=line.state.num_steps,
|
205
|
+
)
|
223
206
|
|
224
|
-
|
225
|
-
writer.add_text(
|
226
|
-
|
207
|
+
for name, contents in self.files.items():
|
208
|
+
writer.add_text(name, contents)
|
209
|
+
self.files.clear()
|
xax/task/mixins/__init__.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1
1
|
"""Defines a single interface for all the mixins."""
|
2
2
|
|
3
3
|
from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
|
4
|
+
from xax.task.mixins.checkpointing import CheckpointingConfig, CheckpointingMixin
|
5
|
+
from xax.task.mixins.compile import CompileConfig, CompileMixin
|
4
6
|
from xax.task.mixins.cpu_stats import CPUStatsConfig, CPUStatsMixin
|
5
7
|
from xax.task.mixins.data_loader import DataloadersConfig, DataloadersMixin
|
6
8
|
from xax.task.mixins.gpu_stats import GPUStatsConfig, GPUStatsMixin
|
@@ -8,4 +10,4 @@ from xax.task.mixins.logger import LoggerConfig, LoggerMixin
|
|
8
10
|
from xax.task.mixins.process import ProcessConfig, ProcessMixin
|
9
11
|
from xax.task.mixins.runnable import RunnableConfig, RunnableMixin
|
10
12
|
from xax.task.mixins.step_wrapper import StepContextConfig, StepContextMixin
|
11
|
-
from xax.task.mixins.train import
|
13
|
+
from xax.task.mixins.train import TrainConfig, TrainMixin
|
xax/task/mixins/artifacts.py
CHANGED
@@ -8,6 +8,8 @@ from dataclasses import dataclass
|
|
8
8
|
from pathlib import Path
|
9
9
|
from typing import Self, TypeVar
|
10
10
|
|
11
|
+
import jax
|
12
|
+
|
11
13
|
from xax.core.conf import field, get_run_dir
|
12
14
|
from xax.core.state import State
|
13
15
|
from xax.nn.parallel import is_master
|
@@ -19,6 +21,7 @@ from xax.utils.text import show_info
|
|
19
21
|
logger = logging.getLogger(__name__)
|
20
22
|
|
21
23
|
|
24
|
+
@jax.tree_util.register_dataclass
|
22
25
|
@dataclass
|
23
26
|
class ArtifactsConfig(BaseConfig):
|
24
27
|
exp_dir: str | None = field(None, help="The fixed experiment directory")
|
@@ -43,8 +46,12 @@ class ArtifactsMixin(BaseTask[Config]):
|
|
43
46
|
run_dir = Path(task_file).resolve().parent
|
44
47
|
return run_dir / self.task_name
|
45
48
|
|
46
|
-
|
47
|
-
|
49
|
+
@property
|
50
|
+
def exp_dir(self) -> Path:
|
51
|
+
return self.get_exp_dir()
|
52
|
+
|
53
|
+
def set_exp_dir(self, exp_dir: str | Path) -> Self:
|
54
|
+
self._exp_dir = Path(exp_dir).expanduser().resolve()
|
48
55
|
return self
|
49
56
|
|
50
57
|
def add_lock_file(self, lock_type: str, *, exists_ok: bool = False) -> None:
|
@@ -61,13 +68,16 @@ class ArtifactsMixin(BaseTask[Config]):
|
|
61
68
|
elif not missing_ok:
|
62
69
|
raise RuntimeError(f"Lock file not found at {lock_file}")
|
63
70
|
|
64
|
-
|
65
|
-
|
71
|
+
def get_exp_dir(self) -> Path:
|
72
|
+
if self._exp_dir is not None:
|
73
|
+
return self._exp_dir
|
74
|
+
|
66
75
|
if self.config.exp_dir is not None:
|
67
76
|
exp_dir = Path(self.config.exp_dir).expanduser().resolve()
|
68
77
|
exp_dir.mkdir(parents=True, exist_ok=True)
|
69
|
-
|
70
|
-
|
78
|
+
self._exp_dir = exp_dir
|
79
|
+
logger.log(LOG_STATUS, self._exp_dir)
|
80
|
+
return self._exp_dir
|
71
81
|
|
72
82
|
def get_exp_dir(run_id: int) -> Path:
|
73
83
|
return self.run_dir / f"run_{run_id}"
|
@@ -81,9 +91,9 @@ class ArtifactsMixin(BaseTask[Config]):
|
|
81
91
|
while (exp_dir := get_exp_dir(run_id)).is_dir() and has_lock_file(exp_dir):
|
82
92
|
run_id += 1
|
83
93
|
exp_dir.mkdir(exist_ok=True, parents=True)
|
84
|
-
|
85
|
-
logger.log(LOG_STATUS,
|
86
|
-
return
|
94
|
+
self._exp_dir = exp_dir.expanduser().resolve()
|
95
|
+
logger.log(LOG_STATUS, self._exp_dir)
|
96
|
+
return self._exp_dir
|
87
97
|
|
88
98
|
@functools.lru_cache(maxsize=None)
|
89
99
|
def stage_environment(self) -> Path | None:
|
@@ -0,0 +1,221 @@
|
|
1
|
+
"""Defines a mixin for handling model checkpointing."""
|
2
|
+
|
3
|
+
import io
|
4
|
+
import json
|
5
|
+
import logging
|
6
|
+
import tarfile
|
7
|
+
from dataclasses import asdict, dataclass
|
8
|
+
from pathlib import Path
|
9
|
+
from typing import Any, Callable, Generic, Literal, TypeVar, cast, overload
|
10
|
+
|
11
|
+
import cloudpickle
|
12
|
+
import jax
|
13
|
+
import optax
|
14
|
+
from jaxtyping import PyTree
|
15
|
+
from omegaconf import DictConfig, OmegaConf
|
16
|
+
|
17
|
+
from xax.core.conf import field
|
18
|
+
from xax.core.state import State
|
19
|
+
from xax.nn.parallel import is_master
|
20
|
+
from xax.task.mixins.artifacts import ArtifactsConfig, ArtifactsMixin
|
21
|
+
|
22
|
+
logger = logging.getLogger(__name__)
|
23
|
+
|
24
|
+
CheckpointPart = Literal["model", "opt", "opt_state", "state", "config"]
|
25
|
+
|
26
|
+
|
27
|
+
def get_ckpt_path(exp_dir: Path, state: State | None = None) -> Path:
|
28
|
+
"""Defines the path to the checkpoint for a given state.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
exp_dir: The experiment directory
|
32
|
+
state: The current trainer state
|
33
|
+
|
34
|
+
Returns:
|
35
|
+
The path to the checkpoint file.
|
36
|
+
"""
|
37
|
+
if state is None:
|
38
|
+
return exp_dir / "checkpoints" / "ckpt.bin"
|
39
|
+
return exp_dir / "checkpoints" / f"ckpt.{state.num_steps}.bin"
|
40
|
+
|
41
|
+
|
42
|
+
@jax.tree_util.register_dataclass
|
43
|
+
@dataclass
|
44
|
+
class CheckpointingConfig(ArtifactsConfig):
|
45
|
+
save_every_n_steps: int | None = field(None, help="Save a checkpoint every N steps")
|
46
|
+
save_every_n_seconds: float | None = field(60.0 * 60.0, help="Save a checkpoint every N seconds")
|
47
|
+
only_save_most_recent: bool = field(True, help="Only keep the most recent checkpoint")
|
48
|
+
load_from_ckpt_path: str | None = field(None, help="If set, load initial model weights from this path")
|
49
|
+
load_ckpt_strict: bool = field(True, help="If set, only load weights for which have a matching key in the model")
|
50
|
+
save_tf_model: bool = field(False, help="If set, saves a Tensorflow version of the model")
|
51
|
+
|
52
|
+
|
53
|
+
Config = TypeVar("Config", bound=CheckpointingConfig)
|
54
|
+
|
55
|
+
|
56
|
+
class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
57
|
+
def __init__(self, config: Config) -> None:
|
58
|
+
super().__init__(config)
|
59
|
+
|
60
|
+
self.__last_ckpt_time = 0.0
|
61
|
+
|
62
|
+
def get_ckpt_path(self, state: State | None = None) -> Path:
|
63
|
+
return get_ckpt_path(self.exp_dir, state)
|
64
|
+
|
65
|
+
def get_init_ckpt_path(self) -> Path | None:
|
66
|
+
if self._exp_dir is not None:
|
67
|
+
ckpt_path = self.get_ckpt_path()
|
68
|
+
if ckpt_path.exists():
|
69
|
+
return ckpt_path
|
70
|
+
if self.config.load_from_ckpt_path is not None:
|
71
|
+
ckpt_path = Path(self.config.load_from_ckpt_path)
|
72
|
+
assert ckpt_path.exists(), f"Checkpoint path {ckpt_path} does not exist."
|
73
|
+
return ckpt_path
|
74
|
+
return None
|
75
|
+
|
76
|
+
def should_checkpoint(self, state: State) -> bool:
|
77
|
+
if self.config.save_every_n_steps is not None:
|
78
|
+
if state.num_steps % self.config.save_every_n_steps == 0:
|
79
|
+
return True
|
80
|
+
if self.config.save_every_n_seconds is not None:
|
81
|
+
last_time, cur_time = self.__last_ckpt_time, state.elapsed_time_s
|
82
|
+
if cur_time - last_time >= self.config.save_every_n_seconds:
|
83
|
+
self.__last_ckpt_time = cur_time
|
84
|
+
return True
|
85
|
+
return False
|
86
|
+
|
87
|
+
@overload
|
88
|
+
def load_checkpoint(
|
89
|
+
self,
|
90
|
+
path: Path,
|
91
|
+
) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]: ...
|
92
|
+
|
93
|
+
@overload
|
94
|
+
def load_checkpoint(self, path: Path, part: Literal["model"]) -> PyTree: ...
|
95
|
+
|
96
|
+
@overload
|
97
|
+
def load_checkpoint(self, path: Path, part: Literal["opt"]) -> optax.GradientTransformation: ...
|
98
|
+
|
99
|
+
@overload
|
100
|
+
def load_checkpoint(self, path: Path, part: Literal["opt_state"]) -> optax.OptState: ...
|
101
|
+
|
102
|
+
@overload
|
103
|
+
def load_checkpoint(self, path: Path, part: Literal["state"]) -> State: ...
|
104
|
+
|
105
|
+
@overload
|
106
|
+
def load_checkpoint(self, path: Path, part: Literal["config"]) -> DictConfig: ...
|
107
|
+
|
108
|
+
def load_checkpoint(
|
109
|
+
self,
|
110
|
+
path: Path,
|
111
|
+
part: CheckpointPart | None = None,
|
112
|
+
) -> (
|
113
|
+
tuple[PyTree, optax.GradientTransformation, optax.OptState, State, DictConfig]
|
114
|
+
| PyTree
|
115
|
+
| optax.GradientTransformation
|
116
|
+
| optax.OptState
|
117
|
+
| State
|
118
|
+
| DictConfig
|
119
|
+
):
|
120
|
+
with tarfile.open(path, "r:gz") as tar:
|
121
|
+
|
122
|
+
def get_model() -> PyTree:
|
123
|
+
if (model := tar.extractfile("model")) is None:
|
124
|
+
raise ValueError(f"Checkpoint does not contain a model file: {path}")
|
125
|
+
return cloudpickle.load(model)
|
126
|
+
|
127
|
+
def get_opt() -> optax.GradientTransformation:
|
128
|
+
if (opt := tar.extractfile("opt")) is None:
|
129
|
+
raise ValueError(f"Checkpoint does not contain an opt file: {path}")
|
130
|
+
return cloudpickle.load(opt)
|
131
|
+
|
132
|
+
def get_opt_state() -> optax.OptState:
|
133
|
+
if (opt_state := tar.extractfile("opt_state")) is None:
|
134
|
+
raise ValueError(f"Checkpoint does not contain an opt_state file: {path}")
|
135
|
+
return cloudpickle.load(opt_state)
|
136
|
+
|
137
|
+
def get_state() -> State:
|
138
|
+
if (state := tar.extractfile("state")) is None:
|
139
|
+
raise ValueError(f"Checkpoint does not contain a state file: {path}")
|
140
|
+
return State(**json.loads(state.read().decode()))
|
141
|
+
|
142
|
+
def get_config() -> DictConfig:
|
143
|
+
if (config := tar.extractfile("config")) is None:
|
144
|
+
raise ValueError(f"Checkpoint does not contain a config file: {path}")
|
145
|
+
return cast(DictConfig, OmegaConf.load(config))
|
146
|
+
|
147
|
+
match part:
|
148
|
+
case "model":
|
149
|
+
return get_model()
|
150
|
+
case "opt":
|
151
|
+
return get_opt()
|
152
|
+
case "opt_state":
|
153
|
+
return get_opt_state()
|
154
|
+
case "state":
|
155
|
+
return get_state()
|
156
|
+
case "config":
|
157
|
+
return get_config()
|
158
|
+
case None:
|
159
|
+
return get_model(), get_opt(), get_opt_state(), get_state(), get_config()
|
160
|
+
case _:
|
161
|
+
raise ValueError(f"Invalid checkpoint part: {part}")
|
162
|
+
|
163
|
+
def save_checkpoint(
|
164
|
+
self,
|
165
|
+
model: PyTree,
|
166
|
+
optimizer: optax.GradientTransformation,
|
167
|
+
opt_state: optax.OptState,
|
168
|
+
state: State,
|
169
|
+
) -> Path:
|
170
|
+
ckpt_path = self.get_ckpt_path(state)
|
171
|
+
|
172
|
+
if not is_master():
|
173
|
+
return ckpt_path
|
174
|
+
|
175
|
+
# Gets the path to the last checkpoint.
|
176
|
+
logger.info("Saving checkpoint to %s", ckpt_path)
|
177
|
+
last_ckpt_path = self.get_ckpt_path()
|
178
|
+
ckpt_path.parent.mkdir(exist_ok=True, parents=True)
|
179
|
+
|
180
|
+
# Potentially removes the last checkpoint.
|
181
|
+
if last_ckpt_path.exists() and self.config.only_save_most_recent:
|
182
|
+
if (base_ckpt := last_ckpt_path.resolve()).is_file():
|
183
|
+
base_ckpt.unlink()
|
184
|
+
|
185
|
+
# Combines all temporary files into a single checkpoint TAR file.
|
186
|
+
with tarfile.open(ckpt_path, "w:gz") as tar:
|
187
|
+
|
188
|
+
def add_file(name: str, write_fn: Callable[[io.BytesIO], Any]) -> None:
|
189
|
+
with io.BytesIO() as buf:
|
190
|
+
write_fn(buf)
|
191
|
+
tarinfo = tarfile.TarInfo(name)
|
192
|
+
tarinfo.size = buf.tell()
|
193
|
+
buf.seek(0)
|
194
|
+
tar.addfile(tarinfo, buf)
|
195
|
+
|
196
|
+
add_file("model", lambda buf: cloudpickle.dump(model, buf))
|
197
|
+
add_file("opt", lambda buf: cloudpickle.dump(optimizer, buf))
|
198
|
+
add_file("opt_state", lambda buf: cloudpickle.dump(opt_state, buf))
|
199
|
+
add_file("state", lambda buf: buf.write(json.dumps(asdict(state), indent=2).encode()))
|
200
|
+
add_file("config", lambda buf: buf.write(OmegaConf.to_yaml(self.config).encode()))
|
201
|
+
|
202
|
+
if self.config.save_tf_model:
|
203
|
+
try:
|
204
|
+
from jax.experimental import jax2tf
|
205
|
+
except ModuleNotFoundError:
|
206
|
+
raise ImportError("Tensorflow is not installed. Install it with `pip install tensorflow`")
|
207
|
+
|
208
|
+
tf_model = jax2tf.convert(model)
|
209
|
+
add_file("model.tf", lambda buf: cloudpickle.dump(tf_model, buf))
|
210
|
+
|
211
|
+
# Updates the symlink to the new checkpoint.
|
212
|
+
last_ckpt_path.unlink(missing_ok=True)
|
213
|
+
try:
|
214
|
+
last_ckpt_path.symlink_to(ckpt_path.relative_to(last_ckpt_path.parent))
|
215
|
+
except FileExistsError:
|
216
|
+
logger.exception("Exception while trying to update %s", ckpt_path)
|
217
|
+
|
218
|
+
# Marks directory as having artifacts which shouldn't be overwritten.
|
219
|
+
self.add_lock_file("ckpt", exists_ok=True)
|
220
|
+
|
221
|
+
return ckpt_path
|
@@ -0,0 +1,104 @@
|
|
1
|
+
"""Defines a mixin for handling JAX compilation behavior.
|
2
|
+
|
3
|
+
This mixin allows control over JAX compilation settings like jit, pmap, and vmap
|
4
|
+
behavior during initialization and training.
|
5
|
+
"""
|
6
|
+
|
7
|
+
import logging
|
8
|
+
from dataclasses import dataclass
|
9
|
+
from pathlib import Path
|
10
|
+
from typing import Generic, TypeVar
|
11
|
+
|
12
|
+
import jax
|
13
|
+
|
14
|
+
from xax.core.conf import field
|
15
|
+
from xax.task.base import BaseConfig, BaseTask
|
16
|
+
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
|
19
|
+
|
20
|
+
@jax.tree_util.register_dataclass
|
21
|
+
@dataclass
|
22
|
+
class CompileOptions:
|
23
|
+
# JAX compilation options
|
24
|
+
disable_jit: bool = field(
|
25
|
+
value=False,
|
26
|
+
help="If True, disables JIT compilation",
|
27
|
+
)
|
28
|
+
enable_x64: bool = field(
|
29
|
+
value=False,
|
30
|
+
help="If True, enables 64-bit precision",
|
31
|
+
)
|
32
|
+
default_device: str | None = field(
|
33
|
+
value=None,
|
34
|
+
help="Default device to use (e.g. 'cpu', 'gpu')",
|
35
|
+
)
|
36
|
+
|
37
|
+
# JAX logging options
|
38
|
+
logging_level: str = field(
|
39
|
+
value="INFO",
|
40
|
+
help="JAX logging verbosity level",
|
41
|
+
)
|
42
|
+
|
43
|
+
# JAX cache options
|
44
|
+
cache_dir: str | None = field(
|
45
|
+
value=lambda: str((Path.home() / ".cache" / "jax" / "jaxcache").resolve()),
|
46
|
+
help="Directory for JAX compilation cache. If None, caching is disabled",
|
47
|
+
)
|
48
|
+
cache_min_size_bytes: int = field(
|
49
|
+
value=-1,
|
50
|
+
help="Minimum size in bytes for cache entries. -1 means no minimum",
|
51
|
+
)
|
52
|
+
cache_min_compile_time_secs: float = field(
|
53
|
+
value=0.0,
|
54
|
+
help="Minimum compilation time in seconds for cache entries. 0 means no minimum",
|
55
|
+
)
|
56
|
+
cache_enable_xla: str = field(
|
57
|
+
value="all",
|
58
|
+
help="Which XLA caches to enable",
|
59
|
+
)
|
60
|
+
|
61
|
+
|
62
|
+
@jax.tree_util.register_dataclass
|
63
|
+
@dataclass
|
64
|
+
class CompileConfig(BaseConfig):
|
65
|
+
compile: CompileOptions = field(CompileOptions(), help="Compilation configuration")
|
66
|
+
|
67
|
+
|
68
|
+
Config = TypeVar("Config", bound=CompileConfig)
|
69
|
+
|
70
|
+
|
71
|
+
class CompileMixin(BaseTask[Config], Generic[Config]):
|
72
|
+
"""Defines a task mixin for controlling JAX compilation behavior."""
|
73
|
+
|
74
|
+
def __init__(self, config: Config) -> None:
|
75
|
+
super().__init__(config)
|
76
|
+
|
77
|
+
cc = self.config.compile
|
78
|
+
|
79
|
+
# Set basic compilation flags
|
80
|
+
if cc.disable_jit:
|
81
|
+
logger.info("Disabling JIT compilation")
|
82
|
+
jax.config.update("jax_disable_jit", True)
|
83
|
+
|
84
|
+
if cc.enable_x64:
|
85
|
+
logger.info("Enabling 64-bit precision")
|
86
|
+
jax.config.update("jax_enable_x64", True)
|
87
|
+
|
88
|
+
if cc.default_device is not None:
|
89
|
+
logger.info("Setting default device to %s", cc.default_device)
|
90
|
+
jax.config.update("jax_default_device", cc.default_device)
|
91
|
+
|
92
|
+
# Set logging level
|
93
|
+
logger.info("Setting JAX logging level to %s", cc.logging_level)
|
94
|
+
jax.config.update("jax_logging_level", cc.logging_level)
|
95
|
+
|
96
|
+
# Configure compilation cache
|
97
|
+
if cc.cache_dir is not None:
|
98
|
+
logger.info("Setting JAX compilation cache directory to %s", cc.cache_dir)
|
99
|
+
jax.config.update("jax_compilation_cache_dir", cc.cache_dir)
|
100
|
+
|
101
|
+
logger.info("Configuring JAX compilation cache parameters")
|
102
|
+
jax.config.update("jax_persistent_cache_min_entry_size_bytes", cc.cache_min_size_bytes)
|
103
|
+
jax.config.update("jax_persistent_cache_min_compile_time_secs", cc.cache_min_compile_time_secs)
|
104
|
+
jax.config.update("jax_persistent_cache_enable_xla_caches", cc.cache_enable_xla)
|