jaxion 0.0.5__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxion/{params_default.json → defaults.json} +4 -0
- jaxion/gravity.py +2 -2
- jaxion/simulation.py +104 -40
- jaxion/utils.py +1 -1
- {jaxion-0.0.5.dist-info → jaxion-0.0.6.dist-info}/METADATA +3 -3
- {jaxion-0.0.5.dist-info → jaxion-0.0.6.dist-info}/RECORD +9 -9
- {jaxion-0.0.5.dist-info → jaxion-0.0.6.dist-info}/WHEEL +0 -0
- {jaxion-0.0.5.dist-info → jaxion-0.0.6.dist-info}/licenses/LICENSE +0 -0
- {jaxion-0.0.5.dist-info → jaxion-0.0.6.dist-info}/top_level.txt +0 -0
jaxion/gravity.py
CHANGED
|
@@ -5,6 +5,6 @@ import jaxdecomp as jd
|
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
def calculate_gravitational_potential(rho, k_sq, G, rho_bar):
|
|
8
|
-
|
|
9
|
-
V = jnp.real(jd.fft.pifft3d(
|
|
8
|
+
V_hat = -jd.fft.pfft3d(4.0 * jnp.pi * G * (rho - rho_bar)) / (k_sq + (k_sq == 0))
|
|
9
|
+
V = jnp.real(jd.fft.pifft3d(V_hat))
|
|
10
10
|
return V
|
jaxion/simulation.py
CHANGED
|
@@ -62,9 +62,10 @@ class Simulation:
|
|
|
62
62
|
self.params["physics"]["hydro"]
|
|
63
63
|
or self.params["physics"]["particles"]
|
|
64
64
|
or self.params["physics"]["external_potential"]
|
|
65
|
+
or self.params["quantum"]["f_15"] != 0.0
|
|
65
66
|
):
|
|
66
67
|
raise NotImplementedError(
|
|
67
|
-
"Cosmological hydro/particles/external_potential
|
|
68
|
+
"Cosmological hydro/particles/external_potential/SI is not yet implemented."
|
|
68
69
|
)
|
|
69
70
|
|
|
70
71
|
if self.params["physics"]["hydro"] or self.params["physics"]["particles"]:
|
|
@@ -94,6 +95,12 @@ class Simulation:
|
|
|
94
95
|
xones, static_argnums=0, in_shardings=None, out_shardings=sharding
|
|
95
96
|
)
|
|
96
97
|
|
|
98
|
+
# customfunctions
|
|
99
|
+
self.custom_kick = None
|
|
100
|
+
self.custom_drift = None
|
|
101
|
+
self.custom_density = None
|
|
102
|
+
self.custom_plot = None
|
|
103
|
+
|
|
97
104
|
# simulation state
|
|
98
105
|
self.state = {}
|
|
99
106
|
self.state["t"] = 0.0
|
|
@@ -173,6 +180,23 @@ class Simulation:
|
|
|
173
180
|
/ constants["speed_of_light"] ** 2
|
|
174
181
|
)
|
|
175
182
|
|
|
183
|
+
@property
|
|
184
|
+
def scattering_length(self):
|
|
185
|
+
"""
|
|
186
|
+
Return the axion self-interaction scattering length in the simulation (kpc)
|
|
187
|
+
"""
|
|
188
|
+
f_15 = self.params["quantum"]["f_15"]
|
|
189
|
+
if f_15 == 0.0:
|
|
190
|
+
return 0.0
|
|
191
|
+
else:
|
|
192
|
+
f = f_15 * 1.0e24 * constants["electron_volt"]
|
|
193
|
+
sign = 1.0 if f > 0 else -1.0
|
|
194
|
+
hbar = constants["reduced_planck_constant"]
|
|
195
|
+
c = constants["speed_of_light"]
|
|
196
|
+
m = self.axion_mass
|
|
197
|
+
a_s = sign * (hbar * c**3 * m) / (32.0 * jnp.pi * (f**2))
|
|
198
|
+
return a_s
|
|
199
|
+
|
|
176
200
|
@property
|
|
177
201
|
def m_per_hbar(self):
|
|
178
202
|
"""
|
|
@@ -224,6 +248,13 @@ class Simulation:
|
|
|
224
248
|
"""
|
|
225
249
|
return quantum_velocity(self.state["psi"], self.box_size, self.m_per_hbar)
|
|
226
250
|
|
|
251
|
+
@property
|
|
252
|
+
def rho_bar(self):
|
|
253
|
+
"""
|
|
254
|
+
Return the mean density of the simulation
|
|
255
|
+
"""
|
|
256
|
+
return self._calc_rho_bar(self.state)
|
|
257
|
+
|
|
227
258
|
def _calc_rho_bar(self, state):
|
|
228
259
|
rho_bar = 0.0
|
|
229
260
|
if self.params["physics"]["quantum"]:
|
|
@@ -235,6 +266,8 @@ class Simulation:
|
|
|
235
266
|
n_particles = self.num_particles
|
|
236
267
|
box_size = self.box_size
|
|
237
268
|
rho_bar += m_particle * n_particles / box_size
|
|
269
|
+
if self.custom_density is not None:
|
|
270
|
+
rho_bar += jnp.mean(self.custom_density(state))
|
|
238
271
|
return rho_bar
|
|
239
272
|
|
|
240
273
|
def _calc_grav_potential(self, state, k_sq):
|
|
@@ -248,6 +281,8 @@ class Simulation:
|
|
|
248
281
|
rho_tot += state["rho"]
|
|
249
282
|
if self.params["physics"]["particles"]:
|
|
250
283
|
rho_tot += bin_particles(state["pos"], self.dx, self.resolution, m_particle)
|
|
284
|
+
if self.custom_density is not None:
|
|
285
|
+
rho_tot += self.custom_density(state)
|
|
251
286
|
if self.params["physics"]["cosmology"]:
|
|
252
287
|
scale_factor = 1.0 / (1.0 + state["redshift"])
|
|
253
288
|
rho_bar *= scale_factor
|
|
@@ -280,8 +315,9 @@ class Simulation:
|
|
|
280
315
|
|
|
281
316
|
# Simulation parameters
|
|
282
317
|
dx = self.dx
|
|
283
|
-
m_per_hbar = self.m_per_hbar
|
|
284
318
|
box_size = self.box_size
|
|
319
|
+
num_cells = self.resolution**3
|
|
320
|
+
m_per_hbar = self.m_per_hbar
|
|
285
321
|
|
|
286
322
|
dt_fac = 1.0
|
|
287
323
|
dt_kin = dt_fac * (m_per_hbar / 6.0) * (dx * dx)
|
|
@@ -290,8 +326,20 @@ class Simulation:
|
|
|
290
326
|
t_span = t_end - t_start
|
|
291
327
|
state["t"] = t_start
|
|
292
328
|
|
|
329
|
+
use_quantum = self.params["physics"]["quantum"]
|
|
330
|
+
use_gravity = self.params["physics"]["gravity"]
|
|
331
|
+
use_hydro = self.params["physics"]["hydro"]
|
|
332
|
+
use_particles = self.params["physics"]["particles"]
|
|
333
|
+
use_cosmology = self.params["physics"]["cosmology"]
|
|
334
|
+
use_external_potential = self.params["physics"]["external_potential"]
|
|
335
|
+
save = self.params["output"]["save"]
|
|
336
|
+
use_custom = self.custom_kick is not None or self.custom_drift is not None
|
|
337
|
+
if use_custom:
|
|
338
|
+
custom_kick = self.custom_kick
|
|
339
|
+
custom_drift = self.custom_drift
|
|
340
|
+
|
|
293
341
|
# cosmology
|
|
294
|
-
if
|
|
342
|
+
if use_cosmology:
|
|
295
343
|
z_start = self.params["time"]["start"]
|
|
296
344
|
z_end = self.params["time"]["end"]
|
|
297
345
|
omega_matter = self.params["cosmology"]["omega_matter"]
|
|
@@ -303,6 +351,17 @@ class Simulation:
|
|
|
303
351
|
state["t"] = 0.0
|
|
304
352
|
state["redshift"] = z_start
|
|
305
353
|
|
|
354
|
+
# self-interaction
|
|
355
|
+
a_s = self.scattering_length
|
|
356
|
+
c = constants["speed_of_light"]
|
|
357
|
+
m = self.axion_mass
|
|
358
|
+
si_coeff = None
|
|
359
|
+
si_coeff2 = None
|
|
360
|
+
do_self_interaction = a_s != 0.0
|
|
361
|
+
if do_self_interaction:
|
|
362
|
+
si_coeff = (4.0 * jnp.pi) * (a_s / m) / m_per_hbar**2
|
|
363
|
+
si_coeff2 = (32.0 * jnp.pi**2 / 3.0) * (a_s / (m * c)) ** 2 / m_per_hbar**5
|
|
364
|
+
|
|
306
365
|
# hydro
|
|
307
366
|
c_sound = self.params["hydro"]["sound_speed"]
|
|
308
367
|
|
|
@@ -317,16 +376,14 @@ class Simulation:
|
|
|
317
376
|
k_sq = None
|
|
318
377
|
|
|
319
378
|
# Fourier space variables
|
|
320
|
-
if
|
|
379
|
+
if use_gravity or use_quantum:
|
|
321
380
|
kx, ky, kz = self.kgrid
|
|
322
381
|
k_sq = kx**2 + ky**2 + kz**2
|
|
323
382
|
|
|
324
383
|
# Checkpointer
|
|
325
|
-
if
|
|
384
|
+
if save:
|
|
326
385
|
options = ocp.CheckpointManagerOptions()
|
|
327
|
-
checkpoint_dir =
|
|
328
|
-
os.getcwd(), self.params["output"]["path"]
|
|
329
|
-
)
|
|
386
|
+
checkpoint_dir = os.path.join(os.getcwd(), self.params["output"]["path"])
|
|
330
387
|
path = os.path.join(os.getcwd(), checkpoint_dir)
|
|
331
388
|
if jax.process_index() == 0:
|
|
332
389
|
path = ocp.test_utils.erase_and_create_empty(checkpoint_dir)
|
|
@@ -336,80 +393,86 @@ class Simulation:
|
|
|
336
393
|
|
|
337
394
|
def _kick(state, kx, ky, kz, k_sq, dt):
|
|
338
395
|
# Kick (half-step)
|
|
339
|
-
if
|
|
340
|
-
self.params["physics"]["gravity"]
|
|
341
|
-
and self.params["physics"]["external_potential"]
|
|
342
|
-
):
|
|
396
|
+
if use_gravity and use_external_potential:
|
|
343
397
|
V = self._calc_grav_potential(state, k_sq) + state["V_ext"]
|
|
344
|
-
elif
|
|
398
|
+
elif use_gravity:
|
|
345
399
|
V = self._calc_grav_potential(state, k_sq)
|
|
346
|
-
elif
|
|
400
|
+
elif use_external_potential:
|
|
347
401
|
V = state["V_ext"]
|
|
348
402
|
|
|
349
|
-
if
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
403
|
+
if use_gravity or use_external_potential:
|
|
404
|
+
if use_quantum:
|
|
405
|
+
if do_self_interaction:
|
|
406
|
+
rho = jnp.abs(state["psi"]) ** 2
|
|
407
|
+
V_prime = V + si_coeff * rho + si_coeff2 * rho**2
|
|
408
|
+
state["psi"] = quantum_kick(
|
|
409
|
+
state["psi"], V_prime, m_per_hbar, dt
|
|
410
|
+
)
|
|
411
|
+
else:
|
|
412
|
+
state["psi"] = quantum_kick(state["psi"], V, m_per_hbar, dt)
|
|
413
|
+
if use_hydro:
|
|
356
414
|
state["vx"], state["vy"], state["vz"] = hydro_accelerate(
|
|
357
415
|
state["vx"], state["vy"], state["vz"], V, kx, ky, kz, dt
|
|
358
416
|
)
|
|
359
|
-
if
|
|
417
|
+
if use_particles:
|
|
360
418
|
state["vel"] = particles_accelerate(
|
|
361
419
|
state["vel"], state["pos"], V, kx, ky, kz, dx, dt
|
|
362
420
|
)
|
|
421
|
+
if use_custom:
|
|
422
|
+
state = custom_kick(state, V, dt)
|
|
423
|
+
|
|
424
|
+
return state
|
|
363
425
|
|
|
364
426
|
def _drift(state, k_sq, dt):
|
|
365
427
|
# Drift (full-step)
|
|
366
|
-
if
|
|
428
|
+
if use_quantum:
|
|
367
429
|
state["psi"] = quantum_drift(state["psi"], k_sq, m_per_hbar, dt)
|
|
368
|
-
if
|
|
430
|
+
if use_hydro:
|
|
369
431
|
state["rho"], state["vx"], state["vy"], state["vz"] = hydro_fluxes(
|
|
370
432
|
state["rho"], state["vx"], state["vy"], state["vz"], dt, dx, c_sound
|
|
371
433
|
)
|
|
372
|
-
if
|
|
434
|
+
if use_particles:
|
|
373
435
|
state["pos"] = particles_drift(state["pos"], state["vel"], dt, box_size)
|
|
436
|
+
if use_custom:
|
|
437
|
+
state = custom_drift(state, k_sq, dt)
|
|
438
|
+
|
|
439
|
+
return state
|
|
374
440
|
|
|
375
|
-
@jax.jit
|
|
376
441
|
def _update(_, carry):
|
|
377
442
|
# Update the simulation state by one timestep
|
|
378
443
|
# according to a 2nd-order `kick-drift-kick` scheme
|
|
379
444
|
state, kx, ky, kz, k_sq = carry
|
|
380
|
-
_kick(state, kx, ky, kz, k_sq, 0.5 * dt)
|
|
381
|
-
_drift(state, k_sq, dt)
|
|
445
|
+
state = _kick(state, kx, ky, kz, k_sq, 0.5 * dt)
|
|
446
|
+
state = _drift(state, k_sq, dt)
|
|
382
447
|
# update time & redshift
|
|
383
448
|
state["t"] += dt
|
|
384
|
-
if
|
|
449
|
+
if use_cosmology:
|
|
385
450
|
scale_factor = get_next_scale_factor(
|
|
386
|
-
state["redshift"],
|
|
387
|
-
dt,
|
|
388
|
-
self.params["cosmology"]["omega_matter"],
|
|
389
|
-
self.params["cosmology"]["omega_lambda"],
|
|
390
|
-
self.params["cosmology"]["little_h"],
|
|
451
|
+
state["redshift"], dt, omega_matter, omega_lambda, little_h
|
|
391
452
|
)
|
|
392
453
|
state["redshift"] = 1.0 / scale_factor - 1.0
|
|
393
|
-
_kick(state, kx, ky, kz, k_sq, 0.5 * dt)
|
|
454
|
+
state = _kick(state, kx, ky, kz, k_sq, 0.5 * dt)
|
|
394
455
|
|
|
395
456
|
return state, kx, ky, kz, k_sq
|
|
396
457
|
|
|
397
458
|
# save initial state
|
|
398
459
|
if jax.process_index() == 0:
|
|
399
460
|
print(f"Starting simulation (res={self.resolution}, nt={nt}) ...")
|
|
400
|
-
if
|
|
461
|
+
if save:
|
|
401
462
|
with open(os.path.join(checkpoint_dir, "params.json"), "w") as f:
|
|
402
463
|
json.dump(self.params, f, indent=2)
|
|
403
464
|
async_checkpoint_manager.save(0, args=ocp.args.StandardSave(state))
|
|
404
465
|
plot_sim(state, checkpoint_dir, 0, self.params)
|
|
466
|
+
if self.custom_plot is not None:
|
|
467
|
+
self.custom_plot(state, checkpoint_dir, 0, self.params)
|
|
405
468
|
async_checkpoint_manager.wait_until_finished()
|
|
406
469
|
|
|
407
470
|
# Simulation Main Loop
|
|
408
471
|
t_start_timer = time.time()
|
|
409
|
-
if
|
|
472
|
+
if save:
|
|
410
473
|
for i in range(1, num_checkpoints + 1):
|
|
411
474
|
carry = jax.lax.fori_loop(0, nt_sub, _update, init_val=carry)
|
|
412
|
-
state,
|
|
475
|
+
state, _, _, _, _ = carry
|
|
413
476
|
jax.block_until_ready(state)
|
|
414
477
|
# save state
|
|
415
478
|
async_checkpoint_manager.save(i, args=ocp.args.StandardSave(state))
|
|
@@ -417,17 +480,18 @@ class Simulation:
|
|
|
417
480
|
elapsed = time.time() - t_start_timer
|
|
418
481
|
est_total = elapsed / i * num_checkpoints
|
|
419
482
|
est_remaining = est_total - elapsed
|
|
420
|
-
num_cells = self.resolution**3
|
|
421
483
|
mcups = (num_cells * (i * nt_sub)) / (elapsed * 1.0e6)
|
|
422
484
|
if jax.process_index() == 0:
|
|
423
485
|
print(
|
|
424
486
|
f"{percent:.1f}%: mcups={mcups:.1f}, estimated time left (s): {est_remaining:.1f}"
|
|
425
487
|
)
|
|
426
488
|
plot_sim(state, checkpoint_dir, i, self.params)
|
|
489
|
+
if self.custom_plot is not None:
|
|
490
|
+
self.custom_plot(state, checkpoint_dir, i, self.params)
|
|
427
491
|
async_checkpoint_manager.wait_until_finished()
|
|
428
492
|
else:
|
|
429
493
|
carry = jax.lax.fori_loop(0, nt, _update, init_val=carry)
|
|
430
|
-
state,
|
|
494
|
+
state, _, _, _, _ = carry
|
|
431
495
|
jax.block_until_ready(state)
|
|
432
496
|
if jax.process_index() == 0:
|
|
433
497
|
print("Simulation Run Time (s): ", time.time() - t_start_timer)
|
jaxion/utils.py
CHANGED
|
@@ -31,7 +31,7 @@ def print_distributed_info():
|
|
|
31
31
|
|
|
32
32
|
def set_up_parameters(user_overwrites):
|
|
33
33
|
# first load the default params
|
|
34
|
-
params_path = importlib.resources.files("jaxion") / "
|
|
34
|
+
params_path = importlib.resources.files("jaxion") / "defaults.json"
|
|
35
35
|
with params_path.open("r", encoding="utf-8") as f:
|
|
36
36
|
params = json.load(f)
|
|
37
37
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: jaxion
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.6
|
|
4
4
|
Summary: A differentiable simulation library for fuzzy dark matter in JAX
|
|
5
5
|
Author-email: Philip Mocz <philip.mocz@gmail.com>
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -68,7 +68,7 @@ Jaxion is built for multi-GPU scalability and is fully differentiable. It is a h
|
|
|
68
68
|
Jaxion is the simpler companion project to differentiable astrophysics code [Adirondax](https://github.com/AdirondaxProject/adirondax)
|
|
69
69
|
|
|
70
70
|
|
|
71
|
-
##
|
|
71
|
+
## Install Jaxion
|
|
72
72
|
|
|
73
73
|
Install with:
|
|
74
74
|
|
|
@@ -87,7 +87,7 @@ See the docs for more info on how to [build from source](https://jaxion.readthed
|
|
|
87
87
|
|
|
88
88
|
## Examples
|
|
89
89
|
|
|
90
|
-
Check out the `examples/` directory for demonstrations of using Jaxion.
|
|
90
|
+
Check out the [`examples/`](https://github.com/JaxionProject/jaxion/tree/main/examples/) directory for demonstrations of using Jaxion.
|
|
91
91
|
|
|
92
92
|
<p align="center">
|
|
93
93
|
<a href="https://github.com/JaxionProject/jaxion/tree/main/examples/cosmological_box">
|
|
@@ -2,16 +2,16 @@ jaxion/__init__.py,sha256=Hdji1UQ47lG24Pqcy6UUq9L0-qy6m9Ax41L0vIYzBto,164
|
|
|
2
2
|
jaxion/analysis.py,sha256=4YT9Z2dkFoXwft3fQM1HyynVPlIdtRd80VtI2vWTyq4,1568
|
|
3
3
|
jaxion/constants.py,sha256=HyY2ktKQakv78jD1yQvFdM3sklUJcPgDMYlTsSPQTxI,512
|
|
4
4
|
jaxion/cosmology.py,sha256=UC1McXNTXGoPRYXn0nI2-csVkJWL-ZBNoCa44oU1b4w,2681
|
|
5
|
-
jaxion/
|
|
5
|
+
jaxion/defaults.json,sha256=EqFHV9HlLIRvTJrfbT5-AI0GfCLBd0ViwTeUV_B8YIw,2883
|
|
6
|
+
jaxion/gravity.py,sha256=2smqy_jjmr0VkMGzLPMYjLahHBMfZ2nNUx0gUiTWcDI,293
|
|
6
7
|
jaxion/hydro.py,sha256=KoJ02tRpAc4V3Ofzw4zbHLRaE2GdIatbOBE04_LsSRw,6980
|
|
7
|
-
jaxion/params_default.json,sha256=9CJrhEPsv5zGEs7_WqFyuccCDipPCDhXgKzVdqOsOWE,2775
|
|
8
8
|
jaxion/particles.py,sha256=pMopGvoZ0J_3EviD0WnTMmiebU9h2_8IO-p6I-E5DEU,3980
|
|
9
9
|
jaxion/quantum.py,sha256=GWOpN6ipfEw-6Ah2zQpxS3oqeSt_iHMDSgnVYSjXY5E,3321
|
|
10
|
-
jaxion/simulation.py,sha256=
|
|
11
|
-
jaxion/utils.py,sha256=
|
|
10
|
+
jaxion/simulation.py,sha256=rxncry9xPAVpohGkLqSo9vft-qaGTnh4f-wKk-A-Y2A,18496
|
|
11
|
+
jaxion/utils.py,sha256=OCpJ3crZqr5VFacymYzi5BkRqFCVBcneoS44wV9mZPg,4087
|
|
12
12
|
jaxion/visualization.py,sha256=Vx_xuEE7BB1AkEGaY9KHNSIIDxJzUVT104o-3uglW8o,2966
|
|
13
|
-
jaxion-0.0.
|
|
14
|
-
jaxion-0.0.
|
|
15
|
-
jaxion-0.0.
|
|
16
|
-
jaxion-0.0.
|
|
17
|
-
jaxion-0.0.
|
|
13
|
+
jaxion-0.0.6.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
14
|
+
jaxion-0.0.6.dist-info/METADATA,sha256=uNykshiiqSIPVw2UM44_zREemfnkaJlGdVqZKyi7X-8,6440
|
|
15
|
+
jaxion-0.0.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
16
|
+
jaxion-0.0.6.dist-info/top_level.txt,sha256=S1OV2VdlDG_9UwpKOIji4itQGOS-VWUOWUi3GeXWzt0,7
|
|
17
|
+
jaxion-0.0.6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|