off 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- off/__init__.py +23 -0
- off/atom_energies.py +151 -0
- off/config/_config.py +108 -0
- off/dft_distrax/__init__.py +27 -0
- off/dft_distrax/dft_distrax.py +216 -0
- off/flow/__init__.py +29 -0
- off/flow/equiv_flows.py +99 -0
- off/functionals/__init__.py +35 -0
- off/functionals/core_correction.py +84 -0
- off/functionals/exchange_correlation.py +174 -0
- off/functionals/external.py +49 -0
- off/functionals/functional.py +129 -0
- off/functionals/hartree.py +62 -0
- off/functionals/kinetic.py +87 -0
- off/main.py +172 -0
- off/ode_solver/__init__.py +32 -0
- off/ode_solver/eqx_ode.py +76 -0
- off/plot_binding_csv.py +63 -0
- off/plot_pes_ema.py +259 -0
- off/plot_pes_mpl.py +280 -0
- off/promolecular/__init__.py +27 -0
- off/promolecular/promolecular_dist.py +465 -0
- off/quadrature.py +261 -0
- off/quadrature_scan.py +188 -0
- off/scan_pes.py +133 -0
- off/test_fwd_rev.py +290 -0
- off/train/__init__.py +44 -0
- off/train/loop.py +228 -0
- off/train/loss.py +149 -0
- off/train/utils.py +38 -0
- off/utils.py +618 -0
- off-0.1.0.dist-info/METADATA +154 -0
- off-0.1.0.dist-info/RECORD +37 -0
- off-0.1.0.dist-info/WHEEL +5 -0
- off-0.1.0.dist-info/entry_points.txt +3 -0
- off-0.1.0.dist-info/licenses/LICENSE +21 -0
- off-0.1.0.dist-info/top_level.txt +1 -0
off/utils.py
ADDED
|
@@ -0,0 +1,618 @@
|
|
|
1
|
+
from typing import Any, Callable, Literal, Union
|
|
2
|
+
from functools import partial
|
|
3
|
+
|
|
4
|
+
import jax
|
|
5
|
+
import jax.numpy as jnp
|
|
6
|
+
import jax.random as jrnd
|
|
7
|
+
from jax import jit,grad,vmap,hessian,lax
|
|
8
|
+
from jaxtyping import Array
|
|
9
|
+
from jax._src import prng
|
|
10
|
+
|
|
11
|
+
from pyscf.data.nist import BOHR
|
|
12
|
+
|
|
13
|
+
import optax
|
|
14
|
+
|
|
15
|
+
from diffrax import (
|
|
16
|
+
Euler, Heun, Midpoint, Ralston, Bosh3, Tsit5, Dopri5, Dopri8,
|
|
17
|
+
ImplicitEuler, Kvaerno3, Kvaerno4, Kvaerno5, Sil3,
|
|
18
|
+
KenCarp3, KenCarp4, KenCarp5, SemiImplicitEuler,
|
|
19
|
+
ReversibleHeun, LeapfrogMidpoint
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
Dtype = Any
|
|
23
|
+
|
|
24
|
+
SolverType = Literal[
|
|
25
|
+
'euler', 'heun', 'midpoint', 'ralston', 'bosh3',
|
|
26
|
+
'tsit5', 'dopri5', 'dopri8', 'implicit_euler',
|
|
27
|
+
'kvaerno3', 'kvaerno4', 'kvaerno5', 'sil3',
|
|
28
|
+
'ken_carp3', 'ken_carp4', 'ken_carp5',
|
|
29
|
+
'semi_implicit_euler', 'reversible_heun',
|
|
30
|
+
'leapfrog_midpoint'
|
|
31
|
+
]
|
|
32
|
+
|
|
33
|
+
@partial(jit, static_argnums=(2,))
|
|
34
|
+
def compute_integral(params: Any, grid_array: Any, rho: Any, Ne: int):
|
|
35
|
+
grid_coords, grid_weights = grid_array
|
|
36
|
+
rho_val = rho(params, grid_coords)
|
|
37
|
+
return jnp.vdot(grid_weights, rho_val)
|
|
38
|
+
|
|
39
|
+
@partial(jit, static_argnums=(2,))
|
|
40
|
+
def laplacian(params: Any, X: Array, fun: callable) -> jax.Array:
|
|
41
|
+
"""_summary_
|
|
42
|
+
|
|
43
|
+
Parameters
|
|
44
|
+
----------
|
|
45
|
+
params : Any
|
|
46
|
+
_description_
|
|
47
|
+
X : Array
|
|
48
|
+
_description_
|
|
49
|
+
fun : callable
|
|
50
|
+
_description_
|
|
51
|
+
|
|
52
|
+
Returns
|
|
53
|
+
-------
|
|
54
|
+
jax.Array
|
|
55
|
+
_description_
|
|
56
|
+
"""
|
|
57
|
+
@partial(jit, static_argnums=(2,))
|
|
58
|
+
def _laplacian(params: Any, X: Array, fun: callable):
|
|
59
|
+
hes_ = hessian(fun, argnums=1)(
|
|
60
|
+
params, X[jnp.newaxis], ) # R[jnp.newaxis]
|
|
61
|
+
hes_ = jnp.squeeze(hes_, axis=(0, 2, 4))
|
|
62
|
+
hes_ = jnp.einsum('...ii', hes_)
|
|
63
|
+
return hes_
|
|
64
|
+
|
|
65
|
+
v_laplacian = vmap(_laplacian, in_axes=(None, 0, None))
|
|
66
|
+
return v_laplacian(params, X, fun)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
@partial(jit, static_argnums=(2,))
|
|
70
|
+
def score(params: Any, X: Array, fun: callable) -> jax.Array:
|
|
71
|
+
"""
|
|
72
|
+
Function that evaluates the score of the model. Using jax.jacrev to compute the gradient of the function
|
|
73
|
+
and then reshaping the output to the original shape of the input and then vmapping the function
|
|
74
|
+
to evaluate the score for each element in the input.
|
|
75
|
+
|
|
76
|
+
Parameters
|
|
77
|
+
----------
|
|
78
|
+
params : Any
|
|
79
|
+
Parameters of the model.
|
|
80
|
+
X : Array
|
|
81
|
+
X values to evaluate the score.
|
|
82
|
+
fun : callable
|
|
83
|
+
Function to evaluate the score.
|
|
84
|
+
|
|
85
|
+
Returns
|
|
86
|
+
-------
|
|
87
|
+
jax.Array
|
|
88
|
+
The score of the model.
|
|
89
|
+
"""
|
|
90
|
+
@jit
|
|
91
|
+
def _score(params: Any, xi: Array):
|
|
92
|
+
score_ = jax.jacrev(fun, argnums=1)(params, xi[jnp.newaxis])
|
|
93
|
+
return jnp.reshape(score_, xi.shape[0])
|
|
94
|
+
|
|
95
|
+
v_score = vmap(_score, in_axes=(None, 0))
|
|
96
|
+
return v_score(params, X)
|
|
97
|
+
|
|
98
|
+
def batch_generator(key: prng.PRNGKeyArray, batch_size: int, prior_dist):
|
|
99
|
+
"""Generator with optional pre-computed score."""
|
|
100
|
+
|
|
101
|
+
# Check if prior has its own score method
|
|
102
|
+
if hasattr(prior_dist, 'score'):
|
|
103
|
+
v_score = prior_dist.score
|
|
104
|
+
else:
|
|
105
|
+
# v_score = jax.vmap(jax.grad(lambda x: prior_dist.log_prob(x).sum()))
|
|
106
|
+
v_score = vmap(jax.jacrev(lambda x:
|
|
107
|
+
prior_dist.log_prob(x)))
|
|
108
|
+
|
|
109
|
+
while True:
|
|
110
|
+
_, key = jrnd.split(key)
|
|
111
|
+
samples = prior_dist.sample(seed=key, sample_shape=batch_size)
|
|
112
|
+
logp_samples = prior_dist.log_prob(samples)
|
|
113
|
+
score = v_score(samples)
|
|
114
|
+
samples0 = lax.concatenate((samples, logp_samples, score), 1)
|
|
115
|
+
|
|
116
|
+
_, key = jrnd.split(key)
|
|
117
|
+
samples = prior_dist.sample(seed=key, sample_shape=batch_size)
|
|
118
|
+
logp_samples = prior_dist.log_prob(samples)
|
|
119
|
+
score = v_score(samples)
|
|
120
|
+
samples1 = lax.concatenate((samples, logp_samples, score), 1)
|
|
121
|
+
|
|
122
|
+
yield lax.concatenate((samples0, samples1), 0)
|
|
123
|
+
|
|
124
|
+
def batch_generator_(key: prng.PRNGKeyArray, batch_size: int, prior_dist: Callable):
|
|
125
|
+
"""
|
|
126
|
+
Generator that yields batches of samples from the prior distribution.
|
|
127
|
+
|
|
128
|
+
Parameters
|
|
129
|
+
----------
|
|
130
|
+
key : prng.PRNGKeyArray
|
|
131
|
+
Key to generate random numbers.
|
|
132
|
+
batch_size : int
|
|
133
|
+
Size of the batch.
|
|
134
|
+
prior_dist : Callable
|
|
135
|
+
Prior distribution.
|
|
136
|
+
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
v_score = jax.vmap(jax.grad(lambda x:
|
|
140
|
+
prior_dist.log_prob(x).sum()))
|
|
141
|
+
while True:
|
|
142
|
+
_, key = jrnd.split(key)
|
|
143
|
+
samples = prior_dist.sample(seed=key, sample_shape=batch_size)
|
|
144
|
+
logp_samples = prior_dist.log_prob(samples)
|
|
145
|
+
score = v_score(samples)
|
|
146
|
+
samples0 = lax.concatenate(
|
|
147
|
+
(samples, logp_samples, score), 1)
|
|
148
|
+
|
|
149
|
+
_, key = jrnd.split(key)
|
|
150
|
+
samples = prior_dist.sample(seed=key, sample_shape=batch_size)
|
|
151
|
+
logp_samples = prior_dist.log_prob(samples)
|
|
152
|
+
score = v_score(samples)
|
|
153
|
+
samples1 = lax.concatenate(
|
|
154
|
+
(samples, logp_samples, score), 1)
|
|
155
|
+
|
|
156
|
+
yield lax.concatenate((samples0, samples1), 0)
|
|
157
|
+
|
|
158
|
+
def batche_generator_1D(key: prng.PRNGKeyArray, batch_size: int, prior_dist: Callable):
|
|
159
|
+
"""
|
|
160
|
+
Generator that yields batches of samples from the prior distribution.
|
|
161
|
+
|
|
162
|
+
Parameters
|
|
163
|
+
----------
|
|
164
|
+
key : prng.PRNGKeyArray
|
|
165
|
+
Key to generate random numbers.
|
|
166
|
+
batch_size : int
|
|
167
|
+
Size of the batch.
|
|
168
|
+
prior_dist : Callable
|
|
169
|
+
Prior distribution.
|
|
170
|
+
|
|
171
|
+
"""
|
|
172
|
+
v_score = vmap(jax.jacrev(lambda x:
|
|
173
|
+
prior_dist.log_prob(x)))
|
|
174
|
+
# v_score = jax.vmap(jax.grad(lambda x:
|
|
175
|
+
# prior_dist.log_prob(x).sum()))
|
|
176
|
+
while True:
|
|
177
|
+
_, key = jrnd.split(key)
|
|
178
|
+
samples = prior_dist.sample(seed=key, sample_shape=batch_size)
|
|
179
|
+
logp_samples = prior_dist.log_prob(samples)
|
|
180
|
+
score = v_score(samples)
|
|
181
|
+
samples0 = lax.concatenate(
|
|
182
|
+
(samples, logp_samples[:,None], score), 1)
|
|
183
|
+
|
|
184
|
+
_, key = jrnd.split(key)
|
|
185
|
+
samples = prior_dist.sample(seed=key, sample_shape=batch_size)
|
|
186
|
+
logp_samples = prior_dist.log_prob(samples)
|
|
187
|
+
score = v_score(samples)
|
|
188
|
+
samples1 = lax.concatenate(
|
|
189
|
+
(samples, logp_samples[:,None], score), 1)
|
|
190
|
+
|
|
191
|
+
yield lax.concatenate((samples0, samples1), 0)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def get_scheduler(epochs: int, sched_type: str = 'zero', lr: float = 3E-4):
|
|
196
|
+
try:
|
|
197
|
+
float(sched_type)
|
|
198
|
+
v = float(sched_type)
|
|
199
|
+
return optax.constant_schedule(v)
|
|
200
|
+
except ValueError:
|
|
201
|
+
if sched_type == 'zero':
|
|
202
|
+
return optax.constant_schedule(0.0)
|
|
203
|
+
elif sched_type == 'one':
|
|
204
|
+
return optax.constant_schedule(1.)
|
|
205
|
+
elif sched_type == 'const' or sched_type == 'c':
|
|
206
|
+
return optax.constant_schedule(lr)
|
|
207
|
+
elif sched_type == 'cos_decay':
|
|
208
|
+
return optax.warmup_cosine_decay_schedule(
|
|
209
|
+
init_value=lr,
|
|
210
|
+
peak_value=lr,
|
|
211
|
+
warmup_steps=150,
|
|
212
|
+
decay_steps=epochs,
|
|
213
|
+
end_value=1E-5,
|
|
214
|
+
)
|
|
215
|
+
elif sched_type == 'mix':
|
|
216
|
+
init_scheduler_min = optax.warmup_cosine_decay_schedule(
|
|
217
|
+
init_value=lr,
|
|
218
|
+
peak_value=lr,
|
|
219
|
+
warmup_steps=150,
|
|
220
|
+
decay_steps=int(2*epochs/3),
|
|
221
|
+
end_value=1E-6,
|
|
222
|
+
)
|
|
223
|
+
constant_scheduler_max = optax.constant_schedule(1E-6)
|
|
224
|
+
return optax.join_schedules([init_scheduler_min,
|
|
225
|
+
constant_scheduler_max], boundaries=[2*epochs/3, 3*epochs/3])
|
|
226
|
+
elif sched_type == 'mix_old':
|
|
227
|
+
constant_scheduler_min = optax.constant_schedule(lr)
|
|
228
|
+
cosine_decay_scheduler = optax.cosine_onecycle_schedule(transition_steps=epochs, peak_value=lr,
|
|
229
|
+
div_factor=50., final_div_factor=1.)
|
|
230
|
+
constant_scheduler_max = optax.constant_schedule(1E-5)
|
|
231
|
+
return optax.join_schedules([constant_scheduler_min, cosine_decay_scheduler,
|
|
232
|
+
constant_scheduler_max], boundaries=[epochs/4, 2*epochs/4])
|
|
233
|
+
|
|
234
|
+
def get_solver(solver_type: SolverType, **solver_kwargs) -> Union[
|
|
235
|
+
Euler, Heun, Midpoint, Ralston, Bosh3, Tsit5, Dopri5, Dopri8,
|
|
236
|
+
ImplicitEuler, Kvaerno3, Kvaerno4, Kvaerno5, Sil3,
|
|
237
|
+
KenCarp3, KenCarp4, KenCarp5, SemiImplicitEuler,
|
|
238
|
+
ReversibleHeun, LeapfrogMidpoint
|
|
239
|
+
]:
|
|
240
|
+
"""Factory function to create Diffrax solvers based on string input.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
solver_type: Name of the solver to instantiate
|
|
244
|
+
**solver_kwargs: Additional arguments to pass to the solver
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
The requested Diffrax solver instance
|
|
248
|
+
"""
|
|
249
|
+
solver_map = {
|
|
250
|
+
'euler': Euler,
|
|
251
|
+
'heun': Heun,
|
|
252
|
+
'midpoint': Midpoint,
|
|
253
|
+
'ralston': Ralston,
|
|
254
|
+
'bosh3': Bosh3,
|
|
255
|
+
'tsit5': Tsit5,
|
|
256
|
+
'dopri5': Dopri5,
|
|
257
|
+
'dopri8': Dopri8,
|
|
258
|
+
'implicit_euler': ImplicitEuler,
|
|
259
|
+
'kvaerno3': Kvaerno3,
|
|
260
|
+
'kvaerno4': Kvaerno4,
|
|
261
|
+
'kvaerno5': Kvaerno5,
|
|
262
|
+
'sil3': Sil3,
|
|
263
|
+
'ken_carp3': KenCarp3,
|
|
264
|
+
'ken_carp4': KenCarp4,
|
|
265
|
+
'ken_carp5': KenCarp5,
|
|
266
|
+
'semi_implicit_euler': SemiImplicitEuler,
|
|
267
|
+
'reversible_heun': ReversibleHeun,
|
|
268
|
+
'leapfrog_midpoint': LeapfrogMidpoint
|
|
269
|
+
}
|
|
270
|
+
|
|
271
|
+
solver_class = solver_map.get(solver_type.lower())
|
|
272
|
+
if solver_class is None:
|
|
273
|
+
raise ValueError(f"Unknown solver type: {solver_type}. "
|
|
274
|
+
f"Available options: {list(solver_map.keys())}")
|
|
275
|
+
|
|
276
|
+
return solver_class(**solver_kwargs)
|
|
277
|
+
|
|
278
|
+
def correlation_polarization_correction(
|
|
279
|
+
e_tilde_PF: float,
|
|
280
|
+
den: Array,
|
|
281
|
+
clip_cte: float = 1e-30
|
|
282
|
+
):
|
|
283
|
+
r"""Spin polarization correction to a correlation functional using eq 2.75 from
|
|
284
|
+
Carsten A. Ullrich, "Time-Dependent Density-Functional Theory".
|
|
285
|
+
|
|
286
|
+
Parameters
|
|
287
|
+
----------
|
|
288
|
+
e_tilde_PF: Float[Array, "spin grid"]
|
|
289
|
+
The paramagnetic/ferromagnetic energy contributions on the grid, to be combined.
|
|
290
|
+
|
|
291
|
+
rho: Float[Array, "spin grid"]
|
|
292
|
+
The electronic density of each spin polarization at each grid point.
|
|
293
|
+
|
|
294
|
+
clip_cte:
|
|
295
|
+
float, defaults to 1e-30
|
|
296
|
+
Small constant to avoid numerical issues when dividing by rho.
|
|
297
|
+
|
|
298
|
+
Returns
|
|
299
|
+
----------
|
|
300
|
+
e_tilde: Float[Array, "grid"]
|
|
301
|
+
The ready to be integrated electronic energy density.
|
|
302
|
+
"""
|
|
303
|
+
|
|
304
|
+
log_rho = jnp.log2(jnp.clip(den.sum(axis=1), a_min=clip_cte))
|
|
305
|
+
# assert not jnp.isnan(log_rho).any() and not jnp.isinf(log_rho).any()
|
|
306
|
+
log_rs = jnp.log2((3 / (4 * jnp.pi)) ** (1 / 3)) - log_rho / 3.0
|
|
307
|
+
|
|
308
|
+
zeta = jnp.where(den.sum(axis=1) > clip_cte, (den[:, 0] - den[:, 1]) / (den.sum(axis=1)), 0.0)
|
|
309
|
+
|
|
310
|
+
def fzeta(z):
|
|
311
|
+
zm = 2 ** (4 * jnp.log2(1 - z) / 3)
|
|
312
|
+
zp = 2 ** (4 * jnp.log2(1 + z) / 3)
|
|
313
|
+
return (zm + zp - 2) / (2 * (2 ** (1 / 3) - 1))
|
|
314
|
+
|
|
315
|
+
A_ = 0.016887
|
|
316
|
+
alpha1 = 0.11125
|
|
317
|
+
beta1 = 10.357
|
|
318
|
+
beta2 = 3.6231
|
|
319
|
+
beta3 = 0.88026
|
|
320
|
+
beta4 = 0.49671
|
|
321
|
+
|
|
322
|
+
ars = 2 ** (jnp.log2(alpha1) + log_rs)
|
|
323
|
+
brs_1_2 = 2 ** (jnp.log2(beta1) + log_rs / 2)
|
|
324
|
+
brs = 2 ** (jnp.log2(beta2) + log_rs)
|
|
325
|
+
brs_3_2 = 2 ** (jnp.log2(beta3) + 3 * log_rs / 2)
|
|
326
|
+
brs2 = 2 ** (jnp.log2(beta4) + 2 * log_rs)
|
|
327
|
+
|
|
328
|
+
alphac = 2 * A_ * (1 + ars) * jnp.log(1 + (1 / (2 * A_)) / (brs_1_2 + brs + brs_3_2 + brs2))
|
|
329
|
+
# assert not jnp.isnan(alphac).any() and not jnp.isinf(alphac).any()
|
|
330
|
+
|
|
331
|
+
fz = fzeta(zeta) #jnp.round(fzeta(zeta), int(math.log10(clip_cte)))
|
|
332
|
+
z4 = zeta**4 #jnp.round(2 ** (4 * jnp.log2(jnp.clip(zeta, a_min=clip_cte))), int(math.log10(clip_cte)))
|
|
333
|
+
|
|
334
|
+
e_tilde = (
|
|
335
|
+
e_tilde_PF[:, 0]
|
|
336
|
+
+ alphac * (fz / (grad(grad(fzeta))(0.0))) * (1 - z4)
|
|
337
|
+
+ (e_tilde_PF[:, 1] - e_tilde_PF[:, 0]) * fz * z4
|
|
338
|
+
)
|
|
339
|
+
# assert not jnp.isnan(e_tilde).any() and not jnp.isinf(e_tilde).any()
|
|
340
|
+
|
|
341
|
+
return e_tilde
|
|
342
|
+
|
|
343
|
+
def one_hot_encode(z: Array) -> Array:
|
|
344
|
+
"""
|
|
345
|
+
One hot encode the input array.
|
|
346
|
+
|
|
347
|
+
Parameters
|
|
348
|
+
----------
|
|
349
|
+
z : Array
|
|
350
|
+
Array to one hot encode.
|
|
351
|
+
|
|
352
|
+
Returns
|
|
353
|
+
-------
|
|
354
|
+
Array
|
|
355
|
+
One hot encoded array.
|
|
356
|
+
"""
|
|
357
|
+
z_u = jnp.unique(z)
|
|
358
|
+
zz = []
|
|
359
|
+
for zi in z:
|
|
360
|
+
for i,zui in enumerate(z_u):
|
|
361
|
+
if zui == zi:
|
|
362
|
+
zz.append(i)
|
|
363
|
+
n_atoms_type = jnp.unique(z).shape[0]
|
|
364
|
+
z_one_hot = jax.nn.one_hot(zz,n_atoms_type)
|
|
365
|
+
|
|
366
|
+
return z_one_hot
|
|
367
|
+
|
|
368
|
+
# def coordinates(mol_name: str, BOHR: float = 1.8897259886 ) -> Array:
|
|
369
|
+
def coordinates(mol_name: str, bond_length: float = 1.4008538753, BOHR: float = 1.8897259886) -> Array:
|
|
370
|
+
|
|
371
|
+
if mol_name == 'H2':
|
|
372
|
+
Ne = 2
|
|
373
|
+
atoms = ['H', 'H']
|
|
374
|
+
coords = jnp.array([[0., 0., -bond_length/2],
|
|
375
|
+
[0., 0., bond_length/2]]) # bond_length already in Bohr
|
|
376
|
+
z = jnp.array([1, 1])
|
|
377
|
+
return Ne, atoms, z, coords
|
|
378
|
+
|
|
379
|
+
elif mol_name == 'H10':
|
|
380
|
+
n_h = 10
|
|
381
|
+
Ne = n_h
|
|
382
|
+
atoms = ['H'] * n_h
|
|
383
|
+
offsets = (jnp.arange(n_h) - (n_h - 1) / 2.) * bond_length # spacing = bond_length
|
|
384
|
+
coords = jnp.stack([jnp.zeros(n_h),
|
|
385
|
+
jnp.zeros(n_h),
|
|
386
|
+
offsets], axis=1)
|
|
387
|
+
z = jnp.ones(n_h)
|
|
388
|
+
return Ne, atoms, z, coords
|
|
389
|
+
|
|
390
|
+
elif mol_name == 'H':
|
|
391
|
+
Ne = 1
|
|
392
|
+
atoms = ['H']
|
|
393
|
+
coords = jnp.array([[0., 0., 0.]])
|
|
394
|
+
z = jnp.array([1.])
|
|
395
|
+
return Ne,atoms,z,coords
|
|
396
|
+
elif mol_name == 'He':
|
|
397
|
+
Ne = 2
|
|
398
|
+
atoms = ['He']
|
|
399
|
+
coords = jnp.array([[0., 0., 0.]])
|
|
400
|
+
z = jnp.array([2.])
|
|
401
|
+
return Ne,atoms,z,coords
|
|
402
|
+
elif mol_name == 'N2':
|
|
403
|
+
Ne = 14
|
|
404
|
+
atoms = ['N', 'N']
|
|
405
|
+
coords = jnp.array([[0., 0., -bond_length/2],
|
|
406
|
+
[0., 0., bond_length/2]])
|
|
407
|
+
z = jnp.array([7., 7.])
|
|
408
|
+
return Ne,atoms,z,coords
|
|
409
|
+
elif mol_name == 'O2':
|
|
410
|
+
Ne = 16
|
|
411
|
+
atoms = ['O', 'O']
|
|
412
|
+
coords = jnp.array([[0., 0., -bond_length/2],
|
|
413
|
+
[0., 0., bond_length/2]])
|
|
414
|
+
z = jnp.array([8., 8.])
|
|
415
|
+
return Ne,atoms,z,coords
|
|
416
|
+
elif mol_name == 'F2':
|
|
417
|
+
Ne = 18
|
|
418
|
+
atoms = ['F', 'F']
|
|
419
|
+
coords = jnp.array([[0., 0., -bond_length/2],
|
|
420
|
+
[0., 0., bond_length/2]])
|
|
421
|
+
z = jnp.array([9., 9.])
|
|
422
|
+
return Ne,atoms,z,coords
|
|
423
|
+
elif mol_name == 'CO':
|
|
424
|
+
Ne = 14
|
|
425
|
+
atoms = ['C', 'O']
|
|
426
|
+
coords = jnp.array([[0., 0., -bond_length/2],
|
|
427
|
+
[0., 0., bond_length/2]])
|
|
428
|
+
z = jnp.array([6., 8.])
|
|
429
|
+
return Ne,atoms,z,coords
|
|
430
|
+
elif mol_name == 'HF':
|
|
431
|
+
Ne = 10
|
|
432
|
+
atoms = ['H', 'F']
|
|
433
|
+
coords = jnp.array([[0., 0., -bond_length/2],
|
|
434
|
+
[0., 0., bond_length/2]])
|
|
435
|
+
z = jnp.array([1., 9.])
|
|
436
|
+
return Ne,atoms,z,coords
|
|
437
|
+
elif mol_name == 'Li':
|
|
438
|
+
Ne = 3
|
|
439
|
+
atoms = ['Li']
|
|
440
|
+
coords = jnp.array([[0., 0., 0.]])
|
|
441
|
+
z = jnp.array([3.])
|
|
442
|
+
return Ne,atoms,z,coords
|
|
443
|
+
elif mol_name == 'Be':
|
|
444
|
+
Ne = 4
|
|
445
|
+
atoms = ['Be']
|
|
446
|
+
coords = jnp.array([[0., 0., 0.]])
|
|
447
|
+
z = jnp.array([4.])
|
|
448
|
+
return Ne,atoms,z,coords
|
|
449
|
+
elif mol_name == 'B':
|
|
450
|
+
Ne = 5
|
|
451
|
+
atoms = ['B']
|
|
452
|
+
coords = jnp.array([[0., 0., 0.]])
|
|
453
|
+
z = jnp.array([5.])
|
|
454
|
+
return Ne,atoms,z,coords
|
|
455
|
+
elif mol_name == 'C':
|
|
456
|
+
Ne = 6
|
|
457
|
+
atoms = ['C']
|
|
458
|
+
coords = jnp.array([[0., 0., 0.]])
|
|
459
|
+
z = jnp.array([6.])
|
|
460
|
+
return Ne,atoms,z,coords
|
|
461
|
+
elif mol_name == 'N':
|
|
462
|
+
Ne = 7
|
|
463
|
+
atoms = ['N']
|
|
464
|
+
coords = jnp.array([[0., 0., 0.]])
|
|
465
|
+
z = jnp.array([7.])
|
|
466
|
+
return Ne,atoms,z,coords
|
|
467
|
+
elif mol_name == 'O':
|
|
468
|
+
Ne = 8
|
|
469
|
+
atoms = ['O']
|
|
470
|
+
coords = jnp.array([[0., 0., 0.]])
|
|
471
|
+
z = jnp.array([8.])
|
|
472
|
+
return Ne,atoms,z,coords
|
|
473
|
+
elif mol_name == 'F':
|
|
474
|
+
Ne = 9
|
|
475
|
+
atoms = ['F']
|
|
476
|
+
coords = jnp.array([[0., 0., 0.]])
|
|
477
|
+
z = jnp.array([9.])
|
|
478
|
+
return Ne,atoms,z,coords
|
|
479
|
+
elif mol_name == 'Ne':
|
|
480
|
+
Ne = 10
|
|
481
|
+
atoms = ['Ne']
|
|
482
|
+
coords = jnp.array([[0., 0., 0.]])
|
|
483
|
+
z = jnp.array([10.])
|
|
484
|
+
return Ne,atoms,z,coords
|
|
485
|
+
elif mol_name == 'LiH':
|
|
486
|
+
Ne = 4
|
|
487
|
+
atoms = ['Li', 'H']
|
|
488
|
+
coords = jnp.array([[0., 0., -bond_length/2],
|
|
489
|
+
[0., 0., bond_length/2]])
|
|
490
|
+
z = jnp.array([3, 1])
|
|
491
|
+
return Ne,atoms,z,coords
|
|
492
|
+
elif mol_name == 'H2O':
|
|
493
|
+
Ne = 10
|
|
494
|
+
atoms = ['O', 'H', 'H']
|
|
495
|
+
coords = jnp.array([[0.0, 0.0, 0.1189120],
|
|
496
|
+
[0.0, 0.7612710, -0.4756480],
|
|
497
|
+
[0.0, -0.7612710, -0.4756480]])*BOHR
|
|
498
|
+
z = jnp.array([8, 1, 1])
|
|
499
|
+
return Ne,atoms,z,coords
|
|
500
|
+
elif mol_name == 'CH4':
|
|
501
|
+
Ne = 10
|
|
502
|
+
atoms = ['C', 'H', 'H', 'H', 'H']
|
|
503
|
+
coords = jnp.array([[0.0, 0.0, 0.0],
|
|
504
|
+
[0.0, 0.0, 1.09],
|
|
505
|
+
[1.03, 0.0, -0.363],
|
|
506
|
+
[-0.515, -0.889165, -0.363],
|
|
507
|
+
[-0.515, 0.889165, -0.363]])*BOHR
|
|
508
|
+
z = jnp.array([6, 1, 1, 1, 1])
|
|
509
|
+
return Ne,atoms,z,coords
|
|
510
|
+
elif mol_name == 'C6H6':
|
|
511
|
+
Ne = 42
|
|
512
|
+
atoms = ['C','C','C','C','C','C','H','H','H','H','H','H']
|
|
513
|
+
coords = jnp.array([[-0.6984022192, 1.2096794375, 0.0001298085],
|
|
514
|
+
[0.6983971652, 1.2096794375, 0.0001298085],
|
|
515
|
+
[1.3968148123, 0.0000100819, 0.0001298085],
|
|
516
|
+
[0.6983970907, -1.2096631450, -0.0001577731],
|
|
517
|
+
[-0.6984347101, -1.2097236297, -0.0001117213],
|
|
518
|
+
[-1.3967427862, -0.0000316846, -0.0000531302],
|
|
519
|
+
[2.4989957176, 0.0000352657, 0.0001433155],
|
|
520
|
+
[1.2492115488, 2.1643070022, 0.0000608098],
|
|
521
|
+
[-1.2497352350, 2.1640355443, 0.0000873078],
|
|
522
|
+
[1.2495192781, -2.1641211865, -0.0004652006],
|
|
523
|
+
[-1.2494488074, -2.1642720181, -0.0003281198],
|
|
524
|
+
[-2.4988922798, 0.0006052813, -0.0002941369]])*BOHR
|
|
525
|
+
z = jnp.array([6.,6.,6.,6.,6.,6.,1.,1.,1.,1.,1.,1.])
|
|
526
|
+
return Ne,atoms,z,coords
|
|
527
|
+
elif mol_name == 'C27H46O':
|
|
528
|
+
Ne = 216
|
|
529
|
+
atoms = ['O','C','C','C','C','C','C','C','C','C','C',
|
|
530
|
+
'C','C','C','C','C','C','C','C','C','C','C',
|
|
531
|
+
'C','C','C','C','C','C','H','H','H','H','H',
|
|
532
|
+
'H','H','H','H','H','H','H','H','H','H','H',
|
|
533
|
+
'H','H','H','H','H','H','H','H','H','H','H',
|
|
534
|
+
'H','H','H','H','H','H','H','H','H','H','H',
|
|
535
|
+
'H','H','H','H','H','H','H','H']
|
|
536
|
+
coords = jnp.array([[-7.6344919379, 0.4597041902, 0.7112354681],
|
|
537
|
+
[1.1080808630, 0.4597041902, 0.7112354681],
|
|
538
|
+
[0.2530014327, -0.8231352339, 0.7112354681],
|
|
539
|
+
[-1.1599473252, -0.6092757486, 1.2632436184],
|
|
540
|
+
[-1.8846097300, 0.4851364826, 0.4266238676],
|
|
541
|
+
[2.4346065739, -0.1264931117, 0.1695905830],
|
|
542
|
+
[0.4237780146, 1.4699206954, -0.2202286076],
|
|
543
|
+
[-3.3372747943, 0.7709975410, 0.9822345823],
|
|
544
|
+
[1.1534637485, -1.8591349538, 1.3772546332],
|
|
545
|
+
[-1.0194512849, 1.7601151269, 0.2271214576],
|
|
546
|
+
[2.5719892229, -1.4735213405, 0.9177695847],
|
|
547
|
+
[-1.9616876023, -1.9189024556, 1.2281160480],
|
|
548
|
+
[3.6706406723, 0.7518022510, 0.2998092579],
|
|
549
|
+
[1.3056592864, 1.0858012443, 2.1142034561],
|
|
550
|
+
[-4.0659099987, -0.5409527880, 1.3058897583],
|
|
551
|
+
[-4.1355207853, 1.5407857692, -0.1122402103],
|
|
552
|
+
[-3.4345565189, -1.7244824608, 1.3955388362],
|
|
553
|
+
[-3.2883872448, 1.6217121992, 2.2742973586],
|
|
554
|
+
[-5.5600747780, -0.4657287448, 1.5202105354],
|
|
555
|
+
[-5.6342682740, 1.6532127465, 0.1635040190],
|
|
556
|
+
[4.9369027993, 0.0920796760, -0.2857409192],
|
|
557
|
+
[-6.2593559003, 0.2808578004, 0.3875037320],
|
|
558
|
+
[3.4346199063, 2.0977159171, -0.3972787505],
|
|
559
|
+
[4.8597222067, -0.3194422440, -1.7619595329],
|
|
560
|
+
[6.1183294840, -1.0351005847, -2.2632694976],
|
|
561
|
+
[6.0683625521, -1.3844022946, -3.7635312418],
|
|
562
|
+
[5.8475871083, -0.1546334232, -4.6458023942],
|
|
563
|
+
[4.9580276609, -2.4067800912, -4.0202930776],
|
|
564
|
+
[0.1320531554, -1.1546414208, -0.3341704111],
|
|
565
|
+
[-1.0961065494, -0.2887741191, 2.3091814564],
|
|
566
|
+
[-2.0071557590, 0.0577891770, -0.5822428933],
|
|
567
|
+
[2.2692308941, -0.3566230816, -0.8919397659],
|
|
568
|
+
[0.4024909308, 1.0868866903, -1.2485637383],
|
|
569
|
+
[0.9326111775, 2.4332063069, -0.2550387088],
|
|
570
|
+
[0.9000265464, -2.8793549884, 1.0717133503],
|
|
571
|
+
[1.0798371703, -1.8121549020, 2.4694535032],
|
|
572
|
+
[-0.9810389048, 2.3462405294, 1.1499104917],
|
|
573
|
+
[-1.4838623070, 2.4059394028, -0.5267205531],
|
|
574
|
+
[3.2317842936, -1.3989125509, 1.7896882295],
|
|
575
|
+
[2.9883545370, -2.2408492672, 0.2555163717],
|
|
576
|
+
[-1.6014174036, -2.5802367118, 2.0254337925],
|
|
577
|
+
[-1.7980739690, -2.4474317007, 0.2806843164],
|
|
578
|
+
[3.8879474770, 0.9526354546, 1.3563558574],
|
|
579
|
+
[1.8492781899, 2.0336135308, 2.0610078045],
|
|
580
|
+
[0.3578137484, 1.3050839833, 2.6129243364],
|
|
581
|
+
[1.8665627197, 0.4298627984, 2.7876605918],
|
|
582
|
+
[-4.0090347972, 1.0378379064, -1.0807548937],
|
|
583
|
+
[-3.7338443144, 2.5534931648, -0.2361852128],
|
|
584
|
+
[-4.0082696014, -2.6236692902, 1.6068637872],
|
|
585
|
+
[-2.9172441072, 2.6329927094, 2.0790411399],
|
|
586
|
+
[-4.2766956463, 1.7447530584, 2.7297022229],
|
|
587
|
+
[-2.6449795846, 1.1683767430, 3.0353862870],
|
|
588
|
+
[-6.0051645049, -1.4649829974, 1.6125803553],
|
|
589
|
+
[-5.7715116059, 0.0256833216, 2.4782609823],
|
|
590
|
+
[-5.8206294845, 2.3008198064, 1.0291274274],
|
|
591
|
+
[-6.1301728674, 2.1483087541, -0.6807530020],
|
|
592
|
+
[5.7878953777, 0.7733743373, -0.1540084782],
|
|
593
|
+
[5.1800431538, -0.7964447756, 0.3105080195],
|
|
594
|
+
[-6.2070878208, -0.3024479474, -0.5399782530],
|
|
595
|
+
[2.9407650699, 1.9873110720, -1.3676590681],
|
|
596
|
+
[4.3895302067, 2.6050134027, -0.5818686450],
|
|
597
|
+
[2.8705286939, 2.7872275940, 0.2335427866],
|
|
598
|
+
[4.0058535847, -0.9856353615, -1.9055309667],
|
|
599
|
+
[4.6772717683, 0.5793866627, -2.3581606350],
|
|
600
|
+
[-8.0203071775, -0.4225513077, 0.8470435952],
|
|
601
|
+
[6.2790374269, -1.9486777435, -1.6774593986],
|
|
602
|
+
[6.9885972888, -0.3917002684, -2.0822296102],
|
|
603
|
+
[7.0293119718, -1.8396620422, -4.0335362510],
|
|
604
|
+
[6.5721859735, 0.6309863795, -4.4078721938],
|
|
605
|
+
[5.9798893020, -0.4194041327, -5.7008780697],
|
|
606
|
+
[4.8401657335, 0.2599872029, -4.5419245327],
|
|
607
|
+
[5.0564586193, -3.2709821863, -3.3550073831],
|
|
608
|
+
[5.0182590552, -2.7761853866, -5.0502175442],
|
|
609
|
+
[3.9582362826, -1.9813780480, -3.8886221559]])*BOHR
|
|
610
|
+
|
|
611
|
+
z = jnp.array([8.,6.,6.,6.,6.,6.,6.,6.,6.,6.,6.,6.,6.,
|
|
612
|
+
6.,6.,6.,6.,6.,6.,6.,6.,6.,6.,6.,6.,6.,6.,6.,
|
|
613
|
+
1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,
|
|
614
|
+
1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,
|
|
615
|
+
1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,
|
|
616
|
+
1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,
|
|
617
|
+
1.,1.,1.,1.,1.,1.])
|
|
618
|
+
return Ne,atoms,z,coords
|