xax 0.1.15__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- xax/__init__.py +1 -1
- xax/core/state.py +26 -1
- xax/requirements.txt +5 -5
- xax/task/base.py +1 -1
- xax/task/logger.py +149 -12
- xax/task/loggers/json.py +12 -4
- xax/task/loggers/stdout.py +21 -16
- xax/task/loggers/tensorboard.py +18 -2
- xax/task/mixins/checkpointing.py +118 -41
- xax/task/mixins/cpu_stats.py +10 -10
- xax/task/mixins/data_loader.py +2 -1
- xax/task/mixins/gpu_stats.py +3 -3
- xax/task/mixins/train.py +59 -29
- xax/utils/experiments.py +34 -30
- xax/utils/tensorboard.py +91 -3
- {xax-0.1.15.dist-info → xax-0.2.0.dist-info}/METADATA +6 -6
- {xax-0.1.15.dist-info → xax-0.2.0.dist-info}/RECORD +20 -20
- {xax-0.1.15.dist-info → xax-0.2.0.dist-info}/WHEEL +0 -0
- {xax-0.1.15.dist-info → xax-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {xax-0.1.15.dist-info → xax-0.2.0.dist-info}/top_level.txt +0 -0
xax/task/mixins/checkpointing.py
CHANGED
@@ -6,9 +6,9 @@ import logging
|
|
6
6
|
import tarfile
|
7
7
|
from dataclasses import asdict, dataclass
|
8
8
|
from pathlib import Path
|
9
|
-
from typing import
|
9
|
+
from typing import Generic, Literal, TypeVar, cast, overload
|
10
10
|
|
11
|
-
import
|
11
|
+
import equinox as eqx
|
12
12
|
import jax
|
13
13
|
import optax
|
14
14
|
from jaxtyping import PyTree
|
@@ -64,7 +64,9 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
64
64
|
def get_init_ckpt_path(self) -> Path | None:
|
65
65
|
if self._exp_dir is not None:
|
66
66
|
ckpt_path = self.get_ckpt_path()
|
67
|
-
if ckpt_path.exists():
|
67
|
+
if not ckpt_path.exists():
|
68
|
+
logger.warning("No checkpoint found in experiment directory: %s", ckpt_path)
|
69
|
+
else:
|
68
70
|
return ckpt_path
|
69
71
|
if self.config.load_from_ckpt_path is not None:
|
70
72
|
ckpt_path = Path(self.config.load_from_ckpt_path)
|
@@ -87,41 +89,54 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
87
89
|
def load_checkpoint(
|
88
90
|
self,
|
89
91
|
path: Path,
|
90
|
-
|
91
|
-
|
92
|
+
*,
|
93
|
+
part: Literal["all"],
|
94
|
+
model_template: PyTree,
|
95
|
+
optimizer_template: PyTree,
|
96
|
+
opt_state_template: PyTree,
|
97
|
+
) -> tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]: ...
|
92
98
|
|
93
99
|
@overload
|
94
100
|
def load_checkpoint(
|
95
101
|
self,
|
96
102
|
path: Path,
|
97
|
-
|
98
|
-
|
103
|
+
*,
|
104
|
+
part: Literal["model_state_config"],
|
105
|
+
model_template: PyTree,
|
106
|
+
) -> tuple[PyTree, State, Config]: ...
|
99
107
|
|
100
108
|
@overload
|
101
109
|
def load_checkpoint(
|
102
110
|
self,
|
103
111
|
path: Path,
|
112
|
+
*,
|
104
113
|
part: Literal["model"],
|
114
|
+
model_template: PyTree,
|
105
115
|
) -> PyTree: ...
|
106
116
|
|
107
117
|
@overload
|
108
118
|
def load_checkpoint(
|
109
119
|
self,
|
110
120
|
path: Path,
|
121
|
+
*,
|
111
122
|
part: Literal["opt"],
|
123
|
+
optimizer_template: PyTree,
|
112
124
|
) -> optax.GradientTransformation: ...
|
113
125
|
|
114
126
|
@overload
|
115
127
|
def load_checkpoint(
|
116
128
|
self,
|
117
129
|
path: Path,
|
130
|
+
*,
|
118
131
|
part: Literal["opt_state"],
|
132
|
+
opt_state_template: PyTree,
|
119
133
|
) -> optax.OptState: ...
|
120
134
|
|
121
135
|
@overload
|
122
136
|
def load_checkpoint(
|
123
137
|
self,
|
124
138
|
path: Path,
|
139
|
+
*,
|
125
140
|
part: Literal["state"],
|
126
141
|
) -> State: ...
|
127
142
|
|
@@ -129,48 +144,71 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
129
144
|
def load_checkpoint(
|
130
145
|
self,
|
131
146
|
path: Path,
|
147
|
+
*,
|
132
148
|
part: Literal["config"],
|
133
|
-
) ->
|
149
|
+
) -> Config: ...
|
134
150
|
|
135
151
|
def load_checkpoint(
|
136
152
|
self,
|
137
153
|
path: Path,
|
154
|
+
*,
|
138
155
|
part: CheckpointPart = "all",
|
156
|
+
model_template: PyTree | None = None,
|
157
|
+
optimizer_template: PyTree | None = None,
|
158
|
+
opt_state_template: PyTree | None = None,
|
139
159
|
) -> (
|
140
|
-
tuple[PyTree, optax.GradientTransformation, optax.OptState, State,
|
141
|
-
| tuple[PyTree, State,
|
160
|
+
tuple[PyTree, optax.GradientTransformation, optax.OptState, State, Config]
|
161
|
+
| tuple[PyTree, State, Config]
|
142
162
|
| PyTree
|
143
163
|
| optax.GradientTransformation
|
144
164
|
| optax.OptState
|
145
165
|
| State
|
146
|
-
|
|
166
|
+
| Config
|
147
167
|
):
|
168
|
+
"""Load a checkpoint.
|
169
|
+
|
170
|
+
Args:
|
171
|
+
path: Path to the checkpoint directory
|
172
|
+
part: Which part of the checkpoint to load
|
173
|
+
model_template: Template model with correct structure but uninitialized weights
|
174
|
+
optimizer_template: Template optimizer with correct structure but uninitialized weights
|
175
|
+
opt_state_template: Template optimizer state with correct structure but uninitialized weights
|
176
|
+
|
177
|
+
Returns:
|
178
|
+
The requested checkpoint components
|
179
|
+
"""
|
148
180
|
with tarfile.open(path, "r:gz") as tar:
|
149
181
|
|
150
182
|
def get_model() -> PyTree:
|
183
|
+
if model_template is None:
|
184
|
+
raise ValueError("model_template must be provided to load model weights")
|
151
185
|
if (model := tar.extractfile("model")) is None:
|
152
186
|
raise ValueError(f"Checkpoint does not contain a model file: {path}")
|
153
|
-
return
|
187
|
+
return eqx.tree_deserialise_leaves(io.BytesIO(model.read()), model_template)
|
154
188
|
|
155
189
|
def get_opt() -> optax.GradientTransformation:
|
156
|
-
if
|
157
|
-
raise ValueError(
|
158
|
-
|
190
|
+
if optimizer_template is None:
|
191
|
+
raise ValueError("optimizer_template must be provided to load optimizer")
|
192
|
+
if (opt := tar.extractfile("optimizer")) is None:
|
193
|
+
raise ValueError(f"Checkpoint does not contain an optimizer file: {path}")
|
194
|
+
return eqx.tree_deserialise_leaves(io.BytesIO(opt.read()), optimizer_template)
|
159
195
|
|
160
196
|
def get_opt_state() -> optax.OptState:
|
197
|
+
if opt_state_template is None:
|
198
|
+
raise ValueError("opt_state_template must be provided to load optimizer state")
|
161
199
|
if (opt_state := tar.extractfile("opt_state")) is None:
|
162
|
-
raise ValueError(f"Checkpoint does not contain an
|
163
|
-
return
|
200
|
+
raise ValueError(f"Checkpoint does not contain an optimizer state file: {path}")
|
201
|
+
return eqx.tree_deserialise_leaves(io.BytesIO(opt_state.read()), opt_state_template)
|
164
202
|
|
165
203
|
def get_state() -> State:
|
166
204
|
if (state := tar.extractfile("state")) is None:
|
167
205
|
raise ValueError(f"Checkpoint does not contain a state file: {path}")
|
168
206
|
return State(**json.loads(state.read().decode()))
|
169
207
|
|
170
|
-
def get_config() ->
|
208
|
+
def get_config() -> Config:
|
171
209
|
if (config := tar.extractfile("config")) is None:
|
172
210
|
raise ValueError(f"Checkpoint does not contain a config file: {path}")
|
173
|
-
return cast(DictConfig, OmegaConf.load(config))
|
211
|
+
return self.get_config(cast(DictConfig, OmegaConf.load(config)), use_cli=False)
|
174
212
|
|
175
213
|
match part:
|
176
214
|
case "model":
|
@@ -192,51 +230,90 @@ class CheckpointingMixin(ArtifactsMixin[Config], Generic[Config]):
|
|
192
230
|
|
193
231
|
def save_checkpoint(
|
194
232
|
self,
|
195
|
-
model: PyTree,
|
196
|
-
optimizer: optax.GradientTransformation,
|
197
|
-
opt_state: optax.OptState,
|
198
|
-
|
233
|
+
model: PyTree | None = None,
|
234
|
+
optimizer: optax.GradientTransformation | None = None,
|
235
|
+
opt_state: optax.OptState | None = None,
|
236
|
+
aux_data: PyTree | None = None,
|
237
|
+
state: State | None = None,
|
199
238
|
) -> Path:
|
239
|
+
"""Save a checkpoint.
|
240
|
+
|
241
|
+
Args:
|
242
|
+
model: The model to save
|
243
|
+
state: The current training state
|
244
|
+
optimizer: The optimizer to save
|
245
|
+
aux_data: Additional data to save
|
246
|
+
opt_state: The optimizer state to save
|
247
|
+
|
248
|
+
Returns:
|
249
|
+
Path to the saved checkpoint
|
250
|
+
"""
|
200
251
|
ckpt_path = self.get_ckpt_path(state)
|
201
252
|
|
202
253
|
if not is_master():
|
203
254
|
return ckpt_path
|
204
255
|
|
205
|
-
# Gets the path to the last checkpoint
|
256
|
+
# Gets the path to the last checkpoint
|
206
257
|
logger.info("Saving checkpoint to %s", ckpt_path)
|
207
258
|
last_ckpt_path = self.get_ckpt_path()
|
208
259
|
ckpt_path.parent.mkdir(exist_ok=True, parents=True)
|
209
260
|
|
210
|
-
# Potentially removes the last checkpoint
|
261
|
+
# Potentially removes the last checkpoint
|
211
262
|
if last_ckpt_path.exists() and self.config.only_save_most_recent:
|
212
263
|
if (base_ckpt := last_ckpt_path.resolve()).is_file():
|
213
264
|
base_ckpt.unlink()
|
214
265
|
|
215
|
-
#
|
266
|
+
# Save the checkpoint components
|
216
267
|
with tarfile.open(ckpt_path, "w:gz") as tar:
|
217
268
|
|
218
|
-
def add_file(name: str,
|
269
|
+
def add_file(name: str, buf: io.BytesIO) -> None:
|
270
|
+
tarinfo = tarfile.TarInfo(name)
|
271
|
+
tarinfo.size = buf.tell()
|
272
|
+
buf.seek(0)
|
273
|
+
tar.addfile(tarinfo, buf)
|
274
|
+
|
275
|
+
# Save model using Equinox
|
276
|
+
if model is not None:
|
277
|
+
with io.BytesIO() as buf:
|
278
|
+
eqx.tree_serialise_leaves(buf, model)
|
279
|
+
add_file("model", buf)
|
280
|
+
|
281
|
+
# Save optimizer using Equinox
|
282
|
+
if optimizer is not None:
|
283
|
+
with io.BytesIO() as buf:
|
284
|
+
eqx.tree_serialise_leaves(buf, optimizer)
|
285
|
+
add_file("optimizer", buf)
|
286
|
+
|
287
|
+
# Save optimizer state using Equinox
|
288
|
+
if opt_state is not None:
|
219
289
|
with io.BytesIO() as buf:
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
290
|
+
eqx.tree_serialise_leaves(buf, opt_state)
|
291
|
+
add_file("opt_state", buf)
|
292
|
+
|
293
|
+
# Save aux data using Equinox.
|
294
|
+
if aux_data is not None:
|
295
|
+
with io.BytesIO() as buf:
|
296
|
+
eqx.tree_serialise_leaves(buf, aux_data)
|
297
|
+
add_file("aux_data", buf)
|
298
|
+
|
299
|
+
# Save state and config as JSON
|
300
|
+
def add_file_bytes(name: str, data: bytes) -> None: # noqa: ANN401
|
301
|
+
info = tarfile.TarInfo(name=name)
|
302
|
+
info.size = len(data)
|
303
|
+
tar.addfile(info, io.BytesIO(data))
|
304
|
+
|
305
|
+
if state is not None:
|
306
|
+
add_file_bytes("state", json.dumps(asdict(state), indent=2).encode())
|
307
|
+
add_file_bytes("config", OmegaConf.to_yaml(self.config).encode())
|
308
|
+
|
309
|
+
# Updates the symlink to the new checkpoint
|
233
310
|
last_ckpt_path.unlink(missing_ok=True)
|
234
311
|
try:
|
235
312
|
last_ckpt_path.symlink_to(ckpt_path.relative_to(last_ckpt_path.parent))
|
236
313
|
except FileExistsError:
|
237
314
|
logger.exception("Exception while trying to update %s", ckpt_path)
|
238
315
|
|
239
|
-
# Calls the base callback
|
316
|
+
# Calls the base callback
|
240
317
|
self.on_after_checkpoint_save(ckpt_path, state)
|
241
318
|
|
242
319
|
return ckpt_path
|
xax/task/mixins/cpu_stats.py
CHANGED
@@ -248,15 +248,15 @@ class CPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
|
|
248
248
|
stats = monitor.get_if_set() if self.config.cpu_stats.only_log_once else monitor.get()
|
249
249
|
|
250
250
|
if stats is not None:
|
251
|
-
self.logger.log_scalar("child_procs", stats.num_child_procs, namespace="🔧 cpu")
|
252
|
-
self.logger.log_scalar("percent", stats.cpu_percent, namespace="🔧 cpu")
|
253
|
-
self.logger.log_scalar("child_percent", stats.child_cpu_percent, namespace="🔧 cpu")
|
254
|
-
self.logger.log_scalar("percent", stats.mem_percent, namespace="🔧 mem")
|
255
|
-
self.logger.log_scalar("shared", stats.mem_shared, namespace="🔧 mem")
|
256
|
-
self.logger.log_scalar("child_percent", stats.child_mem_percent, namespace="🔧 mem")
|
257
|
-
self.logger.log_scalar("rss/cur", stats.mem_rss, namespace="🔧 mem")
|
258
|
-
self.logger.log_scalar("rss/total", stats.mem_rss_total, namespace="🔧 mem")
|
259
|
-
self.logger.log_scalar("vms/cur", stats.mem_vms, namespace="🔧 mem")
|
260
|
-
self.logger.log_scalar("vms/total", stats.mem_vms_total, namespace="🔧 mem")
|
251
|
+
self.logger.log_scalar("child_procs", stats.num_child_procs, namespace="🔧 cpu", secondary=True)
|
252
|
+
self.logger.log_scalar("percent", stats.cpu_percent, namespace="🔧 cpu", secondary=True)
|
253
|
+
self.logger.log_scalar("child_percent", stats.child_cpu_percent, namespace="🔧 cpu", secondary=True)
|
254
|
+
self.logger.log_scalar("percent", stats.mem_percent, namespace="🔧 mem", secondary=True)
|
255
|
+
self.logger.log_scalar("shared", stats.mem_shared, namespace="🔧 mem", secondary=True)
|
256
|
+
self.logger.log_scalar("child_percent", stats.child_mem_percent, namespace="🔧 mem", secondary=True)
|
257
|
+
self.logger.log_scalar("rss/cur", stats.mem_rss, namespace="🔧 mem", secondary=True)
|
258
|
+
self.logger.log_scalar("rss/total", stats.mem_rss_total, namespace="🔧 mem", secondary=True)
|
259
|
+
self.logger.log_scalar("vms/cur", stats.mem_vms, namespace="🔧 mem", secondary=True)
|
260
|
+
self.logger.log_scalar("vms/total", stats.mem_vms_total, namespace="🔧 mem", secondary=True)
|
261
261
|
|
262
262
|
return state
|
xax/task/mixins/data_loader.py
CHANGED
@@ -9,6 +9,7 @@ import jax
|
|
9
9
|
from dpshdl.dataloader import CollatedDataloaderItem, Dataloader
|
10
10
|
from dpshdl.dataset import Dataset, ErrorHandlingDataset
|
11
11
|
from dpshdl.prefetcher import Prefetcher
|
12
|
+
from jaxtyping import PRNGKeyArray
|
12
13
|
from omegaconf import II, MISSING
|
13
14
|
|
14
15
|
from xax.core.conf import field, is_missing
|
@@ -103,7 +104,7 @@ class DataloadersMixin(ProcessMixin[Config], BaseTask[Config], Generic[Config],
|
|
103
104
|
"or `get_data_iterator` to return an iterator for the given dataset."
|
104
105
|
)
|
105
106
|
|
106
|
-
def get_data_iterator(self, phase: Phase) -> Iterator:
|
107
|
+
def get_data_iterator(self, phase: Phase, key: PRNGKeyArray) -> Iterator:
|
107
108
|
raise NotImplementedError(
|
108
109
|
"You must implement either the `get_dataset` method to return the dataset for the given phase, "
|
109
110
|
"or `get_data_iterator` to return an iterator for the given dataset."
|
xax/task/mixins/gpu_stats.py
CHANGED
@@ -264,8 +264,8 @@ class GPUStatsMixin(ProcessMixin[Config], LoggerMixin[Config], Generic[Config]):
|
|
264
264
|
for gpu_stat in stats.values():
|
265
265
|
if gpu_stat is None:
|
266
266
|
continue
|
267
|
-
self.logger.log_scalar(f"mem/{gpu_stat.index}", gpu_stat.memory_used, namespace="🔧 gpu")
|
268
|
-
self.logger.log_scalar(f"temp/{gpu_stat.index}", gpu_stat.temperature, namespace="🔧 gpu")
|
269
|
-
self.logger.log_scalar(f"util/{gpu_stat.index}", gpu_stat.utilization, namespace="🔧 gpu")
|
267
|
+
self.logger.log_scalar(f"mem/{gpu_stat.index}", gpu_stat.memory_used, namespace="🔧 gpu", secondary=True)
|
268
|
+
self.logger.log_scalar(f"temp/{gpu_stat.index}", gpu_stat.temperature, namespace="🔧 gpu", secondary=True)
|
269
|
+
self.logger.log_scalar(f"util/{gpu_stat.index}", gpu_stat.utilization, namespace="🔧 gpu", secondary=True)
|
270
270
|
|
271
271
|
return state
|
xax/task/mixins/train.py
CHANGED
@@ -11,7 +11,7 @@ import textwrap
|
|
11
11
|
import time
|
12
12
|
import traceback
|
13
13
|
from abc import ABC, abstractmethod
|
14
|
-
from dataclasses import dataclass, is_dataclass
|
14
|
+
from dataclasses import asdict, dataclass, is_dataclass
|
15
15
|
from threading import Thread
|
16
16
|
from typing import (
|
17
17
|
Any,
|
@@ -33,7 +33,6 @@ import jax.numpy as jnp
|
|
33
33
|
import numpy as np
|
34
34
|
import optax
|
35
35
|
from jaxtyping import Array, PRNGKeyArray, PyTree
|
36
|
-
from omegaconf import DictConfig
|
37
36
|
|
38
37
|
from xax.core.conf import field
|
39
38
|
from xax.core.state import Phase, State
|
@@ -50,6 +49,7 @@ from xax.utils.experiments import (
|
|
50
49
|
TrainingFinishedError,
|
51
50
|
diff_configs,
|
52
51
|
get_diff_string,
|
52
|
+
get_info_json,
|
53
53
|
get_state_file_string,
|
54
54
|
get_training_code,
|
55
55
|
)
|
@@ -218,7 +218,12 @@ class TrainMixin(
|
|
218
218
|
return state.replace(elapsed_time_s=time.time() - state.start_time_s)
|
219
219
|
|
220
220
|
def log_train_step(
|
221
|
-
self,
|
221
|
+
self,
|
222
|
+
model: PyTree,
|
223
|
+
batch: Batch,
|
224
|
+
output: Output,
|
225
|
+
metrics: FrozenDict[str, Array],
|
226
|
+
state: State,
|
222
227
|
) -> None:
|
223
228
|
"""Override this function to do logging during the training phase.
|
224
229
|
|
@@ -234,7 +239,12 @@ class TrainMixin(
|
|
234
239
|
"""
|
235
240
|
|
236
241
|
def log_valid_step(
|
237
|
-
self,
|
242
|
+
self,
|
243
|
+
model: PyTree,
|
244
|
+
batch: Batch,
|
245
|
+
output: Output,
|
246
|
+
metrics: FrozenDict[str, Array],
|
247
|
+
state: State,
|
238
248
|
) -> None:
|
239
249
|
"""Override this function to do logging during the validation phase.
|
240
250
|
|
@@ -252,12 +262,20 @@ class TrainMixin(
|
|
252
262
|
def log_state_timers(self, state: State) -> None:
|
253
263
|
timer = self.state_timers[state.phase]
|
254
264
|
timer.step(state)
|
255
|
-
for
|
256
|
-
|
257
|
-
|
265
|
+
for k, v in timer.log_dict().items():
|
266
|
+
if isinstance(v, tuple):
|
267
|
+
v, secondary = v
|
268
|
+
else:
|
269
|
+
secondary = False
|
270
|
+
self.logger.log_scalar(k, v, namespace="⌛ timers", secondary=secondary)
|
258
271
|
|
259
272
|
def log_step(
|
260
|
-
self,
|
273
|
+
self,
|
274
|
+
model: PyTree,
|
275
|
+
batch: Batch,
|
276
|
+
output: Output,
|
277
|
+
metrics: FrozenDict[str, Array],
|
278
|
+
state: State,
|
261
279
|
) -> None:
|
262
280
|
phase = state.phase
|
263
281
|
|
@@ -322,20 +340,30 @@ class TrainMixin(
|
|
322
340
|
|
323
341
|
if init_ckpt_path is not None:
|
324
342
|
logger.info("Loading checkpoint from %s", init_ckpt_path)
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
343
|
+
model_spec = eqx.filter_eval_shape(self.get_model, key)
|
344
|
+
model, state, config = self.load_checkpoint(
|
345
|
+
init_ckpt_path,
|
346
|
+
part="model_state_config",
|
347
|
+
model_template=model_spec,
|
348
|
+
)
|
349
|
+
config_diff = get_diff_string(diff_configs(asdict(config), asdict(self.config)))
|
350
|
+
if config_diff:
|
351
|
+
logger.warning("Loaded config differs from current config:\n%s", config_diff)
|
331
352
|
|
332
|
-
|
333
|
-
model, state, config = self.load_checkpoint(init_ckpt_path, "model_state_config")
|
334
|
-
config_diff = get_diff_string(diff_configs(config, cast(DictConfig, self.config)))
|
335
|
-
if config_diff:
|
336
|
-
logger.warning("Loaded config differs from current config:\n%s", config_diff)
|
353
|
+
if not load_optimizer:
|
337
354
|
return model, state
|
338
355
|
|
356
|
+
# Loads the optimizer.
|
357
|
+
optimizer_spec = eqx.filter_eval_shape(self.get_optimizer)
|
358
|
+
optimizer = self.load_checkpoint(init_ckpt_path, part="opt", optimizer_template=optimizer_spec)
|
359
|
+
|
360
|
+
# Loads the optimizer state.
|
361
|
+
opt_state_spec = eqx.filter_eval_shape(self.get_initial_opt_state, model, optimizer)
|
362
|
+
opt_state = self.load_checkpoint(init_ckpt_path, part="opt_state", opt_state_template=opt_state_spec)
|
363
|
+
|
364
|
+
return model, optimizer, opt_state, state
|
365
|
+
|
366
|
+
logger.info("No checkpoint found. Initializing a new model.")
|
339
367
|
model = self.get_model(key)
|
340
368
|
state = State.init_state()
|
341
369
|
|
@@ -536,6 +564,7 @@ class TrainMixin(
|
|
536
564
|
self.logger.log_file("state.txt", get_state_file_string(self))
|
537
565
|
self.logger.log_file("training_code.py", get_training_code(self))
|
538
566
|
self.logger.log_file("config.yaml", self.config_str(self.config, use_cli=False))
|
567
|
+
self.logger.log_file("info.json", get_info_json())
|
539
568
|
|
540
569
|
def model_partition_fn(self, item: Any) -> bool: # noqa: ANN401
|
541
570
|
return eqx.is_inexact_array(item)
|
@@ -609,16 +638,16 @@ class TrainMixin(
|
|
609
638
|
|
610
639
|
if self.should_checkpoint(state):
|
611
640
|
model = eqx.combine(model_arr, model_static)
|
612
|
-
self.save_checkpoint(model, optimizer, opt_state, state)
|
641
|
+
self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
|
613
642
|
|
614
643
|
# After finishing training, save the final checkpoint.
|
615
644
|
model = eqx.combine(model_arr, model_static)
|
616
|
-
self.save_checkpoint(model, optimizer, opt_state, state)
|
645
|
+
self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
|
617
646
|
|
618
647
|
@contextlib.contextmanager
|
619
|
-
def get_train_iterator(self) -> Generator[Iterator[Batch], None, None]:
|
648
|
+
def get_train_iterator(self, key: PRNGKeyArray) -> Generator[Iterator[Batch], None, None]:
|
620
649
|
try:
|
621
|
-
train_iterator: Iterator[Batch] = self.get_data_iterator("train")
|
650
|
+
train_iterator: Iterator[Batch] = self.get_data_iterator("train", key=key)
|
622
651
|
yield train_iterator
|
623
652
|
return
|
624
653
|
except NotImplementedError:
|
@@ -635,9 +664,9 @@ class TrainMixin(
|
|
635
664
|
logger.info("Closing train prefetcher")
|
636
665
|
|
637
666
|
@contextlib.contextmanager
|
638
|
-
def get_valid_iterator(self) -> Generator[Iterator[Batch], None, None]:
|
667
|
+
def get_valid_iterator(self, key: PRNGKeyArray) -> Generator[Iterator[Batch], None, None]:
|
639
668
|
try:
|
640
|
-
valid_iterator: Iterator[Batch] = self.get_data_iterator("valid")
|
669
|
+
valid_iterator: Iterator[Batch] = self.get_data_iterator("valid", key=key)
|
641
670
|
yield valid_iterator
|
642
671
|
return
|
643
672
|
except NotImplementedError:
|
@@ -681,12 +710,13 @@ class TrainMixin(
|
|
681
710
|
state = self.on_training_start(state)
|
682
711
|
|
683
712
|
def on_exit() -> None:
|
684
|
-
self.save_checkpoint(model, optimizer, opt_state, state)
|
713
|
+
self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
|
685
714
|
|
686
715
|
# Handle user-defined interrupts during the training loop.
|
687
716
|
self.add_signal_handler(on_exit, signal.SIGUSR1, signal.SIGTERM)
|
688
717
|
|
689
|
-
|
718
|
+
key, tkey, vkey = jax.random.split(key, 3)
|
719
|
+
with self.get_train_iterator(tkey) as train_pf, self.get_valid_iterator(vkey) as valid_pf:
|
690
720
|
try:
|
691
721
|
self.train_loop(
|
692
722
|
model=model,
|
@@ -703,7 +733,7 @@ class TrainMixin(
|
|
703
733
|
f"Finished training after {state.num_steps} steps, {state.num_samples} samples",
|
704
734
|
important=True,
|
705
735
|
)
|
706
|
-
self.save_checkpoint(model, optimizer, opt_state, state)
|
736
|
+
self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
|
707
737
|
|
708
738
|
except (KeyboardInterrupt, bdb.BdbQuit):
|
709
739
|
if is_master():
|
@@ -713,7 +743,7 @@ class TrainMixin(
|
|
713
743
|
exception_tb = textwrap.indent(highlight_exception_message(traceback.format_exc()), " ")
|
714
744
|
sys.stdout.write(f"Caught exception during training loop:\n\n{exception_tb}\n")
|
715
745
|
sys.stdout.flush()
|
716
|
-
self.save_checkpoint(model, optimizer, opt_state, state)
|
746
|
+
self.save_checkpoint(model=model, optimizer=optimizer, opt_state=opt_state, state=state)
|
717
747
|
|
718
748
|
finally:
|
719
749
|
state = self.on_training_end(state)
|
xax/utils/experiments.py
CHANGED
@@ -7,6 +7,7 @@ import functools
|
|
7
7
|
import hashlib
|
8
8
|
import inspect
|
9
9
|
import itertools
|
10
|
+
import json
|
10
11
|
import logging
|
11
12
|
import math
|
12
13
|
import os
|
@@ -24,7 +25,7 @@ import warnings
|
|
24
25
|
from abc import ABC, abstractmethod
|
25
26
|
from pathlib import Path
|
26
27
|
from types import TracebackType
|
27
|
-
from typing import Any, Iterator, Self, TypeVar, cast
|
28
|
+
from typing import Any, Iterator, Mapping, Self, Sequence, TypeVar, cast
|
28
29
|
from urllib.parse import urlparse
|
29
30
|
|
30
31
|
import git
|
@@ -114,28 +115,13 @@ class StateTimer:
|
|
114
115
|
self.sample_timer.step(state.num_samples if state.phase == "train" else state.num_valid_samples, cur_time)
|
115
116
|
self.iter_timer.step(cur_time)
|
116
117
|
|
117
|
-
def log_dict(self) -> dict[str,
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
|
122
|
-
"total": self.step_timer.steps,
|
123
|
-
"per-second": self.step_timer.steps_per_second,
|
124
|
-
}
|
125
|
-
|
126
|
-
# Logs sample statistics.
|
127
|
-
logs["⌛ samples"] = {
|
128
|
-
"total": self.sample_timer.steps,
|
129
|
-
"per-second": self.sample_timer.steps_per_second,
|
118
|
+
def log_dict(self) -> dict[str, int | float | tuple[int | float, bool]]:
|
119
|
+
return {
|
120
|
+
"steps/second": self.step_timer.steps_per_second,
|
121
|
+
"samples/second": (self.sample_timer.steps_per_second, True),
|
122
|
+
"dt": self.iter_timer.iter_seconds,
|
130
123
|
}
|
131
124
|
|
132
|
-
# Logs full iteration statistics.
|
133
|
-
logs["⌛ dt"] = {
|
134
|
-
"iter": self.iter_timer.iter_seconds,
|
135
|
-
}
|
136
|
-
|
137
|
-
return logs
|
138
|
-
|
139
125
|
|
140
126
|
class IntervalTicker:
|
141
127
|
def __init__(self, interval: float) -> None:
|
@@ -217,8 +203,8 @@ class MinGradScaleError(TrainingFinishedError):
|
|
217
203
|
|
218
204
|
|
219
205
|
def diff_configs(
|
220
|
-
first:
|
221
|
-
second:
|
206
|
+
first: Mapping | Sequence,
|
207
|
+
second: Mapping | Sequence,
|
222
208
|
prefix: str | None = None,
|
223
209
|
) -> tuple[list[str], list[str]]:
|
224
210
|
"""Returns the difference between two configs.
|
@@ -245,7 +231,7 @@ def diff_configs(
|
|
245
231
|
|
246
232
|
any_config = (ListConfig, DictConfig)
|
247
233
|
|
248
|
-
if isinstance(first,
|
234
|
+
if isinstance(first, Mapping) and isinstance(second, Mapping):
|
249
235
|
first_keys, second_keys = cast(set[str], set(first.keys())), cast(set[str], set(second.keys()))
|
250
236
|
|
251
237
|
# Gets the new keys in each config.
|
@@ -255,11 +241,12 @@ def diff_configs(
|
|
255
241
|
# Gets the new sub-keys in each config.
|
256
242
|
for key in first_keys.intersection(second_keys):
|
257
243
|
sub_prefix = key if prefix is None else f"{prefix}.{key}"
|
258
|
-
if
|
259
|
-
if
|
260
|
-
|
261
|
-
|
262
|
-
|
244
|
+
if isinstance(first, DictConfig) and isinstance(second, DictConfig):
|
245
|
+
if OmegaConf.is_missing(first, key) or OmegaConf.is_missing(second, key):
|
246
|
+
if not OmegaConf.is_missing(first, key):
|
247
|
+
new_first += [get_diff_string(sub_prefix, first[key])]
|
248
|
+
if not OmegaConf.is_missing(second, key):
|
249
|
+
new_second += [get_diff_string(sub_prefix, second[key])]
|
263
250
|
elif isinstance(first[key], any_config) and isinstance(second[key], any_config):
|
264
251
|
sub_new_first, sub_new_second = diff_configs(first[key], second[key], prefix=sub_prefix)
|
265
252
|
new_first, new_second = new_first + sub_new_first, new_second + sub_new_second
|
@@ -268,7 +255,7 @@ def diff_configs(
|
|
268
255
|
new_first += [get_diff_string(sub_prefix, first_val)]
|
269
256
|
new_second += [get_diff_string(sub_prefix, second_val)]
|
270
257
|
|
271
|
-
elif isinstance(first,
|
258
|
+
elif isinstance(first, Sequence) and isinstance(second, Sequence):
|
272
259
|
if len(first) > len(second):
|
273
260
|
for i in range(len(second), len(first)):
|
274
261
|
new_first += [get_diff_string(prefix, first[i])]
|
@@ -483,16 +470,33 @@ def get_command_line_string() -> str:
|
|
483
470
|
return " ".join(sys.argv)
|
484
471
|
|
485
472
|
|
473
|
+
def get_environment_variables() -> str:
|
474
|
+
return "\n".join([f"{key}={value}" for key, value in sorted(os.environ.items())])
|
475
|
+
|
476
|
+
|
486
477
|
def get_state_file_string(obj: object) -> str:
|
487
478
|
return "\n\n".join(
|
488
479
|
[
|
489
480
|
f"=== Command Line ===\n\n{get_command_line_string()}",
|
490
481
|
f"=== Git State ===\n\n{get_git_state(obj)}",
|
491
482
|
f"=== Packages ===\n\n{get_packages_with_versions()}",
|
483
|
+
f"=== Environment Variables ===\n\n{get_environment_variables()}",
|
492
484
|
]
|
493
485
|
)
|
494
486
|
|
495
487
|
|
488
|
+
def get_info_json() -> str:
|
489
|
+
return json.dumps(
|
490
|
+
{
|
491
|
+
"process_id": os.getpid(),
|
492
|
+
"job": {
|
493
|
+
"start_time": datetime.datetime.now().isoformat(),
|
494
|
+
},
|
495
|
+
},
|
496
|
+
indent=2,
|
497
|
+
)
|
498
|
+
|
499
|
+
|
496
500
|
def get_training_code(obj: object) -> str:
|
497
501
|
"""Gets the text from the file containing the provided object.
|
498
502
|
|