aimnet 0.0.1__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aimnet/train/pt2jpt.py DELETED
@@ -1,81 +0,0 @@
1
- import os
2
- from typing import List, Optional
3
-
4
- import click
5
- import torch
6
- from torch import nn
7
-
8
- from aimnet.config import build_module, load_yaml
9
-
10
-
11
- def set_eval(model: nn.Module) -> torch.nn.Module:
12
- for p in model.parameters():
13
- p.requires_grad_(False)
14
- return model.eval()
15
-
16
-
17
- def add_cutoff(
18
- model: nn.Module, cutoff: Optional[float] = None, cutoff_lr: Optional[float] = float("inf")
19
- ) -> nn.Module:
20
- if cutoff is None:
21
- cutoff = max(v.item() for k, v in model.state_dict().items() if k.endswith("aev.rc_s"))
22
- model.cutoff = cutoff # type: ignore[assignment]
23
- if cutoff_lr is not None:
24
- model.cutoff_lr = cutoff_lr # type: ignore[assignment]
25
- return model
26
-
27
-
28
- def add_sae_to_shifts(model: nn.Module, sae_file: str) -> nn.Module:
29
- sae = load_yaml(sae_file)
30
- if not isinstance(sae, dict):
31
- raise TypeError("SAE file must contain a dictionary.")
32
- model.outputs.atomic_shift.double()
33
- for k, v in sae.items():
34
- model.outputs.atomic_shift.shifts.weight[k] += v
35
- return model
36
-
37
-
38
- def mask_not_implemented_species(model: nn.Module, species: List[int]) -> nn.Module:
39
- weight = model.afv.weight
40
- for i in range(1, weight.shape[0]):
41
- if i not in species:
42
- weight[i, :] = torch.nan
43
- return model
44
-
45
-
46
- _default_aimnet2_config = os.path.join(os.path.dirname(__file__), "..", "models", "aimnet2.yaml")
47
-
48
-
49
- @click.command(short_help="Compile PyTorch model to TorchScript.")
50
- @click.argument("pt", type=str) # , help='Path to the input PyTorch weights file.')
51
- @click.argument("jpt", type=str) # , help='Path to the output TorchScript file.')
52
- @click.option("--model", type=str, default=_default_aimnet2_config, help="Path to model definition YAML file")
53
- @click.option("--sae", type=str, default=None, help="Path to the energy shift YAML file.")
54
- @click.option("--species", type=str, default=None, help="Comma-separated list of parametrized atomic numbers.")
55
- @click.option("--no-lr", is_flag=True, help="Do not add LR cutoff for model")
56
- def jitcompile(model: str, pt: str, jpt: str, sae=None, species=None, no_lr=False): # type: ignore
57
- """Build model from YAML config, load weight from PT file and write JIT-compiled JPT file.
58
- Plus some modifications to work with aimnet2calc.
59
- """
60
- model: nn.Module = build_module(model) # type: ignore[annotation-unchecked]
61
- model = set_eval(model)
62
- cutoff_lr = None if no_lr else float("inf")
63
- model = add_cutoff(model, cutoff_lr=cutoff_lr)
64
- sd = torch.load(pt, map_location="cpu", weights_only=True)
65
- print(model.load_state_dict(sd, strict=False))
66
- if sae:
67
- model = add_sae_to_shifts(model, sae)
68
- numbers = None
69
- if species:
70
- numbers = list(map(int, species.split(",")))
71
- elif sae:
72
- numbers = list(load_yaml(sae).keys()) # type: ignore[union-attr]
73
- if numbers:
74
- model = mask_not_implemented_species(model, numbers) # type: ignore[call-arg]
75
- model.register_buffer("impemented_species", torch.tensor(numbers, dtype=torch.int64))
76
- model_jit = torch.jit.script(model)
77
- model_jit.save(jpt)
78
-
79
-
80
- if __name__ == "__main__":
81
- jitcompile()
@@ -1,78 +0,0 @@
1
- Metadata-Version: 2.1
2
- Name: aimnet
3
- Version: 0.0.1
4
- Summary: AIMNet Machine Learned Interatomic Potential
5
- Home-page: https://github.com/isayevlab/aimnetcentral
6
- Author: Roman Zubatyuk
7
- Author-email: zubatyuk@gmail.com
8
- Requires-Python: >=3.10,<4.0
9
- Classifier: Programming Language :: Python :: 3
10
- Classifier: Programming Language :: Python :: 3.10
11
- Classifier: Programming Language :: Python :: 3.11
12
- Classifier: Programming Language :: Python :: 3.12
13
- Classifier: Programming Language :: Python :: 3.13
14
- Requires-Dist: click (>=8.1.7,<9.0.0)
15
- Requires-Dist: h5py (>=3.12.1,<4.0.0)
16
- Requires-Dist: ignite (>=1.1.0,<2.0.0)
17
- Requires-Dist: jinja2 (>=3.1.4,<4.0.0)
18
- Requires-Dist: numba (>=0.60.0,<0.61.0)
19
- Requires-Dist: numpy (<2.0)
20
- Requires-Dist: omegaconf (>=2.3.0,<3.0.0)
21
- Requires-Dist: pyyaml (>=6.0.2,<7.0.0)
22
- Requires-Dist: requests (>=2.32.3,<3.0.0)
23
- Requires-Dist: torch (>=2.5.0,<3.0.0)
24
- Requires-Dist: wandb (>=0.18.5,<0.19.0)
25
- Project-URL: Documentation, https://isayevlab.github.io/aimnetcentral/
26
- Project-URL: Repository, https://github.com/isayevlab/aimnetcentral
27
- Description-Content-Type: text/markdown
28
-
29
- # aimnetcentral
30
-
31
- [![Release](https://img.shields.io/github/v/release/isayevlab/aimnetcentral)](https://img.shields.io/github/v/release/isayevlab/aimnetcentral)
32
- [![Build status](https://img.shields.io/github/actions/workflow/status/isayevlab/aimnetcentral/main.yml?branch=main)](https://github.com/isayevlab/aimnetcentral/actions/workflows/main.yml?query=branch%3Amain)
33
- [![codecov](https://codecov.io/gh/isayevlab/aimnetcentral/branch/main/graph/badge.svg)](https://codecov.io/gh/isayevlab/aimnetcentral)
34
- [![Commit activity](https://img.shields.io/github/commit-activity/m/isayevlab/aimnetcentral)](https://img.shields.io/github/commit-activity/m/isayevlab/aimnetcentral)
35
- [![License](https://img.shields.io/github/license/isayevlab/aimnetcentral)](https://img.shields.io/github/license/isayevlab/aimnetcentral)
36
-
37
- AIMNet Machine Learned Interatomic Potential
38
-
39
- - **Github repository**: <https://github.com/isayevlab/aimnetcentral/>
40
- - **Documentation** <https://isayevlab.github.io/aimnetcentral/>
41
-
42
- ## Getting started with your project
43
-
44
- First, create a repository on GitHub with the same name as this project, and then run the following commands:
45
-
46
- ```bash
47
- git init -b main
48
- git add .
49
- git commit -m "init commit"
50
- git remote add origin git@github.com:isayevlab/aimnetcentral.git
51
- git push -u origin main
52
- ```
53
-
54
- Finally, install the environment and the pre-commit hooks with
55
-
56
- ```bash
57
- make install
58
- ```
59
-
60
- You are now ready to start development on your project!
61
- The CI/CD pipeline will be triggered when you open a pull request, merge to main, or when you create a new release.
62
-
63
- To finalize the set-up for publishing to PyPI or Artifactory, see [here](https://fpgmaas.github.io/cookiecutter-poetry/features/publishing/#set-up-for-pypi).
64
- For activating the automatic documentation with MkDocs, see [here](https://fpgmaas.github.io/cookiecutter-poetry/features/mkdocs/#enabling-the-documentation-on-github).
65
- To enable the code coverage reports, see [here](https://fpgmaas.github.io/cookiecutter-poetry/features/codecov/).
66
-
67
- ## Releasing a new version
68
-
69
- - Create an API Token on [PyPI](https://pypi.org/).
70
- - Add the API Token to your projects secrets with the name `PYPI_TOKEN` by visiting [this page](https://github.com/isayevlab/aimnetcentral/settings/secrets/actions/new).
71
- - Create a [new release](https://github.com/isayevlab/aimnetcentral/releases/new) on Github.
72
- - Create a new tag in the form `*.*.*`.
73
- - For more details, see [here](https://fpgmaas.github.io/cookiecutter-poetry/features/cicd/#how-to-trigger-a-release).
74
-
75
- ---
76
-
77
- Repository initiated with [fpgmaas/cookiecutter-poetry](https://github.com/fpgmaas/cookiecutter-poetry).
78
-
@@ -1,41 +0,0 @@
1
- aimnet/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- aimnet/base.py,sha256=v-8ukDanaxnhYP0IbugKoixssiJTQMR6KyGtnriH1RI,1663
3
- aimnet/calculators/__init__.py,sha256=V-x9HIkk_wDkBI0W88A0St-9dmTYPZXErGT_hWLkgng,383
4
- aimnet/calculators/aimnet2ase.py,sha256=omyzJKzTpLvIltvYKOvikQ6VuSx6IViWQNZlceMrV_A,3536
5
- aimnet/calculators/aimnet2pysis.py,sha256=TDoDl-TExCQkg8EXfsTKjOu-z6Kd25xXqSH98gF8iGM,2876
6
- aimnet/calculators/calculator.py,sha256=9QbhAfw7EBi5CRhP0rT-2XVx243Ub3cq-XrqJI2LLb0,13092
7
- aimnet/calculators/model_registry.py,sha256=TaQdTUiKqKWt7fSOxODd69QgnmCuNVwHL3qBOcyMsZI,1921
8
- aimnet/calculators/model_registry.yaml,sha256=Yejl2ls8DWvwRaBIwpGfc3WsHB9jZdIH_3B-i5klmks,1410
9
- aimnet/calculators/nb_kernel_cpu.py,sha256=a6ClFEzfc7-Fr4m_TNGOprjdq0GiFZ8wheqmKeDjqjM,7309
10
- aimnet/calculators/nb_kernel_cuda.py,sha256=0FHpGJ6-MwQgoX-GiBLGZqdhN4P2uqX4CPxQIi_-S2c,6552
11
- aimnet/calculators/nbmat.py,sha256=_3LKqoHEVsM1dXBMICQYyciPoqIBUgVq6KGM7XgIMh4,7785
12
- aimnet/cli.py,sha256=2A51wv5N2i9HdmiCcam_JDY03WKeOTTn8dTxTtp1QXM,415
13
- aimnet/config.py,sha256=Ld5wyApDhI8OovodC_ibpcDL9rU-VJ8gkQZ3BaiDf0c,5920
14
- aimnet/constants.py,sha256=T6eb_CG5dkxuUg22Xf8WiF0iGB-JLoLNzHq9R8SRzrQ,8760
15
- aimnet/data/__init__.py,sha256=rD84W2qazjXUbdcYBuHkMmYp44fGRFMS3qnv1OKPsQs,87
16
- aimnet/data/sgdataset.py,sha256=N3Nwwn1GicoohcRh-DJWJFdibmV2o6YZjKaTlWyZZx4,18926
17
- aimnet/dftd3_data.pt,sha256=g9HQ_2nekMs97UnIDIFReGEG8q6n6zQlq-Yl5LGQnOI,2710647
18
- aimnet/models/__init__.py,sha256=gb6kA4dHknYOzllaz02-Bx5kBkWaSjj9wLtY0Fx9zdo,87
19
- aimnet/models/aimnet2.py,sha256=O0CcP21sJcSeZCE-H0klKfFPuqDOmxWdrX_ZbX1aFm8,7006
20
- aimnet/models/aimnet2.yaml,sha256=VmI0ub7UbiJm833zwCtCKsrBIaXZKjJsCXxXxutwdLI,1062
21
- aimnet/models/aimnet2_dftd3_wb97m.yaml,sha256=Gk869_2RQIXzlHwRNbfsai4u9wPX3oAyINK1I71eBFI,1219
22
- aimnet/models/base.py,sha256=uLcOaTiXEZPpLgBVEnT5QFcIZrA8KgVNLe_guR5o99o,1815
23
- aimnet/modules/__init__.py,sha256=c4tcXbNfRRvFgIk9RBk6X_5iKmCy6za7ooeuJfqlAHc,205
24
- aimnet/modules/aev.py,sha256=QgnFsz_k-qd_d0sufBc0B_r2iCCJezl2Pn9wbvDIeIU,7801
25
- aimnet/modules/core.py,sha256=kDhU50plXKXNXnr5BgKJREwHa2F6aawL6wI2O74cMb0,8236
26
- aimnet/modules/lr.py,sha256=x6AAdyF_Qgkp7OKcS9dk8FDU5ENoVkKYQ2aQgUajz6Y,8835
27
- aimnet/nbops.py,sha256=iR9MR3MSts5Ow6j-_PtwylqHSYEQyJDwDEvF7FHNSJc,5450
28
- aimnet/ops.py,sha256=p6_jwC_xjEw3FaNvwmNkkIOIttFRsd9WF03ovUkX3cU,7888
29
- aimnet/train/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
- aimnet/train/calc_sae.py,sha256=h-y8jQVU6O7Rmw8xJWX3DHOpFv1gtjFwtZzZyUOhz9g,1459
31
- aimnet/train/default_train.yaml,sha256=FAG3f3pFC2dkFUuVO_XE_UB6rFr-R3GkACddn7FH_58,4783
32
- aimnet/train/loss.py,sha256=BqBaMBCPBkFcm1QnpczFJY6ZU_5KfAQEBo-LCYdJ0m0,3175
33
- aimnet/train/metrics.py,sha256=HEfIoFhOc6eD8yrgxqXUGI3CfWnXfm6G6AUpncGQljs,6867
34
- aimnet/train/pt2jpt.py,sha256=XMc95PrXvlHBUqtAVNJ0JzJP0ZJK-WrJ8rbSg4yxHCU,3129
35
- aimnet/train/train.py,sha256=S34BWOhfhe8fg9qGkp-xw6O4-sgHM6R2oQ0-e8dk6hg,5711
36
- aimnet/train/utils.py,sha256=2pYZbY1L9KdmuHqmfdQDId-TWdl-WMeNoeatcGZWPPE,14730
37
- aimnet-0.0.1.dist-info/LICENSE,sha256=73sk-zg2yVRrOZQDeVbPlVB7ScZc1iK0kyCBMwNwQgA,1072
38
- aimnet-0.0.1.dist-info/METADATA,sha256=RAVsz_d3HL2QFGV9H9M2J57riF1W9VMmhsjUWvTOSaw,3717
39
- aimnet-0.0.1.dist-info/WHEEL,sha256=Nq82e9rUAnEjt98J6MlVmMCZb-t9cYE2Ir1kpBmnWfs,88
40
- aimnet-0.0.1.dist-info/entry_points.txt,sha256=3SPZbjuQFS1V1xryBwVoo8ix0DNGYzHeoDbvw0z2Mdk,164
41
- aimnet-0.0.1.dist-info/RECORD,,
@@ -1,5 +0,0 @@
1
- [console_scripts]
2
- aimnet=aimnet.cli:cli
3
- aimnet-clear-models-cache=aimnet.calculators.model_registry:clear_assets
4
- aimnet2pysis=aimnet.calculators.aimnet2pysis:main
5
-