trainloop 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,24 @@
1
+ Metadata-Version: 2.3
2
+ Name: trainloop
3
+ Version: 0.1.0
4
+ Summary: Minimal PyTorch training loop with hooks and checkpointing.
5
+ Author: Karim Abou Zeid
6
+ Author-email: Karim Abou Zeid <contact@ka.codes>
7
+ Requires-Dist: pillow>=11.3.0
8
+ Requires-Dist: torch>=2.0.0
9
+ Requires-Python: >=3.10
10
+ Description-Content-Type: text/markdown
11
+
12
+ # trainloop
13
+
14
+ [![PyPI version](https://img.shields.io/pypi/v/trainloop.svg)](https://pypi.org/project/trainloop/)
15
+
16
+ Minimal PyTorch training loop with hooks for logging, checkpointing, and customization.
17
+
18
+ Docs: https://kabouzeid.github.io/trainloop/
19
+
20
+ ## Install
21
+
22
+ ```bash
23
+ pip install trainloop
24
+ ```
@@ -0,0 +1,13 @@
1
+ # trainloop
2
+
3
+ [![PyPI version](https://img.shields.io/pypi/v/trainloop.svg)](https://pypi.org/project/trainloop/)
4
+
5
+ Minimal PyTorch training loop with hooks for logging, checkpointing, and customization.
6
+
7
+ Docs: https://kabouzeid.github.io/trainloop/
8
+
9
+ ## Install
10
+
11
+ ```bash
12
+ pip install trainloop
13
+ ```
@@ -0,0 +1,26 @@
1
+ [project]
2
+ name = "trainloop"
3
+ version = "0.1.0"
4
+ description = "Minimal PyTorch training loop with hooks and checkpointing."
5
+ readme = "README.md"
6
+ authors = [
7
+ { name = "Karim Abou Zeid", email = "contact@ka.codes" }
8
+ ]
9
+ requires-python = ">=3.10"
10
+ dependencies = [
11
+ "pillow>=11.3.0",
12
+ "torch>=2.0.0",
13
+ ]
14
+
15
+ [build-system]
16
+ requires = ["uv_build>=0.9.0,<0.10.0"]
17
+ build-backend = "uv_build"
18
+
19
+ [dependency-groups]
20
+ dev = [
21
+ "wandb>=0.20.1",
22
+ "pytest>=8.4.0",
23
+ "ruff>=0.11.13",
24
+ "zensical>=0.0.10",
25
+ "mkdocstrings-python>=2.0.1",
26
+ ]
@@ -0,0 +1,25 @@
1
+ from .hooks import (
2
+ BaseHook,
3
+ CheckpointingHook,
4
+ CudaMaxMemoryHook,
5
+ EmaHook,
6
+ ImageFileLoggerHook,
7
+ LoggingHook,
8
+ ProgressHook,
9
+ WandbHook,
10
+ )
11
+ from .trainer import BaseTrainer, LossNoneWarning, map_nested_tensor
12
+
13
+ __all__ = [
14
+ "BaseTrainer",
15
+ "BaseHook",
16
+ "CheckpointingHook",
17
+ "CudaMaxMemoryHook",
18
+ "LoggingHook",
19
+ "ProgressHook",
20
+ "EmaHook",
21
+ "WandbHook",
22
+ "ImageFileLoggerHook",
23
+ "LossNoneWarning",
24
+ "map_nested_tensor",
25
+ ]
@@ -0,0 +1,716 @@
1
+ import os
2
+ import shutil
3
+ import signal
4
+ import sys
5
+ import tempfile
6
+ import time
7
+ import warnings
8
+ from datetime import timedelta
9
+ from numbers import Number
10
+ from pathlib import Path
11
+ from typing import Any, Callable, Iterable, Literal, Sequence
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+ from PIL import Image
16
+ from PIL.Image import Image as PILImage
17
+ from torch.distributed.checkpoint.state_dict import (
18
+ get_model_state_dict,
19
+ set_model_state_dict,
20
+ )
21
+ from torch.optim.swa_utils import AveragedModel, get_ema_avg_fn
22
+
23
+ try:
24
+ import wandb
25
+ except ImportError:
26
+ # only needed for WandbHook
27
+ pass
28
+
29
+ from .trainer import BaseTrainer, Records
30
+ from .utils import flatten_nested_dict, key_average
31
+
32
+
33
+ class BaseHook:
34
+ """Lifecycle hooks for `BaseTrainer`."""
35
+
36
+ def on_before_train(self, trainer: BaseTrainer):
37
+ pass
38
+
39
+ def on_before_step(self, trainer: BaseTrainer):
40
+ pass
41
+
42
+ def on_before_optimizer_step(self, trainer: BaseTrainer):
43
+ pass
44
+
45
+ def on_after_step(self, trainer: BaseTrainer):
46
+ pass
47
+
48
+ def on_after_train(self, trainer: BaseTrainer):
49
+ pass
50
+
51
+ def on_log(self, trainer: BaseTrainer, records: dict, dry_run: bool = False):
52
+ pass
53
+
54
+ def on_log_images(self, trainer: BaseTrainer, records: dict, dry_run: bool = False):
55
+ pass
56
+
57
+ def on_state_dict(self, trainer: BaseTrainer, state_dict: dict):
58
+ pass
59
+
60
+ def on_load_state_dict(self, trainer: BaseTrainer, state_dict: dict):
61
+ pass
62
+
63
+
64
+ class _StatsHook(BaseHook):
65
+ """Collect step statistics and hand them to subclasses for reporting.
66
+
67
+ Args:
68
+ interval: Emit stats every N steps.
69
+ sync: If True, aggregate stats across distributed ranks.
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ interval: int,
75
+ sync: bool,
76
+ ):
77
+ self.interval = interval
78
+ self.sync = sync
79
+ self.reset()
80
+
81
+ def reset(self):
82
+ self.losses = []
83
+ self.records_ls = []
84
+ self.grad_norms = []
85
+ self.data_times = []
86
+ self.step_times = []
87
+ self.max_memories = []
88
+
89
+ def on_after_step(self, trainer: BaseTrainer):
90
+ # collect and aggregate over accumulation steps
91
+ self.losses.append(torch.stack(trainer.step_info["loss"]).mean())
92
+ if trainer.grad_clip is not None:
93
+ self.grad_norms.append(trainer.step_info["grad_norm"])
94
+ self.records_ls.append(key_average(trainer.step_info["records"]))
95
+ self.data_times.append(sum(trainer.step_info["data_time"])) # total
96
+ self.step_times.append(trainer.step_info["step_time"])
97
+ if "max_memory" in trainer.step_info:
98
+ self.max_memories.append(trainer.step_info["max_memory"])
99
+
100
+ if trainer.step % self.interval == 0 or trainer.step == trainer.max_steps:
101
+ # aggregate over steps
102
+ loss = torch.stack(self.losses).mean()
103
+ grad_norm = torch.stack(self.grad_norms).mean() if self.grad_norms else None
104
+ records = key_average(self.records_ls)
105
+ data_time = sum(self.data_times) / len(self.data_times)
106
+ step_time = sum(self.step_times) / len(self.step_times)
107
+ max_memory = max(self.max_memories) if self.max_memories else None
108
+
109
+ if self.sync:
110
+ # aggregate accross all ranks
111
+ dist.all_reduce(loss, op=dist.ReduceOp.AVG)
112
+ if grad_norm is not None:
113
+ dist.all_reduce(grad_norm, op=dist.ReduceOp.AVG)
114
+
115
+ gathered = [None] * dist.get_world_size()
116
+ dist.all_gather_object(
117
+ gathered,
118
+ {
119
+ "records": records,
120
+ "data_time": data_time,
121
+ "step_time": step_time,
122
+ "max_memory": max_memory,
123
+ },
124
+ )
125
+ records = key_average([stat["records"] for stat in gathered])
126
+ data_time = sum(stat["data_time"] for stat in gathered) / len(gathered)
127
+ step_time = sum(stat["step_time"] for stat in gathered) / len(gathered)
128
+ if "max_memory" in trainer.step_info:
129
+ max_memory = max(stat["max_memory"] for stat in gathered)
130
+
131
+ self.process_stats(
132
+ trainer,
133
+ loss.item(),
134
+ grad_norm.item() if grad_norm is not None else None,
135
+ step_time,
136
+ data_time,
137
+ max_memory,
138
+ records,
139
+ )
140
+ self.reset()
141
+
142
+ def process_stats(
143
+ self,
144
+ trainer: BaseTrainer,
145
+ loss: float,
146
+ grad_norm: float | None,
147
+ step_time: float,
148
+ data_time: float,
149
+ max_memory: float | None,
150
+ records: Records,
151
+ ):
152
+ raise NotImplementedError("Subclasses must implement this method.")
153
+
154
+
155
+ class ETATracker:
156
+ def __init__(self, warmup_steps: int):
157
+ """Track ETA across training steps after a warmup period.
158
+
159
+ Args:
160
+ warmup_steps: Number of steps to skip before timing begins.
161
+ """
162
+ assert warmup_steps > 0, "Warmup steps must be greater than 0"
163
+ self.warmup_steps = warmup_steps
164
+ self.steps = 0
165
+ self.timing_start = None
166
+ self.timed_steps = 0
167
+
168
+ def step(self):
169
+ self.steps += 1
170
+ if self.steps == self.warmup_steps:
171
+ self.timing_start = time.perf_counter()
172
+ if self.steps > self.warmup_steps:
173
+ self.timed_steps += 1
174
+
175
+ def get_eta(self, steps_remaining: int):
176
+ if self.timed_steps == 0:
177
+ return None
178
+
179
+ elapsed = time.perf_counter() - self.timing_start
180
+ avg_step_time = elapsed / self.timed_steps
181
+ eta_seconds = avg_step_time * steps_remaining
182
+ return timedelta(seconds=int(eta_seconds))
183
+
184
+
185
+ class ProgressHook(_StatsHook):
186
+ """Log progress to stdout with optional metrics, ETA, and memory.
187
+
188
+ Args:
189
+ interval: Log every N steps.
190
+ with_records: Include per-step records in the log line.
191
+ sync: If True, aggregate across distributed ranks.
192
+ eta_warmup: Steps to warm up ETA calculation.
193
+ show_units: Whether to print units (s, GiB) alongside values.
194
+ """
195
+
196
+ def __init__(
197
+ self,
198
+ interval: int = 1,
199
+ with_records: bool = False,
200
+ sync: bool = False,
201
+ eta_warmup: int = 10,
202
+ show_units: bool = True,
203
+ ):
204
+ super().__init__(interval=interval, sync=sync)
205
+ self.with_records = with_records
206
+ self.eta_warmup = eta_warmup
207
+ self.show_units = show_units
208
+
209
+ def on_before_train(self, trainer: BaseTrainer):
210
+ super().on_before_train(trainer)
211
+ trainer.logger.info("=> Starting training ...")
212
+ self.eta_tracker = ETATracker(warmup_steps=self.eta_warmup)
213
+
214
+ def on_after_train(self, trainer: BaseTrainer):
215
+ super().on_after_train(trainer)
216
+ trainer.logger.info("=> Finished training")
217
+
218
+ def on_before_step(self, trainer: BaseTrainer):
219
+ super().on_before_step(trainer)
220
+ self.lrs = [
221
+ (i, param_group["lr"])
222
+ for i, param_group in enumerate(trainer.optimizer.param_groups)
223
+ ] # record the LR before the scheduler steps
224
+
225
+ def on_after_step(self, trainer: BaseTrainer):
226
+ self.eta_tracker.step() # should be called before process_stats
227
+ super().on_after_step(trainer)
228
+
229
+ def process_stats(
230
+ self,
231
+ trainer: BaseTrainer,
232
+ loss: float,
233
+ grad_norm: float | None,
234
+ step_time: float,
235
+ data_time: float,
236
+ max_memory: float | None,
237
+ records: Records,
238
+ ):
239
+ eta = self.eta_tracker.get_eta(trainer.max_steps - trainer.step)
240
+ trainer.logger.info(
241
+ f"Step {trainer.step:>{len(str(trainer.max_steps))}}/{trainer.max_steps}:"
242
+ + f" step {step_time:.4f}{'s' if self.show_units else ''} data {data_time:.4f}{'s' if self.show_units else ''}"
243
+ + (f" eta {eta}" if eta is not None else "")
244
+ + (
245
+ f" mem {max_memory:#.3g}{'GiB' if self.show_units else ''}"
246
+ if max_memory is not None
247
+ else ""
248
+ )
249
+ + f" loss {loss:.4f}"
250
+ + (f" grad_norm {grad_norm:.4f}" if grad_norm is not None else "")
251
+ + (" " + " ".join(f"lr_{i} {lr:.2e}" for i, lr in self.lrs))
252
+ + (
253
+ (
254
+ " | "
255
+ + " ".join(
256
+ f"{'/'.join(k)} {f'{v:#.4g}' if isinstance(v, Number) else v}"
257
+ for k, v in flatten_nested_dict(records).items()
258
+ )
259
+ )
260
+ if self.with_records
261
+ else ""
262
+ )
263
+ )
264
+
265
+
266
+ class LoggingHook(_StatsHook):
267
+ """Aggregate stats and forward them to ``trainer.log``.
268
+
269
+ Args:
270
+ interval: Log every N steps.
271
+ sync: If True, aggregate across distributed ranks.
272
+ """
273
+
274
+ def __init__(
275
+ self,
276
+ interval: int = 10,
277
+ sync: bool = True,
278
+ ):
279
+ super().__init__(interval, sync)
280
+
281
+ def process_stats(
282
+ self,
283
+ trainer: BaseTrainer,
284
+ loss: float,
285
+ grad_norm: float | None,
286
+ step_time: float,
287
+ data_time: float,
288
+ max_memory: float | None,
289
+ records: Records,
290
+ ):
291
+ lrs = [
292
+ (i, param_group["lr"])
293
+ for i, param_group in enumerate(trainer.optimizer.param_groups)
294
+ ]
295
+ trainer.log(
296
+ {
297
+ "train": records
298
+ | ({"grad_norm": grad_norm} if grad_norm is not None else {})
299
+ | ({"max_memory": max_memory} if max_memory is not None else {})
300
+ | {
301
+ "loss": loss,
302
+ "data_time": data_time,
303
+ "step_time": step_time,
304
+ "lr": {f"group_{i}": lr for i, lr in lrs},
305
+ }
306
+ }
307
+ )
308
+
309
+
310
+ class CheckpointingHook(BaseHook):
311
+ """Save and optionally restore checkpoints at regular intervals.
312
+
313
+ Args:
314
+ interval: Save every ``interval`` steps.
315
+ keep_previous: Keep the last N checkpoints in addition to the latest.
316
+ keep_interval: Keep checkpoints every ``keep_interval`` steps.
317
+ path: Directory (relative to workspace unless absolute) for checkpoints.
318
+ load: Path to load at startup or ``\"latest\"`` to auto-resume.
319
+ exit_signals: Signals that trigger a checkpoint then exit.
320
+ exit_code: Exit code after handling an exit signal.
321
+ exit_wait: Optional sleep before exit (useful for schedulers).
322
+ """
323
+
324
+ def __init__(
325
+ self,
326
+ interval: int,
327
+ keep_previous: int = 0, # keep N previous checkpoints
328
+ keep_interval: int = 0, # keep checkpoints of every N-th step
329
+ path: Path | str = "checkpoint",
330
+ load: Path | str | Literal["latest"] | None = "latest",
331
+ exit_signals: list[signal.Signals] | signal.Signals = None,
332
+ exit_code: int | Literal["128+signal"] = "128+signal",
333
+ exit_wait: timedelta | float = 0.0,
334
+ ):
335
+ assert interval > 0
336
+ assert keep_previous >= 0
337
+ self.interval = interval
338
+ self.keep_previous = keep_previous
339
+ self.keep_interval = keep_interval
340
+ self.path = Path(path)
341
+ self.load_path = Path(load) if load is not None else None
342
+
343
+ self.local_exit_signal: signal.Signals = -1 # not a valid value
344
+ exit_signals = exit_signals if exit_signals is not None else []
345
+ if not isinstance(exit_signals, Iterable):
346
+ exit_signals = [exit_signals]
347
+ for sig in exit_signals:
348
+ signal.signal(sig, lambda *args: setattr(self, "local_exit_signal", sig))
349
+ self.has_exit_signal_handlers = len(exit_signals) > 0
350
+ self.exit_code = exit_code
351
+ self.exit_wait = (
352
+ exit_wait.total_seconds() if isinstance(exit_wait, timedelta) else exit_wait
353
+ )
354
+
355
+ def on_before_train(self, trainer: BaseTrainer):
356
+ if self.has_exit_signal_handlers:
357
+ # micro optimization: allocate signal tensor only once
358
+ self.dist_exit_signal = torch.tensor(
359
+ self.local_exit_signal, dtype=torch.int32, device=trainer.device
360
+ )
361
+ load_path = self.load_path
362
+ if load_path is not None:
363
+ # handles 'latest' and regular checkpoints
364
+ if len(load_path.parts) == 1 and not load_path.is_absolute():
365
+ load_path = self.path / load_path
366
+ if not load_path.is_absolute():
367
+ assert trainer.workspace is not None
368
+ load_path = trainer.workspace / load_path
369
+ if not load_path.is_dir():
370
+ # nonexistent path is only ok if we're loading the 'latest' checkpoint
371
+ assert str(self.load_path) == "latest", (
372
+ f"Checkpoint path {load_path} does not exist"
373
+ )
374
+ return
375
+
376
+ trainer.logger.info(f"=> Loading checkpoint from {load_path} ...")
377
+ state_dict = {
378
+ file.with_suffix("").name: torch.load(
379
+ file, map_location=trainer.device, weights_only=True
380
+ )
381
+ for file in load_path.iterdir()
382
+ if file.is_file() and file.suffix == ".pt"
383
+ }
384
+ trainer.logger.debug(f"Checkpoint contains: {', '.join(state_dict.keys())}")
385
+ trainer.load_state_dict(state_dict)
386
+
387
+ def on_before_step(self, trainer: BaseTrainer):
388
+ if self.has_exit_signal_handlers:
389
+ self.dist_exit_signal.fill_(self.local_exit_signal)
390
+ # micro optimization: reduce async during step and read after step
391
+ self.dist_exit_signal_work = dist.all_reduce(
392
+ self.dist_exit_signal, op=dist.ReduceOp.MAX, async_op=True
393
+ )
394
+
395
+ def on_after_step(self, trainer: BaseTrainer):
396
+ save_and_exit = False
397
+ if self.has_exit_signal_handlers:
398
+ self.dist_exit_signal_work.wait()
399
+ exit_signal = self.dist_exit_signal.item()
400
+ save_and_exit = exit_signal != -1
401
+
402
+ # NOTE: Check if last step here (not in on_after_train) to avoid saving twice
403
+ if (
404
+ trainer.step % self.interval == 0
405
+ or trainer.step == trainer.max_steps
406
+ or save_and_exit
407
+ ):
408
+ if save_and_exit:
409
+ trainer.logger.info(
410
+ f"=> Caught signal {exit_signal}. Saving checkpoint before exit ..."
411
+ )
412
+ self._save_checkpoint(
413
+ trainer,
414
+ keep=self.keep_interval > 0 and trainer.step % self.keep_interval == 0,
415
+ )
416
+ if save_and_exit:
417
+ dist.barrier()
418
+ if self.exit_wait > 0:
419
+ trainer.logger.info(
420
+ f"=> Waiting {self.exit_wait:.0f} seconds before exit ..."
421
+ )
422
+ time.sleep(self.exit_wait) # try wait for the Slurm job timeout
423
+ exit_code = (
424
+ 128 + exit_signal
425
+ if self.exit_code == "128+signal"
426
+ else sys.exit(self.exit_code)
427
+ )
428
+ trainer.logger.info(f"=> Exiting (code: {exit_code})")
429
+ sys.exit(exit_code)
430
+
431
+ def _save_checkpoint(self, trainer: BaseTrainer, keep: bool):
432
+ """Save a model checkpoint.
433
+
434
+ Raises only if writing the current checkpoint fails. Issues encountered
435
+ while retaining or pruning older checkpoints are logged but not raised.
436
+ """
437
+
438
+ dist.barrier()
439
+
440
+ state_dict = trainer.state_dict()
441
+
442
+ # TODO: all rank gathered states
443
+ # gathered_random_states = [None] * dist.get_world_size()
444
+ # dist.gather_object(
445
+ # get_random_state(),
446
+ # gathered_random_states if dist.get_rank() == 0 else None,
447
+ # dst=0,
448
+ # )
449
+
450
+ if dist.get_rank() == 0:
451
+ # make dir
452
+ save_path = self.path / str(trainer.step)
453
+ if not save_path.is_absolute():
454
+ assert trainer.workspace is not None
455
+ save_path = trainer.workspace / save_path
456
+ save_path.parent.mkdir(parents=True, exist_ok=True)
457
+ trainer.logger.info(f"=> Saving checkpoint to {save_path} ...")
458
+
459
+ # save
460
+ tmp_save_path = self._get_tmp_save_dir(save_path)
461
+ for name, sub_state_dict in state_dict.items():
462
+ torch.save(sub_state_dict, tmp_save_path / f"{name}.pt")
463
+ tmp_save_path.rename(save_path)
464
+
465
+ # symlink latest
466
+ latest_symlink = save_path.parent / "latest"
467
+ if latest_symlink.is_symlink():
468
+ latest_symlink.unlink()
469
+ if latest_symlink.exists():
470
+ trainer.logger.error(
471
+ f"{latest_symlink} already exists and is not a symlink. Will not create 'latest' symlink."
472
+ )
473
+ else:
474
+ latest_symlink.symlink_to(save_path.name, target_is_directory=True)
475
+
476
+ if keep:
477
+ keep_path = save_path.with_name(save_path.name + "_keep")
478
+ trainer.logger.info(
479
+ f"=> Marking checkpoint for keeping {keep_path} ..."
480
+ )
481
+ # retain checkpoint via symlink
482
+ try:
483
+ save_path.rename(keep_path)
484
+ save_path.symlink_to(keep_path.name, target_is_directory=True)
485
+ except Exception:
486
+ trainer.logger.exception(
487
+ f"Could not rename/symlink checkpoint for keeping {keep_path} ..."
488
+ )
489
+ # # retain checkpoint via hard-linked copy (saves space, survives pruning of original)
490
+ # try:
491
+ # shutil.copytree(save_path, keep_path, copy_function=os.link)
492
+ # except Exception:
493
+ # trainer.logger.exception(
494
+ # f"Could not copy checkpoint for keeping {keep_path} ..."
495
+ # )
496
+
497
+ # prune
498
+ prev_ckpts = sorted(
499
+ [
500
+ p
501
+ for p in save_path.parent.iterdir()
502
+ if p.is_dir()
503
+ and self._is_int(p.name)
504
+ and int(p.name) < trainer.step
505
+ ],
506
+ key=lambda p: int(p.name),
507
+ )
508
+ for p in (
509
+ prev_ckpts[: -self.keep_previous]
510
+ if self.keep_previous > 0
511
+ else prev_ckpts
512
+ ):
513
+ trainer.logger.info(f"=> Pruning checkpoint {p} ...")
514
+ try:
515
+ if p.is_symlink():
516
+ p.unlink()
517
+ else:
518
+ shutil.rmtree(p)
519
+ except Exception:
520
+ trainer.logger.exception(f"Could not remove {p}")
521
+
522
+ @staticmethod
523
+ def _get_tmp_save_dir(path: Path):
524
+ mask = os.umask(0) # only way to get the umask is to set it
525
+ os.umask(mask)
526
+ tmp_save_path = Path(
527
+ tempfile.mkdtemp(prefix=path.name + ".tmp.", dir=path.parent)
528
+ )
529
+ os.chmod(tmp_save_path, 0o777 & ~mask) # set default mkdir permissions
530
+ return tmp_save_path
531
+
532
+ @staticmethod
533
+ def _is_int(s: str):
534
+ try:
535
+ int(s) # let's make absolutely sure that constructing and int will work
536
+ return str.isdecimal(s) # this filters out stuff like '+3' and '-3'
537
+ except ValueError:
538
+ return False
539
+
540
+
541
+ class CudaMaxMemoryHook(BaseHook):
542
+ """Record peak CUDA memory per step into ``trainer.step_info``."""
543
+
544
+ def on_before_step(self, trainer: BaseTrainer):
545
+ torch.cuda.reset_peak_memory_stats(trainer.device)
546
+
547
+ def on_after_step(self, trainer: BaseTrainer):
548
+ trainer.step_info["max_memory"] = torch.cuda.max_memory_allocated(
549
+ trainer.device
550
+ ) / (1024**3) # GiB
551
+
552
+
553
+ class EmaHook(BaseHook):
554
+ """Maintain an exponential moving average of model weights.
555
+
556
+ Args:
557
+ decay: EMA decay rate.
558
+ """
559
+
560
+ def __init__(self, decay: float):
561
+ self.decay = decay
562
+
563
+ def on_before_train(self, trainer: BaseTrainer):
564
+ trainer.logger.info("=> Creating EMA model ...")
565
+ # Note that AveragedModel does not seem to support FSDP. It will crash here.
566
+ self.ema_model = AveragedModel(trainer.model, avg_fn=get_ema_avg_fn(self.decay))
567
+
568
+ def on_after_step(self, trainer: BaseTrainer):
569
+ self.ema_model.update_parameters(trainer.model)
570
+
571
+ def on_load_state_dict(self, trainer: BaseTrainer, state_dict: dict):
572
+ trainer.logger.info("=> Loading EMA model state ...")
573
+ set_model_state_dict(self.ema_model, state_dict["ema_model"])
574
+
575
+ def on_state_dict(self, trainer: BaseTrainer, state_dict: dict):
576
+ # Note: sadly, we need to keep the AveragedModel wrapper, to save its n_averaged buffer
577
+ state_dict["ema_model"] = get_model_state_dict(self.ema_model)
578
+
579
+
580
+ class WandbHook(BaseHook):
581
+ """Log metrics and images to Weights & Biases (rank 0 only).
582
+
583
+ Args:
584
+ project: W&B project name.
585
+ config: Optional config dict or JSON file path to log.
586
+ tags: Optional tag list.
587
+ image_format: File format for images or a callable to derive it per key.
588
+ **wandb_kwargs: Extra arguments forwarded to ``wandb.init``.
589
+ """
590
+
591
+ def __init__(
592
+ self,
593
+ project: str,
594
+ config: dict[str, Any] | str | None = None,
595
+ tags: Sequence[str] | None = None,
596
+ image_format: str | None | Callable[[str], str | None] = "png",
597
+ **wandb_kwargs,
598
+ ):
599
+ self.project = project
600
+ self.config = config
601
+ self.tags = tags
602
+ if callable(image_format):
603
+ self.image_format = image_format
604
+ else:
605
+ self.image_format = lambda _: image_format
606
+ self.wandb_kwargs = wandb_kwargs
607
+
608
+ def on_before_train(self, trainer: BaseTrainer):
609
+ if dist.get_rank() == 0:
610
+ wandb_run_id = self._load_wandb_run_id(trainer)
611
+
612
+ tags = os.getenv("WANDB_TAGS", "")
613
+ tags = list(self.tags) + (tags.split(",") if tags else []) # concat
614
+ tags = list(dict.fromkeys(tags)) # deduplicate while preserving order
615
+
616
+ # it seems that we should use resume_from={run_id}?_{step} in wandb.init instead, but it's not well documented
617
+ self.wandb = wandb.init(
618
+ project=os.getenv("WANDB_PROJECT", self.project),
619
+ dir=os.getenv("WANDB_DIR", trainer.workspace),
620
+ id=os.getenv("WANDB_RUN_ID", wandb_run_id),
621
+ resume=os.getenv("WANDB_RESUME", "must" if wandb_run_id else None),
622
+ config=self.config,
623
+ tags=tags,
624
+ **self.wandb_kwargs,
625
+ )
626
+ if not self.wandb.disabled:
627
+ self._save_wandb_run_id(trainer, self.wandb.id)
628
+
629
+ def on_after_train(self, trainer: BaseTrainer):
630
+ if dist.get_rank() == 0:
631
+ self.wandb.finish()
632
+
633
+ def on_log(self, trainer: BaseTrainer, records: dict, dry_run: bool = False):
634
+ if dist.get_rank() == 0:
635
+ data = {"/".join(k): v for k, v in flatten_nested_dict(records).items()}
636
+ if not dry_run:
637
+ self.wandb.log(data, step=trainer.step)
638
+ else:
639
+ trainer.logger.debug(f"Dry run log. Would log: {data}")
640
+
641
+ def on_log_images(self, trainer: BaseTrainer, records: dict, dry_run: bool = False):
642
+ if dist.get_rank() == 0:
643
+ wandb_data = {}
644
+ for k, img in flatten_nested_dict({"vis": records}).items():
645
+ file_type = self.image_format(k[-1])
646
+ wandb_data.setdefault("/".join(k[:-1]), []).append(
647
+ wandb.Image(
648
+ self._ensure_jpeg_compatible(img)
649
+ if file_type in ["jpg", "jpeg"]
650
+ else img,
651
+ caption=k[-1],
652
+ file_type=file_type,
653
+ )
654
+ )
655
+
656
+ if not dry_run:
657
+ self.wandb.log(wandb_data, step=trainer.step)
658
+ else:
659
+ trainer.logger.debug(f"Dry run log. Would log: {wandb_data}")
660
+
661
+ @staticmethod
662
+ def _ensure_jpeg_compatible(img: PILImage, bg_color: tuple = (255, 255, 255)):
663
+ if img.mode in ("RGB", "L"):
664
+ return img
665
+ elif img.mode in ("RGBA", "LA"):
666
+ background = Image.new("RGB", img.size, bg_color)
667
+ background.paste(img, mask=img.getchannel("A"))
668
+ return background
669
+ else:
670
+ warnings.warn(
671
+ f"Trying to convert {img.mode} to RGB in a best-effort manner."
672
+ )
673
+ return img.convert("RGB")
674
+
675
+ @staticmethod
676
+ def _wandb_run_id_file_name(trainer: BaseTrainer):
677
+ return trainer.workspace / "wandb_run_id"
678
+
679
+ @classmethod
680
+ def _save_wandb_run_id(cls, trainer: BaseTrainer, run_id: str):
681
+ cls._wandb_run_id_file_name(trainer).write_text(run_id)
682
+
683
+ @classmethod
684
+ def _load_wandb_run_id(cls, trainer: BaseTrainer):
685
+ f = cls._wandb_run_id_file_name(trainer)
686
+ if f.exists():
687
+ return f.read_text()
688
+ return None
689
+
690
+
691
+ class ImageFileLoggerHook(BaseHook):
692
+ """Persist logged images to ``workspace/visualizations`` on rank 0.
693
+
694
+ Args:
695
+ image_format: File extension or callable taking the leaf key.
696
+ """
697
+
698
+ def __init__(
699
+ self,
700
+ image_format: str | Callable[[str], str] = "png",
701
+ ):
702
+ if callable(image_format):
703
+ self.image_format = image_format
704
+ else:
705
+ self.image_format = lambda _: image_format
706
+
707
+ def on_log_images(self, trainer: BaseTrainer, records: dict, dry_run: bool = False):
708
+ if dist.get_rank() == 0:
709
+ for k, img in flatten_nested_dict(records).items():
710
+ p = trainer.workspace / "visualizations" / str(trainer.step) / Path(*k)
711
+ p = Path(str(p) + "." + self.image_format(k[-1]))
712
+ if not dry_run:
713
+ p.parent.mkdir(parents=True, exist_ok=True)
714
+ img.save(p)
715
+ else:
716
+ trainer.logger.debug(f"Dry run log. Would save {img} to: {p}")
File without changes
@@ -0,0 +1,406 @@
1
+ import logging
2
+ import time
3
+ import warnings
4
+ from contextlib import closing, nullcontext
5
+ from logging import Logger
6
+ from numbers import Number
7
+ from pathlib import Path
8
+ from typing import (
9
+ TYPE_CHECKING,
10
+ Any,
11
+ Callable,
12
+ Iterable,
13
+ Iterator,
14
+ TypeAlias,
15
+ Union,
16
+ )
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ from torch.distributed.checkpoint.state_dict import (
21
+ StateDictOptions,
22
+ get_state_dict,
23
+ set_state_dict,
24
+ )
25
+
26
+ if TYPE_CHECKING:
27
+ from .hooks import BaseHook
28
+
29
+ Records: TypeAlias = dict[str, Union[Number, "Records"]]
30
+
31
+
32
+ class BaseTrainer:
33
+ """
34
+ Minimal training loop that orchestrates builds, accumulation, retries, and hooks.
35
+
36
+ Subclasses provide component factories and a forward pass; the base class handles
37
+ sequencing, mixed precision, accumulation, state management, and hook dispatch.
38
+
39
+ Args:
40
+ max_steps: Number of training steps to run.
41
+ grad_clip: Max gradient norm; if set, gradients are clipped before stepping.
42
+ max_non_finite_grad_retries: Number of retries when encountering non-finite gradients (scaler disabled).
43
+ mixed_precision: ``\"fp16\"`` or ``\"bf16\"`` to enable autocast; ``None`` disables it.
44
+ gradient_accumulation_steps: Number of microsteps to accumulate before stepping.
45
+ workspace: Optional working directory used by hooks (e.g., checkpoints, logs).
46
+ device: Device for the model and tensors.
47
+ no_sync_accumulate: Whether to call ``no_sync`` on distributed modules during accumulation.
48
+ state_dict_options: Torch distributed checkpoint options.
49
+ logger: Logger instance; a default logger is created when omitted.
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ max_steps: int,
55
+ grad_clip: float | None = None,
56
+ max_non_finite_grad_retries: int | None = None,
57
+ mixed_precision: str | None = None,
58
+ gradient_accumulation_steps: int | None = None,
59
+ workspace: Path | str | None = None,
60
+ device: torch.device | str | int | None = None,
61
+ no_sync_accumulate: bool = True, # can make sense to disable this for FSDP
62
+ state_dict_options: StateDictOptions | None = None,
63
+ logger: Logger | None = None,
64
+ ):
65
+ self.step = 0 # refers to the last begun step. incremented *before* each step
66
+ self.max_steps = max_steps
67
+ self.grad_clip = grad_clip
68
+ self.max_non_finite_grad_retries = max_non_finite_grad_retries
69
+ match mixed_precision:
70
+ case "fp16":
71
+ self.mixed_precision = torch.float16
72
+ case "bf16":
73
+ self.mixed_precision = torch.bfloat16
74
+ case None:
75
+ self.mixed_precision = None
76
+ case _:
77
+ raise ValueError(f"Unsupported mixed precision: {mixed_precision}")
78
+ self.device = (
79
+ torch.device(device) if device is not None else torch.get_default_device()
80
+ )
81
+ self.gradient_accumulation_steps = gradient_accumulation_steps or 1
82
+ self.workspace = Path(workspace) if workspace is not None else None
83
+ self.logger = logger if logger is not None else logging.getLogger("trainer")
84
+ self.no_sync_accumulate = no_sync_accumulate
85
+ self.state_dict_options = state_dict_options
86
+
87
+ def _build(self):
88
+ self.logger.debug("_build()")
89
+ self.data_loader = self.build_data_loader()
90
+ self.model = self.build_model()
91
+ self.optimizer = self.build_optimizer()
92
+ self.lr_scheduler = self.build_lr_scheduler()
93
+ self.grad_scaler = self.build_grad_scaler()
94
+ self.hooks = self.build_hooks()
95
+
96
+ def build_data_loader(self) -> Iterable:
97
+ """Return the training data iterator."""
98
+ raise NotImplementedError
99
+
100
+ def build_model(self) -> nn.Module:
101
+ """Construct and return the model."""
102
+ raise NotImplementedError
103
+
104
+ def build_optimizer(self) -> torch.optim.Optimizer:
105
+ """Create the optimizer for the model."""
106
+ raise NotImplementedError
107
+
108
+ def build_lr_scheduler(self) -> torch.optim.lr_scheduler.LRScheduler | None:
109
+ """Optionally create a learning-rate scheduler."""
110
+ return None
111
+
112
+ def build_hooks(self) -> list["BaseHook"]:
113
+ """Return hooks to run during training."""
114
+ return []
115
+
116
+ def build_grad_scaler(self) -> torch.amp.GradScaler:
117
+ """Create the gradient scaler used for mixed precision."""
118
+ return torch.amp.GradScaler(
119
+ self.device.type, enabled=self.mixed_precision == torch.float16
120
+ )
121
+
122
+ def state_dict(self) -> dict[str, Any]:
123
+ self.logger.debug("state_dict()")
124
+ model_state_dict, optimizer_state_dict = get_state_dict(
125
+ self.model, self.optimizer, options=self.state_dict_options
126
+ )
127
+ state_dict = {
128
+ "model": model_state_dict,
129
+ "training_state": {
130
+ "step": self.step,
131
+ "optimizer": optimizer_state_dict,
132
+ "lr_scheduler": self.lr_scheduler.state_dict()
133
+ if self.lr_scheduler
134
+ else None,
135
+ "grad_scaler": self.grad_scaler.state_dict(),
136
+ },
137
+ }
138
+ for h in self.hooks:
139
+ h.on_state_dict(self, state_dict)
140
+ return state_dict
141
+
142
+ def load_state_dict(self, state_dict: dict[str, Any]):
143
+ self.logger.debug("load_state_dict()")
144
+ training_state = state_dict["training_state"]
145
+
146
+ self.step = training_state["step"]
147
+ self.logger.info(f"=> Resuming from step {self.step} ...")
148
+
149
+ if self.lr_scheduler is not None:
150
+ # NOTE: order is important. load the optimizer AFTER lr_scheduler. https://github.com/pytorch/pytorch/issues/119168
151
+ self.logger.info("=> Loading LR scheduler state ...")
152
+ self.lr_scheduler.load_state_dict(training_state["lr_scheduler"])
153
+ self.logger.info("=> Loading grad scaler state ...")
154
+ self.grad_scaler.load_state_dict(training_state["grad_scaler"])
155
+
156
+ self.logger.info("=> Loading model and optimizer state ...")
157
+ set_state_dict(
158
+ self.model,
159
+ self.optimizer,
160
+ model_state_dict=state_dict["model"],
161
+ optim_state_dict=training_state["optimizer"],
162
+ options=self.state_dict_options,
163
+ )
164
+
165
+ self.logger.info("=> Loading hook states ...")
166
+ for h in self.hooks:
167
+ h.on_load_state_dict(self, state_dict)
168
+
169
+ def train(self):
170
+ """Run the training loop until ``max_steps`` are completed."""
171
+ self._build()
172
+ self._before_train()
173
+
174
+ self.model.train()
175
+ self.optimizer.zero_grad() # just in case
176
+
177
+ # attempt to explicitly close the iterator since it likely owns resources such as worker processes
178
+ with maybe_closing(iter(self.data_loader)) as data_iter:
179
+ while self.step < self.max_steps:
180
+ self.step += 1
181
+ self.step_info = {}
182
+ self._before_step()
183
+
184
+ step_time = time.perf_counter()
185
+ self._run_step(data_iter)
186
+ self.step_info["step_time"] = time.perf_counter() - step_time
187
+
188
+ self._after_step()
189
+
190
+ self._after_train()
191
+
192
+ # the only difference is that we add the accumulate context and do the warning
193
+ def _run_step(self, data_iter: Iterator):
194
+ """
195
+ Run a single optimizer step of training.
196
+ Args:
197
+ data_iter (Iterator): Data iterator.
198
+ """
199
+
200
+ def reset_step_info():
201
+ self.step_info["loss"] = []
202
+ self.step_info["records"] = []
203
+
204
+ reset_step_info()
205
+ self.step_info["data_time"] = []
206
+ non_finite_grad_retry_count = 0
207
+ i_acc = 0
208
+ while i_acc < self.gradient_accumulation_steps:
209
+ is_accumulating = i_acc < self.gradient_accumulation_steps - 1
210
+ no_sync_accumulate = (
211
+ self.model.no_sync()
212
+ if self.no_sync_accumulate
213
+ and is_accumulating
214
+ and hasattr(self.model, "no_sync")
215
+ else nullcontext()
216
+ ) # for DDP and FSDP
217
+
218
+ data_time = time.perf_counter()
219
+ input = next(data_iter)
220
+ self.step_info["data_time"].append(time.perf_counter() - data_time)
221
+
222
+ with no_sync_accumulate:
223
+ with torch.autocast(
224
+ device_type=self.device.type,
225
+ dtype=self.mixed_precision,
226
+ enabled=bool(self.mixed_precision),
227
+ ):
228
+ self.logger.debug(f"{self.step}-{i_acc} forward()")
229
+ loss, records = self.forward(input)
230
+ if loss is None:
231
+ if isinstance(
232
+ self.model,
233
+ (
234
+ torch.nn.parallel.DistributedDataParallel,
235
+ torch.distributed.fsdp.FullyShardedDataParallel,
236
+ ),
237
+ ):
238
+ # TODO: find a better way to handle this
239
+ # It seems that each DDP forward call is expected to be followed by a backward pass, as DDP maintains internal state after the forward pass that anticipates a backward step.
240
+ # While it might work with `broadcast_buffers=False` or if the backward pass is collectively skipped across all ranks,
241
+ # this behavior is not officially documented as safe and could result in undefined behavior.
242
+ # Since `Trainer.forward` may also return None before calling `DDP.forward`, this is just a warning rather than an error.
243
+ # I think the same thing applies to FSDP, but I haven't confirmed it.
244
+ warnings.warn(
245
+ "Loss is None; skipping backward step. Ensure self.model.forward was not called in self.forward to avoid undefined behavior in DDP and FSDP.",
246
+ LossNoneWarning,
247
+ )
248
+ continue # skip the backward & optimizer step
249
+ if not torch.isfinite(
250
+ loss
251
+ ): # TODO: check if device sync slows down training
252
+ self.logger.warning(
253
+ f"Loss is non-finite ({loss.item()}). records={records}"
254
+ )
255
+ # we will handle non-finite later at the optimizer step, the warning is just for debugging
256
+ # keep in mind that at least for DDP, we must still call backward() to avoid undefined behavior!
257
+
258
+ self.step_info["loss"].append(loss.detach())
259
+ self.step_info["records"].append(records)
260
+ loss = loss / self.gradient_accumulation_steps
261
+ self.logger.debug(f"{self.step}-{i_acc} backward()")
262
+ self.grad_scaler.scale(loss).backward()
263
+ i_acc += 1 # only increment after an actual backward pass
264
+
265
+ if not is_accumulating:
266
+ if not self.grad_scaler.is_enabled():
267
+ # only skip non-finite grads if the scaler is disabled (the scaler needs to process non-finite grads to adjust the scale)
268
+ if any(
269
+ (not torch.isfinite(p.grad).all())
270
+ for p in self.model.parameters()
271
+ if p.grad is not None
272
+ ):
273
+ if self.max_non_finite_grad_retries is None or (
274
+ non_finite_grad_retry_count
275
+ < self.max_non_finite_grad_retries
276
+ ):
277
+ non_finite_grad_retry_count += 1
278
+ self.logger.warning(
279
+ f"Gradient is non-finite. Retrying step {self.step} (retry {non_finite_grad_retry_count}"
280
+ + (
281
+ f"/{self.max_non_finite_grad_retries})."
282
+ if self.max_non_finite_grad_retries is not None
283
+ else ")."
284
+ )
285
+ )
286
+ self.optimizer.zero_grad()
287
+ # TODO: check if we also need to "reset" (is that a thing?) the scaler here
288
+ reset_step_info()
289
+ i_acc = 0 # start accumulation again
290
+ continue
291
+ else:
292
+ raise RuntimeError(
293
+ "Gradient is non-finite. Exceeded maximum retries for non-finite gradients."
294
+ )
295
+ self.grad_scaler.unscale_(self.optimizer)
296
+ self._before_optimizer_step()
297
+ if self.grad_clip is not None:
298
+ self.step_info["grad_norm"] = torch.nn.utils.clip_grad_norm_(
299
+ self.model.parameters(), self.grad_clip
300
+ )
301
+ self.logger.debug(f"{self.step}-{i_acc - 1} step()")
302
+ self.grad_scaler.step(self.optimizer)
303
+ self.grad_scaler.update()
304
+ self.optimizer.zero_grad()
305
+ if self.lr_scheduler is not None:
306
+ self.lr_scheduler.step()
307
+
308
+ def forward(self, input: Any) -> tuple[torch.Tensor | None, Records]:
309
+ """
310
+ Perform a forward pass and return loss plus records for logging.
311
+
312
+ Args:
313
+ input: Batch yielded by the data loader.
314
+
315
+ Returns:
316
+ The loss (``None`` skips backward/step; if using DDP/FSDP, avoid invoking the wrapped module's ``forward`` in that case).
317
+ A nested dict of numeric metrics that will be averaged and emitted to hooks.
318
+ """
319
+ raise NotImplementedError
320
+
321
+ @classmethod
322
+ def unwrap(self, module: nn.Module) -> nn.Module:
323
+ match module:
324
+ case torch._dynamo.eval_frame.OptimizedModule():
325
+ return self.unwrap(module._orig_mod)
326
+ case (
327
+ torch.nn.parallel.DistributedDataParallel()
328
+ | torch.nn.parallel.DataParallel()
329
+ | torch.optim.swa_utils.AveragedModel()
330
+ ):
331
+ return self.unwrap(module.module)
332
+ return module
333
+
334
+ @property
335
+ def unwrapped_model(self):
336
+ return self.unwrap(self.model)
337
+
338
+ def log(self, records: dict[str, Any], dry_run: bool = False):
339
+ """
340
+ Dispatch numeric records to hooks (e.g., trackers or stdout).
341
+
342
+ Args:
343
+ records: Nested dict of numeric metrics to log.
344
+ dry_run: If True, hooks should avoid side effects and only report intent.
345
+ """
346
+ self.logger.debug("log()")
347
+ for h in self.hooks:
348
+ h.on_log(self, records, dry_run=dry_run)
349
+
350
+ def log_images(self, records: dict[str, Any], dry_run: bool = False):
351
+ """
352
+ Dispatch image records to hooks.
353
+
354
+ Args:
355
+ records: Nested dict of images to log.
356
+ dry_run: If True, hooks should avoid side effects and only report intent.
357
+ """
358
+ self.logger.debug("log_images()")
359
+ for h in self.hooks:
360
+ h.on_log_images(self, records, dry_run=dry_run)
361
+
362
+ def _before_train(self):
363
+ self.logger.debug("_before_train()")
364
+ for h in self.hooks:
365
+ h.on_before_train(self)
366
+
367
+ def _after_train(self):
368
+ self.logger.debug("_after_train()")
369
+ for h in self.hooks:
370
+ h.on_after_train(self)
371
+
372
+ def _before_step(self):
373
+ self.logger.debug("_before_step()")
374
+ for h in self.hooks:
375
+ h.on_before_step(self)
376
+
377
+ def _after_step(self):
378
+ self.logger.debug("_after_step()")
379
+ for h in self.hooks:
380
+ h.on_after_step(self)
381
+
382
+ def _before_optimizer_step(self):
383
+ self.logger.debug("_before_optimizer_step()")
384
+ for h in self.hooks:
385
+ h.on_before_optimizer_step(self)
386
+
387
+
388
+ def maybe_closing(obj):
389
+ """Return a context manager that closes `obj` if it has a .close() method, otherwise does nothing."""
390
+ return closing(obj) if callable(getattr(obj, "close", None)) else nullcontext(obj)
391
+
392
+
393
+ def map_nested_tensor(f: Callable[[torch.Tensor], Any], obj: Any):
394
+ """Apply ``f`` to every tensor contained in a nested structure."""
395
+ if isinstance(obj, torch.Tensor):
396
+ return f(obj)
397
+ elif isinstance(obj, (list, tuple, set)):
398
+ return type(obj)(map_nested_tensor(f, o) for o in obj)
399
+ elif isinstance(obj, dict):
400
+ return type(obj)((k, map_nested_tensor(f, v)) for k, v in obj.items())
401
+ else:
402
+ return obj
403
+
404
+
405
+ class LossNoneWarning(UserWarning):
406
+ """Warning raised when ``forward`` returns ``None`` in distributed contexts."""
@@ -0,0 +1,67 @@
1
+ # modified from: https://github.com/microsoft/MoGe/blob/6b8b43db567ca4b08615c39b42cffd6c76cada29/moge/utils/tools.py
2
+
3
+ import math
4
+ from typing import Any, Generator, MutableMapping
5
+
6
+
7
+ def traverse_nested_dict_keys(
8
+ d: dict[str, dict],
9
+ ) -> Generator[tuple[str, ...], None, None]:
10
+ for k, v in d.items():
11
+ if isinstance(v, dict):
12
+ for sub_key in traverse_nested_dict_keys(v):
13
+ yield (k,) + sub_key
14
+ else:
15
+ yield (k,)
16
+
17
+
18
+ def get_nested_dict(d: dict[str, dict], keys: tuple[str, ...], default: Any = None):
19
+ for k in keys:
20
+ d = d.get(k, default)
21
+ if d is None:
22
+ break
23
+ return d
24
+
25
+
26
+ def set_nested_dict(d: dict[str, dict], keys: tuple[str, ...], value: Any):
27
+ for k in keys[:-1]:
28
+ d = d.setdefault(k, {})
29
+ d[keys[-1]] = value
30
+
31
+
32
+ def key_average(list_of_dicts: list, exclude_nan: bool = False) -> dict[str, Any]:
33
+ """
34
+ Returns a dictionary with the average value of each key in the input list of dictionaries.
35
+ """
36
+ _nested_dict_keys = set()
37
+ for d in list_of_dicts:
38
+ _nested_dict_keys.update(traverse_nested_dict_keys(d))
39
+ _nested_dict_keys = sorted(_nested_dict_keys)
40
+ result = {}
41
+ for k in _nested_dict_keys:
42
+ values = []
43
+ for d in list_of_dicts:
44
+ v = get_nested_dict(d, k)
45
+ if v is not None and (not exclude_nan or not math.isnan(v)):
46
+ values.append(v)
47
+ avg = sum(values) / len(values) if values else float("nan")
48
+ set_nested_dict(result, k, avg)
49
+ return result
50
+
51
+
52
+ def flatten_nested_dict(
53
+ d: dict[str, Any], parent_key: tuple[str, ...] = None
54
+ ) -> dict[tuple[str, ...], Any]:
55
+ """
56
+ Flattens a nested dictionary into a single-level dictionary, with keys as tuples.
57
+ """
58
+ items = []
59
+ if parent_key is None:
60
+ parent_key = ()
61
+ for k, v in d.items():
62
+ new_key = parent_key + (k,)
63
+ if isinstance(v, MutableMapping):
64
+ items.extend(flatten_nested_dict(v, new_key).items())
65
+ else:
66
+ items.append((new_key, v))
67
+ return dict(items)