aimnet 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aimnet/__init__.py +0 -0
- aimnet/base.py +41 -0
- aimnet/calculators/__init__.py +15 -0
- aimnet/calculators/aimnet2ase.py +98 -0
- aimnet/calculators/aimnet2pysis.py +76 -0
- aimnet/calculators/calculator.py +320 -0
- aimnet/calculators/model_registry.py +60 -0
- aimnet/calculators/model_registry.yaml +33 -0
- aimnet/calculators/nb_kernel_cpu.py +222 -0
- aimnet/calculators/nb_kernel_cuda.py +217 -0
- aimnet/calculators/nbmat.py +220 -0
- aimnet/cli.py +22 -0
- aimnet/config.py +170 -0
- aimnet/constants.py +467 -0
- aimnet/data/__init__.py +1 -0
- aimnet/data/sgdataset.py +517 -0
- aimnet/dftd3_data.pt +0 -0
- aimnet/models/__init__.py +2 -0
- aimnet/models/aimnet2.py +188 -0
- aimnet/models/aimnet2.yaml +44 -0
- aimnet/models/aimnet2_dftd3_wb97m.yaml +51 -0
- aimnet/models/base.py +51 -0
- aimnet/modules/__init__.py +3 -0
- aimnet/modules/aev.py +201 -0
- aimnet/modules/core.py +237 -0
- aimnet/modules/lr.py +243 -0
- aimnet/nbops.py +151 -0
- aimnet/ops.py +208 -0
- aimnet/train/__init__.py +0 -0
- aimnet/train/calc_sae.py +43 -0
- aimnet/train/default_train.yaml +166 -0
- aimnet/train/loss.py +83 -0
- aimnet/train/metrics.py +188 -0
- aimnet/train/pt2jpt.py +81 -0
- aimnet/train/train.py +155 -0
- aimnet/train/utils.py +398 -0
- aimnet-0.0.1.dist-info/LICENSE +21 -0
- aimnet-0.0.1.dist-info/METADATA +78 -0
- aimnet-0.0.1.dist-info/RECORD +41 -0
- aimnet-0.0.1.dist-info/WHEEL +4 -0
- aimnet-0.0.1.dist-info/entry_points.txt +5 -0
aimnet/train/pt2jpt.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from typing import List, Optional
|
|
3
|
+
|
|
4
|
+
import click
|
|
5
|
+
import torch
|
|
6
|
+
from torch import nn
|
|
7
|
+
|
|
8
|
+
from aimnet.config import build_module, load_yaml
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def set_eval(model: nn.Module) -> torch.nn.Module:
|
|
12
|
+
for p in model.parameters():
|
|
13
|
+
p.requires_grad_(False)
|
|
14
|
+
return model.eval()
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def add_cutoff(
|
|
18
|
+
model: nn.Module, cutoff: Optional[float] = None, cutoff_lr: Optional[float] = float("inf")
|
|
19
|
+
) -> nn.Module:
|
|
20
|
+
if cutoff is None:
|
|
21
|
+
cutoff = max(v.item() for k, v in model.state_dict().items() if k.endswith("aev.rc_s"))
|
|
22
|
+
model.cutoff = cutoff # type: ignore[assignment]
|
|
23
|
+
if cutoff_lr is not None:
|
|
24
|
+
model.cutoff_lr = cutoff_lr # type: ignore[assignment]
|
|
25
|
+
return model
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def add_sae_to_shifts(model: nn.Module, sae_file: str) -> nn.Module:
|
|
29
|
+
sae = load_yaml(sae_file)
|
|
30
|
+
if not isinstance(sae, dict):
|
|
31
|
+
raise TypeError("SAE file must contain a dictionary.")
|
|
32
|
+
model.outputs.atomic_shift.double()
|
|
33
|
+
for k, v in sae.items():
|
|
34
|
+
model.outputs.atomic_shift.shifts.weight[k] += v
|
|
35
|
+
return model
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def mask_not_implemented_species(model: nn.Module, species: List[int]) -> nn.Module:
|
|
39
|
+
weight = model.afv.weight
|
|
40
|
+
for i in range(1, weight.shape[0]):
|
|
41
|
+
if i not in species:
|
|
42
|
+
weight[i, :] = torch.nan
|
|
43
|
+
return model
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
_default_aimnet2_config = os.path.join(os.path.dirname(__file__), "..", "models", "aimnet2.yaml")
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@click.command(short_help="Compile PyTorch model to TorchScript.")
|
|
50
|
+
@click.argument("pt", type=str) # , help='Path to the input PyTorch weights file.')
|
|
51
|
+
@click.argument("jpt", type=str) # , help='Path to the output TorchScript file.')
|
|
52
|
+
@click.option("--model", type=str, default=_default_aimnet2_config, help="Path to model definition YAML file")
|
|
53
|
+
@click.option("--sae", type=str, default=None, help="Path to the energy shift YAML file.")
|
|
54
|
+
@click.option("--species", type=str, default=None, help="Comma-separated list of parametrized atomic numbers.")
|
|
55
|
+
@click.option("--no-lr", is_flag=True, help="Do not add LR cutoff for model")
|
|
56
|
+
def jitcompile(model: str, pt: str, jpt: str, sae=None, species=None, no_lr=False): # type: ignore
|
|
57
|
+
"""Build model from YAML config, load weight from PT file and write JIT-compiled JPT file.
|
|
58
|
+
Plus some modifications to work with aimnet2calc.
|
|
59
|
+
"""
|
|
60
|
+
model: nn.Module = build_module(model) # type: ignore[annotation-unchecked]
|
|
61
|
+
model = set_eval(model)
|
|
62
|
+
cutoff_lr = None if no_lr else float("inf")
|
|
63
|
+
model = add_cutoff(model, cutoff_lr=cutoff_lr)
|
|
64
|
+
sd = torch.load(pt, map_location="cpu", weights_only=True)
|
|
65
|
+
print(model.load_state_dict(sd, strict=False))
|
|
66
|
+
if sae:
|
|
67
|
+
model = add_sae_to_shifts(model, sae)
|
|
68
|
+
numbers = None
|
|
69
|
+
if species:
|
|
70
|
+
numbers = list(map(int, species.split(",")))
|
|
71
|
+
elif sae:
|
|
72
|
+
numbers = list(load_yaml(sae).keys()) # type: ignore[union-attr]
|
|
73
|
+
if numbers:
|
|
74
|
+
model = mask_not_implemented_species(model, numbers) # type: ignore[call-arg]
|
|
75
|
+
model.register_buffer("impemented_species", torch.tensor(numbers, dtype=torch.int64))
|
|
76
|
+
model_jit = torch.jit.script(model)
|
|
77
|
+
model_jit.save(jpt)
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
if __name__ == "__main__":
|
|
81
|
+
jitcompile()
|
aimnet/train/train.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
import click
|
|
5
|
+
import omegaconf
|
|
6
|
+
import torch
|
|
7
|
+
from omegaconf import OmegaConf
|
|
8
|
+
|
|
9
|
+
from aimnet.train import utils
|
|
10
|
+
|
|
11
|
+
_default_model = os.path.join(os.path.dirname(__file__), "..", "models", "aimnet2.yaml")
|
|
12
|
+
_default_config = os.path.join(os.path.dirname(__file__), "default_train.yaml")
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@click.command()
|
|
16
|
+
@click.option(
|
|
17
|
+
"--config",
|
|
18
|
+
type=click.Path(exists=True),
|
|
19
|
+
default=None,
|
|
20
|
+
multiple=True,
|
|
21
|
+
help="Path to the extra configuration file (overrides values, could be preficied multiple times).",
|
|
22
|
+
)
|
|
23
|
+
@click.option(
|
|
24
|
+
"--model", type=click.Path(exists=True), default=_default_model, help="Path to the model definition file."
|
|
25
|
+
)
|
|
26
|
+
@click.option("--load", type=click.Path(exists=True), default=None, help="Path to the model weights to load.")
|
|
27
|
+
@click.option("--save", type=click.Path(), default=None, help="Path to save the model weights.")
|
|
28
|
+
@click.option(
|
|
29
|
+
"--no-default-config",
|
|
30
|
+
is_flag=True,
|
|
31
|
+
default=False,
|
|
32
|
+
)
|
|
33
|
+
@click.argument("args", type=str, nargs=-1)
|
|
34
|
+
def train(config, model, load=None, save=None, args=None, no_default_config=False):
|
|
35
|
+
"""Train AIMNet2 model.
|
|
36
|
+
By default, will load AIMNet2 model and default train config.
|
|
37
|
+
ARGS are one or more parameters wo overwrite in config in a dot-separated form.
|
|
38
|
+
For example: `train.data=mydataset.h5`.
|
|
39
|
+
"""
|
|
40
|
+
logging.basicConfig(level=logging.INFO)
|
|
41
|
+
|
|
42
|
+
# model config
|
|
43
|
+
logging.info("Start training")
|
|
44
|
+
logging.info(f"Using model definition: {model}")
|
|
45
|
+
model_cfg = OmegaConf.load(model)
|
|
46
|
+
logging.info("--- START model.yaml ---")
|
|
47
|
+
model_yaml = OmegaConf.to_yaml(model_cfg)
|
|
48
|
+
logging.info(model_yaml)
|
|
49
|
+
logging.info("--- END model.yaml ---")
|
|
50
|
+
|
|
51
|
+
# train config
|
|
52
|
+
if not no_default_config:
|
|
53
|
+
logging.info(f"Using default training configuration: {_default_config}")
|
|
54
|
+
train_cfg = OmegaConf.load(_default_config)
|
|
55
|
+
else:
|
|
56
|
+
train_cfg = OmegaConf.create()
|
|
57
|
+
|
|
58
|
+
for cfg in config:
|
|
59
|
+
logging.info(f"Using configuration: {cfg}")
|
|
60
|
+
train_cfg = OmegaConf.merge(train_cfg, OmegaConf.load(cfg))
|
|
61
|
+
|
|
62
|
+
if args:
|
|
63
|
+
logging.info("Overriding configuration:")
|
|
64
|
+
for arg in args:
|
|
65
|
+
logging.info(arg)
|
|
66
|
+
args_cfg = OmegaConf.from_dotlist(args)
|
|
67
|
+
train_cfg = OmegaConf.merge(train_cfg, args_cfg)
|
|
68
|
+
logging.info("--- START train.yaml ---")
|
|
69
|
+
train_cfg = OmegaConf.to_yaml(train_cfg)
|
|
70
|
+
logging.info(train_cfg)
|
|
71
|
+
logging.info("--- END train.yaml ---")
|
|
72
|
+
|
|
73
|
+
# try load model and pring its configuration
|
|
74
|
+
logging.info("Building model")
|
|
75
|
+
model = utils.build_model(model_cfg)
|
|
76
|
+
logging.info(model)
|
|
77
|
+
|
|
78
|
+
# launch
|
|
79
|
+
num_gpus = torch.cuda.device_count()
|
|
80
|
+
logging.info(f"Start training using {num_gpus} GPU(s):")
|
|
81
|
+
for i in range(num_gpus):
|
|
82
|
+
logging.info(torch.cuda.get_device_name(i))
|
|
83
|
+
if num_gpus == 0:
|
|
84
|
+
logging.warning("No GPU available. Training will run on CPU. Use for testing only.")
|
|
85
|
+
if num_gpus > 1:
|
|
86
|
+
logging.info("Using DDP training.")
|
|
87
|
+
from ignite import distributed as idist
|
|
88
|
+
|
|
89
|
+
with idist.Parallel(backend="nccl", nproc_per_node=num_gpus) as parallel: # type: ignore[attr-defined]
|
|
90
|
+
parallel.run(run, num_gpus, model_cfg, train_cfg, load, save)
|
|
91
|
+
else:
|
|
92
|
+
run(0, 1, model_cfg, train_cfg, load, save)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def run(local_rank, world_size, model_cfg, train_cfg, load, save):
|
|
96
|
+
if local_rank == 0:
|
|
97
|
+
logging.basicConfig(level=logging.INFO)
|
|
98
|
+
else:
|
|
99
|
+
logging.basicConfig(level=logging.ERROR)
|
|
100
|
+
|
|
101
|
+
# load configs
|
|
102
|
+
model_cfg = OmegaConf.create(model_cfg)
|
|
103
|
+
if not isinstance(model_cfg, omegaconf.DictConfig):
|
|
104
|
+
raise TypeError("Model configuration must be a dictionary.")
|
|
105
|
+
train_cfg = OmegaConf.create(train_cfg)
|
|
106
|
+
if not isinstance(train_cfg, omegaconf.DictConfig):
|
|
107
|
+
raise TypeError("Train configuration must be a dictionary.")
|
|
108
|
+
|
|
109
|
+
# build model
|
|
110
|
+
_force_training = "forces" in train_cfg.data.y
|
|
111
|
+
model = utils.build_model(model_cfg, forces=_force_training)
|
|
112
|
+
if world_size > 1:
|
|
113
|
+
from ignite import distributed as idist
|
|
114
|
+
|
|
115
|
+
model = idist.auto_model(model) # type: ignore[attr-defined]
|
|
116
|
+
|
|
117
|
+
# load weights
|
|
118
|
+
if load is not None:
|
|
119
|
+
device = next(model.parameters()).device # type: ignore[attr-defined]
|
|
120
|
+
logging.info(f"Loading weights from file {load}")
|
|
121
|
+
sd = torch.load(load, map_location=device)
|
|
122
|
+
logging.info(utils.unwrap_module(model).load_state_dict(sd, strict=False))
|
|
123
|
+
|
|
124
|
+
# data loaders
|
|
125
|
+
train_loader, val_loader = utils.get_loaders(train_cfg.data)
|
|
126
|
+
|
|
127
|
+
# optimizer, scheduler, etc
|
|
128
|
+
model = utils.set_trainable_parameters(
|
|
129
|
+
model, # type: ignore[attr-defined]
|
|
130
|
+
train_cfg.optimizer.force_train,
|
|
131
|
+
train_cfg.optimizer.force_no_train,
|
|
132
|
+
)
|
|
133
|
+
optimizer = utils.get_optimizer(model, train_cfg.optimizer)
|
|
134
|
+
if world_size > 1:
|
|
135
|
+
optimizer = idist.auto_optim(optimizer) # type: ignore[attr-defined]
|
|
136
|
+
scheduler = utils.get_scheduler(optimizer, train_cfg.scheduler) if train_cfg.scheduler is not None else None # type: ignore[attr-defined]
|
|
137
|
+
loss = utils.get_loss(train_cfg.loss)
|
|
138
|
+
metrics = utils.get_metrics(train_cfg.metrics)
|
|
139
|
+
metrics.attach_loss(loss) # type: ignore[attr-defined]
|
|
140
|
+
|
|
141
|
+
# ignite engine
|
|
142
|
+
trainer, validator = utils.build_engine(model, optimizer, scheduler, loss, metrics, train_cfg, val_loader)
|
|
143
|
+
|
|
144
|
+
if local_rank == 0 and train_cfg.wandb is not None:
|
|
145
|
+
utils.setup_wandb(train_cfg, model_cfg, model, trainer, validator, optimizer)
|
|
146
|
+
|
|
147
|
+
trainer.run(train_loader, max_epochs=train_cfg.trainer.epochs)
|
|
148
|
+
|
|
149
|
+
if local_rank == 0 and save is not None:
|
|
150
|
+
logging.info(f"Saving model weights to file {save}")
|
|
151
|
+
torch.save(utils.unwrap_module(model).state_dict(), save)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
if __name__ == "__main__":
|
|
155
|
+
train()
|
aimnet/train/utils.py
ADDED
|
@@ -0,0 +1,398 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
import re
|
|
4
|
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import omegaconf
|
|
8
|
+
import torch
|
|
9
|
+
from ignite import distributed as idist
|
|
10
|
+
from ignite.engine import Engine, Events
|
|
11
|
+
from ignite.handlers import ModelCheckpoint, ProgressBar, TerminateOnNan, global_step_from_engine
|
|
12
|
+
from omegaconf import OmegaConf
|
|
13
|
+
from torch import Tensor, nn
|
|
14
|
+
|
|
15
|
+
from aimnet.config import build_module, get_init_module, get_module, load_yaml
|
|
16
|
+
from aimnet.data import SizeGroupedDataset
|
|
17
|
+
from aimnet.modules import Forces
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def enable_tf32(enable=True):
|
|
21
|
+
if enable:
|
|
22
|
+
torch.backends.cuda.matmul.allow_tf32 = True
|
|
23
|
+
torch.backends.cudnn.allow_tf32 = True
|
|
24
|
+
else:
|
|
25
|
+
torch.backends.cuda.matmul.allow_tf32 = False
|
|
26
|
+
torch.backends.cudnn.allow_tf32 = False
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def make_seed(all_reduce=True):
|
|
30
|
+
# create seed
|
|
31
|
+
seed = int.from_bytes(os.urandom(2), "big")
|
|
32
|
+
if all_reduce and idist.get_world_size() > 1:
|
|
33
|
+
seed = idist.all_reduce(seed)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def load_dataset(cfg: omegaconf.DictConfig, kind="train"):
|
|
37
|
+
# only load required subset of keys
|
|
38
|
+
keys = list(cfg.x) + list(cfg.y)
|
|
39
|
+
# in DDP setting, will only load 1/WORLD_SIZE of the data
|
|
40
|
+
if idist.get_world_size() > 1 and not cfg.ddp_load_full_dataset:
|
|
41
|
+
shard = (idist.get_local_rank(), idist.get_world_size())
|
|
42
|
+
else:
|
|
43
|
+
shard = None
|
|
44
|
+
|
|
45
|
+
extra_kwargs = {
|
|
46
|
+
"keys": keys,
|
|
47
|
+
"shard": shard,
|
|
48
|
+
}
|
|
49
|
+
cfg.datasets[kind].kwargs.update(extra_kwargs)
|
|
50
|
+
cfg.datasets[kind].args = [cfg[kind]]
|
|
51
|
+
ds = build_module(OmegaConf.to_container(cfg.datasets[kind])) # type: ignore[arg-type]
|
|
52
|
+
ds = apply_sae(ds, cfg) # type: ignore[arg-type]
|
|
53
|
+
return ds
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def apply_sae(ds: SizeGroupedDataset, cfg: omegaconf.DictConfig):
|
|
57
|
+
for k, c in cfg.sae.items():
|
|
58
|
+
if c is not None and k in cfg.y:
|
|
59
|
+
sae = load_yaml(c.file)
|
|
60
|
+
unique_numbers = set(np.unique(ds.concatenate("numbers").tolist()))
|
|
61
|
+
if not set(sae.keys()).issubset(unique_numbers): # type: ignore[attr-defined]
|
|
62
|
+
raise ValueError(f"Keys in SAE file {c.file} do not cover all the dataset atoms")
|
|
63
|
+
if c.mode == "linreg":
|
|
64
|
+
ds.apply_peratom_shift(k, k, sap_dict=sae)
|
|
65
|
+
elif c.mode == "logratio":
|
|
66
|
+
ds.apply_pertype_logratio(k, k, sap_dict=sae)
|
|
67
|
+
else:
|
|
68
|
+
raise ValueError(f"Unknown SAE mode {c.mode}")
|
|
69
|
+
for g in ds.groups:
|
|
70
|
+
g[k] = g[k].astype("float32")
|
|
71
|
+
return ds
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def get_sampler(ds: SizeGroupedDataset, cfg: omegaconf.DictConfig, kind="train"):
|
|
75
|
+
d = OmegaConf.to_container(cfg.samplers[kind])
|
|
76
|
+
if not isinstance(d, dict):
|
|
77
|
+
raise TypeError("Sampler configuration must be a dictionary.")
|
|
78
|
+
if "kwargs" not in d:
|
|
79
|
+
d["kwargs"] = {}
|
|
80
|
+
d["kwargs"]["ds"] = ds
|
|
81
|
+
sampler = build_module(d)
|
|
82
|
+
return sampler
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def log_ds_group_sizes(ds):
|
|
86
|
+
logging.info("Group sizes")
|
|
87
|
+
for _n, g in ds.items():
|
|
88
|
+
logging.info(f"{_n:03d}: {len(g)}")
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def get_loaders(cfg: omegaconf.DictConfig):
|
|
92
|
+
ds_train: SizeGroupedDataset
|
|
93
|
+
# load datasets
|
|
94
|
+
ds_train = load_dataset(cfg, kind="train")
|
|
95
|
+
logging.info(f"Loaded train dataset from {cfg.train} with {len(ds_train)} samples.")
|
|
96
|
+
log_ds_group_sizes(ds_train)
|
|
97
|
+
if cfg.val is not None:
|
|
98
|
+
ds_val = load_dataset(cfg, kind="val")
|
|
99
|
+
logging.info(f"Loaded validation dataset from {cfg.val} with {len(ds_val)} samples.")
|
|
100
|
+
else:
|
|
101
|
+
if cfg.separate_val:
|
|
102
|
+
ds_train, ds_val = ds_train.random_split(1 - cfg.val_fraction, cfg.val_fraction)
|
|
103
|
+
logging.info(
|
|
104
|
+
f"Randomly train dataset into train and val datasets, sizes {len(ds_train)} and {len(ds_val)} {cfg.val_fraction * 100:.1f}%."
|
|
105
|
+
)
|
|
106
|
+
else:
|
|
107
|
+
ds_val = ds_train.random_split(cfg.val_fraction)[0]
|
|
108
|
+
logging.info(
|
|
109
|
+
f"Using a random fraction ({cfg.val_fraction * 100:.1f}%, {len(ds_val)} samples) of train dataset for validation."
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# merge small groups
|
|
113
|
+
ds_train.merge_groups(
|
|
114
|
+
min_size=8 * cfg.samplers.train.kwargs.batch_size, mode_atoms=cfg.samplers.train.kwargs.batch_mode == "atoms"
|
|
115
|
+
)
|
|
116
|
+
logging.info("After merging small groups in train dataset")
|
|
117
|
+
log_ds_group_sizes(ds_train)
|
|
118
|
+
|
|
119
|
+
loader_train = ds_train.get_loader(get_sampler(ds_train, cfg, kind="train"), cfg.x, cfg.y, **cfg.loaders.train)
|
|
120
|
+
loader_val = ds_val.get_loader(get_sampler(ds_val, cfg, kind="val"), cfg.x, cfg.y, **cfg.loaders.val)
|
|
121
|
+
return loader_train, loader_val
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def get_optimizer(model: nn.Module, cfg: omegaconf.DictConfig):
|
|
125
|
+
logging.info("Building optimizer")
|
|
126
|
+
param_groups = {}
|
|
127
|
+
for k, c in cfg.param_groups.items():
|
|
128
|
+
c = OmegaConf.to_container(c)
|
|
129
|
+
if not isinstance(c, dict):
|
|
130
|
+
raise TypeError("Param groups must be a dictionary.")
|
|
131
|
+
c.pop("re")
|
|
132
|
+
param_groups[k] = {"params": [], **c}
|
|
133
|
+
param_groups["default"] = {"params": []}
|
|
134
|
+
logging.info(f"Default parameters: {cfg.kwargs}")
|
|
135
|
+
for n, p in model.named_parameters():
|
|
136
|
+
if not p.requires_grad:
|
|
137
|
+
continue
|
|
138
|
+
_matched = False
|
|
139
|
+
for k, c in cfg.param_groups.items():
|
|
140
|
+
if re.search(c.re, n):
|
|
141
|
+
param_groups[k]["params"].append(p)
|
|
142
|
+
logging.info(f"{n}: {c}")
|
|
143
|
+
_matched = True
|
|
144
|
+
break
|
|
145
|
+
if not _matched:
|
|
146
|
+
param_groups["default"]["params"].append(p)
|
|
147
|
+
d = OmegaConf.to_container(cfg)
|
|
148
|
+
if not isinstance(d, dict):
|
|
149
|
+
raise TypeError("Optimizer configuration must be a dictionary.")
|
|
150
|
+
d["args"] = [[v for v in param_groups.values() if len(v["params"])]]
|
|
151
|
+
optimizer = get_init_module(d["class"], d["args"], d["kwargs"])
|
|
152
|
+
logging.info(f"Optimizer: {optimizer}")
|
|
153
|
+
logging.info("Trainable parameters:")
|
|
154
|
+
N = 0
|
|
155
|
+
for n, p in model.named_parameters():
|
|
156
|
+
if p.requires_grad:
|
|
157
|
+
logging.info(f"{n}: {p.shape}")
|
|
158
|
+
N += p.numel()
|
|
159
|
+
logging.info(f"Total number of trainable parameters: {N}")
|
|
160
|
+
return optimizer
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
def get_scheduler(optimizer: torch.optim.Optimizer, cfg: omegaconf.DictConfig):
|
|
164
|
+
d = OmegaConf.to_container(cfg)
|
|
165
|
+
if not isinstance(d, dict):
|
|
166
|
+
raise TypeError("Scheduler configuration must be a dictionary.")
|
|
167
|
+
d["args"] = [optimizer]
|
|
168
|
+
scheduler = build_module(d)
|
|
169
|
+
return scheduler
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def get_loss(cfg: omegaconf.DictConfig):
|
|
173
|
+
d = OmegaConf.to_container(cfg)
|
|
174
|
+
if not isinstance(d, dict):
|
|
175
|
+
raise TypeError("Loss configuration must be a dictionary.")
|
|
176
|
+
loss = build_module(d)
|
|
177
|
+
return loss
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def set_trainable_parameters(model: nn.Module, force_train: List[str], force_no_train: List[str]) -> nn.Module:
|
|
181
|
+
for n, p in model.named_parameters():
|
|
182
|
+
if any(re.search(x, n) for x in force_no_train):
|
|
183
|
+
p.requires_grad_(False)
|
|
184
|
+
logging.info(f"requires_grad {n} {p.requires_grad}")
|
|
185
|
+
if any(re.search(x, n) for x in force_train):
|
|
186
|
+
p.requires_grad_(True)
|
|
187
|
+
logging.info(f"requires_grad {n} {p.requires_grad}")
|
|
188
|
+
return model
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def unwrap_module(net):
|
|
192
|
+
if isinstance(net, (Forces, torch.nn.parallel.DistributedDataParallel)):
|
|
193
|
+
net = net.module
|
|
194
|
+
return unwrap_module(net)
|
|
195
|
+
else:
|
|
196
|
+
return net
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def build_model(cfg, forces=False):
|
|
200
|
+
d = OmegaConf.to_container(cfg)
|
|
201
|
+
if not isinstance(d, dict):
|
|
202
|
+
raise TypeError("Model configuration must be a dictionary.")
|
|
203
|
+
model = build_module(d)
|
|
204
|
+
if forces is not None:
|
|
205
|
+
model = Forces(model) # type: ignore[attr-defined]
|
|
206
|
+
return model
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def get_metrics(cfg: omegaconf.DictConfig):
|
|
210
|
+
d = OmegaConf.to_container(cfg)
|
|
211
|
+
if not isinstance(d, dict):
|
|
212
|
+
raise TypeError("Metrics configuration must be a dictionary.")
|
|
213
|
+
metrics = build_module(d)
|
|
214
|
+
return metrics
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def train_step(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
|
|
218
|
+
global model
|
|
219
|
+
global optimizer
|
|
220
|
+
global prepare_batch
|
|
221
|
+
global loss_fn
|
|
222
|
+
global device
|
|
223
|
+
|
|
224
|
+
model.train() # type: ignore
|
|
225
|
+
optimizer.zero_grad() # type: ignore
|
|
226
|
+
x, y = prepare_batch(batch, device=device, non_blocking=True) # type: ignore
|
|
227
|
+
y_pred = model(x) # type: ignore
|
|
228
|
+
loss = loss_fn(y_pred, y)["loss"] # type: ignore
|
|
229
|
+
loss.backward()
|
|
230
|
+
optimizer.step() # type: ignore
|
|
231
|
+
|
|
232
|
+
return loss.item()
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def val_step(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
|
|
236
|
+
global model
|
|
237
|
+
global optimizer
|
|
238
|
+
global prepare_batch
|
|
239
|
+
global loss_fn
|
|
240
|
+
global device
|
|
241
|
+
|
|
242
|
+
model.eval() # type: ignore
|
|
243
|
+
if not next(iter(batch[0].values())).numel():
|
|
244
|
+
return None
|
|
245
|
+
x, y = prepare_batch(batch, device=device, non_blocking=True) # type: ignore
|
|
246
|
+
with torch.no_grad():
|
|
247
|
+
y_pred = model(x) # type: ignore
|
|
248
|
+
return y_pred, y
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
def prepare_batch(batch: Dict[str, Tensor], device="cuda", non_blocking=True) -> Dict[str, Tensor]: # noqa: F811
|
|
252
|
+
for k, v in batch.items():
|
|
253
|
+
batch[k] = v.to(device, non_blocking=non_blocking)
|
|
254
|
+
return batch
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def default_trainer(
|
|
258
|
+
model: torch.nn.Module,
|
|
259
|
+
optimizer: torch.optim.Optimizer,
|
|
260
|
+
loss_fn: Union[Callable, torch.nn.Module],
|
|
261
|
+
device: Optional[Union[str, torch.device]] = None,
|
|
262
|
+
non_blocking: bool = True,
|
|
263
|
+
) -> Engine:
|
|
264
|
+
def _update(engine: Engine, batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]]) -> float:
|
|
265
|
+
model.train()
|
|
266
|
+
optimizer.zero_grad()
|
|
267
|
+
x = prepare_batch(batch[0], device=device, non_blocking=non_blocking) # type: ignore
|
|
268
|
+
y = prepare_batch(batch[1], device=device, non_blocking=non_blocking) # type: ignore
|
|
269
|
+
y_pred = model(x)
|
|
270
|
+
loss = loss_fn(y_pred, y)["loss"]
|
|
271
|
+
loss.backward()
|
|
272
|
+
torch.nn.utils.clip_grad_value_(model.parameters(), 0.4)
|
|
273
|
+
optimizer.step()
|
|
274
|
+
return loss.item()
|
|
275
|
+
|
|
276
|
+
return Engine(_update)
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def default_evaluator(
|
|
280
|
+
model: torch.nn.Module, device: Optional[Union[str, torch.device]] = None, non_blocking: bool = True
|
|
281
|
+
) -> Engine:
|
|
282
|
+
def _inference(
|
|
283
|
+
engine: Engine, batch: Tuple[Dict[str, Tensor], Dict[str, Tensor]]
|
|
284
|
+
) -> Tuple[Dict[str, Tensor], Dict[str, Tensor]]:
|
|
285
|
+
model.eval()
|
|
286
|
+
x = prepare_batch(batch[0], device=device, non_blocking=non_blocking) # type: ignore
|
|
287
|
+
y = prepare_batch(batch[1], device=device, non_blocking=non_blocking) # type: ignore
|
|
288
|
+
with torch.no_grad():
|
|
289
|
+
y_pred = model(x)
|
|
290
|
+
return y_pred, y
|
|
291
|
+
|
|
292
|
+
return Engine(_inference)
|
|
293
|
+
|
|
294
|
+
|
|
295
|
+
class TerminateOnLowLR:
|
|
296
|
+
def __init__(self, optimizer, low_lr=1e-5):
|
|
297
|
+
self.low_lr = low_lr
|
|
298
|
+
self.optimizer = optimizer
|
|
299
|
+
|
|
300
|
+
def __call__(self, engine):
|
|
301
|
+
if self.optimizer.param_groups[0]["lr"] < self.low_lr:
|
|
302
|
+
engine.terminate()
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def build_engine(model, optimizer, scheduler, loss_fn, metrics, cfg, loader_val):
|
|
306
|
+
device = next(model.parameters()).device
|
|
307
|
+
|
|
308
|
+
train_fn = get_module(cfg.trainer.trainer)
|
|
309
|
+
trainer = train_fn(model, optimizer, loss_fn, device=device, non_blocking=True)
|
|
310
|
+
# check for NaNs after each epoch
|
|
311
|
+
trainer.add_event_handler(Events.EPOCH_COMPLETED, TerminateOnNan())
|
|
312
|
+
|
|
313
|
+
# log LR
|
|
314
|
+
def log_lr(engine):
|
|
315
|
+
lr = optimizer.param_groups[0]["lr"]
|
|
316
|
+
logging.info(f"LR: {lr}")
|
|
317
|
+
|
|
318
|
+
trainer.add_event_handler(Events.EPOCH_STARTED, log_lr)
|
|
319
|
+
# write TQDM progress
|
|
320
|
+
if idist.get_local_rank() == 0:
|
|
321
|
+
pbar = ProgressBar()
|
|
322
|
+
pbar.attach(trainer, event_name=Events.ITERATION_COMPLETED(every=100))
|
|
323
|
+
|
|
324
|
+
# attach validator
|
|
325
|
+
validate_fn = get_module(cfg.trainer.evaluator)
|
|
326
|
+
validator = validate_fn(model, device=device, non_blocking=True)
|
|
327
|
+
metrics.attach(validator, "multi")
|
|
328
|
+
trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1), validator.run, data=loader_val)
|
|
329
|
+
|
|
330
|
+
# scheduler
|
|
331
|
+
if scheduler is not None:
|
|
332
|
+
validator.add_event_handler(Events.COMPLETED, scheduler)
|
|
333
|
+
terminator = TerminateOnLowLR(optimizer, cfg.scheduler.terminate_on_low_lr)
|
|
334
|
+
trainer.add_event_handler(Events.EPOCH_STARTED, terminator)
|
|
335
|
+
|
|
336
|
+
# checkpoint after each epoch
|
|
337
|
+
if cfg.checkpoint and idist.get_local_rank() == 0:
|
|
338
|
+
kwargs = OmegaConf.to_container(cfg.checkpoint.kwargs) if "kwargs" not in cfg.checkpoint else {}
|
|
339
|
+
if not isinstance(kwargs, dict):
|
|
340
|
+
raise TypeError("Checkpoint kwargs must be a dictionary.")
|
|
341
|
+
kwargs["global_step_transform"] = global_step_from_engine(trainer)
|
|
342
|
+
kwargs["dirname"] = cfg.checkpoint.dirname
|
|
343
|
+
kwargs["filename_prefix"] = cfg.checkpoint.filename_prefix
|
|
344
|
+
checkpointer = ModelCheckpoint(**kwargs) # type: ignore
|
|
345
|
+
validator.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {"model": unwrap_module(model)})
|
|
346
|
+
|
|
347
|
+
return trainer, validator
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def setup_wandb(cfg, model_cfg, model, trainer, validator, optimizer):
|
|
351
|
+
import wandb
|
|
352
|
+
from ignite.handlers import WandBLogger, global_step_from_engine
|
|
353
|
+
from ignite.handlers.wandb_logger import OptimizerParamsHandler
|
|
354
|
+
|
|
355
|
+
init_kwargs = OmegaConf.to_container(cfg.wandb.init, resolve=True)
|
|
356
|
+
wandb.init(**init_kwargs) # type: ignore
|
|
357
|
+
wandb_logger = WandBLogger(init=False)
|
|
358
|
+
|
|
359
|
+
OmegaConf.save(model_cfg, wandb.run.dir + "/model.yaml") # type: ignore
|
|
360
|
+
OmegaConf.save(cfg, wandb.run.dir + "/train.yaml") # type: ignore
|
|
361
|
+
|
|
362
|
+
wandb_logger.attach_output_handler(
|
|
363
|
+
trainer,
|
|
364
|
+
event_name=Events.ITERATION_COMPLETED(every=200),
|
|
365
|
+
output_transform=lambda loss: {"loss": loss},
|
|
366
|
+
tag="train",
|
|
367
|
+
)
|
|
368
|
+
wandb_logger.attach_output_handler(
|
|
369
|
+
validator,
|
|
370
|
+
event_name=Events.EPOCH_COMPLETED,
|
|
371
|
+
global_step_transform=lambda *_: trainer.state.iteration,
|
|
372
|
+
metric_names="all",
|
|
373
|
+
tag="val",
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
class EpochLRLogger(OptimizerParamsHandler):
|
|
377
|
+
def __call__(self, engine, logger, event_name):
|
|
378
|
+
global_step = engine.state.iteration
|
|
379
|
+
params = {
|
|
380
|
+
f"{self.param_name}_{i}": float(g[self.param_name]) for i, g in enumerate(self.optimizer.param_groups)
|
|
381
|
+
}
|
|
382
|
+
logger.log(params, step=global_step, sync=self.sync)
|
|
383
|
+
|
|
384
|
+
wandb_logger.attach(trainer, log_handler=EpochLRLogger(optimizer), event_name=Events.EPOCH_STARTED)
|
|
385
|
+
|
|
386
|
+
score_function = lambda engine: 1.0 / engine.state.metrics["loss"]
|
|
387
|
+
model_checkpoint = ModelCheckpoint(
|
|
388
|
+
wandb.run.dir, # type: ignore
|
|
389
|
+
n_saved=1,
|
|
390
|
+
filename_prefix="best", # type: ignore
|
|
391
|
+
require_empty=False,
|
|
392
|
+
score_function=score_function,
|
|
393
|
+
global_step_transform=global_step_from_engine(trainer),
|
|
394
|
+
)
|
|
395
|
+
validator.add_event_handler(Events.EPOCH_COMPLETED, model_checkpoint, {"model": unwrap_module(model)})
|
|
396
|
+
|
|
397
|
+
if cfg.wandb.watch_model:
|
|
398
|
+
wandb.watch(unwrap_module(model), **OmegaConf.to_container(cfg.wandb.watch_model, resolve=True)) # type: ignore
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2024, Roman Zubatyuk
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|