shnitsel-tools 0.0.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. shnitsel_tools-0.0.1/LICENSE +21 -0
  2. shnitsel_tools-0.0.1/PKG-INFO +197 -0
  3. shnitsel_tools-0.0.1/README.md +157 -0
  4. shnitsel_tools-0.0.1/pyproject.toml +103 -0
  5. shnitsel_tools-0.0.1/setup.cfg +4 -0
  6. shnitsel_tools-0.0.1/shnitsel/__init__.py +14 -0
  7. shnitsel_tools-0.0.1/shnitsel/core/__init__.py +39 -0
  8. shnitsel_tools-0.0.1/shnitsel/core/ase.py +148 -0
  9. shnitsel_tools-0.0.1/shnitsel/core/datasheet/__init__.py +135 -0
  10. shnitsel_tools-0.0.1/shnitsel/core/datasheet/colormaps.py +14 -0
  11. shnitsel_tools-0.0.1/shnitsel/core/datasheet/common.py +62 -0
  12. shnitsel_tools-0.0.1/shnitsel/core/datasheet/dip_trans_hist.py +211 -0
  13. shnitsel_tools-0.0.1/shnitsel/core/datasheet/hist.py +37 -0
  14. shnitsel_tools-0.0.1/shnitsel/core/datasheet/nacs_hist.py +52 -0
  15. shnitsel_tools-0.0.1/shnitsel/core/datasheet/oop.py +501 -0
  16. shnitsel_tools-0.0.1/shnitsel/core/datasheet/per_state_hist.py +38 -0
  17. shnitsel_tools-0.0.1/shnitsel/core/datasheet/structure.py +83 -0
  18. shnitsel_tools-0.0.1/shnitsel/core/datasheet/time.py +56 -0
  19. shnitsel_tools-0.0.1/shnitsel/core/filter_unphysical.py +192 -0
  20. shnitsel_tools-0.0.1/shnitsel/core/filtre.py +98 -0
  21. shnitsel_tools-0.0.1/shnitsel/core/indexes.py +16 -0
  22. shnitsel_tools-0.0.1/shnitsel/core/parse/__init__.py +327 -0
  23. shnitsel_tools-0.0.1/shnitsel/core/parse/common.py +135 -0
  24. shnitsel_tools-0.0.1/shnitsel/core/parse/nx.py +265 -0
  25. shnitsel_tools-0.0.1/shnitsel/core/parse/pyrai2md.py +224 -0
  26. shnitsel_tools-0.0.1/shnitsel/core/parse/sharc_icond.py +491 -0
  27. shnitsel_tools-0.0.1/shnitsel/core/parse/sharc_traj.py +282 -0
  28. shnitsel_tools-0.0.1/shnitsel/core/parse/xyz.py +34 -0
  29. shnitsel_tools-0.0.1/shnitsel/core/plot/__init__.py +20 -0
  30. shnitsel_tools-0.0.1/shnitsel/core/plot/kde.py +196 -0
  31. shnitsel_tools-0.0.1/shnitsel/core/plot/p3mhelpers.py +76 -0
  32. shnitsel_tools-0.0.1/shnitsel/core/plot/pca_biplot.py +476 -0
  33. shnitsel_tools-0.0.1/shnitsel/core/plot/polychrom.py +107 -0
  34. shnitsel_tools-0.0.1/shnitsel/core/plot/select.py +42 -0
  35. shnitsel_tools-0.0.1/shnitsel/core/plot/spectra3d.py +127 -0
  36. shnitsel_tools-0.0.1/shnitsel/core/postprocess.py +937 -0
  37. shnitsel_tools-0.0.1/shnitsel/core/spectra.py +56 -0
  38. shnitsel_tools-0.0.1/shnitsel/core/xrhelpers.py +342 -0
  39. shnitsel_tools-0.0.1/shnitsel/plot/__init__.py +11 -0
  40. shnitsel_tools-0.0.1/shnitsel/rd.py +1 -0
  41. shnitsel_tools-0.0.1/shnitsel/xarray.py +267 -0
  42. shnitsel_tools-0.0.1/shnitsel_tools.egg-info/PKG-INFO +197 -0
  43. shnitsel_tools-0.0.1/shnitsel_tools.egg-info/SOURCES.txt +52 -0
  44. shnitsel_tools-0.0.1/shnitsel_tools.egg-info/dependency_links.txt +1 -0
  45. shnitsel_tools-0.0.1/shnitsel_tools.egg-info/requires.txt +29 -0
  46. shnitsel_tools-0.0.1/shnitsel_tools.egg-info/top_level.txt +1 -0
  47. shnitsel_tools-0.0.1/tests/test_accessors.py +41 -0
  48. shnitsel_tools-0.0.1/tests/test_ase.py +29 -0
  49. shnitsel_tools-0.0.1/tests/test_parallel_parsing.py +6 -0
  50. shnitsel_tools-0.0.1/tests/test_plots.py +165 -0
  51. shnitsel_tools-0.0.1/tests/test_postprocess.py +78 -0
  52. shnitsel_tools-0.0.1/tests/test_pyrai2md.py +5 -0
  53. shnitsel_tools-0.0.1/tests/test_sharc_icond.py +12 -0
  54. shnitsel_tools-0.0.1/tests/test_si_consistency.py +85 -0
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 SHNITSEL
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,197 @@
1
+ Metadata-Version: 2.4
2
+ Name: shnitsel-tools
3
+ Version: 0.0.1
4
+ Summary: Display and interpret output of SHARC
5
+ Author-email: Robin Curth <robin.curth@uni-leipzig.de>, Theodor Everley Röhrkasten <theodor.roehrkasten@fau.de>, Carolin Müller <carolin.cpc.mueller@fau.de>, Julia Westermayr <julia.westermayr@uni-leipzig.de>
6
+ Classifier: Programming Language :: Python :: 3
7
+ Classifier: License :: OSI Approved :: MIT License
8
+ Classifier: Operating System :: OS Independent
9
+ Requires-Python: >=3.10
10
+ Description-Content-Type: text/markdown
11
+ License-File: LICENSE
12
+ Requires-Dist: scipy>=1.11.4
13
+ Requires-Dist: matplotlib
14
+ Requires-Dist: xarray
15
+ Requires-Dist: rdkit
16
+ Requires-Dist: ase
17
+ Requires-Dist: numpy>=1.26.2
18
+ Requires-Dist: pandas>=2.1.4
19
+ Requires-Dist: h5netcdf>=1.3.0
20
+ Requires-Dist: scikit-learn>=1.3.2
21
+ Requires-Dist: jupyter>=1.0.0
22
+ Requires-Dist: py3dmol>=2.0.4
23
+ Requires-Dist: tqdm>=4.66.2
24
+ Provides-Extra: extras
25
+ Requires-Dist: plotly>=5.22.0; extra == "extras"
26
+ Requires-Dist: seaborn>=0.13.1; extra == "extras"
27
+ Requires-Dist: bokeh; extra == "extras"
28
+ Provides-Extra: dev
29
+ Requires-Dist: pytest>=5.2; extra == "dev"
30
+ Requires-Dist: pytest-cov>=4.1.0; extra == "dev"
31
+ Requires-Dist: hypothesis; extra == "dev"
32
+ Requires-Dist: tox>=4.10.0; extra == "dev"
33
+ Requires-Dist: pre-commit>=2.20.0; extra == "dev"
34
+ Requires-Dist: ruff==0.4.1; extra == "dev"
35
+ Requires-Dist: mypy>=1.9.0; extra == "dev"
36
+ Requires-Dist: pandas-stubs; extra == "dev"
37
+ Requires-Dist: scipy-stubs; extra == "dev"
38
+ Requires-Dist: types-tqdm; extra == "dev"
39
+ Dynamic: license-file
40
+
41
+ <div align="center">
42
+ <h1>shnitsel-tools</h1>
43
+ <img src="https://github.com/SHNITSEL/shnitsel-tools/blob/main/logo_shnitsel_tools.png" alt="SHNITSEL-TOOLS Logo" width="200px">
44
+ <h3>Surface Hopping Nested Instances Training Set for Excited-state Learning Tools</h3>
45
+ <br>
46
+ <a href="[https://zenodo.org/records/15482819](https://shnitsel.github.io/)">
47
+ <img src="https://img.shields.io/badge/Website-shnitsel.github.io-yellow.svg" alt="DOI">
48
+ </a>
49
+ </div>
50
+
51
+ --------------------
52
+
53
+ ## About
54
+
55
+ `shnitsel-tools` is designed to to support the entire data lifecycle of surface hopping (SH) trajectory data upon simulation: data managment, storage, processing, visualization and interpretation.
56
+ The tool is compatible with surface hopping data generated using the software packages [SHARC 3/4](https://sharc-md.org/), [Newton-X](https://newtonx.org/), and [PyRAI2MD](https://github.com/lopez-lab/PyRAI2MD).
57
+ The `shnitsel` code leverages Xarray to benefit from efficient multidimensional data handling, improved metadata management, and a structure that aligns naturally with the needs of quantum chemical datasets.
58
+
59
+ ## Usage
60
+
61
+ The package is organized into ...
62
+
63
+ > [!TIP]
64
+ > ### Tutorials
65
+ >
66
+ > For a quick start, see the [tutorials](https://github.com/SHNITSEL/shnitsel-tools/blob/main/tutorials) directory,
67
+ > which contains Jupyter Notebooks showing the workflow for parsing, writing and loading SHNITSEL databases as well as how to postprocess and visualize the respective data.
68
+ >
69
+ > #### Collection & storage
70
+ > - [parsing trajcetory and initial condition data obtained by SHARC](https://github.com/SHNITSEL/shnitsel-tools/blob/main/tutorials/0_1_sharc2hdf5.ipynb)
71
+ > - [parsing trajectory data produced with Newton-X](https://github.com/SHNITSEL/shnitsel-tools/blob/main/tutorials/0_2_nx2hdf5.ipynb)
72
+ > - [convert ASE databases](https://github.com/SHNITSEL/shnitsel-tools/blob/main/tutorials/0_4_ase2hdf5.ipynb)
73
+ >
74
+ > #### Management
75
+ >
76
+ > [exploration of electronic properties](https://github.com/SHNITSEL/shnitsel-tools/blob/main/tutorials/2_2_PS_explore.ipynb)
77
+ >
78
+ > #### Postprocessing & visualization of data
79
+ > - [datasheet for trajectory data](https://github.com/SHNITSEL/shnitsel-tools/blob/main/tutorials/3_1_datasheet.ipynb)
80
+ > - [principal component analysis and trajectory classification](https://github.com/SHNITSEL/shnitsel-tools/blob/main/tutorials/1_1_GS_PCA.ipynb)
81
+ >
82
+ > #### Walkthroughs
83
+ >
84
+ > (Link to folder, further info in README)
85
+ >
86
+
87
+ ```bash
88
+ shnitsel
89
+ ├── core
90
+ │   ├── ase.py
91
+ │   ├── datasheet
92
+ │   │   ├── __init__.py
93
+ │   │   ├── oop.py
94
+ │   │   └── plot
95
+ │   │   ├── colormaps.py
96
+ │   │   ├── common.py
97
+ │   │   ├── dip_trans_hist.py
98
+ │   │   ├── hist.py
99
+ │   │   ├── __init__.py
100
+ │   │   ├── nacs_hist.py
101
+ │   │   ├── per_state_hist.py
102
+ │   │   ├── structure.py
103
+ │   │   └── time.py
104
+ │   ├── filter_unphysical.py
105
+ │   ├── filtre.py
106
+ │   ├── indexes.py
107
+ │   ├── __init__.py
108
+ │   ├── parse
109
+ │   │   ├── common.py
110
+ │   │   ├── __init__.py
111
+ │   │   ├── nx.py
112
+ │   │   ├── pyrai2md.py
113
+ │   │   ├── sharc_icond.py
114
+ │   │   ├── sharc_traj.py
115
+ │   │   └── xyz.py
116
+ │   ├── pca_biplot.py
117
+ │   ├── plot
118
+ │   │   ├── __init__.py
119
+ │   │   ├── kde.py
120
+ │   │   ├── p3mhelpers.py
121
+ │   │   ├── select.py
122
+ │   │   └── spectra3d.py
123
+ │   ├── plotting.py
124
+ │   ├── postprocess.py
125
+ │   ├── spectra.py
126
+ │   └── xrhelpers.py
127
+ ├── __init__.py
128
+ ├── plot
129
+ │   └── __init__.py
130
+ ├── rd.py
131
+ ├── README.md
132
+ └── xarray.py
133
+ ```
134
+
135
+ ## Installation
136
+
137
+ You can create the environment with a custom path using one of the following methods:
138
+
139
+ <details open>
140
+ <summary><strong>Option 1: Using `uv`</strong></summary>
141
+ We recommend to use the `uv` tool, available at https://docs.astral.sh/uv/.
142
+ Run the following in the `shnitsel-tools` directory:
143
+
144
+ ```bash
145
+ uv venv # create an environment under ./.venv
146
+ . .venv/bin/activate # activate the new environment
147
+ uv pip install -e . # install shnitsel in editable mode
148
+ ```
149
+
150
+ To install the optional development dependencies run
151
+
152
+ ```bash
153
+ uv pip install -e '.[dev]'
154
+ ```
155
+
156
+ </details>
157
+
158
+ <details open>
159
+ <summary><strong>Option 2: Using the `--prefix` Flag</strong></summary>
160
+
161
+ You can create the environment and specify the desired path by using the `conda env create` command with the `--prefix` flag:
162
+
163
+ ```bash
164
+ conda env create --prefix /home/user/anaconda3/envs/shnitsel -f shnitsel-tools.yml
165
+ ```
166
+ </details>
167
+
168
+ <details>
169
+ <summary><strong>Option 3: Adding the Path to the .yml File</strong></summary>
170
+
171
+ Alternatively, you can manually add the desired path to the shnitsel-tools.yml file and create the environment directly:
172
+
173
+ 1) Open the shnitsel-tools.yml file for editing:
174
+
175
+ ```bash
176
+ vi shnitsel-tools.yml
177
+ ```
178
+
179
+ 2) Add the following line to the file:
180
+
181
+
182
+ ```
183
+ prefix: /home/user/anaconda3/envs/shnitsel
184
+ ```
185
+
186
+ 3) Create the environment with a custom path.
187
+
188
+ ```bash
189
+ conda env create -f shnitsel-rdkit.yml
190
+ ```
191
+ </details>
192
+
193
+ ## Further Information
194
+
195
+ [![Website](https://img.shields.io/badge/Website-shnitsel.github.io-yellow.svg)](https://shnitsel.github.io/)
196
+
197
+
@@ -0,0 +1,157 @@
1
+ <div align="center">
2
+ <h1>shnitsel-tools</h1>
3
+ <img src="https://github.com/SHNITSEL/shnitsel-tools/blob/main/logo_shnitsel_tools.png" alt="SHNITSEL-TOOLS Logo" width="200px">
4
+ <h3>Surface Hopping Nested Instances Training Set for Excited-state Learning Tools</h3>
5
+ <br>
6
+ <a href="[https://zenodo.org/records/15482819](https://shnitsel.github.io/)">
7
+ <img src="https://img.shields.io/badge/Website-shnitsel.github.io-yellow.svg" alt="DOI">
8
+ </a>
9
+ </div>
10
+
11
+ --------------------
12
+
13
+ ## About
14
+
15
+ `shnitsel-tools` is designed to to support the entire data lifecycle of surface hopping (SH) trajectory data upon simulation: data managment, storage, processing, visualization and interpretation.
16
+ The tool is compatible with surface hopping data generated using the software packages [SHARC 3/4](https://sharc-md.org/), [Newton-X](https://newtonx.org/), and [PyRAI2MD](https://github.com/lopez-lab/PyRAI2MD).
17
+ The `shnitsel` code leverages Xarray to benefit from efficient multidimensional data handling, improved metadata management, and a structure that aligns naturally with the needs of quantum chemical datasets.
18
+
19
+ ## Usage
20
+
21
+ The package is organized into ...
22
+
23
+ > [!TIP]
24
+ > ### Tutorials
25
+ >
26
+ > For a quick start, see the [tutorials](https://github.com/SHNITSEL/shnitsel-tools/blob/main/tutorials) directory,
27
+ > which contains Jupyter Notebooks showing the workflow for parsing, writing and loading SHNITSEL databases as well as how to postprocess and visualize the respective data.
28
+ >
29
+ > #### Collection & storage
30
+ > - [parsing trajcetory and initial condition data obtained by SHARC](https://github.com/SHNITSEL/shnitsel-tools/blob/main/tutorials/0_1_sharc2hdf5.ipynb)
31
+ > - [parsing trajectory data produced with Newton-X](https://github.com/SHNITSEL/shnitsel-tools/blob/main/tutorials/0_2_nx2hdf5.ipynb)
32
+ > - [convert ASE databases](https://github.com/SHNITSEL/shnitsel-tools/blob/main/tutorials/0_4_ase2hdf5.ipynb)
33
+ >
34
+ > #### Management
35
+ >
36
+ > [exploration of electronic properties](https://github.com/SHNITSEL/shnitsel-tools/blob/main/tutorials/2_2_PS_explore.ipynb)
37
+ >
38
+ > #### Postprocessing & visualization of data
39
+ > - [datasheet for trajectory data](https://github.com/SHNITSEL/shnitsel-tools/blob/main/tutorials/3_1_datasheet.ipynb)
40
+ > - [principal component analysis and trajectory classification](https://github.com/SHNITSEL/shnitsel-tools/blob/main/tutorials/1_1_GS_PCA.ipynb)
41
+ >
42
+ > #### Walkthroughs
43
+ >
44
+ > (Link to folder, further info in README)
45
+ >
46
+
47
+ ```bash
48
+ shnitsel
49
+ ├── core
50
+ │   ├── ase.py
51
+ │   ├── datasheet
52
+ │   │   ├── __init__.py
53
+ │   │   ├── oop.py
54
+ │   │   └── plot
55
+ │   │   ├── colormaps.py
56
+ │   │   ├── common.py
57
+ │   │   ├── dip_trans_hist.py
58
+ │   │   ├── hist.py
59
+ │   │   ├── __init__.py
60
+ │   │   ├── nacs_hist.py
61
+ │   │   ├── per_state_hist.py
62
+ │   │   ├── structure.py
63
+ │   │   └── time.py
64
+ │   ├── filter_unphysical.py
65
+ │   ├── filtre.py
66
+ │   ├── indexes.py
67
+ │   ├── __init__.py
68
+ │   ├── parse
69
+ │   │   ├── common.py
70
+ │   │   ├── __init__.py
71
+ │   │   ├── nx.py
72
+ │   │   ├── pyrai2md.py
73
+ │   │   ├── sharc_icond.py
74
+ │   │   ├── sharc_traj.py
75
+ │   │   └── xyz.py
76
+ │   ├── pca_biplot.py
77
+ │   ├── plot
78
+ │   │   ├── __init__.py
79
+ │   │   ├── kde.py
80
+ │   │   ├── p3mhelpers.py
81
+ │   │   ├── select.py
82
+ │   │   └── spectra3d.py
83
+ │   ├── plotting.py
84
+ │   ├── postprocess.py
85
+ │   ├── spectra.py
86
+ │   └── xrhelpers.py
87
+ ├── __init__.py
88
+ ├── plot
89
+ │   └── __init__.py
90
+ ├── rd.py
91
+ ├── README.md
92
+ └── xarray.py
93
+ ```
94
+
95
+ ## Installation
96
+
97
+ You can create the environment with a custom path using one of the following methods:
98
+
99
+ <details open>
100
+ <summary><strong>Option 1: Using `uv`</strong></summary>
101
+ We recommend to use the `uv` tool, available at https://docs.astral.sh/uv/.
102
+ Run the following in the `shnitsel-tools` directory:
103
+
104
+ ```bash
105
+ uv venv # create an environment under ./.venv
106
+ . .venv/bin/activate # activate the new environment
107
+ uv pip install -e . # install shnitsel in editable mode
108
+ ```
109
+
110
+ To install the optional development dependencies run
111
+
112
+ ```bash
113
+ uv pip install -e '.[dev]'
114
+ ```
115
+
116
+ </details>
117
+
118
+ <details open>
119
+ <summary><strong>Option 2: Using the `--prefix` Flag</strong></summary>
120
+
121
+ You can create the environment and specify the desired path by using the `conda env create` command with the `--prefix` flag:
122
+
123
+ ```bash
124
+ conda env create --prefix /home/user/anaconda3/envs/shnitsel -f shnitsel-tools.yml
125
+ ```
126
+ </details>
127
+
128
+ <details>
129
+ <summary><strong>Option 3: Adding the Path to the .yml File</strong></summary>
130
+
131
+ Alternatively, you can manually add the desired path to the shnitsel-tools.yml file and create the environment directly:
132
+
133
+ 1) Open the shnitsel-tools.yml file for editing:
134
+
135
+ ```bash
136
+ vi shnitsel-tools.yml
137
+ ```
138
+
139
+ 2) Add the following line to the file:
140
+
141
+
142
+ ```
143
+ prefix: /home/user/anaconda3/envs/shnitsel
144
+ ```
145
+
146
+ 3) Create the environment with a custom path.
147
+
148
+ ```bash
149
+ conda env create -f shnitsel-rdkit.yml
150
+ ```
151
+ </details>
152
+
153
+ ## Further Information
154
+
155
+ [![Website](https://img.shields.io/badge/Website-shnitsel.github.io-yellow.svg)](https://shnitsel.github.io/)
156
+
157
+
@@ -0,0 +1,103 @@
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "shnitsel-tools"
7
+ version = "0.0.1"
8
+ authors = [
9
+ { name="Robin Curth", email="robin.curth@uni-leipzig.de" },
10
+ { name="Theodor Everley Röhrkasten", email="theodor.roehrkasten@fau.de" },
11
+ { name="Carolin Müller", email="carolin.cpc.mueller@fau.de" },
12
+ { name="Julia Westermayr", email="julia.westermayr@uni-leipzig.de" },
13
+
14
+ ]
15
+ description = "Display and interpret output of SHARC"
16
+ readme = "README.md"
17
+ requires-python = ">=3.10"
18
+ classifiers = [
19
+ "Programming Language :: Python :: 3",
20
+ "License :: OSI Approved :: MIT License",
21
+ "Operating System :: OS Independent",
22
+ ]
23
+ dependencies = [
24
+ "scipy>=1.11.4",
25
+ "matplotlib",
26
+ "xarray",
27
+ "rdkit",
28
+ "ase",
29
+ "numpy>=1.26.2",
30
+ "pandas>=2.1.4",
31
+ "h5netcdf>=1.3.0",
32
+ "scikit-learn>=1.3.2",
33
+ "jupyter>=1.0.0",
34
+ "py3dmol>=2.0.4",
35
+ "tqdm>=4.66.2",
36
+ ]
37
+
38
+ [project.optional-dependencies]
39
+ extras = [
40
+ "plotly>=5.22.0",
41
+ "seaborn>=0.13.1",
42
+ "bokeh"
43
+ ]
44
+
45
+ dev = [
46
+ "pytest>=5.2",
47
+ "pytest-cov>=4.1.0",
48
+ "hypothesis", # ==6.127.3
49
+ "tox>=4.10.0",
50
+ "pre-commit>=2.20.0",
51
+ "ruff==0.4.1",
52
+ "mypy>=1.9.0",
53
+ "pandas-stubs",
54
+ "scipy-stubs",
55
+ "types-tqdm",
56
+ ]
57
+
58
+ [tool.setuptools.packages.find]
59
+ where = ["."]
60
+ include = ["shnitsel*"]
61
+
62
+ [tool.ruff]
63
+ include = ["shnitsel/**/*.py"]
64
+
65
+ [tool.ruff.format]
66
+ quote-style = "preserve"
67
+
68
+ [tool.ruff.lint]
69
+ ignore = [
70
+ "E731", # Do not assign a `lambda` expression, use a `def`
71
+ "E741", # Ambiguous variable name
72
+ "E70", # Multiple statements on one line
73
+ ]
74
+
75
+ [tool.ruff.lint.flake8-annotations]
76
+ ignore-fully-untyped = true
77
+
78
+ [tool.mypy]
79
+ packages = "shnitsel"
80
+ warn_return_any = true
81
+ warn_unused_configs = true
82
+ check_untyped_defs = true
83
+ pretty = true
84
+ color_output = true
85
+ error_summary = true
86
+
87
+ [[tool.mypy.overrides]]
88
+ module = [
89
+ "scipy",
90
+ "sklearn.*"
91
+ ]
92
+ ignore_missing_imports = true
93
+
94
+ [tool.tox]
95
+ min_version = "4.19"
96
+ env_list = ["py313", "py312", "py311", "py310"]
97
+
98
+ [tool.tox.testenv]
99
+ deps = ["pytest"]
100
+ commands = [
101
+ ["python", "--version"],
102
+ ["pytest"]
103
+ ]
@@ -0,0 +1,4 @@
1
+ [egg_info]
2
+ tag_build =
3
+ tag_date = 0
4
+
@@ -0,0 +1,14 @@
1
+ from shnitsel import (
2
+ core as core,
3
+ plot as plot,
4
+ )
5
+ from shnitsel.core import (
6
+ parse as parse,
7
+ postprocess as postprocess,
8
+ xrhelpers as xrhelpers,
9
+ )
10
+ from shnitsel.core.xrhelpers import open_frames as open_frames
11
+ from shnitsel.core.parse import read_trajs as read_trajs
12
+ from shnitsel.core.ase import read_ase as read_ase
13
+
14
+ __all__ = ['plot', 'parse', 'open_frames', 'read_trajs', 'read_ase']
@@ -0,0 +1,39 @@
1
+ from . import (
2
+ ase,
3
+ filter_unphysical,
4
+ parse,
5
+ plot,
6
+ postprocess,
7
+ xrhelpers,
8
+ )
9
+
10
+ from .xrhelpers import (
11
+ open_frames as open_frames,
12
+ save_frames as save_frames,
13
+ )
14
+ from .postprocess import (
15
+ dihedral as dihedral,
16
+ get_per_state as get_per_state,
17
+ get_inter_state as get_inter_state,
18
+ assign_fosc as assign_fosc,
19
+ )
20
+
21
+ from .plot import pca_biplot
22
+ from .plot.spectra3d import spectra_all_times as spectra_all_times
23
+
24
+ __all__ = [
25
+ 'ase',
26
+ 'parse',
27
+ 'postprocess',
28
+ 'xrhelpers',
29
+ 'filter_unphysical',
30
+ 'pca_biplot',
31
+ 'plot',
32
+ 'open_frames',
33
+ 'save_frames',
34
+ 'dihedral',
35
+ 'get_per_state',
36
+ 'get_inter_state',
37
+ 'assign_fosc',
38
+ 'spectra_all_times',
39
+ ]
@@ -0,0 +1,148 @@
1
+ import os
2
+ from typing import Collection
3
+
4
+ from ase import Atoms
5
+ from ase.db import connect
6
+ import numpy as np
7
+ import xarray as xr
8
+
9
+
10
+ def _prepare_for_write(frames: xr.Dataset) -> xr.Dataset:
11
+ # Recombine permanent and transition dipoles, as schnetpack expects
12
+ dipoles: np.ndarray | xr.DataArray | None = None
13
+ frames = frames.copy(deep=False)
14
+ if 'dipoles' in frames:
15
+ dipoles = frames['dipoles']
16
+ elif 'dip_perm' in frames and 'dip_trans' in frames:
17
+ dip_perm = frames['dip_perm'].transpose('frame', 'state', 'direction').data
18
+ dip_trans = (
19
+ frames['dip_trans'].transpose('frame', 'statecomb', 'direction').data
20
+ )
21
+ dipoles = np.concat((dip_perm, dip_trans.data), axis=1)
22
+ del frames['dip_perm'], frames['dip_trans']
23
+ elif 'dip_perm' in frames:
24
+ dipoles = frames['dip_perm']
25
+ del frames['dip_perm']
26
+ elif 'dip_trans' in frames:
27
+ dipoles = frames['dip_trans']
28
+ del frames['dip_trans']
29
+
30
+ if dipoles is not None:
31
+ frames['dipoles'] = ['frame', 'state_or_statecomb', 'direction'], dipoles
32
+
33
+ return frames
34
+
35
+
36
+ def write_ase(
37
+ frames: xr.Dataset,
38
+ db_path: str,
39
+ kind: str | None,
40
+ keys: Collection | None = None,
41
+ preprocess: bool = True,
42
+ ):
43
+ if preprocess:
44
+ frames = _prepare_for_write(frames)
45
+
46
+ statedims = ['state', 'statecomb', 'state_or_statecomb']
47
+ if kind == 'schnet':
48
+ order = ['frame', *statedims, 'atom', 'direction']
49
+ frames = frames.transpose(*order, missing_dims='ignore')
50
+ elif kind == 'spainn':
51
+ frames['energy'] = frames['energy'].expand_dims('tmp', axis=1)
52
+ order = ['frame', 'tmp', 'atom', *statedims, 'direction']
53
+ frames = frames.transpose(*order, missing_dims='ignore')
54
+ elif kind is None:
55
+ # leave the axis orders as they are
56
+ pass
57
+ else:
58
+ raise ValueError(
59
+ f"'kind' should be one of 'schnet', 'spainn' or None, not '{kind}'"
60
+ )
61
+
62
+ if os.path.exists(db_path):
63
+ os.remove(db_path)
64
+
65
+ if not keys:
66
+ keys = frames.data_vars.keys()
67
+ keys = set(frames.data_vars).intersection(keys).difference({'atNames'})
68
+
69
+ with connect(db_path, type='db') as db:
70
+ for i, frame in frames.groupby('frame'):
71
+ frame = frame.squeeze('frame')
72
+ db.write(
73
+ Atoms(symbols=frame['atNames'].data, positions=frame['atXYZ']),
74
+ data={k: frame[k].data for k in keys},
75
+ )
76
+
77
+
78
+ def read_ase(db_path: str, kind: str):
79
+ """Reads an ASE DB containing data in the SPaiNN or SchNet format
80
+
81
+ Parameters
82
+ ----------
83
+ db_path
84
+ Path to the database
85
+ kind
86
+ Must be one of 'spainn' or 'schnet'; determines interpretation of array shapes
87
+
88
+ Returns
89
+ -------
90
+ An `xr.Dataset` of frames
91
+
92
+ Raises
93
+ ------
94
+ ValueError
95
+ If `kind` is not one of 'spainn' or 'schnet'
96
+ FileNotFoundError
97
+ If `db_path` is not a file
98
+ """
99
+ if kind == 'schnet':
100
+ shapes = {
101
+ 'energy': ['frame', 'state'],
102
+ 'forces': ['frame', 'state', 'atom', 'direction'],
103
+ 'nacs': ['frame', 'statecomb', 'atom', 'direction'],
104
+ 'dipoles': ['frame', 'state_or_statecomb', 'direction'],
105
+ }
106
+ elif kind == 'spainn':
107
+ shapes = {
108
+ 'energy': ['frame', 'tmp', 'state'], # Note the extra dim, removed below
109
+ 'forces': ['frame', 'atom', 'state', 'direction'],
110
+ 'nacs': ['frame', 'atom', 'statecomb', 'direction'],
111
+ 'dipoles': ['frame', 'state_or_statecomb', 'direction'],
112
+ }
113
+ else:
114
+ raise ValueError(f"'kind' should be one of 'schnet' or 'spainn', not '{kind}'")
115
+
116
+ if not os.path.isfile(db_path):
117
+ raise FileNotFoundError(db_path)
118
+
119
+ with connect(db_path) as db:
120
+ data_vars = {}
121
+ for name, dims in shapes.items():
122
+ try:
123
+ data = np.stack([row.data[name] for row in db.select()])
124
+ data_vars[name] = dims, data
125
+ except KeyError:
126
+ pass
127
+
128
+ atXYZ = np.stack([row.positions for row in db.select()])
129
+ data_vars['atXYZ'] = ['frame', 'atom', 'direction'], atXYZ
130
+ atNames = ['atom'], next(db.select()).symbols
131
+
132
+ if 'dipoles' in data_vars:
133
+ nstates = data_vars['energy'][1].shape[1]
134
+
135
+ dipoles = data_vars['dipoles'][1]
136
+ dip_perm = dipoles[:, :nstates, :]
137
+ dip_trans = dipoles[:, nstates:, :]
138
+ del data_vars['dipoles']
139
+
140
+ data_vars['dip_perm'] = ['frame', 'state', 'direction'], dip_perm
141
+ data_vars['dip_trans'] = ['frame', 'statecomb', 'direction'], dip_trans
142
+
143
+ frames = xr.Dataset(data_vars).assign_coords(atNames=atNames)
144
+ if kind == 'spainn':
145
+ assert 'tmp' in frames.dims
146
+ frames = frames.squeeze('tmp')
147
+
148
+ return frames