trainloop 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- trainloop-0.1.0/PKG-INFO +24 -0
- trainloop-0.1.0/README.md +13 -0
- trainloop-0.1.0/pyproject.toml +26 -0
- trainloop-0.1.0/src/trainloop/__init__.py +25 -0
- trainloop-0.1.0/src/trainloop/hooks.py +716 -0
- trainloop-0.1.0/src/trainloop/py.typed +0 -0
- trainloop-0.1.0/src/trainloop/trainer.py +406 -0
- trainloop-0.1.0/src/trainloop/utils.py +67 -0
trainloop-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: trainloop
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Minimal PyTorch training loop with hooks and checkpointing.
|
|
5
|
+
Author: Karim Abou Zeid
|
|
6
|
+
Author-email: Karim Abou Zeid <contact@ka.codes>
|
|
7
|
+
Requires-Dist: pillow>=11.3.0
|
|
8
|
+
Requires-Dist: torch>=2.0.0
|
|
9
|
+
Requires-Python: >=3.10
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
|
|
12
|
+
# trainloop
|
|
13
|
+
|
|
14
|
+
[](https://pypi.org/project/trainloop/)
|
|
15
|
+
|
|
16
|
+
Minimal PyTorch training loop with hooks for logging, checkpointing, and customization.
|
|
17
|
+
|
|
18
|
+
Docs: https://kabouzeid.github.io/trainloop/
|
|
19
|
+
|
|
20
|
+
## Install
|
|
21
|
+
|
|
22
|
+
```bash
|
|
23
|
+
pip install trainloop
|
|
24
|
+
```
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# trainloop
|
|
2
|
+
|
|
3
|
+
[](https://pypi.org/project/trainloop/)
|
|
4
|
+
|
|
5
|
+
Minimal PyTorch training loop with hooks for logging, checkpointing, and customization.
|
|
6
|
+
|
|
7
|
+
Docs: https://kabouzeid.github.io/trainloop/
|
|
8
|
+
|
|
9
|
+
## Install
|
|
10
|
+
|
|
11
|
+
```bash
|
|
12
|
+
pip install trainloop
|
|
13
|
+
```
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "trainloop"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "Minimal PyTorch training loop with hooks and checkpointing."
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
authors = [
|
|
7
|
+
{ name = "Karim Abou Zeid", email = "contact@ka.codes" }
|
|
8
|
+
]
|
|
9
|
+
requires-python = ">=3.10"
|
|
10
|
+
dependencies = [
|
|
11
|
+
"pillow>=11.3.0",
|
|
12
|
+
"torch>=2.0.0",
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
[build-system]
|
|
16
|
+
requires = ["uv_build>=0.9.0,<0.10.0"]
|
|
17
|
+
build-backend = "uv_build"
|
|
18
|
+
|
|
19
|
+
[dependency-groups]
|
|
20
|
+
dev = [
|
|
21
|
+
"wandb>=0.20.1",
|
|
22
|
+
"pytest>=8.4.0",
|
|
23
|
+
"ruff>=0.11.13",
|
|
24
|
+
"zensical>=0.0.10",
|
|
25
|
+
"mkdocstrings-python>=2.0.1",
|
|
26
|
+
]
|
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
from .hooks import (
|
|
2
|
+
BaseHook,
|
|
3
|
+
CheckpointingHook,
|
|
4
|
+
CudaMaxMemoryHook,
|
|
5
|
+
EmaHook,
|
|
6
|
+
ImageFileLoggerHook,
|
|
7
|
+
LoggingHook,
|
|
8
|
+
ProgressHook,
|
|
9
|
+
WandbHook,
|
|
10
|
+
)
|
|
11
|
+
from .trainer import BaseTrainer, LossNoneWarning, map_nested_tensor
|
|
12
|
+
|
|
13
|
+
__all__ = [
|
|
14
|
+
"BaseTrainer",
|
|
15
|
+
"BaseHook",
|
|
16
|
+
"CheckpointingHook",
|
|
17
|
+
"CudaMaxMemoryHook",
|
|
18
|
+
"LoggingHook",
|
|
19
|
+
"ProgressHook",
|
|
20
|
+
"EmaHook",
|
|
21
|
+
"WandbHook",
|
|
22
|
+
"ImageFileLoggerHook",
|
|
23
|
+
"LossNoneWarning",
|
|
24
|
+
"map_nested_tensor",
|
|
25
|
+
]
|
|
@@ -0,0 +1,716 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import shutil
|
|
3
|
+
import signal
|
|
4
|
+
import sys
|
|
5
|
+
import tempfile
|
|
6
|
+
import time
|
|
7
|
+
import warnings
|
|
8
|
+
from datetime import timedelta
|
|
9
|
+
from numbers import Number
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import Any, Callable, Iterable, Literal, Sequence
|
|
12
|
+
|
|
13
|
+
import torch
|
|
14
|
+
import torch.distributed as dist
|
|
15
|
+
from PIL import Image
|
|
16
|
+
from PIL.Image import Image as PILImage
|
|
17
|
+
from torch.distributed.checkpoint.state_dict import (
|
|
18
|
+
get_model_state_dict,
|
|
19
|
+
set_model_state_dict,
|
|
20
|
+
)
|
|
21
|
+
from torch.optim.swa_utils import AveragedModel, get_ema_avg_fn
|
|
22
|
+
|
|
23
|
+
try:
|
|
24
|
+
import wandb
|
|
25
|
+
except ImportError:
|
|
26
|
+
# only needed for WandbHook
|
|
27
|
+
pass
|
|
28
|
+
|
|
29
|
+
from .trainer import BaseTrainer, Records
|
|
30
|
+
from .utils import flatten_nested_dict, key_average
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class BaseHook:
|
|
34
|
+
"""Lifecycle hooks for `BaseTrainer`."""
|
|
35
|
+
|
|
36
|
+
def on_before_train(self, trainer: BaseTrainer):
|
|
37
|
+
pass
|
|
38
|
+
|
|
39
|
+
def on_before_step(self, trainer: BaseTrainer):
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
def on_before_optimizer_step(self, trainer: BaseTrainer):
|
|
43
|
+
pass
|
|
44
|
+
|
|
45
|
+
def on_after_step(self, trainer: BaseTrainer):
|
|
46
|
+
pass
|
|
47
|
+
|
|
48
|
+
def on_after_train(self, trainer: BaseTrainer):
|
|
49
|
+
pass
|
|
50
|
+
|
|
51
|
+
def on_log(self, trainer: BaseTrainer, records: dict, dry_run: bool = False):
|
|
52
|
+
pass
|
|
53
|
+
|
|
54
|
+
def on_log_images(self, trainer: BaseTrainer, records: dict, dry_run: bool = False):
|
|
55
|
+
pass
|
|
56
|
+
|
|
57
|
+
def on_state_dict(self, trainer: BaseTrainer, state_dict: dict):
|
|
58
|
+
pass
|
|
59
|
+
|
|
60
|
+
def on_load_state_dict(self, trainer: BaseTrainer, state_dict: dict):
|
|
61
|
+
pass
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class _StatsHook(BaseHook):
|
|
65
|
+
"""Collect step statistics and hand them to subclasses for reporting.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
interval: Emit stats every N steps.
|
|
69
|
+
sync: If True, aggregate stats across distributed ranks.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
def __init__(
|
|
73
|
+
self,
|
|
74
|
+
interval: int,
|
|
75
|
+
sync: bool,
|
|
76
|
+
):
|
|
77
|
+
self.interval = interval
|
|
78
|
+
self.sync = sync
|
|
79
|
+
self.reset()
|
|
80
|
+
|
|
81
|
+
def reset(self):
|
|
82
|
+
self.losses = []
|
|
83
|
+
self.records_ls = []
|
|
84
|
+
self.grad_norms = []
|
|
85
|
+
self.data_times = []
|
|
86
|
+
self.step_times = []
|
|
87
|
+
self.max_memories = []
|
|
88
|
+
|
|
89
|
+
def on_after_step(self, trainer: BaseTrainer):
|
|
90
|
+
# collect and aggregate over accumulation steps
|
|
91
|
+
self.losses.append(torch.stack(trainer.step_info["loss"]).mean())
|
|
92
|
+
if trainer.grad_clip is not None:
|
|
93
|
+
self.grad_norms.append(trainer.step_info["grad_norm"])
|
|
94
|
+
self.records_ls.append(key_average(trainer.step_info["records"]))
|
|
95
|
+
self.data_times.append(sum(trainer.step_info["data_time"])) # total
|
|
96
|
+
self.step_times.append(trainer.step_info["step_time"])
|
|
97
|
+
if "max_memory" in trainer.step_info:
|
|
98
|
+
self.max_memories.append(trainer.step_info["max_memory"])
|
|
99
|
+
|
|
100
|
+
if trainer.step % self.interval == 0 or trainer.step == trainer.max_steps:
|
|
101
|
+
# aggregate over steps
|
|
102
|
+
loss = torch.stack(self.losses).mean()
|
|
103
|
+
grad_norm = torch.stack(self.grad_norms).mean() if self.grad_norms else None
|
|
104
|
+
records = key_average(self.records_ls)
|
|
105
|
+
data_time = sum(self.data_times) / len(self.data_times)
|
|
106
|
+
step_time = sum(self.step_times) / len(self.step_times)
|
|
107
|
+
max_memory = max(self.max_memories) if self.max_memories else None
|
|
108
|
+
|
|
109
|
+
if self.sync:
|
|
110
|
+
# aggregate accross all ranks
|
|
111
|
+
dist.all_reduce(loss, op=dist.ReduceOp.AVG)
|
|
112
|
+
if grad_norm is not None:
|
|
113
|
+
dist.all_reduce(grad_norm, op=dist.ReduceOp.AVG)
|
|
114
|
+
|
|
115
|
+
gathered = [None] * dist.get_world_size()
|
|
116
|
+
dist.all_gather_object(
|
|
117
|
+
gathered,
|
|
118
|
+
{
|
|
119
|
+
"records": records,
|
|
120
|
+
"data_time": data_time,
|
|
121
|
+
"step_time": step_time,
|
|
122
|
+
"max_memory": max_memory,
|
|
123
|
+
},
|
|
124
|
+
)
|
|
125
|
+
records = key_average([stat["records"] for stat in gathered])
|
|
126
|
+
data_time = sum(stat["data_time"] for stat in gathered) / len(gathered)
|
|
127
|
+
step_time = sum(stat["step_time"] for stat in gathered) / len(gathered)
|
|
128
|
+
if "max_memory" in trainer.step_info:
|
|
129
|
+
max_memory = max(stat["max_memory"] for stat in gathered)
|
|
130
|
+
|
|
131
|
+
self.process_stats(
|
|
132
|
+
trainer,
|
|
133
|
+
loss.item(),
|
|
134
|
+
grad_norm.item() if grad_norm is not None else None,
|
|
135
|
+
step_time,
|
|
136
|
+
data_time,
|
|
137
|
+
max_memory,
|
|
138
|
+
records,
|
|
139
|
+
)
|
|
140
|
+
self.reset()
|
|
141
|
+
|
|
142
|
+
def process_stats(
|
|
143
|
+
self,
|
|
144
|
+
trainer: BaseTrainer,
|
|
145
|
+
loss: float,
|
|
146
|
+
grad_norm: float | None,
|
|
147
|
+
step_time: float,
|
|
148
|
+
data_time: float,
|
|
149
|
+
max_memory: float | None,
|
|
150
|
+
records: Records,
|
|
151
|
+
):
|
|
152
|
+
raise NotImplementedError("Subclasses must implement this method.")
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class ETATracker:
|
|
156
|
+
def __init__(self, warmup_steps: int):
|
|
157
|
+
"""Track ETA across training steps after a warmup period.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
warmup_steps: Number of steps to skip before timing begins.
|
|
161
|
+
"""
|
|
162
|
+
assert warmup_steps > 0, "Warmup steps must be greater than 0"
|
|
163
|
+
self.warmup_steps = warmup_steps
|
|
164
|
+
self.steps = 0
|
|
165
|
+
self.timing_start = None
|
|
166
|
+
self.timed_steps = 0
|
|
167
|
+
|
|
168
|
+
def step(self):
|
|
169
|
+
self.steps += 1
|
|
170
|
+
if self.steps == self.warmup_steps:
|
|
171
|
+
self.timing_start = time.perf_counter()
|
|
172
|
+
if self.steps > self.warmup_steps:
|
|
173
|
+
self.timed_steps += 1
|
|
174
|
+
|
|
175
|
+
def get_eta(self, steps_remaining: int):
|
|
176
|
+
if self.timed_steps == 0:
|
|
177
|
+
return None
|
|
178
|
+
|
|
179
|
+
elapsed = time.perf_counter() - self.timing_start
|
|
180
|
+
avg_step_time = elapsed / self.timed_steps
|
|
181
|
+
eta_seconds = avg_step_time * steps_remaining
|
|
182
|
+
return timedelta(seconds=int(eta_seconds))
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class ProgressHook(_StatsHook):
|
|
186
|
+
"""Log progress to stdout with optional metrics, ETA, and memory.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
interval: Log every N steps.
|
|
190
|
+
with_records: Include per-step records in the log line.
|
|
191
|
+
sync: If True, aggregate across distributed ranks.
|
|
192
|
+
eta_warmup: Steps to warm up ETA calculation.
|
|
193
|
+
show_units: Whether to print units (s, GiB) alongside values.
|
|
194
|
+
"""
|
|
195
|
+
|
|
196
|
+
def __init__(
|
|
197
|
+
self,
|
|
198
|
+
interval: int = 1,
|
|
199
|
+
with_records: bool = False,
|
|
200
|
+
sync: bool = False,
|
|
201
|
+
eta_warmup: int = 10,
|
|
202
|
+
show_units: bool = True,
|
|
203
|
+
):
|
|
204
|
+
super().__init__(interval=interval, sync=sync)
|
|
205
|
+
self.with_records = with_records
|
|
206
|
+
self.eta_warmup = eta_warmup
|
|
207
|
+
self.show_units = show_units
|
|
208
|
+
|
|
209
|
+
def on_before_train(self, trainer: BaseTrainer):
|
|
210
|
+
super().on_before_train(trainer)
|
|
211
|
+
trainer.logger.info("=> Starting training ...")
|
|
212
|
+
self.eta_tracker = ETATracker(warmup_steps=self.eta_warmup)
|
|
213
|
+
|
|
214
|
+
def on_after_train(self, trainer: BaseTrainer):
|
|
215
|
+
super().on_after_train(trainer)
|
|
216
|
+
trainer.logger.info("=> Finished training")
|
|
217
|
+
|
|
218
|
+
def on_before_step(self, trainer: BaseTrainer):
|
|
219
|
+
super().on_before_step(trainer)
|
|
220
|
+
self.lrs = [
|
|
221
|
+
(i, param_group["lr"])
|
|
222
|
+
for i, param_group in enumerate(trainer.optimizer.param_groups)
|
|
223
|
+
] # record the LR before the scheduler steps
|
|
224
|
+
|
|
225
|
+
def on_after_step(self, trainer: BaseTrainer):
|
|
226
|
+
self.eta_tracker.step() # should be called before process_stats
|
|
227
|
+
super().on_after_step(trainer)
|
|
228
|
+
|
|
229
|
+
def process_stats(
|
|
230
|
+
self,
|
|
231
|
+
trainer: BaseTrainer,
|
|
232
|
+
loss: float,
|
|
233
|
+
grad_norm: float | None,
|
|
234
|
+
step_time: float,
|
|
235
|
+
data_time: float,
|
|
236
|
+
max_memory: float | None,
|
|
237
|
+
records: Records,
|
|
238
|
+
):
|
|
239
|
+
eta = self.eta_tracker.get_eta(trainer.max_steps - trainer.step)
|
|
240
|
+
trainer.logger.info(
|
|
241
|
+
f"Step {trainer.step:>{len(str(trainer.max_steps))}}/{trainer.max_steps}:"
|
|
242
|
+
+ f" step {step_time:.4f}{'s' if self.show_units else ''} data {data_time:.4f}{'s' if self.show_units else ''}"
|
|
243
|
+
+ (f" eta {eta}" if eta is not None else "")
|
|
244
|
+
+ (
|
|
245
|
+
f" mem {max_memory:#.3g}{'GiB' if self.show_units else ''}"
|
|
246
|
+
if max_memory is not None
|
|
247
|
+
else ""
|
|
248
|
+
)
|
|
249
|
+
+ f" loss {loss:.4f}"
|
|
250
|
+
+ (f" grad_norm {grad_norm:.4f}" if grad_norm is not None else "")
|
|
251
|
+
+ (" " + " ".join(f"lr_{i} {lr:.2e}" for i, lr in self.lrs))
|
|
252
|
+
+ (
|
|
253
|
+
(
|
|
254
|
+
" | "
|
|
255
|
+
+ " ".join(
|
|
256
|
+
f"{'/'.join(k)} {f'{v:#.4g}' if isinstance(v, Number) else v}"
|
|
257
|
+
for k, v in flatten_nested_dict(records).items()
|
|
258
|
+
)
|
|
259
|
+
)
|
|
260
|
+
if self.with_records
|
|
261
|
+
else ""
|
|
262
|
+
)
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
class LoggingHook(_StatsHook):
|
|
267
|
+
"""Aggregate stats and forward them to ``trainer.log``.
|
|
268
|
+
|
|
269
|
+
Args:
|
|
270
|
+
interval: Log every N steps.
|
|
271
|
+
sync: If True, aggregate across distributed ranks.
|
|
272
|
+
"""
|
|
273
|
+
|
|
274
|
+
def __init__(
|
|
275
|
+
self,
|
|
276
|
+
interval: int = 10,
|
|
277
|
+
sync: bool = True,
|
|
278
|
+
):
|
|
279
|
+
super().__init__(interval, sync)
|
|
280
|
+
|
|
281
|
+
def process_stats(
|
|
282
|
+
self,
|
|
283
|
+
trainer: BaseTrainer,
|
|
284
|
+
loss: float,
|
|
285
|
+
grad_norm: float | None,
|
|
286
|
+
step_time: float,
|
|
287
|
+
data_time: float,
|
|
288
|
+
max_memory: float | None,
|
|
289
|
+
records: Records,
|
|
290
|
+
):
|
|
291
|
+
lrs = [
|
|
292
|
+
(i, param_group["lr"])
|
|
293
|
+
for i, param_group in enumerate(trainer.optimizer.param_groups)
|
|
294
|
+
]
|
|
295
|
+
trainer.log(
|
|
296
|
+
{
|
|
297
|
+
"train": records
|
|
298
|
+
| ({"grad_norm": grad_norm} if grad_norm is not None else {})
|
|
299
|
+
| ({"max_memory": max_memory} if max_memory is not None else {})
|
|
300
|
+
| {
|
|
301
|
+
"loss": loss,
|
|
302
|
+
"data_time": data_time,
|
|
303
|
+
"step_time": step_time,
|
|
304
|
+
"lr": {f"group_{i}": lr for i, lr in lrs},
|
|
305
|
+
}
|
|
306
|
+
}
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
class CheckpointingHook(BaseHook):
|
|
311
|
+
"""Save and optionally restore checkpoints at regular intervals.
|
|
312
|
+
|
|
313
|
+
Args:
|
|
314
|
+
interval: Save every ``interval`` steps.
|
|
315
|
+
keep_previous: Keep the last N checkpoints in addition to the latest.
|
|
316
|
+
keep_interval: Keep checkpoints every ``keep_interval`` steps.
|
|
317
|
+
path: Directory (relative to workspace unless absolute) for checkpoints.
|
|
318
|
+
load: Path to load at startup or ``\"latest\"`` to auto-resume.
|
|
319
|
+
exit_signals: Signals that trigger a checkpoint then exit.
|
|
320
|
+
exit_code: Exit code after handling an exit signal.
|
|
321
|
+
exit_wait: Optional sleep before exit (useful for schedulers).
|
|
322
|
+
"""
|
|
323
|
+
|
|
324
|
+
def __init__(
|
|
325
|
+
self,
|
|
326
|
+
interval: int,
|
|
327
|
+
keep_previous: int = 0, # keep N previous checkpoints
|
|
328
|
+
keep_interval: int = 0, # keep checkpoints of every N-th step
|
|
329
|
+
path: Path | str = "checkpoint",
|
|
330
|
+
load: Path | str | Literal["latest"] | None = "latest",
|
|
331
|
+
exit_signals: list[signal.Signals] | signal.Signals = None,
|
|
332
|
+
exit_code: int | Literal["128+signal"] = "128+signal",
|
|
333
|
+
exit_wait: timedelta | float = 0.0,
|
|
334
|
+
):
|
|
335
|
+
assert interval > 0
|
|
336
|
+
assert keep_previous >= 0
|
|
337
|
+
self.interval = interval
|
|
338
|
+
self.keep_previous = keep_previous
|
|
339
|
+
self.keep_interval = keep_interval
|
|
340
|
+
self.path = Path(path)
|
|
341
|
+
self.load_path = Path(load) if load is not None else None
|
|
342
|
+
|
|
343
|
+
self.local_exit_signal: signal.Signals = -1 # not a valid value
|
|
344
|
+
exit_signals = exit_signals if exit_signals is not None else []
|
|
345
|
+
if not isinstance(exit_signals, Iterable):
|
|
346
|
+
exit_signals = [exit_signals]
|
|
347
|
+
for sig in exit_signals:
|
|
348
|
+
signal.signal(sig, lambda *args: setattr(self, "local_exit_signal", sig))
|
|
349
|
+
self.has_exit_signal_handlers = len(exit_signals) > 0
|
|
350
|
+
self.exit_code = exit_code
|
|
351
|
+
self.exit_wait = (
|
|
352
|
+
exit_wait.total_seconds() if isinstance(exit_wait, timedelta) else exit_wait
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
def on_before_train(self, trainer: BaseTrainer):
|
|
356
|
+
if self.has_exit_signal_handlers:
|
|
357
|
+
# micro optimization: allocate signal tensor only once
|
|
358
|
+
self.dist_exit_signal = torch.tensor(
|
|
359
|
+
self.local_exit_signal, dtype=torch.int32, device=trainer.device
|
|
360
|
+
)
|
|
361
|
+
load_path = self.load_path
|
|
362
|
+
if load_path is not None:
|
|
363
|
+
# handles 'latest' and regular checkpoints
|
|
364
|
+
if len(load_path.parts) == 1 and not load_path.is_absolute():
|
|
365
|
+
load_path = self.path / load_path
|
|
366
|
+
if not load_path.is_absolute():
|
|
367
|
+
assert trainer.workspace is not None
|
|
368
|
+
load_path = trainer.workspace / load_path
|
|
369
|
+
if not load_path.is_dir():
|
|
370
|
+
# nonexistent path is only ok if we're loading the 'latest' checkpoint
|
|
371
|
+
assert str(self.load_path) == "latest", (
|
|
372
|
+
f"Checkpoint path {load_path} does not exist"
|
|
373
|
+
)
|
|
374
|
+
return
|
|
375
|
+
|
|
376
|
+
trainer.logger.info(f"=> Loading checkpoint from {load_path} ...")
|
|
377
|
+
state_dict = {
|
|
378
|
+
file.with_suffix("").name: torch.load(
|
|
379
|
+
file, map_location=trainer.device, weights_only=True
|
|
380
|
+
)
|
|
381
|
+
for file in load_path.iterdir()
|
|
382
|
+
if file.is_file() and file.suffix == ".pt"
|
|
383
|
+
}
|
|
384
|
+
trainer.logger.debug(f"Checkpoint contains: {', '.join(state_dict.keys())}")
|
|
385
|
+
trainer.load_state_dict(state_dict)
|
|
386
|
+
|
|
387
|
+
def on_before_step(self, trainer: BaseTrainer):
|
|
388
|
+
if self.has_exit_signal_handlers:
|
|
389
|
+
self.dist_exit_signal.fill_(self.local_exit_signal)
|
|
390
|
+
# micro optimization: reduce async during step and read after step
|
|
391
|
+
self.dist_exit_signal_work = dist.all_reduce(
|
|
392
|
+
self.dist_exit_signal, op=dist.ReduceOp.MAX, async_op=True
|
|
393
|
+
)
|
|
394
|
+
|
|
395
|
+
def on_after_step(self, trainer: BaseTrainer):
|
|
396
|
+
save_and_exit = False
|
|
397
|
+
if self.has_exit_signal_handlers:
|
|
398
|
+
self.dist_exit_signal_work.wait()
|
|
399
|
+
exit_signal = self.dist_exit_signal.item()
|
|
400
|
+
save_and_exit = exit_signal != -1
|
|
401
|
+
|
|
402
|
+
# NOTE: Check if last step here (not in on_after_train) to avoid saving twice
|
|
403
|
+
if (
|
|
404
|
+
trainer.step % self.interval == 0
|
|
405
|
+
or trainer.step == trainer.max_steps
|
|
406
|
+
or save_and_exit
|
|
407
|
+
):
|
|
408
|
+
if save_and_exit:
|
|
409
|
+
trainer.logger.info(
|
|
410
|
+
f"=> Caught signal {exit_signal}. Saving checkpoint before exit ..."
|
|
411
|
+
)
|
|
412
|
+
self._save_checkpoint(
|
|
413
|
+
trainer,
|
|
414
|
+
keep=self.keep_interval > 0 and trainer.step % self.keep_interval == 0,
|
|
415
|
+
)
|
|
416
|
+
if save_and_exit:
|
|
417
|
+
dist.barrier()
|
|
418
|
+
if self.exit_wait > 0:
|
|
419
|
+
trainer.logger.info(
|
|
420
|
+
f"=> Waiting {self.exit_wait:.0f} seconds before exit ..."
|
|
421
|
+
)
|
|
422
|
+
time.sleep(self.exit_wait) # try wait for the Slurm job timeout
|
|
423
|
+
exit_code = (
|
|
424
|
+
128 + exit_signal
|
|
425
|
+
if self.exit_code == "128+signal"
|
|
426
|
+
else sys.exit(self.exit_code)
|
|
427
|
+
)
|
|
428
|
+
trainer.logger.info(f"=> Exiting (code: {exit_code})")
|
|
429
|
+
sys.exit(exit_code)
|
|
430
|
+
|
|
431
|
+
def _save_checkpoint(self, trainer: BaseTrainer, keep: bool):
|
|
432
|
+
"""Save a model checkpoint.
|
|
433
|
+
|
|
434
|
+
Raises only if writing the current checkpoint fails. Issues encountered
|
|
435
|
+
while retaining or pruning older checkpoints are logged but not raised.
|
|
436
|
+
"""
|
|
437
|
+
|
|
438
|
+
dist.barrier()
|
|
439
|
+
|
|
440
|
+
state_dict = trainer.state_dict()
|
|
441
|
+
|
|
442
|
+
# TODO: all rank gathered states
|
|
443
|
+
# gathered_random_states = [None] * dist.get_world_size()
|
|
444
|
+
# dist.gather_object(
|
|
445
|
+
# get_random_state(),
|
|
446
|
+
# gathered_random_states if dist.get_rank() == 0 else None,
|
|
447
|
+
# dst=0,
|
|
448
|
+
# )
|
|
449
|
+
|
|
450
|
+
if dist.get_rank() == 0:
|
|
451
|
+
# make dir
|
|
452
|
+
save_path = self.path / str(trainer.step)
|
|
453
|
+
if not save_path.is_absolute():
|
|
454
|
+
assert trainer.workspace is not None
|
|
455
|
+
save_path = trainer.workspace / save_path
|
|
456
|
+
save_path.parent.mkdir(parents=True, exist_ok=True)
|
|
457
|
+
trainer.logger.info(f"=> Saving checkpoint to {save_path} ...")
|
|
458
|
+
|
|
459
|
+
# save
|
|
460
|
+
tmp_save_path = self._get_tmp_save_dir(save_path)
|
|
461
|
+
for name, sub_state_dict in state_dict.items():
|
|
462
|
+
torch.save(sub_state_dict, tmp_save_path / f"{name}.pt")
|
|
463
|
+
tmp_save_path.rename(save_path)
|
|
464
|
+
|
|
465
|
+
# symlink latest
|
|
466
|
+
latest_symlink = save_path.parent / "latest"
|
|
467
|
+
if latest_symlink.is_symlink():
|
|
468
|
+
latest_symlink.unlink()
|
|
469
|
+
if latest_symlink.exists():
|
|
470
|
+
trainer.logger.error(
|
|
471
|
+
f"{latest_symlink} already exists and is not a symlink. Will not create 'latest' symlink."
|
|
472
|
+
)
|
|
473
|
+
else:
|
|
474
|
+
latest_symlink.symlink_to(save_path.name, target_is_directory=True)
|
|
475
|
+
|
|
476
|
+
if keep:
|
|
477
|
+
keep_path = save_path.with_name(save_path.name + "_keep")
|
|
478
|
+
trainer.logger.info(
|
|
479
|
+
f"=> Marking checkpoint for keeping {keep_path} ..."
|
|
480
|
+
)
|
|
481
|
+
# retain checkpoint via symlink
|
|
482
|
+
try:
|
|
483
|
+
save_path.rename(keep_path)
|
|
484
|
+
save_path.symlink_to(keep_path.name, target_is_directory=True)
|
|
485
|
+
except Exception:
|
|
486
|
+
trainer.logger.exception(
|
|
487
|
+
f"Could not rename/symlink checkpoint for keeping {keep_path} ..."
|
|
488
|
+
)
|
|
489
|
+
# # retain checkpoint via hard-linked copy (saves space, survives pruning of original)
|
|
490
|
+
# try:
|
|
491
|
+
# shutil.copytree(save_path, keep_path, copy_function=os.link)
|
|
492
|
+
# except Exception:
|
|
493
|
+
# trainer.logger.exception(
|
|
494
|
+
# f"Could not copy checkpoint for keeping {keep_path} ..."
|
|
495
|
+
# )
|
|
496
|
+
|
|
497
|
+
# prune
|
|
498
|
+
prev_ckpts = sorted(
|
|
499
|
+
[
|
|
500
|
+
p
|
|
501
|
+
for p in save_path.parent.iterdir()
|
|
502
|
+
if p.is_dir()
|
|
503
|
+
and self._is_int(p.name)
|
|
504
|
+
and int(p.name) < trainer.step
|
|
505
|
+
],
|
|
506
|
+
key=lambda p: int(p.name),
|
|
507
|
+
)
|
|
508
|
+
for p in (
|
|
509
|
+
prev_ckpts[: -self.keep_previous]
|
|
510
|
+
if self.keep_previous > 0
|
|
511
|
+
else prev_ckpts
|
|
512
|
+
):
|
|
513
|
+
trainer.logger.info(f"=> Pruning checkpoint {p} ...")
|
|
514
|
+
try:
|
|
515
|
+
if p.is_symlink():
|
|
516
|
+
p.unlink()
|
|
517
|
+
else:
|
|
518
|
+
shutil.rmtree(p)
|
|
519
|
+
except Exception:
|
|
520
|
+
trainer.logger.exception(f"Could not remove {p}")
|
|
521
|
+
|
|
522
|
+
@staticmethod
|
|
523
|
+
def _get_tmp_save_dir(path: Path):
|
|
524
|
+
mask = os.umask(0) # only way to get the umask is to set it
|
|
525
|
+
os.umask(mask)
|
|
526
|
+
tmp_save_path = Path(
|
|
527
|
+
tempfile.mkdtemp(prefix=path.name + ".tmp.", dir=path.parent)
|
|
528
|
+
)
|
|
529
|
+
os.chmod(tmp_save_path, 0o777 & ~mask) # set default mkdir permissions
|
|
530
|
+
return tmp_save_path
|
|
531
|
+
|
|
532
|
+
@staticmethod
|
|
533
|
+
def _is_int(s: str):
|
|
534
|
+
try:
|
|
535
|
+
int(s) # let's make absolutely sure that constructing and int will work
|
|
536
|
+
return str.isdecimal(s) # this filters out stuff like '+3' and '-3'
|
|
537
|
+
except ValueError:
|
|
538
|
+
return False
|
|
539
|
+
|
|
540
|
+
|
|
541
|
+
class CudaMaxMemoryHook(BaseHook):
|
|
542
|
+
"""Record peak CUDA memory per step into ``trainer.step_info``."""
|
|
543
|
+
|
|
544
|
+
def on_before_step(self, trainer: BaseTrainer):
|
|
545
|
+
torch.cuda.reset_peak_memory_stats(trainer.device)
|
|
546
|
+
|
|
547
|
+
def on_after_step(self, trainer: BaseTrainer):
|
|
548
|
+
trainer.step_info["max_memory"] = torch.cuda.max_memory_allocated(
|
|
549
|
+
trainer.device
|
|
550
|
+
) / (1024**3) # GiB
|
|
551
|
+
|
|
552
|
+
|
|
553
|
+
class EmaHook(BaseHook):
|
|
554
|
+
"""Maintain an exponential moving average of model weights.
|
|
555
|
+
|
|
556
|
+
Args:
|
|
557
|
+
decay: EMA decay rate.
|
|
558
|
+
"""
|
|
559
|
+
|
|
560
|
+
def __init__(self, decay: float):
|
|
561
|
+
self.decay = decay
|
|
562
|
+
|
|
563
|
+
def on_before_train(self, trainer: BaseTrainer):
|
|
564
|
+
trainer.logger.info("=> Creating EMA model ...")
|
|
565
|
+
# Note that AveragedModel does not seem to support FSDP. It will crash here.
|
|
566
|
+
self.ema_model = AveragedModel(trainer.model, avg_fn=get_ema_avg_fn(self.decay))
|
|
567
|
+
|
|
568
|
+
def on_after_step(self, trainer: BaseTrainer):
|
|
569
|
+
self.ema_model.update_parameters(trainer.model)
|
|
570
|
+
|
|
571
|
+
def on_load_state_dict(self, trainer: BaseTrainer, state_dict: dict):
|
|
572
|
+
trainer.logger.info("=> Loading EMA model state ...")
|
|
573
|
+
set_model_state_dict(self.ema_model, state_dict["ema_model"])
|
|
574
|
+
|
|
575
|
+
def on_state_dict(self, trainer: BaseTrainer, state_dict: dict):
|
|
576
|
+
# Note: sadly, we need to keep the AveragedModel wrapper, to save its n_averaged buffer
|
|
577
|
+
state_dict["ema_model"] = get_model_state_dict(self.ema_model)
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
class WandbHook(BaseHook):
|
|
581
|
+
"""Log metrics and images to Weights & Biases (rank 0 only).
|
|
582
|
+
|
|
583
|
+
Args:
|
|
584
|
+
project: W&B project name.
|
|
585
|
+
config: Optional config dict or JSON file path to log.
|
|
586
|
+
tags: Optional tag list.
|
|
587
|
+
image_format: File format for images or a callable to derive it per key.
|
|
588
|
+
**wandb_kwargs: Extra arguments forwarded to ``wandb.init``.
|
|
589
|
+
"""
|
|
590
|
+
|
|
591
|
+
def __init__(
|
|
592
|
+
self,
|
|
593
|
+
project: str,
|
|
594
|
+
config: dict[str, Any] | str | None = None,
|
|
595
|
+
tags: Sequence[str] | None = None,
|
|
596
|
+
image_format: str | None | Callable[[str], str | None] = "png",
|
|
597
|
+
**wandb_kwargs,
|
|
598
|
+
):
|
|
599
|
+
self.project = project
|
|
600
|
+
self.config = config
|
|
601
|
+
self.tags = tags
|
|
602
|
+
if callable(image_format):
|
|
603
|
+
self.image_format = image_format
|
|
604
|
+
else:
|
|
605
|
+
self.image_format = lambda _: image_format
|
|
606
|
+
self.wandb_kwargs = wandb_kwargs
|
|
607
|
+
|
|
608
|
+
def on_before_train(self, trainer: BaseTrainer):
|
|
609
|
+
if dist.get_rank() == 0:
|
|
610
|
+
wandb_run_id = self._load_wandb_run_id(trainer)
|
|
611
|
+
|
|
612
|
+
tags = os.getenv("WANDB_TAGS", "")
|
|
613
|
+
tags = list(self.tags) + (tags.split(",") if tags else []) # concat
|
|
614
|
+
tags = list(dict.fromkeys(tags)) # deduplicate while preserving order
|
|
615
|
+
|
|
616
|
+
# it seems that we should use resume_from={run_id}?_{step} in wandb.init instead, but it's not well documented
|
|
617
|
+
self.wandb = wandb.init(
|
|
618
|
+
project=os.getenv("WANDB_PROJECT", self.project),
|
|
619
|
+
dir=os.getenv("WANDB_DIR", trainer.workspace),
|
|
620
|
+
id=os.getenv("WANDB_RUN_ID", wandb_run_id),
|
|
621
|
+
resume=os.getenv("WANDB_RESUME", "must" if wandb_run_id else None),
|
|
622
|
+
config=self.config,
|
|
623
|
+
tags=tags,
|
|
624
|
+
**self.wandb_kwargs,
|
|
625
|
+
)
|
|
626
|
+
if not self.wandb.disabled:
|
|
627
|
+
self._save_wandb_run_id(trainer, self.wandb.id)
|
|
628
|
+
|
|
629
|
+
def on_after_train(self, trainer: BaseTrainer):
|
|
630
|
+
if dist.get_rank() == 0:
|
|
631
|
+
self.wandb.finish()
|
|
632
|
+
|
|
633
|
+
def on_log(self, trainer: BaseTrainer, records: dict, dry_run: bool = False):
|
|
634
|
+
if dist.get_rank() == 0:
|
|
635
|
+
data = {"/".join(k): v for k, v in flatten_nested_dict(records).items()}
|
|
636
|
+
if not dry_run:
|
|
637
|
+
self.wandb.log(data, step=trainer.step)
|
|
638
|
+
else:
|
|
639
|
+
trainer.logger.debug(f"Dry run log. Would log: {data}")
|
|
640
|
+
|
|
641
|
+
def on_log_images(self, trainer: BaseTrainer, records: dict, dry_run: bool = False):
|
|
642
|
+
if dist.get_rank() == 0:
|
|
643
|
+
wandb_data = {}
|
|
644
|
+
for k, img in flatten_nested_dict({"vis": records}).items():
|
|
645
|
+
file_type = self.image_format(k[-1])
|
|
646
|
+
wandb_data.setdefault("/".join(k[:-1]), []).append(
|
|
647
|
+
wandb.Image(
|
|
648
|
+
self._ensure_jpeg_compatible(img)
|
|
649
|
+
if file_type in ["jpg", "jpeg"]
|
|
650
|
+
else img,
|
|
651
|
+
caption=k[-1],
|
|
652
|
+
file_type=file_type,
|
|
653
|
+
)
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
if not dry_run:
|
|
657
|
+
self.wandb.log(wandb_data, step=trainer.step)
|
|
658
|
+
else:
|
|
659
|
+
trainer.logger.debug(f"Dry run log. Would log: {wandb_data}")
|
|
660
|
+
|
|
661
|
+
@staticmethod
|
|
662
|
+
def _ensure_jpeg_compatible(img: PILImage, bg_color: tuple = (255, 255, 255)):
|
|
663
|
+
if img.mode in ("RGB", "L"):
|
|
664
|
+
return img
|
|
665
|
+
elif img.mode in ("RGBA", "LA"):
|
|
666
|
+
background = Image.new("RGB", img.size, bg_color)
|
|
667
|
+
background.paste(img, mask=img.getchannel("A"))
|
|
668
|
+
return background
|
|
669
|
+
else:
|
|
670
|
+
warnings.warn(
|
|
671
|
+
f"Trying to convert {img.mode} to RGB in a best-effort manner."
|
|
672
|
+
)
|
|
673
|
+
return img.convert("RGB")
|
|
674
|
+
|
|
675
|
+
@staticmethod
|
|
676
|
+
def _wandb_run_id_file_name(trainer: BaseTrainer):
|
|
677
|
+
return trainer.workspace / "wandb_run_id"
|
|
678
|
+
|
|
679
|
+
@classmethod
|
|
680
|
+
def _save_wandb_run_id(cls, trainer: BaseTrainer, run_id: str):
|
|
681
|
+
cls._wandb_run_id_file_name(trainer).write_text(run_id)
|
|
682
|
+
|
|
683
|
+
@classmethod
|
|
684
|
+
def _load_wandb_run_id(cls, trainer: BaseTrainer):
|
|
685
|
+
f = cls._wandb_run_id_file_name(trainer)
|
|
686
|
+
if f.exists():
|
|
687
|
+
return f.read_text()
|
|
688
|
+
return None
|
|
689
|
+
|
|
690
|
+
|
|
691
|
+
class ImageFileLoggerHook(BaseHook):
|
|
692
|
+
"""Persist logged images to ``workspace/visualizations`` on rank 0.
|
|
693
|
+
|
|
694
|
+
Args:
|
|
695
|
+
image_format: File extension or callable taking the leaf key.
|
|
696
|
+
"""
|
|
697
|
+
|
|
698
|
+
def __init__(
|
|
699
|
+
self,
|
|
700
|
+
image_format: str | Callable[[str], str] = "png",
|
|
701
|
+
):
|
|
702
|
+
if callable(image_format):
|
|
703
|
+
self.image_format = image_format
|
|
704
|
+
else:
|
|
705
|
+
self.image_format = lambda _: image_format
|
|
706
|
+
|
|
707
|
+
def on_log_images(self, trainer: BaseTrainer, records: dict, dry_run: bool = False):
|
|
708
|
+
if dist.get_rank() == 0:
|
|
709
|
+
for k, img in flatten_nested_dict(records).items():
|
|
710
|
+
p = trainer.workspace / "visualizations" / str(trainer.step) / Path(*k)
|
|
711
|
+
p = Path(str(p) + "." + self.image_format(k[-1]))
|
|
712
|
+
if not dry_run:
|
|
713
|
+
p.parent.mkdir(parents=True, exist_ok=True)
|
|
714
|
+
img.save(p)
|
|
715
|
+
else:
|
|
716
|
+
trainer.logger.debug(f"Dry run log. Would save {img} to: {p}")
|
|
File without changes
|
|
@@ -0,0 +1,406 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import time
|
|
3
|
+
import warnings
|
|
4
|
+
from contextlib import closing, nullcontext
|
|
5
|
+
from logging import Logger
|
|
6
|
+
from numbers import Number
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import (
|
|
9
|
+
TYPE_CHECKING,
|
|
10
|
+
Any,
|
|
11
|
+
Callable,
|
|
12
|
+
Iterable,
|
|
13
|
+
Iterator,
|
|
14
|
+
TypeAlias,
|
|
15
|
+
Union,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
import torch.nn as nn
|
|
20
|
+
from torch.distributed.checkpoint.state_dict import (
|
|
21
|
+
StateDictOptions,
|
|
22
|
+
get_state_dict,
|
|
23
|
+
set_state_dict,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
if TYPE_CHECKING:
|
|
27
|
+
from .hooks import BaseHook
|
|
28
|
+
|
|
29
|
+
Records: TypeAlias = dict[str, Union[Number, "Records"]]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class BaseTrainer:
|
|
33
|
+
"""
|
|
34
|
+
Minimal training loop that orchestrates builds, accumulation, retries, and hooks.
|
|
35
|
+
|
|
36
|
+
Subclasses provide component factories and a forward pass; the base class handles
|
|
37
|
+
sequencing, mixed precision, accumulation, state management, and hook dispatch.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
max_steps: Number of training steps to run.
|
|
41
|
+
grad_clip: Max gradient norm; if set, gradients are clipped before stepping.
|
|
42
|
+
max_non_finite_grad_retries: Number of retries when encountering non-finite gradients (scaler disabled).
|
|
43
|
+
mixed_precision: ``\"fp16\"`` or ``\"bf16\"`` to enable autocast; ``None`` disables it.
|
|
44
|
+
gradient_accumulation_steps: Number of microsteps to accumulate before stepping.
|
|
45
|
+
workspace: Optional working directory used by hooks (e.g., checkpoints, logs).
|
|
46
|
+
device: Device for the model and tensors.
|
|
47
|
+
no_sync_accumulate: Whether to call ``no_sync`` on distributed modules during accumulation.
|
|
48
|
+
state_dict_options: Torch distributed checkpoint options.
|
|
49
|
+
logger: Logger instance; a default logger is created when omitted.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
max_steps: int,
|
|
55
|
+
grad_clip: float | None = None,
|
|
56
|
+
max_non_finite_grad_retries: int | None = None,
|
|
57
|
+
mixed_precision: str | None = None,
|
|
58
|
+
gradient_accumulation_steps: int | None = None,
|
|
59
|
+
workspace: Path | str | None = None,
|
|
60
|
+
device: torch.device | str | int | None = None,
|
|
61
|
+
no_sync_accumulate: bool = True, # can make sense to disable this for FSDP
|
|
62
|
+
state_dict_options: StateDictOptions | None = None,
|
|
63
|
+
logger: Logger | None = None,
|
|
64
|
+
):
|
|
65
|
+
self.step = 0 # refers to the last begun step. incremented *before* each step
|
|
66
|
+
self.max_steps = max_steps
|
|
67
|
+
self.grad_clip = grad_clip
|
|
68
|
+
self.max_non_finite_grad_retries = max_non_finite_grad_retries
|
|
69
|
+
match mixed_precision:
|
|
70
|
+
case "fp16":
|
|
71
|
+
self.mixed_precision = torch.float16
|
|
72
|
+
case "bf16":
|
|
73
|
+
self.mixed_precision = torch.bfloat16
|
|
74
|
+
case None:
|
|
75
|
+
self.mixed_precision = None
|
|
76
|
+
case _:
|
|
77
|
+
raise ValueError(f"Unsupported mixed precision: {mixed_precision}")
|
|
78
|
+
self.device = (
|
|
79
|
+
torch.device(device) if device is not None else torch.get_default_device()
|
|
80
|
+
)
|
|
81
|
+
self.gradient_accumulation_steps = gradient_accumulation_steps or 1
|
|
82
|
+
self.workspace = Path(workspace) if workspace is not None else None
|
|
83
|
+
self.logger = logger if logger is not None else logging.getLogger("trainer")
|
|
84
|
+
self.no_sync_accumulate = no_sync_accumulate
|
|
85
|
+
self.state_dict_options = state_dict_options
|
|
86
|
+
|
|
87
|
+
def _build(self):
|
|
88
|
+
self.logger.debug("_build()")
|
|
89
|
+
self.data_loader = self.build_data_loader()
|
|
90
|
+
self.model = self.build_model()
|
|
91
|
+
self.optimizer = self.build_optimizer()
|
|
92
|
+
self.lr_scheduler = self.build_lr_scheduler()
|
|
93
|
+
self.grad_scaler = self.build_grad_scaler()
|
|
94
|
+
self.hooks = self.build_hooks()
|
|
95
|
+
|
|
96
|
+
def build_data_loader(self) -> Iterable:
|
|
97
|
+
"""Return the training data iterator."""
|
|
98
|
+
raise NotImplementedError
|
|
99
|
+
|
|
100
|
+
def build_model(self) -> nn.Module:
|
|
101
|
+
"""Construct and return the model."""
|
|
102
|
+
raise NotImplementedError
|
|
103
|
+
|
|
104
|
+
def build_optimizer(self) -> torch.optim.Optimizer:
|
|
105
|
+
"""Create the optimizer for the model."""
|
|
106
|
+
raise NotImplementedError
|
|
107
|
+
|
|
108
|
+
def build_lr_scheduler(self) -> torch.optim.lr_scheduler.LRScheduler | None:
|
|
109
|
+
"""Optionally create a learning-rate scheduler."""
|
|
110
|
+
return None
|
|
111
|
+
|
|
112
|
+
def build_hooks(self) -> list["BaseHook"]:
|
|
113
|
+
"""Return hooks to run during training."""
|
|
114
|
+
return []
|
|
115
|
+
|
|
116
|
+
def build_grad_scaler(self) -> torch.amp.GradScaler:
|
|
117
|
+
"""Create the gradient scaler used for mixed precision."""
|
|
118
|
+
return torch.amp.GradScaler(
|
|
119
|
+
self.device.type, enabled=self.mixed_precision == torch.float16
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
def state_dict(self) -> dict[str, Any]:
|
|
123
|
+
self.logger.debug("state_dict()")
|
|
124
|
+
model_state_dict, optimizer_state_dict = get_state_dict(
|
|
125
|
+
self.model, self.optimizer, options=self.state_dict_options
|
|
126
|
+
)
|
|
127
|
+
state_dict = {
|
|
128
|
+
"model": model_state_dict,
|
|
129
|
+
"training_state": {
|
|
130
|
+
"step": self.step,
|
|
131
|
+
"optimizer": optimizer_state_dict,
|
|
132
|
+
"lr_scheduler": self.lr_scheduler.state_dict()
|
|
133
|
+
if self.lr_scheduler
|
|
134
|
+
else None,
|
|
135
|
+
"grad_scaler": self.grad_scaler.state_dict(),
|
|
136
|
+
},
|
|
137
|
+
}
|
|
138
|
+
for h in self.hooks:
|
|
139
|
+
h.on_state_dict(self, state_dict)
|
|
140
|
+
return state_dict
|
|
141
|
+
|
|
142
|
+
def load_state_dict(self, state_dict: dict[str, Any]):
|
|
143
|
+
self.logger.debug("load_state_dict()")
|
|
144
|
+
training_state = state_dict["training_state"]
|
|
145
|
+
|
|
146
|
+
self.step = training_state["step"]
|
|
147
|
+
self.logger.info(f"=> Resuming from step {self.step} ...")
|
|
148
|
+
|
|
149
|
+
if self.lr_scheduler is not None:
|
|
150
|
+
# NOTE: order is important. load the optimizer AFTER lr_scheduler. https://github.com/pytorch/pytorch/issues/119168
|
|
151
|
+
self.logger.info("=> Loading LR scheduler state ...")
|
|
152
|
+
self.lr_scheduler.load_state_dict(training_state["lr_scheduler"])
|
|
153
|
+
self.logger.info("=> Loading grad scaler state ...")
|
|
154
|
+
self.grad_scaler.load_state_dict(training_state["grad_scaler"])
|
|
155
|
+
|
|
156
|
+
self.logger.info("=> Loading model and optimizer state ...")
|
|
157
|
+
set_state_dict(
|
|
158
|
+
self.model,
|
|
159
|
+
self.optimizer,
|
|
160
|
+
model_state_dict=state_dict["model"],
|
|
161
|
+
optim_state_dict=training_state["optimizer"],
|
|
162
|
+
options=self.state_dict_options,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
self.logger.info("=> Loading hook states ...")
|
|
166
|
+
for h in self.hooks:
|
|
167
|
+
h.on_load_state_dict(self, state_dict)
|
|
168
|
+
|
|
169
|
+
def train(self):
|
|
170
|
+
"""Run the training loop until ``max_steps`` are completed."""
|
|
171
|
+
self._build()
|
|
172
|
+
self._before_train()
|
|
173
|
+
|
|
174
|
+
self.model.train()
|
|
175
|
+
self.optimizer.zero_grad() # just in case
|
|
176
|
+
|
|
177
|
+
# attempt to explicitly close the iterator since it likely owns resources such as worker processes
|
|
178
|
+
with maybe_closing(iter(self.data_loader)) as data_iter:
|
|
179
|
+
while self.step < self.max_steps:
|
|
180
|
+
self.step += 1
|
|
181
|
+
self.step_info = {}
|
|
182
|
+
self._before_step()
|
|
183
|
+
|
|
184
|
+
step_time = time.perf_counter()
|
|
185
|
+
self._run_step(data_iter)
|
|
186
|
+
self.step_info["step_time"] = time.perf_counter() - step_time
|
|
187
|
+
|
|
188
|
+
self._after_step()
|
|
189
|
+
|
|
190
|
+
self._after_train()
|
|
191
|
+
|
|
192
|
+
# the only difference is that we add the accumulate context and do the warning
|
|
193
|
+
def _run_step(self, data_iter: Iterator):
|
|
194
|
+
"""
|
|
195
|
+
Run a single optimizer step of training.
|
|
196
|
+
Args:
|
|
197
|
+
data_iter (Iterator): Data iterator.
|
|
198
|
+
"""
|
|
199
|
+
|
|
200
|
+
def reset_step_info():
|
|
201
|
+
self.step_info["loss"] = []
|
|
202
|
+
self.step_info["records"] = []
|
|
203
|
+
|
|
204
|
+
reset_step_info()
|
|
205
|
+
self.step_info["data_time"] = []
|
|
206
|
+
non_finite_grad_retry_count = 0
|
|
207
|
+
i_acc = 0
|
|
208
|
+
while i_acc < self.gradient_accumulation_steps:
|
|
209
|
+
is_accumulating = i_acc < self.gradient_accumulation_steps - 1
|
|
210
|
+
no_sync_accumulate = (
|
|
211
|
+
self.model.no_sync()
|
|
212
|
+
if self.no_sync_accumulate
|
|
213
|
+
and is_accumulating
|
|
214
|
+
and hasattr(self.model, "no_sync")
|
|
215
|
+
else nullcontext()
|
|
216
|
+
) # for DDP and FSDP
|
|
217
|
+
|
|
218
|
+
data_time = time.perf_counter()
|
|
219
|
+
input = next(data_iter)
|
|
220
|
+
self.step_info["data_time"].append(time.perf_counter() - data_time)
|
|
221
|
+
|
|
222
|
+
with no_sync_accumulate:
|
|
223
|
+
with torch.autocast(
|
|
224
|
+
device_type=self.device.type,
|
|
225
|
+
dtype=self.mixed_precision,
|
|
226
|
+
enabled=bool(self.mixed_precision),
|
|
227
|
+
):
|
|
228
|
+
self.logger.debug(f"{self.step}-{i_acc} forward()")
|
|
229
|
+
loss, records = self.forward(input)
|
|
230
|
+
if loss is None:
|
|
231
|
+
if isinstance(
|
|
232
|
+
self.model,
|
|
233
|
+
(
|
|
234
|
+
torch.nn.parallel.DistributedDataParallel,
|
|
235
|
+
torch.distributed.fsdp.FullyShardedDataParallel,
|
|
236
|
+
),
|
|
237
|
+
):
|
|
238
|
+
# TODO: find a better way to handle this
|
|
239
|
+
# It seems that each DDP forward call is expected to be followed by a backward pass, as DDP maintains internal state after the forward pass that anticipates a backward step.
|
|
240
|
+
# While it might work with `broadcast_buffers=False` or if the backward pass is collectively skipped across all ranks,
|
|
241
|
+
# this behavior is not officially documented as safe and could result in undefined behavior.
|
|
242
|
+
# Since `Trainer.forward` may also return None before calling `DDP.forward`, this is just a warning rather than an error.
|
|
243
|
+
# I think the same thing applies to FSDP, but I haven't confirmed it.
|
|
244
|
+
warnings.warn(
|
|
245
|
+
"Loss is None; skipping backward step. Ensure self.model.forward was not called in self.forward to avoid undefined behavior in DDP and FSDP.",
|
|
246
|
+
LossNoneWarning,
|
|
247
|
+
)
|
|
248
|
+
continue # skip the backward & optimizer step
|
|
249
|
+
if not torch.isfinite(
|
|
250
|
+
loss
|
|
251
|
+
): # TODO: check if device sync slows down training
|
|
252
|
+
self.logger.warning(
|
|
253
|
+
f"Loss is non-finite ({loss.item()}). records={records}"
|
|
254
|
+
)
|
|
255
|
+
# we will handle non-finite later at the optimizer step, the warning is just for debugging
|
|
256
|
+
# keep in mind that at least for DDP, we must still call backward() to avoid undefined behavior!
|
|
257
|
+
|
|
258
|
+
self.step_info["loss"].append(loss.detach())
|
|
259
|
+
self.step_info["records"].append(records)
|
|
260
|
+
loss = loss / self.gradient_accumulation_steps
|
|
261
|
+
self.logger.debug(f"{self.step}-{i_acc} backward()")
|
|
262
|
+
self.grad_scaler.scale(loss).backward()
|
|
263
|
+
i_acc += 1 # only increment after an actual backward pass
|
|
264
|
+
|
|
265
|
+
if not is_accumulating:
|
|
266
|
+
if not self.grad_scaler.is_enabled():
|
|
267
|
+
# only skip non-finite grads if the scaler is disabled (the scaler needs to process non-finite grads to adjust the scale)
|
|
268
|
+
if any(
|
|
269
|
+
(not torch.isfinite(p.grad).all())
|
|
270
|
+
for p in self.model.parameters()
|
|
271
|
+
if p.grad is not None
|
|
272
|
+
):
|
|
273
|
+
if self.max_non_finite_grad_retries is None or (
|
|
274
|
+
non_finite_grad_retry_count
|
|
275
|
+
< self.max_non_finite_grad_retries
|
|
276
|
+
):
|
|
277
|
+
non_finite_grad_retry_count += 1
|
|
278
|
+
self.logger.warning(
|
|
279
|
+
f"Gradient is non-finite. Retrying step {self.step} (retry {non_finite_grad_retry_count}"
|
|
280
|
+
+ (
|
|
281
|
+
f"/{self.max_non_finite_grad_retries})."
|
|
282
|
+
if self.max_non_finite_grad_retries is not None
|
|
283
|
+
else ")."
|
|
284
|
+
)
|
|
285
|
+
)
|
|
286
|
+
self.optimizer.zero_grad()
|
|
287
|
+
# TODO: check if we also need to "reset" (is that a thing?) the scaler here
|
|
288
|
+
reset_step_info()
|
|
289
|
+
i_acc = 0 # start accumulation again
|
|
290
|
+
continue
|
|
291
|
+
else:
|
|
292
|
+
raise RuntimeError(
|
|
293
|
+
"Gradient is non-finite. Exceeded maximum retries for non-finite gradients."
|
|
294
|
+
)
|
|
295
|
+
self.grad_scaler.unscale_(self.optimizer)
|
|
296
|
+
self._before_optimizer_step()
|
|
297
|
+
if self.grad_clip is not None:
|
|
298
|
+
self.step_info["grad_norm"] = torch.nn.utils.clip_grad_norm_(
|
|
299
|
+
self.model.parameters(), self.grad_clip
|
|
300
|
+
)
|
|
301
|
+
self.logger.debug(f"{self.step}-{i_acc - 1} step()")
|
|
302
|
+
self.grad_scaler.step(self.optimizer)
|
|
303
|
+
self.grad_scaler.update()
|
|
304
|
+
self.optimizer.zero_grad()
|
|
305
|
+
if self.lr_scheduler is not None:
|
|
306
|
+
self.lr_scheduler.step()
|
|
307
|
+
|
|
308
|
+
def forward(self, input: Any) -> tuple[torch.Tensor | None, Records]:
|
|
309
|
+
"""
|
|
310
|
+
Perform a forward pass and return loss plus records for logging.
|
|
311
|
+
|
|
312
|
+
Args:
|
|
313
|
+
input: Batch yielded by the data loader.
|
|
314
|
+
|
|
315
|
+
Returns:
|
|
316
|
+
The loss (``None`` skips backward/step; if using DDP/FSDP, avoid invoking the wrapped module's ``forward`` in that case).
|
|
317
|
+
A nested dict of numeric metrics that will be averaged and emitted to hooks.
|
|
318
|
+
"""
|
|
319
|
+
raise NotImplementedError
|
|
320
|
+
|
|
321
|
+
@classmethod
|
|
322
|
+
def unwrap(self, module: nn.Module) -> nn.Module:
|
|
323
|
+
match module:
|
|
324
|
+
case torch._dynamo.eval_frame.OptimizedModule():
|
|
325
|
+
return self.unwrap(module._orig_mod)
|
|
326
|
+
case (
|
|
327
|
+
torch.nn.parallel.DistributedDataParallel()
|
|
328
|
+
| torch.nn.parallel.DataParallel()
|
|
329
|
+
| torch.optim.swa_utils.AveragedModel()
|
|
330
|
+
):
|
|
331
|
+
return self.unwrap(module.module)
|
|
332
|
+
return module
|
|
333
|
+
|
|
334
|
+
@property
|
|
335
|
+
def unwrapped_model(self):
|
|
336
|
+
return self.unwrap(self.model)
|
|
337
|
+
|
|
338
|
+
def log(self, records: dict[str, Any], dry_run: bool = False):
|
|
339
|
+
"""
|
|
340
|
+
Dispatch numeric records to hooks (e.g., trackers or stdout).
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
records: Nested dict of numeric metrics to log.
|
|
344
|
+
dry_run: If True, hooks should avoid side effects and only report intent.
|
|
345
|
+
"""
|
|
346
|
+
self.logger.debug("log()")
|
|
347
|
+
for h in self.hooks:
|
|
348
|
+
h.on_log(self, records, dry_run=dry_run)
|
|
349
|
+
|
|
350
|
+
def log_images(self, records: dict[str, Any], dry_run: bool = False):
|
|
351
|
+
"""
|
|
352
|
+
Dispatch image records to hooks.
|
|
353
|
+
|
|
354
|
+
Args:
|
|
355
|
+
records: Nested dict of images to log.
|
|
356
|
+
dry_run: If True, hooks should avoid side effects and only report intent.
|
|
357
|
+
"""
|
|
358
|
+
self.logger.debug("log_images()")
|
|
359
|
+
for h in self.hooks:
|
|
360
|
+
h.on_log_images(self, records, dry_run=dry_run)
|
|
361
|
+
|
|
362
|
+
def _before_train(self):
|
|
363
|
+
self.logger.debug("_before_train()")
|
|
364
|
+
for h in self.hooks:
|
|
365
|
+
h.on_before_train(self)
|
|
366
|
+
|
|
367
|
+
def _after_train(self):
|
|
368
|
+
self.logger.debug("_after_train()")
|
|
369
|
+
for h in self.hooks:
|
|
370
|
+
h.on_after_train(self)
|
|
371
|
+
|
|
372
|
+
def _before_step(self):
|
|
373
|
+
self.logger.debug("_before_step()")
|
|
374
|
+
for h in self.hooks:
|
|
375
|
+
h.on_before_step(self)
|
|
376
|
+
|
|
377
|
+
def _after_step(self):
|
|
378
|
+
self.logger.debug("_after_step()")
|
|
379
|
+
for h in self.hooks:
|
|
380
|
+
h.on_after_step(self)
|
|
381
|
+
|
|
382
|
+
def _before_optimizer_step(self):
|
|
383
|
+
self.logger.debug("_before_optimizer_step()")
|
|
384
|
+
for h in self.hooks:
|
|
385
|
+
h.on_before_optimizer_step(self)
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
def maybe_closing(obj):
|
|
389
|
+
"""Return a context manager that closes `obj` if it has a .close() method, otherwise does nothing."""
|
|
390
|
+
return closing(obj) if callable(getattr(obj, "close", None)) else nullcontext(obj)
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
def map_nested_tensor(f: Callable[[torch.Tensor], Any], obj: Any):
|
|
394
|
+
"""Apply ``f`` to every tensor contained in a nested structure."""
|
|
395
|
+
if isinstance(obj, torch.Tensor):
|
|
396
|
+
return f(obj)
|
|
397
|
+
elif isinstance(obj, (list, tuple, set)):
|
|
398
|
+
return type(obj)(map_nested_tensor(f, o) for o in obj)
|
|
399
|
+
elif isinstance(obj, dict):
|
|
400
|
+
return type(obj)((k, map_nested_tensor(f, v)) for k, v in obj.items())
|
|
401
|
+
else:
|
|
402
|
+
return obj
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
class LossNoneWarning(UserWarning):
|
|
406
|
+
"""Warning raised when ``forward`` returns ``None`` in distributed contexts."""
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
# modified from: https://github.com/microsoft/MoGe/blob/6b8b43db567ca4b08615c39b42cffd6c76cada29/moge/utils/tools.py
|
|
2
|
+
|
|
3
|
+
import math
|
|
4
|
+
from typing import Any, Generator, MutableMapping
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def traverse_nested_dict_keys(
|
|
8
|
+
d: dict[str, dict],
|
|
9
|
+
) -> Generator[tuple[str, ...], None, None]:
|
|
10
|
+
for k, v in d.items():
|
|
11
|
+
if isinstance(v, dict):
|
|
12
|
+
for sub_key in traverse_nested_dict_keys(v):
|
|
13
|
+
yield (k,) + sub_key
|
|
14
|
+
else:
|
|
15
|
+
yield (k,)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def get_nested_dict(d: dict[str, dict], keys: tuple[str, ...], default: Any = None):
|
|
19
|
+
for k in keys:
|
|
20
|
+
d = d.get(k, default)
|
|
21
|
+
if d is None:
|
|
22
|
+
break
|
|
23
|
+
return d
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def set_nested_dict(d: dict[str, dict], keys: tuple[str, ...], value: Any):
|
|
27
|
+
for k in keys[:-1]:
|
|
28
|
+
d = d.setdefault(k, {})
|
|
29
|
+
d[keys[-1]] = value
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def key_average(list_of_dicts: list, exclude_nan: bool = False) -> dict[str, Any]:
|
|
33
|
+
"""
|
|
34
|
+
Returns a dictionary with the average value of each key in the input list of dictionaries.
|
|
35
|
+
"""
|
|
36
|
+
_nested_dict_keys = set()
|
|
37
|
+
for d in list_of_dicts:
|
|
38
|
+
_nested_dict_keys.update(traverse_nested_dict_keys(d))
|
|
39
|
+
_nested_dict_keys = sorted(_nested_dict_keys)
|
|
40
|
+
result = {}
|
|
41
|
+
for k in _nested_dict_keys:
|
|
42
|
+
values = []
|
|
43
|
+
for d in list_of_dicts:
|
|
44
|
+
v = get_nested_dict(d, k)
|
|
45
|
+
if v is not None and (not exclude_nan or not math.isnan(v)):
|
|
46
|
+
values.append(v)
|
|
47
|
+
avg = sum(values) / len(values) if values else float("nan")
|
|
48
|
+
set_nested_dict(result, k, avg)
|
|
49
|
+
return result
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def flatten_nested_dict(
|
|
53
|
+
d: dict[str, Any], parent_key: tuple[str, ...] = None
|
|
54
|
+
) -> dict[tuple[str, ...], Any]:
|
|
55
|
+
"""
|
|
56
|
+
Flattens a nested dictionary into a single-level dictionary, with keys as tuples.
|
|
57
|
+
"""
|
|
58
|
+
items = []
|
|
59
|
+
if parent_key is None:
|
|
60
|
+
parent_key = ()
|
|
61
|
+
for k, v in d.items():
|
|
62
|
+
new_key = parent_key + (k,)
|
|
63
|
+
if isinstance(v, MutableMapping):
|
|
64
|
+
items.extend(flatten_nested_dict(v, new_key).items())
|
|
65
|
+
else:
|
|
66
|
+
items.append((new_key, v))
|
|
67
|
+
return dict(items)
|