off 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
off/utils.py ADDED
@@ -0,0 +1,618 @@
1
+ from typing import Any, Callable, Literal, Union
2
+ from functools import partial
3
+
4
+ import jax
5
+ import jax.numpy as jnp
6
+ import jax.random as jrnd
7
+ from jax import jit,grad,vmap,hessian,lax
8
+ from jaxtyping import Array
9
+ from jax._src import prng
10
+
11
+ from pyscf.data.nist import BOHR
12
+
13
+ import optax
14
+
15
+ from diffrax import (
16
+ Euler, Heun, Midpoint, Ralston, Bosh3, Tsit5, Dopri5, Dopri8,
17
+ ImplicitEuler, Kvaerno3, Kvaerno4, Kvaerno5, Sil3,
18
+ KenCarp3, KenCarp4, KenCarp5, SemiImplicitEuler,
19
+ ReversibleHeun, LeapfrogMidpoint
20
+ )
21
+
22
+ Dtype = Any
23
+
24
+ SolverType = Literal[
25
+ 'euler', 'heun', 'midpoint', 'ralston', 'bosh3',
26
+ 'tsit5', 'dopri5', 'dopri8', 'implicit_euler',
27
+ 'kvaerno3', 'kvaerno4', 'kvaerno5', 'sil3',
28
+ 'ken_carp3', 'ken_carp4', 'ken_carp5',
29
+ 'semi_implicit_euler', 'reversible_heun',
30
+ 'leapfrog_midpoint'
31
+ ]
32
+
33
+ @partial(jit, static_argnums=(2,))
34
+ def compute_integral(params: Any, grid_array: Any, rho: Any, Ne: int):
35
+ grid_coords, grid_weights = grid_array
36
+ rho_val = rho(params, grid_coords)
37
+ return jnp.vdot(grid_weights, rho_val)
38
+
39
+ @partial(jit, static_argnums=(2,))
40
+ def laplacian(params: Any, X: Array, fun: callable) -> jax.Array:
41
+ """_summary_
42
+
43
+ Parameters
44
+ ----------
45
+ params : Any
46
+ _description_
47
+ X : Array
48
+ _description_
49
+ fun : callable
50
+ _description_
51
+
52
+ Returns
53
+ -------
54
+ jax.Array
55
+ _description_
56
+ """
57
+ @partial(jit, static_argnums=(2,))
58
+ def _laplacian(params: Any, X: Array, fun: callable):
59
+ hes_ = hessian(fun, argnums=1)(
60
+ params, X[jnp.newaxis], ) # R[jnp.newaxis]
61
+ hes_ = jnp.squeeze(hes_, axis=(0, 2, 4))
62
+ hes_ = jnp.einsum('...ii', hes_)
63
+ return hes_
64
+
65
+ v_laplacian = vmap(_laplacian, in_axes=(None, 0, None))
66
+ return v_laplacian(params, X, fun)
67
+
68
+
69
+ @partial(jit, static_argnums=(2,))
70
+ def score(params: Any, X: Array, fun: callable) -> jax.Array:
71
+ """
72
+ Function that evaluates the score of the model. Using jax.jacrev to compute the gradient of the function
73
+ and then reshaping the output to the original shape of the input and then vmapping the function
74
+ to evaluate the score for each element in the input.
75
+
76
+ Parameters
77
+ ----------
78
+ params : Any
79
+ Parameters of the model.
80
+ X : Array
81
+ X values to evaluate the score.
82
+ fun : callable
83
+ Function to evaluate the score.
84
+
85
+ Returns
86
+ -------
87
+ jax.Array
88
+ The score of the model.
89
+ """
90
+ @jit
91
+ def _score(params: Any, xi: Array):
92
+ score_ = jax.jacrev(fun, argnums=1)(params, xi[jnp.newaxis])
93
+ return jnp.reshape(score_, xi.shape[0])
94
+
95
+ v_score = vmap(_score, in_axes=(None, 0))
96
+ return v_score(params, X)
97
+
98
+ def batch_generator(key: prng.PRNGKeyArray, batch_size: int, prior_dist):
99
+ """Generator with optional pre-computed score."""
100
+
101
+ # Check if prior has its own score method
102
+ if hasattr(prior_dist, 'score'):
103
+ v_score = prior_dist.score
104
+ else:
105
+ # v_score = jax.vmap(jax.grad(lambda x: prior_dist.log_prob(x).sum()))
106
+ v_score = vmap(jax.jacrev(lambda x:
107
+ prior_dist.log_prob(x)))
108
+
109
+ while True:
110
+ _, key = jrnd.split(key)
111
+ samples = prior_dist.sample(seed=key, sample_shape=batch_size)
112
+ logp_samples = prior_dist.log_prob(samples)
113
+ score = v_score(samples)
114
+ samples0 = lax.concatenate((samples, logp_samples, score), 1)
115
+
116
+ _, key = jrnd.split(key)
117
+ samples = prior_dist.sample(seed=key, sample_shape=batch_size)
118
+ logp_samples = prior_dist.log_prob(samples)
119
+ score = v_score(samples)
120
+ samples1 = lax.concatenate((samples, logp_samples, score), 1)
121
+
122
+ yield lax.concatenate((samples0, samples1), 0)
123
+
124
+ def batch_generator_(key: prng.PRNGKeyArray, batch_size: int, prior_dist: Callable):
125
+ """
126
+ Generator that yields batches of samples from the prior distribution.
127
+
128
+ Parameters
129
+ ----------
130
+ key : prng.PRNGKeyArray
131
+ Key to generate random numbers.
132
+ batch_size : int
133
+ Size of the batch.
134
+ prior_dist : Callable
135
+ Prior distribution.
136
+
137
+ """
138
+
139
+ v_score = jax.vmap(jax.grad(lambda x:
140
+ prior_dist.log_prob(x).sum()))
141
+ while True:
142
+ _, key = jrnd.split(key)
143
+ samples = prior_dist.sample(seed=key, sample_shape=batch_size)
144
+ logp_samples = prior_dist.log_prob(samples)
145
+ score = v_score(samples)
146
+ samples0 = lax.concatenate(
147
+ (samples, logp_samples, score), 1)
148
+
149
+ _, key = jrnd.split(key)
150
+ samples = prior_dist.sample(seed=key, sample_shape=batch_size)
151
+ logp_samples = prior_dist.log_prob(samples)
152
+ score = v_score(samples)
153
+ samples1 = lax.concatenate(
154
+ (samples, logp_samples, score), 1)
155
+
156
+ yield lax.concatenate((samples0, samples1), 0)
157
+
158
+ def batche_generator_1D(key: prng.PRNGKeyArray, batch_size: int, prior_dist: Callable):
159
+ """
160
+ Generator that yields batches of samples from the prior distribution.
161
+
162
+ Parameters
163
+ ----------
164
+ key : prng.PRNGKeyArray
165
+ Key to generate random numbers.
166
+ batch_size : int
167
+ Size of the batch.
168
+ prior_dist : Callable
169
+ Prior distribution.
170
+
171
+ """
172
+ v_score = vmap(jax.jacrev(lambda x:
173
+ prior_dist.log_prob(x)))
174
+ # v_score = jax.vmap(jax.grad(lambda x:
175
+ # prior_dist.log_prob(x).sum()))
176
+ while True:
177
+ _, key = jrnd.split(key)
178
+ samples = prior_dist.sample(seed=key, sample_shape=batch_size)
179
+ logp_samples = prior_dist.log_prob(samples)
180
+ score = v_score(samples)
181
+ samples0 = lax.concatenate(
182
+ (samples, logp_samples[:,None], score), 1)
183
+
184
+ _, key = jrnd.split(key)
185
+ samples = prior_dist.sample(seed=key, sample_shape=batch_size)
186
+ logp_samples = prior_dist.log_prob(samples)
187
+ score = v_score(samples)
188
+ samples1 = lax.concatenate(
189
+ (samples, logp_samples[:,None], score), 1)
190
+
191
+ yield lax.concatenate((samples0, samples1), 0)
192
+
193
+
194
+
195
+ def get_scheduler(epochs: int, sched_type: str = 'zero', lr: float = 3E-4):
196
+ try:
197
+ float(sched_type)
198
+ v = float(sched_type)
199
+ return optax.constant_schedule(v)
200
+ except ValueError:
201
+ if sched_type == 'zero':
202
+ return optax.constant_schedule(0.0)
203
+ elif sched_type == 'one':
204
+ return optax.constant_schedule(1.)
205
+ elif sched_type == 'const' or sched_type == 'c':
206
+ return optax.constant_schedule(lr)
207
+ elif sched_type == 'cos_decay':
208
+ return optax.warmup_cosine_decay_schedule(
209
+ init_value=lr,
210
+ peak_value=lr,
211
+ warmup_steps=150,
212
+ decay_steps=epochs,
213
+ end_value=1E-5,
214
+ )
215
+ elif sched_type == 'mix':
216
+ init_scheduler_min = optax.warmup_cosine_decay_schedule(
217
+ init_value=lr,
218
+ peak_value=lr,
219
+ warmup_steps=150,
220
+ decay_steps=int(2*epochs/3),
221
+ end_value=1E-6,
222
+ )
223
+ constant_scheduler_max = optax.constant_schedule(1E-6)
224
+ return optax.join_schedules([init_scheduler_min,
225
+ constant_scheduler_max], boundaries=[2*epochs/3, 3*epochs/3])
226
+ elif sched_type == 'mix_old':
227
+ constant_scheduler_min = optax.constant_schedule(lr)
228
+ cosine_decay_scheduler = optax.cosine_onecycle_schedule(transition_steps=epochs, peak_value=lr,
229
+ div_factor=50., final_div_factor=1.)
230
+ constant_scheduler_max = optax.constant_schedule(1E-5)
231
+ return optax.join_schedules([constant_scheduler_min, cosine_decay_scheduler,
232
+ constant_scheduler_max], boundaries=[epochs/4, 2*epochs/4])
233
+
234
+ def get_solver(solver_type: SolverType, **solver_kwargs) -> Union[
235
+ Euler, Heun, Midpoint, Ralston, Bosh3, Tsit5, Dopri5, Dopri8,
236
+ ImplicitEuler, Kvaerno3, Kvaerno4, Kvaerno5, Sil3,
237
+ KenCarp3, KenCarp4, KenCarp5, SemiImplicitEuler,
238
+ ReversibleHeun, LeapfrogMidpoint
239
+ ]:
240
+ """Factory function to create Diffrax solvers based on string input.
241
+
242
+ Args:
243
+ solver_type: Name of the solver to instantiate
244
+ **solver_kwargs: Additional arguments to pass to the solver
245
+
246
+ Returns:
247
+ The requested Diffrax solver instance
248
+ """
249
+ solver_map = {
250
+ 'euler': Euler,
251
+ 'heun': Heun,
252
+ 'midpoint': Midpoint,
253
+ 'ralston': Ralston,
254
+ 'bosh3': Bosh3,
255
+ 'tsit5': Tsit5,
256
+ 'dopri5': Dopri5,
257
+ 'dopri8': Dopri8,
258
+ 'implicit_euler': ImplicitEuler,
259
+ 'kvaerno3': Kvaerno3,
260
+ 'kvaerno4': Kvaerno4,
261
+ 'kvaerno5': Kvaerno5,
262
+ 'sil3': Sil3,
263
+ 'ken_carp3': KenCarp3,
264
+ 'ken_carp4': KenCarp4,
265
+ 'ken_carp5': KenCarp5,
266
+ 'semi_implicit_euler': SemiImplicitEuler,
267
+ 'reversible_heun': ReversibleHeun,
268
+ 'leapfrog_midpoint': LeapfrogMidpoint
269
+ }
270
+
271
+ solver_class = solver_map.get(solver_type.lower())
272
+ if solver_class is None:
273
+ raise ValueError(f"Unknown solver type: {solver_type}. "
274
+ f"Available options: {list(solver_map.keys())}")
275
+
276
+ return solver_class(**solver_kwargs)
277
+
278
+ def correlation_polarization_correction(
279
+ e_tilde_PF: float,
280
+ den: Array,
281
+ clip_cte: float = 1e-30
282
+ ):
283
+ r"""Spin polarization correction to a correlation functional using eq 2.75 from
284
+ Carsten A. Ullrich, "Time-Dependent Density-Functional Theory".
285
+
286
+ Parameters
287
+ ----------
288
+ e_tilde_PF: Float[Array, "spin grid"]
289
+ The paramagnetic/ferromagnetic energy contributions on the grid, to be combined.
290
+
291
+ rho: Float[Array, "spin grid"]
292
+ The electronic density of each spin polarization at each grid point.
293
+
294
+ clip_cte:
295
+ float, defaults to 1e-30
296
+ Small constant to avoid numerical issues when dividing by rho.
297
+
298
+ Returns
299
+ ----------
300
+ e_tilde: Float[Array, "grid"]
301
+ The ready to be integrated electronic energy density.
302
+ """
303
+
304
+ log_rho = jnp.log2(jnp.clip(den.sum(axis=1), a_min=clip_cte))
305
+ # assert not jnp.isnan(log_rho).any() and not jnp.isinf(log_rho).any()
306
+ log_rs = jnp.log2((3 / (4 * jnp.pi)) ** (1 / 3)) - log_rho / 3.0
307
+
308
+ zeta = jnp.where(den.sum(axis=1) > clip_cte, (den[:, 0] - den[:, 1]) / (den.sum(axis=1)), 0.0)
309
+
310
+ def fzeta(z):
311
+ zm = 2 ** (4 * jnp.log2(1 - z) / 3)
312
+ zp = 2 ** (4 * jnp.log2(1 + z) / 3)
313
+ return (zm + zp - 2) / (2 * (2 ** (1 / 3) - 1))
314
+
315
+ A_ = 0.016887
316
+ alpha1 = 0.11125
317
+ beta1 = 10.357
318
+ beta2 = 3.6231
319
+ beta3 = 0.88026
320
+ beta4 = 0.49671
321
+
322
+ ars = 2 ** (jnp.log2(alpha1) + log_rs)
323
+ brs_1_2 = 2 ** (jnp.log2(beta1) + log_rs / 2)
324
+ brs = 2 ** (jnp.log2(beta2) + log_rs)
325
+ brs_3_2 = 2 ** (jnp.log2(beta3) + 3 * log_rs / 2)
326
+ brs2 = 2 ** (jnp.log2(beta4) + 2 * log_rs)
327
+
328
+ alphac = 2 * A_ * (1 + ars) * jnp.log(1 + (1 / (2 * A_)) / (brs_1_2 + brs + brs_3_2 + brs2))
329
+ # assert not jnp.isnan(alphac).any() and not jnp.isinf(alphac).any()
330
+
331
+ fz = fzeta(zeta) #jnp.round(fzeta(zeta), int(math.log10(clip_cte)))
332
+ z4 = zeta**4 #jnp.round(2 ** (4 * jnp.log2(jnp.clip(zeta, a_min=clip_cte))), int(math.log10(clip_cte)))
333
+
334
+ e_tilde = (
335
+ e_tilde_PF[:, 0]
336
+ + alphac * (fz / (grad(grad(fzeta))(0.0))) * (1 - z4)
337
+ + (e_tilde_PF[:, 1] - e_tilde_PF[:, 0]) * fz * z4
338
+ )
339
+ # assert not jnp.isnan(e_tilde).any() and not jnp.isinf(e_tilde).any()
340
+
341
+ return e_tilde
342
+
343
+ def one_hot_encode(z: Array) -> Array:
344
+ """
345
+ One hot encode the input array.
346
+
347
+ Parameters
348
+ ----------
349
+ z : Array
350
+ Array to one hot encode.
351
+
352
+ Returns
353
+ -------
354
+ Array
355
+ One hot encoded array.
356
+ """
357
+ z_u = jnp.unique(z)
358
+ zz = []
359
+ for zi in z:
360
+ for i,zui in enumerate(z_u):
361
+ if zui == zi:
362
+ zz.append(i)
363
+ n_atoms_type = jnp.unique(z).shape[0]
364
+ z_one_hot = jax.nn.one_hot(zz,n_atoms_type)
365
+
366
+ return z_one_hot
367
+
368
+ # def coordinates(mol_name: str, BOHR: float = 1.8897259886 ) -> Array:
369
+ def coordinates(mol_name: str, bond_length: float = 1.4008538753, BOHR: float = 1.8897259886) -> Array:
370
+
371
+ if mol_name == 'H2':
372
+ Ne = 2
373
+ atoms = ['H', 'H']
374
+ coords = jnp.array([[0., 0., -bond_length/2],
375
+ [0., 0., bond_length/2]]) # bond_length already in Bohr
376
+ z = jnp.array([1, 1])
377
+ return Ne, atoms, z, coords
378
+
379
+ elif mol_name == 'H10':
380
+ n_h = 10
381
+ Ne = n_h
382
+ atoms = ['H'] * n_h
383
+ offsets = (jnp.arange(n_h) - (n_h - 1) / 2.) * bond_length # spacing = bond_length
384
+ coords = jnp.stack([jnp.zeros(n_h),
385
+ jnp.zeros(n_h),
386
+ offsets], axis=1)
387
+ z = jnp.ones(n_h)
388
+ return Ne, atoms, z, coords
389
+
390
+ elif mol_name == 'H':
391
+ Ne = 1
392
+ atoms = ['H']
393
+ coords = jnp.array([[0., 0., 0.]])
394
+ z = jnp.array([1.])
395
+ return Ne,atoms,z,coords
396
+ elif mol_name == 'He':
397
+ Ne = 2
398
+ atoms = ['He']
399
+ coords = jnp.array([[0., 0., 0.]])
400
+ z = jnp.array([2.])
401
+ return Ne,atoms,z,coords
402
+ elif mol_name == 'N2':
403
+ Ne = 14
404
+ atoms = ['N', 'N']
405
+ coords = jnp.array([[0., 0., -bond_length/2],
406
+ [0., 0., bond_length/2]])
407
+ z = jnp.array([7., 7.])
408
+ return Ne,atoms,z,coords
409
+ elif mol_name == 'O2':
410
+ Ne = 16
411
+ atoms = ['O', 'O']
412
+ coords = jnp.array([[0., 0., -bond_length/2],
413
+ [0., 0., bond_length/2]])
414
+ z = jnp.array([8., 8.])
415
+ return Ne,atoms,z,coords
416
+ elif mol_name == 'F2':
417
+ Ne = 18
418
+ atoms = ['F', 'F']
419
+ coords = jnp.array([[0., 0., -bond_length/2],
420
+ [0., 0., bond_length/2]])
421
+ z = jnp.array([9., 9.])
422
+ return Ne,atoms,z,coords
423
+ elif mol_name == 'CO':
424
+ Ne = 14
425
+ atoms = ['C', 'O']
426
+ coords = jnp.array([[0., 0., -bond_length/2],
427
+ [0., 0., bond_length/2]])
428
+ z = jnp.array([6., 8.])
429
+ return Ne,atoms,z,coords
430
+ elif mol_name == 'HF':
431
+ Ne = 10
432
+ atoms = ['H', 'F']
433
+ coords = jnp.array([[0., 0., -bond_length/2],
434
+ [0., 0., bond_length/2]])
435
+ z = jnp.array([1., 9.])
436
+ return Ne,atoms,z,coords
437
+ elif mol_name == 'Li':
438
+ Ne = 3
439
+ atoms = ['Li']
440
+ coords = jnp.array([[0., 0., 0.]])
441
+ z = jnp.array([3.])
442
+ return Ne,atoms,z,coords
443
+ elif mol_name == 'Be':
444
+ Ne = 4
445
+ atoms = ['Be']
446
+ coords = jnp.array([[0., 0., 0.]])
447
+ z = jnp.array([4.])
448
+ return Ne,atoms,z,coords
449
+ elif mol_name == 'B':
450
+ Ne = 5
451
+ atoms = ['B']
452
+ coords = jnp.array([[0., 0., 0.]])
453
+ z = jnp.array([5.])
454
+ return Ne,atoms,z,coords
455
+ elif mol_name == 'C':
456
+ Ne = 6
457
+ atoms = ['C']
458
+ coords = jnp.array([[0., 0., 0.]])
459
+ z = jnp.array([6.])
460
+ return Ne,atoms,z,coords
461
+ elif mol_name == 'N':
462
+ Ne = 7
463
+ atoms = ['N']
464
+ coords = jnp.array([[0., 0., 0.]])
465
+ z = jnp.array([7.])
466
+ return Ne,atoms,z,coords
467
+ elif mol_name == 'O':
468
+ Ne = 8
469
+ atoms = ['O']
470
+ coords = jnp.array([[0., 0., 0.]])
471
+ z = jnp.array([8.])
472
+ return Ne,atoms,z,coords
473
+ elif mol_name == 'F':
474
+ Ne = 9
475
+ atoms = ['F']
476
+ coords = jnp.array([[0., 0., 0.]])
477
+ z = jnp.array([9.])
478
+ return Ne,atoms,z,coords
479
+ elif mol_name == 'Ne':
480
+ Ne = 10
481
+ atoms = ['Ne']
482
+ coords = jnp.array([[0., 0., 0.]])
483
+ z = jnp.array([10.])
484
+ return Ne,atoms,z,coords
485
+ elif mol_name == 'LiH':
486
+ Ne = 4
487
+ atoms = ['Li', 'H']
488
+ coords = jnp.array([[0., 0., -bond_length/2],
489
+ [0., 0., bond_length/2]])
490
+ z = jnp.array([3, 1])
491
+ return Ne,atoms,z,coords
492
+ elif mol_name == 'H2O':
493
+ Ne = 10
494
+ atoms = ['O', 'H', 'H']
495
+ coords = jnp.array([[0.0, 0.0, 0.1189120],
496
+ [0.0, 0.7612710, -0.4756480],
497
+ [0.0, -0.7612710, -0.4756480]])*BOHR
498
+ z = jnp.array([8, 1, 1])
499
+ return Ne,atoms,z,coords
500
+ elif mol_name == 'CH4':
501
+ Ne = 10
502
+ atoms = ['C', 'H', 'H', 'H', 'H']
503
+ coords = jnp.array([[0.0, 0.0, 0.0],
504
+ [0.0, 0.0, 1.09],
505
+ [1.03, 0.0, -0.363],
506
+ [-0.515, -0.889165, -0.363],
507
+ [-0.515, 0.889165, -0.363]])*BOHR
508
+ z = jnp.array([6, 1, 1, 1, 1])
509
+ return Ne,atoms,z,coords
510
+ elif mol_name == 'C6H6':
511
+ Ne = 42
512
+ atoms = ['C','C','C','C','C','C','H','H','H','H','H','H']
513
+ coords = jnp.array([[-0.6984022192, 1.2096794375, 0.0001298085],
514
+ [0.6983971652, 1.2096794375, 0.0001298085],
515
+ [1.3968148123, 0.0000100819, 0.0001298085],
516
+ [0.6983970907, -1.2096631450, -0.0001577731],
517
+ [-0.6984347101, -1.2097236297, -0.0001117213],
518
+ [-1.3967427862, -0.0000316846, -0.0000531302],
519
+ [2.4989957176, 0.0000352657, 0.0001433155],
520
+ [1.2492115488, 2.1643070022, 0.0000608098],
521
+ [-1.2497352350, 2.1640355443, 0.0000873078],
522
+ [1.2495192781, -2.1641211865, -0.0004652006],
523
+ [-1.2494488074, -2.1642720181, -0.0003281198],
524
+ [-2.4988922798, 0.0006052813, -0.0002941369]])*BOHR
525
+ z = jnp.array([6.,6.,6.,6.,6.,6.,1.,1.,1.,1.,1.,1.])
526
+ return Ne,atoms,z,coords
527
+ elif mol_name == 'C27H46O':
528
+ Ne = 216
529
+ atoms = ['O','C','C','C','C','C','C','C','C','C','C',
530
+ 'C','C','C','C','C','C','C','C','C','C','C',
531
+ 'C','C','C','C','C','C','H','H','H','H','H',
532
+ 'H','H','H','H','H','H','H','H','H','H','H',
533
+ 'H','H','H','H','H','H','H','H','H','H','H',
534
+ 'H','H','H','H','H','H','H','H','H','H','H',
535
+ 'H','H','H','H','H','H','H','H']
536
+ coords = jnp.array([[-7.6344919379, 0.4597041902, 0.7112354681],
537
+ [1.1080808630, 0.4597041902, 0.7112354681],
538
+ [0.2530014327, -0.8231352339, 0.7112354681],
539
+ [-1.1599473252, -0.6092757486, 1.2632436184],
540
+ [-1.8846097300, 0.4851364826, 0.4266238676],
541
+ [2.4346065739, -0.1264931117, 0.1695905830],
542
+ [0.4237780146, 1.4699206954, -0.2202286076],
543
+ [-3.3372747943, 0.7709975410, 0.9822345823],
544
+ [1.1534637485, -1.8591349538, 1.3772546332],
545
+ [-1.0194512849, 1.7601151269, 0.2271214576],
546
+ [2.5719892229, -1.4735213405, 0.9177695847],
547
+ [-1.9616876023, -1.9189024556, 1.2281160480],
548
+ [3.6706406723, 0.7518022510, 0.2998092579],
549
+ [1.3056592864, 1.0858012443, 2.1142034561],
550
+ [-4.0659099987, -0.5409527880, 1.3058897583],
551
+ [-4.1355207853, 1.5407857692, -0.1122402103],
552
+ [-3.4345565189, -1.7244824608, 1.3955388362],
553
+ [-3.2883872448, 1.6217121992, 2.2742973586],
554
+ [-5.5600747780, -0.4657287448, 1.5202105354],
555
+ [-5.6342682740, 1.6532127465, 0.1635040190],
556
+ [4.9369027993, 0.0920796760, -0.2857409192],
557
+ [-6.2593559003, 0.2808578004, 0.3875037320],
558
+ [3.4346199063, 2.0977159171, -0.3972787505],
559
+ [4.8597222067, -0.3194422440, -1.7619595329],
560
+ [6.1183294840, -1.0351005847, -2.2632694976],
561
+ [6.0683625521, -1.3844022946, -3.7635312418],
562
+ [5.8475871083, -0.1546334232, -4.6458023942],
563
+ [4.9580276609, -2.4067800912, -4.0202930776],
564
+ [0.1320531554, -1.1546414208, -0.3341704111],
565
+ [-1.0961065494, -0.2887741191, 2.3091814564],
566
+ [-2.0071557590, 0.0577891770, -0.5822428933],
567
+ [2.2692308941, -0.3566230816, -0.8919397659],
568
+ [0.4024909308, 1.0868866903, -1.2485637383],
569
+ [0.9326111775, 2.4332063069, -0.2550387088],
570
+ [0.9000265464, -2.8793549884, 1.0717133503],
571
+ [1.0798371703, -1.8121549020, 2.4694535032],
572
+ [-0.9810389048, 2.3462405294, 1.1499104917],
573
+ [-1.4838623070, 2.4059394028, -0.5267205531],
574
+ [3.2317842936, -1.3989125509, 1.7896882295],
575
+ [2.9883545370, -2.2408492672, 0.2555163717],
576
+ [-1.6014174036, -2.5802367118, 2.0254337925],
577
+ [-1.7980739690, -2.4474317007, 0.2806843164],
578
+ [3.8879474770, 0.9526354546, 1.3563558574],
579
+ [1.8492781899, 2.0336135308, 2.0610078045],
580
+ [0.3578137484, 1.3050839833, 2.6129243364],
581
+ [1.8665627197, 0.4298627984, 2.7876605918],
582
+ [-4.0090347972, 1.0378379064, -1.0807548937],
583
+ [-3.7338443144, 2.5534931648, -0.2361852128],
584
+ [-4.0082696014, -2.6236692902, 1.6068637872],
585
+ [-2.9172441072, 2.6329927094, 2.0790411399],
586
+ [-4.2766956463, 1.7447530584, 2.7297022229],
587
+ [-2.6449795846, 1.1683767430, 3.0353862870],
588
+ [-6.0051645049, -1.4649829974, 1.6125803553],
589
+ [-5.7715116059, 0.0256833216, 2.4782609823],
590
+ [-5.8206294845, 2.3008198064, 1.0291274274],
591
+ [-6.1301728674, 2.1483087541, -0.6807530020],
592
+ [5.7878953777, 0.7733743373, -0.1540084782],
593
+ [5.1800431538, -0.7964447756, 0.3105080195],
594
+ [-6.2070878208, -0.3024479474, -0.5399782530],
595
+ [2.9407650699, 1.9873110720, -1.3676590681],
596
+ [4.3895302067, 2.6050134027, -0.5818686450],
597
+ [2.8705286939, 2.7872275940, 0.2335427866],
598
+ [4.0058535847, -0.9856353615, -1.9055309667],
599
+ [4.6772717683, 0.5793866627, -2.3581606350],
600
+ [-8.0203071775, -0.4225513077, 0.8470435952],
601
+ [6.2790374269, -1.9486777435, -1.6774593986],
602
+ [6.9885972888, -0.3917002684, -2.0822296102],
603
+ [7.0293119718, -1.8396620422, -4.0335362510],
604
+ [6.5721859735, 0.6309863795, -4.4078721938],
605
+ [5.9798893020, -0.4194041327, -5.7008780697],
606
+ [4.8401657335, 0.2599872029, -4.5419245327],
607
+ [5.0564586193, -3.2709821863, -3.3550073831],
608
+ [5.0182590552, -2.7761853866, -5.0502175442],
609
+ [3.9582362826, -1.9813780480, -3.8886221559]])*BOHR
610
+
611
+ z = jnp.array([8.,6.,6.,6.,6.,6.,6.,6.,6.,6.,6.,6.,6.,
612
+ 6.,6.,6.,6.,6.,6.,6.,6.,6.,6.,6.,6.,6.,6.,6.,
613
+ 1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,
614
+ 1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,
615
+ 1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,
616
+ 1.,1.,1.,1.,1.,1.,1.,1.,1.,1.,
617
+ 1.,1.,1.,1.,1.,1.])
618
+ return Ne,atoms,z,coords