jaxion 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- jaxion/__init__.py +1 -0
- jaxion/analysis.py +43 -0
- jaxion/cosmology.py +74 -0
- jaxion/gravity.py +3 -2
- jaxion/hydro.py +5 -4
- jaxion/params_default.json +52 -15
- jaxion/particles.py +5 -4
- jaxion/quantum.py +82 -2
- jaxion/simulation.py +145 -38
- jaxion/utils.py +27 -0
- jaxion/visualization.py +73 -57
- {jaxion-0.0.2.dist-info → jaxion-0.0.3.dist-info}/METADATA +23 -3
- jaxion-0.0.3.dist-info/RECORD +17 -0
- jaxion-0.0.2.dist-info/RECORD +0 -15
- {jaxion-0.0.2.dist-info → jaxion-0.0.3.dist-info}/WHEEL +0 -0
- {jaxion-0.0.2.dist-info → jaxion-0.0.3.dist-info}/licenses/LICENSE +0 -0
- {jaxion-0.0.2.dist-info → jaxion-0.0.3.dist-info}/top_level.txt +0 -0
jaxion/__init__.py
CHANGED
jaxion/analysis.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
import jaxdecomp as jd
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def radial_power_spectrum(data_cube, kx, ky, kz, box_size):
|
|
6
|
+
"""
|
|
7
|
+
Computes radially averaged power spectral density of data_cube (3D).
|
|
8
|
+
data_cube: jnp.ndarray (3D, must be cubic)
|
|
9
|
+
boxsize: float (physical size of box)
|
|
10
|
+
Returns: Pf (radial power spectrum), k (wavenumbers), total_power
|
|
11
|
+
"""
|
|
12
|
+
dim = data_cube.ndim
|
|
13
|
+
nx = data_cube.shape[0]
|
|
14
|
+
dx = box_size / nx
|
|
15
|
+
|
|
16
|
+
# Compute power spectrum
|
|
17
|
+
data_cube_hat = jd.fft.pfft3d(data_cube)
|
|
18
|
+
total_power = 0.5 * jnp.sum(jnp.abs(data_cube_hat) ** 2) / nx**dim * dx**dim
|
|
19
|
+
phi_k = 0.5 * jnp.abs(data_cube_hat) ** 2 / nx**dim * dx**dim
|
|
20
|
+
half_size = nx // 2 + 1
|
|
21
|
+
|
|
22
|
+
# Compute radially-averaged power spectrum
|
|
23
|
+
# if dim == 2:
|
|
24
|
+
# k_r = jnp.sqrt(kx**2 + ky**2)
|
|
25
|
+
k_r = jnp.sqrt(kx**2 + ky**2 + kz**2)
|
|
26
|
+
|
|
27
|
+
Pf, _ = jnp.histogram(
|
|
28
|
+
k_r, range=(-0.5, half_size - 0.5), bins=half_size, weights=phi_k
|
|
29
|
+
)
|
|
30
|
+
norm, _ = jnp.histogram(k_r, range=(-0.5, half_size - 0.5), bins=half_size)
|
|
31
|
+
Pf /= norm + (norm == 0)
|
|
32
|
+
|
|
33
|
+
k = 2.0 * jnp.pi * jnp.arange(half_size) / box_size
|
|
34
|
+
dk = 2.0 * jnp.pi / box_size
|
|
35
|
+
|
|
36
|
+
Pf /= dk**dim
|
|
37
|
+
|
|
38
|
+
# Add geometrical factor
|
|
39
|
+
# if dim == 2:
|
|
40
|
+
# Pf = Pf * 2.0 * jnp.pi * k
|
|
41
|
+
Pf *= 4.0 * jnp.pi * k**2
|
|
42
|
+
|
|
43
|
+
return Pf, k, total_power
|
jaxion/cosmology.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
|
|
3
|
+
# Pure functions for cosmology simulation
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_physical_time_interval(z_start, z_end, omega_matter, omega_lambda, little_h):
|
|
7
|
+
"""Compute the total physical time between two redshifts."""
|
|
8
|
+
# da/dt = H0 * sqrt(omega_matter / a + omega_lambda * a^2)
|
|
9
|
+
a_start = 1.0 / (1.0 + z_start)
|
|
10
|
+
a_end = 1.0 / (1.0 + z_end)
|
|
11
|
+
H0 = 0.1 * little_h # Hubble constant in (km/s/kpc)
|
|
12
|
+
n_quad = 10000
|
|
13
|
+
a_lin = jnp.linspace(a_start, a_end, n_quad)
|
|
14
|
+
a_dot = H0 * jnp.sqrt((omega_matter / a_lin) + (omega_lambda * a_lin**2))
|
|
15
|
+
dt_hat = jnp.trapezoid(1.0 / a_dot, a_lin)
|
|
16
|
+
return dt_hat
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_supercomoving_time_interval(
|
|
20
|
+
z_start, z_end, omega_matter, omega_lambda, little_h
|
|
21
|
+
):
|
|
22
|
+
"""Compute the total supercomoving time (dt_hat = a^-2 dt) between two redshifts."""
|
|
23
|
+
# da/dt = H0 * sqrt(omega_matter / a + omega_lambda * a^2)
|
|
24
|
+
# da/dt_hat = a^2 * da/dt
|
|
25
|
+
# dt_hat/da = a^-2 / (da/dt)
|
|
26
|
+
a_start = 1.0 / (1.0 + z_start)
|
|
27
|
+
a_end = 1.0 / (1.0 + z_end)
|
|
28
|
+
H0 = 0.1 * little_h # Hubble constant in (km/s/kpc)
|
|
29
|
+
n_quad = 10000
|
|
30
|
+
a_lin = jnp.linspace(a_start, a_end, n_quad)
|
|
31
|
+
a_dot = H0 * jnp.sqrt((omega_matter / a_lin) + (omega_lambda * a_lin**2))
|
|
32
|
+
dt_hat_da = a_lin**-2 / a_dot
|
|
33
|
+
dt_hat = jnp.trapezoid(dt_hat_da, a_lin)
|
|
34
|
+
return dt_hat
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def get_scale_factor(z_start, dt_hat, omega_matter, omega_lambda, little_h):
|
|
38
|
+
"""Compute the scale factor corresponding to a given supercomoving time,
|
|
39
|
+
by root finding."""
|
|
40
|
+
a_start = 1.0 / (1.0 + z_start)
|
|
41
|
+
a = a_start
|
|
42
|
+
tolerance = 1e-6
|
|
43
|
+
max_iterations = 100
|
|
44
|
+
lower_bound = 0.0
|
|
45
|
+
upper_bound = 2.0 # Set a reasonable upper bound for the scale factor
|
|
46
|
+
for _ in range(max_iterations):
|
|
47
|
+
dt_hat_guess = get_supercomoving_time_interval(
|
|
48
|
+
z_start, 1.0 / a - 1.0, omega_matter, omega_lambda, little_h
|
|
49
|
+
)
|
|
50
|
+
error = dt_hat_guess - dt_hat
|
|
51
|
+
if jnp.abs(error) < tolerance:
|
|
52
|
+
break
|
|
53
|
+
# Use bisection method
|
|
54
|
+
if error > 0:
|
|
55
|
+
upper_bound = a # Move upper bound down
|
|
56
|
+
else:
|
|
57
|
+
lower_bound = a # Move lower bound up
|
|
58
|
+
a = (lower_bound + upper_bound) / 2 # Update scale factor to midpoint
|
|
59
|
+
return a
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def get_next_scale_factor(z_start, dt_hat, omega_matter, omega_lambda, little_h):
|
|
63
|
+
"""Advance scale factor by dt_hat using RK2 (midpoint) on da/dt_hat = a^2 * da/dt."""
|
|
64
|
+
a_start = 1.0 / (1.0 + z_start)
|
|
65
|
+
H0 = 0.1 * little_h
|
|
66
|
+
|
|
67
|
+
def g(a):
|
|
68
|
+
a_dot = H0 * jnp.sqrt((omega_matter / a) + (omega_lambda * a**2))
|
|
69
|
+
return a**2 * a_dot # da/dt_hat
|
|
70
|
+
|
|
71
|
+
k1 = g(a_start)
|
|
72
|
+
k2 = g(a_start + 0.5 * dt_hat * k1)
|
|
73
|
+
a = a_start + dt_hat * k2
|
|
74
|
+
return a
|
jaxion/gravity.py
CHANGED
|
@@ -1,9 +1,10 @@
|
|
|
1
1
|
import jax.numpy as jnp
|
|
2
|
+
import jaxdecomp as jd
|
|
2
3
|
|
|
3
4
|
# Pure functions for gravity calculations
|
|
4
5
|
|
|
5
6
|
|
|
6
7
|
def calculate_gravitational_potential(rho, k_sq, G, rho_bar):
|
|
7
|
-
Vhat = -
|
|
8
|
-
V = jnp.real(
|
|
8
|
+
Vhat = -jd.fft.pfft3d(4.0 * jnp.pi * G * (rho - rho_bar)) / (k_sq + (k_sq == 0))
|
|
9
|
+
V = jnp.real(jd.fft.pifft3d(Vhat))
|
|
9
10
|
return V
|
jaxion/hydro.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import jax.numpy as jnp
|
|
2
|
+
import jaxdecomp as jd
|
|
2
3
|
|
|
3
4
|
# Pure functions for hydro simulation
|
|
4
5
|
|
|
@@ -89,11 +90,11 @@ def get_flux(rho_L, vx_L, vy_L, vz_L, rho_R, vx_R, vy_R, vz_R, cs):
|
|
|
89
90
|
|
|
90
91
|
|
|
91
92
|
def hydro_accelerate(vx, vy, vz, V, kx, ky, kz, dt):
|
|
92
|
-
V_hat =
|
|
93
|
+
V_hat = jd.fft.pfft3d(V)
|
|
93
94
|
|
|
94
|
-
ax = -jnp.real(
|
|
95
|
-
ay = -jnp.real(
|
|
96
|
-
az = -jnp.real(
|
|
95
|
+
ax = -jnp.real(jd.fft.pifft3d(1.0j * kx * V_hat))
|
|
96
|
+
ay = -jnp.real(jd.fft.pifft3d(1.0j * ky * V_hat))
|
|
97
|
+
az = -jnp.real(jd.fft.pifft3d(1.0j * kz * V_hat))
|
|
97
98
|
|
|
98
99
|
vx += ax * dt
|
|
99
100
|
vy += ay * dt
|
jaxion/params_default.json
CHANGED
|
@@ -28,7 +28,7 @@
|
|
|
28
28
|
"domain": {
|
|
29
29
|
"box_size": {
|
|
30
30
|
"default": 10.0,
|
|
31
|
-
"description": "periodic domain box size
|
|
31
|
+
"description": "periodic domain box size [kpc]."
|
|
32
32
|
},
|
|
33
33
|
"resolution_base": {
|
|
34
34
|
"default": 32,
|
|
@@ -40,38 +40,75 @@
|
|
|
40
40
|
}
|
|
41
41
|
},
|
|
42
42
|
"time": {
|
|
43
|
-
"start":
|
|
44
|
-
|
|
45
|
-
|
|
43
|
+
"start": {
|
|
44
|
+
"default": 0.0,
|
|
45
|
+
"description": "simulation start time [kpc/(km/s)] or [redshift] (cosmology=true)."
|
|
46
|
+
},
|
|
47
|
+
"end": {
|
|
48
|
+
"default": 1.0,
|
|
49
|
+
"description": "simulation end time [kpc/(km/s)] or [redshift] (cosmology=true)."
|
|
50
|
+
},
|
|
51
|
+
"adaptive": {
|
|
52
|
+
"default": false,
|
|
53
|
+
"description": "switch on for adaptive time stepping."
|
|
54
|
+
}
|
|
46
55
|
},
|
|
47
56
|
"output": {
|
|
48
|
-
"path":
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
57
|
+
"path": {
|
|
58
|
+
"default": "./checkpoints",
|
|
59
|
+
"description": "path to output directory."
|
|
60
|
+
},
|
|
61
|
+
"num_checkpoints": {
|
|
62
|
+
"default": 100,
|
|
63
|
+
"description": "number of checkpoints to save."
|
|
64
|
+
},
|
|
65
|
+
"save": {
|
|
66
|
+
"default": true,
|
|
67
|
+
"description": "switch on to save checkpoints."
|
|
68
|
+
},
|
|
69
|
+
"plot_dynamic_range": {
|
|
70
|
+
"default": 100.0,
|
|
71
|
+
"description": "dynamic range for plotting."
|
|
72
|
+
}
|
|
52
73
|
},
|
|
53
74
|
"quantum": {
|
|
54
75
|
"m_22": {
|
|
55
76
|
"default": 1.0,
|
|
56
|
-
"description": "axion mass
|
|
77
|
+
"description": "axion mass [10^{-22} eV]."
|
|
57
78
|
}
|
|
58
79
|
},
|
|
59
80
|
"hydro": {
|
|
60
81
|
"sound_speed": {
|
|
61
82
|
"default": 1.0,
|
|
62
|
-
"description": "
|
|
83
|
+
"description": "isothermal sound speed [km/s]."
|
|
63
84
|
}
|
|
64
85
|
},
|
|
65
86
|
"particles": {
|
|
66
|
-
"num_particles":
|
|
67
|
-
|
|
87
|
+
"num_particles": {
|
|
88
|
+
"default": 0,
|
|
89
|
+
"description": "number of particles."
|
|
90
|
+
},
|
|
91
|
+
"particle_mass": {
|
|
92
|
+
"default": 1.0,
|
|
93
|
+
"description": "particle mass [M_sun]."
|
|
94
|
+
}
|
|
68
95
|
},
|
|
69
96
|
"cosmology": {
|
|
70
|
-
"
|
|
71
|
-
|
|
97
|
+
"omega_matter": {
|
|
98
|
+
"default": 0.3,
|
|
99
|
+
"description": "matter density parameter."
|
|
100
|
+
},
|
|
101
|
+
"omega_lambda": {
|
|
102
|
+
"default": 0.7,
|
|
103
|
+
"description": "dark energy density parameter."
|
|
104
|
+
},
|
|
105
|
+
"little_h": {
|
|
106
|
+
"default": 0.7,
|
|
107
|
+
"description": "Hubble parameter little h (H_0 = 100 h km/s/Mpc)."
|
|
108
|
+
}
|
|
72
109
|
},
|
|
73
110
|
"version": {
|
|
74
111
|
"default": "unknown",
|
|
75
|
-
"description": "
|
|
112
|
+
"description": "jaxion version used (auto detected)."
|
|
76
113
|
}
|
|
77
114
|
}
|
jaxion/particles.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import jax
|
|
2
2
|
import jax.numpy as jnp
|
|
3
|
+
import jaxdecomp as jd
|
|
3
4
|
|
|
4
5
|
# Pure functions for particle-mesh calculations (stars, BHs, ...)
|
|
5
6
|
|
|
@@ -64,10 +65,10 @@ def get_acceleration(pos, V, kx, ky, kz, dx):
|
|
|
64
65
|
i, ip1, w_i, w_ip1 = get_cic_indices_and_weights(pos, dx, resolution)
|
|
65
66
|
|
|
66
67
|
# find accelerations on the grid
|
|
67
|
-
V_hat =
|
|
68
|
-
ax = -jnp.real(
|
|
69
|
-
ay = -jnp.real(
|
|
70
|
-
az = -jnp.real(
|
|
68
|
+
V_hat = jd.fft.pfft3d(V)
|
|
69
|
+
ax = -jnp.real(jd.fft.pifft3d(1.0j * kx * V_hat))
|
|
70
|
+
ay = -jnp.real(jd.fft.pifft3d(1.0j * ky * V_hat))
|
|
71
|
+
az = -jnp.real(jd.fft.pifft3d(1.0j * kz * V_hat))
|
|
71
72
|
a_grid = jnp.stack((ax, ay, az), axis=-1)
|
|
72
73
|
|
|
73
74
|
# interpolate the accelerations to the particle positions
|
jaxion/quantum.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import jax.numpy as jnp
|
|
2
|
+
import jaxdecomp as jd
|
|
2
3
|
|
|
3
4
|
# Pure functions for quantum simulation
|
|
4
5
|
|
|
@@ -9,7 +10,86 @@ def quantum_kick(psi, V, m_per_hbar, dt):
|
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
def quantum_drift(psi, k_sq, m_per_hbar, dt):
|
|
12
|
-
psi_hat =
|
|
13
|
+
psi_hat = jd.fft.pfft3d(psi)
|
|
13
14
|
psi_hat = jnp.exp(dt * (-1.0j * k_sq / m_per_hbar / 2.0)) * psi_hat
|
|
14
|
-
psi =
|
|
15
|
+
psi = jd.fft.pifft3d(psi_hat)
|
|
15
16
|
return psi
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_gradient(psi, kx, ky, kz):
|
|
20
|
+
"""
|
|
21
|
+
Computes gradient of wavefunction psi
|
|
22
|
+
psi: jnp.ndarray (3D)
|
|
23
|
+
Returns: grad_psi_x, grad_psi_y, grad_psi_z
|
|
24
|
+
"""
|
|
25
|
+
psi_hat = jd.fft.pfft3d(psi)
|
|
26
|
+
grad_psi_x_hat = 1.0j * kx * psi_hat
|
|
27
|
+
grad_psi_y_hat = 1.0j * ky * psi_hat
|
|
28
|
+
grad_psi_z_hat = 1.0j * kz * psi_hat
|
|
29
|
+
|
|
30
|
+
grad_psi_x = jd.fft.pifft3d(grad_psi_x_hat)
|
|
31
|
+
grad_psi_y = jd.fft.pifft3d(grad_psi_y_hat)
|
|
32
|
+
grad_psi_z = jd.fft.pifft3d(grad_psi_z_hat)
|
|
33
|
+
|
|
34
|
+
return grad_psi_x, grad_psi_y, grad_psi_z
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def quantum_velocity(psi, box_size, m_per_hbar):
|
|
38
|
+
"""
|
|
39
|
+
Compute the velocity from the wave-function
|
|
40
|
+
v = nabla S / m
|
|
41
|
+
psi = sqrt(rho) exp(i S / hbar)
|
|
42
|
+
"""
|
|
43
|
+
N = psi.shape[0]
|
|
44
|
+
dx = box_size / N
|
|
45
|
+
|
|
46
|
+
S_per_hbar = jnp.angle(psi)
|
|
47
|
+
|
|
48
|
+
# Central differences with phase unwrapping
|
|
49
|
+
# vx = jnp.roll(S_per_hbar, -1, axis=0) - jnp.roll(S_per_hbar, 1, axis=0)
|
|
50
|
+
# vy = jnp.roll(S_per_hbar, -1, axis=1) - jnp.roll(S_per_hbar, 1, axis=1)
|
|
51
|
+
# vz = jnp.roll(S_per_hbar, -1, axis=2) - jnp.roll(S_per_hbar, 1, axis=2)
|
|
52
|
+
# vx = jnp.where(vx > jnp.pi, vx - 2 * jnp.pi, vx)
|
|
53
|
+
# vx = jnp.where(vx <= -jnp.pi, vx + 2 * jnp.pi, vx)
|
|
54
|
+
# vy = jnp.where(vy > jnp.pi, vy - 2 * jnp.pi, vy)
|
|
55
|
+
# vy = jnp.where(vy <= -jnp.pi, vy + 2 * jnp.pi, vy)
|
|
56
|
+
# vz = jnp.where(vz > jnp.pi, vz - 2 * jnp.pi, vz)
|
|
57
|
+
# vz = jnp.where(vz <= -jnp.pi, vz + 2 * jnp.pi, vz)
|
|
58
|
+
# vx = vx / (2.0 * dx) / m_per_hbar
|
|
59
|
+
# vy = vy / (2.0 * dx) / m_per_hbar
|
|
60
|
+
# vz = vz / (2.0 * dx) / m_per_hbar
|
|
61
|
+
|
|
62
|
+
# Forward differences
|
|
63
|
+
vx1 = jnp.roll(S_per_hbar, -1, axis=0) - S_per_hbar
|
|
64
|
+
vy1 = jnp.roll(S_per_hbar, -1, axis=1) - S_per_hbar
|
|
65
|
+
vz1 = jnp.roll(S_per_hbar, -1, axis=2) - S_per_hbar
|
|
66
|
+
vx1 = jnp.where(vx1 > jnp.pi, vx1 - 2 * jnp.pi, vx1)
|
|
67
|
+
vx1 = jnp.where(vx1 <= -jnp.pi, vx1 + 2 * jnp.pi, vx1)
|
|
68
|
+
vy1 = jnp.where(vy1 > jnp.pi, vy1 - 2 * jnp.pi, vy1)
|
|
69
|
+
vy1 = jnp.where(vy1 <= -jnp.pi, vy1 + 2 * jnp.pi, vy1)
|
|
70
|
+
vz1 = jnp.where(vz1 > jnp.pi, vz1 - 2 * jnp.pi, vz1)
|
|
71
|
+
vz1 = jnp.where(vz1 <= -jnp.pi, vz1 + 2 * jnp.pi, vz1)
|
|
72
|
+
vx1 = vx1 / dx / m_per_hbar
|
|
73
|
+
vy1 = vy1 / dx / m_per_hbar
|
|
74
|
+
vz1 = vz1 / dx / m_per_hbar
|
|
75
|
+
|
|
76
|
+
# Backward differences
|
|
77
|
+
vx2 = S_per_hbar - jnp.roll(S_per_hbar, 1, axis=0)
|
|
78
|
+
vy2 = S_per_hbar - jnp.roll(S_per_hbar, 1, axis=1)
|
|
79
|
+
vz2 = S_per_hbar - jnp.roll(S_per_hbar, 1, axis=2)
|
|
80
|
+
vx2 = jnp.where(vx2 > jnp.pi, vx2 - 2 * jnp.pi, vx2)
|
|
81
|
+
vx2 = jnp.where(vx2 <= -jnp.pi, vx2 + 2 * jnp.pi, vx2)
|
|
82
|
+
vy2 = jnp.where(vy2 > jnp.pi, vy2 - 2 * jnp.pi, vy2)
|
|
83
|
+
vy2 = jnp.where(vy2 <= -jnp.pi, vy2 + 2 * jnp.pi, vy2)
|
|
84
|
+
vz2 = jnp.where(vz2 > jnp.pi, vz2 - 2 * jnp.pi, vz2)
|
|
85
|
+
vz2 = jnp.where(vz2 <= -jnp.pi, vz2 + 2 * jnp.pi, vz2)
|
|
86
|
+
vx2 = vx2 / dx / m_per_hbar
|
|
87
|
+
vy2 = vy2 / dx / m_per_hbar
|
|
88
|
+
vz2 = vz2 / dx / m_per_hbar
|
|
89
|
+
|
|
90
|
+
# Average forward and backward
|
|
91
|
+
vx = 0.5 * (vx1 + vx2)
|
|
92
|
+
vy = 0.5 * (vy1 + vy2)
|
|
93
|
+
vz = 0.5 * (vz1 + vz2)
|
|
94
|
+
|
|
95
|
+
return vx, vy, vz
|
jaxion/simulation.py
CHANGED
|
@@ -6,11 +6,19 @@ import json
|
|
|
6
6
|
import time
|
|
7
7
|
|
|
8
8
|
from .constants import constants
|
|
9
|
-
from .quantum import quantum_kick, quantum_drift
|
|
9
|
+
from .quantum import quantum_kick, quantum_drift, quantum_velocity
|
|
10
10
|
from .gravity import calculate_gravitational_potential
|
|
11
11
|
from .hydro import hydro_fluxes, hydro_accelerate
|
|
12
12
|
from .particles import particles_accelerate, particles_drift, bin_particles
|
|
13
|
-
from .
|
|
13
|
+
from .cosmology import get_supercomoving_time_interval, get_next_scale_factor
|
|
14
|
+
from .utils import (
|
|
15
|
+
set_up_parameters,
|
|
16
|
+
print_parameters,
|
|
17
|
+
xmeshgrid,
|
|
18
|
+
xmeshgrid_transpose,
|
|
19
|
+
xzeros,
|
|
20
|
+
xones,
|
|
21
|
+
)
|
|
14
22
|
from .visualization import plot_sim
|
|
15
23
|
|
|
16
24
|
|
|
@@ -21,10 +29,20 @@ class Simulation:
|
|
|
21
29
|
Parameters
|
|
22
30
|
----------
|
|
23
31
|
params (dict): The Python dictionary that contains the simulation parameters.
|
|
32
|
+
Params can also be a string path to a checkpoint directory to load a saved simulation.
|
|
24
33
|
|
|
25
34
|
"""
|
|
26
35
|
|
|
27
|
-
def __init__(self, params):
|
|
36
|
+
def __init__(self, params, sharding=None):
|
|
37
|
+
# allow loading directly from a checkpoint path
|
|
38
|
+
load_from_checkpoint = False
|
|
39
|
+
checkpoint_dir = ""
|
|
40
|
+
if isinstance(params, str):
|
|
41
|
+
load_from_checkpoint = True
|
|
42
|
+
checkpoint_dir = os.path.join(os.getcwd(), params)
|
|
43
|
+
with open(os.path.join(checkpoint_dir, "params.json"), "r") as f:
|
|
44
|
+
params = json.load(f)
|
|
45
|
+
|
|
28
46
|
# start from default simulation parameters and update with user params
|
|
29
47
|
self._params = set_up_parameters(params)
|
|
30
48
|
|
|
@@ -32,38 +50,80 @@ class Simulation:
|
|
|
32
50
|
if self.resolution % 2 != 0:
|
|
33
51
|
raise ValueError("Resolution must be divisible by 2.")
|
|
34
52
|
|
|
53
|
+
if self.params["time"]["adaptive"]:
|
|
54
|
+
raise NotImplementedError("Adaptive time stepping is not yet implemented.")
|
|
55
|
+
|
|
56
|
+
if self.params["physics"]["cosmology"]:
|
|
57
|
+
if (
|
|
58
|
+
self.params["physics"]["hydro"]
|
|
59
|
+
or self.params["physics"]["particles"]
|
|
60
|
+
or self.params["physics"]["external_potential"]
|
|
61
|
+
):
|
|
62
|
+
raise NotImplementedError(
|
|
63
|
+
"Cosmological hydro/particles/external_potential physics is not yet implemented."
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
if self.params["physics"]["hydro"] or self.params["physics"]["particles"]:
|
|
67
|
+
if sharding is not None:
|
|
68
|
+
raise NotImplementedError(
|
|
69
|
+
"hydro/particles sharding is not yet implemented."
|
|
70
|
+
)
|
|
71
|
+
|
|
35
72
|
if self.params["output"]["save"]:
|
|
36
|
-
|
|
37
|
-
|
|
73
|
+
if jax.process_index() == 0:
|
|
74
|
+
print("Simulation parameters:")
|
|
75
|
+
print_parameters(self.params)
|
|
76
|
+
|
|
77
|
+
# jitted functions
|
|
78
|
+
self.xmeshgrid_jit = jax.jit(
|
|
79
|
+
xmeshgrid, in_shardings=None, out_shardings=sharding
|
|
80
|
+
)
|
|
81
|
+
self.xmeshgrid_transpose_jit = jax.jit(
|
|
82
|
+
xmeshgrid_transpose, in_shardings=None, out_shardings=sharding
|
|
83
|
+
)
|
|
84
|
+
self.xzeros_jit = jax.jit(
|
|
85
|
+
xzeros, static_argnums=0, in_shardings=None, out_shardings=sharding
|
|
86
|
+
)
|
|
87
|
+
self.xones_jit = jax.jit(
|
|
88
|
+
xones, static_argnums=0, in_shardings=None, out_shardings=sharding
|
|
89
|
+
)
|
|
38
90
|
|
|
39
91
|
# simulation state
|
|
40
92
|
self.state = {}
|
|
41
93
|
self.state["t"] = 0.0
|
|
94
|
+
if self.params["physics"]["cosmology"]:
|
|
95
|
+
self.state["redshift"] = 0.0
|
|
42
96
|
if self.params["physics"]["quantum"]:
|
|
43
|
-
self.state["psi"] = (
|
|
44
|
-
jnp.zeros((self.resolution, self.resolution, self.resolution)) * 1j
|
|
45
|
-
)
|
|
97
|
+
self.state["psi"] = self.xzeros_jit(self.resolution) * 1j
|
|
46
98
|
if self.params["physics"]["external_potential"]:
|
|
47
|
-
self.state["V_ext"] =
|
|
48
|
-
(self.resolution, self.resolution, self.resolution)
|
|
49
|
-
)
|
|
99
|
+
self.state["V_ext"] = self.xzeros_jit(self.resolution)
|
|
50
100
|
if self.params["physics"]["hydro"]:
|
|
51
101
|
self.state["rho"] = jnp.zeros(
|
|
52
|
-
(self.resolution, self.resolution, self.resolution)
|
|
102
|
+
(self.resolution, self.resolution, self.resolution),
|
|
53
103
|
)
|
|
54
104
|
self.state["vx"] = jnp.zeros(
|
|
55
|
-
(self.resolution, self.resolution, self.resolution)
|
|
105
|
+
(self.resolution, self.resolution, self.resolution),
|
|
56
106
|
)
|
|
57
107
|
self.state["vy"] = jnp.zeros(
|
|
58
|
-
(self.resolution, self.resolution, self.resolution)
|
|
108
|
+
(self.resolution, self.resolution, self.resolution),
|
|
59
109
|
)
|
|
60
110
|
self.state["vz"] = jnp.zeros(
|
|
61
|
-
(self.resolution, self.resolution, self.resolution)
|
|
111
|
+
(self.resolution, self.resolution, self.resolution),
|
|
62
112
|
)
|
|
63
113
|
if self.params["physics"]["particles"]:
|
|
64
114
|
self.state["pos"] = jnp.zeros((self.num_particles, 3))
|
|
65
115
|
self.state["vel"] = jnp.zeros((self.num_particles, 3))
|
|
66
116
|
|
|
117
|
+
if load_from_checkpoint:
|
|
118
|
+
options = ocp.CheckpointManagerOptions()
|
|
119
|
+
async_checkpoint_manager = ocp.CheckpointManager(
|
|
120
|
+
checkpoint_dir, options=options
|
|
121
|
+
)
|
|
122
|
+
step = async_checkpoint_manager.latest_step()
|
|
123
|
+
self.state = async_checkpoint_manager.restore(
|
|
124
|
+
step, args=ocp.args.StandardRestore(self.state)
|
|
125
|
+
)
|
|
126
|
+
|
|
67
127
|
@property
|
|
68
128
|
def resolution(self):
|
|
69
129
|
return (
|
|
@@ -104,19 +164,24 @@ class Simulation:
|
|
|
104
164
|
def grid(self):
|
|
105
165
|
hx = 0.5 * self.dx
|
|
106
166
|
x_lin = jnp.linspace(hx, self.box_size - hx, self.resolution)
|
|
107
|
-
|
|
108
|
-
return
|
|
167
|
+
xx, yy, zz = self.xmeshgrid_jit(x_lin)
|
|
168
|
+
return xx, yy, zz
|
|
109
169
|
|
|
110
170
|
@property
|
|
111
171
|
def kgrid(self):
|
|
112
172
|
nx = self.resolution
|
|
113
173
|
k_lin = (2.0 * jnp.pi / self.box_size) * jnp.arange(-nx / 2, nx / 2)
|
|
114
|
-
kx, ky, kz =
|
|
174
|
+
kx, ky, kz = self.xmeshgrid_transpose_jit(k_lin)
|
|
115
175
|
kx = jnp.fft.ifftshift(kx)
|
|
116
176
|
ky = jnp.fft.ifftshift(ky)
|
|
117
177
|
kz = jnp.fft.ifftshift(kz)
|
|
118
178
|
return kx, ky, kz
|
|
119
179
|
|
|
180
|
+
@property
|
|
181
|
+
def quantum_velocity(self):
|
|
182
|
+
m_per_hbar = self.axion_mass / constants["reduced_planck_constant"]
|
|
183
|
+
return quantum_velocity(self.state["psi"], self.box_size, m_per_hbar)
|
|
184
|
+
|
|
120
185
|
def _calc_rho_bar(self, state):
|
|
121
186
|
rho_bar = 0.0
|
|
122
187
|
if self.params["physics"]["quantum"]:
|
|
@@ -133,7 +198,7 @@ class Simulation:
|
|
|
133
198
|
def _calc_grav_potential(self, state, k_sq):
|
|
134
199
|
G = constants["gravitational_constant"]
|
|
135
200
|
m_particle = self.params["particles"]["particle_mass"]
|
|
136
|
-
rho_bar = self._calc_rho_bar(
|
|
201
|
+
rho_bar = self._calc_rho_bar(state)
|
|
137
202
|
rho_tot = 0.0
|
|
138
203
|
if self.params["physics"]["quantum"]:
|
|
139
204
|
rho_tot += jnp.abs(state["psi"]) ** 2
|
|
@@ -141,6 +206,10 @@ class Simulation:
|
|
|
141
206
|
rho_tot += state["rho"]
|
|
142
207
|
if self.params["physics"]["particles"]:
|
|
143
208
|
rho_tot += bin_particles(state["pos"], self.dx, self.resolution, m_particle)
|
|
209
|
+
if self.params["physics"]["cosmology"]:
|
|
210
|
+
scale_factor = 1.0 / (1.0 + state["redshift"])
|
|
211
|
+
rho_bar *= scale_factor
|
|
212
|
+
rho_tot *= scale_factor
|
|
144
213
|
return calculate_gravitational_potential(rho_tot, k_sq, G, rho_bar)
|
|
145
214
|
|
|
146
215
|
@property
|
|
@@ -167,19 +236,40 @@ class Simulation:
|
|
|
167
236
|
# Simulation parameters
|
|
168
237
|
dx = self.dx
|
|
169
238
|
m_per_hbar = self.axion_mass / constants["reduced_planck_constant"]
|
|
239
|
+
box_size = self.box_size
|
|
170
240
|
|
|
171
241
|
dt_fac = 1.0
|
|
172
242
|
dt_kin = dt_fac * (m_per_hbar / 6.0) * (dx * dx)
|
|
243
|
+
t_start = self.params["time"]["start"]
|
|
173
244
|
t_end = self.params["time"]["end"]
|
|
245
|
+
t_span = t_end - t_start
|
|
246
|
+
state["t"] = t_start
|
|
247
|
+
|
|
248
|
+
# cosmology
|
|
249
|
+
if self.params["physics"]["cosmology"]:
|
|
250
|
+
z_start = self.params["time"]["start"]
|
|
251
|
+
z_end = self.params["time"]["end"]
|
|
252
|
+
omega_matter = self.params["cosmology"]["omega_matter"]
|
|
253
|
+
omega_lambda = self.params["cosmology"]["omega_lambda"]
|
|
254
|
+
little_h = self.params["cosmology"]["little_h"]
|
|
255
|
+
t_span = get_supercomoving_time_interval(
|
|
256
|
+
z_start, z_end, omega_matter, omega_lambda, little_h
|
|
257
|
+
)
|
|
258
|
+
state["t"] = 0.0
|
|
259
|
+
state["redshift"] = z_start
|
|
174
260
|
|
|
261
|
+
# hydro
|
|
175
262
|
c_sound = self.params["hydro"]["sound_speed"]
|
|
176
|
-
box_size = self.box_size
|
|
177
263
|
|
|
178
264
|
# round up to the nearest multiple of num_checkpoints
|
|
179
265
|
num_checkpoints = self.params["output"]["num_checkpoints"]
|
|
180
|
-
nt = int(round(round(
|
|
266
|
+
nt = int(round(round(t_span / dt_kin) / num_checkpoints) * num_checkpoints)
|
|
181
267
|
nt_sub = int(round(nt / num_checkpoints))
|
|
182
|
-
dt =
|
|
268
|
+
dt = t_span / nt
|
|
269
|
+
|
|
270
|
+
# distributed arrays (fixed) needed for calculations
|
|
271
|
+
kx, ky, kz = None, None, None
|
|
272
|
+
k_sq = None
|
|
183
273
|
|
|
184
274
|
# Fourier space variables
|
|
185
275
|
if self.params["physics"]["gravity"] or self.params["physics"]["quantum"]:
|
|
@@ -192,10 +282,14 @@ class Simulation:
|
|
|
192
282
|
checkpoint_dir = checkpoint_dir = os.path.join(
|
|
193
283
|
os.getcwd(), self.params["output"]["path"]
|
|
194
284
|
)
|
|
195
|
-
path =
|
|
285
|
+
path = os.path.join(os.getcwd(), checkpoint_dir)
|
|
286
|
+
if jax.process_index() == 0:
|
|
287
|
+
path = ocp.test_utils.erase_and_create_empty(checkpoint_dir)
|
|
196
288
|
async_checkpoint_manager = ocp.CheckpointManager(path, options=options)
|
|
197
289
|
|
|
198
|
-
|
|
290
|
+
carry = (state, kx, ky, kz, k_sq)
|
|
291
|
+
|
|
292
|
+
def _kick(state, kx, ky, kz, k_sq, dt):
|
|
199
293
|
# Kick (half-step)
|
|
200
294
|
if (
|
|
201
295
|
self.params["physics"]["gravity"]
|
|
@@ -222,7 +316,7 @@ class Simulation:
|
|
|
222
316
|
state["vel"], state["pos"], V, kx, ky, kz, dx, dt
|
|
223
317
|
)
|
|
224
318
|
|
|
225
|
-
def _drift(state, dt):
|
|
319
|
+
def _drift(state, k_sq, dt):
|
|
226
320
|
# Drift (full-step)
|
|
227
321
|
if self.params["physics"]["quantum"]:
|
|
228
322
|
state["psi"] = quantum_drift(state["psi"], k_sq, m_per_hbar, dt)
|
|
@@ -234,20 +328,30 @@ class Simulation:
|
|
|
234
328
|
state["pos"] = particles_drift(state["pos"], state["vel"], dt, box_size)
|
|
235
329
|
|
|
236
330
|
@jax.jit
|
|
237
|
-
def _update(_,
|
|
331
|
+
def _update(_, carry):
|
|
238
332
|
# Update the simulation state by one timestep
|
|
239
333
|
# according to a 2nd-order `kick-drift-kick` scheme
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
# update time
|
|
334
|
+
state, kx, ky, kz, k_sq = carry
|
|
335
|
+
_kick(state, kx, ky, kz, k_sq, 0.5 * dt)
|
|
336
|
+
_drift(state, k_sq, dt)
|
|
337
|
+
# update time & redshift
|
|
245
338
|
state["t"] += dt
|
|
339
|
+
if self.params["physics"]["cosmology"]:
|
|
340
|
+
scale_factor = get_next_scale_factor(
|
|
341
|
+
state["redshift"],
|
|
342
|
+
dt,
|
|
343
|
+
self.params["cosmology"]["omega_matter"],
|
|
344
|
+
self.params["cosmology"]["omega_lambda"],
|
|
345
|
+
self.params["cosmology"]["little_h"],
|
|
346
|
+
)
|
|
347
|
+
state["redshift"] = 1.0 / scale_factor - 1.0
|
|
348
|
+
_kick(state, kx, ky, kz, k_sq, 0.5 * dt)
|
|
246
349
|
|
|
247
|
-
return state
|
|
350
|
+
return state, kx, ky, kz, k_sq
|
|
248
351
|
|
|
249
352
|
# save initial state
|
|
250
|
-
|
|
353
|
+
if jax.process_index() == 0:
|
|
354
|
+
print(f"Starting simulation (res={self.resolution}, nt={nt}) ...")
|
|
251
355
|
if self.params["output"]["save"]:
|
|
252
356
|
with open(os.path.join(checkpoint_dir, "params.json"), "w") as f:
|
|
253
357
|
json.dump(self.params, f, indent=2)
|
|
@@ -259,7 +363,8 @@ class Simulation:
|
|
|
259
363
|
t_start_timer = time.time()
|
|
260
364
|
if self.params["output"]["save"]:
|
|
261
365
|
for i in range(1, num_checkpoints + 1):
|
|
262
|
-
|
|
366
|
+
carry = jax.lax.fori_loop(0, nt_sub, _update, init_val=carry)
|
|
367
|
+
state, kx, ky, kz, k_sq = carry
|
|
263
368
|
jax.block_until_ready(state)
|
|
264
369
|
# save state
|
|
265
370
|
async_checkpoint_manager.save(i, args=ocp.args.StandardSave(state))
|
|
@@ -267,15 +372,17 @@ class Simulation:
|
|
|
267
372
|
elapsed = time.time() - t_start_timer
|
|
268
373
|
est_total = elapsed / i * num_checkpoints
|
|
269
374
|
est_remaining = est_total - elapsed
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
375
|
+
if jax.process_index() == 0:
|
|
376
|
+
print(
|
|
377
|
+
f"{percent:.1f}%: estimated time remaining (s): {est_remaining:.1f}"
|
|
378
|
+
)
|
|
273
379
|
plot_sim(state, checkpoint_dir, i, self.params)
|
|
274
380
|
async_checkpoint_manager.wait_until_finished()
|
|
275
381
|
else:
|
|
276
382
|
state = jax.lax.fori_loop(0, nt, _update, init_val=state)
|
|
277
383
|
jax.block_until_ready(state)
|
|
278
|
-
|
|
384
|
+
if jax.process_index() == 0:
|
|
385
|
+
print("Simulation Run Time (s): ", time.time() - t_start_timer)
|
|
279
386
|
|
|
280
387
|
return state
|
|
281
388
|
|
jaxion/utils.py
CHANGED
|
@@ -5,6 +5,7 @@ from pathlib import Path
|
|
|
5
5
|
import importlib.resources
|
|
6
6
|
import json
|
|
7
7
|
from importlib.metadata import version
|
|
8
|
+
import jax.numpy as jnp
|
|
8
9
|
|
|
9
10
|
|
|
10
11
|
def print_parameters(params):
|
|
@@ -90,3 +91,29 @@ def run_example_main(example_path, argv=None):
|
|
|
90
91
|
os.chdir(old_cwd)
|
|
91
92
|
sys.argv = old_argv
|
|
92
93
|
return result
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
# Make a distributed meshgrid function
|
|
97
|
+
def xmeshgrid(x_lin):
|
|
98
|
+
xx, yy, zz = jnp.meshgrid(x_lin, x_lin, x_lin, indexing="ij")
|
|
99
|
+
return xx, yy, zz
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
# NOTE: jaxdecomp (jd) has pfft3d transpose the axis (X, Y, Z) --> (Y, Z, X), and pifft3d undo it
|
|
103
|
+
# so the fourier space variables (e.g. kx, ky, kz) all need to be transposed
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def xmeshgrid_transpose(x_lin):
|
|
107
|
+
xx, yy, zz = jnp.meshgrid(x_lin, x_lin, x_lin, indexing="ij")
|
|
108
|
+
xx = jnp.transpose(xx, (1, 2, 0))
|
|
109
|
+
yy = jnp.transpose(yy, (1, 2, 0))
|
|
110
|
+
zz = jnp.transpose(zz, (1, 2, 0))
|
|
111
|
+
return xx, yy, zz
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def xzeros(nx):
|
|
115
|
+
return jnp.zeros((nx, nx, nx))
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def xones(nx):
|
|
119
|
+
return jnp.ones((nx, nx, nx))
|
jaxion/visualization.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
import jax
|
|
1
2
|
import jax.numpy as jnp
|
|
2
3
|
import matplotlib.pyplot as plt
|
|
3
4
|
import os
|
|
@@ -8,67 +9,82 @@ def plot_sim(state, checkpoint_dir, i, params):
|
|
|
8
9
|
|
|
9
10
|
dynamic_range = params["output"]["plot_dynamic_range"]
|
|
10
11
|
|
|
12
|
+
# process distributed data
|
|
11
13
|
if params["physics"]["quantum"]:
|
|
12
|
-
plt.clf()
|
|
13
|
-
|
|
14
|
-
# DM projection
|
|
15
14
|
nx = state["psi"].shape[0]
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
15
|
+
rho_bar_dm = jnp.mean(jnp.abs(state["psi"]) ** 2)
|
|
16
|
+
rho_proj_dm = jax.experimental.multihost_utils.process_allgather(
|
|
17
|
+
jnp.log10(jnp.mean(jnp.abs(state["psi"]) ** 2, axis=2).T)
|
|
18
|
+
).reshape(nx, nx)
|
|
19
|
+
if params["physics"]["hydro"]:
|
|
20
|
+
nx = state["rho"].shape[0]
|
|
21
|
+
rho_bar_gas = jnp.mean(state["rho"])
|
|
22
|
+
rho_proj_gas = jax.experimental.multihost_utils.process_allgather(
|
|
23
|
+
jnp.log10(jnp.mean(jnp.abs(state["rho"]) ** 2, axis=2).T)
|
|
24
|
+
).reshape(nx, nx)
|
|
19
25
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
cmap="inferno",
|
|
25
|
-
origin="lower",
|
|
26
|
-
vmin=vmin,
|
|
27
|
-
vmax=vmax,
|
|
28
|
-
extent=[0, nx, 0, nx],
|
|
29
|
-
)
|
|
30
|
-
if params["physics"]["particles"]:
|
|
31
|
-
# draw particles
|
|
32
|
-
box_size = params["domain"]["box_size"]
|
|
33
|
-
sx = (state["pos"][:, 0] / box_size) * nx
|
|
34
|
-
sy = (state["pos"][:, 1] / box_size) * nx
|
|
35
|
-
plt.plot(sx, sy, color="cyan", marker=".", linestyle="None", markersize=2)
|
|
36
|
-
ax.set_aspect("equal")
|
|
37
|
-
ax.get_xaxis().set_visible(False)
|
|
38
|
-
ax.get_yaxis().set_visible(False)
|
|
26
|
+
# create plot on process 0
|
|
27
|
+
if jax.process_index() == 0:
|
|
28
|
+
if params["physics"]["quantum"]:
|
|
29
|
+
plt.clf()
|
|
39
30
|
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
)
|
|
45
|
-
plt.close()
|
|
31
|
+
# DM projection
|
|
32
|
+
nx = state["psi"].shape[0]
|
|
33
|
+
vmin = jnp.log10(rho_bar_dm / dynamic_range)
|
|
34
|
+
vmax = jnp.log10(rho_bar_dm * dynamic_range)
|
|
46
35
|
|
|
47
|
-
|
|
48
|
-
|
|
36
|
+
ax = plt.gca()
|
|
37
|
+
ax.imshow(
|
|
38
|
+
rho_proj_dm,
|
|
39
|
+
cmap="inferno",
|
|
40
|
+
origin="lower",
|
|
41
|
+
vmin=vmin,
|
|
42
|
+
vmax=vmax,
|
|
43
|
+
extent=[0, nx, 0, nx],
|
|
44
|
+
)
|
|
45
|
+
if params["physics"]["particles"]:
|
|
46
|
+
# draw particles
|
|
47
|
+
box_size = params["domain"]["box_size"]
|
|
48
|
+
sx = (state["pos"][:, 0] / box_size) * nx
|
|
49
|
+
sy = (state["pos"][:, 1] / box_size) * nx
|
|
50
|
+
plt.plot(
|
|
51
|
+
sx, sy, color="cyan", marker=".", linestyle="None", markersize=2
|
|
52
|
+
)
|
|
53
|
+
ax.set_aspect("equal")
|
|
54
|
+
ax.get_xaxis().set_visible(False)
|
|
55
|
+
ax.get_yaxis().set_visible(False)
|
|
49
56
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
57
|
+
plt.savefig(
|
|
58
|
+
os.path.join(checkpoint_dir, f"dm{i:03d}.png"),
|
|
59
|
+
bbox_inches="tight",
|
|
60
|
+
pad_inches=0,
|
|
61
|
+
)
|
|
62
|
+
plt.close()
|
|
63
|
+
|
|
64
|
+
if params["physics"]["hydro"]:
|
|
65
|
+
plt.clf()
|
|
66
|
+
|
|
67
|
+
# gas projection
|
|
68
|
+
nx = state["rho"].shape[0]
|
|
69
|
+
vmin = jnp.log10(rho_bar_gas / dynamic_range)
|
|
70
|
+
vmax = jnp.log10(rho_bar_gas * dynamic_range)
|
|
71
|
+
|
|
72
|
+
ax = plt.gca()
|
|
73
|
+
ax.imshow(
|
|
74
|
+
rho_proj_gas,
|
|
75
|
+
cmap="viridis",
|
|
76
|
+
origin="lower",
|
|
77
|
+
vmin=vmin,
|
|
78
|
+
vmax=vmax,
|
|
79
|
+
extent=[0, nx, 0, nx],
|
|
80
|
+
)
|
|
81
|
+
ax.set_aspect("equal")
|
|
82
|
+
ax.get_xaxis().set_visible(False)
|
|
83
|
+
ax.get_yaxis().set_visible(False)
|
|
68
84
|
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
85
|
+
plt.savefig(
|
|
86
|
+
os.path.join(checkpoint_dir, f"gas{i:03d}.png"),
|
|
87
|
+
bbox_inches="tight",
|
|
88
|
+
pad_inches=0,
|
|
89
|
+
)
|
|
90
|
+
plt.close()
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: jaxion
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.3
|
|
4
4
|
Summary: A differentiable simulation library for fuzzy dark matter in JAX
|
|
5
5
|
Author-email: Philip Mocz <philip.mocz@gmail.com>
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -21,6 +21,12 @@ Provides-Extra: cuda12
|
|
|
21
21
|
Requires-Dist: jax[cuda12]==0.5.3; extra == "cuda12"
|
|
22
22
|
Dynamic: license-file
|
|
23
23
|
|
|
24
|
+
<p align="center">
|
|
25
|
+
<a href="https://jaxion.readthedocs.io">
|
|
26
|
+
<img src="docs/_static/jaxion-logo.svg" alt="jaxion logo" width="128"/>
|
|
27
|
+
</a>
|
|
28
|
+
</p>
|
|
29
|
+
|
|
24
30
|
# jaxion
|
|
25
31
|
|
|
26
32
|
[![Repo Status][status-badge]][status-link]
|
|
@@ -44,9 +50,9 @@ A simple JAX-powered simulation library for numerical experiments of fuzzy dark
|
|
|
44
50
|
|
|
45
51
|
Author: [Philip Mocz (@pmocz)](https://github.com/pmocz/)
|
|
46
52
|
|
|
47
|
-
|
|
53
|
+
Jaxion is built for multi-GPU scalability and is fully differentiable. It is a high-performance JAX-based simulation library for modeling fuzzy dark matter alongside stars, gas, and cosmological dynamics. Being differentiable, Jaxion can seamlessly integrate with pipelines for inverse-problems, inference, optimization, and coupling to ML models.
|
|
48
54
|
|
|
49
|
-
Jaxion is
|
|
55
|
+
Jaxion is the simpler companion project to differentiable astrophysics code [Adirondax](https://github.com/AdirondaxProject/adirondax)
|
|
50
56
|
|
|
51
57
|
|
|
52
58
|
## Getting started
|
|
@@ -66,6 +72,20 @@ pip install jaxion[cuda12]
|
|
|
66
72
|
Check out the `examples/` directory for demonstrations of using Jaxion.
|
|
67
73
|
|
|
68
74
|
|
|
75
|
+
## Examples
|
|
76
|
+
|
|
77
|
+
<p align="center">
|
|
78
|
+
<a href="https://github.com/JaxionProject/jaxion/tree/main/examples/dynamical_friction">
|
|
79
|
+
<img src="examples/dynamical_friction/movie.gif" alt="dynamical_friction" width="128"/>
|
|
80
|
+
</a>
|
|
81
|
+
<a href="https://github.com/JaxionProject/jaxion/tree/main/examples/heating_gas">
|
|
82
|
+
<img src="examples/heating_gas/movie.gif" alt="heating_gas" width="128"/>
|
|
83
|
+
</a>
|
|
84
|
+
<a href="https://github.com/JaxionProject/jaxion/tree/main/examples/tidal_stripping">
|
|
85
|
+
<img src="examples/tidal_stripping/movie.gif" alt="tidal_stripping" width="128"/>
|
|
86
|
+
</a>
|
|
87
|
+
</p>
|
|
88
|
+
|
|
69
89
|
## Links
|
|
70
90
|
|
|
71
91
|
* [Code repository](https://github.com/JaxionProject/jaxion) on GitHub (this page).
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
jaxion/__init__.py,sha256=Hdji1UQ47lG24Pqcy6UUq9L0-qy6m9Ax41L0vIYzBto,164
|
|
2
|
+
jaxion/analysis.py,sha256=UZpCvDr2ISDAgRoZFtuhvDk0QtN3mn5NxIoYYY8wzs8,1304
|
|
3
|
+
jaxion/constants.py,sha256=b0vgqtezIEGp4O5JRJo3YK3eQC8AN6hmV84-4pYWgmc,528
|
|
4
|
+
jaxion/cosmology.py,sha256=UC1McXNTXGoPRYXn0nI2-csVkJWL-ZBNoCa44oU1b4w,2681
|
|
5
|
+
jaxion/gravity.py,sha256=3brRZelKm-soXqk_Lt3SqhbZ00woJCraqwdMuR-KooA,291
|
|
6
|
+
jaxion/hydro.py,sha256=KoJ02tRpAc4V3Ofzw4zbHLRaE2GdIatbOBE04_LsSRw,6980
|
|
7
|
+
jaxion/params_default.json,sha256=9CJrhEPsv5zGEs7_WqFyuccCDipPCDhXgKzVdqOsOWE,2775
|
|
8
|
+
jaxion/particles.py,sha256=pMopGvoZ0J_3EviD0WnTMmiebU9h2_8IO-p6I-E5DEU,3980
|
|
9
|
+
jaxion/quantum.py,sha256=GWOpN6ipfEw-6Ah2zQpxS3oqeSt_iHMDSgnVYSjXY5E,3321
|
|
10
|
+
jaxion/simulation.py,sha256=3UrrHPJr0_39BiP86nQSl4ui8EsPHf9RrZ5hDobit5k,14884
|
|
11
|
+
jaxion/utils.py,sha256=rT7NM0FNEgFwN7oTgTb-jkR66Iw0xYTHHxcoikYd1ag,3572
|
|
12
|
+
jaxion/visualization.py,sha256=MJ3iqrvAem4RbozM79Ri2cNTIwKUV8o5X8M32FiDfbo,2928
|
|
13
|
+
jaxion-0.0.3.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
14
|
+
jaxion-0.0.3.dist-info/METADATA,sha256=1C3fwJHJEBZYMEhxmTkKd7z_ZUrF15E3P0oS5XxrdjA,3579
|
|
15
|
+
jaxion-0.0.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
16
|
+
jaxion-0.0.3.dist-info/top_level.txt,sha256=S1OV2VdlDG_9UwpKOIji4itQGOS-VWUOWUi3GeXWzt0,7
|
|
17
|
+
jaxion-0.0.3.dist-info/RECORD,,
|
jaxion-0.0.2.dist-info/RECORD
DELETED
|
@@ -1,15 +0,0 @@
|
|
|
1
|
-
jaxion/__init__.py,sha256=gHrU_SvIBIbuYt-MeEeyiMyuJKwLjTTiz513OZ7CDHQ,95
|
|
2
|
-
jaxion/constants.py,sha256=b0vgqtezIEGp4O5JRJo3YK3eQC8AN6hmV84-4pYWgmc,528
|
|
3
|
-
jaxion/gravity.py,sha256=DdVOhAaoIPLXtW7ALrMK2umLPkStzxJ778l1SD2yMkY,266
|
|
4
|
-
jaxion/hydro.py,sha256=3ZFNDhIaPBw4uTBqFbjJotCLkuKG8AvRcY1T78QQ96U,6953
|
|
5
|
-
jaxion/params_default.json,sha256=4-XNPN52phGxYsP6FOeIJH0ALtICjrQxApotz-NhJEc,1782
|
|
6
|
-
jaxion/particles.py,sha256=7oJBVyddfOXUDYxQuGIs8_ByL2Hu6jkuWE1MWtshFjk,3953
|
|
7
|
-
jaxion/quantum.py,sha256=91tEotQhX1CB3HCaRCDARRPX_dEx58GIUCluY5RyPvA,377
|
|
8
|
-
jaxion/simulation.py,sha256=KcIi91LHJO5K4M0gxU4qc7e4JEeOyW1HUZOYwt8FOG0,10385
|
|
9
|
-
jaxion/utils.py,sha256=LCVx5_N9-6HsbENENOEz8tI6Czl0Z8Bit2icf1k4024,2880
|
|
10
|
-
jaxion/visualization.py,sha256=WKID2_LFbbqXwmW_TLTnVznKYHfjiDKNbjD9xS4GfLc,2209
|
|
11
|
-
jaxion-0.0.2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
12
|
-
jaxion-0.0.2.dist-info/METADATA,sha256=uUe69xVa678F2uV28-CczduNdfoOtBpwhfuM_mJKXpM,2811
|
|
13
|
-
jaxion-0.0.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
14
|
-
jaxion-0.0.2.dist-info/top_level.txt,sha256=S1OV2VdlDG_9UwpKOIji4itQGOS-VWUOWUi3GeXWzt0,7
|
|
15
|
-
jaxion-0.0.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|