aimnet 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- aimnet/__init__.py +7 -0
- aimnet/base.py +24 -8
- aimnet/calculators/__init__.py +4 -4
- aimnet/calculators/aimnet2ase.py +19 -6
- aimnet/calculators/calculator.py +868 -108
- aimnet/calculators/model_registry.py +2 -5
- aimnet/calculators/model_registry.yaml +55 -17
- aimnet/cli.py +62 -6
- aimnet/config.py +8 -9
- aimnet/data/sgdataset.py +23 -22
- aimnet/kernels/__init__.py +66 -0
- aimnet/kernels/conv_sv_2d_sp_wp.py +478 -0
- aimnet/models/__init__.py +13 -1
- aimnet/models/aimnet2.py +19 -22
- aimnet/models/base.py +183 -15
- aimnet/models/convert.py +30 -0
- aimnet/models/utils.py +735 -0
- aimnet/modules/__init__.py +1 -1
- aimnet/modules/aev.py +49 -48
- aimnet/modules/core.py +14 -13
- aimnet/modules/lr.py +520 -115
- aimnet/modules/ops.py +537 -0
- aimnet/nbops.py +105 -15
- aimnet/ops.py +90 -18
- aimnet/train/export_model.py +226 -0
- aimnet/train/loss.py +7 -7
- aimnet/train/metrics.py +5 -6
- aimnet/train/train.py +4 -1
- aimnet/train/utils.py +42 -13
- aimnet-0.1.0.dist-info/METADATA +308 -0
- aimnet-0.1.0.dist-info/RECORD +43 -0
- {aimnet-0.0.1.dist-info → aimnet-0.1.0.dist-info}/WHEEL +1 -1
- aimnet-0.1.0.dist-info/entry_points.txt +3 -0
- aimnet/calculators/nb_kernel_cpu.py +0 -222
- aimnet/calculators/nb_kernel_cuda.py +0 -217
- aimnet/calculators/nbmat.py +0 -220
- aimnet/train/pt2jpt.py +0 -81
- aimnet-0.0.1.dist-info/METADATA +0 -78
- aimnet-0.0.1.dist-info/RECORD +0 -41
- aimnet-0.0.1.dist-info/entry_points.txt +0 -5
- {aimnet-0.0.1.dist-info → aimnet-0.1.0.dist-info/licenses}/LICENSE +0 -0
aimnet/train/pt2jpt.py
DELETED
|
@@ -1,81 +0,0 @@
|
|
|
1
|
-
import os
|
|
2
|
-
from typing import List, Optional
|
|
3
|
-
|
|
4
|
-
import click
|
|
5
|
-
import torch
|
|
6
|
-
from torch import nn
|
|
7
|
-
|
|
8
|
-
from aimnet.config import build_module, load_yaml
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def set_eval(model: nn.Module) -> torch.nn.Module:
|
|
12
|
-
for p in model.parameters():
|
|
13
|
-
p.requires_grad_(False)
|
|
14
|
-
return model.eval()
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
def add_cutoff(
|
|
18
|
-
model: nn.Module, cutoff: Optional[float] = None, cutoff_lr: Optional[float] = float("inf")
|
|
19
|
-
) -> nn.Module:
|
|
20
|
-
if cutoff is None:
|
|
21
|
-
cutoff = max(v.item() for k, v in model.state_dict().items() if k.endswith("aev.rc_s"))
|
|
22
|
-
model.cutoff = cutoff # type: ignore[assignment]
|
|
23
|
-
if cutoff_lr is not None:
|
|
24
|
-
model.cutoff_lr = cutoff_lr # type: ignore[assignment]
|
|
25
|
-
return model
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
def add_sae_to_shifts(model: nn.Module, sae_file: str) -> nn.Module:
|
|
29
|
-
sae = load_yaml(sae_file)
|
|
30
|
-
if not isinstance(sae, dict):
|
|
31
|
-
raise TypeError("SAE file must contain a dictionary.")
|
|
32
|
-
model.outputs.atomic_shift.double()
|
|
33
|
-
for k, v in sae.items():
|
|
34
|
-
model.outputs.atomic_shift.shifts.weight[k] += v
|
|
35
|
-
return model
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
def mask_not_implemented_species(model: nn.Module, species: List[int]) -> nn.Module:
|
|
39
|
-
weight = model.afv.weight
|
|
40
|
-
for i in range(1, weight.shape[0]):
|
|
41
|
-
if i not in species:
|
|
42
|
-
weight[i, :] = torch.nan
|
|
43
|
-
return model
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
_default_aimnet2_config = os.path.join(os.path.dirname(__file__), "..", "models", "aimnet2.yaml")
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
@click.command(short_help="Compile PyTorch model to TorchScript.")
|
|
50
|
-
@click.argument("pt", type=str) # , help='Path to the input PyTorch weights file.')
|
|
51
|
-
@click.argument("jpt", type=str) # , help='Path to the output TorchScript file.')
|
|
52
|
-
@click.option("--model", type=str, default=_default_aimnet2_config, help="Path to model definition YAML file")
|
|
53
|
-
@click.option("--sae", type=str, default=None, help="Path to the energy shift YAML file.")
|
|
54
|
-
@click.option("--species", type=str, default=None, help="Comma-separated list of parametrized atomic numbers.")
|
|
55
|
-
@click.option("--no-lr", is_flag=True, help="Do not add LR cutoff for model")
|
|
56
|
-
def jitcompile(model: str, pt: str, jpt: str, sae=None, species=None, no_lr=False): # type: ignore
|
|
57
|
-
"""Build model from YAML config, load weight from PT file and write JIT-compiled JPT file.
|
|
58
|
-
Plus some modifications to work with aimnet2calc.
|
|
59
|
-
"""
|
|
60
|
-
model: nn.Module = build_module(model) # type: ignore[annotation-unchecked]
|
|
61
|
-
model = set_eval(model)
|
|
62
|
-
cutoff_lr = None if no_lr else float("inf")
|
|
63
|
-
model = add_cutoff(model, cutoff_lr=cutoff_lr)
|
|
64
|
-
sd = torch.load(pt, map_location="cpu", weights_only=True)
|
|
65
|
-
print(model.load_state_dict(sd, strict=False))
|
|
66
|
-
if sae:
|
|
67
|
-
model = add_sae_to_shifts(model, sae)
|
|
68
|
-
numbers = None
|
|
69
|
-
if species:
|
|
70
|
-
numbers = list(map(int, species.split(",")))
|
|
71
|
-
elif sae:
|
|
72
|
-
numbers = list(load_yaml(sae).keys()) # type: ignore[union-attr]
|
|
73
|
-
if numbers:
|
|
74
|
-
model = mask_not_implemented_species(model, numbers) # type: ignore[call-arg]
|
|
75
|
-
model.register_buffer("impemented_species", torch.tensor(numbers, dtype=torch.int64))
|
|
76
|
-
model_jit = torch.jit.script(model)
|
|
77
|
-
model_jit.save(jpt)
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
if __name__ == "__main__":
|
|
81
|
-
jitcompile()
|
aimnet-0.0.1.dist-info/METADATA
DELETED
|
@@ -1,78 +0,0 @@
|
|
|
1
|
-
Metadata-Version: 2.1
|
|
2
|
-
Name: aimnet
|
|
3
|
-
Version: 0.0.1
|
|
4
|
-
Summary: AIMNet Machine Learned Interatomic Potential
|
|
5
|
-
Home-page: https://github.com/isayevlab/aimnetcentral
|
|
6
|
-
Author: Roman Zubatyuk
|
|
7
|
-
Author-email: zubatyuk@gmail.com
|
|
8
|
-
Requires-Python: >=3.10,<4.0
|
|
9
|
-
Classifier: Programming Language :: Python :: 3
|
|
10
|
-
Classifier: Programming Language :: Python :: 3.10
|
|
11
|
-
Classifier: Programming Language :: Python :: 3.11
|
|
12
|
-
Classifier: Programming Language :: Python :: 3.12
|
|
13
|
-
Classifier: Programming Language :: Python :: 3.13
|
|
14
|
-
Requires-Dist: click (>=8.1.7,<9.0.0)
|
|
15
|
-
Requires-Dist: h5py (>=3.12.1,<4.0.0)
|
|
16
|
-
Requires-Dist: ignite (>=1.1.0,<2.0.0)
|
|
17
|
-
Requires-Dist: jinja2 (>=3.1.4,<4.0.0)
|
|
18
|
-
Requires-Dist: numba (>=0.60.0,<0.61.0)
|
|
19
|
-
Requires-Dist: numpy (<2.0)
|
|
20
|
-
Requires-Dist: omegaconf (>=2.3.0,<3.0.0)
|
|
21
|
-
Requires-Dist: pyyaml (>=6.0.2,<7.0.0)
|
|
22
|
-
Requires-Dist: requests (>=2.32.3,<3.0.0)
|
|
23
|
-
Requires-Dist: torch (>=2.5.0,<3.0.0)
|
|
24
|
-
Requires-Dist: wandb (>=0.18.5,<0.19.0)
|
|
25
|
-
Project-URL: Documentation, https://isayevlab.github.io/aimnetcentral/
|
|
26
|
-
Project-URL: Repository, https://github.com/isayevlab/aimnetcentral
|
|
27
|
-
Description-Content-Type: text/markdown
|
|
28
|
-
|
|
29
|
-
# aimnetcentral
|
|
30
|
-
|
|
31
|
-
[](https://img.shields.io/github/v/release/isayevlab/aimnetcentral)
|
|
32
|
-
[](https://github.com/isayevlab/aimnetcentral/actions/workflows/main.yml?query=branch%3Amain)
|
|
33
|
-
[](https://codecov.io/gh/isayevlab/aimnetcentral)
|
|
34
|
-
[](https://img.shields.io/github/commit-activity/m/isayevlab/aimnetcentral)
|
|
35
|
-
[](https://img.shields.io/github/license/isayevlab/aimnetcentral)
|
|
36
|
-
|
|
37
|
-
AIMNet Machine Learned Interatomic Potential
|
|
38
|
-
|
|
39
|
-
- **Github repository**: <https://github.com/isayevlab/aimnetcentral/>
|
|
40
|
-
- **Documentation** <https://isayevlab.github.io/aimnetcentral/>
|
|
41
|
-
|
|
42
|
-
## Getting started with your project
|
|
43
|
-
|
|
44
|
-
First, create a repository on GitHub with the same name as this project, and then run the following commands:
|
|
45
|
-
|
|
46
|
-
```bash
|
|
47
|
-
git init -b main
|
|
48
|
-
git add .
|
|
49
|
-
git commit -m "init commit"
|
|
50
|
-
git remote add origin git@github.com:isayevlab/aimnetcentral.git
|
|
51
|
-
git push -u origin main
|
|
52
|
-
```
|
|
53
|
-
|
|
54
|
-
Finally, install the environment and the pre-commit hooks with
|
|
55
|
-
|
|
56
|
-
```bash
|
|
57
|
-
make install
|
|
58
|
-
```
|
|
59
|
-
|
|
60
|
-
You are now ready to start development on your project!
|
|
61
|
-
The CI/CD pipeline will be triggered when you open a pull request, merge to main, or when you create a new release.
|
|
62
|
-
|
|
63
|
-
To finalize the set-up for publishing to PyPI or Artifactory, see [here](https://fpgmaas.github.io/cookiecutter-poetry/features/publishing/#set-up-for-pypi).
|
|
64
|
-
For activating the automatic documentation with MkDocs, see [here](https://fpgmaas.github.io/cookiecutter-poetry/features/mkdocs/#enabling-the-documentation-on-github).
|
|
65
|
-
To enable the code coverage reports, see [here](https://fpgmaas.github.io/cookiecutter-poetry/features/codecov/).
|
|
66
|
-
|
|
67
|
-
## Releasing a new version
|
|
68
|
-
|
|
69
|
-
- Create an API Token on [PyPI](https://pypi.org/).
|
|
70
|
-
- Add the API Token to your projects secrets with the name `PYPI_TOKEN` by visiting [this page](https://github.com/isayevlab/aimnetcentral/settings/secrets/actions/new).
|
|
71
|
-
- Create a [new release](https://github.com/isayevlab/aimnetcentral/releases/new) on Github.
|
|
72
|
-
- Create a new tag in the form `*.*.*`.
|
|
73
|
-
- For more details, see [here](https://fpgmaas.github.io/cookiecutter-poetry/features/cicd/#how-to-trigger-a-release).
|
|
74
|
-
|
|
75
|
-
---
|
|
76
|
-
|
|
77
|
-
Repository initiated with [fpgmaas/cookiecutter-poetry](https://github.com/fpgmaas/cookiecutter-poetry).
|
|
78
|
-
|
aimnet-0.0.1.dist-info/RECORD
DELETED
|
@@ -1,41 +0,0 @@
|
|
|
1
|
-
aimnet/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
aimnet/base.py,sha256=v-8ukDanaxnhYP0IbugKoixssiJTQMR6KyGtnriH1RI,1663
|
|
3
|
-
aimnet/calculators/__init__.py,sha256=V-x9HIkk_wDkBI0W88A0St-9dmTYPZXErGT_hWLkgng,383
|
|
4
|
-
aimnet/calculators/aimnet2ase.py,sha256=omyzJKzTpLvIltvYKOvikQ6VuSx6IViWQNZlceMrV_A,3536
|
|
5
|
-
aimnet/calculators/aimnet2pysis.py,sha256=TDoDl-TExCQkg8EXfsTKjOu-z6Kd25xXqSH98gF8iGM,2876
|
|
6
|
-
aimnet/calculators/calculator.py,sha256=9QbhAfw7EBi5CRhP0rT-2XVx243Ub3cq-XrqJI2LLb0,13092
|
|
7
|
-
aimnet/calculators/model_registry.py,sha256=TaQdTUiKqKWt7fSOxODd69QgnmCuNVwHL3qBOcyMsZI,1921
|
|
8
|
-
aimnet/calculators/model_registry.yaml,sha256=Yejl2ls8DWvwRaBIwpGfc3WsHB9jZdIH_3B-i5klmks,1410
|
|
9
|
-
aimnet/calculators/nb_kernel_cpu.py,sha256=a6ClFEzfc7-Fr4m_TNGOprjdq0GiFZ8wheqmKeDjqjM,7309
|
|
10
|
-
aimnet/calculators/nb_kernel_cuda.py,sha256=0FHpGJ6-MwQgoX-GiBLGZqdhN4P2uqX4CPxQIi_-S2c,6552
|
|
11
|
-
aimnet/calculators/nbmat.py,sha256=_3LKqoHEVsM1dXBMICQYyciPoqIBUgVq6KGM7XgIMh4,7785
|
|
12
|
-
aimnet/cli.py,sha256=2A51wv5N2i9HdmiCcam_JDY03WKeOTTn8dTxTtp1QXM,415
|
|
13
|
-
aimnet/config.py,sha256=Ld5wyApDhI8OovodC_ibpcDL9rU-VJ8gkQZ3BaiDf0c,5920
|
|
14
|
-
aimnet/constants.py,sha256=T6eb_CG5dkxuUg22Xf8WiF0iGB-JLoLNzHq9R8SRzrQ,8760
|
|
15
|
-
aimnet/data/__init__.py,sha256=rD84W2qazjXUbdcYBuHkMmYp44fGRFMS3qnv1OKPsQs,87
|
|
16
|
-
aimnet/data/sgdataset.py,sha256=N3Nwwn1GicoohcRh-DJWJFdibmV2o6YZjKaTlWyZZx4,18926
|
|
17
|
-
aimnet/dftd3_data.pt,sha256=g9HQ_2nekMs97UnIDIFReGEG8q6n6zQlq-Yl5LGQnOI,2710647
|
|
18
|
-
aimnet/models/__init__.py,sha256=gb6kA4dHknYOzllaz02-Bx5kBkWaSjj9wLtY0Fx9zdo,87
|
|
19
|
-
aimnet/models/aimnet2.py,sha256=O0CcP21sJcSeZCE-H0klKfFPuqDOmxWdrX_ZbX1aFm8,7006
|
|
20
|
-
aimnet/models/aimnet2.yaml,sha256=VmI0ub7UbiJm833zwCtCKsrBIaXZKjJsCXxXxutwdLI,1062
|
|
21
|
-
aimnet/models/aimnet2_dftd3_wb97m.yaml,sha256=Gk869_2RQIXzlHwRNbfsai4u9wPX3oAyINK1I71eBFI,1219
|
|
22
|
-
aimnet/models/base.py,sha256=uLcOaTiXEZPpLgBVEnT5QFcIZrA8KgVNLe_guR5o99o,1815
|
|
23
|
-
aimnet/modules/__init__.py,sha256=c4tcXbNfRRvFgIk9RBk6X_5iKmCy6za7ooeuJfqlAHc,205
|
|
24
|
-
aimnet/modules/aev.py,sha256=QgnFsz_k-qd_d0sufBc0B_r2iCCJezl2Pn9wbvDIeIU,7801
|
|
25
|
-
aimnet/modules/core.py,sha256=kDhU50plXKXNXnr5BgKJREwHa2F6aawL6wI2O74cMb0,8236
|
|
26
|
-
aimnet/modules/lr.py,sha256=x6AAdyF_Qgkp7OKcS9dk8FDU5ENoVkKYQ2aQgUajz6Y,8835
|
|
27
|
-
aimnet/nbops.py,sha256=iR9MR3MSts5Ow6j-_PtwylqHSYEQyJDwDEvF7FHNSJc,5450
|
|
28
|
-
aimnet/ops.py,sha256=p6_jwC_xjEw3FaNvwmNkkIOIttFRsd9WF03ovUkX3cU,7888
|
|
29
|
-
aimnet/train/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
30
|
-
aimnet/train/calc_sae.py,sha256=h-y8jQVU6O7Rmw8xJWX3DHOpFv1gtjFwtZzZyUOhz9g,1459
|
|
31
|
-
aimnet/train/default_train.yaml,sha256=FAG3f3pFC2dkFUuVO_XE_UB6rFr-R3GkACddn7FH_58,4783
|
|
32
|
-
aimnet/train/loss.py,sha256=BqBaMBCPBkFcm1QnpczFJY6ZU_5KfAQEBo-LCYdJ0m0,3175
|
|
33
|
-
aimnet/train/metrics.py,sha256=HEfIoFhOc6eD8yrgxqXUGI3CfWnXfm6G6AUpncGQljs,6867
|
|
34
|
-
aimnet/train/pt2jpt.py,sha256=XMc95PrXvlHBUqtAVNJ0JzJP0ZJK-WrJ8rbSg4yxHCU,3129
|
|
35
|
-
aimnet/train/train.py,sha256=S34BWOhfhe8fg9qGkp-xw6O4-sgHM6R2oQ0-e8dk6hg,5711
|
|
36
|
-
aimnet/train/utils.py,sha256=2pYZbY1L9KdmuHqmfdQDId-TWdl-WMeNoeatcGZWPPE,14730
|
|
37
|
-
aimnet-0.0.1.dist-info/LICENSE,sha256=73sk-zg2yVRrOZQDeVbPlVB7ScZc1iK0kyCBMwNwQgA,1072
|
|
38
|
-
aimnet-0.0.1.dist-info/METADATA,sha256=RAVsz_d3HL2QFGV9H9M2J57riF1W9VMmhsjUWvTOSaw,3717
|
|
39
|
-
aimnet-0.0.1.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
|
|
40
|
-
aimnet-0.0.1.dist-info/entry_points.txt,sha256=3SPZbjuQFS1V1xryBwVoo8ix0DNGYzHeoDbvw0z2Mdk,164
|
|
41
|
-
aimnet-0.0.1.dist-info/RECORD,,
|
|
File without changes
|