nshtrainer 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. nshtrainer/__init__.py +64 -0
  2. nshtrainer/_experimental/__init__.py +2 -0
  3. nshtrainer/_experimental/flops/__init__.py +48 -0
  4. nshtrainer/_experimental/flops/flop_counter.py +787 -0
  5. nshtrainer/_experimental/flops/module_tracker.py +140 -0
  6. nshtrainer/_snoop.py +216 -0
  7. nshtrainer/_submit/print_environment_info.py +31 -0
  8. nshtrainer/_submit/session/_output.py +12 -0
  9. nshtrainer/_submit/session/_script.py +109 -0
  10. nshtrainer/_submit/session/lsf.py +467 -0
  11. nshtrainer/_submit/session/slurm.py +573 -0
  12. nshtrainer/_submit/session/unified.py +350 -0
  13. nshtrainer/actsave/__init__.py +7 -0
  14. nshtrainer/actsave/_callback.py +75 -0
  15. nshtrainer/actsave/_loader.py +144 -0
  16. nshtrainer/actsave/_saver.py +337 -0
  17. nshtrainer/callbacks/__init__.py +35 -0
  18. nshtrainer/callbacks/_throughput_monitor_callback.py +549 -0
  19. nshtrainer/callbacks/base.py +113 -0
  20. nshtrainer/callbacks/early_stopping.py +112 -0
  21. nshtrainer/callbacks/ema.py +383 -0
  22. nshtrainer/callbacks/finite_checks.py +75 -0
  23. nshtrainer/callbacks/gradient_skipping.py +103 -0
  24. nshtrainer/callbacks/interval.py +322 -0
  25. nshtrainer/callbacks/latest_epoch_checkpoint.py +45 -0
  26. nshtrainer/callbacks/log_epoch.py +35 -0
  27. nshtrainer/callbacks/norm_logging.py +187 -0
  28. nshtrainer/callbacks/on_exception_checkpoint.py +44 -0
  29. nshtrainer/callbacks/print_table.py +90 -0
  30. nshtrainer/callbacks/throughput_monitor.py +56 -0
  31. nshtrainer/callbacks/timer.py +157 -0
  32. nshtrainer/callbacks/wandb_watch.py +103 -0
  33. nshtrainer/config.py +289 -0
  34. nshtrainer/data/__init__.py +4 -0
  35. nshtrainer/data/balanced_batch_sampler.py +132 -0
  36. nshtrainer/data/transform.py +67 -0
  37. nshtrainer/lr_scheduler/__init__.py +18 -0
  38. nshtrainer/lr_scheduler/_base.py +101 -0
  39. nshtrainer/lr_scheduler/linear_warmup_cosine.py +138 -0
  40. nshtrainer/lr_scheduler/reduce_lr_on_plateau.py +73 -0
  41. nshtrainer/model/__init__.py +44 -0
  42. nshtrainer/model/base.py +641 -0
  43. nshtrainer/model/config.py +2064 -0
  44. nshtrainer/model/modules/callback.py +157 -0
  45. nshtrainer/model/modules/debug.py +42 -0
  46. nshtrainer/model/modules/distributed.py +70 -0
  47. nshtrainer/model/modules/logger.py +170 -0
  48. nshtrainer/model/modules/profiler.py +24 -0
  49. nshtrainer/model/modules/rlp_sanity_checks.py +202 -0
  50. nshtrainer/model/modules/shared_parameters.py +72 -0
  51. nshtrainer/nn/__init__.py +19 -0
  52. nshtrainer/nn/mlp.py +106 -0
  53. nshtrainer/nn/module_dict.py +66 -0
  54. nshtrainer/nn/module_list.py +50 -0
  55. nshtrainer/nn/nonlinearity.py +157 -0
  56. nshtrainer/optimizer.py +62 -0
  57. nshtrainer/runner.py +21 -0
  58. nshtrainer/scripts/check_env.py +41 -0
  59. nshtrainer/scripts/find_packages.py +51 -0
  60. nshtrainer/trainer/__init__.py +1 -0
  61. nshtrainer/trainer/signal_connector.py +208 -0
  62. nshtrainer/trainer/trainer.py +340 -0
  63. nshtrainer/typecheck.py +144 -0
  64. nshtrainer/util/environment.py +119 -0
  65. nshtrainer/util/seed.py +11 -0
  66. nshtrainer/util/singleton.py +89 -0
  67. nshtrainer/util/slurm.py +49 -0
  68. nshtrainer/util/typed.py +2 -0
  69. nshtrainer/util/typing_utils.py +19 -0
  70. nshtrainer-0.1.0.dist-info/METADATA +18 -0
  71. nshtrainer-0.1.0.dist-info/RECORD +72 -0
  72. nshtrainer-0.1.0.dist-info/WHEEL +4 -0
@@ -0,0 +1,322 @@
1
+ from collections.abc import Callable
2
+ from typing import Literal
3
+
4
+ from lightning.pytorch import LightningModule, Trainer
5
+ from lightning.pytorch.callbacks import Callback
6
+ from typing_extensions import override
7
+
8
+ Split = Literal["train", "val", "test", "predict"]
9
+
10
+
11
+ def _check_step(step: int, interval: int, skip_first: bool = False):
12
+ if step % interval != 0:
13
+ return False
14
+ if skip_first and step == 0:
15
+ return False
16
+ return True
17
+
18
+
19
+ class StepIntervalCallback(Callback):
20
+ def __init__(
21
+ self,
22
+ function: Callable[[Trainer, LightningModule], None],
23
+ *,
24
+ interval: int,
25
+ skip_first: bool = False,
26
+ splits: list[Split] = ["train", "val", "test", "predict"],
27
+ ):
28
+ super().__init__()
29
+
30
+ self.function = function
31
+ self.interval = interval
32
+ self.skip_first = skip_first
33
+ self.splits = set(splits)
34
+
35
+ @override
36
+ def on_train_batch_start(
37
+ self,
38
+ trainer,
39
+ pl_module,
40
+ batch,
41
+ batch_idx,
42
+ dataloader_idx=1,
43
+ ):
44
+ if (
45
+ not _check_step(
46
+ trainer.global_step,
47
+ self.interval,
48
+ skip_first=self.skip_first,
49
+ )
50
+ or "train" not in self.splits
51
+ ):
52
+ return
53
+ self.function(trainer, pl_module)
54
+
55
+ @override
56
+ def on_validation_batch_start(
57
+ self,
58
+ trainer,
59
+ pl_module,
60
+ batch,
61
+ batch_idx,
62
+ dataloader_idx=1,
63
+ ):
64
+ if (
65
+ not _check_step(
66
+ trainer.global_step,
67
+ self.interval,
68
+ skip_first=self.skip_first,
69
+ )
70
+ or "val" not in self.splits
71
+ ):
72
+ return
73
+ self.function(trainer, pl_module)
74
+
75
+ @override
76
+ def on_test_batch_start(
77
+ self,
78
+ trainer,
79
+ pl_module,
80
+ batch,
81
+ batch_idx,
82
+ dataloader_idx=1,
83
+ ):
84
+ if (
85
+ not _check_step(
86
+ trainer.global_step,
87
+ self.interval,
88
+ skip_first=self.skip_first,
89
+ )
90
+ or "test" not in self.splits
91
+ ):
92
+ return
93
+ self.function(trainer, pl_module)
94
+
95
+ @override
96
+ def on_predict_batch_start(
97
+ self,
98
+ trainer,
99
+ pl_module,
100
+ batch,
101
+ batch_idx,
102
+ dataloader_idx=1,
103
+ ):
104
+ if (
105
+ not _check_step(
106
+ trainer.global_step,
107
+ self.interval,
108
+ skip_first=self.skip_first,
109
+ )
110
+ or "predict" not in self.splits
111
+ ):
112
+ return
113
+ self.function(trainer, pl_module)
114
+
115
+
116
+ class EpochIntervalCallback(Callback):
117
+ def __init__(
118
+ self,
119
+ function: Callable[[Trainer, LightningModule], None],
120
+ *,
121
+ interval: int,
122
+ skip_first: bool = False,
123
+ splits: list[Split] = ["train", "val", "test", "predict"],
124
+ ):
125
+ super().__init__()
126
+
127
+ self.function = function
128
+ self.interval = interval
129
+ self.skip_first = skip_first
130
+ self.splits = set(splits)
131
+
132
+ @override
133
+ def on_train_epoch_start(self, trainer, pl_module):
134
+ if (
135
+ not _check_step(
136
+ trainer.current_epoch,
137
+ self.interval,
138
+ skip_first=self.skip_first,
139
+ )
140
+ or "train" not in self.splits
141
+ ):
142
+ return
143
+ self.function(trainer, pl_module)
144
+
145
+ @override
146
+ def on_validation_epoch_start(self, trainer, pl_module):
147
+ if (
148
+ not _check_step(
149
+ trainer.current_epoch,
150
+ self.interval,
151
+ skip_first=self.skip_first,
152
+ )
153
+ or "val" not in self.splits
154
+ ):
155
+ return
156
+ self.function(trainer, pl_module)
157
+
158
+ @override
159
+ def on_test_epoch_start(self, trainer, pl_module):
160
+ if (
161
+ not _check_step(
162
+ trainer.current_epoch,
163
+ self.interval,
164
+ skip_first=self.skip_first,
165
+ )
166
+ or "test" not in self.splits
167
+ ):
168
+ return
169
+ self.function(trainer, pl_module)
170
+
171
+ @override
172
+ def on_predict_epoch_start(self, trainer, pl_module):
173
+ if (
174
+ not _check_step(
175
+ trainer.current_epoch,
176
+ self.interval,
177
+ skip_first=self.skip_first,
178
+ )
179
+ or "predict" not in self.splits
180
+ ):
181
+ return
182
+ self.function(trainer, pl_module)
183
+
184
+
185
+ class IntervalCallback(Callback):
186
+ def __init__(
187
+ self,
188
+ function: Callable[[Trainer, LightningModule], None],
189
+ *,
190
+ step_interval: int | None = None,
191
+ epoch_interval: int | None = None,
192
+ skip_first: bool = False,
193
+ splits: list[Split] = ["train", "val", "test", "predict"],
194
+ ):
195
+ super().__init__()
196
+
197
+ self.callback = None
198
+
199
+ if step_interval is not None:
200
+ self.callback = StepIntervalCallback(
201
+ function,
202
+ interval=step_interval,
203
+ splits=splits,
204
+ skip_first=skip_first,
205
+ )
206
+ elif epoch_interval is not None:
207
+ self.callback = EpochIntervalCallback(
208
+ function,
209
+ interval=epoch_interval,
210
+ splits=splits,
211
+ skip_first=skip_first,
212
+ )
213
+ else:
214
+ raise ValueError("Either step_interval or epoch_interval must be specified")
215
+
216
+ @override
217
+ def on_train_batch_start(
218
+ self,
219
+ trainer,
220
+ pl_module,
221
+ batch,
222
+ batch_idx,
223
+ dataloader_idx=1,
224
+ ):
225
+ if not isinstance(self.callback, StepIntervalCallback):
226
+ return
227
+
228
+ self.callback.on_train_batch_start(
229
+ trainer,
230
+ pl_module,
231
+ batch,
232
+ batch_idx,
233
+ dataloader_idx=1,
234
+ )
235
+
236
+ @override
237
+ def on_validation_batch_start(
238
+ self,
239
+ trainer,
240
+ pl_module,
241
+ batch,
242
+ batch_idx,
243
+ dataloader_idx=1,
244
+ ):
245
+ if not isinstance(self.callback, StepIntervalCallback):
246
+ return
247
+
248
+ self.callback.on_validation_batch_start(
249
+ trainer,
250
+ pl_module,
251
+ batch,
252
+ batch_idx,
253
+ dataloader_idx=1,
254
+ )
255
+
256
+ @override
257
+ def on_test_batch_start(
258
+ self,
259
+ trainer,
260
+ pl_module,
261
+ batch,
262
+ batch_idx,
263
+ dataloader_idx=1,
264
+ ):
265
+ if not isinstance(self.callback, StepIntervalCallback):
266
+ return
267
+
268
+ self.callback.on_test_batch_start(
269
+ trainer,
270
+ pl_module,
271
+ batch,
272
+ batch_idx,
273
+ dataloader_idx=1,
274
+ )
275
+
276
+ @override
277
+ def on_predict_batch_start(
278
+ self,
279
+ trainer,
280
+ pl_module,
281
+ batch,
282
+ batch_idx,
283
+ dataloader_idx=1,
284
+ ):
285
+ if not isinstance(self.callback, StepIntervalCallback):
286
+ return
287
+
288
+ self.callback.on_predict_batch_start(
289
+ trainer,
290
+ pl_module,
291
+ batch,
292
+ batch_idx,
293
+ dataloader_idx=1,
294
+ )
295
+
296
+ @override
297
+ def on_train_epoch_start(self, trainer, pl_module):
298
+ if not isinstance(self.callback, EpochIntervalCallback):
299
+ return
300
+
301
+ self.callback.on_train_epoch_start(trainer, pl_module)
302
+
303
+ @override
304
+ def on_validation_epoch_start(self, trainer, pl_module):
305
+ if not isinstance(self.callback, EpochIntervalCallback):
306
+ return
307
+
308
+ self.callback.on_validation_epoch_start(trainer, pl_module)
309
+
310
+ @override
311
+ def on_test_epoch_start(self, trainer, pl_module):
312
+ if not isinstance(self.callback, EpochIntervalCallback):
313
+ return
314
+
315
+ self.callback.on_test_epoch_start(trainer, pl_module)
316
+
317
+ @override
318
+ def on_predict_epoch_start(self, trainer, pl_module):
319
+ if not isinstance(self.callback, EpochIntervalCallback):
320
+ return
321
+
322
+ self.callback.on_predict_epoch_start(trainer, pl_module)
@@ -0,0 +1,45 @@
1
+ import logging
2
+ from pathlib import Path
3
+
4
+ from lightning.fabric.utilities.types import _PATH
5
+ from lightning.pytorch import LightningModule, Trainer
6
+ from lightning.pytorch.callbacks import Checkpoint
7
+ from typing_extensions import override
8
+
9
+ log = logging.getLogger(__name__)
10
+
11
+
12
+ class LatestEpochCheckpoint(Checkpoint):
13
+ DEFAULT_FILENAME = "latest_epoch{epoch:02d}_step{step:04d}.ckpt"
14
+
15
+ def __init__(
16
+ self,
17
+ dirpath: _PATH,
18
+ filename: str | None = None,
19
+ save_weights_only: bool = False,
20
+ ):
21
+ super().__init__()
22
+
23
+ self._dirpath = Path(dirpath)
24
+ self._filename = filename or self.DEFAULT_FILENAME
25
+ self._save_weights_only = save_weights_only
26
+
27
+ # Also, we hold a reference to the last checkpoint path
28
+ # to be able to remove it when a new checkpoint is saved.
29
+ self._last_ckpt_path: Path | None = None
30
+
31
+ def _ckpt_path(self, trainer: Trainer):
32
+ return self._dirpath / self._filename.format(
33
+ epoch=trainer.current_epoch, step=trainer.global_step
34
+ )
35
+
36
+ @override
37
+ def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):
38
+ # Remove the last checkpoint if it exists
39
+ if self._last_ckpt_path is not None:
40
+ trainer.strategy.remove_checkpoint(self._last_ckpt_path)
41
+
42
+ # Save the new checkpoint
43
+ filepath = self._ckpt_path(trainer)
44
+ trainer.save_checkpoint(filepath, self._save_weights_only)
45
+ self._last_ckpt_path = filepath
@@ -0,0 +1,35 @@
1
+ import logging
2
+ import math
3
+ from typing import Any
4
+
5
+ from lightning.pytorch import LightningModule, Trainer
6
+ from lightning.pytorch.callbacks import Callback
7
+ from typing_extensions import override
8
+
9
+ log = logging.getLogger(__name__)
10
+
11
+
12
+ class LogEpochCallback(Callback):
13
+ def __init__(self, metric_name: str = "computed_epoch"):
14
+ super().__init__()
15
+
16
+ self.metric_name = metric_name
17
+
18
+ @override
19
+ def on_train_batch_start(
20
+ self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int
21
+ ):
22
+ if trainer.logger is None:
23
+ return
24
+
25
+ # If trainer.num_training_batches is not set or is nan/inf, we cannot calculate the epoch
26
+ if (
27
+ not trainer.num_training_batches
28
+ or math.isnan(trainer.num_training_batches)
29
+ or math.isinf(trainer.num_training_batches)
30
+ ):
31
+ log.warning("Trainer has no valid num_training_batches. Cannot log epoch.")
32
+ return
33
+
34
+ epoch = pl_module.global_step / trainer.num_training_batches
35
+ pl_module.log(self.metric_name, epoch, on_step=True, on_epoch=False)
@@ -0,0 +1,187 @@
1
+ from logging import getLogger
2
+ from typing import Literal, cast
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from lightning.pytorch import Callback, LightningModule, Trainer
7
+ from torch.optim import Optimizer
8
+ from typing_extensions import override
9
+
10
+ from .base import CallbackConfigBase
11
+
12
+ log = getLogger(__name__)
13
+
14
+
15
+ def grad_norm(
16
+ module: nn.Module,
17
+ norm_type: float | int | str,
18
+ group_separator: str = "/",
19
+ grad: bool = True,
20
+ ) -> dict[str, torch.Tensor | float]:
21
+ """Compute each parameter's gradient's norm and their overall norm.
22
+
23
+ The overall norm is computed over all gradients together, as if they
24
+ were concatenated into a single vector.
25
+
26
+ Args:
27
+ module: :class:`torch.nn.Module` to inspect.
28
+ norm_type: The type of the used p-norm, cast to float if necessary.
29
+ Can be ``'inf'`` for infinity norm.
30
+ group_separator: The separator string used by the logger to group
31
+ the gradients norms in their own subfolder instead of the logs one.
32
+
33
+ Return:
34
+ norms: The dictionary of p-norms of each parameter's gradient and
35
+ a special entry for the total p-norm of the gradients viewed
36
+ as a single vector.
37
+ """
38
+ norm_type = float(norm_type)
39
+ if norm_type <= 0:
40
+ raise ValueError(
41
+ f"`norm_type` must be a positive number or 'inf' (infinity norm). Got {norm_type}"
42
+ )
43
+
44
+ if grad:
45
+ norms = {
46
+ f"grad_{norm_type}_norm{group_separator}{name}": p.grad.data.norm(norm_type)
47
+ for name, p in module.named_parameters()
48
+ if p.grad is not None
49
+ }
50
+ if norms:
51
+ total_norm = torch.tensor(list(norms.values())).norm(norm_type)
52
+ norms[f"grad_{norm_type}_norm_total"] = total_norm
53
+ else:
54
+ norms = {
55
+ f"param_{norm_type}_norm{group_separator}{name}": p.data.norm(norm_type)
56
+ for name, p in module.named_parameters()
57
+ if p.grad is not None
58
+ }
59
+ if norms:
60
+ total_norm = torch.tensor(list(norms.values())).norm(norm_type)
61
+ norms[f"param_{norm_type}_norm_total"] = total_norm
62
+
63
+ return norms
64
+
65
+
66
+ def _to_norm_type(log_grad_norm_per_param: float | str | Literal[True]):
67
+ norm_type = 2.0
68
+ if log_grad_norm_per_param is not True:
69
+ norm_type = log_grad_norm_per_param
70
+ return norm_type
71
+
72
+
73
+ def compute_norm(
74
+ pl_module: LightningModule,
75
+ optimizer: Optimizer | None = None,
76
+ p: float | str = 2.0,
77
+ *,
78
+ grad: bool,
79
+ ) -> torch.Tensor:
80
+ if optimizer is not None:
81
+ tensors = [
82
+ cast(torch.Tensor, p.grad if grad else p)
83
+ for group in optimizer.param_groups
84
+ for p in group["params"]
85
+ if p.grad is not None
86
+ ]
87
+ else:
88
+ tensors = [
89
+ p.grad if grad else p for p in pl_module.parameters() if p.grad is not None
90
+ ]
91
+
92
+ if not tensors:
93
+ return torch.tensor(0.0, device=pl_module.device)
94
+
95
+ return torch.norm(torch.stack([torch.norm(g, p=p) for g in tensors]), p=p)
96
+
97
+
98
+ class NormLoggingCallback(Callback):
99
+ def __init__(self, config: "NormLoggingConfig"):
100
+ super().__init__()
101
+
102
+ self.config = config
103
+
104
+ def _perform_norm_logging(
105
+ self,
106
+ pl_module: LightningModule,
107
+ optimizer: Optimizer,
108
+ prefix: str,
109
+ ):
110
+ # Gradient norm logging
111
+ if log_grad_norm := self.config.log_grad_norm:
112
+ norm = compute_norm(
113
+ pl_module,
114
+ optimizer,
115
+ _to_norm_type(log_grad_norm),
116
+ grad=True,
117
+ )
118
+ pl_module.log(f"{prefix}grad_norm", norm, on_step=True, on_epoch=False)
119
+ if log_grad_norm_per_param := self.config.log_grad_norm_per_param:
120
+ norm_type = _to_norm_type(log_grad_norm_per_param)
121
+ pl_module.log_dict(
122
+ {
123
+ f"{prefix}{k}": v
124
+ for k, v in grad_norm(pl_module, norm_type, grad=True).items()
125
+ }
126
+ )
127
+
128
+ # Parameter norm logging
129
+ if log_param_norm := self.config.log_param_norm:
130
+ norm = compute_norm(
131
+ pl_module,
132
+ optimizer,
133
+ _to_norm_type(log_param_norm),
134
+ grad=False,
135
+ )
136
+ pl_module.log(f"{prefix}param_norm", norm, on_step=True, on_epoch=False)
137
+ if log_param_norm_per_param := self.config.log_param_norm_per_param:
138
+ norm_type = _to_norm_type(log_param_norm_per_param)
139
+ pl_module.log_dict(
140
+ {
141
+ f"{prefix}{k}": v
142
+ for k, v in grad_norm(pl_module, norm_type, grad=False).items()
143
+ }
144
+ )
145
+
146
+ @override
147
+ def on_after_backward(self, trainer: Trainer, pl_module: LightningModule):
148
+ if len(trainer.optimizers) == 1:
149
+ optimizer = trainer.optimizers[0]
150
+ self._perform_norm_logging(pl_module, optimizer, prefix="train/")
151
+ else:
152
+ for i, optimizer in enumerate(trainer.optimizers):
153
+ self._perform_norm_logging(
154
+ pl_module, optimizer, prefix=f"train/optimizer_{i}/"
155
+ )
156
+
157
+
158
+ class NormLoggingConfig(CallbackConfigBase):
159
+ name: Literal["norm_logging"] = "norm_logging"
160
+
161
+ log_grad_norm: bool | str | float = False
162
+ """If enabled, will log the gradient norm (averaged across all model parameters) to the logger."""
163
+ log_grad_norm_per_param: bool | str | float = False
164
+ """If enabled, will log the gradient norm for each model parameter to the logger."""
165
+
166
+ log_param_norm: bool | str | float = False
167
+ """If enabled, will log the parameter norm (averaged across all model parameters) to the logger."""
168
+ log_param_norm_per_param: bool | str | float = False
169
+ """If enabled, will log the parameter norm for each model parameter to the logger."""
170
+
171
+ def __bool__(self):
172
+ return any(
173
+ v
174
+ for v in (
175
+ self.log_grad_norm,
176
+ self.log_grad_norm_per_param,
177
+ self.log_param_norm,
178
+ self.log_param_norm_per_param,
179
+ )
180
+ )
181
+
182
+ @override
183
+ def construct_callbacks(self, root_config):
184
+ if not self:
185
+ return
186
+
187
+ yield NormLoggingCallback(self)
@@ -0,0 +1,44 @@
1
+ import datetime
2
+ import logging
3
+ import os
4
+ from typing import Any
5
+
6
+ from lightning.pytorch import Trainer
7
+ from lightning.pytorch.callbacks import OnExceptionCheckpoint as _OnExceptionCheckpoint
8
+ from typing_extensions import override
9
+
10
+ log = logging.getLogger(__name__)
11
+
12
+
13
+ class OnExceptionCheckpoint(_OnExceptionCheckpoint):
14
+ @property
15
+ @override
16
+ def ckpt_path(self) -> str:
17
+ ckpt_path = super().ckpt_path
18
+
19
+ # Remve the extension and add the current timestamp
20
+ timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
21
+ ckpt_path, ext = os.path.splitext(ckpt_path)
22
+ return f"{ckpt_path}_{timestamp}{ext}"
23
+
24
+ @override
25
+ def on_exception(self, trainer: Trainer, *_: Any, **__: Any) -> None:
26
+ # We override this to checkpoint the model manually,
27
+ # without calling the dist barrier.
28
+
29
+ # trainer.save_checkpoint(self.ckpt_path)
30
+
31
+ if trainer.model is None:
32
+ raise AttributeError(
33
+ "Saving a checkpoint is only possible if a model is attached to the Trainer. Did you call"
34
+ " `Trainer.save_checkpoint()` before calling `Trainer.{fit,validate,test,predict}`?"
35
+ )
36
+ checkpoint = trainer._checkpoint_connector.dump_checkpoint(weights_only=False)
37
+ trainer.strategy.save_checkpoint(
38
+ checkpoint, self.ckpt_path, storage_options=None
39
+ )
40
+ # self.strategy.barrier("Trainer.save_checkpoint") # <-- This is disabled
41
+
42
+ @override
43
+ def teardown(self, trainer: Trainer, *_: Any, **__: Any) -> None:
44
+ trainer.strategy.remove_checkpoint(self.ckpt_path)