aimnet 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aimnet/train/pt2jpt.py ADDED
@@ -0,0 +1,81 @@
1
+ import os
2
+ from typing import List, Optional
3
+
4
+ import click
5
+ import torch
6
+ from torch import nn
7
+
8
+ from aimnet.config import build_module, load_yaml
9
+
10
+
11
+ def set_eval(model: nn.Module) -> torch.nn.Module:
12
+ for p in model.parameters():
13
+ p.requires_grad_(False)
14
+ return model.eval()
15
+
16
+
17
+ def add_cutoff(
18
+ model: nn.Module, cutoff: Optional[float] = None, cutoff_lr: Optional[float] = float("inf")
19
+ ) -> nn.Module:
20
+ if cutoff is None:
21
+ cutoff = max(v.item() for k, v in model.state_dict().items() if k.endswith("aev.rc_s"))
22
+ model.cutoff = cutoff # type: ignore[assignment]
23
+ if cutoff_lr is not None:
24
+ model.cutoff_lr = cutoff_lr # type: ignore[assignment]
25
+ return model
26
+
27
+
28
+ def add_sae_to_shifts(model: nn.Module, sae_file: str) -> nn.Module:
29
+ sae = load_yaml(sae_file)
30
+ if not isinstance(sae, dict):
31
+ raise TypeError("SAE file must contain a dictionary.")
32
+ model.outputs.atomic_shift.double()
33
+ for k, v in sae.items():
34
+ model.outputs.atomic_shift.shifts.weight[k] += v
35
+ return model
36
+
37
+
38
+ def mask_not_implemented_species(model: nn.Module, species: List[int]) -> nn.Module:
39
+ weight = model.afv.weight
40
+ for i in range(1, weight.shape[0]):
41
+ if i not in species:
42
+ weight[i, :] = torch.nan
43
+ return model
44
+
45
+
46
+ _default_aimnet2_config = os.path.join(os.path.dirname(__file__), "..", "models", "aimnet2.yaml")
47
+
48
+
49
+ @click.command(short_help="Compile PyTorch model to TorchScript.")
50
+ @click.argument("pt", type=str) # , help='Path to the input PyTorch weights file.')
51
+ @click.argument("jpt", type=str) # , help='Path to the output TorchScript file.')
52
+ @click.option("--model", type=str, default=_default_aimnet2_config, help="Path to model definition YAML file")
53
+ @click.option("--sae", type=str, default=None, help="Path to the energy shift YAML file.")
54
+ @click.option("--species", type=str, default=None, help="Comma-separated list of parametrized atomic numbers.")
55
+ @click.option("--no-lr", is_flag=True, help="Do not add LR cutoff for model")
56
+ def jitcompile(model: str, pt: str, jpt: str, sae=None, species=None, no_lr=False): # type: ignore
57
+ """Build model from YAML config, load weight from PT file and write JIT-compiled JPT file.
58
+ Plus some modifications to work with aimnet2calc.
59
+ """
60
+ model: nn.Module = build_module(model) # type: ignore[annotation-unchecked]
61
+ model = set_eval(model)
62
+ cutoff_lr = None if no_lr else float("inf")
63
+ model = add_cutoff(model, cutoff_lr=cutoff_lr)
64
+ sd = torch.load(pt, map_location="cpu", weights_only=True)
65
+ print(model.load_state_dict(sd, strict=False))
66
+ if sae:
67
+ model = add_sae_to_shifts(model, sae)
68
+ numbers = None
69
+ if species:
70
+ numbers = list(map(int, species.split(",")))
71
+ elif sae:
72
+ numbers = list(load_yaml(sae).keys()) # type: ignore[union-attr]
73
+ if numbers:
74
+ model = mask_not_implemented_species(model, numbers) # type: ignore[call-arg]
75
+ model.register_buffer("impemented_species", torch.tensor(numbers, dtype=torch.int64))
76
+ model_jit = torch.jit.script(model)
77
+ model_jit.save(jpt)
78
+
79
+
80
+ if __name__ == "__main__":
81
+ jitcompile()
aimnet/train/train.py ADDED
@@ -0,0 +1,155 @@
1
+ import logging
2
+ import os
3
+
4
+ import click
5
+ import omegaconf
6
+ import torch
7
+ from omegaconf import OmegaConf
8
+
9
+ from aimnet.train import utils
10
+
11
+ _default_model = os.path.join(os.path.dirname(__file__), "..", "models", "aimnet2.yaml")
12
+ _default_config = os.path.join(os.path.dirname(__file__), "default_train.yaml")
13
+
14
+
15
+ @click.command()
16
+ @click.option(
17
+ "--config",
18
+ type=click.Path(exists=True),
19
+ default=None,
20
+ multiple=True,
21
+ help="Path to the extra configuration file (overrides values, could be preficied multiple times).",
22
+ )
23
+ @click.option(
24
+ "--model", type=click.Path(exists=True), default=_default_model, help="Path to the model definition file."
25
+ )
26
+ @click.option("--load", type=click.Path(exists=True), default=None, help="Path to the model weights to load.")
27
+ @click.option("--save", type=click.Path(), default=None, help="Path to save the model weights.")
28
+ @click.option(
29
+ "--no-default-config",
30
+ is_flag=True,
31
+ default=False,
32
+ )
33
+ @click.argument("args", type=str, nargs=-1)
34
+ def train(config, model, load=None, save=None, args=None, no_default_config=False):
35
+ """Train AIMNet2 model.
36
+ By default, will load AIMNet2 model and default train config.
37
+ ARGS are one or more parameters wo overwrite in config in a dot-separated form.
38
+ For example: `train.data=mydataset.h5`.
39
+ """
40
+ logging.basicConfig(level=logging.INFO)
41
+
42
+ # model config
43
+ logging.info("Start training")
44
+ logging.info(f"Using model definition: {model}")
45
+ model_cfg = OmegaConf.load(model)
46
+ logging.info("--- START model.yaml ---")
47
+ model_yaml = OmegaConf.to_yaml(model_cfg)
48
+ logging.info(model_yaml)
49
+ logging.info("--- END model.yaml ---")
50
+
51
+ # train config
52
+ if not no_default_config:
53
+ logging.info(f"Using default training configuration: {_default_config}")
54
+ train_cfg = OmegaConf.load(_default_config)
55
+ else:
56
+ train_cfg = OmegaConf.create()
57
+
58
+ for cfg in config:
59
+ logging.info(f"Using configuration: {cfg}")
60
+ train_cfg = OmegaConf.merge(train_cfg, OmegaConf.load(cfg))
61
+
62
+ if args:
63
+ logging.info("Overriding configuration:")
64
+ for arg in args:
65
+ logging.info(arg)
66
+ args_cfg = OmegaConf.from_dotlist(args)
67
+ train_cfg = OmegaConf.merge(train_cfg, args_cfg)
68
+ logging.info("--- START train.yaml ---")
69
+ train_cfg = OmegaConf.to_yaml(train_cfg)
70
+ logging.info(train_cfg)
71
+ logging.info("--- END train.yaml ---")
72
+
73
+ # try load model and pring its configuration
74
+ logging.info("Building model")
75
+ model = utils.build_model(model_cfg)
76
+ logging.info(model)
77
+
78
+ # launch
79
+ num_gpus = torch.cuda.device_count()
80
+ logging.info(f"Start training using {num_gpus} GPU(s):")
81
+ for i in range(num_gpus):
82
+ logging.info(torch.cuda.get_device_name(i))
83
+ if num_gpus == 0:
84
+ logging.warning("No GPU available. Training will run on CPU. Use for testing only.")
85
+ if num_gpus > 1:
86
+ logging.info("Using DDP training.")
87
+ from ignite import distributed as idist
88
+
89
+ with idist.Parallel(backend="nccl", nproc_per_node=num_gpus) as parallel: # type: ignore[attr-defined]
90
+ parallel.run(run, num_gpus, model_cfg, train_cfg, load, save)
91
+ else:
92
+ run(0, 1, model_cfg, train_cfg, load, save)
93
+
94
+
95
+ def run(local_rank, world_size, model_cfg, train_cfg, load, save):
96
+ if local_rank == 0:
97
+ logging.basicConfig(level=logging.INFO)
98
+ else:
99
+ logging.basicConfig(level=logging.ERROR)
100
+
101
+ # load configs
102
+ model_cfg = OmegaConf.create(model_cfg)
103
+ if not isinstance(model_cfg, omegaconf.DictConfig):
104
+ raise TypeError("Model configuration must be a dictionary.")
105
+ train_cfg = OmegaConf.create(train_cfg)
106
+ if not isinstance(train_cfg, omegaconf.DictConfig):
107
+ raise TypeError("Train configuration must be a dictionary.")
108
+
109
+ # build model
110
+ _force_training = "forces" in train_cfg.data.y
111
+ model = utils.build_model(model_cfg, forces=_force_training)
112
+ if world_size > 1:
113
+ from ignite import distributed as idist
114
+
115
+ model = idist.auto_model(model) # type: ignore[attr-defined]
116
+
117
+ # load weights
118
+ if load is not None:
119
+ device = next(model.parameters()).device # type: ignore[attr-defined]
120
+ logging.info(f"Loading weights from file {load}")
121
+ sd = torch.load(load, map_location=device)
122
+ logging.info(utils.unwrap_module(model).load_state_dict(sd, strict=False))
123
+
124
+ # data loaders
125
+ train_loader, val_loader = utils.get_loaders(train_cfg.data)
126
+
127
+ # optimizer, scheduler, etc
128
+ model = utils.set_trainable_parameters(
129
+ model, # type: ignore[attr-defined]
130
+ train_cfg.optimizer.force_train,
131
+ train_cfg.optimizer.force_no_train,
132
+ )
133
+ optimizer = utils.get_optimizer(model, train_cfg.optimizer)
134
+ if world_size > 1:
135
+ optimizer = idist.auto_optim(optimizer) # type: ignore[attr-defined]
136
+ scheduler = utils.get_scheduler(optimizer, train_cfg.scheduler) if train_cfg.scheduler is not None else None # type: ignore[attr-defined]
137
+ loss = utils.get_loss(train_cfg.loss)
138
+ metrics = utils.get_metrics(train_cfg.metrics)
139
+ metrics.attach_loss(loss) # type: ignore[attr-defined]
140
+
141
+ # ignite engine
142
+ trainer, validator = utils.build_engine(model, optimizer, scheduler, loss, metrics, train_cfg, val_loader)
143
+
144
+ if local_rank == 0 and train_cfg.wandb is not None:
145
+ utils.setup_wandb(train_cfg, model_cfg, model, trainer, validator, optimizer)
146
+
147
+ trainer.run(train_loader, max_epochs=train_cfg.trainer.epochs)
148
+
149
+ if local_rank == 0 and save is not None:
150
+ logging.info(f"Saving model weights to file {save}")
151
+ torch.save(utils.unwrap_module(model).state_dict(), save)
152
+
153
+
154
+ if __name__ == "__main__":
155
+ train()
aimnet/train/utils.py ADDED
@@ -0,0 +1,398 @@
1
+ import logging
2
+ import os
3
+ import re
4
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
5
+
6
+ import numpy as np
7
+ import omegaconf
8
+ import torch
9
+ from ignite import distributed as idist
10
+ from ignite.engine import Engine, Events
11
+ from ignite.handlers import ModelCheckpoint, ProgressBar, TerminateOnNan, global_step_from_engine
12
+ from omegaconf import OmegaConf
13
+ from torch import Tensor, nn
14
+
15
+ from aimnet.config import build_module, get_init_module, get_module, load_yaml
16
+ from aimnet.data import SizeGroupedDataset
17
+ from aimnet.modules import Forces
18
+
19
+
20
+ def enable_tf32(enable=True):
21
+ if enable:
22
+ torch.backends.cuda.matmul.allow_tf32 = True
23
+ torch.backends.cudnn.allow_tf32 = True
24
+ else:
25
+ torch.backends.cuda.matmul.allow_tf32 = False
26
+ torch.backends.cudnn.allow_tf32 = False
27
+
28
+
29
+ def make_seed(all_reduce=True):
30
+ # create seed
31
+ seed = int.from_bytes(os.urandom(2), "big")
32
+ if all_reduce and idist.get_world_size() > 1:
33
+ seed = idist.all_reduce(seed)
34
+
35
+
36
+ def load_dataset(cfg: omegaconf.DictConfig, kind="train"):
37
+ # only load required subset of keys
38
+ keys = list(cfg.x) + list(cfg.y)
39
+ # in DDP setting, will only load 1/WORLD_SIZE of the data
40
+ if idist.get_world_size() > 1 and not cfg.ddp_load_full_dataset:
41
+ shard = (idist.get_local_rank(), idist.get_world_size())
42
+ else:
43
+ shard = None
44
+
45
+ extra_kwargs = {
46
+ "keys": keys,
47
+ "shard": shard,
48
+ }
49
+ cfg.datasets[kind].kwargs.update(extra_kwargs)
50
+ cfg.datasets[kind].args = [cfg[kind]]
51
+ ds = build_module(OmegaConf.to_container(cfg.datasets[kind])) # type: ignore[arg-type]
52
+ ds = apply_sae(ds, cfg) # type: ignore[arg-type]
53
+ return ds
54
+
55
+
56
+ def apply_sae(ds: SizeGroupedDataset, cfg: omegaconf.DictConfig):
57
+ for k, c in cfg.sae.items():
58
+ if c is not None and k in cfg.y:
59
+ sae = load_yaml(c.file)
60
+ unique_numbers = set(np.unique(ds.concatenate("numbers").tolist()))
61
+ if not set(sae.keys()).issubset(unique_numbers): # type: ignore[attr-defined]
62
+ raise ValueError(f"Keys in SAE file {c.file} do not cover all the dataset atoms")
63
+ if c.mode == "linreg":
64
+ ds.apply_peratom_shift(k, k, sap_dict=sae)
65
+ elif c.mode == "logratio":
66
+ ds.apply_pertype_logratio(k, k, sap_dict=sae)
67
+ else:
68
+ raise ValueError(f"Unknown SAE mode {c.mode}")
69
+ for g in ds.groups:
70
+ g[k] = g[k].astype("float32")
71
+ return ds
72
+
73
+
74
+ def get_sampler(ds: SizeGroupedDataset, cfg: omegaconf.DictConfig, kind="train"):
75
+ d = OmegaConf.to_container(cfg.samplers[kind])
76
+ if not isinstance(d, dict):
77
+ raise TypeError("Sampler configuration must be a dictionary.")
78
+ if "kwargs" not in d:
79
+ d["kwargs"] = {}
80
+ d["kwargs"]["ds"] = ds
81
+ sampler = build_module(d)
82
+ return sampler
83
+
84
+
85
+ def log_ds_group_sizes(ds):
86
+ logging.info("Group sizes")
87
+ for _n, g in ds.items():
88
+ logging.info(f"{_n:03d}: {len(g)}")
89
+
90
+
91
+ def get_loaders(cfg: omegaconf.DictConfig):
92
+ ds_train: SizeGroupedDataset
93
+ # load datasets
94
+ ds_train = load_dataset(cfg, kind="train")
95
+ logging.info(f"Loaded train dataset from {cfg.train} with {len(ds_train)} samples.")
96
+ log_ds_group_sizes(ds_train)
97
+ if cfg.val is not None:
98
+ ds_val = load_dataset(cfg, kind="val")
99
+ logging.info(f"Loaded validation dataset from {cfg.val} with {len(ds_val)} samples.")
100
+ else:
101
+ if cfg.separate_val:
102
+ ds_train, ds_val = ds_train.random_split(1 - cfg.val_fraction, cfg.val_fraction)
103
+ logging.info(
104
+ f"Randomly train dataset into train and val datasets, sizes {len(ds_train)} and {len(ds_val)} {cfg.val_fraction * 100:.1f}%."
105
+ )
106
+ else:
107
+ ds_val = ds_train.random_split(cfg.val_fraction)[0]
108
+ logging.info(
109
+ f"Using a random fraction ({cfg.val_fraction * 100:.1f}%, {len(ds_val)} samples) of train dataset for validation."
110
+ )
111
+
112
+ # merge small groups
113
+ ds_train.merge_groups(
114
+ min_size=8 * cfg.samplers.train.kwargs.batch_size, mode_atoms=cfg.samplers.train.kwargs.batch_mode == "atoms"
115
+ )
116
+ logging.info("After merging small groups in train dataset")
117
+ log_ds_group_sizes(ds_train)
118
+
119
+ loader_train = ds_train.get_loader(get_sampler(ds_train, cfg, kind="train"), cfg.x, cfg.y, **cfg.loaders.train)
120
+ loader_val = ds_val.get_loader(get_sampler(ds_val, cfg, kind="val"), cfg.x, cfg.y, **cfg.loaders.val)
121
+ return loader_train, loader_val
122
+
123
+
124
+ def get_optimizer(model: nn.Module, cfg: omegaconf.DictConfig):
125
+ logging.info("Building optimizer")
126
+ param_groups = {}
127
+ for k, c in cfg.param_groups.items():
128
+ c = OmegaConf.to_container(c)
129
+ if not isinstance(c, dict):
130
+ raise TypeError("Param groups must be a dictionary.")
131
+ c.pop("re")
132
+ param_groups[k] = {"params": [], **c}
133
+ param_groups["default"] = {"params": []}
134
+ logging.info(f"Default parameters: {cfg.kwargs}")
135
+ for n, p in model.named_parameters():
136
+ if not p.requires_grad:
137
+ continue
138
+ _matched = False
139
+ for k, c in cfg.param_groups.items():
140
+ if re.search(c.re, n):
141
+ param_groups[k]["params"].append(p)
142
+ logging.info(f"{n}: {c}")
143
+ _matched = True
144
+ break
145
+ if not _matched:
146
+ param_groups["default"]["params"].append(p)
147
+ d = OmegaConf.to_container(cfg)
148
+ if not isinstance(d, dict):
149
+ raise TypeError("Optimizer configuration must be a dictionary.")
150
+ d["args"] = [[v for v in param_groups.values() if len(v["params"])]]
151
+ optimizer = get_init_module(d["class"], d["args"], d["kwargs"])
152
+ logging.info(f"Optimizer: {optimizer}")
153
+ logging.info("Trainable parameters:")
154
+ N = 0
155
+ for n, p in model.named_parameters():
156
+ if p.requires_grad:
157
+ logging.info(f"{n}: {p.shape}")
158
+ N += p.numel()
159
+ logging.info(f"Total number of trainable parameters: {N}")
160
+ return optimizer
161
+
162
+
163
+ def get_scheduler(optimizer: torch.optim.Optimizer, cfg: omegaconf.DictConfig):
164
+ d = OmegaConf.to_container(cfg)
165
+ if not isinstance(d, dict):
166
+ raise TypeError("Scheduler configuration must be a dictionary.")
167
+ d["args"] = [optimizer]
168
+ scheduler = build_module(d)
169
+ return scheduler
170
+
171
+
172
+ def get_loss(cfg: omegaconf.DictConfig):
173
+ d = OmegaConf.to_container(cfg)
174
+ if not isinstance(d, dict):
175
+ raise TypeError("Loss configuration must be a dictionary.")
176
+ loss = build_module(d)
177
+ return loss
178
+
179
+
180
+ def set_trainable_parameters(model: nn.Module, force_train: List[str], force_no_train: List[str]) -> nn.Module:
181
+ for n, p in model.named_parameters():
182
+ if any(re.search(x, n) for x in force_no_train):
183
+ p.requires_grad_(False)
184
+ logging.info(f"requires_grad {n} {p.requires_grad}")
185
+ if any(re.search(x, n) for x in force_train):
186
+ p.requires_grad_(True)
187
+ logging.info(f"requires_grad {n} {p.requires_grad}")
188
+ return model
189
+
190
+
191
+ def unwrap_module(net):
192
+ if isinstance(net, (Forces, torch.nn.parallel.DistributedDataParallel)):
193
+ net = net.module
194
+ return unwrap_module(net)
195
+ else:
196
+ return net
197
+
198
+
199
+ def build_model(cfg, forces=False):
200
+ d = OmegaConf.to_container(cfg)
201
+ if not isinstance(d, dict):
202
+ raise TypeError("Model configuration must be a dictionary.")
203
+ model = build_module(d)
204
+ if forces is not None:
205
+ model = Forces(model) # type: ignore[attr-defined]
206
+ return model
207
+
208
+
209
+ def get_metrics(cfg: omegaconf.DictConfig):
210
+ d = OmegaConf.to_container(cfg)
211
+ if not isinstance(d, dict):
212
+ raise TypeError("Metrics configuration must be a dictionary.")
213
+ metrics = build_module(d)
214
+ return metrics
215
+
216
+
217
+ def train_step(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
218
+ global model
219
+ global optimizer
220
+ global prepare_batch
221
+ global loss_fn
222
+ global device
223
+
224
+ model.train() # type: ignore
225
+ optimizer.zero_grad() # type: ignore
226
+ x, y = prepare_batch(batch, device=device, non_blocking=True) # type: ignore
227
+ y_pred = model(x) # type: ignore
228
+ loss = loss_fn(y_pred, y)["loss"] # type: ignore
229
+ loss.backward()
230
+ optimizer.step() # type: ignore
231
+
232
+ return loss.item()
233
+
234
+
235
+ def val_step(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
236
+ global model
237
+ global optimizer
238
+ global prepare_batch
239
+ global loss_fn
240
+ global device
241
+
242
+ model.eval() # type: ignore
243
+ if not next(iter(batch[0].values())).numel():
244
+ return None
245
+ x, y = prepare_batch(batch, device=device, non_blocking=True) # type: ignore
246
+ with torch.no_grad():
247
+ y_pred = model(x) # type: ignore
248
+ return y_pred, y
249
+
250
+
251
+ def prepare_batch(batch: Dict[str, Tensor], device="cuda", non_blocking=True) -> Dict[str, Tensor]: # noqa: F811
252
+ for k, v in batch.items():
253
+ batch[k] = v.to(device, non_blocking=non_blocking)
254
+ return batch
255
+
256
+
257
+ def default_trainer(
258
+ model: torch.nn.Module,
259
+ optimizer: torch.optim.Optimizer,
260
+ loss_fn: Union[Callable, torch.nn.Module],
261
+ device: Optional[Union[str, torch.device]] = None,
262
+ non_blocking: bool = True,
263
+ ) -> Engine:
264
+ def _update(engine: Engine, batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]]) -> float:
265
+ model.train()
266
+ optimizer.zero_grad()
267
+ x = prepare_batch(batch[0], device=device, non_blocking=non_blocking) # type: ignore
268
+ y = prepare_batch(batch[1], device=device, non_blocking=non_blocking) # type: ignore
269
+ y_pred = model(x)
270
+ loss = loss_fn(y_pred, y)["loss"]
271
+ loss.backward()
272
+ torch.nn.utils.clip_grad_value_(model.parameters(), 0.4)
273
+ optimizer.step()
274
+ return loss.item()
275
+
276
+ return Engine(_update)
277
+
278
+
279
+ def default_evaluator(
280
+ model: torch.nn.Module, device: Optional[Union[str, torch.device]] = None, non_blocking: bool = True
281
+ ) -> Engine:
282
+ def _inference(
283
+ engine: Engine, batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]]
284
+ ) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]]:
285
+ model.eval()
286
+ x = prepare_batch(batch[0], device=device, non_blocking=non_blocking) # type: ignore
287
+ y = prepare_batch(batch[1], device=device, non_blocking=non_blocking) # type: ignore
288
+ with torch.no_grad():
289
+ y_pred = model(x)
290
+ return y_pred, y
291
+
292
+ return Engine(_inference)
293
+
294
+
295
+ class TerminateOnLowLR:
296
+ def __init__(self, optimizer, low_lr=1e-5):
297
+ self.low_lr = low_lr
298
+ self.optimizer = optimizer
299
+
300
+ def __call__(self, engine):
301
+ if self.optimizer.param_groups[0]["lr"] < self.low_lr:
302
+ engine.terminate()
303
+
304
+
305
+ def build_engine(model, optimizer, scheduler, loss_fn, metrics, cfg, loader_val):
306
+ device = next(model.parameters()).device
307
+
308
+ train_fn = get_module(cfg.trainer.trainer)
309
+ trainer = train_fn(model, optimizer, loss_fn, device=device, non_blocking=True)
310
+ # check for NaNs after each epoch
311
+ trainer.add_event_handler(Events.EPOCH_COMPLETED, TerminateOnNan())
312
+
313
+ # log LR
314
+ def log_lr(engine):
315
+ lr = optimizer.param_groups[0]["lr"]
316
+ logging.info(f"LR: {lr}")
317
+
318
+ trainer.add_event_handler(Events.EPOCH_STARTED, log_lr)
319
+ # write TQDM progress
320
+ if idist.get_local_rank() == 0:
321
+ pbar = ProgressBar()
322
+ pbar.attach(trainer, event_name=Events.ITERATION_COMPLETED(every=100))
323
+
324
+ # attach validator
325
+ validate_fn = get_module(cfg.trainer.evaluator)
326
+ validator = validate_fn(model, device=device, non_blocking=True)
327
+ metrics.attach(validator, "multi")
328
+ trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1), validator.run, data=loader_val)
329
+
330
+ # scheduler
331
+ if scheduler is not None:
332
+ validator.add_event_handler(Events.COMPLETED, scheduler)
333
+ terminator = TerminateOnLowLR(optimizer, cfg.scheduler.terminate_on_low_lr)
334
+ trainer.add_event_handler(Events.EPOCH_STARTED, terminator)
335
+
336
+ # checkpoint after each epoch
337
+ if cfg.checkpoint and idist.get_local_rank() == 0:
338
+ kwargs = OmegaConf.to_container(cfg.checkpoint.kwargs) if "kwargs" not in cfg.checkpoint else {}
339
+ if not isinstance(kwargs, dict):
340
+ raise TypeError("Checkpoint kwargs must be a dictionary.")
341
+ kwargs["global_step_transform"] = global_step_from_engine(trainer)
342
+ kwargs["dirname"] = cfg.checkpoint.dirname
343
+ kwargs["filename_prefix"] = cfg.checkpoint.filename_prefix
344
+ checkpointer = ModelCheckpoint(**kwargs) # type: ignore
345
+ validator.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {"model": unwrap_module(model)})
346
+
347
+ return trainer, validator
348
+
349
+
350
+ def setup_wandb(cfg, model_cfg, model, trainer, validator, optimizer):
351
+ import wandb
352
+ from ignite.handlers import WandBLogger, global_step_from_engine
353
+ from ignite.handlers.wandb_logger import OptimizerParamsHandler
354
+
355
+ init_kwargs = OmegaConf.to_container(cfg.wandb.init, resolve=True)
356
+ wandb.init(**init_kwargs) # type: ignore
357
+ wandb_logger = WandBLogger(init=False)
358
+
359
+ OmegaConf.save(model_cfg, wandb.run.dir + "/model.yaml") # type: ignore
360
+ OmegaConf.save(cfg, wandb.run.dir + "/train.yaml") # type: ignore
361
+
362
+ wandb_logger.attach_output_handler(
363
+ trainer,
364
+ event_name=Events.ITERATION_COMPLETED(every=200),
365
+ output_transform=lambda loss: {"loss": loss},
366
+ tag="train",
367
+ )
368
+ wandb_logger.attach_output_handler(
369
+ validator,
370
+ event_name=Events.EPOCH_COMPLETED,
371
+ global_step_transform=lambda *_: trainer.state.iteration,
372
+ metric_names="all",
373
+ tag="val",
374
+ )
375
+
376
+ class EpochLRLogger(OptimizerParamsHandler):
377
+ def __call__(self, engine, logger, event_name):
378
+ global_step = engine.state.iteration
379
+ params = {
380
+ f"{self.param_name}_{i}": float(g[self.param_name]) for i, g in enumerate(self.optimizer.param_groups)
381
+ }
382
+ logger.log(params, step=global_step, sync=self.sync)
383
+
384
+ wandb_logger.attach(trainer, log_handler=EpochLRLogger(optimizer), event_name=Events.EPOCH_STARTED)
385
+
386
+ score_function = lambda engine: 1.0 / engine.state.metrics["loss"]
387
+ model_checkpoint = ModelCheckpoint(
388
+ wandb.run.dir, # type: ignore
389
+ n_saved=1,
390
+ filename_prefix="best", # type: ignore
391
+ require_empty=False,
392
+ score_function=score_function,
393
+ global_step_transform=global_step_from_engine(trainer),
394
+ )
395
+ validator.add_event_handler(Events.EPOCH_COMPLETED, model_checkpoint, {"model": unwrap_module(model)})
396
+
397
+ if cfg.wandb.watch_model:
398
+ wandb.watch(unwrap_module(model), **OmegaConf.to_container(cfg.wandb.watch_model, resolve=True)) # type: ignore
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024, Roman Zubatyuk
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.